#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantum Eraser – GUI (Aer / IBM Quantum) with:
- Live plots
- Binary outputs for S
- *Correlated bits* exports: pairs, S|I=0, S|I=1
- Correlation/visibility checker
- Optional IBM Runtime support (token / instance / backend)

Install:
  pip install "qiskit==1.2.*" "qiskit-aer==0.15.*" matplotlib numpy
  pip install "qiskit-ibm-runtime>=0.25.0"   # only if using IBM backends

Run:
  python quantum_eraser_gui_ibm_corr.py
"""
import os
import math
import threading
import tkinter as tk
from tkinter import ttk, messagebox, filedialog

import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator

# Optional IBM Runtime imports
IBM_RUNTIME_AVAILABLE = False
try:
    from qiskit_ibm_runtime import QiskitRuntimeService
    IBM_RUNTIME_AVAILABLE = True
except Exception:
    IBM_RUNTIME_AVAILABLE = False

# ------------------- Core helpers -------------------
def eraser_circuit(phi: float, erase: bool=True) -> QuantumCircuit:
    """Return the 2-qubit eraser circuit (q0=S, q1=I; c0=S, c1=I)."""
    qc = QuantumCircuit(2, 2)
    qc.h(0)       # first beamsplitter on S
    qc.rz(phi, 0) # phase
    qc.cx(0, 1)   # which-path entanglement
    qc.h(0)       # second beamsplitter on S
    if erase:
        qc.h(1)   # measure I in X-basis
    qc.measure(0, 0)  # S -> c0
    qc.measure(1, 1)  # I -> c1
    return qc

def conditional_prob_from_counts(counts, i_bit='0'):
    tot = n0 = 0
    for k, v in counts.items():
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0':
                n0 += v
    return (n0 / tot) if tot else 0.0

def visibility(vals):
    if not vals: return 0.0
    vmax, vmin = max(vals), min(vals)
    denom = vmax + vmin
    return (vmax - vmin) / denom if denom else 0.0

def save_text(path, text):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        f.write(text)

def list_to_text(lines):
    return ''.join(lines)

# ------------------- GUI Application -------------------
class EraserGUI(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Quantum Eraser – GUI (Aer / IBM) + Correlated Bits")
        self.geometry("1040x780")

        # Defaults
        self.var_shots   = tk.IntVar(value=4000)
        self.var_nphi    = tk.IntVar(value=13)
        self.var_save    = tk.BooleanVar(value=True)
        self.var_show    = tk.BooleanVar(value=True)
        self.var_outdir  = tk.StringVar(value="out")
        self.var_seed    = tk.IntVar(value=1234)

        # Extra: save correlated streams
        self.var_save_corr = tk.BooleanVar(value=True)

        # IBM Runtime
        self.var_use_ibm = tk.BooleanVar(value=False)
        self.var_ibm_token    = tk.StringVar(value="")          # paste token here
        self.var_ibm_instance = tk.StringVar(value="")          # e.g. "ibm-q/open/main"
        self.var_ibm_backend  = tk.StringVar(value="")          # e.g. "ibm_brisbane"
        self.service = None
        self.backend = None

        # Storage for last-run
        self.last_phis = []
        self.last_p_plus = []
        self.last_p_minus = []
        self.last_memory_by_phi = {}  # idx -> list of bitstrings 'is'

        # Controls
        top = ttk.Frame(self, padding=10)
        top.pack(side=tk.TOP, fill=tk.X)

        ttk.Label(top, text="shots:").grid(row=0, column=0, sticky="w")
        ttk.Entry(top, textvariable=self.var_shots, width=8).grid(row=0, column=1, padx=5)

        ttk.Label(top, text="nphi:").grid(row=0, column=2, sticky="w")
        ttk.Entry(top, textvariable=self.var_nphi, width=8).grid(row=0, column=3, padx=5)

        ttk.Checkbutton(top, text="save", variable=self.var_save).grid(row=0, column=4, padx=6)
        ttk.Checkbutton(top, text="show", variable=self.var_show).grid(row=0, column=5, padx=6)
        ttk.Checkbutton(top, text="save correlated streams", variable=self.var_save_corr).grid(row=0, column=6, padx=6)

        ttk.Label(top, text="outdir:").grid(row=1, column=0, sticky="w", pady=5)
        ttk.Entry(top, textvariable=self.var_outdir, width=26).grid(row=1, column=1, columnspan=2, sticky="we", padx=5)
        ttk.Button(top, text="Browse…", command=self.choose_outdir).grid(row=1, column=3, sticky="w")

        ttk.Label(top, text="seed:").grid(row=1, column=4, sticky="e")
        ttk.Entry(top, textvariable=self.var_seed, width=10).grid(row=1, column=5, padx=5)

        # IBM frame
        ibm = ttk.LabelFrame(self, text="IBM Quantum (real backend)", padding=8)
        ibm.pack(side=tk.TOP, fill=tk.X, padx=10, pady=4)
        ttk.Checkbutton(ibm, text="Use IBM Quantum", variable=self.var_use_ibm).grid(row=0, column=0, sticky="w")
        ttk.Label(ibm, text="API Token:").grid(row=0, column=1, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_token, width=40, show="•").grid(row=0, column=2, padx=5, sticky="we")
        ttk.Label(ibm, text="Instance:").grid(row=0, column=3, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_instance, width=24).grid(row=0, column=4, padx=5)
        ttk.Label(ibm, text="Backend:").grid(row=0, column=5, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_backend, width=18).grid(row=0, column=6, padx=5)

        # Buttons
        btns = ttk.Frame(self, padding=(10,4))
        btns.pack(side=tk.TOP, fill=tk.X)
        self.btn_run = ttk.Button(btns, text="Run", command=self.on_run_clicked)
        self.btn_run.pack(side=tk.LEFT, padx=5)
        self.btn_correlate = ttk.Button(btns, text="Correlate / Detect Interference", command=self.on_correlate_clicked)
        self.btn_correlate.pack(side=tk.LEFT, padx=5)
        self.btn_preview = ttk.Button(btns, text="Show Last Phi Correlations", command=self.on_preview_clicked)
        self.btn_preview.pack(side=tk.LEFT, padx=5)

        # Plot
        self.fig = plt.Figure(figsize=(7.8,4.8))
        self.ax  = self.fig.add_subplot(111)
        self.ax.set_title("Quantum Eraser – P(s=0|i=±)")
        self.ax.set_xlabel("φ (radiani)")
        self.ax.set_ylabel("Probabilitate")
        (self.line_plus,)  = self.ax.plot([], [], marker='o', label="P(s=0|i=+)")
        (self.line_minus,) = self.ax.plot([], [], marker='s', label="P(s=0|i=-)")
        self.ax.legend()
        self.ax.grid(True, linestyle='--', alpha=0.4)

        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Status
        self.status = tk.StringVar(value="Ready.")
        ttk.Label(self, textvariable=self.status, anchor="w", padding=8).pack(side=tk.BOTTOM, fill=tk.X)

        self._running = False

    # ------------- UI helpers -------------
    def choose_outdir(self):
        d = filedialog.askdirectory(title="Select output directory")
        if d:
            self.var_outdir.set(d)

    def set_running(self, running: bool):
        self._running = running
        self.btn_run.config(state=(tk.DISABLED if running else tk.NORMAL))

    # ------------- Run -------------
    def on_run_clicked(self):
        if self._running:
            messagebox.showinfo("Info", "Experiment already running.")
            return
        # Validate
        try:
            shots = int(self.var_shots.get())
            nphi  = int(self.var_nphi.get())
            if shots <= 0 or nphi <= 1:
                raise ValueError
        except Exception:
            messagebox.showerror("Error", "Invalid shots or nphi.")
            return
        if self.var_use_ibm.get():
            if not IBM_RUNTIME_AVAILABLE:
                messagebox.showerror("Error", "qiskit-ibm-runtime not installed.")
                return
            if not self.var_ibm_token.get() or not self.var_ibm_backend.get():
                messagebox.showerror("Error", "Please provide IBM API Token and Backend name.")
                return
        self.set_running(True)
        self.status.set("Running…")
        threading.Thread(target=self.run_experiment_threadsafe, daemon=True).start()

    def run_experiment_threadsafe(self):
        try:
            self.run_experiment()
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))
        finally:
            self.set_running(False)

    def run_experiment(self):
        shots = int(self.var_shots.get())
        nphi  = int(self.var_nphi.get())
        save  = bool(self.var_save.get())
        save_corr = bool(self.var_save_corr.get())
        outdir= self.var_outdir.get()
        seed  = int(self.var_seed.get())

        os.makedirs(outdir, exist_ok=True)
        phis = np.linspace(0.0, 2.0*math.pi, nphi)

        # Choose backend
        use_ibm = self.var_use_ibm.get()
        if use_ibm:
            backend = self._build_ibm_backend()
        else:
            backend = self._build_aer_backend()

        # Reset plot and storage
        self.line_plus.set_data([], [])
        self.line_minus.set_data([], [])
        self.ax.relim(); self.ax.autoscale_view(); self.canvas.draw_idle()
        p_plus, p_minus = [], []
        self.last_phis = list(phis)
        self.last_p_plus = p_plus
        self.last_p_minus = p_minus
        self.last_memory_by_phi.clear()

        for idx, phi in enumerate(phis):
            qc = eraser_circuit(phi, erase=True)
            if use_ibm:
                tqc = transpile(qc, backend=backend, optimization_level=3, seed_transpiler=seed)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                try:
                    memory = result.get_memory()
                except Exception:
                    # Some IBM backends may not support memory; synthesize
                    memory = []
                    for bitstr, c in counts.items():
                        memory.extend([bitstr] * int(c))
            else:
                backend_local = backend
                tqc = transpile(qc, backend=backend_local, optimization_level=3)
                job = backend_local.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                memory = result.get_memory()

            # Save for preview
            self.last_memory_by_phi[idx] = memory

            # Conditional probabilities
            p0_plus  = conditional_prob_from_counts(counts, '0')
            p0_minus = conditional_prob_from_counts(counts, '1')
            p_plus.append(p0_plus); p_minus.append(p0_minus)

            # Update live plot
            self.line_plus.set_data(phis[:idx+1], p_plus)
            self.line_minus.set_data(phis[:idx+1], p_minus)
            self.ax.relim(); self.ax.autoscale_view(); self.canvas.draw_idle()

            # Save binary streams and correlated streams
            if save:
                # S only stream (unconditional S)
                s_stream = ''.join(m[1] for m in memory)
                save_text(os.path.join(outdir, f"binary_phi_{idx:02d}.txt"), s_stream)

            if save_corr:
                # Save pairs one per line
                save_text(os.path.join(outdir, f"pairs_phi_{idx:02d}.txt"), '\n'.join(memory))
                # Save S|I=0 and S|I=1 as pure bit streams
                s_given_i0 = ''.join(m[1] for m in memory if m[0]=='0')
                s_given_i1 = ''.join(m[1] for m in memory if m[0]=='1')
                save_text(os.path.join(outdir, f"S_given_I0_phi_{idx:02d}.txt"), s_given_i0)
                save_text(os.path.join(outdir, f"S_given_I1_phi_{idx:02d}.txt"), s_given_i1)
                # Save counts CSV for transparency
                import csv
                counts_path = os.path.join(outdir, f"counts_phi_{idx:02d}.csv")
                with open(counts_path, 'w', newline='') as f:
                    w = csv.writer(f)
                    w.writerow(['pair','count'])
                    for pair in ['00','01','10','11']:
                        w.writerow([pair, counts.get(pair, 0)])

            self.status.set(f"φ={phi:.3f} • P(+)= {p0_plus:.3f} • P(-)= {p0_minus:.3f} • step {idx+1}/{nphi}")

        # Save CSV + PNG for main plot
        if bool(self.var_save.get()):
            import csv
            with open(os.path.join(outdir, "eraser_probs.csv"), "w", newline="") as f:
                w = csv.writer(f)
                w.writerow(['phi','P(s=0|i=+)','P(s=0|i=-)'])
                for phi,a,b in zip(phis, p_plus, p_minus):
                    w.writerow([phi, a, b])
            self.fig.savefig(os.path.join(outdir, "eraser_plot.png"), dpi=160)

        Vp = visibility(p_plus); Vm = visibility(p_minus)
        self.status.set(f"Done. V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f}  (IBM: {use_ibm})")

    # ------------- Correlate button -------------
    def on_correlate_clicked(self):
        if not self.last_phis or not self.last_p_plus:
            messagebox.showinfo("Info", "No run data available. Please Run first.")
            return
        Vp = visibility(self.last_p_plus); Vm = visibility(self.last_p_minus)
        verdict = "Interferență detectată" if (Vp > 0.25 and Vm > 0.25) else "Interferență slabă/absentă"
        self.ax.text(0.02, 0.96, f"V(+)={Vp:.2f}, V(-)={Vm:.2f}\n{verdict}",
                     transform=self.ax.transAxes, fontsize=10,
                     bbox=dict(boxstyle="round", facecolor="white", alpha=0.7))
        self.canvas.draw_idle()
        self.status.set(f"Correlation: V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f} → {verdict}")

    # ------------- Preview correlated bits -------------
    def on_preview_clicked(self):
        """Popup with a quick preview of *actual correlated bits* for the last phi index."""
        if not self.last_memory_by_phi:
            messagebox.showinfo("Info", "No run data available. Please Run first.")
            return
        idx = max(self.last_memory_by_phi.keys())
        memory = self.last_memory_by_phi[idx]
        # Build counts and small preview
        from collections import Counter
        c = Counter(memory)
        preview_pairs = '\n'.join(memory[:200])  # first 200 pairs
        # Window
        win = tk.Toplevel(self)
        win.title(f"Correlated bits preview (phi index {idx})")
        win.geometry("520x560")
        frm = ttk.Frame(win, padding=8); frm.pack(fill=tk.BOTH, expand=True)
        ttk.Label(frm, text=f"First 200 pairs (I S) for phi[{idx}]:").pack(anchor="w")
        txt = tk.Text(frm, height=14)
        txt.insert("1.0", preview_pairs)
        txt.config(state=tk.DISABLED)
        txt.pack(fill=tk.BOTH, expand=False, pady=6)
        ttk.Label(frm, text="Counts (coincidences):").pack(anchor="w")
        tree = ttk.Treeview(frm, columns=("pair","count"), show="headings", height=6)
        tree.heading("pair", text="pair"); tree.heading("count", text="count")
        for pair in ['00','01','10','11']:
            tree.insert("", "end", values=(pair, c.get(pair,0)))
        tree.pack(fill=tk.BOTH, expand=False, pady=6)
        # Conditional streams preview
        s_i0 = ''.join(m[1] for m in memory if m[0]=='0')[:200]
        s_i1 = ''.join(m[1] for m in memory if m[0]=='1')[:200]
        ttk.Label(frm, text="S | I=0 (first 200 bits):").pack(anchor="w", pady=(8,0))
        txt0 = tk.Text(frm, height=4); txt0.insert("1.0", s_i0); txt0.config(state=tk.DISABLED); txt0.pack(fill=tk.X)
        ttk.Label(frm, text="S | I=1 (first 200 bits):").pack(anchor="w", pady=(8,0))
        txt1 = tk.Text(frm, height=4); txt1.insert("1.0", s_i1); txt1.config(state=tk.DISABLED); txt1.pack(fill=tk.X)

    # ------------- Backend builders -------------
    def _build_aer_backend(self):
        return AerSimulator(seed_simulator=int(self.var_seed.get()))

    def _build_ibm_backend(self):
        token = self.var_ibm_token.get().strip()
        instance = self.var_ibm_instance.get().strip() or None
        service = QiskitRuntimeService(channel="ibm_quantum", token=token, instance=instance) if instance else QiskitRuntimeService(channel="ibm_quantum", token=token)
        self.service = service
        backend_name = self.var_ibm_backend.get().strip()
        backend = service.backend(backend_name)
        self.backend = backend
        return backend

def main():
    app = EraserGUI()
    app.mainloop()

if __name__ == "__main__":
    main()
