#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantum Eraser – GUI (shots / nphi / save / show)
Author: ChatGPT

Install:
  pip install "qiskit==1.2.*" "qiskit-aer==0.15.*" matplotlib numpy

Run:
  python quantum_eraser_gui.py

What it does:
- Lets you set shots, nphi, and toggle "save" and "show" from a simple GUI.
- Runs the delayed-choice quantum eraser simulation (AerSimulator).
- Shows a live plot of P(s=0|i=+) and P(s=0|i=-).
- If "save" is checked, writes:
    out/binary_phi_XX.txt (binary stream for S per phase)
    out/eraser_probs.csv
    out/eraser_plot.png
"""
import os
import math
import threading
import tkinter as tk
from tkinter import ttk, messagebox, filedialog

import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator

# ------------------- Core experiment -------------------
def eraser_run(phi: float, erase: bool=True, shots: int=4000, seed: int=1234):
    """Return counts + memory list (bitstrings 'i s')."""
    qc = QuantumCircuit(2,2)  # q0=S, q1=I; c0=S, c1=I
    qc.h(0)
    qc.rz(phi, 0)
    qc.cx(0,1)
    qc.h(0)
    if erase:
        qc.h(1)  # measure I in X-basis
    qc.measure(0,0)
    qc.measure(1,1)

    backend = AerSimulator(seed_simulator=seed)
    tqc = transpile(qc, backend=backend, optimization_level=3)
    job = backend.run(tqc, shots=shots, memory=True)
    res = job.result()
    return res.get_counts(), res.get_memory()

def conditional_prob(counts, i_bit='0'):
    tot = n0 = 0
    for k, v in counts.items():
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0':
                n0 += v
    return (n0 / tot) if tot else 0.0

def save_binary(memory, path):
    bits = ''.join(m[1] for m in memory)  # only S (second bit)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w') as f:
        f.write(bits)

def save_csv(phi_list, p_plus, p_minus, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    import csv
    with open(path, 'w', newline='') as f:
        w = csv.writer(f)
        w.writerow(['phi','P(s=0|i=+)','P(s=0|i=-)'])
        for phi,a,b in zip(phi_list,p_plus,p_minus):
            w.writerow([phi,a,b])

def visibility(vals):
    if not vals: return 0.0
    vmax, vmin = max(vals), min(vals)
    denom = vmax + vmin
    return (vmax - vmin) / denom if denom else 0.0

# ------------------- GUI Application -------------------
class EraserGUI(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Quantum Eraser – GUI")
        self.geometry("900x650")

        # Defaults
        self.var_shots = tk.IntVar(value=4000)
        self.var_nphi  = tk.IntVar(value=13)
        self.var_save  = tk.BooleanVar(value=True)
        self.var_show  = tk.BooleanVar(value=True)
        self.var_outdir = tk.StringVar(value="out")
        self.var_seed   = tk.IntVar(value=1234)

        # Controls
        frm = ttk.Frame(self, padding=10)
        frm.pack(side=tk.TOP, fill=tk.X)

        ttk.Label(frm, text="shots:").grid(row=0, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.var_shots, width=8).grid(row=0, column=1, padx=5)

        ttk.Label(frm, text="nphi:").grid(row=0, column=2, sticky="w")
        ttk.Entry(frm, textvariable=self.var_nphi, width=8).grid(row=0, column=3, padx=5)

        ttk.Checkbutton(frm, text="save", variable=self.var_save).grid(row=0, column=4, padx=10)
        ttk.Checkbutton(frm, text="show", variable=self.var_show).grid(row=0, column=5, padx=10)

        ttk.Label(frm, text="outdir:").grid(row=1, column=0, sticky="w", pady=5)
        out_entry = ttk.Entry(frm, textvariable=self.var_outdir, width=24)
        out_entry.grid(row=1, column=1, columnspan=3, sticky="we", padx=5, pady=5)
        ttk.Button(frm, text="Browse…", command=self.choose_outdir).grid(row=1, column=4, sticky="w")

        ttk.Label(frm, text="seed:").grid(row=1, column=5, sticky="e")
        ttk.Entry(frm, textvariable=self.var_seed, width=10).grid(row=1, column=6, padx=5)

        self.btn_run = ttk.Button(frm, text="Run", command=self.on_run_clicked)
        self.btn_run.grid(row=0, column=6, padx=10)

        # Figure
        self.fig = plt.Figure(figsize=(7,4.5))
        self.ax  = self.fig.add_subplot(111)
        self.ax.set_title("Quantum Eraser – P(s=0|i=±)")
        self.ax.set_xlabel("φ (radiani)")
        self.ax.set_ylabel("Probabilitate")
        self.line_plus,  = self.ax.plot([], [], marker='o', label="P(s=0|i=+)")
        self.line_minus, = self.ax.plot([], [], marker='s', label="P(s=0|i=-)")
        self.ax.legend()
        self.ax.grid(True, linestyle='--', alpha=0.4)

        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # Status
        self.status = tk.StringVar(value="Ready.")
        ttk.Label(self, textvariable=self.status, anchor="w", padding=8).pack(side=tk.BOTTOM, fill=tk.X)

        self._running = False

    def choose_outdir(self):
        d = filedialog.askdirectory(title="Select output directory")
        if d:
            self.var_outdir.set(d)

    def on_run_clicked(self):
        if self._running:
            messagebox.showinfo("Info", "Experiment already running.")
            return
        try:
            shots = int(self.var_shots.get())
            nphi  = int(self.var_nphi.get())
            if shots <= 0 or nphi <= 1:
                raise ValueError
        except Exception:
            messagebox.showerror("Error", "Invalid shots or nphi.")
            return

        self._running = True
        self.btn_run.config(state=tk.DISABLED)
        self.status.set("Running…")
        threading.Thread(target=self.run_experiment_threadsafe, daemon=True).start()

    def run_experiment_threadsafe(self):
        try:
            self.run_experiment()
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))
        finally:
            self._running = False
            self.btn_run.config(state=tk.NORMAL)

    def run_experiment(self):
        shots = int(self.var_shots.get())
        nphi  = int(self.var_nphi.get())
        save  = bool(self.var_save.get())
        showp = bool(self.var_show.get())
        outdir= self.var_outdir.get()
        seed  = int(self.var_seed.get())

        os.makedirs(outdir, exist_ok=True)
        phis = np.linspace(0.0, 2.0*math.pi, nphi)
        p_plus, p_minus = [], []

        # clear plot
        self.line_plus.set_data([], [])
        self.line_minus.set_data([], [])
        self.ax.relim(); self.ax.autoscale_view()
        self.canvas.draw_idle()

        for idx, phi in enumerate(phis):
            counts, memory = eraser_run(phi, erase=True, shots=shots, seed=seed)
            p0_plus  = conditional_prob(counts, '0')
            p0_minus = conditional_prob(counts, '1')
            p_plus.append(p0_plus); p_minus.append(p0_minus)

            # update live plot
            self.line_plus.set_data(phis[:idx+1], p_plus)
            self.line_minus.set_data(phis[:idx+1], p_minus)
            self.ax.relim(); self.ax.autoscale_view()
            self.canvas.draw_idle()

            # save binary per phase
            if save:
                save_binary(memory, os.path.join(outdir, f"binary_phi_{idx:02d}.txt"))

            self.status.set(f"φ={phi:.3f} • P(+)= {p0_plus:.3f} • P(-)= {p0_minus:.3f} • step {idx+1}/{nphi}")

        # save CSV and PNG
        if save:
            import csv
            with open(os.path.join(outdir, "eraser_probs.csv"), "w", newline="") as f:
                w = csv.writer(f)
                w.writerow(['phi','P(s=0|i=+)','P(s=0|i=-)'])
                for phi,a,b in zip(phis, p_plus, p_minus):
                    w.writerow([phi, a, b])
            self.fig.savefig(os.path.join(outdir, "eraser_plot.png"), dpi=160)

        Vp = visibility(p_plus); Vm = visibility(p_minus)
        self.status.set(f"Done. Vizibilitate: V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f}")
        if showp:
            # keep the window open; nothing extra needed in Tk GUI
            pass

def main():
    app = EraserGUI()
    app.mainloop()

if __name__ == "__main__":
    main()
