#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantum Eraser – GUI (Aer or IBM Quantum) with Live Plots, Binary Outputs, and Correlation Check
Author: ChatGPT

Install (local simulator):
  pip install "qiskit==1.2.*" "qiskit-aer==0.15.*" matplotlib numpy

For IBM Quantum (real hardware or cloud simulators):
  pip install "qiskit-ibm-runtime>=0.25.0"

Run:
  python quantum_eraser_gui_ibm.py
"""
import os
import math
import threading
import tkinter as tk
from tkinter import ttk, messagebox, filedialog

import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator

# Optional IBM Runtime imports (loaded lazily)
IBM_RUNTIME_AVAILABLE = False
try:
    from qiskit_ibm_runtime import QiskitRuntimeService
    IBM_RUNTIME_AVAILABLE = True
except Exception:
    IBM_RUNTIME_AVAILABLE = False

# ------------------- Core circuit -------------------
def eraser_circuit(phi: float, erase: bool=True) -> QuantumCircuit:
    """Return the 2-qubit eraser circuit (q0=S, q1=I; c0=S, c1=I)."""
    qc = QuantumCircuit(2, 2)
    # First beamsplitter on S
    qc.h(0)
    # Phase
    qc.rz(phi, 0)
    # Which-path entanglement
    qc.cx(0, 1)
    # Second beamsplitter on S
    qc.h(0)
    # Idler basis: X (eraser) or Z (which-path)
    if erase:
        qc.h(1)
    # Measurements
    qc.measure(0, 0)  # S -> c0
    qc.measure(1, 1)  # I -> c1
    return qc

def conditional_prob_from_counts(counts, i_bit='0'):
    """Compute P(s=0 | i=i_bit) from Qiskit counts dict with keys 'i s' (c1 c0)."""
    tot = n0 = 0
    for k, v in counts.items():
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0':
                n0 += v
    return (n0 / tot) if tot else 0.0

def visibility(vals):
    if not vals: return 0.0
    vmax, vmin = max(vals), min(vals)
    denom = vmax + vmin
    return (vmax - vmin) / denom if denom else 0.0

def save_binary_stream(memory, out_path):
    """Save raw binary stream for S (second bit of 'i s'). If memory is None, skip."""
    if memory is None:
        return
    bits = ''.join(m[1] for m in memory)  # S is second bit
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, 'w') as f:
        f.write(bits)

# ------------------- GUI Application -------------------
class EraserGUI(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Quantum Eraser – GUI (Aer / IBM Quantum)")
        self.geometry("980x720")

        # Defaults
        self.var_shots   = tk.IntVar(value=4000)
        self.var_nphi    = tk.IntVar(value=13)
        self.var_save    = tk.BooleanVar(value=True)
        self.var_show    = tk.BooleanVar(value=True)
        self.var_outdir  = tk.StringVar(value="out")
        self.var_seed    = tk.IntVar(value=1234)

        # IBM Runtime
        self.var_use_ibm = tk.BooleanVar(value=False)
        self.var_ibm_token    = tk.StringVar(value="")          # paste token here
        self.var_ibm_instance = tk.StringVar(value="")          # e.g. "ibm-q/open/main"
        self.var_ibm_backend  = tk.StringVar(value="")          # e.g. "ibm_brisbane"
        self.service = None
        self.backend = None

        # Storage for last-run data (for correlation button)
        self.last_phis = []
        self.last_p_plus = []
        self.last_p_minus = []

        # Controls frame
        top = ttk.Frame(self, padding=10)
        top.pack(side=tk.TOP, fill=tk.X)

        # Row 0
        ttk.Label(top, text="shots:").grid(row=0, column=0, sticky="w")
        ttk.Entry(top, textvariable=self.var_shots, width=8).grid(row=0, column=1, padx=5)

        ttk.Label(top, text="nphi:").grid(row=0, column=2, sticky="w")
        ttk.Entry(top, textvariable=self.var_nphi, width=8).grid(row=0, column=3, padx=5)

        ttk.Checkbutton(top, text="save", variable=self.var_save).grid(row=0, column=4, padx=8)
        ttk.Checkbutton(top, text="show", variable=self.var_show).grid(row=0, column=5, padx=8)

        ttk.Label(top, text="outdir:").grid(row=1, column=0, sticky="w", pady=5)
        ttk.Entry(top, textvariable=self.var_outdir, width=24).grid(row=1, column=1, columnspan=2, sticky="we", padx=5)
        ttk.Button(top, text="Browse…", command=self.choose_outdir).grid(row=1, column=3, sticky="w")

        ttk.Label(top, text="seed:").grid(row=1, column=4, sticky="e")
        ttk.Entry(top, textvariable=self.var_seed, width=10).grid(row=1, column=5, padx=5)

        # IBM runtime controls
        ibm = ttk.LabelFrame(self, text="IBM Quantum (real backend)", padding=8)
        ibm.pack(side=tk.TOP, fill=tk.X, padx=10, pady=4)

        ttk.Checkbutton(ibm, text="Use IBM Quantum", variable=self.var_use_ibm).grid(row=0, column=0, sticky="w")
        ttk.Label(ibm, text="API Token:").grid(row=0, column=1, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_token, width=40, show="•").grid(row=0, column=2, padx=5, sticky="we")
        ttk.Label(ibm, text="Instance:").grid(row=0, column=3, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_instance, width=24).grid(row=0, column=4, padx=5)
        ttk.Label(ibm, text="Backend:").grid(row=0, column=5, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_backend, width=18).grid(row=0, column=6, padx=5)

        # Buttons
        btns = ttk.Frame(self, padding=(10,4))
        btns.pack(side=tk.TOP, fill=tk.X)
        self.btn_run = ttk.Button(btns, text="Run", command=self.on_run_clicked)
        self.btn_run.pack(side=tk.LEFT, padx=5)
        self.btn_correlate = ttk.Button(btns, text="Correlate / Detect Interference", command=self.on_correlate_clicked)
        self.btn_correlate.pack(side=tk.LEFT, padx=5)

        # Figure
        self.fig = plt.Figure(figsize=(7.6,4.8))
        self.ax  = self.fig.add_subplot(111)
        self.ax.set_title("Quantum Eraser – P(s=0|i=±)")
        self.ax.set_xlabel("φ (radiani)")
        self.ax.set_ylabel("Probabilitate")
        (self.line_plus,)  = self.ax.plot([], [], marker='o', label="P(s=0|i=+)")
        (self.line_minus,) = self.ax.plot([], [], marker='s', label="P(s=0|i=-)")
        self.ax.legend()
        self.ax.grid(True, linestyle='--', alpha=0.4)

        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Status
        self.status = tk.StringVar(value="Ready.")
        ttk.Label(self, textvariable=self.status, anchor="w", padding=8).pack(side=tk.BOTTOM, fill=tk.X)

        self._running = False

    # ------------- UI Helpers -------------
    def choose_outdir(self):
        d = filedialog.askdirectory(title="Select output directory")
        if d:
            self.var_outdir.set(d)

    def set_running(self, running: bool):
        self._running = running
        self.btn_run.config(state=(tk.DISABLED if running else tk.NORMAL))

    # ------------- Run Button -------------
    def on_run_clicked(self):
        if self._running:
            messagebox.showinfo("Info", "Experiment already running.")
            return
        # Basic validation
        try:
            shots = int(self.var_shots.get())
            nphi  = int(self.var_nphi.get())
            if shots <= 0 or nphi <= 1:
                raise ValueError
        except Exception:
            messagebox.showerror("Error", "Invalid shots or nphi.")
            return
        # If IBM selected, check fields
        if self.var_use_ibm.get():
            if not IBM_RUNTIME_AVAILABLE:
                messagebox.showerror("Error", "qiskit-ibm-runtime not installed.")
                return
            if not self.var_ibm_token.get() or not self.var_ibm_backend.get():
                messagebox.showerror("Error", "Please provide IBM API Token and Backend name.")
                return

        self.set_running(True)
        self.status.set("Running…")
        threading.Thread(target=self.run_experiment_threadsafe, daemon=True).start()

    def run_experiment_threadsafe(self):
        try:
            self.run_experiment()
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))
        finally:
            self.set_running(False)

    # ------------- Correlate Button -------------
    def on_correlate_clicked(self):
        """Compute visibility from last run and annotate plot + status."""
        if not self.last_phis or not self.last_p_plus or not self.last_p_minus:
            messagebox.showinfo("Info", "No recent data to correlate. Please run first.")
            return
        Vp = visibility(self.last_p_plus)
        Vm = visibility(self.last_p_minus)
        verdict = "Interferență detectată" if (Vp > 0.25 and Vm > 0.25) else "Interferență slabă sau absentă"
        # Annotate on plot
        self.ax.text(0.02, 0.96, f"V(+)={Vp:.2f}, V(-)={Vm:.2f}\n{verdict}",
                     transform=self.ax.transAxes, fontsize=10,
                     bbox=dict(boxstyle="round", facecolor="white", alpha=0.7))
        self.canvas.draw_idle()
        self.status.set(f"Correlation done: V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f} → {verdict}")

    # ------------- Backend Builders -------------
    def _build_aer_backend(self):
        return AerSimulator(seed_simulator=int(self.var_seed.get()))

    def _build_ibm_backend(self):
        """Connect to IBM Runtime and return backend object by name."""
        token = self.var_ibm_token.get().strip()
        instance = self.var_ibm_instance.get().strip() or None
        service = QiskitRuntimeService(channel="ibm_quantum", token=token, instance=instance) if instance else QiskitRuntimeService(channel="ibm_quantum", token=token)
        self.service = service
        backend_name = self.var_ibm_backend.get().strip()
        backend = service.backend(backend_name)
        self.backend = backend
        return backend

    # ------------- Main experiment -------------
    def run_experiment(self):
        shots = int(self.var_shots.get())
        nphi  = int(self.var_nphi.get())
        save  = bool(self.var_save.get())
        showp = bool(self.var_show.get())
        outdir= self.var_outdir.get()
        seed  = int(self.var_seed.get())

        os.makedirs(outdir, exist_ok=True)
        phis = np.linspace(0.0, 2.0*math.pi, nphi)

        # Clear previous plot
        self.line_plus.set_data([], [])
        self.line_minus.set_data([], [])
        self.ax.relim(); self.ax.autoscale_view()
        self.canvas.draw_idle()

        use_ibm = self.var_use_ibm.get()
        if use_ibm:
            backend = self._build_ibm_backend()
        else:
            backend = self._build_aer_backend()

        p_plus, p_minus = [], []
        self.last_phis = list(phis)
        self.last_p_plus = p_plus
        self.last_p_minus = p_minus

        for idx, phi in enumerate(phis):
            qc = eraser_circuit(phi, erase=True)

            if use_ibm:
                # Transpile to target backend coupling map & run
                tqc = transpile(qc, backend=backend, optimization_level=3, seed_transpiler=seed)
                # Try to request memory if supported; fallback to counts-only
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                memory = None
                try:
                    memory = result.get_memory()
                except Exception:
                    memory = None
            else:
                # Local Aer
                backend_local = backend
                tqc = transpile(qc, backend=backend_local, optimization_level=3)
                job = backend_local.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                memory = result.get_memory()

            # Probabilities conditioned on idler outcomes
            p0_plus  = conditional_prob_from_counts(counts, '0')  # i=0 ≈ +_X
            p0_minus = conditional_prob_from_counts(counts, '1')  # i=1 ≈ -_X
            p_plus.append(p0_plus)
            p_minus.append(p0_minus)

            # Live plot update
            self.line_plus.set_data(phis[:idx+1], p_plus)
            self.line_minus.set_data(phis[:idx+1], p_minus)
            self.ax.relim(); self.ax.autoscale_view()
            self.canvas.draw_idle()

            # Save binary stream if available
            if save:
                # If memory not available (some IBM backends), synthesize a stream from counts
                if memory is None:
                    # Synthesize shot-wise labels according to counts distribution
                    synth = []
                    for bitstr, c in counts.items():
                        synth.extend([bitstr] * int(c))
                    memory = synth
                save_binary_stream(memory, os.path.join(outdir, f"binary_phi_{idx:02d}.txt"))

            self.status.set(f"φ={phi:.3f} • P(+)= {p0_plus:.3f} • P(-)= {p0_minus:.3f} • step {idx+1}/{nphi}")

        # Save CSV and PNG
        if save:
            import csv
            with open(os.path.join(outdir, "eraser_probs.csv"), "w", newline="") as f:
                w = csv.writer(f)
                w.writerow(['phi','P(s=0|i=+)','P(s=0|i=-)'])
                for phi,a,b in zip(phis, p_plus, p_minus):
                    w.writerow([phi, a, b])
            self.fig.savefig(os.path.join(outdir, "eraser_plot.png"), dpi=160)

        Vp = visibility(p_plus); Vm = visibility(p_minus)
        self.status.set(f"Done. V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f} (Use IBM: {use_ibm})")
        if showp:
            pass  # Window stays open

def main():
    app = EraserGUI()
    app.mainloop()

if __name__ == "__main__":
    main()
