#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantum Eraser GUI + Correlated Bits + Brute-Force (parallel)
Author: ChatGPT (adaptare pentru utilizator)
Features:
  - Run delayed-choice quantum eraser on Aer or IBM backends
  - Save binary streams, correlated streams (pairs, S|I=0, S|I=1)
  - Live plotting of P(s=0|i=±)
  - Correlation / visibility detection
  - Brute-force module: try binary keys (0/1) to extract hidden ASCII text,
    scoring by printable ratio + word heuristics
  - Optional multiprocessing parallel brute-force
Dependencies:
  pip install "qiskit==1.2.*" "qiskit-aer==0.15.*" matplotlib numpy
  pip install "qiskit-ibm-runtime>=0.25.0"    # only if you want IBM backends
Usage:
  python quantum_eraser_gui_ibm_corr_bruteforce_parallel.py
"""

import os
import math
import threading
import time
import argparse
from collections import Counter
from typing import List, Tuple, Dict, Optional
import hashlib
import multiprocessing

import tkinter as tk
from tkinter import ttk, messagebox, filedialog

import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

# Qiskit imports
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator
# IBM runtime optionally
IBM_RUNTIME_AVAILABLE = False
try:
    from qiskit_ibm_runtime import QiskitRuntimeService
    IBM_RUNTIME_AVAILABLE = True
except Exception:
    IBM_RUNTIME_AVAILABLE = False

# --------------------------
# Quantum circuit helpers
# --------------------------
def eraser_circuit(phi: float, erase: bool = True) -> QuantumCircuit:
    qc = QuantumCircuit(2, 2)
    qc.h(0)
    qc.rz(phi, 0)
    qc.cx(0, 1)
    qc.h(0)
    if erase:
        qc.h(1)
    qc.measure(0, 0)
    qc.measure(1, 1)
    return qc

def conditional_prob_from_counts(counts: Dict[str,int], i_bit: str = '0') -> float:
    tot = n0 = 0
    for k,v in counts.items():
        if len(k) < 2: continue
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0': n0 += v
    return (n0 / tot) if tot else 0.0

def visibility(vals: List[float]) -> float:
    if not vals: return 0.0
    vmax, vmin = max(vals), min(vals)
    denom = vmax + vmin
    return (vmax - vmin) / denom if denom else 0.0

# --------------------------
# I/O helpers
# --------------------------
def save_text(path: str, text: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

def read_binary_file(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        s = f.read().strip()
    # sanitize: keep only 0/1
    return ''.join(ch for ch in s if ch in '01')

# --------------------------
# Brute-force utilities
# --------------------------
COMMON_WORDS = [
    "the","and","that","have","for","not","with","you","this","but","was","are",
    "from","they","his","she","which","will","one","all","would","there","their",
    "în","și","pe","la","care","să","este","nu"
]

def sha256_stream(key_bin: str, length_bits: int) -> str:
    key_bytes = int(key_bin, 2).to_bytes((len(key_bin)+7)//8, 'big')
    out_bits = []
    counter = 0
    while len(out_bits) < length_bits:
        m = key_bytes + counter.to_bytes(4, 'big')
        digest = hashlib.sha256(m).digest()
        for b in digest:
            for i in range(8):
                out_bits.append(str((b >> (7 - i)) & 1))
                if len(out_bits) >= length_bits: break
            if len(out_bits) >= length_bits: break
        counter += 1
    return ''.join(out_bits[:length_bits])

def deterministic_sample_positions(total_bits: int, m: int) -> List[int]:
    if m <= 0: return []
    if m >= total_bits:
        return list(range(total_bits))
    step = total_bits / m
    pos = [int(round(i * step)) % total_bits for i in range(m)]
    seen = set(); res = []
    for p in pos:
        orig = p
        while p in seen:
            p = (p + 1) % total_bits
            if p == orig: break
        seen.add(p); res.append(p)
    return res

def extract_bits_with_key(stego_bits: str, key_bin: str, out_len_bits: int, sample_positions: Optional[List[int]] = None) -> str:
    total = len(stego_bits)
    if sample_positions is None:
        if out_len_bits > total:
            raise ValueError("Requested more bits than available in stego")
        positions = list(range(out_len_bits))
    else:
        positions = sample_positions
        if max(positions) >= total:
            raise ValueError("Sample positions exceed length of stego bits")
    keystream = sha256_stream(key_bin, len(positions))
    out_bits = []
    for kbit, pos in zip(keystream, positions):
        sbit = stego_bits[pos]
        out_bits.append('1' if (sbit != kbit) else '0')  # XOR
    return ''.join(out_bits)

def bits_to_bytes(bits: str) -> bytes:
    pad = (-len(bits)) % 8
    bits_padded = bits + ('0' * pad)
    byts = bytearray()
    for i in range(0, len(bits_padded), 8):
        byte = bits_padded[i:i+8]
        byts.append(int(byte, 2))
    return bytes(byts)

def score_message_bytes(b: bytes) -> Tuple[float, float]:
    if not b:
        return 0.0, 0.0
    printable = sum(1 for x in b if 9 <= x <= 126)
    printable_ratio = printable / len(b)
    try:
        s = b.decode('utf-8', errors='ignore').lower()
    except Exception:
        s = ''
    common_score = sum((s.count(w) for w in COMMON_WORDS))
    return printable_ratio, float(common_score)

def try_key_worker(args) -> Tuple[str, float, float, str]:
    key_bin, stego_bits, out_chars, sample_positions = args
    out_bits_len = out_chars * 8
    try:
        extracted_bits = extract_bits_with_key(stego_bits, key_bin, out_bits_len, sample_positions)
        b = bits_to_bytes(extracted_bits)
        pr, cs = score_message_bytes(b)
        preview = b.decode('utf-8', errors='replace')[:out_chars]
    except Exception:
        pr, cs, preview = 0.0, 0.0, ""
    return (key_bin, pr, cs, preview)

def run_bruteforce(stego_bits: str,
                   keylen: int = 6,
                   out_chars: int = 24,
                   topN: int = 20,
                   parallel: bool = False,
                   workers: int = None) -> List[Tuple[str, float, float, str]]:
    total_keys = 2 ** keylen
    if total_keys > 2000000 and not parallel:
        raise ValueError("Too many keys; enable parallel or reduce keylen")
    out_len_bits = out_chars * 8
    sample_positions = deterministic_sample_positions(len(stego_bits), out_len_bits)
    keys = [format(i, '0{}b'.format(keylen)) for i in range(2**keylen)]
    args_iter = ((k, stego_bits, out_chars, sample_positions) for k in keys)

    results = []
    if parallel:
        workers = workers or max(1, multiprocessing.cpu_count() - 1)
        with multiprocessing.Pool(processes=workers) as pool:
            for res in pool.imap_unordered(try_key_worker, args_iter, chunksize=64):
                results.append(res)
    else:
        for arg in args_iter:
            results.append(try_key_worker(arg))

    combined = []
    for key_bin, pr, cs, preview in results:
        score = pr + 0.02 * cs
        combined.append((key_bin, score, pr, cs, preview))
    combined.sort(key=lambda x: x[1], reverse=True)
    top = [(k, pr, cs, prev) for (k, s, pr, cs, prev) in combined[:topN]]
    return top

# --------------------------
# GUI Application
# --------------------------
class EraserBruteGUI(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Quantum Eraser – GUI + Correlated Bits + BruteForce")
        self.geometry("1180x820")

        # Core parameters
        self.var_shots = tk.IntVar(value=4000)
        self.var_nphi  = tk.IntVar(value=13)
        self.var_seed  = tk.IntVar(value=1234)
        self.var_save  = tk.BooleanVar(value=True)
        self.var_show  = tk.BooleanVar(value=True)
        self.var_outdir = tk.StringVar(value="out")

        # correlated
        self.var_save_corr = tk.BooleanVar(value=True)

        # IBM
        self.var_use_ibm = tk.BooleanVar(value=False)
        self.var_ibm_token = tk.StringVar(value="")
        self.var_ibm_instance = tk.StringVar(value="")
        self.var_ibm_backend = tk.StringVar(value="")

        # Brute-force params
        self.var_keylen = tk.IntVar(value=6)
        self.var_outchars = tk.IntVar(value=24)
        self.var_topn = tk.IntVar(value=20)
        self.var_parallel = tk.BooleanVar(value=False)
        self.var_workers = tk.IntVar(value=max(1, multiprocessing.cpu_count()-1))

        self._build_controls()
        self._build_plot()
        self._build_bruteforce_panel()

        # Storage for last-run
        self.last_phis = []
        self.last_p_plus = []
        self.last_p_minus = []
        self.last_memory_by_phi = {}

        self._running = False

    def _build_controls(self):
        frm = ttk.Frame(self, padding=8)
        frm.pack(side=tk.TOP, fill=tk.X)

        ttk.Label(frm, text="shots:").grid(row=0, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.var_shots, width=8).grid(row=0, column=1, padx=4)
        ttk.Label(frm, text="nphi:").grid(row=0, column=2, sticky="w")
        ttk.Entry(frm, textvariable=self.var_nphi, width=8).grid(row=0, column=3, padx=4)
        ttk.Label(frm, text="seed:").grid(row=0, column=4, sticky="w")
        ttk.Entry(frm, textvariable=self.var_seed, width=8).grid(row=0, column=5, padx=4)
        ttk.Checkbutton(frm, text="save", variable=self.var_save).grid(row=0, column=6, padx=6)
        ttk.Checkbutton(frm, text="show", variable=self.var_show).grid(row=0, column=7, padx=6)
        ttk.Checkbutton(frm, text="save correlated streams", variable=self.var_save_corr).grid(row=0, column=8, padx=6)

        ttk.Label(frm, text="outdir:").grid(row=1, column=0, sticky="w", pady=6)
        ttk.Entry(frm, textvariable=self.var_outdir, width=36).grid(row=1, column=1, columnspan=4, sticky="we", padx=4)
        ttk.Button(frm, text="Browse…", command=self._choose_outdir).grid(row=1, column=5, sticky="w", padx=4)

        ibm = ttk.LabelFrame(self, text="IBM Quantum (optional)", padding=6)
        ibm.pack(side=tk.TOP, fill=tk.X, padx=8, pady=6)
        ttk.Checkbutton(ibm, text="Use IBM Quantum", variable=self.var_use_ibm).grid(row=0, column=0, sticky="w")
        ttk.Label(ibm, text="API Token:").grid(row=0, column=1, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_token, width=44, show="•").grid(row=0, column=2, padx=4, sticky="we")
        ttk.Label(ibm, text="Instance:").grid(row=0, column=3, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_instance, width=20).grid(row=0, column=4, padx=4)
        ttk.Label(ibm, text="Backend:").grid(row=0, column=5, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_backend, width=18).grid(row=0, column=6, padx=4)

        btns = ttk.Frame(self, padding=6)
        btns.pack(side=tk.TOP, fill=tk.X)
        self.btn_run = ttk.Button(btns, text="Run Experiment", command=self.on_run)
        self.btn_run.pack(side=tk.LEFT, padx=4)
        self.btn_correlate = ttk.Button(btns, text="Correlate / Detect", command=self.on_correlate)
        self.btn_correlate.pack(side=tk.LEFT, padx=4)
        self.btn_preview = ttk.Button(btns, text="Show Last Phi Correlations", command=self.on_preview)
        self.btn_preview.pack(side=tk.LEFT, padx=4)

    def _build_plot(self):
        self.fig = plt.Figure(figsize=(8.5,5))
        self.ax = self.fig.add_subplot(111)
        self.ax.set_title("Quantum Eraser – P(s=0|i=±)")
        self.ax.set_xlabel("φ (radiani)")
        self.ax.set_ylabel("Probabilitate")
        (self.line_plus,) = self.ax.plot([], [], marker='o', label="P(s=0|i=+)")
        (self.line_minus,) = self.ax.plot([], [], marker='s', label="P(s=0|i=-)")
        self.ax.legend(); self.ax.grid(True, linestyle='--', alpha=0.4)
        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        self.status = tk.StringVar(value="Ready")
        ttk.Label(self, textvariable=self.status, anchor="w", padding=6).pack(side=tk.BOTTOM, fill=tk.X)

    def _build_bruteforce_panel(self):
        panel = ttk.LabelFrame(self, text="Brute-Force (analyze binary streams)", padding=6)
        panel.pack(side=tk.BOTTOM, fill=tk.X, padx=8, pady=6)

        ttk.Label(panel, text="Select binary file:").grid(row=0, column=0, sticky="w")
        self.var_target_file = tk.StringVar(value="")
        ttk.Entry(panel, textvariable=self.var_target_file, width=56).grid(row=0, column=1, columnspan=4, sticky="we", padx=4)
        ttk.Button(panel, text="Browse", command=self._choose_target_file).grid(row=0, column=5, padx=4)

        ttk.Label(panel, text="keylen:").grid(row=1, column=0, sticky="w", pady=4)
        ttk.Entry(panel, textvariable=self.var_keylen, width=6).grid(row=1, column=1, padx=4)
        ttk.Label(panel, text="out_chars:").grid(row=1, column=2, sticky="w")
        ttk.Entry(panel, textvariable=self.var_outchars, width=6).grid(row=1, column=3, padx=4)
        ttk.Label(panel, text="topN:").grid(row=1, column=4, sticky="w")
        ttk.Entry(panel, textvariable=self.var_topn, width=6).grid(row=1, column=5, padx=4)

        ttk.Checkbutton(panel, text="parallel", variable=self.var_parallel).grid(row=2, column=0, padx=4, pady=4)
        ttk.Label(panel, text="workers:").grid(row=2, column=1, sticky="e")
        ttk.Entry(panel, textvariable=self.var_workers, width=6).grid(row=2, column=2, padx=4)
        ttk.Button(panel, text="Run Brute-Force", command=self.on_run_bruteforce).grid(row=2, column=3, padx=6)

        cols = ("key","printable","common_words","preview")
        self.tree = ttk.Treeview(panel, columns=cols, show="headings", height=8)
        for c in cols:
            self.tree.heading(c, text=c)
            self.tree.column(c, width=180 if c!="preview" else 360, anchor="w")
        self.tree.grid(row=3, column=0, columnspan=6, sticky="nsew", pady=6)
        panel.grid_rowconfigure(3, weight=1)

        ttk.Button(panel, text="Save Report", command=self._save_report).grid(row=4, column=0, padx=4, pady=6)
        ttk.Button(panel, text="Open Outdir", command=self._open_outdir).grid(row=4, column=1, padx=4, pady=6)

    # ---------------- UI actions ----------------
    def _choose_outdir(self):
        d = filedialog.askdirectory(title="Select outdir")
        if d: self.var_outdir.set(d)

    def _choose_target_file(self):
        f = filedialog.askopenfilename(title="Select binary file", filetypes=[("Text files","*.txt"),("All files","*.*")])
        if f: self.var_target_file.set(f)

    def _open_outdir(self):
        out = self.var_outdir.get()
        try:
            if os.name == 'nt':
                os.startfile(out)
            elif os.name == 'posix':
                os.system(f'xdg-open "{out}"')
        except Exception:
            messagebox.showinfo("Open", f"Outdir: {out}")

    def on_run(self):
        if hasattr(self, "_running") and self._running:
            messagebox.showinfo("Info","Already running")
            return
        if self.var_use_ibm.get() and not IBM_RUNTIME_AVAILABLE:
            messagebox.showerror("Error","qiskit-ibm-runtime not installed")
            return
        try:
            shots = int(self.var_shots.get())
            nphi  = int(self.var_nphi.get())
            seed  = int(self.var_seed.get())
            if shots <= 0 or nphi <= 1:
                raise ValueError
        except Exception:
            messagebox.showerror("Error","Invalid shots/nphi/seed")
            return
        self._running = True
        self.btn_run.config(state=tk.DISABLED)
        self.status.set("Running experiment...")
        threading.Thread(target=self._run_experiment_thread, daemon=True).start()

    def _build_aer_backend(self):
        return AerSimulator(seed_simulator=int(self.var_seed.get()))

    def _build_ibm_backend(self):
        token = self.var_ibm_token.get().strip()
        instance = self.var_ibm_instance.get().strip() or None
        service = QiskitRuntimeService(channel="ibm_quantum", token=token, instance=instance) if instance else QiskitRuntimeService(channel="ibm_quantum", token=token)
        backend_name = self.var_ibm_backend.get().strip()
        backend = service.backend(backend_name)
        self.service = service
        self.backend = backend
        return backend

    def _run_experiment_thread(self):
        try:
            self._run_experiment()
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))
        finally:
            self._running = False
            self.btn_run.config(state=tk.NORMAL)

    def _run_experiment(self):
        shots = int(self.var_shots.get())
        nphi = int(self.var_nphi.get())
        save = bool(self.var_save.get())
        save_corr = bool(self.var_save_corr.get())
        outdir = self.var_outdir.get()
        seed = int(self.var_seed.get())
        use_ibm = bool(self.var_use_ibm.get())

        os.makedirs(outdir, exist_ok=True)
        phis = list(np.linspace(0.0, 2.0*math.pi, nphi))

        if use_ibm:
            backend = self._build_ibm_backend()
        else:
            backend = self._build_aer_backend()

        self.line_plus.set_data([], [])
        self.line_minus.set_data([], [])
        self.ax.relim(); self.ax.autoscale_view(); self.canvas.draw_idle()

        p_plus, p_minus = [], []
        self.last_phis = phis
        self.last_p_plus = p_plus
        self.last_p_minus = p_minus
        self.last_memory_by_phi.clear()

        for idx, phi in enumerate(phis):
            qc = eraser_circuit(phi, erase=True)
            if use_ibm:
                tqc = transpile(qc, backend=backend, optimization_level=3, seed_transpiler=seed)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                try:
                    memory = result.get_memory()
                except Exception:
                    memory = []
                    for bitstr, c in counts.items():
                        memory.extend([bitstr] * int(c))
            else:
                tqc = transpile(qc, backend=backend, optimization_level=3)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                memory = result.get_memory()

            self.last_memory_by_phi[idx] = memory

            p0_plus = conditional_prob_from_counts(counts, '0')
            p0_minus = conditional_prob_from_counts(counts, '1')
            p_plus.append(p0_plus); p_minus.append(p0_minus)

            self.line_plus.set_data(phis[:idx+1], p_plus)
            self.line_minus.set_data(phis[:idx+1], p_minus)
            self.ax.relim(); self.ax.autoscale_view(); self.canvas.draw_idle()

            if save:
                s_stream = ''.join(m[1] for m in memory)
                save_text(os.path.join(outdir, f"binary_phi_{idx:02d}.txt"), s_stream)
            if save_corr:
                save_text(os.path.join(outdir, f"pairs_phi_{idx:02d}.txt"), '\n'.join(memory))
                s_i0 = ''.join(m[1] for m in memory if m[0]=='0')
                s_i1 = ''.join(m[1] for m in memory if m[0]=='1')
                save_text(os.path.join(outdir, f"S_given_I0_phi_{idx:02d}.txt"), s_i0)
                save_text(os.path.join(outdir, f"S_given_I1_phi_{idx:02d}.txt"), s_i1)
                import csv
                with open(os.path.join(outdir, f"counts_phi_{idx:02d}.csv"), "w", newline='') as f:
                    w = csv.writer(f)
                    w.writerow(['pair','count'])
                    for pair in ['00','01','10','11']:
                        w.writerow([pair, counts.get(pair, 0)])

            self.status.set(f"φ={phi:.3f} • P(+)= {p0_plus:.3f} • P(-)= {p0_minus:.3f} • {idx+1}/{len(phis)}")

        if save:
            import csv
            with open(os.path.join(outdir, "eraser_probs.csv"), "w", newline='') as f:
                w = csv.writer(f)
                w.writerow(['phi','P(s=0|i=+)','P(s=0|i=-)'])
                for phi,a,b in zip(phis, p_plus, p_minus):
                    w.writerow([phi,a,b])
            self.fig.savefig(os.path.join(outdir, "eraser_plot.png"), dpi=160)

        Vp = visibility(p_plus); Vm = visibility(p_minus)
        self.status.set(f"Done. V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f}  (IBM: {use_ibm})")

    def on_correlate(self):
        if not self.last_p_plus:
            messagebox.showinfo("Info","Run experiment first")
            return
        Vp = visibility(self.last_p_plus); Vm = visibility(self.last_p_minus)
        verdict = "Interferență detectată" if (Vp > 0.25 and Vm > 0.25) else "Interferență slabă/absentă"
        self.ax.text(0.02, 0.96, f"V(+)={Vp:.2f}, V(-)={Vm:.2f}\n{verdict}",
                     transform=self.ax.transAxes, fontsize=10,
                     bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
        self.canvas.draw_idle()
        self.status.set(f"Correlation: V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f} → {verdict}")

    def on_preview(self):
        if not self.last_memory_by_phi:
            messagebox.showinfo("Info","Run first")
            return
        idx = max(self.last_memory_by_phi.keys())
        memory = self.last_memory_by_phi[idx]
        c = Counter(memory)
        preview_pairs = '\n'.join(memory[:200])
        s_i0 = ''.join(m[1] for m in memory if m[0]=='0')[:400]
        s_i1 = ''.join(m[1] for m in memory if m[0]=='1')[:400]

        win = tk.Toplevel(self)
        win.title(f"Preview correlations (phi idx {idx})")
        win.geometry("640x720")
        ttk.Label(win, text=f"First 200 pairs (I S) for phi index {idx}:").pack(anchor="w", padx=8, pady=(8,0))
        txt = tk.Text(win, height=12); txt.pack(fill=tk.BOTH, padx=8)
        txt.insert("1.0", preview_pairs); txt.config(state=tk.DISABLED)

        ttk.Label(win, text="Counts (coincidences):").pack(anchor="w", padx=8, pady=(8,0))
        tree = ttk.Treeview(win, columns=("pair","count"), show="headings", height=4)
        tree.heading("pair", text="pair"); tree.heading("count", text="count")
        for pair in ['00','01','10','11']:
            tree.insert("", "end", values=(pair, c.get(pair,0)))
        tree.pack(fill=tk.X, padx=8, pady=4)

        ttk.Label(win, text="S | I=0 (first 400 bits):").pack(anchor="w", padx=8, pady=(8,0))
        t0 = tk.Text(win, height=4); t0.pack(fill=tk.X, padx=8); t0.insert("1.0", s_i0); t0.config(state=tk.DISABLED)
        ttk.Label(win, text="S | I=1 (first 400 bits):").pack(anchor="w", padx=8, pady=(8,0))
        t1 = tk.Text(win, height=4); t1.pack(fill=tk.X, padx=8); t1.insert("1.0", s_i1); t1.config(state=tk.DISABLED)

    # ---------------- Brute-force ----------------
    def on_run_bruteforce(self):
        target = self.var_target_file.get().strip()
        if not target or not os.path.isfile(target):
            messagebox.showerror("Error","Select a valid binary file")
            return
        try:
            keylen = int(self.var_keylen.get())
            outchars = int(self.var_outchars.get())
            topn = int(self.var_topn.get())
            parallel = bool(self.var_parallel.get())
            workers = int(self.var_workers.get())
            if keylen <= 0 or outchars <= 0 or topn <= 0:
                raise ValueError
        except Exception:
            messagebox.showerror("Error","Invalid brute-force parameters")
            return

        self.tree.delete(*self.tree.get_children())
        self.status.set("Running brute-force...")
        threading.Thread(target=self._run_bruteforce_thread, args=(target,keylen,outchars,topn,parallel,workers), daemon=True).start()

    def _run_bruteforce_thread(self, target, keylen, outchars, topn, parallel, workers):
        try:
            stego_bits = read_binary_file(target)
            top = run_bruteforce(stego_bits, keylen, outchars, topn, parallel, workers)
            for (k, pr, cs, preview) in top:
                self.tree.insert("", "end", values=(k, f"{pr:.3f}", f"{cs:.1f}", preview))
            report_path = os.path.join(self.var_outdir.get(), "bruteforce_report.txt")
            lines = [f"Brute-force report: target={target}\n"]
            lines.append(f"keylen={keylen} outchars={outchars} parallel={parallel} workers={workers}\n")
            lines.append("Top results:\n")
            for k, pr, cs, preview in top:
                lines.append(f"{k}  printable={pr:.3f} common_words={cs:.1f} preview={preview}\n")
            save_text(report_path, ''.join(lines))
            self.status.set(f"Brute-force finished. Report: {report_path}")
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))

    def _save_report(self):
        out = filedialog.asksaveasfilename(title="Save report as", defaultextension=".txt", filetypes=[("Text files","*.txt")])
        if not out: return
        rows = []
        for item in self.tree.get_children():
            rows.append(self.tree.item(item)["values"])
        txt = "Brute-force candidates:\n"
        for r in rows:
            txt += f"{r[0]} printable={r[1]} common={r[2]} preview={r[3]}\n"
        save_text(out, txt)
        messagebox.showinfo("Saved", f"Report saved to {out}")

# --------------------------
# Main
# --------------------------
def main():
    parser = argparse.ArgumentParser(description="Quantum Eraser GUI + Brute-Force")
    parser.add_argument("--no-gui", action="store_true", help="Run non-interactive (not used)")
    args = parser.parse_args()

    app = EraserBruteGUI()
    app.mainloop()

if __name__ == "__main__":
    main()
