#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantum Eraser GUI + Correlated Bits + Brute-Force (parallel)
Author: ChatGPT (adaptare pentru utilizator)
Features:
  - Run delayed-choice quantum eraser on Aer or IBM backends
  - Save binary streams, correlated streams (pairs, S|I=0, S|I=1)
  - Live plotting of P(s=0|i=±)
  - Correlation / visibility detection
  - Analiză non-corelație (MI, chi2, HSIC, predictability gain, ΔP)
  - Brute-force module: try binary keys (0/1) to extract hidden ASCII text,
    scoring by printable ratio + word heuristics
  - Optional multiprocessing parallel brute-force
Dependencies:
  pip install "qiskit==1.2.*" "qiskit-aer==0.15.*" matplotlib numpy
  pip install "qiskit-ibm-runtime>=0.25.0"    # doar dacă vrei IBM backends
Usage:
  python quantum_eraser_gui_ibm_corr_bruteforce_parallel.py
"""

import os
import math
import threading
import argparse
from collections import Counter
from typing import List, Tuple, Dict, Optional
import hashlib
import multiprocessing

import tkinter as tk
from tkinter import ttk, messagebox, filedialog

import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

# Qiskit imports
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator
# IBM runtime opțional
IBM_RUNTIME_AVAILABLE = False
try:
    from qiskit_ibm_runtime import QiskitRuntimeService
    IBM_RUNTIME_AVAILABLE = True
except Exception:
    IBM_RUNTIME_AVAILABLE = False

# --------------------------
# Quantum circuit helpers
# --------------------------
def eraser_circuit(phi: float, erase: bool = True) -> QuantumCircuit:
    qc = QuantumCircuit(2, 2)
    qc.h(0)
    qc.rz(phi, 0)
    qc.cx(0, 1)
    qc.h(0)
    if erase:
        qc.h(1)
    qc.measure(0, 0)
    qc.measure(1, 1)
    return qc

def conditional_prob_from_counts(counts: Dict[str,int], i_bit: str = '0') -> float:
    tot = n0 = 0
    for k,v in counts.items():
        if len(k) < 2: continue
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0': n0 += v
    return (n0 / tot) if tot else 0.0

def visibility(vals: List[float]) -> float:
    if not vals: return 0.0
    vmax, vmin = max(vals), min(vals)
    denom = vmax + vmin
    return (vmax - vmin) / denom if denom else 0.0

# --------------------------
# Non-correlation metrics
# --------------------------
def _probs_from_counts(counts: Dict[str,int]) -> Tuple[float,float,float,float,Dict[str,float]]:
    n = sum(counts.values())
    if n == 0:
        return 0.0, 0.0, 0.0, 0.0, {k:0.0 for k in ['00','01','10','11']}
    p00 = counts.get('00',0)/n
    p01 = counts.get('01',0)/n
    p10 = counts.get('10',0)/n
    p11 = counts.get('11',0)/n
    p_i0 = p00 + p01
    p_i1 = p10 + p11
    p_s0 = p00 + p10
    p_s1 = p01 + p11
    return p_i0, p_i1, p_s0, p_s1, {'00':p00,'01':p01,'10':p10,'11':p11}

def mutual_information_from_counts(counts: Dict[str,int]) -> float:
    p_i0, p_i1, p_s0, p_s1, pj = _probs_from_counts(counts)
    p_i = {'0': p_i0, '1': p_i1}
    p_s = {'0': p_s0, '1': p_s1}
    mi = 0.0
    for i in ['0','1']:
        for s in ['0','1']:
            pij = pj[i+s]
            if pij > 0 and p_i[i] > 0 and p_s[s] > 0:
                mi += pij * math.log(pij / (p_i[i] * p_s[s] + 1e-15) + 1e-15)
    return mi

def chi2_stat_from_counts(counts: Dict[str,int]) -> float:
    n = sum(counts.values())
    if n == 0:
        return 0.0
    p_i0, p_i1, p_s0, p_s1, pj = _probs_from_counts(counts)
    expected = {
        '00': p_i0 * p_s0,
        '01': p_i0 * p_s1,
        '10': p_i1 * p_s0,
        '11': p_i1 * p_s1,
    }
    chi2 = 0.0
    for k in ['00','01','10','11']:
        o = pj[k]
        e = expected[k]
        if e > 1e-12:
            chi2 += (n * (o - e) * (o - e)) / e
    return chi2

def predictability_gain_from_counts(counts: Dict[str,int]) -> float:
    p_i0, p_i1, p_s0, p_s1, pj = _probs_from_counts(counts)
    if p_i0 == 0 or p_i1 == 0:
        return 0.0
    p_s0_i0 = pj['00'] / (pj['00'] + pj['01'] + 1e-15)
    p_s0_i1 = pj['10'] / (pj['10'] + pj['11'] + 1e-15)
    gain = abs(p_s0_i0 - p_s0) + abs(p_s0_i1 - p_s0)
    return gain

def hsic_binary_from_memory(memory: List[str]) -> float:
    n = len(memory)
    if n < 4:
        return 0.0
    x = np.array([1 if m[0]=='1' else 0 for m in memory], dtype=float)
    y = np.array([1 if m[1]=='1' else 0 for m in memory], dtype=float)
    K = (x[:,None] == x[None,:]).astype(float)
    L = (y[:,None] == y[None,:]).astype(float)
    H = np.eye(n) - (1.0/n) * np.ones((n,n))
    KH = K @ H
    LH = L @ H
    hsic = np.trace(KH @ LH) / ((n-1.0)**2)
    return float(hsic)

# --------------------------
# I/O helpers
# --------------------------
def save_text(path: str, text: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

def read_binary_file(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        s = f.read().strip()
    return ''.join(ch for ch in s if ch in '01')

# --------------------------
# Brute-force utilities
# --------------------------
COMMON_WORDS = [
    "the","and","that","have","for","not","with","you","this","but","was","are",
    "from","they","his","she","which","will","one","all","would","there","their",
    "în","și","pe","la","care","să","este","nu"
]

def sha256_stream(key_bin: str, length_bits: int) -> str:
    key_bytes = int(key_bin, 2).to_bytes((len(key_bin)+7)//8, 'big')
    out_bits = []
    counter = 0
    while len(out_bits) < length_bits:
        m = key_bytes + counter.to_bytes(4, 'big')
        digest = hashlib.sha256(m).digest()
        for b in digest:
            for i in range(8):
                out_bits.append(str((b >> (7 - i)) & 1))
                if len(out_bits) >= length_bits: break
            if len(out_bits) >= length_bits: break
        counter += 1
    return ''.join(out_bits[:length_bits])

def deterministic_sample_positions(total_bits: int, m: int) -> List[int]:
    if m <= 0: return []
    if m >= total_bits:
        return list(range(total_bits))
    step = total_bits / m
    pos = [int(round(i * step)) % total_bits for i in range(m)]
    seen = set(); res = []
    for p in pos:
        orig = p
        while p in seen:
            p = (p + 1) % total_bits
            if p == orig: break
        seen.add(p); res.append(p)
    return res

def extract_bits_with_key(stego_bits: str, key_bin: str, out_len_bits: int, sample_positions: Optional[List[int]] = None) -> str:
    total = len(stego_bits)
    if sample_positions is None:
        if out_len_bits > total:
            raise ValueError("Requested more bits than available in stego")
        positions = list(range(out_len_bits))
    else:
        positions = sample_positions
        if max(positions) >= total:
            raise ValueError("Sample positions exceed length of stego bits")
    keystream = sha256_stream(key_bin, len(positions))
    out_bits = []
    for kbit, pos in zip(keystream, positions):
        sbit = stego_bits[pos]
        out_bits.append('1' if (sbit != kbit) else '0')  # XOR
    return ''.join(out_bits)

def bits_to_bytes(bits: str) -> bytes:
    pad = (-len(bits)) % 8
    bits_padded = bits + ('0' * pad)
    byts = bytearray()
    for i in range(0, len(bits_padded), 8):
        byte = bits_padded[i:i+8]
        byts.append(int(byte, 2))
    return bytes(byts)

def score_message_bytes(b: bytes) -> Tuple[float, float]:
    if not b:
        return 0.0, 0.0
    printable = sum(1 for x in b if 9 <= x <= 126)
    printable_ratio = printable / len(b)
    try:
        s = b.decode('utf-8', errors='ignore').lower()
    except Exception:
        s = ''
    common_score = sum((s.count(w) for w in COMMON_WORDS))
    return printable_ratio, float(common_score)

def try_key_worker(args) -> Tuple[str, float, float, str]:
    key_bin, stego_bits, out_chars, sample_positions = args
    out_bits_len = out_chars * 8
    try:
        extracted_bits = extract_bits_with_key(stego_bits, key_bin, out_bits_len, sample_positions)
        b = bits_to_bytes(extracted_bits)
        pr, cs = score_message_bytes(b)
        preview = b.decode('utf-8', errors='replace')[:out_chars]
    except Exception:
        pr, cs, preview = 0.0, 0.0, ""
    return (key_bin, pr, cs, preview)

def run_bruteforce(stego_bits: str,
                   keylen: int = 6,
                   out_chars: int = 24,
                   topN: int = 20,
                   parallel: bool = False,
                   workers: int = None) -> List[Tuple[str, float, float, str]]:
    total_keys = 2 ** keylen
    if total_keys > 2000000 and not parallel:
        raise ValueError("Too many keys; enable parallel or reduce keylen")
    out_len_bits = out_chars * 8
    sample_positions = deterministic_sample_positions(len(stego_bits), out_len_bits)
    keys = [format(i, '0{}b'.format(keylen)) for i in range(2**keylen)]
    args_iter = ((k, stego_bits, out_chars, sample_positions) for k in keys)

    results = []
    if parallel:
        workers = workers or max(1, multiprocessing.cpu_count() - 1)
        with multiprocessing.Pool(processes=workers) as pool:
            for res in pool.imap_unordered(try_key_worker, args_iter, chunksize=64):
                results.append(res)
    else:
        for arg in args_iter:
            results.append(try_key_worker(arg))

    combined = []
    for key_bin, pr, cs, preview in results:
        score = pr + 0.02 * cs
        combined.append((key_bin, score, pr, cs, preview))
    combined.sort(key=lambda x: x[1], reverse=True)
    top = [(k, pr, cs, prev) for (k, s, pr, cs, prev) in combined[:topN]]
    return top

# --------------------------
# GUI Application
# --------------------------
class EraserBruteGUI(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Quantum Eraser – GUI + Correlated Bits + BruteForce")
        self.geometry("1180x820")

        # Core parameters
        self.var_shots = tk.IntVar(value=4000)
        self.var_nphi  = tk.IntVar(value=13)
        self.var_seed  = tk.IntVar(value=1234)
        self.var_save  = tk.BooleanVar(value=True)
        self.var_show  = tk.BooleanVar(value=True)
        self.var_outdir = tk.StringVar(value="out")

        # correlated
        self.var_save_corr = tk.BooleanVar(value=True)

        # IBM
        self.var_use_ibm = tk.BooleanVar(value=False)
        self.var_ibm_token = tk.StringVar(value="")
        self.var_ibm_instance = tk.StringVar(value="")
        self.var_ibm_backend = tk.StringVar(value="")

        # Brute-force params
        self.var_keylen = tk.IntVar(value=6)
        self.var_outchars = tk.IntVar(value=24)
        self.var_topn = tk.IntVar(value=20)
        self.var_parallel = tk.BooleanVar(value=False)
        self.var_workers = tk.IntVar(value=max(1, multiprocessing.cpu_count()-1))

        # Mesaj pentru transmisie automată
        self.var_binary_message = tk.StringVar(value="1011001")

        self._build_controls()
        self._build_layout()
        self._build_plot()
        # Notebook pentru taburi (plasat sus, deasupra butoanelor)
        self.nb = ttk.Notebook(self.tabs_container_top)
        self.nb.pack(side=tk.TOP, fill=tk.X, expand=False, padx=0, pady=0)
        self._build_bruteforce_panel()

        # Storage for last-run
        self.last_phis = []
        self.last_p_plus = []
        self.last_p_minus = []
        self.last_memory_by_phi = {}

        self._running = False

    def _build_controls(self):
        frm = ttk.Frame(self, padding=8)
        frm.pack(side=tk.TOP, fill=tk.X)

        ttk.Label(frm, text="shots:").grid(row=0, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.var_shots, width=8).grid(row=0, column=1, padx=4)
        ttk.Label(frm, text="nphi:").grid(row=0, column=2, sticky="w")
        ttk.Entry(frm, textvariable=self.var_nphi, width=8).grid(row=0, column=3, padx=4)
        ttk.Label(frm, text="seed:").grid(row=0, column=4, sticky="w")
        ttk.Entry(frm, textvariable=self.var_seed, width=8).grid(row=0, column=5, padx=4)
        ttk.Checkbutton(frm, text="save", variable=self.var_save).grid(row=0, column=6, padx=6)
        ttk.Checkbutton(frm, text="show", variable=self.var_show).grid(row=0, column=7, padx=6)
        ttk.Checkbutton(frm, text="save correlated streams", variable=self.var_save_corr).grid(row=0, column=8, padx=6)

        ttk.Label(frm, text="outdir:").grid(row=1, column=0, sticky="w", pady=6)
        ttk.Entry(frm, textvariable=self.var_outdir, width=36).grid(row=1, column=1, columnspan=4, sticky="we", padx=4)
        ttk.Button(frm, text="Browse…", command=self._choose_outdir).grid(row=1, column=5, sticky="w", padx=4)

        ibm = ttk.LabelFrame(self, text="IBM Quantum (optional)", padding=6)
        ibm.pack(side=tk.TOP, fill=tk.X, padx=8, pady=6)
        ttk.Checkbutton(ibm, text="Use IBM Quantum", variable=self.var_use_ibm).grid(row=0, column=0, sticky="w")
        ttk.Label(ibm, text="API Token:").grid(row=0, column=1, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_token, width=44, show="•").grid(row=0, column=2, padx=4, sticky="we")
        ttk.Label(ibm, text="Instance:").grid(row=0, column=3, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_instance, width=20).grid(row=0, column=4, padx=4)
        ttk.Label(ibm, text="Backend:").grid(row=0, column=5, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_backend, width=18).grid(row=0, column=6, padx=4)

        # Container pentru taburi plasat chiar deasupra rândului de butoane
        self.tabs_container_top = ttk.Frame(self, padding=0)
        self.tabs_container_top.pack(side=tk.TOP, fill=tk.X, padx=8, pady=(4,4))

        btns = ttk.Frame(self, padding=6)
        btns.pack(side=tk.TOP, fill=tk.X)
        self.btn_run = ttk.Button(btns, text="Run Experiment", command=self.on_run)
        self.btn_run.pack(side=tk.LEFT, padx=4)
        self.btn_correlate = ttk.Button(btns, text="Correlate / Detect", command=self.on_correlate)
        self.btn_correlate.pack(side=tk.LEFT, padx=4)
        self.btn_preview = ttk.Button(btns, text="Show Last Phi Correlations", command=self.on_preview)
        self.btn_preview.pack(side=tk.LEFT, padx=4)
        self.btn_noncorr = ttk.Button(btns, text="Analiză non-corelație", command=self.on_analyze_noncorr)
        self.btn_noncorr.pack(side=tk.LEFT, padx=4)

    def _build_layout(self):
        # Paned window: doar plot (jos). Taburile sunt sus în tabs_container_top.
        body = ttk.Frame(self)
        body.pack(side=tk.TOP, fill=tk.BOTH, expand=True)
        self.paned = tk.PanedWindow(body, orient=tk.VERTICAL)
        self.paned.pack(fill=tk.BOTH, expand=True)
        self.plot_frame = ttk.Frame(self.paned)
        # Adăugăm doar plotul în paned (jos)
        self.paned.add(self.plot_frame)
        try:
            self.paned.paneconfig(self.plot_frame, minsize=320)
        except Exception:
            pass

    def _build_plot(self):
        self.fig = plt.Figure(figsize=(8.5,5))
        self.ax = self.fig.add_subplot(111)
        self.ax.set_title("Quantum Eraser – P(s=0|i=±)")
        self.ax.set_xlabel("φ (radiani)")
        self.ax.set_ylabel("Probabilitate")
        (self.line_plus,) = self.ax.plot([], [], marker='o', label="P(s=0|i=+)")
        (self.line_minus,) = self.ax.plot([], [], marker='s', label="P(s=0|i=-)")
        self.ax.legend(); self.ax.grid(True, linestyle='--', alpha=0.4)
        self.canvas = FigureCanvasTkAgg(self.fig, master=self.plot_frame)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        self.status = tk.StringVar(value="Ready")
        ttk.Label(self, textvariable=self.status, anchor="w", padding=6).pack(side=tk.BOTTOM, fill=tk.X)

    def _build_bruteforce_panel(self):
        # --- Tab Brute-Force ---
        bf_tab = ttk.Frame(self.nb)
        self.nb.add(bf_tab, text="Brute-Force")
        panel = ttk.LabelFrame(bf_tab, text="Brute-Force (analyze binary streams)", padding=6)
        panel.pack(fill=tk.BOTH, expand=True)

        ttk.Label(panel, text="Select binary file:").grid(row=0, column=0, sticky="w")
        self.var_target_file = tk.StringVar(value="")
        ttk.Entry(panel, textvariable=self.var_target_file, width=56).grid(row=0, column=1, columnspan=4, sticky="we", padx=4)
        ttk.Button(panel, text="Browse", command=self._choose_target_file).grid(row=0, column=5, padx=4)

        ttk.Label(panel, text="keylen:").grid(row=1, column=0, sticky="w", pady=4)
        ttk.Entry(panel, textvariable=self.var_keylen, width=6).grid(row=1, column=1, padx=4)
        ttk.Label(panel, text="out_chars:").grid(row=1, column=2, sticky="w")
        ttk.Entry(panel, textvariable=self.var_outchars, width=6).grid(row=1, column=3, padx=4)
        ttk.Label(panel, text="topN:").grid(row=1, column=4, sticky="w")
        ttk.Entry(panel, textvariable=self.var_topn, width=6).grid(row=1, column=5, padx=4)

        ttk.Checkbutton(panel, text="parallel", variable=self.var_parallel).grid(row=2, column=0, padx=4, pady=4)
        ttk.Label(panel, text="workers:").grid(row=2, column=1, sticky="e")
        ttk.Entry(panel, textvariable=self.var_workers, width=6).grid(row=2, column=2, padx=4)
        ttk.Button(panel, text="Run Brute-Force", command=self.on_run_bruteforce).grid(row=2, column=3, padx=6)

        cols = ("key","printable","common_words","preview")
        self.tree = ttk.Treeview(panel, columns=cols, show="headings", height=8)
        for c in cols:
            self.tree.heading(c, text=c)
            self.tree.column(c, width=180 if c!="preview" else 360, anchor="w")
        self.tree.grid(row=3, column=0, columnspan=6, sticky="nsew", pady=6)
        panel.grid_rowconfigure(3, weight=1)

        ttk.Button(panel, text="Save Report", command=self._save_report).grid(row=4, column=0, padx=4, pady=6)
        ttk.Button(panel, text="Open Outdir", command=self._open_outdir).grid(row=4, column=1, padx=4, pady=6)

        # --- Tab Analiză ---
        an_tab = ttk.Frame(self.nb)
        self.nb.add(an_tab, text="Analiză")
        ttk.Label(an_tab, text="Analiză non-corelație și comparație cu corelarea", anchor="w").pack(anchor="w", padx=8, pady=(10,4))
        ttk.Label(an_tab, text="Apasă butonul pentru a calcula metrici (MI, chi2, HSIC, ΔP) și a deschide tabelul de rezultate.", wraplength=960).pack(anchor="w", padx=8)

        # Câmp pentru mesaj binar
        msg_frame = ttk.Frame(an_tab)
        msg_frame.pack(anchor="w", padx=8, pady=4)
        ttk.Label(msg_frame, text="Mesaj binar pentru transmisie automată:").pack(side=tk.LEFT)
        ttk.Entry(msg_frame, textvariable=self.var_binary_message, width=40).pack(side=tk.LEFT, padx=4)

        ttk.Button(an_tab, text="Calculează metrici și deschide tabel", command=self.on_analyze_noncorr).pack(anchor="w", padx=8, pady=8)

    # ---------------- UI actions ----------------
    def _choose_outdir(self):
        d = filedialog.askdirectory(title="Select outdir")
        if d: self.var_outdir.set(d)

    def _choose_target_file(self):
        f = filedialog.askopenfilename(title="Select binary file", filetypes=[("Text files","*.txt"),("All files","*.*")])
        if f: self.var_target_file.set(f)

    def _open_outdir(self):
        out = self.var_outdir.get()
        try:
            if os.name == 'nt':
                os.startfile(out)
            elif os.name == 'posix':
                os.system(f'xdg-open "{out}"')
        except Exception:
            messagebox.showinfo("Open", f"Outdir: {out}")

    def on_run(self):
        if hasattr(self, "_running") and self._running:
            messagebox.showinfo("Info","Already running")
            return
        if self.var_use_ibm.get() and not IBM_RUNTIME_AVAILABLE:
            messagebox.showerror("Error","qiskit-ibm-runtime not installed")
            return
        try:
            shots = int(self.var_shots.get())
            nphi  = int(self.var_nphi.get())
            seed  = int(self.var_seed.get())
            if shots <= 0 or nphi <= 1:
                raise ValueError
        except Exception:
            messagebox.showerror("Error","Invalid shots/nphi/seed")
            return
        self._running = True
        self.btn_run.config(state=tk.DISABLED)
        self.status.set("Running experiment...")
        threading.Thread(target=self._run_experiment_thread, daemon=True).start()

    def _build_aer_backend(self):
        return AerSimulator(seed_simulator=int(self.var_seed.get()))

    def _build_ibm_backend(self):
        token = self.var_ibm_token.get().strip()
        instance = self.var_ibm_instance.get().strip() or None
        service = QiskitRuntimeService(channel="ibm_quantum", token=token, instance=instance) if instance else QiskitRuntimeService(channel="ibm_quantum", token=token)
        backend_name = self.var_ibm_backend.get().strip()
        backend = service.backend(backend_name)
        self.service = service
        self.backend = backend
        return backend

    def _run_experiment_thread(self):
        try:
            self._run_experiment()
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))
        finally:
            self._running = False
            self.btn_run.config(state=tk.NORMAL)

    def _run_experiment(self):
        shots = int(self.var_shots.get())
        nphi = int(self.var_nphi.get())
        save = bool(self.var_save.get())
        save_corr = bool(self.var_save_corr.get())
        outdir = self.var_outdir.get()
        seed = int(self.var_seed.get())
        use_ibm = bool(self.var_use_ibm.get())

        os.makedirs(outdir, exist_ok=True)
        phis = list(np.linspace(0.0, 2.0*math.pi, nphi))

        if use_ibm:
            backend = self._build_ibm_backend()
        else:
            backend = self._build_aer_backend()

        self.line_plus.set_data([], [])
        self.line_minus.set_data([], [])
        self.ax.relim(); self.ax.autoscale_view(); self.canvas.draw_idle()

        p_plus, p_minus = [], []
        self.last_phis = phis
        self.last_p_plus = p_plus
        self.last_p_minus = p_minus
        self.last_memory_by_phi.clear()

        for idx, phi in enumerate(phis):
            qc = eraser_circuit(phi, erase=True)
            if use_ibm:
                tqc = transpile(qc, backend=backend, optimization_level=3, seed_transpiler=seed)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                try:
                    memory = result.get_memory()
                except Exception:
                    memory = []
                    for bitstr, c in counts.items():
                        memory.extend([bitstr] * int(c))
            else:
                tqc = transpile(qc, backend=backend, optimization_level=3)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                memory = result.get_memory()

            self.last_memory_by_phi[idx] = memory

            p0_plus = conditional_prob_from_counts(counts, '0')
            p0_minus = conditional_prob_from_counts(counts, '1')
            p_plus.append(p0_plus); p_minus.append(p0_minus)

            self.line_plus.set_data(phis[:idx+1], p_plus)
            self.line_minus.set_data(phis[:idx+1], p_minus)
            self.ax.relim(); self.ax.autoscale_view(); self.canvas.draw_idle()

            if save:
                s_stream = ''.join(m[1] for m in memory)
                save_text(os.path.join(outdir, f"binary_phi_{idx:02d}.txt"), s_stream)
            if save_corr:
                save_text(os.path.join(outdir, f"pairs_phi_{idx:02d}.txt"), '\n'.join(memory))
                s_i0 = ''.join(m[1] for m in memory if m[0]=='0')
                s_i1 = ''.join(m[1] for m in memory if m[0]=='1')
                save_text(os.path.join(outdir, f"S_given_I0_phi_{idx:02d}.txt"), s_i0)
                save_text(os.path.join(outdir, f"S_given_I1_phi_{idx:02d}.txt"), s_i1)
                import csv
                with open(os.path.join(outdir, f"counts_phi_{idx:02d}.csv"), "w", newline='') as f:
                    w = csv.writer(f)
                    w.writerow(['pair','count'])
                    for pair in ['00','01','10','11']:
                        w.writerow([pair, counts.get(pair, 0)])

            self.status.set(f"φ={phi:.3f} • P(+)= {p0_plus:.3f} • P(-)= {p0_minus:.3f} • {idx+1}/{len(phis)}")

        if save:
            import csv
            with open(os.path.join(outdir, "eraser_probs.csv"), "w", newline='') as f:
                w = csv.writer(f)
                w.writerow(['phi','P(s=0|i=+)','P(s=0|i=-)'])
                for phi,a,b in zip(phis, p_plus, p_minus):
                    w.writerow([phi,a,b])
            self.fig.savefig(os.path.join(outdir, "eraser_plot.png"), dpi=160)

        Vp = visibility(p_plus); Vm = visibility(p_minus)
        self.status.set(f"Done. V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f}  (IBM: {use_ibm})")

    def on_correlate(self):
        if not self.last_p_plus:
            messagebox.showinfo("Info","Run experiment first")
            return
        Vp = visibility(self.last_p_plus); Vm = visibility(self.last_p_minus)
        verdict = "Interferență detectată" if (Vp > 0.25 and Vm > 0.25) else "Interferență slabă/absentă"
        self.ax.text(0.02, 0.96, f"V(+)={Vp:.2f}, V(-)={Vm:.2f}\n{verdict}",
                     transform=self.ax.transAxes, fontsize=10,
                     bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
        self.canvas.draw_idle()
        self.status.set(f"Correlation: V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f} → {verdict}")

    def on_preview(self):
        if not self.last_memory_by_phi:
            messagebox.showinfo("Info","Run first")
            return
        idx = max(self.last_memory_by_phi.keys())
        memory = self.last_memory_by_phi[idx]
        c = Counter(memory)
        preview_pairs = '\n'.join(memory[:200])
        s_i0 = ''.join(m[1] for m in memory if m[0]=='0')[:400]
        s_i1 = ''.join(m[1] for m in memory if m[0]=='1')[:400]

        win = tk.Toplevel(self)
        win.title(f"Preview correlations (phi idx {idx})")
        win.geometry("640x720")
        ttk.Label(win, text=f"First 200 pairs (I S) for phi index {idx}:").pack(anchor="w", padx=8, pady=(8,0))
        txt = tk.Text(win, height=12); txt.pack(fill=tk.BOTH, padx=8)
        txt.insert("1.0", preview_pairs); txt.config(state=tk.DISABLED)

        ttk.Label(win, text="Counts (coincidences):").pack(anchor="w", padx=8, pady=(8,0))
        tree = ttk.Treeview(win, columns=("pair","count"), show="headings", height=4)
        tree.heading("pair", text="pair"); tree.heading("count", text="count")
        for pair in ['00','01','10','11']:
            tree.insert("", "end", values=(pair, c.get(pair,0)))
        tree.pack(fill=tk.X, padx=8, pady=4)

        ttk.Label(win, text="S | I=0 (first 400 bits):").pack(anchor="w", padx=8, pady=(8,0))
        t0 = tk.Text(win, height=4); t0.pack(fill=tk.X, padx=8); t0.insert("1.0", s_i0); t0.config(state=tk.DISABLED)
        ttk.Label(win, text="S | I=1 (first 400 bits):").pack(anchor="w", padx=8, pady=(8,0))
        t1 = tk.Text(win, height=4); t1.pack(fill=tk.X, padx=8); t1.insert("1.0", s_i1); t1.config(state=tk.DISABLED)

    # ---------------- Brute-force ----------------
    def on_run_bruteforce(self, auto_run=False):
        target = self.var_target_file.get().strip()
        if not target or not os.path.isfile(target):
            messagebox.showerror("Error","Select a valid binary file")
            return
        try:
            keylen = int(self.var_keylen.get())
            outchars = int(self.var_outchars.get())
            topn = int(self.var_topn.get())
            parallel = bool(self.var_parallel.get())
            workers = int(self.var_workers.get())
            if keylen <= 0 or outchars <= 0 or topn <= 0:
                raise ValueError
        except Exception:
            messagebox.showerror("Error","Invalid brute-force parameters")
            return

        self.tree.delete(*self.tree.get_children())
        self.status.set("Running brute-force...")
        threading.Thread(target=self._run_bruteforce_thread, args=(target,keylen,outchars,topn,parallel,workers,auto_run), daemon=True).start()

    def _run_bruteforce_thread(self, target, keylen, outchars, topn, parallel, workers, auto_run=False):
        try:
            stego_bits = read_binary_file(target)
            top = run_bruteforce(stego_bits, keylen, outchars, topn, parallel, workers)
            for (k, pr, cs, preview) in top:
                self.tree.insert("", "end", values=(k, f"{pr:.3f}", f"{cs:.1f}", preview))
            report_path = os.path.join(self.var_outdir.get(), "bruteforce_report.txt")
            lines = [f"Brute-force report: target={target}\n"]
            lines.append(f"keylen={keylen} outchars={outchars} parallel={parallel} workers={workers}\n")
            lines.append("Top results:\n")
            for k, pr, cs, preview in top:
                lines.append(f"{k}  printable={pr:.3f} common_words={cs:.1f} preview={preview}\n")
            save_text(report_path, ''.join(lines))
            self.status.set(f"Brute-force finished. Report: {report_path}")

            # Final verdict
            max_score = top[0][1] if top else 0
            threshold = 0.8
            if max_score > threshold:
                verdict = "Experimentul a fost un succes: informația a fost recuperată fără corelație prin brute-force cu lungimea cheii {}.".format(keylen)
            else:
                verdict = "Experimentul nu a reușit: nu s-a putut recupera informația fără corelație."
            
            if not auto_run:
                messagebox.showinfo("Verdict Final", verdict)

            # Salvează verdictul în fișier
            concluzie_path = os.path.join(self.var_outdir.get(), "concluzie.txt")
            save_text(concluzie_path, verdict)

            if auto_run:
                original_message = self.var_binary_message.get().strip()
                best_decoded_message = top[0][3] if top else "N/A"
                self.after(100, lambda: self._show_comparison_window(original_message, best_decoded_message))

        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))

    def _show_comparison_window(self, original, decoded):
        win = tk.Toplevel(self)
        win.title("Rezultat Transmisie Automată")

        ttk.Label(win, text="Mesaj Transmis:", padding=(10,5,10,0)).pack(anchor="w")
        e1 = ttk.Entry(win, width=50)
        e1.pack(padx=10, fill=tk.X)
        e1.insert(0, original)
        e1.config(state='readonly')

        ttk.Label(win, text="Mesaj Decodat (Cel mai bun rezultat):", padding=(10,5,10,0)).pack(anchor="w")
        e2 = ttk.Entry(win, width=50)
        e2.pack(padx=10, fill=tk.X)
        e2.insert(0, decoded)
        e2.config(state='readonly')

        ttk.Button(win, text="OK", command=win.destroy).pack(pady=10)

    def _save_report(self):
        out = filedialog.asksaveasfilename(title="Save report as", defaultextension=".txt", filetypes=[("Text files","*.txt")])
        if not out: return
        rows = []
        for item in self.tree.get_children():
            rows.append(self.tree.item(item)["values"])
        txt = "Brute-force candidates:\n"
        for r in rows:
            txt += f"{r[0]} printable={r[1]} common={r[2]} preview={r[3]}\n"
        save_text(out, txt)
        messagebox.showinfo("Saved", f"Report saved to {out}")

    # ---------------- Analiză non-corelație ----------------
    def on_analyze_noncorr(self):
        if not self.last_memory_by_phi:
            messagebox.showinfo("Info","Run experiment first")
            return
        outdir = self.var_outdir.get()
        rows = []
        for idx in sorted(self.last_memory_by_phi.keys()):
            memory = self.last_memory_by_phi[idx]
            c = Counter(memory)
            mi = mutual_information_from_counts(c)
            chi2 = chi2_stat_from_counts(c)
            gain = predictability_gain_from_counts(c)
            p_i0, p_i1, p_s0, p_s1, pj = _probs_from_counts(c)
            p_s0_i0 = pj['00'] / (pj['00'] + pj['01'] + 1e-15) if (pj['00']+pj['01'])>0 else 0.0
            p_s0_i1 = pj['10'] / (pj['10'] + pj['11'] + 1e-15) if (pj['10']+pj['11'])>0 else 0.0
            hsic = hsic_binary_from_memory(memory)
            delta_p = abs(p_s0_i0 - p_s0_i1)
            noncorr_detect = (mi > 0.02) or (hsic > 0.05) or (abs(p_s0 - 0.5) > 0.05)
            corr_detect = (delta_p > 0.10)
            v_noncorr = "detectabil (fără corelare)" if noncorr_detect else "nedetectabil (fără corelare)"
            v_corr = "clar (cu corelare)" if corr_detect else "slab/absent (cu corelare)"
            rows.append((idx, mi, chi2, gain, hsic, delta_p, v_noncorr, v_corr, p_s0, p_s0_i0, p_s0_i1))
        try:
            import csv
            os.makedirs(outdir, exist_ok=True)
            with open(os.path.join(outdir, "noncorr_metrics.csv"), "w", newline='') as f:
                w = csv.writer(f)
                w.writerow(["phi_idx","MI","chi2","predict_gain","HSIC","DeltaP","verdict_noncorr","verdict_corr","P(S=0)","P(S=0|I=0)","P(S=0|I=1)"])
                for r in rows:
                    w.writerow(list(r))
        except Exception:
            pass

        win = tk.Toplevel(self)
        win.title("Analiză non-corelație pe φ")
        win.geometry("1060x700")
        cols = ("phi_idx","MI","chi2","gain","HSIC","ΔP","verdict_noncorr","verdict_corr","P(S=0)","P(S=0|I=0)","P(S=0|I=1)")
        tree = ttk.Treeview(win, columns=cols, show="headings", height=12)
        for c in cols:
            tree.heading(c, text=c)
            tree.column(c, width=120 if c not in ("phi_idx","verdict_noncorr","verdict_corr") else (80 if c=="phi_idx" else 160), anchor="w")
        rows_sorted = sorted(rows, key=lambda r: r[1], reverse=True)
        for r in rows_sorted:
            tree.insert("", "end", values=(
                r[0], f"{r[1]:.4f}", f"{r[2]:.2f}", f"{r[3]:.4f}", f"{r[4]:.4f}", f"{r[5]:.3f}", r[6], r[7], f"{r[8]:.3f}", f"{r[9]:.3f}", f"{r[10]:.3f}"
            ))
        tree.pack(fill=tk.X, padx=8, pady=8)

        # Plot MI(φ) și ΔP(φ)
        plot_frame = ttk.Frame(win)
        plot_frame.pack(fill=tk.BOTH, expand=True, padx=8, pady=8)
        fig = plt.Figure(figsize=(8.8,3.8))
        ax1 = fig.add_subplot(121)
        ax2 = fig.add_subplot(122)
        idxs = sorted(self.last_memory_by_phi.keys())
        phis_vals = [self.last_phis[i] if i < len(self.last_phis) else i for i in idxs]
        mi_list, dp_list = [], []
        for i in idxs:
            mem_i = self.last_memory_by_phi[i]
            c_i = Counter(mem_i)
            mi_i = mutual_information_from_counts(c_i)
            p_i0, p_i1, p_s0, p_s1, pj_i = _probs_from_counts(c_i)
            p_s0_i0_i = pj_i['00'] / (pj_i['00'] + pj_i['01'] + 1e-15) if (pj_i['00']+pj_i['01'])>0 else 0.0
            p_s0_i1_i = pj_i['10'] / (pj_i['10'] + pj_i['11'] + 1e-15) if (pj_i['10']+pj_i['11'])>0 else 0.0
            mi_list.append(mi_i)
            dp_list.append(abs(p_s0_i0_i - p_s0_i1_i))
        ax1.plot(phis_vals, mi_list, marker='o', color='tab:blue', label='MI')
        ax1.axhline(0.02, color='gray', linestyle='--', linewidth=1, label='prag MI')
        ax1.set_title('Mutual Information vs φ')
        ax1.set_xlabel('φ (radiani)')
        ax1.set_ylabel('MI')
        ax1.grid(True, linestyle='--', alpha=0.4)
        ax1.legend()
        ax2.plot(phis_vals, dp_list, marker='s', color='tab:orange', label='ΔP')
        ax2.axhline(0.10, color='gray', linestyle='--', linewidth=1, label='prag ΔP')
        ax2.set_title('ΔP = |P(S=0|I=0)-P(S=0|I=1)| vs φ')
        ax2.set_xlabel('φ (radiani)')
        ax2.set_ylabel('ΔP')
        ax2.grid(True, linestyle='--', alpha=0.4)
        ax2.legend()
        canvas = FigureCanvasTkAgg(fig, master=plot_frame)
        canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        canvas.draw()

        # Selector φ și rulare Brute-Force pe binary_phi_XX
        ctl = ttk.Frame(win)
        ctl.pack(fill=tk.X, padx=8, pady=(0,8))
        ttk.Label(ctl, text="Selectează φ idx pentru Brute-Force pe binary_phi_XX:").pack(side=tk.LEFT)
        vals = [str(i) for i in idxs]
        var_phi_sel = tk.StringVar(value=vals[0] if vals else "0")
        cb = ttk.Combobox(ctl, values=vals, textvariable=var_phi_sel, width=6, state='readonly')
        cb.pack(side=tk.LEFT, padx=8)
        def run_bf_for_phi():
            try:
                sel = int(var_phi_sel.get())
            except Exception:
                sel = idxs[0] if idxs else 0
            path = os.path.join(outdir, f"binary_phi_{sel:02d}.txt")
            if not os.path.isfile(path):
                messagebox.showerror("Error", f"Nu există fișierul: {path}. Asigură-te că 'save' este activ la run.")
                return
            self.var_target_file.set(path)
            try:
                self.nb.select(0)  # tab Brute-Force
            except Exception:
                pass
            self.on_run_bruteforce()
        ttk.Button(ctl, text="Rulează Brute-Force pe φ selectat", command=run_bf_for_phi).pack(side=tk.LEFT, padx=8)

        total = len(rows)
        noncorr_yes = sum(1 for r in rows if "detectabil" in r[6])
        corr_yes = sum(1 for r in rows if "clar" in r[7])
        ttk.Label(win, text=f"Detectare fără corelare: {noncorr_yes}/{total} • Detectare cu corelare: {corr_yes}/{total}").pack(anchor="w", padx=8)
        ttk.Label(win, text=f"CSV salvat în: {os.path.join(outdir, 'noncorr_metrics.csv')}").pack(anchor="w", padx=8)

        # Verdict final pentru analiza non-corelație
        if noncorr_yes > 0:
            verdict_nc = f"Analiza a identificat cel puțin un caz detectabil fără corelare ({noncorr_yes}/{total}). Experimentul indică existența informației recuperabile fără corelare."
        else:
            verdict_nc = "Analiza nu a identificat cazuri detectabile fără corelare. Informația nu pare recuperabilă fără corelare."
        # Afișează verdictul după ce fereastra și graficele sunt desenate
        win.after(300, lambda: messagebox.showinfo("Verdict Analiză", verdict_nc))
        # Salvează concluzia într-un fișier separat
        concl_path = os.path.join(outdir, "concluzie_analiza.txt")
        try:
            save_text(concl_path, verdict_nc)
        except Exception:
            pass

        # Declanșează transmisia automată dacă este cazul
        if noncorr_yes > 0:
            self.after(500, self._run_auto_transmit_decode)

    def _run_auto_transmit_decode(self):
        message = self.var_binary_message.get().strip()
        if not message or not all(c in '01' for c in message):
            messagebox.showerror("Eroare Mesaj", "Mesajul binar este invalid sau gol.")
            return

        # Generează un semnal purtător aleator
        carrier_len = len(message) * 20  # Raport de 1:20 între mesaj și purtător
        carrier_signal = ''.join(np.random.choice(['0', '1'], size=carrier_len))

        # Intercalează mesajul în semnalul purtător
        stego_signal = list(carrier_signal)
        step = carrier_len // len(message)
        for i, bit in enumerate(message):
            stego_signal[i * step] = bit
        stego_signal = "".join(stego_signal)

        # Salvează semnalul steganografic
        outdir = self.var_outdir.get()
        stego_path = os.path.join(outdir, "stego_signal_auto.txt")
        save_text(stego_path, stego_signal)

        # Rulează brute-force pe semnalul nou
        self.var_target_file.set(stego_path)
        self.on_run_bruteforce(auto_run=True)  # Reutilizează funcția existentă

# --------------------------
# Main
# --------------------------
def main():
    parser = argparse.ArgumentParser(description="Quantum Eraser GUI + Brute-Force")
    parser.add_argument("--no-gui", action="store_true", help="Run non-interactive (not used)")
    args = parser.parse_args()

    app = EraserBruteGUI()
    app.mainloop()

if __name__ == "__main__":
    main()