|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import psutil
|
|
from tqdm import tqdm
|
|
import time
|
|
|
|
|
|
def load_data(file_path):
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
|
|
|
|
class SimpleTokenizer:
|
|
def __init__(self, vocab_path):
|
|
self.char_to_idx = torch.load(vocab_path)
|
|
self.idx_to_char = {i: c for c, i in self.char_to_idx.items()}
|
|
|
|
def encode(self, text):
|
|
return [self.char_to_idx.get(c, self.char_to_idx.get('<unk>', 0)) for c in text]
|
|
|
|
def decode(self, indices):
|
|
return ''.join([self.idx_to_char.get(i, '') for i in indices])
|
|
|
|
|
|
class TransformerModel(nn.Module):
|
|
def __init__(self, vocab_size, emb_size=256, num_heads=4, num_layers=4, ff_hid_dim=1024):
|
|
super().__init__()
|
|
self.embedding = nn.Embedding(vocab_size, emb_size)
|
|
self.pos_embedding = nn.Parameter(torch.zeros(1, 512, emb_size))
|
|
self.transformer_blocks = nn.ModuleList([
|
|
nn.TransformerEncoderLayer(d_model=emb_size, nhead=num_heads, dim_feedforward=ff_hid_dim)
|
|
for _ in range(num_layers)
|
|
])
|
|
self.output = nn.Linear(emb_size, vocab_size)
|
|
|
|
def forward(self, x):
|
|
x = self.embedding(x) + self.pos_embedding[:, :x.size(1), :]
|
|
for block in self.transformer_blocks:
|
|
x = block(x)
|
|
return self.output(x)
|
|
|
|
|
|
def get_batches(data, batch_size, seq_length):
|
|
inputs, targets = [], []
|
|
for i in range(0, len(data) - seq_length - 1, seq_length):
|
|
x = data[i:i + seq_length]
|
|
y = data[i + 1:i + 1 + seq_length]
|
|
if len(x) == seq_length and len(y) == seq_length:
|
|
inputs.append(x)
|
|
targets.append(y)
|
|
if len(inputs) == batch_size:
|
|
yield (
|
|
torch.tensor(inputs, dtype=torch.long),
|
|
torch.tensor(targets, dtype=torch.long)
|
|
)
|
|
inputs, targets = [], []
|
|
|
|
|
|
def show_memory():
|
|
process = psutil.Process()
|
|
mem_info = process.memory_info()
|
|
return f"{mem_info.rss / 1024**2:.2f} MB"
|
|
|
|
|
|
def train():
|
|
vocab_size = 30000
|
|
batch_size = 64
|
|
seq_length = 64
|
|
num_epochs = 3
|
|
lr = 0.001
|
|
vocab_path = 'vocab.pth'
|
|
data_path = 'data.txt'
|
|
|
|
text = load_data(data_path)
|
|
tokenizer = SimpleTokenizer(vocab_path)
|
|
tokens = tokenizer.encode(text)
|
|
model = TransformerModel(vocab_size)
|
|
optimizer = optim.Adam(model.parameters(), lr=lr)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
model.to(device)
|
|
|
|
for epoch in range(num_epochs):
|
|
batches = list(get_batches(tokens, batch_size, seq_length))
|
|
total = len(batches)
|
|
total_loss = 0
|
|
print(f"\n🧠 Epoch {epoch+1}/{num_epochs} — {total} batches")
|
|
|
|
with tqdm(total=total, desc="Training", bar_format="{l_bar}{bar} [ time left: {remaining} ]") as pbar:
|
|
for step, (x, y) in enumerate(batches):
|
|
x, y = x.to(device), y.to(device)
|
|
optimizer.zero_grad()
|
|
output = model(x)
|
|
loss = criterion(output.view(-1, vocab_size), y.view(-1))
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
avg_loss = total_loss / (step + 1)
|
|
|
|
if step % 10 == 0:
|
|
pbar.set_description(f"Loss: {loss.item():.4f} | RAM: {show_memory()}")
|
|
pbar.update(1)
|
|
|
|
torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")
|
|
print(f"💾 Model saved: model_epoch_{epoch+1}.pth")
|
|
|
|
if __name__ == "__main__":
|
|
train()
|
|
|