#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Quantum Eraser – Detection Attack Ultimate (v3.0 - 2025)
Imbunatatiri:
✓ Multiprocessing (foloseste toate CPU cores)
✓ Algoritm Genetic Real (pentru keylen > 20)
✓ Grafica Live: Reconstructia interferentei in timp real
✓ Convergenta vizuala
"""

import tkinter as tk
from tkinter import ttk, messagebox, filedialog
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
import hashlib
import threading
import time
import random
from concurrent.futures import ProcessPoolExecutor
import multiprocessing

# Incercam import Qiskit, daca nu exista, folosim mock data (pentru demo UI)
try:
    from qiskit import QuantumCircuit, transpile
    from qiskit_aer import AerSimulator
    HAS_QISKIT = True
except ImportError:
    HAS_QISKIT = False

# ==================== CORE LOGIC (Stateless functions for Multiprocessing) ====================

def generate_keystream(key_int: int, key_len: int, length: int) -> np.ndarray:
    """Genereaza un keystream deterministic bazat pe SHA256."""
    # Optimizare: generam in chunk-uri
    np.random.seed(key_int) # Pseudo-keystream rapid pentru demo (SHA256 e lent in bucla mare)
    # Nota: Intr-un atac real criptografic, aici ar fi AES sau SHA256
    # Pentru a pastra viteza in GUI demo, folosim un generator rapid deterministic
    return np.random.randint(0, 2, length, dtype=np.int8)

def generate_sha256_keystream_fast(key_int: int, key_len: int, length: int) -> np.ndarray:
    """Varianta lenta dar criptografic corecta (folosita in atac real)"""
    bits = np.zeros(length, dtype=np.int8)
    key_bytes = key_int.to_bytes((key_len + 7) // 8, 'big')
    counter = 0
    pos = 0
    while pos < length:
        digest = hashlib.sha256(key_bytes + counter.to_bytes(4, 'big')).digest()
        # Convert bytes to bits
        for b in digest:
            if pos >= length: break
            for i in range(7, -1, -1):
                if pos >= length: break
                bits[pos] = (b >> i) & 1
                pos += 1
        counter += 1
    return bits

def evaluate_candidate(args):
    """Functie worker pentru multiprocessing. Evalueaza o cheie."""
    key_int, key_len, s_bits, cum_shots = args
    
    # Generam keystream-ul ghicit (Idler-ul ipotetic)
    # Folosim varianta SHA256 reala
    I_guess = generate_sha256_keystream_fast(key_int, key_len, len(s_bits))
    
    # Calculam vizibilitatea
    # Vrem sa vedem daca I_guess sorteaza S_bits in franje de interferenta
    
    scores_diff = []
    idx_start = 0
    
    # Vectorizare pe segmente (phi)
    for idx_end in cum_shots[1:]:
        seg_I = I_guess[idx_start:idx_end]
        seg_S = s_bits[idx_start:idx_end]
        
        # Cand I=0, care e probabilitatea ca S=0?
        # P(S=0 | I=0) = count(S=0 & I=0) / count(I=0)
        # S=0 (activ) e notat cu 1 in array-ul nostru de detectie, deci inversam logica daca e cazul
        # Aici presupunem s_bits: 1=click, 0=no click.
        
        mask_I0 = (seg_I == 0)
        mask_I1 = (seg_I == 1)
        
        count_I0 = np.sum(mask_I0)
        count_I1 = np.sum(mask_I1)
        
        if count_I0 == 0 or count_I1 == 0:
            idx_start = idx_end
            continue
            
        p_s0_given_i0 = np.sum(seg_S[mask_I0]) / count_I0
        p_s0_given_i1 = np.sum(seg_S[mask_I1]) / count_I1
        
        # Diferenta dintre pattern-uri
        scores_diff.append(abs(p_s0_given_i0 - p_s0_given_i1))
        idx_start = idx_end
    
    if not scores_diff:
        return 0.0, key_int

    # Scorul este media diferentei de probabilitate (Visibility proxy)
    # Daca cheia e gresita, P(...) e random 0.5, deci diff ~ 0
    # Daca cheia e buna, P(...) variaza sinusoidal, diff e mare
    final_score = np.mean(scores_diff) * 2 # Scalare spre 1
    return final_score, key_int

# ==================== GUI CLASS ====================

class QuantumEraserApp(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("Quantum Eraser – Cryptanalytic Attack v3.0 (Multi-Core)")
        self.geometry("1400x850")
        try:
            self.state('zoomed')
        except:
            pass

        self.style = ttk.Style(self)
        self.style.theme_use('clam')
        
        # Data holders
        self.experiment_data = None # {phis, s_bits, cum_shots, shots_per_phi}
        self.is_attacking = False
        self.attack_queue = multiprocessing.Queue()
        
        self._setup_ui()
        
    def _setup_ui(self):
        # --- Top Bar: Experiment Control ---
        top_frame = ttk.LabelFrame(self, text="1. Quantum Experiment Setup")
        top_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=5)
        
        ttk.Label(top_frame, text="Shots/Angle:").pack(side=tk.LEFT, padx=5)
        self.e_shots = ttk.Entry(top_frame, width=8); self.e_shots.insert(0, "4000")
        self.e_shots.pack(side=tk.LEFT)
        
        ttk.Label(top_frame, text="Angles (N):").pack(side=tk.LEFT, padx=5)
        self.e_angles = ttk.Entry(top_frame, width=5); self.e_angles.insert(0, "16")
        self.e_angles.pack(side=tk.LEFT)
        
        self.btn_run_exp = ttk.Button(top_frame, text="RUN EXPERIMENT (Generate Data)", command=self.run_experiment)
        self.btn_run_exp.pack(side=tk.LEFT, padx=20)
        
        lbl_status = ttk.Label(top_frame, text="Qiskit: " + ("INSTALLED" if HAS_QISKIT else "MISSING (Simulated Mode)"), 
                               foreground="green" if HAS_QISKIT else "red")
        lbl_status.pack(side=tk.RIGHT, padx=10)

        # --- Main Content ---
        content = ttk.Frame(self)
        content.pack(fill=tk.BOTH, expand=True, padx=10, pady=5)
        
        # Left: Attack Configuration & Stats
        left_panel = ttk.LabelFrame(content, text="2. Attack Configuration")
        left_panel.pack(side=tk.LEFT, fill=tk.Y, padx=5)
        
        # Attack Params
        frm_param = ttk.Frame(left_panel)
        frm_param.pack(fill=tk.X, pady=10)
        
        ttk.Label(frm_param, text="Key Length (bits):").grid(row=0, column=0, sticky='w')
        self.sc_key = tk.Scale(frm_param, from_=8, to=64, orient=tk.HORIZONTAL)
        self.sc_key.set(16)
        self.sc_key.grid(row=0, column=1, sticky='ew')
        
        ttk.Label(frm_param, text="Method:").grid(row=1, column=0, sticky='w', pady=10)
        self.var_method = tk.StringVar(value="Genetic")
        ttk.Radiobutton(frm_param, text="Brute-Force (Exhaustive)", variable=self.var_method, value="Brute").grid(row=1, column=1, sticky='w')
        ttk.Radiobutton(frm_param, text="Genetic Algorithm (Smart)", variable=self.var_method, value="Genetic").grid(row=2, column=1, sticky='w')
        
        ttk.Label(frm_param, text="Workers (CPU):").grid(row=3, column=0, sticky='w', pady=5)
        self.lbl_cpu = ttk.Label(frm_param, text=f"{multiprocessing.cpu_count()} threads")
        self.lbl_cpu.grid(row=3, column=1, sticky='w')

        # Start/Stop
        self.btn_attack = ttk.Button(left_panel, text="▶ START ATTACK", command=self.toggle_attack)
        self.btn_attack.pack(fill=tk.X, pady=20, padx=5)
        
        # Logs
        ttk.Label(left_panel, text="Attack Log:").pack(anchor='w')
        self.log_text = tk.Text(left_panel, height=15, width=40, font=("Consolas", 9))
        self.log_text.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)

        # Right: Visualizations
        right_panel = ttk.Frame(content)
        right_panel.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
        
        self.notebook = ttk.Notebook(right_panel)
        self.notebook.pack(fill=tk.BOTH, expand=True)
        
        # Tab 1: Live Physics Reconstruction
        tab1 = ttk.Frame(self.notebook)
        self.notebook.add(tab1, text="Live Interference Reconstruction")
        
        self.fig_rec = Figure(figsize=(5, 4), dpi=100)
        self.ax_rec = self.fig_rec.add_subplot(111)
        self.ax_rec.set_xlabel("Phase φ (radians)")
        self.ax_rec.set_ylabel("Probability P(S|I)")
        self.ax_rec.set_title("Real-time Physics Recovery")
        self.canvas_rec = FigureCanvasTkAgg(self.fig_rec, tab1)
        self.canvas_rec.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        # Tab 2: Convergence
        tab2 = ttk.Frame(self.notebook)
        self.notebook.add(tab2, text="Algorithm Convergence")
        
        self.fig_conv = Figure(figsize=(5, 4), dpi=100)
        self.ax_conv = self.fig_conv.add_subplot(111)
        self.ax_conv.set_xlabel("Generation / Iteration")
        self.ax_conv.set_ylabel("Best Fitness Score")
        self.ax_conv.grid(True)
        self.canvas_conv = FigureCanvasTkAgg(self.fig_conv, tab2)
        self.canvas_conv.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        # Data containers for plots
        self.history_scores = []
        self.history_iters = []

    def log(self, msg):
        self.log_text.insert(tk.END, f"{msg}\n")
        self.log_text.see(tk.END)

    def run_experiment(self):
        """Simuleaza experimentul cuantic si salveaza datele 'criptate'."""
        shots = int(self.e_shots.get())
        n_angles = int(self.e_angles.get())
        phis = np.linspace(0, 2*np.pi, n_angles)
        
        self.log(f"Generare date: {n_angles} unghiuri, {shots} shots...")
        
        all_s_bits = []
        cum_shots = [0]
        
        # Generam o "Cheie Adevarata" random pentru acest experiment
        true_key_len = self.sc_key.get()
        true_key_int = random.getrandbits(true_key_len)
        self.log(f"DEBUG: True Key (Hidden) = {bin(true_key_int)}")
        
        # Generam True Keystream (Idler)
        total_shots = shots * n_angles
        true_idler_stream = generate_sha256_keystream_fast(true_key_int, true_key_len, total_shots)
        
        # Simulam masuratori
        # Daca Idler=0, S are interferenta (depinde de phi).
        # Daca Idler=1, S nu are interferenta (0.5).
        # Aceasta este simplificarea corelata.
        
        current_idx = 0
        for phi in phis:
            segment_len = shots
            idler_segment = true_idler_stream[current_idx : current_idx + segment_len]
            
            # Probabilitate cuantica teoretica pentru S=0 cand avem interferenta
            p_interference = (1 + np.cos(phi)) / 2.0
            
            s_segment = np.zeros(segment_len, dtype=np.int8)
            
            # Vectorizat
            rnd = np.random.rand(segment_len)
            
            # Logic:
            # Daca I=0 -> Prob = p_interference
            # Daca I=1 -> Prob = 0.5
            
            # Masca
            mask_i0 = (idler_segment == 0)
            mask_i1 = (idler_segment == 1)
            
            # Setam bitii S
            s_segment[mask_i0] = (rnd[mask_i0] < p_interference).astype(np.int8)
            s_segment[mask_i1] = (rnd[mask_i1] < 0.5).astype(np.int8)
            
            all_s_bits.extend(s_segment)
            cum_shots.append(cum_shots[-1] + segment_len)
            current_idx += segment_len
            
        self.experiment_data = {
            'phis': phis,
            's_bits': np.array(all_s_bits),
            'cum_shots': np.array(cum_shots),
            'true_key': true_key_int # doar pentru debug
        }
        
        # Plot initial (Zero information)
        self.ax_rec.clear()
        self.ax_rec.plot(phis, [0.5]*len(phis), 'r--', label="Current Guess (Noise)")
        self.ax_rec.set_ylim(0, 1)
        self.ax_rec.set_title("Waiting for attack... (Only Signal S known)")
        self.canvas_rec.draw()
        
        self.log("Experiment complet. Datele S sunt inregistrate. I este sters.")

    def toggle_attack(self):
        if self.is_attacking:
            self.is_attacking = False
            self.btn_attack.config(text="▶ START ATTACK")
            self.log("Stopping attack...")
            return
            
        if not self.experiment_data:
            messagebox.showerror("Error", "Run experiment first!")
            return
            
        self.is_attacking = True
        self.btn_attack.config(text="⏹ STOP ATTACK")
        self.history_scores = []
        self.history_iters = []
        self.ax_conv.clear()
        
        method = self.var_method.get()
        key_len = self.sc_key.get()
        
        # Start background thread for managing the attack loop
        threading.Thread(target=self._attack_manager, args=(method, key_len), daemon=True).start()

    def _attack_manager(self, method, key_len):
        s_bits = self.experiment_data['s_bits']
        cum_shots = self.experiment_data['cum_shots']
        phis = self.experiment_data['phis']
        
        best_score = -1.0
        best_key = -1
        
        workers = multiprocessing.cpu_count()
        pool = ProcessPoolExecutor(max_workers=workers)
        
        generation = 0
        
        # --- INIT POPULATION (Genetic) or RANGE (Brute) ---
        pop_size = 100
        population = [random.getrandbits(key_len) for _ in range(pop_size)]
        
        self.log(f"Starting {method} with {key_len} bits using {workers} workers...")
        
        start_time = time.time()
        
        while self.is_attacking:
            current_batch = []
            
            if method == "Brute":
                # In brute force, luam chunk-uri secventiale
                # Demo limit: doar random sampling pentru chei mari, sau secvential pt mici
                # Facem Random Search pentru demo daca keylen e mare
                current_batch = [random.getrandbits(key_len) for _ in range(pop_size)]
            else:
                # Genetic: folosim populatia curenta
                current_batch = population
            
            # Prepare args for multiprocessing
            tasks = [(k, key_len, s_bits, cum_shots) for k in current_batch]
            
            # Run parallel evaluation
            results = list(pool.map(evaluate_candidate, tasks))
            
            # Process results
            # results e o lista de (score, key)
            results.sort(key=lambda x: x[0], reverse=True)
            
            batch_best_score, batch_best_key = results[0]
            
            # Update global best
            if batch_best_score > best_score:
                best_score = batch_best_score
                best_key = batch_best_key
                self.log(f"Gen {generation}: New Best Score = {best_score:.4f}")
                
                # Update GUI graphs in main thread
                self.after(0, lambda: self._update_plots(best_key, best_score, generation))
            
            # Genetic Evolution Step
            if method == "Genetic":
                # Elitism
                new_pop = [r[1] for r in results[:10]] 
                
                # Breeding
                while len(new_pop) < pop_size:
                    # Select parents (Top 50%)
                    p1 = random.choice(results[:50])[1]
                    p2 = random.choice(results[:50])[1]
                    
                    # Crossover
                    mask = (1 << (key_len // 2)) - 1
                    child = (p1 & mask) | (p2 & ~mask)
                    
                    # Mutation
                    if random.random() < 0.2:
                        bit_to_flip = random.randint(0, key_len-1)
                        child ^= (1 << bit_to_flip)
                        
                    new_pop.append(child)
                population = new_pop
                
            generation += 1
            
            # Check if perfectly found (score close to theoretical max approx 0.5-1.0 depending on metric)
            if best_score > 0.95:
                self.log("!!! KEY FOUND OR VERY CLOSE !!!")
                self.is_attacking = False
                break
                
            # Throttle GUI updates slightly
            # time.sleep(0.01) 
            
        pool.shutdown(wait=False)
        self.after(0, lambda: self.btn_attack.config(text="▶ START ATTACK"))
        self.log(f"Stopped. Best Key Found: {bin(best_key)}")

    def _update_plots(self, key_int, score, generation):
        # 1. Update Convergence Plot
        self.history_scores.append(score)
        self.history_iters.append(generation)
        self.ax_conv.clear()
        self.ax_conv.plot(self.history_iters, self.history_scores, 'g-')
        self.ax_conv.set_title(f"Convergence (Best: {score:.4f})")
        self.ax_conv.set_xlabel("Generations")
        self.ax_conv.grid(True)
        self.canvas_conv.draw()
        
        # 2. Update Reconstruction Plot (The Physics)
        # Trebuie sa refacem curbele pentru cheia curenta
        phis = self.experiment_data['phis']
        s_bits = self.experiment_data['s_bits']
        cum_shots = self.experiment_data['cum_shots']
        
        I_guess = generate_sha256_keystream_fast(key_int, self.sc_key.get(), len(s_bits))
        
        p0_vals = []
        p1_vals = []
        
        idx_start = 0
        for idx_end in cum_shots[1:]:
            seg_I = I_guess[idx_start:idx_end]
            seg_S = s_bits[idx_start:idx_end]
            
            m0 = (seg_I == 0)
            m1 = (seg_I == 1)
            
            # Calc probabilitati
            v0 = np.mean(seg_S[m0]) if np.sum(m0)>0 else 0.5
            v1 = np.mean(seg_S[m1]) if np.sum(m1)>0 else 0.5
            
            p0_vals.append(v0)
            p1_vals.append(v1)
            idx_start = idx_end
            
        self.ax_rec.clear()
        # Plotam punctele reconstruite
        self.ax_rec.plot(phis, p0_vals, 'o-', color='blue', label='P(S|I=0) Reconstructed')
        self.ax_rec.plot(phis, p1_vals, 's-', color='orange', alpha=0.5, label='P(S|I=1) Reconstructed')
        
        # Referinta teoretica
        theory = (1 + np.cos(phis))/2
        self.ax_rec.plot(phis, theory, 'k:', alpha=0.3, label='Theoretical Target')
        
        self.ax_rec.set_ylim(0, 1)
        self.ax_rec.set_title(f"Interference Reconstruction (Key: {bin(key_int)[:10]}...)")
        self.ax_rec.legend(loc='upper right', fontsize='small')
        self.canvas_rec.draw()

if __name__ == "__main__":
    # Windows multiprocessing fix
    multiprocessing.freeze_support()
    
    app = QuantumEraserApp()
    app.mainloop()
