# dcqe_one_composer_hw.py
# Delayed-Choice Quantum Eraser pe hardware IBM, într-un SINGUR circuit (Composer-friendly),
# actualizat pentru Qiskit >= 2.0: control flow cu `if_test` (nu .c_if), SamplerV2, qasm2.dumps.

import os, sys, math, json, getpass
from pathlib import Path
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt

from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister
from qiskit.circuit import Parameter
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager

from qiskit_ibm_runtime import QiskitRuntimeService
from qiskit_ibm_runtime import SamplerV2 as Sampler
from qiskit.qasm2 import dumps as qasm2_dumps

# ---------------------- Config ----------------------
OUT = Path("dcqe_one_composer_outputs"); OUT.mkdir(exist_ok=True)
SHOTS = int(os.getenv("DCQE_SHOTS", "2000"))
NUM_PHI = int(os.getenv("DCQE_NUM_PHI", "21"))
USE_DELAYED = True                          # introducem delay pe idler pentru "delayed-choice"
IDLER_DELAY_DT = int(os.getenv("DCQE_IDLER_DELAY_DT", "2048"))

def get_token() -> str:
    tok = os.getenv("IQP_API_TOKEN") or os.getenv("IBM_QUANTUM_API_TOKEN")
    if tok: return tok.strip()
    try:
        return getpass.getpass("Introduceți IBM Quantum API token: ").strip()
    except Exception:
        return input("Introduceți IBM Quantum API token (se va afișa): ").strip()

def connect_runtime(token: str, instance: str | None = None) -> QiskitRuntimeService:
    # Fallback compatibil cu instalări diferite (noi/vechi)
    for ch in ("ibm_quantum_platform", "ibm_quantum", "ibm_cloud"):
        try:
            return QiskitRuntimeService(channel=ch, token=token, instance=instance)
        except Exception:
            continue
    raise RuntimeError("Conectare eșuată: verificați qiskit-ibm-runtime și tokenul.")

# ------------------- Circuit "one-composer" -------------------
def dcqe_single_composer(phi_param: Parameter) -> QuantumCircuit:
    """
    q0 = signal, q1 = idler, q2 = ancilla (mode)
    Registre clasice în ORDIN FIX (important pentru parsingul counts):
        c_i   (1 bit)  – idler      [MSB]
        c_s   (1 bit)  – signal
        c_mod (1 bit)  – mode/anc   [LSB]
    Flux:
      H(q0); CX(q0,q1); RZ(phi)(q0); H(q0);
      barrier; (optional delay pe q1);
      measure q0 -> c_s;
      measure q2 -> c_mod;
      if (c_mod == 1): H(q1);           # eraser (baza X)
      measure q1 -> c_i;
    """
    q = QuantumRegister(3, "q")
    c_i   = ClassicalRegister(1, "c_i")     # MSB
    c_s   = ClassicalRegister(1, "c_s")
    c_mod = ClassicalRegister(1, "c_mode")  # LSB
    qc = QuantumCircuit(q, c_i, c_s, c_mod, name="DCQE_one_composer")

    # "double-slit" + which-path marking
    qc.h(q[0])
    qc.cx(q[0], q[1])
    qc.rz(phi_param, q[0])
    qc.h(q[0])

    qc.barrier()
    if USE_DELAYED:
        try:
            qc.delay(IDLER_DELAY_DT, q[1])  # întârzie idler
        except Exception:
            pass

    # măsurăm mai întâi semnalul (delayed-choice feeling)
    qc.measure(q[0], c_s[0])
    # măsurăm ancilla -> decide modul
    qc.measure(q[2], c_mod[0])

    # modern control-flow: if_test în loc de .c_if(...)
    with qc.if_test((c_mod, 1)):  # dacă c_mode == 1 → ERASER (H pe idler)
        qc.h(q[1])

    # măsurăm idler
    qc.measure(q[1], c_i[0])

    return qc

def build_param_circuits(num_phi: int):
    phi = Parameter("phi")
    base = dcqe_single_composer(phi)

    # Template QASM2 pentru Composer (un singur fișier, cu parametru)
    (OUT / "dcqe_one_composer_template.qasm").write_text(qasm2_dumps(base), encoding="utf-8")

    # Două moduri din ACELAȘI circuit:
    #   mode0 (which-path): ancilla în |0>  -> nu se aplică H pe idler
    #   mode1 (eraser):     ancilla în |1>  -> H condițional pe idler
    def with_mode(prep_one: bool):
        qc = base.copy()
        if prep_one:
            qc.x(2)  # q2 = |1>
        qc.name = "DCQE_mode1" if prep_one else "DCQE_mode0"
        return qc

    phis = np.linspace(0.0, 2.0*math.pi, num_phi, endpoint=True)
    mode0, mode1 = [], []
    for val in phis:
        mode0.append(with_mode(False).bind_parameters({phi: float(val)}))
        mode1.append(with_mode(True ).bind_parameters({phi: float(val)}))
    return phis, mode0, mode1

