#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantum Eraser GUI + Correlated Bits + Brute-Force
(Interference, keylen range, live plots, manual keys, explanations)

Author: ChatGPT (adaptat pentru utilizator)

Features:
  - Rulează delayed-choice quantum eraser pe Aer sau IBM backends
  - Salvează fluxuri binare, fluxuri corelate (perechi, S|I=0, S|I=1)
  - Plot live pentru P(s=0|i=±) din experimentul "standard" (I real)
  - Calcul vizibilitate interferență (I real)
  - Modul de brute-force:
      * cheie binară cu lungime L, unde L ∈ [keylen_min, keylen_max]
      * cheia generează un I_k sintetic din keystream SHA-256
      * se caută cheile care maximizează interferența:
          score(key) = V_plus + V_minus
  - Live update:
      * tabelul cu Top N chei se actualizează periodic în timpul brute-force-ului
      * status bar afișează progresul (chei testate / total)
      * fereastră „Live best-key interference”:
          - arată P(s=0|I=0/1) (real) și P(s=0|I_k=0/1) (pentru cheia aleasă)
          - actualizează score, V_plus, V_minus numeric
      * slider care alege rangul din Top N (1 = cea mai bună, 2 = a doua etc.)
  - Manual key:
      * poți introduce orice cheie binară „010101...” și să vezi curbele pentru ea.

Dependencies:
  pip install "qiskit==1.2.*" "qiskit-aer==0.15.*" matplotlib numpy
  pip install "qiskit-ibm-runtime>=0.25.0"    # doar dacă vreți IBM backends
