# dcqe_one_composer.py
# DCQE (Delayed-Choice Quantum Eraser) pe hardware IBM cu UN SINGUR circuit QASM2
# - un ancilla "mode" este măsurat devreme -> c_mode
# - dacă c_mode == 1, aplicăm H pe idler (eraser); dacă 0, NU (which-path)
# - semnalul (q0) e măsurat ÎNAINTE (delayed-choice stil)
# - rulăm același circuit de două ori: setăm ancilla în |0> (mode=0) sau |1> (mode=1)
# - sweep φ, plot rezultate, export QASM2 pentru Composer

import os, sys, math, json, getpass
from pathlib import Path
from datetime import datetime

import numpy as np
import matplotlib.pyplot as plt

from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister
from qiskit.circuit import Parameter
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager

from qiskit_ibm_runtime import QiskitRuntimeService
from qiskit_ibm_runtime import SamplerV2 as Sampler
from qiskit.qasm2 import dumps as qasm2_dumps

OUT = Path("dcqe_one_composer_outputs"); OUT.mkdir(exist_ok=True)
SHOTS = int(os.getenv("DCQE_SHOTS", "2000"))
NUM_PHI = int(os.getenv("DCQE_NUM_PHI", "21"))
USE_DELAYED = True
IDLER_DELAY_DT = int(os.getenv("DCQE_IDLER_DELAY_DT", "2048"))

def get_token():
    tok = os.getenv("IQP_API_TOKEN") or os.getenv("IBM_QUANTUM_API_TOKEN")
    if tok: return tok.strip()
    try:
        return getpass.getpass("Introduceți IBM Quantum API token: ").strip()
    except Exception:
        return input("Introduceți IBM Quantum API token (se va afișa): ").strip()

def connect_runtime(token, instance=None):
    # compat pe versiuni diferite de qiskit-ibm-runtime
    for ch in ("ibm_quantum_platform", "ibm_quantum", "ibm_cloud"):
        try:
            return QiskitRuntimeService(channel=ch, token=token, instance=instance)
        except Exception:
            continue
    raise RuntimeError("Nu m-am putut conecta la IBM Runtime (verificați pachetele și tokenul).")

def dcqe_single_composer(phi_param: Parameter) -> QuantumCircuit:
    """
    Un singur circuit DCQE cu:
      - q0 = signal, q1 = idler, q2 = ancilla (mode-qubit)
      - c_s (1 bit): rezultatul măsurării signal
      - c_i (1 bit): rezultatul măsurării idler
      - c_mode (1 bit): rezultatul măsurării ancilla (decide dacă aplicăm H pe idler)
    Flux:
      1) pregătim superpoziția pe q0; entanglăm which-path în q1 (CNOT)
      2) aplicăm faza variabilă φ pe q0 (RZ(φ)) și recombinăm (H pe q0)
      3) (delayed) măsurăm q0 -> c_s
      4) măsurăm q2 -> c_mode
      5) dacă c_mode == 1: aplicăm H pe q1 (eraser). dacă 0: nimic (which-path)
      6) măsurăm q1 -> c_i
    """
    q = QuantumRegister(3, "q")
    c_s = ClassicalRegister(1, "c_s")
    c_i = ClassicalRegister(1, "c_i")
    c_mode = ClassicalRegister(1, "c_mode")
    qc = QuantumCircuit(q, c_s, c_i, c_mode, name="DCQE_one_composer")

    # double-slit pe semnal (q0), marcare which-path în idler (q1)
    qc.h(q[0])
    qc.cx(q[0], q[1])
    # faza variabilă pe semnal
    qc.rz(phi_param, q[0])
    # recombinare
    qc.h(q[0])

    # (opțional) întârziere pe idler pentru "delayed choice" mai teatral
    qc.barrier()
    if USE_DELAYED:
        try:
            qc.delay(IDLER_DELAY_DT, q[1])
        except Exception:
            pass

    # măsurăm semnalul înainte (delayed-choice)
    qc.measure(q[0], c_s[0])
    # măsurăm ancilla în c_mode (decizia eraser)
    qc.measure(q[2], c_mode[0])

    # aplica H pe idler numai dacă c_mode == 1 (eraser)
    # Qiskit c_if: condiționare pe ÎNTREGUL creg (OpenQASM 2: if (c_mode==1) h q[1];)
    qc.h(q[1]).c_if(c_mode, 1)

    # măsurăm idler
    qc.measure(q[1], c_i[0])

    return qc