# --------------------- Execuție pe hardware ---------------------
def run_on_backend(service: QiskitRuntimeService, shots: int, num_phi: int):
    # Alege automat un QPU real cu >=3 qubiți
    try:
        backend = service.least_busy(operational=True, simulator=False, min_num_qubits=3)
    except Exception:
        backs = service.backends(operational=True, simulator=False, min_num_qubits=3)
        if not backs:
            raise RuntimeError("Niciun backend hardware disponibil (>=3 qubiți).")
        backend = sorted(backs, key=lambda b: getattr(b.status(), "pending_jobs", 10**9))[0]
    print(f"[Backend ales] {backend.name}")

    pm = generate_preset_pass_manager(target=backend.target, optimization_level=3)

    phis, mode0_circs, mode1_circs = build_param_circuits(num_phi)
    mode0_isa = [pm.run(c) for c in mode0_circs]
    mode1_isa = [pm.run(c) for c in mode1_circs]

    sampler = Sampler(mode=backend)

    job0 = sampler.run(mode0_isa, shots=shots); res0 = job0.result()
    job1 = sampler.run(mode1_isa, shots=shots); res1 = job1.result()

    def counts_of(idx, prim_result):
        # SamplerV2 recomandat: join_data().get_counts()
        return prim_result[idx].join_data().get_counts()

    # Bitstring-ul are 3 biți în ordinea registrelor clasice: [c_i][c_s][c_mode]
    # Interes:
    #  - Mode 0 (which-path): marginala P(s=0) (bitul din mijloc == '0')
    #  - Mode 1 (eraser): condiționate P(s=0 | idler=0/1)
    p_s0_mode0 = []
    p_s0_id0_mode1 = []
    p_s0_id1_mode1 = []

    for i in range(num_phi):
        cnt0 = counts_of(i, res0)
        cnt1 = counts_of(i, res1)

        # Mode 0: P(signal=0)
        shots0 = sum(cnt0.values())
        p_s0 = sum(v for k, v in cnt0.items() if len(k) == 3 and k[1] == "0") / shots0
        p_s0_mode0.append(p_s0)

        # Mode 1: condiționate pe idler (MSB)
        shots1 = sum(cnt1.values())
        n_id0 = sum(v for k, v in cnt1.items() if len(k) == 3 and k[0] == "0")
        n_id1 = shots1 - n_id0

        ps0_id0 = (sum(v for k, v in cnt1.items() if len(k) == 3 and k[0] == "0" and k[1] == "0") / n_id0) if n_id0 > 0 else float("nan")
        ps0_id1 = (sum(v for k, v in cnt1.items() if len(k) == 3 and k[0] == "1" and k[1] == "0") / n_id1) if n_id1 > 0 else float("nan")

        p_s0_id0_mode1.append(ps0_id0)
        p_s0_id1_mode1.append(ps0_id1)

    # Export două instanțe QASM2 utile pentru Composer (phi=π/4)
    phi_ref = math.pi/4
    _, [m0_ref], [m1_ref] = (lambda: build_param_circuits(1))()
    (OUT / "dcqe_mode0_phi_pi4.qasm").write_text(qasm2_dumps(m0_ref), encoding="utf-8")
    (OUT / "dcqe_mode1_phi_pi4.qasm").write_text(qasm2_dumps(m1_ref), encoding="utf-8")

    # Salvăm rezultate JSON
    results = {
        "timestamp": datetime.utcnow().isoformat() + "Z",
        "backend": backend.name,
        "shots": shots,
        "num_phi": num_phi,
        "phis": list(map(float, phis)),
        "P_s0_mode0_whichpath": p_s0_mode0,
        "P_s0_given_idler0_mode1_eraser": p_s0_id0_mode1,
        "P_s0_given_idler1_mode1_eraser": p_s0_id1_mode1,
        "notes": "Mode0=which-path (fără H pe idler) → marginală plată; Mode1=eraser (H condițional) → franjuri complementare în condiționate."
    }
    (OUT / "dcqe_one_composer_results.json").write_text(json.dumps(results, indent=2), encoding="utf-8")

    # Plot 1: which-path (marginal)
    plt.figure(figsize=(7,4))
    plt.plot(phis, p_s0_mode0, marker="o")
    plt.xlabel("φ [rad]"); plt.ylabel("P(signal=0)")
    plt.title("Mode 0 — which-path (marginal, așteptat: fără franjuri)")
    plt.grid(True, alpha=0.3); plt.tight_layout()
    plt.savefig(OUT / "mode0_whichpath_marginal.png", dpi=160); plt.close()

    # Plot 2: eraser condiționat
    plt.figure(figsize=(7,4))
    plt.plot(phis, p_s0_id0_mode1, marker="o", label="idler=0")
    plt.plot(phis, p_s0_id1_mode1, marker="s", label="idler=1")
    plt.xlabel("φ [rad]"); plt.ylabel("P(signal=0 | idler)")
    plt.title("Mode 1 — eraser (franjuri complementare)")
    plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
    plt.savefig(OUT / "mode1_eraser_conditionals.png", dpi=160); plt.close()

    print("\n=== DCQE (one-composer) — gata ===")
    print(f"Backend: {backend.name}")
    print(f"Rezultate: {OUT.resolve()} → dcqe_one_composer_results.json")
    print("Imagini:   mode0_whichpath_marginal.png, mode1_eraser_conditionals.png")
    print("Composer:  dcqe_one_composer_template.qasm  (un singur fișier, ambele moduri)")
    print("           dcqe_mode0_phi_pi4.qasm / dcqe_mode1_phi_pi4.qasm")

if __name__ == "__main__":
    print("=== DCQE (one-composer) pe IBM Quantum — Qiskit ≥ 2.0 ===")
    token = get_token()
    if not token:
        sys.exit("Token gol. Setați IQP_API_TOKEN sau introduceți tokenul.")
    instance = os.getenv("IBM_QUANTUM_INSTANCE")  # opțional
    service = connect_runtime(token, instance)
    run_on_backend(service, SHOTS, NUM_PHI)
