|
import torch |
|
import gradio as gr |
|
import tiktoken |
|
import os |
|
from torch.nn import functional as F |
|
from model import GPT, GPTConfig |
|
|
|
|
|
model = None |
|
|
|
def initialize_model(): |
|
global model |
|
if model is None: |
|
model_path = 'model/model_state_dict.pth' |
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"Model file not found at {model_path}") |
|
|
|
try: |
|
model = GPT(GPTConfig()) |
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
model.eval() |
|
except Exception as e: |
|
raise Exception(f"Error loading model: {str(e)}") |
|
return model |
|
|
|
def generate_shakespeare(prompt, max_length=100, temperature=0.8): |
|
"""Generate Shakespeare-style text from a prompt""" |
|
try: |
|
|
|
model = initialize_model() |
|
|
|
|
|
enc = tiktoken.get_encoding('gpt2') |
|
prompt_tokens = enc.encode(prompt) |
|
|
|
|
|
if len(prompt_tokens) > model.config.block_size: |
|
return f"Prompt too long. Please limit to {model.config.block_size} tokens." |
|
|
|
x = torch.tensor(prompt_tokens).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
while x.size(1) < max_length: |
|
|
|
logits, _ = model(x) |
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
x = torch.cat((x, next_token), dim=1) |
|
|
|
|
|
if next_token.item() == enc.encode('\n')[0]: |
|
break |
|
|
|
|
|
generated_tokens = x[0].tolist() |
|
generated_text = enc.decode(generated_tokens) |
|
|
|
return generated_text |
|
|
|
except Exception as e: |
|
return f"Error generating text: {str(e)}" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_shakespeare, |
|
inputs=[ |
|
gr.Textbox( |
|
label="Enter your prompt", |
|
placeholder="Enter some Shakespeare-style text...", |
|
lines=2 |
|
), |
|
gr.Slider( |
|
minimum=10, |
|
maximum=200, |
|
value=100, |
|
step=1, |
|
label="Max Length", |
|
info="Maximum length of generated text" |
|
), |
|
gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.8, |
|
step=0.1, |
|
label="Temperature", |
|
info="Higher values make the output more random, lower values make it more focused" |
|
) |
|
], |
|
outputs=gr.Textbox(label="Generated Text", lines=5), |
|
title="Shakespeare Text Generator", |
|
description="""Generate Shakespeare-style text based on your prompt using a fine-tuned GPT model. |
|
Enter a prompt and adjust the parameters to control the generation.""", |
|
examples=[ |
|
["To be, or not to be,", 100, 0.8], |
|
["All the world's a stage,", 100, 0.8], |
|
["Romeo, Romeo,", 100, 0.8] |
|
], |
|
cache_examples=True |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
try: |
|
|
|
initialize_model() |
|
demo.launch() |
|
except Exception as e: |
|
print(f"Error starting the application: {str(e)}") |