import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import json
import os
import signal
import sys

# Handler pentru Ctrl+C (salvare la oprire)
def signal_handler(sig, frame):
    print('\nOprește antrenarea... Salvând modelul final.')
    torch.save(model.state_dict(), f'hybrid_lstm_final.pth')
    with open('hybrid_results.json', 'w') as f:
        json.dump(results, f, indent=4)
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

# 1. Citire date
try:
    with open('datebinare.txt', 'r') as f:
        data_str = f.read().strip()
    data = np.array([int(bit) for bit in data_str if bit in ['0', '1']])
    print("Date citite din fișier.")
except FileNotFoundError:
    np.random.seed(42)
    data = np.random.randint(0, 2, 11776)  # Adaptat la lungimea ta
    print("Fișier nu găsit - folosesc date sintetice.")

print(f'Lungime șir: {len(data)}')
print(f'Exemplu primii 20 biți: {data[:20]}')

# 2. Pregătire secvențe pentru unsupervised
sequence_length = 10
X_unsup = []
for i in range(len(data) - sequence_length):
    X_unsup.append(data[i:i + sequence_length])
X_unsup = np.array(X_unsup)
print(f'Număr secvențe: {len(X_unsup)}')

# Normalizare
scaler = StandardScaler()
X_unsup_scaled = scaler.fit_transform(X_unsup)

# 3. Unsupervised: KMeans + PCA
n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
cluster_labels = kmeans.fit_predict(X_unsup_scaled)

pca = PCA(n_components=2)
pca_features = pca.fit_transform(X_unsup_scaled)
print(f'PCA explică {pca.explained_variance_ratio_.sum():.2%} din variație.')

# One-hot pentru clustere
encoder = OneHotEncoder(sparse_output=False)
cluster_onehot = encoder.fit_transform(cluster_labels.reshape(-1, 1))

# 4. Pregătire date pentru supervised
X_sup, y = [], []
for i in range(len(data) - sequence_length):
    seq = data[i:i + sequence_length].astype(float)
    cluster_feat = cluster_onehot[i]
    pca_feat = pca_features[i]
    combined = np.concatenate([seq, cluster_feat, pca_feat])  # 10 + 5 + 2 = 17 features
    X_sup.append(combined)
    y.append(data[i + sequence_length])
X_sup, y = np.array(X_sup), np.array(y)

# Split
X_train, X_test, y_train, y_test = train_test_split(X_sup, y, test_size=0.2, random_state=42, shuffle=False)
print(f'Train: {len(X_train)}, Test: {len(X_test)}')

# Dataset PyTorch
class HybridDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = HybridDataset(X_train, y_train)
test_dataset = HybridDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 5. Model LSTM Hibrid (îmbunătățit pentru epoci lungi)
class HybridLSTM(nn.Module):
    def __init__(self, input_size=17, hidden_size=64, num_layers=1):
        super(HybridLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(0.2)  # Adăugat pentru regularizare
        self.fc = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        if x.dim() == 2:  # Dacă flat, adaugă seq_len=1
            x = x.unsqueeze(1)
        out, _ = self.lstm(x)
        out = self.dropout(out[:, -1, :])  # Dropout după ultimul timestep
        out = self.fc(out)
        return self.sigmoid(out)

model = HybridLSTM()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)  # LR mai mic pentru stabilitate
print("Model hibrid continuu construit.")

# 6. Antrenare continuă
train_losses = []
epoch = 0
results = {'epochs': [], 'losses': [], 'accuracies': []}
print("Încep antrenarea continuă. Oprește cu Ctrl+C.")

while True:  # Loop infinit
    epoch += 1
    model.train()
    total_loss = 0
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_x).squeeze()
        loss = criterion(outputs, batch_y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)
    
    # Evaluare acuratețe la fiecare 5 epoci
    if epoch % 5 == 0:
        model.eval()
        y_pred = []
        with torch.no_grad():
            for batch_x, _ in test_loader:
                outputs = model(batch_x).squeeze()
                preds = (outputs > 0.5).float().numpy()
                y_pred.extend(preds)
        y_pred = np.array(y_pred)
        acc = accuracy_score(y_test, y_pred)
        print(f'Epoch {epoch}, Loss: {avg_loss:.4f}, Acuratețe test: {acc:.4f}')
        
        # Actualizează rezultate
        results['epochs'].append(epoch)
        results['losses'].append(float(avg_loss))
        results['accuracies'].append(float(acc))
        with open('hybrid_results.json', 'w') as f:
            json.dump(results, f, indent=4)
    
    # Checkpoint și plot la fiecare 50 epoci
    if epoch % 50 == 0:
        checkpoint_path = f'hybrid_lstm_epoch_{epoch}.pth'
        torch.save(model.state_dict(), checkpoint_path)
        print(f'Checkpoint salvat: {checkpoint_path}')
        
        # Plot loss
        plt.figure(figsize=(10, 6))
        plt.plot(train_losses, label='Train Loss')
        plt.title(f'Evoluție Loss după {epoch} Epoci')
        plt.xlabel('Epoci')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(f'loss_plot_epoch_{epoch}.png')  # Salvează ca imagine
        plt.show()  # Deschide fereastră

# (Codul nu ajunge aici - se oprește cu Ctrl+C)
