from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator
import numpy as np

def eraser_counts_qiskit(phi, erase=True, shots=8000):
    # q0 = S (semnal), q1 = I (idler); c0 = S, c1 = I
    qc = QuantumCircuit(2, 2)
    qc.h(0)          # primul beam-splitter
    qc.rz(phi, 0)    # fază pe S
    qc.cx(0, 1)      # marcaj which-path (entanglare)
    qc.h(0)          # al doilea beam-splitter pe S  <-- ADĂUGAT
    if erase:
        qc.h(1)      # măsurare I în baza X (eraser)
    qc.measure(0, 0) # S -> c0
    qc.measure(1, 1) # I -> c1

    backend = AerSimulator()
    tqc = transpile(qc, backend=backend, optimization_level=3)
    res = backend.run(tqc, shots=shots).result()
    return res.get_counts()

def conditional_prob(counts, i_bit='0'):
    # Cheile sunt 'i s' (c1 c0), ex. '10'
    tot = n0 = 0
    for k, v in counts.items():
        i, s = k[0], k[1]
        if i == i_bit:
            tot += v
            if s == '0':
                n0 += v
    return (n0/tot) if tot else 0.0

# Scan: acum ar trebui să vedeți franji complementare în modul erase=True
phis = np.linspace(0, 2*np.pi, 13)
for phi in phis:
    cE = eraser_counts_qiskit(phi, erase=True, shots=4000)
    p_plus  = conditional_prob(cE, i_bit='0')  # i=0 ~ '+'
    p_minus = conditional_prob(cE, i_bit='1')  # i=1 ~ '-'
    print(f"phi={phi:.2f}  P(s=0|i=+)={p_plus:.3f}  P(s=0|i=-)={p_minus:.3f}")
