TopicGPT / app.py
skylersterling's picture
Update app.py
2c4becd verified
raw
history blame
1.48 kB
# Import the necessary libraries
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os
# Get the Hugging Face token from the environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', use_auth_token=HF_TOKEN)
model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model.eval()
model.to('cpu')
# Define the function that generates text from a prompt
def generate_text(prompt):
input_tokens = tokenizer.encode(prompt, return_tensors='pt')
input_tokens = input_tokens.to('cpu')
generated_text = prompt # Start with the initial prompt
for _ in range(80): # Adjust the range to control the number of tokens generated
with torch.no_grad():
outputs = model(input_tokens)
predictions = outputs.logits
next_token = torch.multinomial(torch.softmax(predictions[:, -1, :], dim=-1), 1)
input_tokens = torch.cat((input_tokens, next_token), dim=1)
decoded_token = tokenizer.decode(next_token.item())
generated_text += decoded_token # Append the new token to the generated text
yield generated_text # Yield the entire generated text so far
# Create a Gradio interface with a text input and a text output
interface = gr.Interface(fn=generate_text, inputs='text', outputs='text', live=False)
interface.launch()