import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, Conv1D, MaxPooling1D, Flatten, MultiHeadAttention, LayerNormalization, GlobalAveragePooling1D
from tensorflow.keras.callbacks import Callback
import tensorflow as tf

# Configurare
FISIER = 'datebinare.txt'
SEQ_LENGTH = 50
BATCH_SIZE = 128
EPOCI_VERIFICARE = 1000  # După câte epoci verifică dacă învață

print("=" * 70)
print("🚀 PROGRAM DE ÎNVĂȚARE ADAPTIVĂ PE DATE BINARE/CUANTICE".center(70))
print("=" * 70)

# Citește fișierul
print("\n📂 Citesc fișierul datebinare.txt...")
with open(FISIER, 'r') as f:
    date_binare = f.read().strip()
    date_binare = ''.join(c for c in date_binare if c in '01')

print(f"✅ Am citit {len(date_binare):,} biți")

# Pregătește datele
print("🔧 Pregătesc secvențele...")
date = np.array([int(bit) for bit in date_binare])

X, y = [], []
for i in range(len(date) - SEQ_LENGTH):
    X.append(date[i:i+SEQ_LENGTH])
    y.append(date[i+SEQ_LENGTH])

X = np.array(X).reshape(-1, SEQ_LENGTH, 1)
y = np.array(y)

split = int(0.8 * len(X))
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]

print(f"✅ {len(X_train):,} secvențe de antrenament, {len(X_val):,} validare")

