tiny_llm / train.py
xcx0902's picture
Upload folder using huggingface_hub
b51c975 verified
raw
history blame
3.8 kB
import torch
import torch.nn as nn
import torch.optim as optim
import os
import json
from tqdm import tqdm, trange
import time
# Generate simple training data
training_text = open("train_data.txt", encoding="utf-8").read()
chars = sorted(list(set(training_text))) # Unique characters
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
# Model parameters
input_size = len(chars)
hidden_size = 32
output_size = len(chars)
sequence_length = 5
epochs = 1000
learning_rate = 0.0001
model_path = "tiny_llm.pth"
# Create training data (input-output pairs)
train_data = []
for i in range(len(training_text) - sequence_length):
input_seq = training_text[i : i + sequence_length]
target_char = training_text[i + sequence_length]
train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char]))
# Define the simple RNN model
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleRNN, self).__init__()
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
x = torch.nn.functional.one_hot(x, num_classes=input_size).float()
out, hidden = self.rnn(x.unsqueeze(0), hidden)
out = self.fc(out[:, -1, :]) # Take last time step's output
return out, hidden
# Load model if available
if os.path.exists(model_path):
model = torch.load(model_path, weights_only=False)
with open("vocab.json", "r") as f:
chars = json.loads(f.read())
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
print("Loaded pre-trained model.")
else:
print("Training new model...")
# Initialize the model
model = SimpleRNN(input_size, hidden_size, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
try:
total_loss = 0
hidden = torch.zeros(1, 1, hidden_size)
pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A")
count = 0
for input_seq, target in pbar:
count += 1
optimizer.zero_grad()
output, hidden = model(input_seq, hidden.detach())
loss = criterion(output, torch.tensor([target]))
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}"
pbar.close()
time.sleep(1)
except KeyboardInterrupt:
break
hidden = torch.zeros(1, 1, hidden_size)
output, hidden = model(input_seq, hidden.detach())
# Save the trained model
torch.save(model, model_path)
with open("vocab.json", "w") as f:
f.write(json.dumps(chars))
print("Model saved.")
# Text generation function
def generate_text(start_text, length=10000):
model.eval()
hidden = torch.zeros(1, 1, hidden_size)
input_seq = torch.tensor([char_to_idx[ch] for ch in start_text])
generated_text = start_text
for _ in trange(length):
output, hidden = model(input_seq, hidden)
predicted_idx = output.argmax().item()
generated_text += idx_to_char[predicted_idx]
input_seq = torch.cat((input_seq[1:], torch.tensor([predicted_idx])))
return generated_text
# Generate some text
while True:
print("LLM Output: ", generate_text(input("Ask LLM: ")))