skylersterling commited on
Commit
e8fc1ff
·
verified ·
1 Parent(s): 6a64f56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -41
app.py CHANGED
@@ -1,44 +1,97 @@
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
-
4
- # Load the model and tokenizer
5
- model_name = "skylersterling/TopicGPT"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name)
8
-
9
- # Define the generation function
10
- def generate_text(context, max_tokens):
11
- input_text = f"#CONTEXT# {context} #TOPIC#"
12
- input_ids = tokenizer.encode(input_text, return_tensors='pt')
13
-
14
- # Generate tokens one by one
15
- generated_ids = input_ids
16
- for _ in range(max_tokens):
17
- outputs = model(generated_ids)
18
- next_token_id = outputs.logits[:, -1, :].argmax(dim=-1)
19
- generated_ids = torch.cat([generated_ids, next_token_id.unsqueeze(0)], dim=1)
20
- if next_token_id == tokenizer.encode("#TOPIC#", add_special_tokens=False)[0]:
21
- break
22
-
23
- generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
24
- return generated_text
25
-
26
- # Create Gradio interface
27
- def gradio_interface():
28
- context_input = gr.inputs.Textbox(lines=5, placeholder="Enter the context here...")
29
- max_tokens_input = gr.inputs.Slider(minimum=1, maximum=200, default=50, step=1)
30
- output_textbox = gr.outputs.Textbox()
31
-
32
- interface = gr.Interface(
33
- fn=generate_text,
34
- inputs=[context_input, max_tokens_input],
35
- outputs=output_textbox,
36
- title="TopicGPT Text Generation",
37
- description="Generate text token-by-token using the TopicGPT model. The input should start with #CONTEXT# and end with #TOPIC#."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
39
-
40
- interface.launch()
41
 
42
- if __name__ == "__main__":
43
- gradio_interface()
44
-
 
1
+ from huggingface_hub import InferenceClient
2
  import gradio as gr
3
+ import random
4
+
5
+ API_URL = "https://api-inference.huggingface.co/models/"
6
+
7
+ client = InferenceClient(
8
+ "skylersterling/TopicGPT"
9
+ )
10
+
11
+ def format_prompt(message, history):
12
+ prompt = "<s>"
13
+ for user_prompt, bot_response in history:
14
+ prompt += f"[INST] {user_prompt} [/INST]"
15
+ prompt += f" {bot_response}</s> "
16
+ prompt += f"[INST] {message} [/INST]"
17
+ return prompt
18
+
19
+ def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
20
+ temperature = float(temperature)
21
+ if temperature < 1e-2:
22
+ temperature = 1e-2
23
+ top_p = float(top_p)
24
+
25
+ generate_kwargs = dict(
26
+ temperature=temperature,
27
+ max_new_tokens=max_new_tokens,
28
+ top_p=top_p,
29
+ repetition_penalty=repetition_penalty,
30
+ do_sample=True,
31
+ seed=random.randint(0, 10**7),
32
+ )
33
+
34
+ formatted_prompt = format_prompt(prompt, history)
35
+
36
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
+ output = ""
38
+
39
+ for response in stream:
40
+ output += response.token.text
41
+ yield output
42
+ return output
43
+
44
+
45
+ additional_inputs=[
46
+ gr.Slider(
47
+ label="Temperature",
48
+ value=0.9,
49
+ minimum=0.0,
50
+ maximum=1.0,
51
+ step=0.05,
52
+ interactive=True,
53
+ info="Higher values produce more diverse outputs",
54
+ ),
55
+ gr.Slider(
56
+ label="Max new tokens",
57
+ value=512,
58
+ minimum=64,
59
+ maximum=1024,
60
+ step=64,
61
+ interactive=True,
62
+ info="The maximum numbers of new tokens",
63
+ ),
64
+ gr.Slider(
65
+ label="Top-p (nucleus sampling)",
66
+ value=0.90,
67
+ minimum=0.0,
68
+ maximum=1,
69
+ step=0.05,
70
+ interactive=True,
71
+ info="Higher values sample more low-probability tokens",
72
+ ),
73
+ gr.Slider(
74
+ label="Repetition penalty",
75
+ value=1.2,
76
+ minimum=1.0,
77
+ maximum=2.0,
78
+ step=0.05,
79
+ interactive=True,
80
+ info="Penalize repeated tokens",
81
+ )
82
+ ]
83
+
84
+ customCSS = """
85
+ #component-7 { # this is the default element ID of the chat component
86
+ height: 800px; # adjust the height as needed
87
+ flex-grow: 1;
88
+ }
89
+ """
90
+
91
+ with gr.Blocks(css=customCSS) as demo:
92
+ gr.ChatInterface(
93
+ generate,
94
+ additional_inputs=additional_inputs,
95
  )
 
 
96
 
97
+ demo.queue().launch(debug=True)