|
import torch |
|
from model import TransformerModel |
|
from tokenizer import SimpleTokenizer |
|
|
|
|
|
tokenizer = SimpleTokenizer("vocab_path") |
|
|
|
|
|
vocab_size = len(tokenizer.char_to_idx) |
|
embed_size = 64 |
|
num_heads = 2 |
|
hidden_dim = 128 |
|
num_layers = 2 |
|
max_len = 32 |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model = TransformerModel(vocab_size, embed_size, num_heads, hidden_dim, num_layers, max_len).to(device) |
|
model.load_state_dict(torch.load("model.pth", map_location=device)) |
|
model.eval() |
|
|
|
|
|
while True: |
|
user_input = input("You: ") |
|
if user_input.lower() in ["quit", "exit"]: |
|
break |
|
|
|
input_ids = tokenizer.encode(user_input) |
|
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(device) |
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor)[0] |
|
prediction = torch.argmax(output, dim=-1).squeeze().tolist() |
|
|
|
response = tokenizer.decode(prediction) |
|
print("AI:", response) |
|
|