TopicGPT / app.py
skylersterling's picture
Update app.py
22fb6e4 verified
raw
history blame
849 Bytes
# Import the libraries
import gradio as gr
import transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os
HF_TOKEN = os.environ.get("HF_TOKEN")
# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model.eval()
# Define the function that generates text from a prompt
def generate_text(prompt):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
output = model.generate(input_ids, max_new_tokens=80, do_sample=True)
text = tokenizer.decode(output[0], skip_special_tokens=True)
return text
# Create a gradio interface with a text input and a text output
interface = gr.Interface(fn=generate_text, inputs='text', outputs='text')
interface.launch()