File size: 3,647 Bytes
b7683e4 395f517 b7683e4 395f517 b7683e4 92cec87 395f517 b7683e4 395f517 b7683e4 395f517 b7683e4 395f517 b7683e4 395f517 b7683e4 395f517 b7683e4 395f517 b7683e4 395f517 b7683e4 395f517 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch
import gradio as gr
import tiktoken
import os
from torch.nn import functional as F
from model import GPT, GPTConfig
# Initialize model globally to avoid reloading
model = None
def initialize_model():
global model
if model is None:
model_path = 'model/model_state_dict.pth'
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}")
try:
model = GPT(GPTConfig())
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
except Exception as e:
raise Exception(f"Error loading model: {str(e)}")
return model
def generate_shakespeare(prompt, max_length=100, temperature=0.8):
"""Generate Shakespeare-style text from a prompt"""
try:
# Initialize model if not already done
model = initialize_model()
# Encode the prompt
enc = tiktoken.get_encoding('gpt2')
prompt_tokens = enc.encode(prompt)
# Safety check for prompt length
if len(prompt_tokens) > model.config.block_size:
return f"Prompt too long. Please limit to {model.config.block_size} tokens."
x = torch.tensor(prompt_tokens).unsqueeze(0)
with torch.no_grad():
while x.size(1) < max_length:
# Get predictions
logits, _ = model(x)
logits = logits[:, -1, :] / temperature
# Sample from the distribution
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to the sequence
x = torch.cat((x, next_token), dim=1)
# Stop if we generate a newline
if next_token.item() == enc.encode('\n')[0]:
break
# Decode the generated text
generated_tokens = x[0].tolist()
generated_text = enc.decode(generated_tokens)
return generated_text
except Exception as e:
return f"Error generating text: {str(e)}"
# Create Gradio interface
demo = gr.Interface(
fn=generate_shakespeare,
inputs=[
gr.Textbox(
label="Enter your prompt",
placeholder="Enter some Shakespeare-style text...",
lines=2
),
gr.Slider(
minimum=10,
maximum=200,
value=100,
step=1,
label="Max Length",
info="Maximum length of generated text"
),
gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.8,
step=0.1,
label="Temperature",
info="Higher values make the output more random, lower values make it more focused"
)
],
outputs=gr.Textbox(label="Generated Text", lines=5),
title="Shakespeare Text Generator",
description="""Generate Shakespeare-style text based on your prompt using a fine-tuned GPT model.
Enter a prompt and adjust the parameters to control the generation.""",
examples=[
["To be, or not to be,", 100, 0.8],
["All the world's a stage,", 100, 0.8],
["Romeo, Romeo,", 100, 0.8]
],
cache_examples=True
)
# Add error handling for the launch
if __name__ == "__main__":
try:
# Test model loading at startup
initialize_model()
demo.launch()
except Exception as e:
print(f"Error starting the application: {str(e)}") |