File size: 2,727 Bytes
c95822b
14ac587
db9d4db
 
 
14ac587
db9d4db
 
 
5e2cdce
db9d4db
22fb6e4
ad4d3e1
db9d4db
 
c4dece9
013060a
 
38c9c31
1a13068
403b4a5
013060a
46d003d
b83610f
14ac587
 
 
3aeba0a
 
c4dece9
 
 
 
 
3aeba0a
 
14ac587
 
 
 
57caaab
c4dece9
 
013060a
db9d4db
c4dece9
 
 
 
ccf312e
56cfefc
4a9dd8e
c4dece9
ccf312e
6532f26
d152efa
c4dece9
 
b0a55f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os

# Get the Hugging Face token from the environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")

# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', use_auth_token=HF_TOKEN)
model = GPT2LMHeadModel.from_pretrained('skylersterling/TopicGPT', use_auth_token=HF_TOKEN)
model.eval()
model.to('cpu')

# Define the function that generates text from a prompt
def generate_text(prompt, temperature, top_p):
    prompt_with_eos = " #CONTEXT# " + prompt + " #TOPIC# "  # Add the string "EOS" to the end of the prompt
    input_tokens = tokenizer.encode(prompt_with_eos, return_tensors='pt')

    input_tokens = input_tokens.to('cpu')

    generated_text = prompt_with_eos  # Start with the initial prompt plus "EOS"
    prompt_length = len(generated_text)

    for _ in range(80):  # Adjust the range to control the number of tokens generated
        with torch.no_grad():
            outputs = model(input_tokens)
            predictions = outputs.logits[:, -1, :] / temperature
            sorted_logits, sorted_indices = torch.sort(predictions, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices[sorted_indices_to_remove]
            predictions[:, indices_to_remove] = -float('Inf')
            next_token = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
        
        input_tokens = torch.cat((input_tokens, next_token), dim=1)
        
        decoded_token = tokenizer.decode(next_token.item())
        generated_text += decoded_token  # Append the new token to the generated text
        if decoded_token == "#":  # Stop if the end of sequence token is generated
            break
        yield generated_text[prompt_length:]  # Yield the generated text excluding the initial prompt plus "EOS"

# Create a Gradio interface with a text input, sliders for temperature and top_p, and a text output
interface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.3, label="Temperature"),
        gr.Slider(minimum=0.05, maximum=1.0, value=0.3, label="Top-p")
    ],
    outputs=gr.Textbox(),
    live=False,
    description="TopicGPT processes the input and returns a reasonably accurate estimate of the topic/theme of a given conversation."
)

interface.launch()