File size: 1,659 Bytes
6c044f9
b22cc86
 
6c044f9
 
8154e9b
b22cc86
6c044f9
 
 
3e6a8ce
6c044f9
 
3e6a8ce
6c044f9
 
 
 
 
 
 
3e6a8ce
6c044f9
 
504eb77
6c044f9
 
 
 
504eb77
6c044f9
 
 
 
3e6a8ce
6c044f9
 
 
 
504eb77
6c044f9
 
504eb77
6c044f9
504eb77
6c044f9
 
 
5448b40
6c044f9
504eb77
 
6c044f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import torch
import tiktoken
import gradio as gr
from transformers import GPT2Tokenizer
from model import GPTLanguageModel

# Initialize the GPT-2 tokenizer
enc = tiktoken.get_encoding("gpt2")  # Using tiktoken
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")  # Using Hugging Face tokenizer for consistency

# Load the GPT-2 model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Model hyperparameters (should match the training configuration)
vocab_size = 50257
n_heads = 8
n_layers = 6
n_embd = 512
block_size = 128
dropout = 0.1

# Create the GPT model instance
model = GPTLanguageModel(vocab_size, n_embd, block_size, n_layers, n_heads).to(device)

# Load the trained model weights
if os.path.exists("model_weights.pth"):
    model.load_state_dict(torch.load("model_weights.pth", map_location=device))
model.eval()

# Function to generate a response based on the user input
def get_response(prompt):
    # Tokenize the input prompt
    context = torch.tensor([enc.encode(prompt)], dtype=torch.long, device=device)

    # Generate tokens from the model
    max_new_tokens = 200  # Number of tokens to generate
    temperature = 0.8  # Can adjust for different sampling behaviors
    generated_text_idx = model.generate(context, max_new_tokens)

    # Decode the generated token IDs into text
    generated_text = enc.decode(generated_text_idx[0].tolist())

    return generated_text

def main():
    """Main function to run the app"""
    # Setup Gradio interface
    iface = gr.Interface(fn=get_response, inputs="text", outputs="text", title="StoryCrafterLLM")
    iface.launch()

if __name__ == "__main__":
    main()