import torch import gradio as gr import tiktoken import os from torch.nn import functional as F from model import GPT, GPTConfig # Initialize model globally to avoid reloading model = None def initialize_model(): global model if model is None: model_path = 'model/model_state_dict.pth' if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found at {model_path}") try: model = GPT(GPTConfig()) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() except Exception as e: raise Exception(f"Error loading model: {str(e)}") return model def generate_shakespeare(prompt, max_length=100, temperature=0.8): """Generate Shakespeare-style text from a prompt""" try: # Initialize model if not already done model = initialize_model() # Encode the prompt enc = tiktoken.get_encoding('gpt2') prompt_tokens = enc.encode(prompt) # Safety check for prompt length if len(prompt_tokens) > model.config.block_size: return f"Prompt too long. Please limit to {model.config.block_size} tokens." x = torch.tensor(prompt_tokens).unsqueeze(0) with torch.no_grad(): while x.size(1) < max_length: # Get predictions logits, _ = model(x) logits = logits[:, -1, :] / temperature # Sample from the distribution probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # Append to the sequence x = torch.cat((x, next_token), dim=1) # Stop if we generate a newline if next_token.item() == enc.encode('\n')[0]: break # Decode the generated text generated_tokens = x[0].tolist() generated_text = enc.decode(generated_tokens) return generated_text except Exception as e: return f"Error generating text: {str(e)}" # Create Gradio interface demo = gr.Interface( fn=generate_shakespeare, inputs=[ gr.Textbox( label="Enter your prompt", placeholder="Enter some Shakespeare-style text...", lines=2 ), gr.Slider( minimum=10, maximum=200, value=100, step=1, label="Max Length", info="Maximum length of generated text" ), gr.Slider( minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature", info="Higher values make the output more random, lower values make it more focused" ) ], outputs=gr.Textbox(label="Generated Text", lines=5), title="Shakespeare Text Generator", description="""Generate Shakespeare-style text based on your prompt using a fine-tuned GPT model. Enter a prompt and adjust the parameters to control the generation.""", examples=[ ["To be, or not to be,", 100, 0.8], ["All the world's a stage,", 100, 0.8], ["Romeo, Romeo,", 100, 0.8] ], cache_examples=True ) # Add error handling for the launch if __name__ == "__main__": try: # Test model loading at startup initialize_model() demo.launch() except Exception as e: print(f"Error starting the application: {str(e)}")