import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# 1. Citire date binare
try:
    with open('datebinare.txt', 'r') as f:
        data_str = f.read().strip()
    data = np.array([int(bit) for bit in data_str if bit in ['0', '1']])
    print("Date binare citite din fișier.")
except FileNotFoundError:
    data = np.random.randint(0, 2, 65544)  # Default pentru test
    print("Folosit date sintetice.")

print(f'Lungime șir: {len(data)}')
print(f'Exemplu primii 20 biți: {data[:20]}')

# 2. Pregătire dataset binar (secvențe glisante)
sequence_length = 256  # Context lung pentru dependențe pe distanțe mari
X, y = [], []
for i in range(len(data) - sequence_length):
    X.append(data[i:i + sequence_length])
    y.append(data[i + sequence_length])  # Bit următor ca label
X, y = np.array(X), np.array(y)
print(f'Număr sample-uri: {len(X)}')

# Split (80/20, fără shuffle pentru secvențe temporale)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, shuffle=False)
print(f'Train: {len(X_train)}, Test: {len(X_test)}')

# Dataset PyTorch
class BinaryDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.LongTensor(X)  # Tokeni 0/1
        self.y = torch.LongTensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = BinaryDataset(X_train, y_train)
test_dataset = BinaryDataset(X_test, y_test)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 3. Model Transformer Binar Simplu (decoder-only ca GPT mic)
class BinaryTransformer(nn.Module):
    def __init__(self, vocab_size=2, d_model=128, nhead=4, num_layers=4, max_len=256):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.zeros(1, max_len, d_model))
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        seq_len = x.size(1)
        emb = self.embedding(x) + self.pos_encoding[:, :seq_len, :]
        # Maschează pentru causal (nu vezi viitorul)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
        out = self.transformer(emb, emb, tgt_mask=tgt_mask)  # Self-attention
        return self.fc(out[:, -1, :])  # Predicție doar pe ultimul token

model = BinaryTransformer()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print("Model Transformer binar construit.")

# 4. Antrenare simplă (5 epoci pentru test rapid)
epochs = 5
train_losses = []
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)
    print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

# 5. Evaluare
model.eval()
y_pred = []
with torch.no_grad():
    for batch_x, _ in test_loader:
        outputs = model(batch_x)
        preds = outputs.argmax(-1).squeeze().numpy()
        y_pred.extend(preds)
y_pred = np.array(y_pred)
acc = (y_pred == y_test).mean()
print(f'Acuratețe finală pe test: {acc:.4f} (peste 0.5 = pattern-uri capturate)')

# 6. Plot loss
plt.plot(train_losses)
plt.title('Evoluție Loss Transformer Binar')
plt.xlabel('Epoci')
plt.ylabel('Loss')
plt.show()

# Salvare model
torch.save(model.state_dict(), 'binary_transformer_model.pth')
print("Model salvat. Pentru predicții: model.load_state_dict(torch.load('binary_transformer_model.pth'))")
