#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantum Eraser GUI + Correlated Bits + Brute-Force (Interference)
Author: ChatGPT (adaptat pentru utilizator)

Features:
  - Rulează delayed-choice quantum eraser pe Aer sau IBM backends
  - Salvează fluxuri binare, fluxuri corelate (perechi, S|I=0, S|I=1)
  - Plot live pentru P(s=0|i=±) din experimentul "standard"
  - Calcul vizibilitate interferență
  - Modul de brute-force: caută chei binare (0/1) care generează un I_k sintetic
    din keystream SHA-256 și maximizează interferența:
      score(key) = V_plus + V_minus
  - Suport opțional pentru multiprocessing în brute-force

Dependencies:
  pip install "qiskit==1.2.*" "qiskit-aer==0.15.*" matplotlib numpy
  pip install "qiskit-ibm-runtime>=0.25.0"    # doar dacă vreți IBM backends

Usage:
  python quantum_eraser_gui_ibm_corr_bruteforce_parallel_interf.py
"""

import os
import math
import threading
import time
import argparse
from collections import Counter
from typing import List, Tuple, Dict, Optional
import hashlib
import multiprocessing

import tkinter as tk
from tkinter import ttk, messagebox, filedialog

import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

# Qiskit imports
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator
# IBM runtime optional
IBM_RUNTIME_AVAILABLE = False
try:
    from qiskit_ibm_runtime import QiskitRuntimeService
    IBM_RUNTIME_AVAILABLE = True
except Exception:
    IBM_RUNTIME_AVAILABLE = False

# --------------------------
# Quantum circuit helpers
# --------------------------
def eraser_circuit(phi: float, erase: bool = True) -> QuantumCircuit:
    """
    Circuitul de quantum eraser cu 2 qubiți:
      - qubit 0 = S (signal)
      - qubit 1 = I (idler)
    Dacă erase=True, aplică H pe I (șterge informația de cale).
    Măsoară amândoi qubiții în baza Z.
    """
    qc = QuantumCircuit(2, 2)
    qc.h(0)
    qc.rz(phi, 0)
    qc.cx(0, 1)
    qc.h(0)
    if erase:
        qc.h(1)
    qc.measure(0, 0)
    qc.measure(1, 1)
    return qc

def conditional_prob_from_counts(counts: Dict[str, int], i_bit: str = '0') -> float:
    """
    Din counts (de tip Qiskit, ex: {"00": 123, "01": 200, ...}),
    calculează P(s=0 | i = i_bit).
    Aici se presupune că string-ul e "IS" (I primul bit, S al doilea).
    """
    tot = n0 = 0
    for k, v in counts.items():
        if len(k) < 2:
            continue
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0':
                n0 += v
    return (n0 / tot) if tot else 0.0

def visibility(vals: List[float]) -> float:
    """
    Calculează vizibilitatea unei serii de probabilități:
      V = (Vmax - Vmin) / (Vmax + Vmin)
    """
    if not vals:
        return 0.0
    vmax, vmin = max(vals), min(vals)
    denom = vmax + vmin
    return (vmax - vmin) / denom if denom else 0.0

# --------------------------
# I/O helpers
# --------------------------
def save_text(path: str, text: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

# --------------------------
# SHA-256 based keystream
# --------------------------
def sha256_stream(key_bin: str, length_bits: int) -> str:
    """
    Generează un flux de biți pseudo-aleatori determinist, folosind SHA-256:
      - key_bin: cheie binară ("010101...")
      - length_bits: lungimea dorită în biți
    """
    key_bytes = int(key_bin, 2).to_bytes((len(key_bin) + 7) // 8, 'big')
    out_bits = []
    counter = 0
    while len(out_bits) < length_bits:
        m = key_bytes + counter.to_bytes(4, 'big')
        digest = hashlib.sha256(m).digest()
        for b in digest:
            for i in range(8):
                out_bits.append(str((b >> (7 - i)) & 1))
                if len(out_bits) >= length_bits:
                    break
            if len(out_bits) >= length_bits:
                break
        counter += 1
    return ''.join(out_bits[:length_bits])

# --------------------------
# Interference brute-force
# --------------------------
def score_key_visibility(key_bin: str,
                         phis: List[float],
                         memory_by_phi: Dict[int, List[str]]) -> Tuple[float, float, float]:
    """
    Evaluează o cheie binară din perspectiva interferenței.

    - memory_by_phi[idx] = lista de stringuri pentru fiecare shot, de forma "IS"
      (I = bitul qubitului idler, S = bitul qubitului signal); aici folosim doar S,
      iar I_k îl vom genera din keystream.
    - phis = lista valorilor de fază (același index ca memory_by_phi)

    Pași:
      1. Construim un keystream global K de lungime N = total_shots.
      2. Pentru fiecare phi_idx:
         - luăm sub-segmentul corespunzător din K, de lungime len(memory_by_phi[phi_idx])
         - definim I_k(j) = bitul din keystream
         - construim un dicționar de counts pentru perechi (I_k, S).
      3. Calculăm P(s=0 | I_k=0)(phi) și P(s=0 | I_k=1)(phi) pentru fiecare phi.
      4. Obținem vizibilitățile V_plus și V_minus.
      5. score_total = V_plus + V_minus (sau altă combinație, dacă doriți).

    Returnează:
      (score_total, V_plus, V_minus)
    """
    total_shots = sum(len(mem) for mem in memory_by_phi.values())
    if total_shots == 0:
        return 0.0, 0.0, 0.0

    # flux global de biți pentru toate shot-urile de la toate valorile phi
    keystream = sha256_stream(key_bin, total_shots)
    g_index = 0  # index global în keystream

    p_plus_vals = []   # P(s=0 | I_k=0) pentru fiecare phi
    p_minus_vals = []  # P(s=0 | I_k=1) pentru fiecare phi

    # parcurgem phis în ordinea indexului
    for phi_idx in range(len(phis)):
        memory = memory_by_phi.get(phi_idx, [])
        L = len(memory)
        if L == 0:
            p_plus_vals.append(0.0)
            p_minus_vals.append(0.0)
            continue

        ksphi = keystream[g_index:g_index + L]
        g_index += L

        # counts pentru perechi (I_k, S) => "00", "01", "10", "11"
        counts_k: Dict[str, int] = {}
        for j, bits in enumerate(memory):
            if len(bits) < 2:
                continue
            s = bits[1]        # bitul S (0 sau 1)
            i_k = ksphi[j]     # bitul I_k din keystream
            pair = i_k + s
            counts_k[pair] = counts_k.get(pair, 0) + 1

        p0_plus = conditional_prob_from_counts(counts_k, i_bit='0')  # P(s=0 | I_k=0)
        p0_minus = conditional_prob_from_counts(counts_k, i_bit='1') # P(s=0 | I_k=1)

        p_plus_vals.append(p0_plus)
        p_minus_vals.append(p0_minus)

    V_plus = visibility(p_plus_vals)
    V_minus = visibility(p_minus_vals)
    score_total = V_plus + V_minus
    return score_total, V_plus, V_minus

def try_key_visibility_worker(args) -> Tuple[str, float, float, float]:
    """
    Worker pentru multiprocessing:
      args = (key_bin, phis, memory_by_phi)
    Returnează:
      (key_bin, score_total, V_plus, V_minus)
    """
    key_bin, phis, memory_by_phi = args
    try:
        score_total, V_plus, V_minus = score_key_visibility(key_bin, phis, memory_by_phi)
    except Exception:
        score_total, V_plus, V_minus = 0.0, 0.0, 0.0
    return (key_bin, score_total, V_plus, V_minus)

def run_bruteforce_visibility(phis: List[float],
                              memory_by_phi: Dict[int, List[str]],
                              keylen: int = 6,
                              topN: int = 20,
                              parallel: bool = False,
                              workers: int = None) -> List[Tuple[str, float, float, float]]:
    """
    Parcurge toate cheile binare de lungime keylen și caută pe cele care
    maximizează interferența (score_total = V_plus + V_minus).

    Folosește direct datele memory_by_phi obținute în experiment.

    Returnează lista "topN" de tuple:
      (key_bin, score_total, V_plus, V_minus),
    sortată descrescător după score_total.
    """
    total_keys = 2 ** keylen
    if total_keys > 2_000_000 and not parallel:
        raise ValueError("Prea multe chei; activați parallel sau micșorați keylen")

    keys = [format(i, '0{}b'.format(keylen)) for i in range(total_keys)]
    args_iter = ((k, phis, memory_by_phi) for k in keys)

    results: List[Tuple[str, float, float, float]] = []
    if parallel:
        workers = workers or max(1, multiprocessing.cpu_count() - 1)
        with multiprocessing.Pool(processes=workers) as pool:
            for res in pool.imap_unordered(try_key_visibility_worker, args_iter, chunksize=64):
                results.append(res)
    else:
        for arg in args_iter:
            results.append(try_key_visibility_worker(arg))

    # sortăm descrescător după score_total
    results.sort(key=lambda x: x[1], reverse=True)
    return results[:topN]

# --------------------------
# GUI Application
# --------------------------
class EraserBruteGUI(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Quantum Eraser – GUI + Correlated Bits + Interference BruteForce")
        self.geometry("1180x820")

        # Core parameters
        self.var_shots = tk.IntVar(value=4000)
        self.var_nphi  = tk.IntVar(value=13)
        self.var_seed  = tk.IntVar(value=1234)
        self.var_save  = tk.BooleanVar(value=True)
        self.var_show  = tk.BooleanVar(value=True)
        self.var_outdir = tk.StringVar(value="out")

        # correlated
        self.var_save_corr = tk.BooleanVar(value=True)

        # IBM
        self.var_use_ibm = tk.BooleanVar(value=False)
        self.var_ibm_token = tk.StringVar(value="")
        self.var_ibm_instance = tk.StringVar(value="")
        self.var_ibm_backend = tk.StringVar(value="")

        # Brute-force params
        self.var_keylen = tk.IntVar(value=6)
        self.var_topn = tk.IntVar(value=20)
        self.var_parallel = tk.BooleanVar(value=False)
        self.var_workers = tk.IntVar(value=max(1, multiprocessing.cpu_count() - 1))

        self._build_controls()
        self._build_plot()
        self._build_bruteforce_panel()

        # Storage for last-run
        self.last_phis: List[float] = []
        self.last_p_plus: List[float] = []
        self.last_p_minus: List[float] = []
        self.last_memory_by_phi: Dict[int, List[str]] = {}

        self._running = False

    def _build_controls(self):
        frm = ttk.Frame(self, padding=8)
        frm.pack(side=tk.TOP, fill=tk.X)

        ttk.Label(frm, text="shots:").grid(row=0, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.var_shots, width=8).grid(row=0, column=1, padx=4)
        ttk.Label(frm, text="nphi:").grid(row=0, column=2, sticky="w")
        ttk.Entry(frm, textvariable=self.var_nphi, width=8).grid(row=0, column=3, padx=4)
        ttk.Label(frm, text="seed:").grid(row=0, column=4, sticky="w")
        ttk.Entry(frm, textvariable=self.var_seed, width=8).grid(row=0, column=5, padx=4)
        ttk.Checkbutton(frm, text="save", variable=self.var_save).grid(row=0, column=6, padx=6)
        ttk.Checkbutton(frm, text="show", variable=self.var_show).grid(row=0, column=7, padx=6)
        ttk.Checkbutton(frm, text="save correlated streams", variable=self.var_save_corr).grid(row=0, column=8, padx=6)

        ttk.Label(frm, text="outdir:").grid(row=1, column=0, sticky="w", pady=6)
        ttk.Entry(frm, textvariable=self.var_outdir, width=36).grid(row=1, column=1, columnspan=4, sticky="we", padx=4)
        ttk.Button(frm, text="Browse…", command=self._choose_outdir).grid(row=1, column=5, sticky="w", padx=4)

        ibm = ttk.LabelFrame(self, text="IBM Quantum (optional)", padding=6)
        ibm.pack(side=tk.TOP, fill=tk.X, padx=8, pady=6)
        ttk.Checkbutton(ibm, text="Use IBM Quantum", variable=self.var_use_ibm).grid(row=0, column=0, sticky="w")
        ttk.Label(ibm, text="API Token:").grid(row=0, column=1, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_token, width=44, show="•").grid(row=0, column=2, padx=4, sticky="we")
        ttk.Label(ibm, text="Instance:").grid(row=0, column=3, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_instance, width=20).grid(row=0, column=4, padx=4)
        ttk.Label(ibm, text="Backend:").grid(row=0, column=5, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_backend, width=18).grid(row=0, column=6, padx=4)

        btns = ttk.Frame(self, padding=6)
        btns.pack(side=tk.TOP, fill=tk.X)
        self.btn_run = ttk.Button(btns, text="Run Experiment", command=self.on_run)
        self.btn_run.pack(side=tk.LEFT, padx=4)
        self.btn_correlate = ttk.Button(btns, text="Correlate / Detect", command=self.on_correlate)
        self.btn_correlate.pack(side=tk.LEFT, padx=4)
        self.btn_preview = ttk.Button(btns, text="Show Last Phi Correlations", command=self.on_preview)
        self.btn_preview.pack(side=tk.LEFT, padx=4)

    def _build_plot(self):
        self.fig = plt.Figure(figsize=(8.5, 5))
        self.ax = self.fig.add_subplot(111)
        self.ax.set_title("Quantum Eraser – P(s=0|i=±)")
        self.ax.set_xlabel("φ (radiani)")
        self.ax.set_ylabel("Probabilitate")
        (self.line_plus,) = self.ax.plot([], [], marker='o', label="P(s=0|i=+)")
        (self.line_minus,) = self.ax.plot([], [], marker='s', label="P(s=0|i=-)")
        self.ax.legend()
        self.ax.grid(True, linestyle='--', alpha=0.4)
        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        self.status = tk.StringVar(value="Ready")
        ttk.Label(self, textvariable=self.status, anchor="w", padding=6).pack(side=tk.BOTTOM, fill=tk.X)

    def _build_bruteforce_panel(self):
        panel = ttk.LabelFrame(self, text="Brute-Force Interference (on last experiment)", padding=6)
        panel.pack(side=tk.BOTTOM, fill=tk.X, padx=8, pady=6)

        ttk.Label(panel, text="keylen:").grid(row=0, column=0, sticky="w", pady=4)
        ttk.Entry(panel, textvariable=self.var_keylen, width=6).grid(row=0, column=1, padx=4)
        ttk.Label(panel, text="topN:").grid(row=0, column=2, sticky="w")
        ttk.Entry(panel, textvariable=self.var_topn, width=6).grid(row=0, column=3, padx=4)

        ttk.Checkbutton(panel, text="parallel", variable=self.var_parallel).grid(row=1, column=0, padx=4, pady=4)
        ttk.Label(panel, text="workers:").grid(row=1, column=1, sticky="e")
        ttk.Entry(panel, textvariable=self.var_workers, width=6).grid(row=1, column=2, padx=4)
        ttk.Button(panel, text="Run Brute-Force", command=self.on_run_bruteforce).grid(row=1, column=3, padx=6)

        # Tabel rezultate: reinterpretăm coloanele
        #  - key
        #  - score  (V_plus + V_minus)
        #  - V_plus
        #  - V_minus
        cols = ("key", "score", "V_plus", "V_minus")
        self.tree = ttk.Treeview(panel, columns=cols, show="headings", height=8)
        self.tree.heading("key", text="key")
        self.tree.heading("score", text="score (V+ + V-)")
        self.tree.heading("V_plus", text="V_plus")
        self.tree.heading("V_minus", text="V_minus")

        self.tree.column("key", width=160, anchor="w")
        self.tree.column("score", width=140, anchor="w")
        self.tree.column("V_plus", width=140, anchor="w")
        self.tree.column("V_minus", width=140, anchor="w")

        self.tree.grid(row=2, column=0, columnspan=4, sticky="nsew", pady=6)
        panel.grid_rowconfigure(2, weight=1)

        ttk.Button(panel, text="Save Report", command=self._save_report).grid(row=3, column=0, padx=4, pady=6)
        ttk.Button(panel, text="Open Outdir", command=self._open_outdir).grid(row=3, column=1, padx=4, pady=6)

    # ---------------- UI actions ----------------
    def _choose_outdir(self):
        d = filedialog.askdirectory(title="Select outdir")
        if d:
            self.var_outdir.set(d)

    def _open_outdir(self):
        out = self.var_outdir.get()
        try:
            if os.name == 'nt':
                os.startfile(out)
            elif os.name == 'posix':
                os.system(f'xdg-open "{out}"')
        except Exception:
            messagebox.showinfo("Open", f"Outdir: {out}")

    def on_run(self):
        if hasattr(self, "_running") and self._running:
            messagebox.showinfo("Info", "Already running")
            return
        if self.var_use_ibm.get() and not IBM_RUNTIME_AVAILABLE:
            messagebox.showerror("Error", "qiskit-ibm-runtime not installed")
            return
        try:
            shots = int(self.var_shots.get())
            nphi  = int(self.var_nphi.get())
            seed  = int(self.var_seed.get())
            if shots <= 0 or nphi <= 1:
                raise ValueError
        except Exception:
            messagebox.showerror("Error", "Invalid shots/nphi/seed")
            return
        self._running = True
        self.btn_run.config(state=tk.DISABLED)
        self.status.set("Running experiment...")
        threading.Thread(target=self._run_experiment_thread, daemon=True).start()

    def _build_aer_backend(self):
        return AerSimulator(seed_simulator=int(self.var_seed.get()))

    def _build_ibm_backend(self):
        token = self.var_ibm_token.get().strip()
        instance = self.var_ibm_instance.get().strip() or None
        service = QiskitRuntimeService(channel="ibm_quantum", token=token, instance=instance) if instance else QiskitRuntimeService(channel="ibm_quantum", token=token)
        backend_name = self.var_ibm_backend.get().strip()
        backend = service.backend(backend_name)
        self.service = service
        self.backend = backend
        return backend

    def _run_experiment_thread(self):
        try:
            self._run_experiment()
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))
        finally:
            self._running = False
            self.btn_run.config(state=tk.NORMAL)

    def _run_experiment(self):
        shots = int(self.var_shots.get())
        nphi = int(self.var_nphi.get())
        save = bool(self.var_save.get())
        save_corr = bool(self.var_save_corr.get())
        outdir = self.var_outdir.get()
        seed = int(self.var_seed.get())
        use_ibm = bool(self.var_use_ibm.get())

        os.makedirs(outdir, exist_ok=True)
        phis = list(np.linspace(0.0, 2.0 * math.pi, nphi))

        if use_ibm:
            backend = self._build_ibm_backend()
        else:
            backend = self._build_aer_backend()

        # resetăm plotul
        self.line_plus.set_data([], [])
        self.line_minus.set_data([], [])
        self.ax.relim()
        self.ax.autoscale_view()
        self.canvas.draw_idle()

        p_plus: List[float] = []
        p_minus: List[float] = []
        self.last_phis = phis
        self.last_p_plus = p_plus
        self.last_p_minus = p_minus
        self.last_memory_by_phi.clear()

        for idx, phi in enumerate(phis):
            qc = eraser_circuit(phi, erase=True)
            if use_ibm:
                tqc = transpile(qc, backend=backend, optimization_level=3, seed_transpiler=seed)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                try:
                    memory = result.get_memory()
                except Exception:
                    memory = []
                    for bitstr, c in counts.items():
                        memory.extend([bitstr] * int(c))
            else:
                tqc = transpile(qc, backend=backend, optimization_level=3)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                memory = result.get_memory()

            # salvăm și în structura internă pentru brute-force
            self.last_memory_by_phi[idx] = memory

            p0_plus = conditional_prob_from_counts(counts, '0')
            p0_minus = conditional_prob_from_counts(counts, '1')
            p_plus.append(p0_plus)
            p_minus.append(p0_minus)

            self.line_plus.set_data(phis[:idx + 1], p_plus)
            self.line_minus.set_data(phis[:idx + 1], p_minus)
            self.ax.relim()
            self.ax.autoscale_view()
            self.canvas.draw_idle()

            if save:
                # flux S (bitul 1)
                s_stream = ''.join(m[1] for m in memory if len(m) >= 2)
                save_text(os.path.join(outdir, f"binary_phi_{idx:02d}.txt"), s_stream)
            if save_corr:
                # flux perechi (I S)
                save_text(os.path.join(outdir, f"pairs_phi_{idx:02d}.txt"), '\n'.join(memory))
                s_i0 = ''.join(m[1] for m in memory if len(m) >= 2 and m[0] == '0')
                s_i1 = ''.join(m[1] for m in memory if len(m) >= 2 and m[0] == '1')
                save_text(os.path.join(outdir, f"S_given_I0_phi_{idx:02d}.txt"), s_i0)
                save_text(os.path.join(outdir, f"S_given_I1_phi_{idx:02d}.txt"), s_i1)
                import csv
                with open(os.path.join(outdir, f"counts_phi_{idx:02d}.csv"), "w", newline='') as f:
                    w = csv.writer(f)
                    w.writerow(['pair', 'count'])
                    for pair in ['00', '01', '10', '11']:
                        w.writerow([pair, counts.get(pair, 0)])

            self.status.set(f"φ={phi:.3f} • P(+)= {p0_plus:.3f} • P(-)= {p0_minus:.3f} • {idx + 1}/{len(phis)}")

        if save:
            import csv
            with open(os.path.join(outdir, "eraser_probs.csv"), "w", newline='') as f:
                w = csv.writer(f)
                w.writerow(['phi', 'P(s=0|i=+)', 'P(s=0|i=-)'])
                for phi, a, b in zip(phis, p_plus, p_minus):
                    w.writerow([phi, a, b])
            self.fig.savefig(os.path.join(outdir, "eraser_plot.png"), dpi=160)

        Vp = visibility(p_plus)
        Vm = visibility(p_minus)
        self.status.set(f"Done. V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f}  (IBM: {use_ibm})")

    def on_correlate(self):
        if not self.last_p_plus:
            messagebox.showinfo("Info", "Run experiment first")
            return
        Vp = visibility(self.last_p_plus)
        Vm = visibility(self.last_p_minus)
        verdict = "Interferență detectată" if (Vp > 0.25 and Vm > 0.25) else "Interferență slabă/absentă"
        self.ax.text(0.02, 0.96, f"V(+)={Vp:.2f}, V(-)={Vm:.2f}\n{verdict}",
                     transform=self.ax.transAxes, fontsize=10,
                     bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
        self.canvas.draw_idle()
        self.status.set(f"Correlation: V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f} → {verdict}")

    def on_preview(self):
        if not self.last_memory_by_phi:
            messagebox.showinfo("Info", "Run first")
            return
        idx = max(self.last_memory_by_phi.keys())
        memory = self.last_memory_by_phi[idx]
        c = Counter(memory)
        preview_pairs = '\n'.join(memory[:200])
        s_i0 = ''.join(m[1] for m in memory if len(m) >= 2 and m[0] == '0')[:400]
        s_i1 = ''.join(m[1] for m in memory if len(m) >= 2 and m[0] == '1')[:400]

        win = tk.Toplevel(self)
        win.title(f"Preview correlations (phi idx {idx})")
        win.geometry("640x720")
        ttk.Label(win, text=f"First 200 pairs (I S) for phi index {idx}:").pack(anchor="w", padx=8, pady=(8, 0))
        txt = tk.Text(win, height=12)
        txt.pack(fill=tk.BOTH, padx=8)
        txt.insert("1.0", preview_pairs)
        txt.config(state=tk.DISABLED)

        ttk.Label(win, text="Counts (coincidences):").pack(anchor="w", padx=8, pady=(8, 0))
        tree = ttk.Treeview(win, columns=("pair", "count"), show="headings", height=4)
        tree.heading("pair", text="pair")
        tree.heading("count", text="count")
        for pair in ['00', '01', '10', '11']:
            tree.insert("", "end", values=(pair, c.get(pair, 0)))
        tree.pack(fill=tk.X, padx=8, pady=4)

        ttk.Label(win, text="S | I=0 (first 400 bits):").pack(anchor="w", padx=8, pady=(8, 0))
        t0 = tk.Text(win, height=4)
        t0.pack(fill=tk.X, padx=8)
        t0.insert("1.0", s_i0)
        t0.config(state=tk.DISABLED)
        ttk.Label(win, text="S | I=1 (first 400 bits):").pack(anchor="w", padx=8, pady=(8, 0))
        t1 = tk.Text(win, height=4)
        t1.pack(fill=tk.X, padx=8)
        t1.insert("1.0", s_i1)
        t1.config(state=tk.DISABLED)

    # ---------------- Brute-force (interferență) ----------------
    def on_run_bruteforce(self):
        """
        Rulează brute-force de interferență pe ultimul experiment efectuat.
        Caută cheile care maximizează V_plus + V_minus.
        """
        if not self.last_memory_by_phi or not self.last_phis:
            messagebox.showerror("Error", "Mai întâi rulați experimentul (Run Experiment).")
            return

        try:
            keylen = int(self.var_keylen.get())
            topn = int(self.var_topn.get())
            parallel = bool(self.var_parallel.get())
            workers = int(self.var_workers.get())
            if keylen <= 0 or topn <= 0:
                raise ValueError
        except Exception:
            messagebox.showerror("Error", "Parametri brute-force invalizi (keylen/topN/workers).")
            return

        self.tree.delete(*self.tree.get_children())
        self.status.set("Running interference brute-force...")
        threading.Thread(
            target=self._run_bruteforce_thread,
            args=(keylen, topn, parallel, workers),
            daemon=True
        ).start()

    def _run_bruteforce_thread(self, keylen, topn, parallel, workers):
        try:
            phis = self.last_phis
            memory_by_phi = self.last_memory_by_phi

            top = run_bruteforce_visibility(
                phis,
                memory_by_phi,
                keylen=keylen,
                topN=topn,
                parallel=parallel,
                workers=workers
            )

            for (k, score_total, V_plus, V_minus) in top:
                self.tree.insert(
                    "",
                    "end",
                    values=(k, f"{score_total:.3f}", f"{V_plus:.3f}", f"{V_minus:.3f}")
                )

            report_path = os.path.join(self.var_outdir.get(), "interference_bruteforce_report.txt")
            lines = []
            lines.append("Interference brute-force report\n")
            lines.append(f"keylen={keylen} topN={topn} parallel={parallel} workers={workers}\n\n")
            lines.append("Top results (key, score_total, V_plus, V_minus):\n")
            for (k, score_total, V_plus, V_minus) in top:
                lines.append(f"{k}  score={score_total:.3f}  V_plus={V_plus:.3f}  V_minus={V_minus:.3f}\n")
            save_text(report_path, ''.join(lines))

            self.status.set(f"Interference brute-force finished. Report: {report_path}")
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))

    def _save_report(self):
        out = filedialog.asksaveasfilename(
            title="Save report as",
            defaultextension=".txt",
            filetypes=[("Text files", "*.txt")]
        )
        if not out:
            return
        rows = []
        for item in self.tree.get_children():
            rows.append(self.tree.item(item)["values"])
        txt = "Interference brute-force candidates:\n"
        for r in rows:
            # r = [key, score, V_plus, V_minus]
            txt += f"key={r[0]}  score={r[1]}  V_plus={r[2]}  V_minus={r[3]}\n"
        save_text(out, txt)
        messagebox.showinfo("Saved", f"Report saved to {out}")

# --------------------------
# Main
# --------------------------
def main():
    parser = argparse.ArgumentParser(description="Quantum Eraser GUI + Interference Brute-Force")
    parser.add_argument("--no-gui", action="store_true", help="Run non-interactive (not used)")
    _ = parser.parse_args()

    app = EraserBruteGUI()
    app.mainloop()

if __name__ == "__main__":
    main()
