File size: 3,804 Bytes
b51c975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torch.optim as optim
import os
import json
from tqdm import tqdm, trange
import time

# Generate simple training data
training_text = open("train_data.txt", encoding="utf-8").read()
chars = sorted(list(set(training_text)))  # Unique characters
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

# Model parameters
input_size = len(chars)
hidden_size = 32
output_size = len(chars)
sequence_length = 5
epochs = 1000
learning_rate = 0.0001
model_path = "tiny_llm.pth"

# Create training data (input-output pairs)
train_data = []
for i in range(len(training_text) - sequence_length):
    input_seq = training_text[i : i + sequence_length]
    target_char = training_text[i + sequence_length]
    train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char]))

# Define the simple RNN model
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden):
        x = torch.nn.functional.one_hot(x, num_classes=input_size).float()
        out, hidden = self.rnn(x.unsqueeze(0), hidden)
        out = self.fc(out[:, -1, :])  # Take last time step's output
        return out, hidden

# Load model if available
if os.path.exists(model_path):
    model = torch.load(model_path, weights_only=False)
    with open("vocab.json", "r") as f:
        chars = json.loads(f.read())
    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for i, ch in enumerate(chars)}
    print("Loaded pre-trained model.")
else:
    print("Training new model...")
    # Initialize the model
    model = SimpleRNN(input_size, hidden_size, output_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(epochs):
        try:
            total_loss = 0
            hidden = torch.zeros(1, 1, hidden_size)
            
            pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A")
            count = 0
            for input_seq, target in pbar:
                count += 1
                optimizer.zero_grad()
                output, hidden = model(input_seq, hidden.detach())
                loss = criterion(output, torch.tensor([target]))
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}"
            
            pbar.close()
            time.sleep(1)
        except KeyboardInterrupt:
            break
    
    hidden = torch.zeros(1, 1, hidden_size)
    output, hidden = model(input_seq, hidden.detach())

    # Save the trained model
    torch.save(model, model_path)
    with open("vocab.json", "w") as f:
        f.write(json.dumps(chars))
    print("Model saved.")

# Text generation function
def generate_text(start_text, length=10000):
    model.eval()
    hidden = torch.zeros(1, 1, hidden_size)
    input_seq = torch.tensor([char_to_idx[ch] for ch in start_text])

    generated_text = start_text
    for _ in trange(length):
        output, hidden = model(input_seq, hidden)
        predicted_idx = output.argmax().item()
        generated_text += idx_to_char[predicted_idx]
        input_seq = torch.cat((input_seq[1:], torch.tensor([predicted_idx])))

    return generated_text

# Generate some text

while True:
    print("LLM Output: ", generate_text(input("Ask LLM: ")))