# bell_chsh_hardware.py
# Verificare CHSH pe hardware IBM Quantum, cu control separabil, 2000 shots/caz.
# Generează scheme PNG + OpenQASM pentru reproducere în IBM Quantum Composer.

import os, math, json, sys
import getpass
from pathlib import Path
from datetime import datetime

from qiskit import QuantumCircuit
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager

from qiskit_ibm_runtime import QiskitRuntimeService
from qiskit_ibm_runtime import SamplerV2 as Sampler  # shots pe hardware suportat

OUTDIR = Path("chsh_outputs")
OUTDIR.mkdir(exist_ok=True)

# --- Export QASM3 dacă e disponibil, altfel fallback la .qasm() ---
def dump_qasm(circ: QuantumCircuit) -> str:
    try:
        from qiskit.qasm3 import dumps as qasm3_dumps  # Qiskit 1.x+
        return qasm3_dumps(circ)
    except Exception:
        return circ.qasm()  # fallback (QASM2/QASM compat)

def save_artifacts(circ: QuantumCircuit, stem: str):
    # PNG (dacă matplotlib e instalat)
    try:
        circ.draw("mpl").savefig(OUTDIR / f"{stem}.png", dpi=180, bbox_inches="tight")
    except Exception:
        pass
    # QASM
    try:
        (OUTDIR / f"{stem}.qasm").write_text(dump_qasm(circ), encoding="utf-8")
    except Exception:
        pass

# --- Baze CHSH: măsurare într-o bază rotită în planul X–Z prin RY(-2θ) + măsurare Z ---
def chsh_rot_to_z(qc: QuantumCircuit, q: int, angle: float):
    qc.ry(-2.0 * angle, q)

def build_entangled_chsh_circuit(a_angle: float, b_angle: float, name="E") -> QuantumCircuit:
    """
    |Φ+> = (|00> + |11>)/√2; apoi măsurări în baze rotite (A,B).
    """
    qc = QuantumCircuit(2, 2, name=name)
    qc.h(0)
    qc.cx(0, 1)
    chsh_rot_to_z(qc, 0, a_angle)  # Alice (q0)
    chsh_rot_to_z(qc, 1, b_angle)  # Bob   (q1)
    qc.measure([0, 1], [0, 1])
    return qc

def build_separable_chsh_circuit(a_angle: float, b_angle: float, name="S") -> QuantumCircuit:
    """
    Stare separabilă |00>, aceleași baze (A,B), apoi măsurare în Z.
    """
    qc = QuantumCircuit(2, 2, name=name)
    chsh_rot_to_z(qc, 0, a_angle)
    chsh_rot_to_z(qc, 1, b_angle)
    qc.measure([0, 1], [0, 1])
    return qc

def expectation_from_counts(counts: dict) -> float:
    """
    E = P(00)+P(11) - P(01) - P(10), pe baza numărării bitstring-urilor.
    """
    shots = max(1, sum(counts.values()))
    p00 = counts.get("00", 0) / shots
    p01 = counts.get("01", 0) / shots
    p10 = counts.get("10", 0) / shots
    p11 = counts.get("11", 0) / shots
    return (p00 + p11) - (p01 + p10)

def compute_chsh_from_four(counts_ab, counts_abp, counts_apb, counts_apbp) -> float:
    """
    S = E(a,b) + E(a,b') + E(a',b) - E(a',b')
    """
    Eab   = expectation_from_counts(counts_ab)
    Eabp  = expectation_from_counts(counts_abp)
    Eapb  = expectation_from_counts(counts_apb)
    Eapbp = expectation_from_counts(counts_apbp)
    return Eab + Eabp + Eapb - Eapbp

def get_token() -> str:
    # 1) variabilă de mediu (recomandat pentru Windows/IDE)
    env_token = os.getenv("IQP_API_TOKEN") or os.getenv("IBM_QUANTUM_API_TOKEN")
    if env_token:
        return env_token.strip()
    # 2) prompt (poate afișa ecou pe unele terminale Windows)
    try:
        return getpass.getpass("Introduceți IBM Quantum API token: ").strip()
    except Exception:
        return input("Introduceți IBM Quantum API token (atenție: se va afișa): ").strip()

