Adapters
File size: 3,820 Bytes
70a6fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
import torch.optim as optim
import psutil
from tqdm import tqdm
import time

# Load your data
def load_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read()

# Tokenizer
class SimpleTokenizer:
    def __init__(self, vocab_path):
        self.char_to_idx = torch.load(vocab_path)
        self.idx_to_char = {i: c for c, i in self.char_to_idx.items()}

    def encode(self, text):
        return [self.char_to_idx.get(c, self.char_to_idx.get('<unk>', 0)) for c in text]

    def decode(self, indices):
        return ''.join([self.idx_to_char.get(i, '') for i in indices])

# Model
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, emb_size=256, num_heads=4, num_layers=4, ff_hid_dim=1024):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.pos_embedding = nn.Parameter(torch.zeros(1, 512, emb_size))
        self.transformer_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=emb_size, nhead=num_heads, dim_feedforward=ff_hid_dim)
            for _ in range(num_layers)
        ])
        self.output = nn.Linear(emb_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding[:, :x.size(1), :]
        for block in self.transformer_blocks:
            x = block(x)
        return self.output(x)

# Batching
def get_batches(data, batch_size, seq_length):
    inputs, targets = [], []
    for i in range(0, len(data) - seq_length - 1, seq_length):
        x = data[i:i + seq_length]
        y = data[i + 1:i + 1 + seq_length]
        if len(x) == seq_length and len(y) == seq_length:
            inputs.append(x)
            targets.append(y)
        if len(inputs) == batch_size:
            yield (
                torch.tensor(inputs, dtype=torch.long),
                torch.tensor(targets, dtype=torch.long)
            )
            inputs, targets = [], []

# Memory
def show_memory():
    process = psutil.Process()
    mem_info = process.memory_info()
    return f"{mem_info.rss / 1024**2:.2f} MB"

# Training
def train():
    vocab_size = 30000
    batch_size = 64
    seq_length = 64
    num_epochs = 3
    lr = 0.001
    vocab_path = 'vocab.pth'
    data_path = 'data.txt'

    text = load_data(data_path)
    tokenizer = SimpleTokenizer(vocab_path)
    tokens = tokenizer.encode(text)
    model = TransformerModel(vocab_size)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    for epoch in range(num_epochs):
        batches = list(get_batches(tokens, batch_size, seq_length))
        total = len(batches)
        total_loss = 0
        print(f"\n🧠 Epoch {epoch+1}/{num_epochs}{total} batches")

        with tqdm(total=total, desc="Training", bar_format="{l_bar}{bar} [ time left: {remaining} ]") as pbar:
            for step, (x, y) in enumerate(batches):
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                output = model(x)
                loss = criterion(output.view(-1, vocab_size), y.view(-1))
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                avg_loss = total_loss / (step + 1)

                if step % 10 == 0:
                    pbar.set_description(f"Loss: {loss.item():.4f} | RAM: {show_memory()}")
                pbar.update(1)

        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")
        print(f"💾 Model saved: model_epoch_{epoch+1}.pth")

if __name__ == "__main__":
    train()