#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantum Eraser GUI + Correlated Bits + Brute-Force (parallel)
Author: ChatGPT (adaptare pentru utilizator)
Features:
  - Run delayed-choice quantum eraser on Aer or IBM backends
  - Save binary streams, correlated streams (pairs, S|I=0, S|I=1)
  - Live plotting of P(s=0|i=±)
  - Correlation / visibility detection
  - Brute-force module: try binary keys (0/1) to extract hidden ASCII text,
    scoring by printable ratio + word heuristics
  - Optional multiprocessing parallel brute-force
Dependencies:
  pip install "qiskit==1.2.*" "qiskit-aer==0.15.*" matplotlib numpy
  pip install "qiskit-ibm-runtime>=0.25.0"    # only if you want IBM backends
Usage:
  python quantum_eraser_gui_ibm_corr_bruteforce_parallel.py
"""
import os
import math
import threading
import time
import argparse
from collections import Counter, deque
from typing import List, Tuple, Dict, Optional
import hashlib
import itertools
import multiprocessing

import tkinter as tk
from tkinter import ttk, messagebox, filedialog

import numpy as np
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

# Qiskit imports
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator
# IBM runtime optionally
IBM_RUNTIME_AVAILABLE = False
try:
    from qiskit_ibm_runtime import QiskitRuntimeService
    IBM_RUNTIME_AVAILABLE = True
except Exception:
    IBM_RUNTIME_AVAILABLE = False

# --------------------------
# Quantum circuit helpers
# --------------------------
def eraser_circuit(phi: float, erase: bool = True) -> QuantumCircuit:
    qc = QuantumCircuit(2, 2)
    qc.h(0)
    qc.rz(phi, 0)
    qc.cx(0, 1)
    qc.h(0)
    if erase:
        qc.h(1)
    qc.measure(0, 0)
    qc.measure(1, 1)
    return qc

def conditional_prob_from_counts(counts: Dict[str,int], i_bit: str = '0') -> float:
    tot = n0 = 0
    for k,v in counts.items():
        if len(k) < 2: continue
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0': n0 += v
    return (n0 / tot) if tot else 0.0

def visibility(vals: List[float]) -> float:
    if not vals: return 0.0
    vmax, vmin = max(vals), min(vals)
    denom = vmax + vmin
    return (vmax - vmin) / denom if denom else 0.0

# --------------------------
# I/O helpers
# --------------------------
def save_text(path: str, text: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

def read_binary_file(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        s = f.read().strip()
    # sanitize: keep only 0/1
    return ''.join(ch for ch in s if ch in '01')

# --------------------------
# Brute-force utilities
# --------------------------
# Heuristic common words (english + a few ro words). Adjust as needed.
COMMON_WORDS = [
    "the","and","that","have","for","not","with","you","this","but","was","are",
    "from","they","his","she","which","will","one","all","would","there","their",
    "în","și","pe","la","care","să","este","nu"
]

def sha256_stream(key_bin: str, length_bits: int) -> str:
    """
    Deterministic pseudorandom keystream derived from a binary key string.
    Generates 'length_bits' bits as a string '0101...'.
    We use repeated SHA256(key || counter) output.
    """
    key_bytes = int(key_bin, 2).to_bytes((len(key_bin)+7)//8, 'big')
    out_bits = []
    counter = 0
    while len(out_bits) < length_bits:
        m = key_bytes + counter.to_bytes(4, 'big')
        digest = hashlib.sha256(m).digest()
        # convert digest bytes to bits
        for b in digest:
            for i in range(8):
                out_bits.append(str((b >> (7 - i)) & 1))
                if len(out_bits) >= length_bits: break
            if len(out_bits) >= length_bits: break
        counter += 1
    return ''.join(out_bits[:length_bits])

def deterministic_sample_positions(total_bits: int, m: int) -> List[int]:
    """
    Return deterministic sample positions (0-based) inside [0,total_bits-1].
    m = number of bits to sample.
    We'll use a simple step to cover the stream uniformly: step = total_bits // m
    If m > total_bits, we cap.
    """
    if m <= 0: return []
    if m >= total_bits:
        return list(range(total_bits))
    step = total_bits / m
    pos = [int(round(i * step)) % total_bits for i in range(m)]
    # ensure unique and within range by adjusting if necessary
    seen = set()
    res = []
    for p in pos:
        orig = p
        while p in seen:
            p = (p + 1) % total_bits
            if p == orig:
                break
        seen.add(p)
        res.append(p)
    return res

def extract_bits_with_key(stego_bits: str, key_bin: str, out_len_bits: int, sample_positions: Optional[List[int]] = None) -> str:
    """
    Derive a keystream from key_bin and XOR the stream with stego_bits at chosen sample_positions.
    If sample_positions is None -> default sample first out_len_bits positions.
    We require stego_bits length >= required positions.
    Return extracted message bits as string.
    """
    total = len(stego_bits)
    if sample_positions is None:
        if out_len_bits > total:
            raise ValueError("Requested more bits than available in stego")
        positions = list(range(out_len_bits))
    else:
        positions = sample_positions
        if max(positions) >= total:
            raise ValueError("Sample positions exceed length of stego bits")
    keystream = sha256_stream(key_bin, len(positions))
    out_bits = []
    for kbit, pos in zip(keystream, positions):
        sbit = stego_bits[pos]
        out_bits.append('1' if (sbit != kbit) else '0')  # XOR -> extract
    return ''.join(out_bits)

def bits_to_bytes(bits: str) -> bytes:
    # pad to full bytes at end with zeros
    pad = (-len(bits)) % 8
    bits_padded = bits + ('0' * pad)
    byts = bytearray()
    for i in range(0, len(bits_padded), 8):
        byte = bits_padded[i:i+8]
        byts.append(int(byte, 2))
    return bytes(byts)

def score_message_bytes(b: bytes) -> Tuple[float, float]:
    """
    Return a tuple (printable_ratio, common_word_score)
    Printable ratio: fraction of bytes within 9..126 printable range (and newline/space)
    Common word score: count of common words matched in lower ascii decode (approx)
    """
    if not b:
        return 0.0, 0.0
    printable = sum(1 for x in b if 9 <= x <= 126)
    printable_ratio = printable / len(b)
    try:
        s = b.decode('utf-8', errors='ignore').lower()
    except Exception:
        s = ''
    common_score = sum((s.count(w) for w in COMMON_WORDS))
    return printable_ratio, float(common_score)

# --------------------------
# Brute-force worker (single key)
# --------------------------
def try_key_worker(args) -> Tuple[str, float, float, str]:
    """
    args: (key_bin, stego_bits, out_chars, sample_positions)
    Returns: (key_bin, printable_ratio, common_score, decoded_preview)
    """
    key_bin, stego_bits, out_chars, sample_positions = args
    out_bits_len = out_chars * 8
    try:
        extracted_bits = extract_bits_with_key(stego_bits, key_bin, out_bits_len, sample_positions)
        b = bits_to_bytes(extracted_bits)
        pr, cs = score_message_bytes(b)
        preview = b.decode('utf-8', errors='replace')[:out_chars]
    except Exception as e:
        pr, cs = 0.0, 0.0
        preview = ""
    return (key_bin, pr, cs, preview)

# --------------------------
# Parallel brute-force runner
# --------------------------
def run_bruteforce(stego_bits: str,
                   keylen: int = 6,
                   out_chars: int = 24,
                   topN: int = 20,
                   parallel: bool = False,
                   workers: int = None) -> List[Tuple[str, float, float, str]]:
    """
    Try all keys of length keylen (binary). Return topN results sorted by score (pr + cs*alpha)
    """
    total_keys = 2 ** keylen
    if total_keys > 2000000 and not parallel:
        raise ValueError("Very many keys; enable parallel or reduce keylen")

    # sample positions to extract: choose deterministic positions covering the stego stream
    out_len_bits = out_chars * 8
    sample_positions = deterministic_sample_positions(len(stego_bits), out_len_bits)

    keys = [format(i, '0{}b'.format(keylen)) for i in range(2**keylen)]
    args_iter = ((k, stego_bits, out_chars, sample_positions) for k in keys)

    results = []
    start = time.time()

    if parallel:
        workers = workers or max(1, multiprocessing.cpu_count() - 1)
        with multiprocessing.Pool(processes=workers) as pool:
            for res in pool.imap_unordered(try_key_worker, args_iter, chunksize=64):
                results.append(res)
    else:
        for arg in args_iter:
            results.append(try_key_worker(arg))

    # scoring: combine printable_ratio and common_word_score (weighted)
    combined = []
    for key_bin, pr, cs, preview in results:
        score = pr + 0.02 * cs  # weight common words small
        combined.append((key_bin, score, pr, cs, preview))
    combined.sort(key=lambda x: x[1], reverse=True)
    top = [(k, pr, cs, prev) for (k, s, pr, cs, prev) in combined[:topN]]

    elapsed = time.time() - start
    return top

# --------------------------
# GUI Application
# --------------------------
class EraserBruteGUI(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Quantum Eraser – GUI + Correlated Bits + BruteForce")
        self.geometry("1180x820")

        # Core parameters
        self.var_shots = tk.IntVar(value=4000)
        self.var_nphi  = tk.IntVar(value=13)
        self.var_seed  = tk.IntVar(value=1234)
        self.var_save  = tk.BooleanVar(value=True)
        self.var_show  = tk.BooleanVar(value=True)
        self.var_outdir = tk.StringVar(value="out")

        # correlated
        self.var_save_corr = tk.BooleanVar(value=True)

        # IBM
        self.var_use_ibm = tk.BooleanVar(value=False)
        self.var_ibm_token = tk.StringVar(value="")
        self.var_ibm_instance = tk.StringVar(value="")
        self.var_ibm_backend = tk.StringVar(value="")

        # Brute-force params
        self.var_keylen = tk.IntVar(value=6)
        self.var_outchars = tk.IntVar(value=24)
        self.var_topn = tk.IntVar(value=20)
        self.var_parallel = tk.BooleanVar(value=False)
        self.var_workers = tk.IntVar(value=max(1, multiprocessing.cpu_count()-1))

        # UI Layout
        self._build_controls()
        self._build_plot()
        self._build_bruteforce_panel()

        # Storage for last-run
        self.last_phis = []
        self.last_p_plus = []
        self.last_p_minus = []
        self.last_memory_by_phi = {}  # idx -> memory list

        self._running = False

    def _build_controls(self):
        frm = ttk.Frame(self, padding=8)
        frm.pack(side=tk.TOP, fill=tk.X)

        # quantum options
        ttk.Label(frm, text="shots:").grid(row=0, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.var_shots, width=8).grid(row=0, column=1, padx=4)
        ttk.Label(frm, text="nphi:").grid(row=0, column=2, sticky="w")
        ttk.Entry(frm, textvariable=self.var_nphi, width=8).grid(row=0, column=3, padx=4)
        ttk.Label(frm, text="seed:").grid(row=0, column=4, sticky="w")
        ttk.Entry(frm, textvariable=self.var_seed, width=8).grid(row=0, column=5, padx=4)
        ttk.Checkbutton(frm, text="save", variable=self.var_save).grid(row=0, column=6, padx=6)
        ttk.Checkbutton(frm, text="show", variable=self.var_show).grid(row=0, column=7, padx=6)
        ttk.Checkbutton(frm, text="save correlated streams", variable=self.var_save_corr).grid(row=0, column=8, padx=6)

        ttk.Label(frm, text="outdir:").grid(row=1, column=0, sticky="w", pady=6)
        ttk.Entry(frm, textvariable=self.var_outdir, width=36).grid(row=1, column=1, columnspan=4, sticky="we", padx=4)
        ttk.Button(frm, text="Browse…", command=self._choose_outdir).grid(row=1, column=5, sticky="w", padx=4)

        # IBM frame
        ibm = ttk.LabelFrame(self, text="IBM Quantum (optional)", padding=6)
        ibm.pack(side=tk.TOP, fill=tk.X, padx=8, pady=6)
        ttk.Checkbutton(ibm, text="Use IBM Quantum", variable=self.var_use_ibm).grid(row=0, column=0, sticky="w")
        ttk.Label(ibm, text="API Token:").grid(row=0, column=1, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_token, width=44, show="•").grid(row=0, column=2, padx=4, sticky="we")
        ttk.Label(ibm, text="Instance:").grid(row=0, column=3, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_instance, width=20).grid(row=0, column=4, padx=4)
        ttk.Label(ibm, text="Backend:").grid(row=0, column=5, sticky="e")
        ttk.Entry(ibm, textvariable=self.var_ibm_backend, width=18).grid(row=0, column=6, padx=4)

        # Action buttons
        btns = ttk.Frame(self, padding=6)
        btns.pack(side=tk.TOP, fill=tk.X)
        self.btn_run = ttk.Button(btns, text="Run Experiment", command=self.on_run)
        self.btn_run.pack(side=tk.LEFT, padx=4)
        self.btn_correlate = ttk.Button(btns, text="Correlate / Detect", command=self.on_correlate)
        self.btn_correlate.pack(side=tk.LEFT, padx=4)
        self.btn_preview = ttk.Button(btns, text="Show Last Phi Correlations", command=self.on_preview)
        self.btn_preview.pack(side=tk.LEFT, padx=4)

    def _build_plot(self):
        self.fig = plt.Figure(figsize=(8.5,5))
        self.ax = self.fig.add_subplot(111)
        self.ax.set_title("Quantum Eraser – P(s=0|i=±)")
        self.ax.set_xlabel("φ (radiani)")
        self.ax.set_ylabel("Probabilitate")
        (self.line_plus,) = self.ax.plot([], [], marker='o', label="P(s=0|i=+)")
        (self.line_minus,) = self.ax.plot([], [], marker='s', label="P(s=0|i=-)")
        self.ax.legend(); self.ax.grid(True, linestyle='--', alpha=0.4)
        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.get_tk_widget().pack(side=tk.TOP, fill=tk.BOTH, expand=True)

        # status bar
        self.status = tk.StringVar(value="Ready")
        ttk.Label(self, textvariable=self.status, anchor="w", padding=6).pack(side=tk.BOTTOM, fill=tk.X)

    def _build_bruteforce_panel(self):
        panel = ttk.LabelFrame(self, text="Brute-Force (analyze binary streams)", padding=6)
        panel.pack(side=tk.BOTTOM, fill=tk.X, padx=8, pady=6)

        ttk.Label(panel, text="Select binary file:").grid(row=0, column=0, sticky="w")
        self.var_target_file = tk.StringVar(value="")
        ttk.Entry(panel, textvariable=self.var_target_file, width=56).grid(row=0, column=1, columnspan=4, sticky="we", padx=4)
        ttk.Button(panel, text="Browse", command=self._choose_target_file).grid(row=0, column=5, padx=4)

        ttk.Label(panel, text="keylen:").grid(row=1, column=0, sticky="w", pady=4)
        ttk.Entry(panel, textvariable=self.var_keylen, width=6).grid(row=1, column=1, padx=4)
        ttk.Label(panel, text="out_chars:").grid(row=1, column=2, sticky="w")
        ttk.Entry(panel, textvariable=self.var_outchars, width=6).grid(row=1, column=3, padx=4)
        ttk.Label(panel, text="topN:").grid(row=1, column=4, sticky="w")
        ttk.Entry(panel, textvariable=self.var_topn, width=6).grid(row=1, column=5, padx=4)

        ttk.Checkbutton(panel, text="parallel", variable=self.var_parallel).grid(row=2, column=0, padx=4, pady=4)
        ttk.Label(panel, text="workers:").grid(row=2, column=1, sticky="e")
        ttk.Entry(panel, textvariable=self.var_workers, width=6).grid(row=2, column=2, padx=4)
        ttk.Button(panel, text="Run Brute-Force", command=self.on_run_bruteforce).grid(row=2, column=3, padx=6)

        # results tree
        cols = ("key","printable","common_words","preview")
        self.tree = ttk.Treeview(panel, columns=cols, show="headings", height=8)
        for c in cols:
            self.tree.heading(c, text=c)
            self.tree.column(c, width=180 if c!="preview" else 360, anchor="w")
        self.tree.grid(row=3, column=0, columnspan=6, sticky="nsew", pady=6)
        panel.grid_rowconfigure(3, weight=1)

        # export buttons
        ttk.Button(panel, text="Save Report", command=self._save_report).grid(row=4, column=0, padx=4, pady=6)
        ttk.Button(panel, text="Open Outdir", command=self._open_outdir).grid(row=4, column=1, padx=4, pady=6)

    # ---------------- UI actions ----------------
    def _choose_outdir(self):
        d = filedialog.askdirectory(title="Select outdir")
        if d: self.var_outdir.set(d)

    def _choose_target_file(self):
        f = filedialog.askopenfilename(title="Select binary file", filetypes=[("Text files","*.txt"),("All files","*.*")])
        if f: self.var_target_file.set(f)

    def _open_outdir(self):
        out = self.var_outdir.get()
        try:
            if os.name == 'nt':
                os.startfile(out)
            elif os.name == 'posix':
                os.system(f'xdg-open "{out}"')
        except Exception:
            messagebox.showinfo("Open", f"Outdir: {out}")

    def on_run(self):
        if self._running:
            messagebox.showinfo("Info","Already running")
            return
        if self.var_use_ibm.get() and not IBM_RUNTIME_AVAILABLE:
            messagebox.showerror("Error","qiskit-ibm-runtime not installed")
            return
        try:
            shots = int(self.var_shots.get())
            nphi  = int(self.var_nphi.get())
            seed  = int(self.var_seed.get())
            if shots <= 0 or nphi <= 1:
                raise ValueError
        except Exception:
            messagebox.showerror("Error","Invalid shots/nphi/seed")
            return
        self._running = True
        self.btn_run.config(state=tk.DISABLED)
        self.status.set("Running experiment...")
        threading.Thread(target=self._run_experiment_thread, daemon=True).start()

    def _build_aer_backend(self):
        return AerSimulator(seed_simulator=int(self.var_seed.get()))

    def _build_ibm_backend(self):
        token = self.var_ibm_token.get().strip()
        instance = self.var_ibm_instance.get().strip() or None
        service = QiskitRuntimeService(channel="ibm_quantum", token=token, instance=instance) if instance else QiskitRuntimeService(channel="ibm_quantum", token=token)
        backend_name = self.var_ibm_backend.get().strip()
        backend = service.backend(backend_name)
        self.service = service
        self.backend = backend
        return backend

    def _run_experiment_thread(self):
        try:
            self._run_experiment()
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))
        finally:
            self._running = False
            self.btn_run.config(state=tk.NORMAL)

    def _run_experiment(self):
        shots = int(self.var_shots.get())
        nphi = int(self.var_nphi.get())
        save = bool(self.var_save.get())
        save_corr = bool(self.var_save_corr.get())
        outdir = self.var_outdir.get()
        seed = int(self.var_seed.get())
        use_ibm = bool(self.var_use_ibm.get())

        os.makedirs(outdir, exist_ok=True)
        phis = list(np.linspace(0.0, 2.0*math.pi, nphi))

        # backend
        if use_ibm:
            backend = self._build_ibm_backend()
        else:
            backend = self._build_aer_backend()

        # reset plot
        self.line_plus.set_data([], [])
        self.line_minus.set_data([], [])
        self.ax.relim(); self.ax.autoscale_view(); self.canvas.draw_idle()

        p_plus, p_minus = [], []
        self.last_phis = phis
        self.last_p_plus = p_plus
        self.last_p_minus = p_minus
        self.last_memory_by_phi.clear()

        for idx, phi in enumerate(phis):
            qc = eraser_circuit(phi, erase=True)
            if use_ibm:
                tqc = transpile(qc, backend=backend, optimization_level=3, seed_transpiler=seed)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                try:
                    memory = result.get_memory()
                except Exception:
                    # synthesize memory from counts
                    memory = []
                    for bitstr, c in counts.items():
                        memory.extend([bitstr] * int(c))
            else:
                tqc = transpile(qc, backend=backend, optimization_level=3)
                job = backend.run(tqc, shots=shots, memory=True)
                result = job.result()
                counts = result.get_counts()
                memory = result.get_memory()

            self.last_memory_by_phi[idx] = memory

            p0_plus = conditional_prob_from_counts(counts, '0')
            p0_minus = conditional_prob_from_counts(counts, '1')
            p_plus.append(p0_plus); p_minus.append(p0_minus)

            # update plot
            self.line_plus.set_data(phis[:idx+1], p_plus)
            self.line_minus.set_data(phis[:idx+1], p_minus)
            self.ax.relim(); self.ax.autoscale_view(); self.canvas.draw_idle()

            # save outputs
            if save:
                s_stream = ''.join(m[1] for m in memory)
                save_text(os.path.join(outdir, f"binary_phi_{idx:02d}.txt"), s_stream)
            if save_corr:
                save_text(os.path.join(outdir, f"pairs_phi_{idx:02d}.txt"), '\n'.join(memory))
                s_i0 = ''.join(m[1] for m in memory if m[0]=='0')
                s_i1 = ''.join(m[1] for m in memory if m[0]=='1')
                save_text(os.path.join(outdir, f"S_given_I0_phi_{idx:02d}.txt"), s_i0)
                save_text(os.path.join(outdir, f"S_given_I1_phi_{idx:02d}.txt"), s_i1)
                # counts CSV
                import csv
                with open(os.path.join(outdir, f"counts_phi_{idx:02d}.csv"), "w", newline='') as f:
                    w = csv.writer(f)
                    w.writerow(['pair','count'])
                    for pair in ['00','01','10','11']:
                        w.writerow([pair, counts.get(pair, 0)])

            self.status.set(f"φ={phi:.3f} • P(+)= {p0_plus:.3f} • P(-)= {p0_minus:.3f} • {idx+1}/{len(phis)}")

        # save main CSV and PNG
        if save:
            import csv
            with open(os.path.join(outdir, "eraser_probs.csv"), "w", newline='') as f:
                w = csv.writer(f)
                w.writerow(['phi','P(s=0|i=+)','P(s=0|i=-)'])
                for phi,a,b in zip(phis, p_plus, p_minus):
                    w.writerow([phi,a,b])
            self.fig.savefig(os.path.join(outdir, "eraser_plot.png"), dpi=160)

        Vp = visibility(p_plus); Vm = visibility(p_minus)
        self.status.set(f"Done. V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f}  (IBM: {use_ibm})")

    def on_correlate(self):
        if not self.last_p_plus:
            messagebox.showinfo("Info","Run experiment first")
            return
        Vp = visibility(self.last_p_plus); Vm = visibility(self.last_p_minus)
        verdict = "Interferență detectată" if (Vp > 0.25 and Vm > 0.25) else "Interferență slabă/absentă"
        self.ax.text(0.02, 0.96, f"V(+)={Vp:.2f}, V(-)={Vm:.2f}\n{verdict}",
                     transform=self.ax.transAxes, fontsize=10,
                     bbox=dict(boxstyle="round", facecolor="white", alpha=0.8))
        self.canvas.draw_idle()
        self.status.set(f"Correlation: V(+)≈{Vp:.3f}, V(-)≈{Vm:.3f} → {verdict}")

    def on_preview(self):
        if not self.last_memory_by_phi:
            messagebox.showinfo("Info","Run first")
            return
        idx = max(self.last_memory_by_phi.keys())
        memory = self.last_memory_by_phi[idx]
        c = Counter(memory)
        preview_pairs = '\n'.join(memory[:200])
        s_i0 = ''.join(m[1] for m in memory if m[0]=='0')[:400]
        s_i1 = ''.join(m[1] for m in memory if m[0]=='1')[:400]

        win = tk.Toplevel(self)
        win.title(f"Preview correlations (phi idx {idx})")
        win.geometry("640x720")
        ttk.Label(win, text=f"First 200 pairs (I S) for phi index {idx}:").pack(anchor="w", padx=8, pady=(8,0))
        txt = tk.Text(win, height=12); txt.pack(fill=tk.BOTH, padx=8)
        txt.insert("1.0", preview_pairs); txt.config(state=tk.DISABLED)

        ttk.Label(win, text="Counts (coincidences):").pack(anchor="w", padx=8, pady=(8,0))
        tree = ttk.Treeview(win, columns=("pair","count"), show="headings", height=4)
        tree.heading("pair", text="pair"); tree.heading("count", text="count")
        for pair in ['00','01','10','11']:
            tree.insert("", "end", values=(pair, c.get(pair,0)))
        tree.pack(fill=tk.X, padx=8, pady=4)

        ttk.Label(win, text="S | I=0 (first 400 bits):").pack(anchor="w", padx=8, pady=(8,0))
        t0 = tk.Text(win, height=4); t0.pack(fill=tk.X, padx=8); t0.insert("1.0", s_i0); t0.config(state=tk.DISABLED)
        ttk.Label(win, text="S | I=1 (first 400 bits):").pack(anchor="w", padx=8, pady=(8,0))
        t1 = tk.Text(win, height=4); t1.pack(fill=tk.X, padx=8); t1.insert("1.0", s_i1); t1.config(state=tk.DISABLED)

    # ---------------- Brute-force ----------------
    def on_run_bruteforce(self):
        target = self.var_target_file.get().strip()
        if not target or not os.path.isfile(target):
            messagebox.showerror("Error","Select a valid binary file")
            return
        try:
            keylen = int(self.var_keylen.get())
            outchars = int(self.var_outchars.get())
            topn = int(self.var_topn.get())
            parallel = bool(self.var_parallel.get())
            workers = int(self.var_workers.get())
            if keylen <= 0 or outchars <= 0 or topn <= 0:
                raise ValueError
        except Exception:
            messagebox.showerror("Error","Invalid brute-force parameters")
            return

        self.tree.delete(*self.tree.get_children())
        self.status.set("Running brute-force...")
        threading.Thread(target=self._run_bruteforce_thread, args=(target,keylen,outchars,topn,parallel,workers), daemon=True).start()

    def _run_bruteforce_thread(self, target, keylen, outchars, topn, parallel, workers):
        try:
            stego_bits = read_binary_file(target)
            top = run_bruteforce(stego_bits, keylen, outchars, topn, parallel, workers)
            # populate tree
            for (k, pr, cs, preview) in top:
                self.tree.insert("", "end", values=(k, f"{pr:.3f}", f"{cs:.1f}", preview))
            # save report
            report_path = os.path.join(self.var_outdir.get(), "bruteforce_report.txt")
            lines = [f"Brute-force report: target={target}\n"]
            lines.append(f"keylen={keylen} outchars={outchars} parallel={parallel} workers={workers}\n")
            lines.append("Top results:\n")
            for k, pr, cs, preview in top:
                lines.append(f"{k}  printable={pr:.3f} common_words={cs:.1f} preview={preview}\n")
            save_text(report_path, ''.join(lines))
            self.status.set(f"Brute-force finished. Report: {report_path}")
        except Exception as e:
            self.status.set(f"Error: {e}")
            messagebox.showerror("Error", str(e))

    def _save_report(self):
        out = filedialog.asksaveasfilename(title="Save report as", defaultextension=".txt", filetypes=[("Text files","*.txt")])
        if not out: return
        rows = []
        for item in self.tree.get_children():
            rows.append(self.tree.item(item)["values"])
        txt = "Brute-force candidates:\n"
        for r in rows:
            txt += f"{r[0]} printable={r[1]} common={r[2]} preview={r[3]}\n"
        save_text(out, txt)
        messagebox.showinfo("Saved", f"Report saved to {out}")

# --------------------------
# Main
# --------------------------
def main():
    parser = argparse.ArgumentParser(description="Quantum Eraser GUI + Brute-Force")
    parser.add_argument("--no-gui", action="store_true", help="Run non-interactive (not used)")
    args = parser.parse_args()

    app = EraserBruteGUI()
    app.mainloop()

if __name__ == "__main__":
    main()