if __name__ == "__main__":
    print("=== CHSH pe IBM Quantum Hardware (cu control separabil) ===")
    token = get_token()
    if not token:
        sys.exit("Token gol. Reporniți cu un token valid (ideal prin variabila de mediu IQP_API_TOKEN).")

    # Instanță (opțional): setați IBM_QUANTUM_INSTANCE în env dacă folosiți o instanță enterprise
    instance = os.getenv("IBM_QUANTUM_INSTANCE")  # ex: "ibm-q/open/main", dacă e cazul

    # Conectare pe canalul corect (IBM Quantum Platform)
    # (ibm_cloud rămâne acceptat, dar ibm_quantum_platform este recomandat)
    service = QiskitRuntimeService(channel="ibm_quantum_platform", token=token, instance=instance)

    # Backend real: cel mai "liber" (operational, non-simulator, ≥2 qubiți)
    try:
        backend = service.least_busy(operational=True, simulator=False, min_num_qubits=2)
    except Exception:
        backends = service.backends(operational=True, simulator=False, min_num_qubits=2)
        if not backends:
            sys.exit("Nu s-a găsit niciun backend hardware disponibil.")
        # sortare de rezervă după pending_jobs (dacă expune status)
        def pending(b):
            try:
                return getattr(b.status(), "pending_jobs", 10**9)
            except Exception:
                return 10**9
        backend = sorted(backends, key=pending)[0]

    print(f"[Backend ales] {backend.name}")

    # Transpilare ISA pentru backend-ul selectat (opt_level 3)
    pm = generate_preset_pass_manager(target=backend.target, optimization_level=3)

    # Setări standard CHSH pentru violare maximă:
    # A0 = 0, A1 = π/4 ; B0 = π/8, B1 = −π/8
    A0 = 0.0
    A1 = math.pi / 4.0
    B0 = math.pi / 8.0
    B1 = -math.pi / 8.0

    # 4 pub-uri (A,B) pentru ENTANGLED
    circ_E_ab    = build_entangled_chsh_circuit(A0, B0, "E_ab")
    circ_E_abp   = build_entangled_chsh_circuit(A0, B1, "E_abp")
    circ_E_apb   = build_entangled_chsh_circuit(A1, B0, "E_apb")
    circ_E_apbp  = build_entangled_chsh_circuit(A1, B1, "E_apbp")

    # 4 pub-uri (A,B) pentru SEPARABLE
    circ_S_ab    = build_separable_chsh_circuit(A0, B0, "S_ab")
    circ_S_abp   = build_separable_chsh_circuit(A0, B1, "S_abp")
    circ_S_apb   = build_separable_chsh_circuit(A1, B0, "S_apb")
    circ_S_apbp  = build_separable_chsh_circuit(A1, B1, "S_apbp")

    # Salvați scheme + QASM pentru reproducere în Composer
    for stem, c in [
        ("entangled_ab",   circ_E_ab),
        ("entangled_abp",  circ_E_abp),
        ("entangled_apb",  circ_E_apb),
        ("entangled_apbp", circ_E_apbp),
        ("separable_ab",   circ_S_ab),
        ("separable_abp",  circ_S_abp),
        ("separable_apb",  circ_S_apb),
        ("separable_apbp", circ_S_apbp),
    ]:
        save_artifacts(c, stem)

    # Transpilare pe target (ISA) – reduce erorile la execuție
    entangled_list = [pm.run(c) for c in [circ_E_ab, circ_E_abp, circ_E_apb, circ_E_apbp]]
    separable_list = [pm.run(c) for c in [circ_S_ab, circ_S_abp, circ_S_apb, circ_S_apbp]]

    # Execuție cu SamplerV2 pe hardware, 2000 shots / circuit
    SHOTS = 2000
    sampler = Sampler(mode=backend)

    job_ent = sampler.run(entangled_list, shots=SHOTS)
    res_ent = job_ent.result()  # PrimitiveResult

    job_sep = sampler.run(separable_list, shots=SHOTS)
    res_sep = job_sep.result()

    # Extragere counts robustă în V2: join_data().get_counts()
    def counts_of(idx, prim_result):
        try:
            return prim_result[idx].join_data().get_counts()
        except Exception as e:
            # Fallback minim: dacă e expus bitstrings
            d = getattr(prim_result[idx], "data", None)
            bs = getattr(d, "bitstrings", None) if d else None
            if bs is not None:
                acc = {}
                for bits in bs:
                    s = "".join(str(int(b)) for b in bits)
                    acc[s] = acc.get(s, 0) + 1
                return acc
            raise RuntimeError(f"Nu pot obține counts din rezultat: {e}")

    E_counts = [counts_of(i, res_ent) for i in range(4)]
    S_counts = [counts_of(i, res_sep) for i in range(4)]

    S_entangled = compute_chsh_from_four(*E_counts)
    S_separable = compute_chsh_from_four(*S_counts)

    out = {
        "timestamp": datetime.utcnow().isoformat() + "Z",
        "backend": backend.name,
        "shots": SHOTS,
        "angles_rad": {"A0": A0, "A1": A1, "B0": B0, "B1": B1},
        "S_entangled": S_entangled,
        "S_separable": S_separable,
        "bounds": {"classical": 2.0, "tsirelson": 2.0 * math.sqrt(2.0)},
        "notes": "Așteptat: |S_entangled| > 2 (violare), |S_separable| ≤ 2 (control), cu variații datorate zgomotului hardware."
    }
    (OUTDIR / "results.json").write_text(json.dumps(out, indent=2), encoding="utf-8")

    print("\n=== Rezultate CHSH (2000 shots/circuit) ===")
    print(f"Backend: {backend.name}")
    print(f"S (entangled):  {S_entangled:.3f}")
    print(f"S (separable):  {S_separable:.3f}")
    print("Limite:          |S| ≤ 2 (clasic),  |S| ≤ 2√2 ≈ 2.828 (cuantic)")
    print(f"\nArtefacte salvate în: {OUTDIR.resolve()}")
    print(" - entangled_*.png / *.qasm")
    print(" - separable_*.png / *.qasm")
    print(" - results.json")
