|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from SimpleRNN import SimpleRNN
|
|
import os
|
|
import json
|
|
from tqdm import tqdm, trange
|
|
import time
|
|
|
|
training_text = open("train_data.txt", encoding="utf-8").read()
|
|
chars = sorted(list(set(training_text)))
|
|
char_to_idx = {ch: i for i, ch in enumerate(chars)}
|
|
idx_to_char = {i: ch for i, ch in enumerate(chars)}
|
|
|
|
parameters = json.loads(open("parameter.json").read())
|
|
input_size = len(chars)
|
|
hidden_size = parameters["hidden_size"]
|
|
output_size = len(chars)
|
|
sequence_length = parameters["sequence_length"]
|
|
epochs = 1000
|
|
learning_rate = parameters["learning_rate"]
|
|
model_path = parameters["model_path"]
|
|
|
|
train_data = []
|
|
for i in range(len(training_text) - sequence_length):
|
|
input_seq = training_text[i : i + sequence_length]
|
|
target_char = training_text[i + sequence_length]
|
|
train_data.append((torch.tensor([char_to_idx[ch] for ch in input_seq]), char_to_idx[target_char]))
|
|
|
|
if os.path.exists(model_path):
|
|
model = torch.load(model_path, weights_only=False)
|
|
print("Loaded pre-trained model. Continue training...")
|
|
else:
|
|
print("Training new model...")
|
|
model = SimpleRNN(input_size, hidden_size, output_size)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
|
for epoch in range(epochs):
|
|
try:
|
|
total_loss = 0
|
|
hidden = torch.zeros(1, 1, hidden_size)
|
|
|
|
pbar = tqdm(train_data, desc=f"Epoch={epoch}, Loss=N/A")
|
|
count = 0
|
|
for input_seq, target in pbar:
|
|
count += 1
|
|
optimizer.zero_grad()
|
|
output, hidden = model(input_seq, hidden.detach())
|
|
loss = criterion(output, torch.tensor([target]))
|
|
loss.backward()
|
|
optimizer.step()
|
|
total_loss += loss.item()
|
|
pbar.desc = f"Epoch={epoch}, Loss={total_loss / count:.12f}"
|
|
|
|
pbar.close()
|
|
time.sleep(1)
|
|
except KeyboardInterrupt:
|
|
break
|
|
|
|
hidden = torch.zeros(1, 1, hidden_size)
|
|
output, hidden = model(input_seq, hidden.detach())
|
|
|
|
torch.save(model, model_path)
|
|
with open("vocab.json", "w") as f:
|
|
f.write(json.dumps(chars))
|
|
print("Model saved.")
|
|
|