import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.callbacks import Callback
import time

# Configurare
FISIER = 'datebinare.txt'
SEQ_LENGTH = 50
BATCH_SIZE = 128
EPOCHS = 30

print("=" * 70)
print("🚀 PROGRAM DE ÎNVĂȚARE PE DATE BINARE".center(70))
print("=" * 70)

# Citește fișierul
print("\n📂 Citesc fișierul datebinare.txt...")
with open(FISIER, 'r') as f:
    date_binare = f.read().strip()
    date_binare = ''.join(c for c in date_binare if c in '01')

print(f"✅ Am citit {len(date_binare):,} biți")

# Pregătește datele
print("🔧 Pregătesc secvențele...")
date = np.array([int(bit) for bit in date_binare])

X, y = [], []
for i in range(len(date) - SEQ_LENGTH):
    X.append(date[i:i+SEQ_LENGTH])
    y.append(date[i+SEQ_LENGTH])

X = np.array(X).reshape(-1, SEQ_LENGTH, 1)
y = np.array(y)

split = int(0.8 * len(X))
X_train, X_val = X[:split], X[split:]
y_train, y_val = y[:split], y[split:]

print(f"✅ {len(X_train):,} secvențe de antrenament, {len(X_val):,} validare")

# Construiește modelul
print("🧠 Construiesc rețeaua neuronală LSTM...")
model = Sequential([
    LSTM(128, return_sequences=True, input_shape=(SEQ_LENGTH, 1)),
    Dropout(0.2),
    LSTM(64),
    Dropout(0.2),
    Dense(32, activation='relu'),
    Dense(1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
print("✅ Model gata!\n")

# Configurare grafic live
plt.ion()  # Mode interactiv
fig, axes = plt.subplots(2, 1, figsize=(12, 10))
fig.suptitle('🎯 ÎNVĂȚARE LIVE - Predicție Biți Binari', fontsize=18, fontweight='bold')

# Liste pentru date
epoci = []
acc_train = []
acc_val = []
loss_train = []
loss_val = []

# Callback personalizat pentru grafic live
class GraficLive(Callback):
    def on_epoch_end(self, epoch, logs=None):
        # Adaugă date
        epoci.append(epoch + 1)
        acc_train.append(logs['accuracy'] * 100)
        acc_val.append(logs['val_accuracy'] * 100)
        loss_train.append(logs['loss'])
        loss_val.append(logs['val_loss'])
        
        # Curăță graficele
        axes[0].clear()
        axes[1].clear()
        
        # GRAFIC 1: ACURATEȚE (procentaj)
        axes[0].plot(epoci, acc_train, 'o-', linewidth=3, markersize=8, 
                     color='#2ecc71', label='Antrenament', alpha=0.8)
        axes[0].plot(epoci, acc_val, 's-', linewidth=3, markersize=8, 
                     color='#3498db', label='Validare', alpha=0.8)
        axes[0].axhline(y=50, color='red', linestyle='--', linewidth=2, 
                        label='Random (50%)', alpha=0.7)
        
        axes[0].set_xlabel('Epocă', fontsize=14, fontweight='bold')
        axes[0].set_ylabel('Acuratețe (%)', fontsize=14, fontweight='bold')
        axes[0].set_title('📊 CÂT DE BINE PREZICE? (mai sus = mai bine)', 
                          fontsize=15, fontweight='bold', pad=15)
        axes[0].legend(fontsize=12, loc='lower right')
        axes[0].grid(True, alpha=0.3, linestyle='--')
        axes[0].set_ylim([45, max(65, max(acc_val) + 5)])
        
        # Adaugă text mare cu ultima acuratețe
        ultima_acc = acc_val[-1]
        culoare_text = '#27ae60' if ultima_acc > 52 else '#e74c3c' if ultima_acc < 48 else '#f39c12'
        axes[0].text(0.02, 0.98, f'ACURATEȚE: {ultima_acc:.1f}%', 
                     transform=axes[0].transAxes, fontsize=20, fontweight='bold',
                     verticalalignment='top', bbox=dict(boxstyle='round', 
                     facecolor=culoare_text, alpha=0.8), color='white')
        
        # Interpretare
        if ultima_acc > 55:
            mesaj = '✅ ÎNVAȚĂ! Există pattern-uri!'
            culoare = '#27ae60'
        elif ultima_acc > 52:
            mesaj = '⚠️ Învață puțin...'
            culoare = '#f39c12'
        else:
            mesaj = '❌ Nu învață - date aleatorii'
            culoare = '#e74c3c'
        
        axes[0].text(0.98, 0.98, mesaj, transform=axes[0].transAxes, 
                     fontsize=16, fontweight='bold', verticalalignment='top',
                     horizontalalignment='right', bbox=dict(boxstyle='round', 
                     facecolor=culoare, alpha=0.8), color='white')
        
        # GRAFIC 2: LOSS (eroare)
        axes[1].plot(epoci, loss_train, 'o-', linewidth=3, markersize=8,
                     color='#e74c3c', label='Antrenament', alpha=0.8)
        axes[1].plot(epoci, loss_val, 's-', linewidth=3, markersize=8,
                     color='#9b59b6', label='Validare', alpha=0.8)
        
        axes[1].set_xlabel('Epocă', fontsize=14, fontweight='bold')
        axes[1].set_ylabel('Loss (Eroare)', fontsize=14, fontweight='bold')
        axes[1].set_title('📉 EROARE (mai jos = mai bine)', 
                          fontsize=15, fontweight='bold', pad=15)
        axes[1].legend(fontsize=12, loc='upper right')
        axes[1].grid(True, alpha=0.3, linestyle='--')
        
        # Text cu loss-ul
        ultim_loss = loss_val[-1]
        axes[1].text(0.02, 0.98, f'LOSS: {ultim_loss:.4f}', 
                     transform=axes[1].transAxes, fontsize=20, fontweight='bold',
                     verticalalignment='top', bbox=dict(boxstyle='round', 
                     facecolor='#9b59b6', alpha=0.8), color='white')
        
        # Afișează progres în consolă
        print(f"Epocă {epoch+1}/{EPOCHS} | "
              f"Acuratețe: {ultima_acc:.2f}% | "
              f"Loss: {ultim_loss:.4f} | "
              f"{'✅ ÎNVAȚĂ!' if ultima_acc > 52 else '❌ Nu învață' if ultima_acc < 48 else '⚠️ Incert'}")
        
        plt.tight_layout()
        plt.draw()
        plt.pause(0.1)  # Pauză mică pentru a vedea graficul

print("\n" + "=" * 70)
print("🎓 ÎNCEPE ANTRENAMENTUL - Urmărește graficul!".center(70))
print("=" * 70 + "\n")

# Antrenează cu grafic live
history = model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val, y_val),
    callbacks=[GraficLive()],
    verbose=0  # Dezactivează output-ul default
)

plt.ioff()  # Dezactivează mode interactiv

# Rezultate finale
print("\n" + "=" * 70)
print("🏁 ANTRENAMENT FINALIZAT!".center(70))
print("=" * 70)

final_acc = acc_val[-1]
final_loss = loss_val[-1]

print(f"\n📊 REZULTATE FINALE:")
print(f"   Acuratețe Validare: {final_acc:.2f}%")
print(f"   Loss Validare: {final_loss:.4f}")

print(f"\n🔍 INTERPRETARE:")
if final_acc > 55:
    print("   ✅ SUCCES! Modelul a învățat pattern-uri clare în date!")
    print("   → Datele NU sunt complet aleatorii")
elif final_acc > 52:
    print("   ⚠️  Modelul a detectat pattern-uri slabe")
    print("   → Există o ușoară structură în date")
else:
    print("   ❌ Modelul nu a învățat nimic (acuratețe ~50%)")
    print("   → Datele sunt aproape complet aleatorii")
    print("   → Nicio rețea nu poate prezice un șir cu adevărat aleator")

# Salvează
model.save('model_predictie_binara.h5')
plt.savefig('evolutie_invatare_final.png', dpi=300, bbox_inches='tight')
print(f"\n💾 Salvat: model_predictie_binara.h5")
print(f"💾 Salvat: evolutie_invatare_final.png")

print("\n" + "=" * 70)
print("Apasă orice tastă pentru a închide graficul...".center(70))
print("=" * 70)
plt.show()
