import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt  # Pentru plot-uri opționale

# 1. Citire/Generare date (înlocuiește cu fișierul tău)
# Cu fișier: with open('datebinare.txt', 'r') as f: data_str = f.read().strip(); data = np.array([int(bit) for bit in data_str if bit in ['0', '1']])
np.random.seed(42)  # Pentru reproducibilitate
data = np.random.randint(0, 2, 7088)  # Date sintetice similare
print(f'Lungime șir: {len(data)}')
print(f'Exemplu primii 20 biți: {data[:20]}')

# 2. Pregătire date
sequence_length = 10  # Lungime secvență (poți ajusta la 5-20)
X, y = [], []
for i in range(len(data) - sequence_length):
    X.append(data[i:i + sequence_length])
    y.append(data[i + sequence_length])
X, y = np.array(X), np.array(y)
X = X.reshape((X.shape[0], X.shape[1], 1))  # (samples, timesteps, features)

# Split: 80% train, 20% test (fără shuffle pentru secvențe)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, shuffle=False)
print(f'Train: {len(X_train)} sample-uri, Test: {len(X_test)} sample-uri')

# 3. Dataset PyTorch
class BinaryDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = BinaryDataset(X_train, y_train)
test_dataset = BinaryDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)  # shuffle=False pentru ordine temporală
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 4. Model LSTM
class LSTMModel(nn.Module):
    def __init__(self, input_size=1, hidden_size=50, num_layers=2):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        out, _ = self.lstm(x)
        out = self.fc(out[:, -1, :])  # Ia ultimul output al secvenței
        return self.sigmoid(out)

model = LSTMModel()
criterion = nn.BCELoss()  # Pentru clasificare binară
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("Model construit (sumar):")
print(model)

# 5. Antrenare
epochs = 20  # Crește la 50 dacă vrei mai multă convergență
train_losses = []
model.train()
for epoch in range(epochs):
    total_loss = 0
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_x).squeeze()
        loss = criterion(outputs, batch_y.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)
    if (epoch + 1) % 5 == 0:
        print(f'Epoch {epoch+1}/{epochs}, Loss mediu: {avg_loss:.4f}')

# 6. Evaluare
model.eval()
y_pred = []
with torch.no_grad():
    for batch_x, _ in test_loader:
        outputs = model(batch_x).squeeze()
        preds = (outputs > 0.5).float().numpy()  # Threshold 0.5 pentru binar
        y_pred.extend(preds)
y_pred = np.array(y_pred)
acc = accuracy_score(y_test, y_pred)
print(f'\nAcuratețe pe test: {acc:.4f} (peste 0.5 înseamnă pattern-uri capturate)')

# 7. Exemplu predicții (primii 5 din test)
print('\nExemplu 5 predicții:')
for i in range(5):
    input_seq = X_test[i]
    with torch.no_grad():
        pred_prob = model(torch.FloatTensor(input_seq).unsqueeze(0)).item()
        pred = 1 if pred_prob > 0.5 else 0
    real = y_test[i]
    print(f'Input: {input_seq.flatten()} -> Predicție: {pred} (prob: {pred_prob:.3f}, Real: {real})')

# 8. Plot istoric loss (opțional)
plt.figure(figsize=(6, 4))
plt.plot(train_losses, label='Train Loss')
plt.title('Evoluție Loss în Antrenare')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Salvare model (opțional)
torch.save(model.state_dict(), 'lstm_binary_model.pth')
print("\nModel salvat ca 'lstm_binary_model.pth'. Încărcare: model.load_state_dict(torch.load('lstm_binary_model.pth'))")
