#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Delayed-Choice Quantum Eraser (Qiskit, AerSimulator)
Author: ChatGPT
Usage examples:
  python quantum_eraser_qiskit.py --shots 6000 --nphi 25 --plot --outdir results
  python quantum_eraser_qiskit.py --shots 4000 --nphi 13

Requires:
  pip install "qiskit==1.2.*" "qiskit-aer==0.15.*" matplotlib
"""
import argparse
import math
import os
import csv
from typing import Dict, List, Tuple

import numpy as np
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator

# -------------------------- Core circuit --------------------------
def eraser_counts_qiskit(phi: float, erase: bool = True, shots: int = 8000, seed: int = 1234) -> Tuple[Dict[str, int], List[str]]:
    """
    Build and run the quantum eraser circuit for a given phase phi.
    Qubits: q0 = S (signal), q1 = I (idler); Classical: c0 = S, c1 = I
    Returns (counts, memory). Memory list contains bitstrings 'i s' (c1 c0).
    """
    qc = QuantumCircuit(2, 2)
    # First beamsplitter on S
    qc.h(0)
    # Phase accumulated by S (analog "path phase")
    qc.rz(phi, 0)
    # Which-path marking via entanglement
    qc.cx(0, 1)
    # Second beamsplitter on S
    qc.h(0)
    # Choose measurement basis for Idler
    if erase:
        qc.h(1)  # measure I in X (eraser); else Z (which-path)
    # Measure
    qc.measure(0, 0)  # S -> c0
    qc.measure(1, 1)  # I -> c1

    backend = AerSimulator(seed_simulator=seed)
    tqc = transpile(qc, backend=backend, optimization_level=3)
    job = backend.run(tqc, shots=shots, memory=True)
    res = job.result()
    counts = res.get_counts()
    memory = res.get_memory()  # list of strings 'i s' (c1 c0)
    return counts, memory

# ---------------------- Post-processing utils ---------------------
def conditional_prob(counts: Dict[str, int], i_bit: str = '0') -> float:
    """
    Compute P(s=0 | i=i_bit) from joint counts.
    Bitstring format is 'i s' (c1 c0), e.g., '10' means i=1, s=0.
    """
    tot = n0 = 0
    for k, v in counts.items():
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0':
                n0 += v
    return (n0 / tot) if tot else 0.0

def hist_S_uncond(memory: List[str]) -> Dict[str, int]:
    """Histogram for S (unconditioned)."""
    from collections import Counter
    s_bits = [m[1] for m in memory]
    return dict(Counter(s_bits))

def hist_S_cond(memory: List[str], i_bit: str = '0') -> Dict[str, int]:
    """Histogram for S conditioned on I=i_bit."""
    from collections import Counter
    s_bits = [m[1] for m in memory if m[0] == i_bit]
    return dict(Counter(s_bits))

def visibility(pvals: List[float]) -> float:
    """Compute interference visibility V = (max - min)/(max + min)."""
    if not pvals:
        return 0.0
    pmax, pmin = max(pvals), min(pvals)
    denom = pmax + pmin
    return (pmax - pmin) / denom if denom else 0.0

# ----------------------------- I/O -------------------------------
def save_csv_prob(phi_list: List[float], p_i0: List[float], p_i1: List[float], path: str) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', newline='') as f:
        w = csv.writer(f)
        w.writerow(['phi', 'P(s=0|i=0)', 'P(s=0|i=1)'])
        for phi, a, b in zip(phi_list, p_i0, p_i1):
            w.writerow([phi, a, b])

def save_memory(memory: List[str], path: str) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', newline='') as f:
        for m in memory:
            f.write(m + '\n')

def plot_curves(phi_list: List[float], p_i0: List[float], p_i1: List[float], title: str, outpath: str) -> None:
    import matplotlib.pyplot as plt
    plt.figure()
    plt.plot(phi_list, p_i0, marker='o', label="P(s=0 | i=0)  (~ '+_X')")
    plt.plot(phi_list, p_i1, marker='s', label="P(s=0 | i=1)  (~ '-_X')")
    plt.xlabel('φ (radiani)')
    plt.ylabel('Probabilitate')
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.4)
    plt.tight_layout()
    os.makedirs(os.path.dirname(outpath), exist_ok=True)
    plt.savefig(outpath, dpi=160)
    plt.close()

# ---------------------------- Runner -----------------------------
def run_scan(nphi: int, shots: int, outdir: str, seed: int, save_mem: bool, do_plot: bool) -> None:
    # Scan φ în [0, 2π]
    phi_list = list(np.linspace(0.0, 2.0 * math.pi, nphi))
    # ----- Eraser mode (I in X-basis) -----
    p_plus, p_minus = [], []
    for idx, phi in enumerate(phi_list):
        counts, memory = eraser_counts_qiskit(phi, erase=True, shots=shots, seed=seed)
        p_i0 = conditional_prob(counts, '0')  # i=0 ≈ '+_X'
        p_i1 = conditional_prob(counts, '1')  # i=1 ≈ '-_X'
        p_plus.append(p_i0); p_minus.append(p_i1)
        print(f"[ERASE] phi={phi:.3f}  P(s=0|i=+)={p_i0:.3f}  P(s=0|i=-)={p_i1:.3f}")
        if save_mem:
            save_memory(memory, os.path.join(outdir, f"memory_erase_phi_{idx:02d}.txt"))
    save_csv_prob(phi_list, p_plus, p_minus, os.path.join(outdir, "eraser_probs.csv"))
    print(f"Vizibilitate ERASE  V(+)≈{visibility(p_plus):.3f}  V(-)≈{visibility(p_minus):.3f}")
    if do_plot:
        plot_curves(phi_list, p_plus, p_minus, "Quantum eraser (I in X)", os.path.join(outdir, "eraser_plot.png"))

    # ----- Which-path mode (I in Z-basis) -----
    p_i0_list, p_i1_list = [], []
    for idx, phi in enumerate(phi_list):
        counts, memory = eraser_counts_qiskit(phi, erase=False, shots=shots, seed=seed)
        p_i0 = conditional_prob(counts, '0')
        p_i1 = conditional_prob(counts, '1')
        p_i0_list.append(p_i0); p_i1_list.append(p_i1)
        print(f"[Z]     phi={phi:.3f}  P(s=0|i=0)={p_i0:.3f}  P(s=0|i=1)={p_i1:.3f}")
        if save_mem:
            save_memory(memory, os.path.join(outdir, f"memory_Z_phi_{idx:02d}.txt"))
    save_csv_prob(phi_list, p_i0_list, p_i1_list, os.path.join(outdir, "whichpath_probs.csv"))
    print(f"Vizibilitate Z     V(i=0)≈{visibility(p_i0_list):.3f}  V(i=1)≈{visibility(p_i1_list):.3f}")
    if do_plot:
        plot_curves(phi_list, p_i0_list, p_i1_list, "Which-path (I in Z)", os.path.join(outdir, "whichpath_plot.png"))

# ----------------------------- Main ------------------------------
def main():
    ap = argparse.ArgumentParser(description="Delayed-choice quantum eraser (Qiskit + Aer)")
    ap.add_argument("--shots", type=int, default=4000, help="Număr de măsurători per punct φ (default: 4000)")
    ap.add_argument("--nphi", type=int, default=13, help="Număr de puncte între 0 și 2π (default: 13)")
    ap.add_argument("--outdir", type=str, default="out", help="Director ieșire (CSV, PNG, memorii)")
    ap.add_argument("--seed", type=int, default=1234, help="Seed pentru simulator (reproducibilitate)")
    ap.add_argument("--plot", action="store_true", help="Generează grafice PNG (matplotlib)")
    ap.add_argument("--memory", action="store_true", help="Salvează bitstring-urile brute pentru fiecare φ")
    args = ap.parse_args()

    os.makedirs(args.outdir, exist_ok=True)
    run_scan(nphi=args.nphi, shots=args.shots, outdir=args.outdir, seed=args.seed, save_mem=args.memory, do_plot=args.plot)
    print(f"Fișiere salvate în: {args.outdir}")

if __name__ == "__main__":
    main()
