File size: 1,479 Bytes
14ac587
c95822b
14ac587
db9d4db
 
 
14ac587
db9d4db
 
 
5e2cdce
db9d4db
22fb6e4
ad4d3e1
db9d4db
 
14ac587
 
1a13068
14ac587
57caaab
b83610f
14ac587
 
 
 
 
 
 
 
 
57caaab
 
db9d4db
14ac587
2c4becd
db9d4db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Import the necessary libraries
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os

# Get the Hugging Face token from the environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")

# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', use_auth_token=HF_TOKEN)
model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model.eval()
model.to('cpu')

# Define the function that generates text from a prompt
def generate_text(prompt):
    input_tokens = tokenizer.encode(prompt, return_tensors='pt')
    input_tokens = input_tokens.to('cpu')
    
    generated_text = prompt  # Start with the initial prompt

    for _ in range(80):  # Adjust the range to control the number of tokens generated
        with torch.no_grad():
            outputs = model(input_tokens)
            predictions = outputs.logits
            next_token = torch.multinomial(torch.softmax(predictions[:, -1, :], dim=-1), 1)
        
        input_tokens = torch.cat((input_tokens, next_token), dim=1)
        
        decoded_token = tokenizer.decode(next_token.item())
        generated_text += decoded_token  # Append the new token to the generated text
        yield generated_text  # Yield the entire generated text so far

# Create a Gradio interface with a text input and a text output
interface = gr.Interface(fn=generate_text, inputs='text', outputs='text', live=False)
interface.launch()