|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import os
|
|
import json
|
|
from tqdm import tqdm, trange
|
|
import time
|
|
|
|
|
|
training_text = open("train_data.txt", encoding="utf-8").read()
|
|
chars = sorted(list(set(training_text)))
|
|
char_to_idx = {ch: i for i, ch in enumerate(chars)}
|
|
idx_to_char = {i: ch for i, ch in enumerate(chars)}
|
|
|
|
|
|
input_size = len(chars)
|
|
hidden_size = 32
|
|
output_size = len(chars)
|
|
sequence_length = 5
|
|
epochs = 1000
|
|
learning_rate = 0.0001
|
|
model_path = "tiny_llm.pth"
|
|
|
|
|
|
train_data = []
|
|
for i in range(len(training_text) - sequence_length):
|
|
input_seq = training_text[i : i + sequence_length]
|
|
target_char = training_text[i + sequence_length]
|
|
train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char]))
|
|
|
|
|
|
class SimpleRNN(nn.Module):
|
|
def __init__(self, input_size, hidden_size, output_size):
|
|
super(SimpleRNN, self).__init__()
|
|
self.hidden_size = hidden_size
|
|
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
|
|
self.fc = nn.Linear(hidden_size, output_size)
|
|
|
|
def forward(self, x, hidden):
|
|
x = torch.nn.functional.one_hot(x, num_classes=input_size).float()
|
|
out, hidden = self.rnn(x.unsqueeze(0), hidden)
|
|
out = self.fc(out[:, -1, :])
|
|
return out, hidden
|
|
|
|
|
|
if os.path.exists(model_path):
|
|
model = torch.load(model_path, weights_only=False)
|
|
with open("vocab.json", "r") as f:
|
|
chars = json.loads(f.read())
|
|
char_to_idx = {ch: i for i, ch in enumerate(chars)}
|
|
idx_to_char = {i: ch for i, ch in enumerate(chars)}
|
|
print("Loaded pre-trained model.")
|
|
else:
|
|
print("Training new model...")
|
|
|
|
model = SimpleRNN(input_size, hidden_size, output_size)
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
for epoch in range(epochs):
|
|
try:
|
|
total_loss = 0
|
|
hidden = torch.zeros(1, 1, hidden_size)
|
|
|
|
pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A")
|
|
count = 0
|
|
for input_seq, target in pbar:
|
|
count += 1
|
|
optimizer.zero_grad()
|
|
output, hidden = model(input_seq, hidden.detach())
|
|
loss = criterion(output, torch.tensor([target]))
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}"
|
|
|
|
pbar.close()
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
break
|
|
|
|
hidden = torch.zeros(1, 1, hidden_size)
|
|
output, hidden = model(input_seq, hidden.detach())
|
|
|
|
|
|
torch.save(model, model_path)
|
|
with open("vocab.json", "w") as f:
|
|
f.write(json.dumps(chars))
|
|
print("Model saved.")
|
|
|
|
|
|
def generate_text(start_text, length=10000):
|
|
model.eval()
|
|
hidden = torch.zeros(1, 1, hidden_size)
|
|
input_seq = torch.tensor([char_to_idx[ch] for ch in start_text])
|
|
|
|
generated_text = start_text
|
|
for _ in trange(length):
|
|
output, hidden = model(input_seq, hidden)
|
|
predicted_idx = output.argmax().item()
|
|
generated_text += idx_to_char[predicted_idx]
|
|
input_seq = torch.cat((input_seq[1:], torch.tensor([predicted_idx])))
|
|
|
|
return generated_text
|
|
|
|
|
|
|
|
while True:
|
|
print("LLM Output: ", generate_text(input("Ask LLM: ")))
|
|
|