"""

import os
import math
import threading
import argparse
from collections import Counter
from typing import List, Tuple, Dict
import hashlib
import multiprocessing

import tkinter as tk
from tkinter import ttk, messagebox, filedialog

import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

# Qiskit imports
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator

# IBM runtime optional
IBM_RUNTIME_AVAILABLE = False
try:
    from qiskit_ibm_runtime import QiskitRuntimeService
    IBM_RUNTIME_AVAILABLE = True
except Exception:
    IBM_RUNTIME_AVAILABLE = False

# --------------------------
# Quantum circuit helpers
# --------------------------
def eraser_circuit(phi: float, erase: bool = True) -> QuantumCircuit:
    """
    Circuitul de quantum eraser cu 2 qubiți:
      - qubit 0 = S (signal)
      - qubit 1 = I (idler)
    Dacă erase=True, aplică H pe I (șterge informația de cale).
    Măsoară amândoi qubiții în baza Z.
    """
    qc = QuantumCircuit(2, 2)
    qc.h(0)
    qc.rz(phi, 0)
    qc.cx(0, 1)
    qc.h(0)
    if erase:
        qc.h(1)
    qc.measure(0, 0)
    qc.measure(1, 1)
    return qc

def conditional_prob_from_counts(counts: Dict[str, int], i_bit: str = '0') -> float:
    """
    Din counts (de tip Qiskit, ex: {"00": 123, "01": 200, ...}),
    calculează P(s=0 | i = i_bit).
    Aici se presupune că string-ul e "IS" (I primul bit, S al doilea).
    """
    tot = n0 = 0
    for k, v in counts.items():
        if len(k) < 2:
            continue
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0':
                n0 += v
    return (n0 / tot) if tot else 0.0

def visibility(vals: List[float]) -> float:
    """
    Calculează vizibilitatea unei serii de probabilități:
      V = (Vmax - Vmin) / (Vmax + Vmin)
    """
    if not vals:
        return 0.0
    vmax, vmin = max(vals), min(vals)
    denom = vmax + vmin
    return (vmax - vmin) / denom if denom else 0.0

# --------------------------
# I/O helpers
# --------------------------
def save_text(path: str, text: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

# --------------------------
# SHA-256 based keystream
# --------------------------
def sha256_stream(key_bin: str, length_bits: int) -> str:
    """
    Generează un flux de biți pseudo-aleatori determinist, folosind SHA-256:
      - key_bin: cheie binară ("010101...")
      - length_bits: lungimea dorită în biți
    """
    key_bytes = int(key_bin, 2).to_bytes((len(key_bin) + 7) // 8, 'big')
    out_bits = []
    counter = 0
    while len(out_bits) < length_bits:
        m = key_bytes + counter.to_bytes(4, 'big')
        digest = hashlib.sha256(m).digest()
        for b in digest:
            for i in range(8):
                out_bits.append(str((b >> (7 - i)) & 1))
                if len(out_bits) >= length_bits:
                    break
            if len(out_bits) >= length_bits:
                break
        counter += 1
    return ''.join(out_bits[:length_bits])

# --------------------------
# Interference brute-force
# --------------------------
def score_key_visibility(key_bin: str,
                         phis: List[float],
                         memory_by_phi: Dict[int, List[str]]) -> Tuple[float, float, float]:
    """
    Evaluează o cheie binară din perspectiva interferenței.

    - memory_by_phi[idx] = lista de stringuri pentru fiecare shot, de forma "IS"
      (I = bitul qubitului idler, S = bitul qubitului signal); aici folosim doar S,
      iar I_k îl generăm din keystream.
    - phis = lista valorilor de fază (același index ca memory_by_phi)

    Pași:
      1. Construim un keystream global K de lungime N = total_shots.
      2. Pentru fiecare phi_idx:
         - luăm sub-segmentul corespunzător din K, de lungime len(memory_by_phi[phi_idx])
         - definim I_k(j) = bitul din keystream
         - construim un dicționar de counts pentru perechi (I_k, S).
      3. Calculăm P(s=0 | I_k=0)(phi) și P(s=0 | I_k=1)(phi) pentru fiecare phi.
      4. Obținem vizibilitățile V_plus și V_minus.
      5. score_total = V_plus + V_minus.

    Returnează:
      (score_total, V_plus, V_minus)
    """
    total_shots = sum(len(mem) for mem in memory_by_phi.values())
    if total_shots == 0:
        return 0.0, 0.0, 0.0

    keystream = sha256_stream(key_bin, total_shots)
    g_index = 0  # index global în keystream

    p_plus_vals = []   # P(s=0 | I_k=0) pentru fiecare phi
    p_minus_vals = []  # P(s=0 | I_k=1) pentru fiecare phi

    for phi_idx in range(len(phis)):
        memory = memory_by_phi.get(phi_idx, [])
        L = len(memory)
        if L == 0:
            p_plus_vals.append(0.0)
            p_minus_vals.append(0.0)
            continue

        ksphi = keystream[g_index:g_index + L]
        g_index += L

        counts_k: Dict[str, int] = {}
        for j, bits in enumerate(memory):
            if len(bits) < 2:
                continue
            s = bits[1]        # bitul S (0 sau 1)
            i_k = ksphi[j]     # bitul I_k din keystream
            pair = i_k + s
            counts_k[pair] = counts_k.get(pair, 0) + 1

        p0_plus = conditional_prob_from_counts(counts_k, i_bit='0')  # P(s=0 | I_k=0)
        p0_minus = conditional_prob_from_counts(counts_k, i_bit='1') # P(s=0 | I_k=1)

        p_plus_vals.append(p0_plus)
        p_minus_vals.append(p0_minus)

    V_plus = visibility(p_plus_vals)
    V_minus = visibility(p_minus_vals)
    score_total = V_plus + V_minus
    return score_total, V_plus, V_minus

def curves_for_key(key_bin: str,
                   phis: List[float],
                   memory_by_phi: Dict[int, List[str]]
                   ) -> Tuple[List[float], List[float], List[float], float, float]:
    """
    Pentru o cheie dată, calculează, pentru fiecare phi:
      - P(s=0 | I_k=0)(phi)
      - P(s=0 | I_k=1)(phi)
    și vizibilitățile V_plus, V_minus.

    Returnează:
      (phis, p_plus_vals, p_minus_vals, V_plus, V_minus)
    """
    total_shots = sum(len(mem) for mem in memory_by_phi.values())
    if total_shots == 0:
        return phis, [0.0] * len(phis), [0.0] * len(phis), 0.0, 0.0

    keystream = sha256_stream(key_bin, total_shots)
    g_index = 0

    p_plus_vals = []
    p_minus_vals = []

    for phi_idx in range(len(phis)):
        memory = memory_by_phi.get(phi_idx, [])
        L = len(memory)
        if L == 0:
            p_plus_vals.append(0.0)
            p_minus_vals.append(0.0)
            continue

        ksphi = keystream[g_index:g_index + L]
        g_index += L

        counts_k: Dict[str, int] = {}
        for j, bits in enumerate(memory):
            if len(bits) < 2:
                continue
            s = bits[1]
            i_k = ksphi[j]
            pair = i_k + s
            counts_k[pair] = counts_k.get(pair, 0) + 1

        p0_plus = conditional_prob_from_counts(counts_k, i_bit='0')
        p0_minus = conditional_prob_from_counts(counts_k, i_bit='1')

        p_plus_vals.append(p0_plus)
        p_minus_vals.append(p0_minus)

    V_plus = visibility(p_plus_vals)
    V_minus = visibility(p_minus_vals)
    return phis, p_plus_vals, p_minus_vals, V_plus, V_minus

def try_key_visibility_worker(args) -> Tuple[str, float, float, float]:
    """
    Worker pentru multiprocessing:
      args = (key_bin, phis, memory_by_phi)
    Returnează:
      (key_bin, score_total, V_plus, V_minus)
    """
    key_bin, phis, memory_by_phi = args
    try:
        score_total, V_plus, V_minus = score_key_visibility(key_bin, phis, memory_by_phi)
    except Exception:
        score_total, V_plus, V_minus = 0.0, 0.0, 0.0
    return (key_bin, score_total, V_plus, V_minus)

def run_bruteforce_visibility_range(phis: List[float],
                                    memory_by_phi: Dict[int, List[str]],
                                    keylen_min: int,
                                    keylen_max: int,
                                    topN: int = 20,
                                    parallel: bool = False,
                                    workers: int = None,
                                    progress_callback=None
                                    ) -> List[Tuple[str, float, float, float]]:
    """
    Parcurge toate cheile binare de lungime L, pentru L în [keylen_min, keylen_max],
    și caută pe cele care maximizează interferența (score_total = V_plus + V_minus).

    progress_callback, dacă este dat, este apelat periodic:
        progress_callback(done_keys, total_keys, top_list)
    unde top_list este lista curentă de (key_bin, score_total, V_plus, V_minus).

    Returnează lista finală "topN" de tuple:
      (key_bin, score_total, V_plus, V_minus),
    sortată descrescător după score_total.
    """

    if keylen_min <= 0 or keylen_max < keylen_min:
        raise ValueError("Interval invalid pentru keylen_min / keylen_max")

    # număr total de chei
    total_keys = 0
    for L in range(keylen_min, keylen_max + 1):
        total_keys += (1 << L)

    if total_keys > 2_000_000 and not parallel:
        raise ValueError("Prea multe chei; activați parallel sau micșorați intervalul de keylen")

    def key_generator():
        for L in range(keylen_min, keylen_max + 1):
            for i in range(1 << L):
                yield format(i, f"0{L}b")

    args_iter = ((k, phis, memory_by_phi) for k in key_generator())

    top: List[Tuple[str, float, float, float]] = []  # (key_bin, score_total, V_plus, V_minus)
    done = 0
    BATCH = 64  # la câte chei actualizăm progress_callback

    def update_top(candidate: Tuple[str, float, float, float]):
        nonlocal top
        key_bin, score_total, V_plus, V_minus = candidate
        if len(top) < topN:
            top.append(candidate)
            top.sort(key=lambda x: x[1], reverse=True)
        else:
            if score_total > top[-1][1]:
                top[-1] = candidate
                top.sort(key=lambda x: x[1], reverse=True)

    if parallel:
        workers_local = workers or max(1, multiprocessing.cpu_count() - 1)
        with multiprocessing.Pool(processes=workers_local) as pool:
            for res in pool.imap_unordered(try_key_visibility_worker, args_iter, chunksize=64):
                done += 1
                update_top(res)
                if progress_callback and (done % BATCH == 0 or done == total_keys):
                    progress_callback(done, total_keys, list(top))
    else:
        for arg in args_iter:
            res = try_key_visibility_worker(arg)
            done += 1
            update_top(res)
            if progress_callback and (done % BATCH == 0 or done == total_keys):
                progress_callback(done, total_keys, list(top))

    # sortăm descrescător după scor la final
    top.sort(key=lambda x: x[1], reverse=True)
    if progress_callback:
        progress_callback(done, total_keys, list(top))
    return top

# --------------------------
# GUI Application
# --------------------------
class EraserBruteGUI(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Quantum Eraser – Interference BruteForce (keylen range, live)")
        self.geometry("1200x860")

        # Core parameters
        self.var_shots = tk.IntVar(value=4000)
        self.var_nphi  = tk.IntVar(value=13)
        self.var_seed  = tk.IntVar(value=1234)
        self.var_save  = tk.BooleanVar(value=True)
        self.var_show  = tk.BooleanVar(value=True)
        self.var_outdir = tk.StringVar(value="out")

        # correlated
        self.var_save_corr = tk.BooleanVar(value=True)

        # IBM
        self.var_use_ibm = tk.BooleanVar(value=False)
        self.var_ibm_token = tk.StringVar(value="")
        self.var_ibm_instance = tk.StringVar(value="")
        self.var_ibm_backend = tk.StringVar(value="")

        # Brute-force params (interval de keylen)
        self.var_keylen_min = tk.IntVar(value=4)
        self.var_keylen_max = tk.IntVar(value=8)
        self.var_topn = tk.IntVar(value=20)
        self.var_parallel = tk.BooleanVar(value=False)
        self.var_workers = tk.IntVar(value=max(1, multiprocessing.cpu_count() - 1))

        # Manual key
        self.var_manual_key = tk.StringVar(value="")
        # Slider pentru rangul din Top N afișat în fereastra live
        self.var_live_rank = tk.IntVar(value=1)

        self._build_controls()
        self._build_plot()
        self._build_bruteforce_panel()
        self._build_notation_panel()

        # Storage for last-run (real interferență)
        self.last_phis: List[float] = []
        self.last_p_plus: List[float] = []   # P(s=0|I=0) real
        self.last_p_minus: List[float] = []  # P(s=0|I=1) real
        self.last_memory_by_phi: Dict[int, List[str]] = {}
        self.last_Vp_real: float = 0.0
        self.last_Vm_real: float = 0.0

        self._running = False

        # Live best-key plot window
        self.live_win = None
        self.live_fig = None
        self.live_ax = None
        self.live_canvas = None
        self.live_line_plus_real = None
        self.live_line_minus_real = None
        self.live_line_plus_syn = None
        self.live_line_minus_syn = None
        self.live_info_var = tk.StringVar(value="No key yet.")

    def _build_controls(self):
        frm = ttk.Frame(self, padding=8)
        frm.pack(side=tk.TOP, fill=tk.X)

        ttk.Label(frm, text="shots:").grid(row=0, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.var_shots, width=8).grid(row=0, column=1, padx=4)
        ttk.Label(frm, text="nphi:").grid(row=0, column=2, sticky="w")
        ttk.Entry(frm, textvariable=self.var_nphi, width=8).grid(row=0, column=3, padx=4)
        ttk.Label(frm, text="seed:").grid(row=0, column=4, sticky="w")
        ttk.Entry(frm, textvariable=self.var_seed, width=8).grid(row=0, column=5, padx=4)
        ttk.Checkbutton(frm, text="save", variable=self.var_save).grid(row=0, column=6, padx=6)
        ttk.Checkbutton(frm, text="save correlated", variable=self.var_save_corr).grid(row=0, column=7, padx=6)

        ttk.Label(frm, text="outdir:").grid(row=1, column=0, sticky="w", pady=6)
        ttk.Entry(frm, textvariable=self.var_outdir, width=36).grid(row=1, column=1, columnspan=3, sticky="we", padx=4)
        ttk.Button(frm, text="Browse…", command=self._choose_outdir).grid(row=1, column=4, sticky="w", padx=4)

        ibm = ttk.LabelFrame(self, text="IBM Quantum (optional)", padding=6)
        ibm.pack(side=tk.TOP, fill=tk.X, padx=8, pady=6)
        ttk.Checkbutton(ibm, text="Use IBM Quantum", variable=self.var_use_ibm).grid(row=0, column=0, sticky="w")
        ttk.Label(ibm, text="API Token:").grid(row=0, column=1, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_token, width=44, show="•").grid(row=0, column=2, padx=4, sticky="we")
        ttk.Label(ibm, text="Instance:").grid(row=0, column=3, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_instance, width=20).grid(row=0, column=4, padx=4)
        ttk.Label(ibm, text="Backend:").grid(row=0, column=5, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_backend, width=18).grid(row=0, column=6, padx=4)

        btns = ttk.Frame(self, padding=6)
        btns.pack(side=tk.TOP, fill=tk.X)
        self.btn_run = ttk.Button(btns, text="Run Experiment", command=self.on_run)
        self.btn_run.pack(side=tk.LEFT, padx=4)
        self.btn_correlate = ttk.Button(btns, text="Correlate / Detect", command=self.on_correlate)
        self.btn_correlate.pack(side=tk.LEFT, padx=4)
        self.btn_preview = ttk.Button(btns, text="Show Last Phi Correlations", command=self.on_preview)
        self.btn_preview.pack(side=tk.LEFT, padx=4)

    def _build_plot(self):
        self.fig = plt.Figure(figsize=(8.5, 5))
        self.ax = self.fig.add_subplot(111)
        self.ax.set_title("Quantum Eraser – P(s=0|I=0/1) (real)")
        self.ax.set_xlabel("φ (radiani)")
        self.ax.set_ylabel("Probabilitate")
        (self.line_plus,) = self.ax.plot([], [], marker='o', label="P(s=0|I=0) real")
        (self.line_minus,) = self.ax.plot([], [], marker='s', label="P(s=0|I=1) real")
        self.ax.legend()
        self.ax.grid(True, linestyle='--', alpha=0.4)
        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        self.status = tk.StringVar(value="Ready")
        ttk.Label(self, textvariable=self.status, anchor="w", padding=6).pack(side=tk.BOTTOM, fill=tk.X)

    def _build_bruteforce_panel(self):
        panel = ttk.LabelFrame(self, text="Brute-Force Interference (on last experiment)", padding=6)
        panel.pack(side=tk.BOTTOM, fill=tk.X, padx=8, pady=6)

        ttk.Label(panel, text="keylen min:").grid(row=0, column=0, sticky="w", pady=4)
        ttk.Entry(panel, textvariable=self.var_keylen_min, width=6).grid(row=0, column=1, padx=4)
        ttk.Label(panel, text="keylen max:").grid(row=0, column=2, sticky="w")
        ttk.Entry(panel, textvariable=self.var_keylen_max, width=6).grid(row=0, column=3, padx=4)

        ttk.Label(panel, text="topN:").grid(row=0, column=4, sticky="w")
        ttk.Entry(panel, textvariable=self.var_topn, width=6).grid(row=0, column=5, padx=4)

        ttk.Checkbutton(panel, text="parallel", variable=self.var_parallel).grid(row=1, column=0, padx=4, pady=4)
        ttk.Label(panel, text="workers:").grid(row=1, column=1, sticky="e")
        ttk.Entry(panel, textvariable=self.var_workers, width=6).grid(row=1, column=2, padx=4)
        ttk.Button(panel, text="Run Brute-Force", command=self.on_run_bruteforce).grid(row=1, column=3, padx=6)

        # Tabel rezultate:
        cols = ("key", "score", "V_plus", "V_minus")
        self.tree = ttk.Treeview(panel, columns=cols, show="headings", height=8)
        self.tree.heading("key", text="key (L:bits)")
        self.tree.heading("score", text="score (V+ + V-)")
        self.tree.heading("V_plus", text="V_plus")
        self.tree.heading("V_minus", text="V_minus")

        self.tree.column("key", width=220, anchor="w")
        self.tree.column("score", width=140, anchor="w")
        self.tree.column("V_plus", width=140, anchor="w")
        self.tree.column("V_minus", width=140, anchor="w")

        self.tree.grid(row=2, column=0, columnspan=6, sticky="nsew", pady=6)
        panel.grid_rowconfigure(2, weight=1)

        # dublu-click pe un rând -> arată curbele pentru cheia respectivă
        self.tree.bind("<Double-1>", self._on_tree_double_click)

        # Manual key panel
        ttk.Label(panel, text="Manual key bits:").grid(row=3, column=0, sticky="w", pady=4)
        ttk.Entry(panel, textvariable=self.var_manual_key, width=32).grid(row=3, column=1, columnspan=3, sticky="we", padx=4)
        ttk.Button(panel, text="Show curves (manual key)", command=self.on_manual_key).grid(row=3, column=4, padx=4)

        # Slider pentru rank (TopN) în fereastra live
        ttk.Label(panel, text="Live rank (1 = best):").grid(row=4, column=0, sticky="w", pady=4)
        self.scale_rank = tk.Scale(panel, from_=1, to=max(1, self.var_topn.get()),
                                   orient=tk.HORIZONTAL, variable=self.var_live_rank, length=200)
        self.scale_rank.grid(row=4, column=1, columnspan=2, sticky="w", padx=4)

        ttk.Button(panel, text="Save Report", command=self._save_report).grid(row=4, column=4, padx=4, pady=6)
        ttk.Button(panel, text="Open Outdir", command=self._open_outdir).grid(row=4, column=5, padx=4, pady=6)

    def _build_notation_panel(self):
        """
        Mic panou cu explicații ale notărilor (pentru claritate direct în GUI).
        """
        frame = ttk.LabelFrame(self, text="Notations (explicații scurte)", padding=6)
        frame.pack(side=tk.BOTTOM, fill=tk.X, padx=8, pady=4)

        text = (
            "φ (phi): faza introdusă în qubitul S.\n"
            "P(s=0|I=0/1): probabilitatea ca S=0 când I (idler) este 0 sau 1 (experiment real).\n"
            "P(s=0|I_k=0/1): probabilitatea ca S=0 când bitul sintetic I_k este 0 sau 1 (cheie bruteforce).\n"
            "V_plus: vizibilitatea interferenței pentru condiția I_k=0 (sau I=0).\n"
            "V_minus: vizibilitatea pentru condiția I_k=1 (sau I=1).\n"
            "score = V_plus + V_minus: cât de puternică este interferența totală pentru o cheie."
        )
        lbl = ttk.Label(frame, text=text, justify="left")
        lbl.pack(anchor="w")

    # ---------------- Live best-key window ----------------
    def _ensure_live_window(self):
        """
        Creează (dacă nu există) fereastra cu grafice live pentru cheia curentă.
        """
        if self.live_win is not None and tk.Toplevel.winfo_exists(self.live_win):
            return

        self.live_win = tk.Toplevel(self)
        self.live_win.title("Live interference: real I vs synthetic I_k (TopN key)")
        self.live_win.geometry("860x620")

        self.live_fig = plt.Figure(figsize=(7.5, 4.5))
        self.live_ax = self.live_fig.add_subplot(111)
        self.live_ax.set_title("Real vs synthetic interference")
        self.live_ax.set_xlabel("φ (radiani)")
        self.live_ax.set_ylabel("Probabilitate")

        # Liniile reale (I autentic)
        (self.live_line_plus_real,) = self.live_ax.plot([], [], marker='o', linestyle='--', label="P(s=0 | I=0) real")
        (self.live_line_minus_real,) = self.live_ax.plot([], [], marker='s', linestyle='--', label="P(s=0 | I=1) real")

        # Liniile sintetice (I_k pentru cheia selectată)
        (self.live_line_plus_syn,) = self.live_ax.plot([], [], marker='^', label="P(s=0 | I_k=0) key")
        (self.live_line_minus_syn,) = self.live_ax.plot([], [], marker='v', label="P(s=0 | I_k=1) key")

        self.live_ax.legend()
        self.live_ax.grid(True, linestyle='--', alpha=0.4)

        self.live_canvas = FigureCanvasTkAgg(self.live_fig, master=self.live_win)
        self.live_canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        self.live_canvas.draw()

        frm = ttk.Frame(self.live_win, padding=8)
        frm.pack(side=tk.BOTTOM, fill=tk.X)
        ttk.Label(frm, textvariable=self.live_info_var, anchor="w").pack(side=tk.LEFT)

    def _update_live_plot(self, key_bin: str, score_total: float, V_plus: float, V_minus: float):
        """
        Actualizează graficul live pentru cheia dată (cheia cu rangul ales în TopN).
        Se apelează în thread-ul GUI (prin self.after).
        """
        if not self.last_phis or not self.last_memory_by_phi:
            return

        self._ensure_live_window()

        phis, p_plus_vals, p_minus_vals, Vp_calc, Vm_calc = curves_for_key(
            key_bin, self.last_phis, self.last_memory_by_phi
        )

        # Curbele reale
        self.live_line_plus_real.set_data(self.last_phis, self.last_p_plus)
        self.live_line_minus_real.set_data(self.last_phis, self.last_p_minus)

        # Curbele sintetice pentru cheia dată
        self.live_line_plus_syn.set_data(phis, p_plus_vals)
        self.live_line_minus_syn.set_data(phis, p_minus_vals)

        self.live_ax.relim()
        self.live_ax.autoscale_view()
        self.live_canvas.draw()

        self.live_info_var.set(
            f"key={key_bin}  score={score_total:.4f}  "
            f"V_plus={V_plus:.4f}  V_minus={V_minus:.4f}  |  "
            f"V_real(+)= {self.last_Vp_real:.4f}  V_real(-)= {self.last_Vm_real:.4f}"
        )

    # ---------------- UI actions ----------------
    def _choose_outdir(self):
        d = filedialog.askdirectory(title="Select outdir")
        if d:
            self.var_outdir.set(d)

    def _open_outdir(self):
        out = self.var_outdir.get()
        try:
            if os.name == 'nt':
                os.startfile(out)
            elif os.name == 'posix':
                os.system(f'xdg-open "{out}"')
        except Exception:
            messagebox.showinfo("Open", f"Outdir: {out}")

    def on_run(self):
        if hasattr(self, "_running") and self._running:
            messagebox.showinfo("Info", "Already running")
            return
        if self.var_use_ibm.get() and not IBM_RUNTIME_AVAILABLE:
            messagebox.showerror("Error", "qiskit-ibm-runtime not installed")
            return
        try:
            shots = int(self.var_shots.get())
            nphi  = int(self.var_nphi.get())
            seed  = int(self.var_seed.get())
            if shots <= 0 or nphi <= 1:
                raise ValueError
        except Exception:
            messagebox.showerror("Error", "Invalid shots/nphi/seed")
            return
        self._running = True
        self.btn_run.config(state=tk.DISABLED)
        self.status.set("Running experiment...")
        threading.Thread(target=self._run_experiment_thread, daemon=True).start()

    def _build_aer_backend(self):
        return AerSimulator(seed_simulator=int(self.var_seed.get()))

    def _build_ibm_backend(self):
        token = self.var_ibm_token.get().strip()
        instance = self.var_ibm_instance.get().strip() or None
        service = QiskitRuntimeService(
            channel="ibm_quantum",
            token=token,
            instance=instance
        ) if instance else QiskitRuntimeService(channel="ibm_quantum", token=token)
        backend_name = self.var_ibm_backend.get().strip()
        backend = service.backend(backend_name)
        self.service = service
        self.backend = backend
        return backend

    def _run_experiment_thread(self):
        try:
            self._run_experiment()
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))
        finally:
            self._running = False
            self.btn_run.config(state=tk.NORMAL)

    def _run_experiment(self):
        shots = int(self.var_shots.get())
        nphi = int(self.var_nphi.get())
        save = bool(self.var_save.get())
        save_corr = bool(self.var_save_corr.get())
        outdir = self.var_outdir.get()
        seed = int(self.var_seed.get())
        use_ibm = bool(self.var_use_ibm.get())

        os.makedirs(outdir, exist_ok=True)
        phis = list(np.linspace(0.0, 2.0 * math.pi, nphi))

        if use_ibm:
            backend = self._build_ibm_backend()
        else:
            backend = self._build_aer_backend()

        # resetăm plotul
        self.line_plus.set_data([], [])
        self.line_minus.set_data([], [])
        self.ax.relim()
        self.ax.autoscale_view()
        self.canvas.draw_idle()

        p_plus: List[float] = []
        p_minus: List[float] = []
        self.last_phis = phis
        self.last_p_plus = p_plus
        self.last_p_minus = p_minus
        self.last_memory_by_phi.clear()

        for idx, phi in enumerate(phis):
            qc = eraser_circuit(phi, erase=True)
            if use_ibm:
                tqc = transpile(qc, backend=backend, optimization_level=3, seed_transpiler=seed)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                try:
                    memory = result.get_memory()
                except Exception:
                    memory = []
                    for bitstr, c in counts.items():
                        memory.extend([bitstr] * int(c))
            else:
                tqc = transpile(qc, backend=backend, optimization_level=3)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                memory = result.get_memory()

            # salvăm și în structura internă pentru brute-force
            self.last_memory_by_phi[idx] = memory

            p0_plus = conditional_prob_from_counts(counts, '0')
            p0_minus = conditional_prob_from_counts(counts, '1')
            p_plus.append(p0_plus)
            p_minus.append(p0_minus)

            self.line_plus.set_data(phis[:idx + 1], p_plus)
            self.line_minus.set_data(phis[:idx + 1], p_minus)
            self.ax.relim()
            self.ax.autoscale_view()
            self.canvas.draw_idle()

            if save:
                # flux S (bitul 1)
                s_stream = ''.join(m[1] for m in memory if len(m) >= 2)
                save_text(os.path.join(outdir, f"binary_phi_{idx:02d}.txt"), s_stream)
            if save_corr:
                # flux perechi (I S)
                save_text(os.path.join(outdir, f"pairs_phi_{idx:02d}.txt"), '\n'.join(memory))
                s_i0 = ''.join(m[1] for m in memory if len(m) >= 2 and m[0] == '0')
                s_i1 = ''.join(m[1] for m in memory if len(m) >= 2 and m[0] == '1')
                save_text(os.path.join(outdir, f"S_given_I0_phi_{idx:02d}.txt"), s_i0)
                save_text(os.path.join(outdir, f"S_given_I1_phi_{idx:02d}.txt"), s_i1)
                import csv
                with open(os.path.join(outdir, f"counts_phi_{idx:02d}.csv"), "w", newline='') as f:
                    w = csv.writer(f)
                    w.writerow(['pair', 'count'])
                    for pair in ['00', '01', '10', '11']:
                        w.writerow([pair, counts.get(pair, 0)])

            self.status.set(f"φ={phi:.3f} • P(I=0)= {p0_plus:.3f} • P(I=1)= {p0_minus:.3f} • {idx + 1}/{len(phis)}")

        if save:
            import csv
            with open(os.path.join(outdir, "eraser_probs.csv"), "w", newline='') as f:
                w = csv.writer(f)
                w.writerow(['phi', 'P(s=0|I=0)', 'P(s=0|I=1)'])
                for phi, a, b in zip(phis, p_plus, p_minus):
                    w.writerow([phi, a, b])
            self.fig.savefig(os.path.join(outdir, "eraser_plot.png"), dpi=160)

        Vp = visibility(p_plus)
        Vm = visibility(p_minus)
        self.last_Vp_real = Vp
        self.last_Vm_real = Vm
        self.status.set(f"Done. V_real(+)= {Vp:.3f}, V_real(-)= {Vm:.3f}  (IBM: {use_ibm})")

    def on_correlate(self):
        if not self.last_p_plus:
            messagebox.showinfo("Info", "Run experiment first")
            return
        Vp = visibility(self.last_p_plus)
        Vm = visibility(self.last_p_minus)
        verdict = "Interferență detectată" if (Vp > 0.25 and Vm > 0.25) else "Interferență slabă/absentă"
        self.ax.text(0.02, 0.96, f"V(+)={Vp:.2f}, V(-)={Vm:.2f}\n{verdict}",
                     transform=self.ax.transAxes, fontsize=10,
                     bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
        self.canvas.draw_idle()
        self.status.set(f"Correlation: V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f} → {verdict}")

    def on_preview(self):
        if not self.last_memory_by_phi:
            messagebox.showinfo("Info", "Run first")
            return
        idx = max(self.last_memory_by_phi.keys())
        memory = self.last_memory_by_phi[idx]
        c = Counter(memory)
        preview_pairs = '\n'.join(memory[:200])
        s_i0 = ''.join(m[1] for m in memory if len(m) >= 2 and m[0] == '0')[:400]
        s_i1 = ''.join(m[1] for m in memory if len(m) >= 2 and m[0] == '1')[:400]

        win = tk.Toplevel(self)
        win.title(f"Preview correlations (phi idx {idx})")
        win.geometry("640x720")
        ttk.Label(win, text=f"First 200 pairs (I S) for phi index {idx}:").pack(anchor="w", padx=8, pady=(8, 0))
        txt = tk.Text(win, height=12)
        txt.pack(fill=tk.BOTH, padx=8)
        txt.insert("1.0", preview_pairs)
        txt.config(state=tk.DISABLED)

        ttk.Label(win, text="Counts (coincidences):").pack(anchor="w", padx=8, pady=(8, 0))
        tree = ttk.Treeview(win, columns=("pair", "count"), show="headings", height=4)
        tree.heading("pair", text="pair")
        tree.heading("count", text="count")
        for pair in ['00', '01', '10', '11']:
            tree.insert("", "end", values=(pair, c.get(pair, 0)))
        tree.pack(fill=tk.X, padx=8, pady=4)

        ttk.Label(win, text="S | I=0 (first 400 bits):").pack(anchor="w", padx=8, pady=(8, 0))
        t0 = tk.Text(win, height=4)
        t0.pack(fill=tk.X, padx=8)
        t0.insert("1.0", s_i0)
        t0.config(state=tk.DISABLED)
        ttk.Label(win, text="S | I=1 (first 400 bits):").pack(anchor="w", padx=8, pady=(8, 0))
        t1 = tk.Text(win, height=4)
        t1.pack(fill=tk.X, padx=8)
        t1.insert("1.0", s_i1)
        t1.config(state=tk.DISABLED)

    def on_manual_key(self):
        """
        Afișează curbele pentru cheia introdusă manual (var_manual_key).
        """
        key = self.var_manual_key.get().strip().replace(" ", "")
        if not key:
            messagebox.showerror("Error", "Introduceți o cheie binară (ex: 010101).")
            return
        if any(ch not in "01" for ch in key):
            messagebox.showerror("Error", "Cheia trebuie să conțină doar 0 și 1.")
            return
        if not self.last_phis or not self.last_memory_by_phi:
            messagebox.showerror("Error", "Trebuie mai întâi să rulați experimentul.")
            return
        self._show_key_curves(key)

    # ---------------- Brute-force (interferență, keylen range) ----------------
    def on_run_bruteforce(self):
        """
        Rulează brute-force de interferență pe ultimul experiment efectuat.
        Caută cheile (de lungime L în [keylen_min, keylen_max]) care maximizează V_plus + V_minus.
        """
        if not self.last_memory_by_phi or not self.last_phis:
            messagebox.showerror("Error", "Mai întâi rulați experimentul (Run Experiment).")
            return

        try:
            keylen_min = int(self.var_keylen_min.get())
            keylen_max = int(self.var_keylen_max.get())
            topn = int(self.var_topn.get())
            parallel = bool(self.var_parallel.get())
            workers = int(self.var_workers.get())
            if keylen_min <= 0 or keylen_max < keylen_min or topn <= 0:
                raise ValueError
        except Exception:
            messagebox.showerror("Error", "Parametri brute-force invalizi (keylen_min/max/topN/workers).")
            return

        # actualizăm scala pentru noul topN
        self.scale_rank.configure(to=max(1, topn))
        self.var_live_rank.set(1)

        # curățăm tabelul
        self.tree.delete(*self.tree.get_children())
        self.status.set("Running interference brute-force...")
        threading.Thread(
            target=self._run_bruteforce_thread,
            args=(keylen_min, keylen_max, topn, parallel, workers),
            daemon=True
        ).start()

    def _run_bruteforce_thread(self, keylen_min, keylen_max, topn, parallel, workers):
        try:
            phis = self.last_phis
            memory_by_phi = self.last_memory_by_phi

            def progress_cb(done, total, top_list):
                # este apelat din thread-ul de bruteforce; actualizăm GUI prin .after()
                def update_gui():
                    try:
                        self.status.set(
                            f"Brute-force: {done}/{total} keys tested "
                            f"({done * 100.0 / max(1, total):.1f}%)"
                        )
                        # reumplem tabelul cu top-ul curent
                        self.tree.delete(*self.tree.get_children())
                        for (k, score_total, V_plus, V_minus) in top_list:
                            key_label = f"{len(k)}:{k}"
                            self.tree.insert(
                                "",
                                "end",
                                values=(key_label, f"{score_total:.3f}", f"{V_plus:.3f}", f"{V_minus:.3f}")
                            )
                        # actualizăm graficul live pentru cheia cu rangul ales
                        if top_list:
                            rank = self.var_live_rank.get()
                            if rank < 1:
                                rank = 1
                            if rank > len(top_list):
                                rank = len(top_list)
                            best_k, best_score, best_Vp, best_Vm = top_list[rank - 1]
                            self._update_live_plot(best_k, best_score, best_Vp, best_Vm)
                    except Exception:
                        pass

                self.after(0, update_gui)

            top = run_bruteforce_visibility_range(
                phis,
                memory_by_phi,
                keylen_min=keylen_min,
                keylen_max=keylen_max,
                topN=topn,
                parallel=parallel,
                workers=workers,
                progress_callback=progress_cb
            )

            # la final, salvăm raportul
            report_path = os.path.join(self.var_outdir.get(), "interference_bruteforce_report.txt")
            lines = []
            lines.append("Interference brute-force report\n")
            lines.append(f"keylen_min={keylen_min} keylen_max={keylen_max} "
                         f"topN={topn} parallel={parallel} workers={workers}\n\n")
            lines.append("Top results (key, score_total, V_plus, V_minus):\n")
            for (k, score_total, V_plus, V_minus) in top:
                lines.append(f"{len(k)}:{k}  score={score_total:.3f}  "
                             f"V_plus={V_plus:.3f}  V_minus={V_minus:.3f}\n")
            save_text(report_path, ''.join(lines))

            self.after(0, lambda: self.status.set(f"Interference brute-force finished. Report: {report_path}"))
        except Exception as e:
            self.after(0, lambda: self.status.set(f"Error: {e}"))
            self.after(0, lambda: messagebox.showerror("Error", str(e)))

    def _on_tree_double_click(self, event):
        """
        Handler pentru dublu-click pe un rând din tabelul de brute-force.
        Extrage cheia (L:bits), ia doar bits și afișează curbele de interferență
        pentru această cheie într-o fereastră nouă separată.
        """
        item_id = self.tree.identify_row(event.y)
        if not item_id:
            return
        values = self.tree.item(item_id).get("values", [])
        if not values:
            return
        key_label = values[0]  # de forma "L:bits"
        try:
            _, key_bits = key_label.split(":", 1)
        except ValueError:
            key_bits = key_label  # fallback

        if not self.last_phis or not self.last_memory_by_phi:
            messagebox.showinfo("Info", "Nu există date de experiment pentru această cheie.")
            return

        self._show_key_curves(key_bits)

    def _show_key_curves(self, key_bin: str):
        """
        Creează o fereastră nouă cu graficul P(s=0|I_k=0)(phi) și P(s=0|I_k=1)(phi)
        pentru cheia dată, plus vizibilitățile V_plus și V_minus.
        """
        try:
            phis, p_plus_vals, p_minus_vals, V_plus, V_minus = curves_for_key(
                key_bin, self.last_phis, self.last_memory_by_phi
            )
        except Exception as e:
            messagebox.showerror("Error", f"Nu pot calcula curbele pentru cheie {key_bin}: {e}")
            return

        win = tk.Toplevel(self)
        win.title(f"Interference curves for key {key_bin}")
        win.geometry("820x600")

        fig = plt.Figure(figsize=(7.5, 4.5))
        ax = fig.add_subplot(111)
        ax.set_title(
            f"P(s=0 | I_k=0/1) pentru cheia {key_bin}\n"
            f"V_plus={V_plus:.3f}, V_minus={V_minus:.3f}"
        )
        ax.set_xlabel("φ (radiani)")
        ax.set_ylabel("Probabilitate")

        ax.plot(phis, p_plus_vals, marker='o', label="P(s=0 | I_k=0)")
        ax.plot(phis, p_minus_vals, marker='s', label="P(s=0 | I_k=1)")
        ax.legend()
        ax.grid(True, linestyle='--', alpha=0.4)

        canvas = FigureCanvasTkAgg(fig, master=win)
        canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        canvas.draw()

        frm = ttk.Frame(win, padding=8)
        frm.pack(side=tk.BOTTOM, fill=tk.X)
        ttk.Label(
            frm,
            text=f"V_plus = {V_plus:.4f}    V_minus = {V_minus:.4f}",
            anchor="w"
        ).pack(side=tk.LEFT)

    def _save_report(self):
        out = filedialog.asksaveasfilename(
            title="Save report as",
            defaultextension=".txt",
            filetypes=[("Text files", "*.txt")]
        )
        if not out:
            return
        rows = []
        for item in self.tree.get_children():
            rows.append(self.tree.item(item)["values"])
        txt = "Interference brute-force candidates:\n"
        for r in rows:
            # r = [key_label, score, V_plus, V_minus]
            txt += f"key={r[0]}  score={r[1]}  V_plus={r[2]}  V_minus={r[3]}\n"
        save_text(out, txt)
        messagebox.showinfo("Saved", f"Report saved to {out}")

# --------------------------
# Main
# --------------------------
def main():
    parser = argparse.ArgumentParser(
        description="Quantum Eraser GUI + Interference Brute-Force (keylen range, live)"
    )
    parser.add_argument("--no-gui", action="store_true", help="Run non-interactive (not used)")
    _ = parser.parse_args()

    app = EraserBruteGUI()
    app.mainloop()

if __name__ == "__main__":
    main()
