import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import json
import os
import signal
import sys

# Handler pentru Ctrl+C
def signal_handler(sig, frame):
    print('\nOprire... Salvând modelul final.')
    torch.save(model.state_dict(), 'hybrid_lstm_new_final.pth')
    with open('hybrid_new_results.json', 'w') as f:
        json.dump(results, f, indent=4)
    sys.exit(0)

signal.signal(signal.SIGINT, signal_handler)

# 1. Citire date noi
try:
    with open('datebinare.txt', 'r') as f:
        data_str = f.read().strip()
    data = np.array([int(bit) for bit in data_str if bit in ['0', '1']])
    print("Date noi citite din fișier.")
except FileNotFoundError:
    np.random.seed(42)
    data = np.random.randint(0, 2, 65544)  # Adaptat la lungimea ta nouă
    print("Fișier nu găsit - folosesc date sintetice.")

print(f'Lungime șir nou: {len(data)}')
print(f'Exemplu primii 20 biți: {data[:20]}')

# 2. Pregătire secvențe
sequence_length = 10  # Poți schimba la 12-15 pentru pattern-uri noi
X_unsup = [data[i:i + sequence_length] for i in range(len(data) - sequence_length)]
X_unsup = np.array(X_unsup)
print(f'Număr secvențe noi: {len(X_unsup)}')

# Normalizare
scaler = StandardScaler()
X_unsup_scaled = scaler.fit_transform(X_unsup)

# 3. Unsupervised nou
n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
cluster_labels = kmeans.fit_predict(X_unsup_scaled)

pca = PCA(n_components=2)
pca_features = pca.fit_transform(X_unsup_scaled)
print(f'PCA explică {pca.explained_variance_ratio_.sum():.2%} din variație nouă.')

# One-hot
encoder = OneHotEncoder(sparse_output=False)
cluster_onehot = encoder.fit_transform(cluster_labels.reshape(-1, 1))

# 4. Date supervised
X_sup, y = [], []
for i in range(len(data) - sequence_length):
    seq = data[i:i + sequence_length].astype(float)
    cluster_feat = cluster_onehot[i]
    pca_feat = pca_features[i]
    combined = np.concatenate([seq, cluster_feat, pca_feat])  # 17 features
    X_sup.append(combined)
    y.append(data[i + sequence_length])
X_sup, y = np.array(X_sup), np.array(y)

# Split
X_train, X_test, y_train, y_test = train_test_split(X_sup, y, test_size=0.2, random_state=42, shuffle=False)
print(f'Train: {len(X_train)}, Test: {len(X_test)}')

# Dataset
class HybridDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = HybridDataset(X_train, y_train)
test_dataset = HybridDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 5. Model îmbunătățit
class HybridLSTM(nn.Module):
    def __init__(self, input_size=17, hidden_size=64, num_layers=1):
        super(HybridLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(0.3)  # Îmbunătățit
        self.fc = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        if x.dim() == 2:
            x = x.unsqueeze(1)
        out, _ = self.lstm(x)
        out = self.dropout(out[:, -1, :])
        out = self.fc(out)
        return self.sigmoid(out)

model = HybridLSTM()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-4)  # L2 reg
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=20, factor=0.8)  # Fără verbose
print("Model hibrid nou construit (corectat).")

# 6. Antrenare continuu cu early stopping
train_losses = []
epoch = 0
results = {'epochs': [], 'losses': [], 'accuracies': []}
prev_acc = 0.5
stagnare_count = 0
print("Antrenare continuă pe date noi. Oprește cu Ctrl+C sau early stopping.")

while True:
    epoch += 1
    model.train()
    total_loss = 0
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_x).squeeze()
        loss = criterion(outputs, batch_y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)
    
    # Evaluare la fiecare 10 epoci
    if epoch % 10 == 0:
        model.eval()
        y_pred = []
        with torch.no_grad():
            for batch_x, _ in test_loader:
                outputs = model(batch_x).squeeze()
                preds = (outputs > 0.5).float().numpy()
                y_pred.extend(preds)
        y_pred = np.array(y_pred)
        acc = accuracy_score(y_test, y_pred)
        print(f'Epoch {epoch}, Loss: {avg_loss:.4f}, Acuratețe test: {acc:.4f}')
        
        results['epochs'].append(epoch)
        results['losses'].append(float(avg_loss))
        results['accuracies'].append(float(acc))
        with open('hybrid_new_results.json', 'w') as f:
            json.dump(results, f, indent=4)
        
        scheduler.step(avg_loss)
        
        # Early stopping
        if acc > prev_acc:
            prev_acc = acc
            stagnare_count = 0
        else:
            stagnare_count += 1
        if stagnare_count >= 30:
            print(f'Stagnare detectată după {stagnare_count} epoci. Oprire automată.')
            break
    
    # Checkpoint și plot la 100 epoci
    if epoch % 100 == 0:
        checkpoint_path = f'hybrid_lstm_new_epoch_{epoch}.pth'
        torch.save(model.state_dict(), checkpoint_path)
        print(f'Checkpoint salvat: {checkpoint_path}')
        
        plt.figure(figsize=(10, 6))
        plt.plot(train_losses, label='Train Loss')
        plt.title(f'Evoluție Loss pe Date Noi (Epoch {epoch})')
        plt.xlabel('Epoci')
        plt.ylabel('Loss')
        plt.legend()
        plt.savefig(f'loss_new_plot_epoch_{epoch}.png')
        plt.show()
