import torch import torch.nn as nn import torch.optim as optim import os import json from tqdm import tqdm, trange import time # Generate simple training data training_text = open("train_data.txt", encoding="utf-8").read() chars = sorted(list(set(training_text))) # Unique characters char_to_idx = {ch: i for i, ch in enumerate(chars)} idx_to_char = {i: ch for i, ch in enumerate(chars)} # Model parameters input_size = len(chars) hidden_size = 32 output_size = len(chars) sequence_length = 5 epochs = 1000 learning_rate = 0.0001 model_path = "tiny_llm.pth" # Create training data (input-output pairs) train_data = [] for i in range(len(training_text) - sequence_length): input_seq = training_text[i : i + sequence_length] target_char = training_text[i + sequence_length] train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char])) # Define the simple RNN model class SimpleRNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(SimpleRNN, self).__init__() self.hidden_size = hidden_size self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, output_size) def forward(self, x, hidden): x = torch.nn.functional.one_hot(x, num_classes=input_size).float() out, hidden = self.rnn(x.unsqueeze(0), hidden) out = self.fc(out[:, -1, :]) # Take last time step's output return out, hidden # Load model if available if os.path.exists(model_path): model = torch.load(model_path, weights_only=False) with open("vocab.json", "r") as f: chars = json.loads(f.read()) char_to_idx = {ch: i for i, ch in enumerate(chars)} idx_to_char = {i: ch for i, ch in enumerate(chars)} print("Loaded pre-trained model.") else: print("Training new model...") # Initialize the model model = SimpleRNN(input_size, hidden_size, output_size) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) for epoch in range(epochs): try: total_loss = 0 hidden = torch.zeros(1, 1, hidden_size) pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A") count = 0 for input_seq, target in pbar: count += 1 optimizer.zero_grad() output, hidden = model(input_seq, hidden.detach()) loss = criterion(output, torch.tensor([target])) loss.backward() optimizer.step() total_loss += loss.item() pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}" pbar.close() time.sleep(1) except KeyboardInterrupt: break hidden = torch.zeros(1, 1, hidden_size) output, hidden = model(input_seq, hidden.detach()) # Save the trained model torch.save(model, model_path) with open("vocab.json", "w") as f: f.write(json.dumps(chars)) print("Model saved.") # Text generation function def generate_text(start_text, length=10000): model.eval() hidden = torch.zeros(1, 1, hidden_size) input_seq = torch.tensor([char_to_idx[ch] for ch in start_text]) generated_text = start_text for _ in trange(length): output, hidden = model(input_seq, hidden) predicted_idx = output.argmax().item() generated_text += idx_to_char[predicted_idx] input_seq = torch.cat((input_seq[1:], torch.tensor([predicted_idx]))) return generated_text # Generate some text while True: print("LLM Output: ", generate_text(input("Ask LLM: ")))