def build_param_circuits(num_phi: int):
    phi = Parameter("phi")
    base = dcqe_single_composer(phi)

    # Export QASM2 al circuitului "universal" (cu parametru)
    (OUT / "dcqe_one_composer_template.qasm").write_text(qasm2_dumps(base), encoding="utf-8")

    # Pregătim două versiuni pentru rulat: mode=0 (which-path) și mode=1 (eraser)
    # → setăm ancilla q2 în |0> sau |1> înainte de măsurarea ei.
    # Tip: folosim 2 copii ale circuitului parametrizat și prefixăm pregătirea ancilla.
    def with_mode(prep_one: bool):
        qc = base.copy()
        if prep_one:
            qc.x(2)  # ancilla q2 în |1>
        qc.name = "DCQE_mode1" if prep_one else "DCQE_mode0"
        return qc

    circs_mode0 = []
    circs_mode1 = []
    phis = np.linspace(0.0, 2.0*math.pi, num_phi, endpoint=True)
    for val in phis:
        c0 = with_mode(False).bind_parameters({phi: float(val)})
        c1 = with_mode(True ).bind_parameters({phi: float(val)})
        circs_mode0.append(c0)
        circs_mode1.append(c1)

    return phis, circs_mode0, circs_mode1

def run_on_backend(service: QiskitRuntimeService, shots: int, num_phi: int):
    try:
        backend = service.least_busy(operational=True, simulator=False, min_num_qubits=3)
    except Exception:
        backs = service.backends(operational=True, simulator=False, min_num_qubits=3)
        if not backs:
            raise RuntimeError("Nu există backend hardware disponibil (≥3 qubiți).")
        backend = sorted(backs, key=lambda b: getattr(b.status(), "pending_jobs", 10**9))[0]
    print(f"[Backend ales] {backend.name}")

    pm = generate_preset_pass_manager(target=backend.target, optimization_level=3)

    phis, mode0, mode1 = build_param_circuits(num_phi)
    mode0_isa = [pm.run(c) for c in mode0]
    mode1_isa = [pm.run(c) for c in mode1]

    sampler = Sampler(mode=backend)

    job0 = sampler.run(mode0_isa, shots=shots); res0 = job0.result()
    job1 = sampler.run(mode1_isa, shots=shots); res1 = job1.result()

    def counts_of(idx, prim_result):
        return prim_result[idx].join_data().get_counts()

    # Extragem probabilități relevante:
    # bitstring ordine: [c_i, c_s, c_mode] -> 3 biți (MSB=c_i)
    # Interes: P(signal=0) (marginal) și condiționate pe idler pentru eraser
    p_s0_mode0 = []  # which-path marginal
    p_s0_id0_mode1 = []  # eraser, condiționat idler=0
    p_s0_id1_mode1 = []  # eraser, condiționat idler=1

    for i in range(num_phi):
        cnt0 = counts_of(i, res0)
        cnt1 = counts_of(i, res1)

        # Mode 0: marginala P(s=0)
        shots0 = sum(cnt0.values())
        p_s0 = sum(v for k, v in cnt0.items() if k[1] == "0") / shots0
        p_s0_mode0.append(p_s0)

        # Mode 1: condiționale pe idler
        shots1 = sum(cnt1.values())
        n_id0 = sum(v for k, v in cnt1.items() if k[0] == "0")
        n_id1 = shots1 - n_id0
        if n_id0 > 0:
            ps0_id0 = sum(v for k, v in cnt1.items() if k[0] == "0" and k[1] == "0") / n_id0
        else:
            ps0_id0 = float("nan")
        if n_id1 > 0:
            ps0_id1 = sum(v for k, v in cnt1.items() if k[0] == "1" and k[1] == "0") / n_id1
        else:
            ps0_id1 = float("nan")
        p_s0_id0_mode1.append(ps0_id0)
        p_s0_id1_mode1.append(ps0_id1)

    # Salvăm rezultate + QASM instanțiat pentru un φ de referință
    ref_phi = math.pi/4
    _, mode0_ref, mode1_ref = build_param_circuits(1)
    (OUT / "dcqe_mode0_phi_pi4.qasm").write_text(qasm2_dumps(mode0_ref[0]), encoding="utf-8")
    (OUT / "dcqe_mode1_phi_pi4.qasm").write_text(qasm2_dumps(mode1_ref[0]), encoding="utf-8")

    results = {
        "timestamp": datetime.utcnow().isoformat() + "Z",
        "backend": backend.name,
        "shots": shots,
        "num_phi": num_phi,
        "phis": list(map(float, phis)),
        "P_s0_mode0_whichpath": p_s0_mode0,
        "P_s0_given_idler0_mode1_eraser": p_s0_id0_mode1,
        "P_s0_given_idler1_mode1_eraser": p_s0_id1_mode1,
        "notes": "Mode0=which-path (fără H pe idler) → marginală plată; Mode1=eraser (H c_if(c_mode==1) pe idler) → franjuri complementare în condiționate."
    }
    (OUT / "dcqe_one_composer_results.json").write_text(json.dumps(results, indent=2), encoding="utf-8")

    # Plot 1: which-path (marginal)
    plt.figure(figsize=(7,4))
    plt.plot(phis, p_s0_mode0, marker="o")
    plt.xlabel("φ [rad]"); plt.ylabel("P(signal=0)")
    plt.title("Mode 0 — which-path (marginal, așteptat: fără franjuri)")
    plt.grid(True, alpha=0.3); plt.tight_layout()
    plt.savefig(OUT / "mode0_whichpath_marginal.png", dpi=160); plt.close()

    # Plot 2: eraser condiționat
    plt.figure(figsize=(7,4))
    plt.plot(phis, p_s0_id0_mode1, marker="o", label="idler=0")
    plt.plot(phis, p_s0_id1_mode1, marker="s", label="idler=1")
    plt.xlabel("φ [rad]"); plt.ylabel("P(signal=0 | idler)")
    plt.title("Mode 1 — eraser (franjuri complementare)")
    plt.legend(); plt.grid(True, alpha=0.3); plt.tight_layout()
    plt.savefig(OUT / "mode1_eraser_conditionals.png", dpi=160); plt.close()

    print("\n=== DCQE (one-composer) — gata ===")
    print(f"Backend: {backend.name}")
    print(f"Rezultate: {OUT.resolve()}  →  dcqe_one_composer_results.json")
    print("Imagini:   mode0_whichpath_marginal.png, mode1_eraser_conditionals.png")
    print("Composer:  dcqe_one_composer_template.qasm (un singur fișier, ambele moduri).")
    print("           dcqe_mode0_phi_pi4.qasm / dcqe_mode1_phi_pi4.qasm (instanțieri de referință).")

if __name__ == "__main__":
    print("=== DCQE (one-composer) pe IBM Quantum — Qiskit ≥ 1.0 ===")
    token = get_token()
    if not token:
        sys.exit("Token gol. Setați IQP_API_TOKEN sau introduceți tokenul.")
    instance = os.getenv("IBM_QUANTUM_INSTANCE")
    service = connect_runtime(token, instance)
    run_on_backend(service, SHOTS, NUM_PHI)
