import numpy as np
from flask import Flask, jsonify, request, render_template

# Opțional: QuTiP pentru modul „quantum”
try:
    import qutip as qt
    HAS_QUTIP = True
except Exception:
    HAS_QUTIP = False

app = Flask(__name__)


def intensity_simple(x, V, phi):
    """Model interferență simplu: I(x) ~ 1 + V cos(2x + φ)."""
    I0 = 1.0
    I = I0 * (1.0 + V * np.cos(2.0 * x + phi))
    return np.clip(I, 0.0, None)


def add_lab_noise(I, noise_level=0.2, phase_jitter=0.3):
    """
    Adaugă zgomot „laborator”:
      - fază care variază ușor;
      - zgomot multiplicativ și aditiv.
    """
    n = len(I)
    x = np.linspace(-np.pi, np.pi, n)

    # jitter de fază (simulăm vibrații / instabilitate)
    phi_rand = np.random.normal(0.0, phase_jitter)
    I_jitter = intensity_simple(x, 1.0, phi_rand)

    # zgomot multiplicativ (fluctuații de intensitate)
    mult_noise = 1.0 + noise_level * np.random.normal(0.0, 1.0, size=n)

    # zgomot aditiv (dark counts etc)
    add_noise = noise_level * 0.5 * np.random.rand(n)

    I_noisy = I * mult_noise + add_noise + 0.1 * I_jitter
    return np.clip(I_noisy, 0.0, None)


def intensity_quantum_with_qutip(x, V, phi):
    """
    Exemplu minimal de utilizare QuTiP.
    Nu construim un model complet multi-mod,
    ci folosim QuTiP pentru a ilustra cum am trata
    amplitudinile a două căi coerente.

    În esență, rezultatul e tot un I(x) ~ 1 + V cos(2x+φ),
    dar cu amplitudini construite explicit ca vectori de stare.
    """
    if not HAS_QUTIP:
        return intensity_simple(x, V, phi)

    # spațiu Hilbert pentru „două căi” (0 = cristal 1, 1 = cristal 2)
    zero = qt.basis(2, 0)
    one = qt.basis(2, 1)

    # stare superpusă |ψ> = (|0> + e^{iφ}|1>)/√2
    psi = (zero + np.exp(1j * phi) * one).unit()

    # operator de „detecție” pentru interferență:
    # proiectăm pe baza (|0> ± |1>)/√2 în funcție de x
    I = []
    for xx in x:
        # fază dependentă de poziție (2x)
        phase_x = 2.0 * xx
        plus = (zero + np.exp(1j * phase_x) * one).unit()
        # intensitatea este ~ |<plus|ψ>|^2, cu vizibilitate V
        amp = (plus.dag() * psi)[0, 0]
        prob = np.abs(amp) ** 2
        # interpolăm între uniform și prob pentru a impune V
        I_point = 1.0 * (1.0 - V) + V * (2.0 * prob)
        I.append(I_point)

    I = np.array(I)
    return np.clip(I, 0.0, None)


def sample_hits(x, I, n_photons, efficiency=1.0):
    """Eșantionăm „fotoni” după distribuția I(x), cu eficiență de detecție < 1."""
    n_photons = max(10, int(n_photons))
    eff = max(0.01, min(1.0, efficiency))

    prob = I / np.sum(I)
    idx = np.random.choice(len(x), size=n_photons, p=prob)
    hits = x[idx]

    # modelăm eficiența detectorului (pierderi)
    keep_mask = np.random.rand(len(hits)) < eff
    hits_eff = hits[keep_mask]
    return hits_eff


def estimate_visibility(I):
    I_max = float(np.max(I))
    I_min = float(np.min(I))
    if I_max + I_min == 0:
        return 0.0
    return (I_max - I_min) / (I_max + I_min)


@app.route("/")
def index():
    return render_template("index.html", has_qutip=HAS_QUTIP)


@app.route("/simulate", methods=["POST"])
def simulate():
    data = request.get_json(force=True)

    mode = data.get("mode", "simple")  # "simple", "quantum", "lab"
    V_base = float(data.get("visibility", 1.0))
    phi = float(data.get("phase", 0.0))
    n_photons = int(data.get("photons", 2000))
    blocked = bool(data.get("blocked", False))
    noise_level = float(data.get("noise", 0.2))
    efficiency = float(data.get("efficiency", 0.8))

    # Dacă Brațul 2 este blocat ⇒ V_eff = 0
    if blocked:
        V_eff = 0.0
    else:
        V_eff = max(0.0, min(1.0, V_base))

    x = np.linspace(-np.pi, np.pi, 600)

    # Alegem modelul de intensitate
    if mode == "quantum":
        I = intensity_quantum_with_qutip(x, V_eff, phi)
    else:
        I = intensity_simple(x, V_eff, phi)

    # Adăugăm „mod laborator” dacă e ales
    if mode == "lab":
        I = add_lab_noise(I, noise_level=noise_level, phase_jitter=noise_level)

    hits = sample_hits(x, I, n_photons, efficiency=efficiency)
    V_meas = estimate_visibility(I)

    return jsonify({
        "x": x.tolist(),
        "I": I.tolist(),
        "hits": hits.tolist(),
        "V_meas": V_meas,
        "mode": mode,
        "blocked": blocked,
        "V_eff": V_eff,
        "has_qutip": HAS_QUTIP
    })


@app.route("/message_demo", methods=["POST"])
def message_demo():
    data = request.get_json(force=True)
    bits_str = data.get("bits", "10110")
    mode = data.get("mode", "simple")
    V_base = float(data.get("visibility", 1.0))
    phi = float(data.get("phase", 0.0))
    n_photons = int(data.get("photons", 2000))
    noise_level = float(data.get("noise", 0.2))
    efficiency = float(data.get("efficiency", 0.8))
    threshold = float(data.get("threshold", 0.4))

    bits = [b for b in bits_str if b in ("0", "1")]
    if not bits:
        return jsonify({"error": "Niciun bit valid (0/1) în șir."}), 400

    results = []
    decoded_bits = []

    x = np.linspace(-np.pi, np.pi, 600)

    for b in bits:
        blocked = (b == "0")
        V_eff = 0.0 if blocked else V_base

        if mode == "quantum":
            I = intensity_quantum_with_qutip(x, V_eff, phi)
        else:
            I = intensity_simple(x, V_eff, phi)

        if mode == "lab":
            I = add_lab_noise(I, noise_level=noise_level, phase_jitter=noise_level)

        hits = sample_hits(x, I, n_photons, efficiency=efficiency)
        V_meas = estimate_visibility(I)
        decoded = "1" if V_meas >= threshold else "0"
        decoded_bits.append(decoded)

        results.append({
            "bit_sent": b,
            "blocked": blocked,
            "V_eff": V_eff,
            "V_meas": V_meas,
            "decoded": decoded,
            "n_hits": int(len(hits))
        })

    errors = sum(1 for a, b in zip(bits, decoded_bits) if a != b)

    return jsonify({
        "bits_sent": "".join(bits),
        "bits_decoded": "".join(decoded_bits),
        "errors": errors,
        "per_bit": results
    })


if __name__ == "__main__":
    # Rulați local: http://127.0.0.1:5000
    app.run(debug=True)