# Funcții pentru construirea diferitelor modele
def construieste_lstm():
    """Model LSTM - bun pentru secvențe temporale clasice"""
    print("🧠 MODEL 1: LSTM (Rețea Recurentă)")
    model = Sequential([
        LSTM(128, return_sequences=True, input_shape=(SEQ_LENGTH, 1)),
        Dropout(0.3),
        LSTM(64),
        Dropout(0.3),
        Dense(32, activation='relu'),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model, "LSTM"

def construieste_transformer():
    """Model Transformer - bun pentru dependențe complexe pe distanțe lungi"""
    print("🔮 MODEL 2: TRANSFORMER (Attention Mechanism)")
    print("   → Mai bun pentru pattern-uri cuantice nelocale")
    
    inputs = tf.keras.Input(shape=(SEQ_LENGTH, 1))
    
    # Expand dimensions pentru attention
    x = Dense(64)(inputs)
    
    # Multi-head attention
    attn_output = MultiHeadAttention(num_heads=8, key_dim=64)(x, x)
    x = LayerNormalization()(x + attn_output)
    
    # Feed forward
    ff_output = Dense(128, activation='relu')(x)
    ff_output = Dense(64)(ff_output)
    x = LayerNormalization()(x + ff_output)
    
    # Global pooling și output
    x = GlobalAveragePooling1D()(x)
    x = Dense(32, activation='relu')(x)
    x = Dropout(0.3)(x)
    outputs = Dense(1, activation='sigmoid')(x)
    
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model, "TRANSFORMER"

def construieste_cnn():
    """Model CNN 1D - bun pentru pattern-uri locale și repetitive"""
    print("🌊 MODEL 3: CNN 1D (Convolutional Neural Network)")
    print("   → Caută pattern-uri fractale și periodicități")
    
    model = Sequential([
        Conv1D(128, kernel_size=3, activation='relu', input_shape=(SEQ_LENGTH, 1)),
        MaxPooling1D(pool_size=2),
        Conv1D(64, kernel_size=3, activation='relu'),
        MaxPooling1D(pool_size=2),
        Conv1D(32, kernel_size=3, activation='relu'),
        Flatten(),
        Dense(64, activation='relu'),
        Dropout(0.3),
        Dense(32, activation='relu'),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model, "CNN 1D"

def construieste_hybrid():
    """Model Hibrid - CNN + LSTM pentru pattern-uri multi-nivel"""
    print("⚡ MODEL 4: HIBRID (CNN + LSTM)")
    print("   → Combinație pentru corelații cuantice complexe")
    
    model = Sequential([
        Conv1D(64, kernel_size=3, activation='relu', input_shape=(SEQ_LENGTH, 1)),
        MaxPooling1D(pool_size=2),
        LSTM(128, return_sequences=True),
        Dropout(0.3),
        LSTM(64),
        Dropout(0.3),
        Dense(32, activation='relu'),
        Dense(1, activation='sigmoid')
    ])
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model, "HIBRID"

# Inițializare grafic
plt.ion()
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
fig.suptitle('🎯 ÎNVĂȚARE ADAPTIVĂ LIVE - Predicție Date Cuantice', fontsize=18, fontweight='bold')

# Liste pentru date
epoci = []
acc_train = []
acc_val = []
loss_train = []
loss_val = []
schimbari_model = []  # Marchează când s-a schimbat modelul

# Variabile pentru tracking
model_curent = None
nume_model = ""
epoci_fara_progres = 0
ultima_acc_medie = 50.0
numar_schimbari = 0

# Lista de modele de încercat
modele_disponibile = [
    construieste_lstm,
    construieste_transformer,
    construieste_cnn,
    construieste_hybrid
]

# Callback pentru grafic live
class GraficLiveAdaptiv(Callback):
    def on_epoch_begin(self, epoch, logs=None):
        global model_curent, nume_model, epoci_fara_progres, ultima_acc_medie, numar_schimbari
        
        # Verifică dacă trebuie să schimbe modelul
        if len(acc_val) > 0 and len(acc_val) % EPOCI_VERIFICARE == 0:
            acc_medie_curenta = np.mean(acc_val[-100:]) if len(acc_val) >= 100 else np.mean(acc_val)
            
            # Dacă nu a progresat (< 1% îmbunătățire față de random)
            if acc_medie_curenta < 51.0:
                print("\n" + "!" * 70)
                print(f"⚠️  DUPĂ {len(epoci)} EPOCI: Nu învață (Acc: {acc_medie_curenta:.2f}%)")
                print("🔄 SCHIMB STRATEGIA DE ÎNVĂȚARE...")
                print("!" * 70 + "\n")
                
                numar_schimbari += 1
                
                # Alege următorul model
                if numar_schimbari < len(modele_disponibile):
                    model_curent, nume_model = modele_disponibile[numar_schimbari]()
                    self.model = model_curent
                    schimbari_model.append(len(epoci))
                else:
                    print("\n" + "=" * 70)
                    print("🛑 AM ÎNCERCAT TOATE MODELELE DISPONIBILE!".center(70))
                    print("=" * 70)
                    print("\n💡 CONCLUZIE:")
                    print("   Datele sunt CUANTIC ALEATORII (QRN - Quantum Random Numbers)")
                    print("   → Nicio rețea nu poate învăța un șir cu adevărat aleator")
                    print("   → Este matematic imposibil de prezis!")
                    print("\n   Aceasta este dovada că datele sunt autentice cuantice! 🎲")
                    print("\n⏸️  Continui antrenamentul pentru observație...")
                    print("   Apasă Ctrl+C pentru a opri\n")
    
    def on_epoch_end(self, epoch, logs=None):
        # Adaugă date
        epoci.append(len(epoci) + 1)
        acc_train.append(logs['accuracy'] * 100)
        acc_val.append(logs['val_accuracy'] * 100)
        loss_train.append(logs['loss'])
        loss_val.append(logs['val_loss'])
        
        # Curăță graficele
        axes[0].clear()
        axes[1].clear()
        
        # GRAFIC 1: ACURATEȚE
        axes[0].plot(epoci, acc_train, '-', linewidth=2, color='#2ecc71', label='Antrenament', alpha=0.6)
        axes[0].plot(epoci, acc_val, '-', linewidth=3, color='#3498db', label='Validare', alpha=0.9)
        axes[0].axhline(y=50, color='red', linestyle='--', linewidth=2, label='Random (50%)', alpha=0.7)
        
        # Marchează schimbările de model
        for schimbare in schimbari_model:
            axes[0].axvline(x=schimbare, color='orange', linestyle=':', linewidth=2, alpha=0.5)
        
        axes[0].set_xlabel('Epocă', fontsize=14, fontweight='bold')
        axes[0].set_ylabel('Acuratețe (%)', fontsize=14, fontweight='bold')
        axes[0].set_title(f'📊 ACURATEȚE - Model: {nume_model}', fontsize=15, fontweight='bold', pad=15)
        axes[0].legend(fontsize=11, loc='lower right')
        axes[0].grid(True, alpha=0.3, linestyle='--')
        axes[0].set_ylim([45, max(65, max(acc_val) + 5)])
        
        # Text cu acuratețe
        ultima_acc = acc_val[-1]
        culoare_text = '#27ae60' if ultima_acc > 52 else '#e74c3c' if ultima_acc < 48 else '#f39c12'
        axes[0].text(0.02, 0.98, f'ACURATEȚE: {ultima_acc:.1f}%', 
                     transform=axes[0].transAxes, fontsize=18, fontweight='bold',
                     verticalalignment='top', bbox=dict(boxstyle='round', 
                     facecolor=culoare_text, alpha=0.8), color='white')
        
        # Status
        if ultima_acc > 55:
            mesaj = '✅ ÎNVAȚĂ!'
            culoare = '#27ae60'
        elif ultima_acc > 52:
            mesaj = '⚠️ Învață puțin'
            culoare = '#f39c12'
        else:
            mesaj = '❌ Nu învață'
            culoare = '#e74c3c'
        
        axes[0].text(0.98, 0.98, mesaj, transform=axes[0].transAxes, 
                     fontsize=14, fontweight='bold', verticalalignment='top',
                     horizontalalignment='right', bbox=dict(boxstyle='round', 
                     facecolor=culoare, alpha=0.8), color='white')
        
        # GRAFIC 2: LOSS
        axes[1].plot(epoci, loss_train, '-', linewidth=2, color='#e74c3c', label='Antrenament', alpha=0.6)
        axes[1].plot(epoci, loss_val, '-', linewidth=3, color='#9b59b6', label='Validare', alpha=0.9)
        
        # Marchează schimbările de model
        for schimbare in schimbari_model:
            axes[1].axvline(x=schimbare, color='orange', linestyle=':', linewidth=2, alpha=0.5)
        
        axes[1].set_xlabel('Epocă', fontsize=14, fontweight='bold')
        axes[1].set_ylabel('Loss (Eroare)', fontsize=14, fontweight='bold')
        axes[1].set_title('📉 EROARE', fontsize=15, fontweight='bold', pad=15)
        axes[1].legend(fontsize=11, loc='upper right')
        axes[1].grid(True, alpha=0.3, linestyle='--')
        
        ultim_loss = loss_val[-1]
        axes[1].text(0.02, 0.98, f'LOSS: {ultim_loss:.4f}', 
                     transform=axes[1].transAxes, fontsize=18, fontweight='bold',
                     verticalalignment='top', bbox=dict(boxstyle='round', 
                     facecolor='#9b59b6', alpha=0.8), color='white')
        
        # Info în consolă
        if len(epoci) % 10 == 0:  # Afișează la fiecare 10 epoci
            print(f"Epocă {len(epoci)} | Model: {nume_model} | "
                  f"Acc: {ultima_acc:.2f}% | Loss: {ultim_loss:.4f} | "
                  f"{'✅' if ultima_acc > 52 else '❌' if ultima_acc < 48 else '⚠️'}")
        
        plt.tight_layout()
        plt.draw()
        plt.pause(0.01)

# Pornește cu primul model
print("\n" + "=" * 70)
print("🎓 ÎNCEPUT ANTRENAMENT - Rulează LA INFINIT!".center(70))
print("   Apasă Ctrl+C pentru a opri".center(70))
print("=" * 70 + "\n")

model_curent, nume_model = construieste_lstm()

try:
    # Antrenare infinită
    model_curent.fit(
        X_train, y_train,
        batch_size=BATCH_SIZE,
        epochs=999999,  # Practic infinit
        validation_data=(X_val, y_val),
        callbacks=[GraficLiveAdaptiv()],
        verbose=0
    )
except KeyboardInterrupt:
    print("\n\n" + "=" * 70)
    print("⏹️  ANTRENAMENT OPRIT DE UTILIZATOR".center(70))
    print("=" * 70)

plt.ioff()

# Rezultate finale
if len(acc_val) > 0:
    print(f"\n📊 STATISTICI FINALE:")
    print(f"   Total epoci rulate: {len(epoci)}")
    print(f"   Acuratețe finală: {acc_val[-1]:.2f}%")
    print(f"   Loss final: {loss_val[-1]:.4f}")
    print(f"   Modele încercate: {numar_schimbari + 1}")
    
    acc_medie = np.mean(acc_val[-100:]) if len(acc_val) >= 100 else np.mean(acc_val)
    print(f"\n🔍 CONCLUZIE:")
    if acc_medie > 55:
        print("   ✅ Datele CONȚIN PATTERN-URI detectabile!")
        print("   → NU sunt cuantic aleatorii pure")
    elif acc_medie > 52:
        print("   ⚠️ Există pattern-uri FOARTE SLABE")
        print("   → Posibil corelații cuantice subtile")
    else:
        print("   ❌ Datele sunt CUANTIC ALEATORII (QRN)")
        print("   → Imposibil de prezis - autenticitate cuantică confirmată!")
        print("   → Aceasta este comportamentul așteptat pentru date cuantice reale")

# Salvează
model_curent.save('model_final_cuantic.h5')
plt.savefig('evolutie_cuantica_final.png', dpi=300, bbox_inches='tight')
print(f"\n💾 Salvat: model_final_cuantic.h5")
print(f"💾 Salvat: evolutie_cuantica_final.png")

print("\n" + "=" * 70)
plt.show()
