File size: 849 Bytes
db9d4db
c95822b
db9d4db
 
 
 
 
 
 
 
 
22fb6e4
db9d4db
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Import the libraries
import gradio as gr
import transformers
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os

HF_TOKEN = os.environ.get("HF_TOKEN")

# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model.eval()

# Define the function that generates text from a prompt
def generate_text(prompt):
  input_ids = tokenizer.encode(prompt, return_tensors='pt')
  output = model.generate(input_ids, max_new_tokens=80, do_sample=True)
  text = tokenizer.decode(output[0], skip_special_tokens=True)
  return text

# Create a gradio interface with a text input and a text output
interface = gr.Interface(fn=generate_text, inputs='text', outputs='text')
interface.launch()