aaabiao commited on
Commit
03ce29c
·
verified ·
1 Parent(s): a42898c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -63
app.py CHANGED
@@ -2,8 +2,8 @@ import os
2
  from threading import Thread
3
  from typing import Iterator
4
 
5
- import spaces
6
  import gradio as gr
 
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
@@ -16,7 +16,8 @@ if torch.cuda.is_available():
16
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
 
19
- def generate_and_display(
 
20
  message: str,
21
  chat_history: list[tuple[str, str]],
22
  system_prompt: str,
@@ -24,7 +25,7 @@ def generate_and_display(
24
  temperature: float = 0.7,
25
  top_p: float = 1.0,
26
  repetition_penalty: float = 1.1,
27
- ) -> str:
28
  conversation = []
29
  if system_prompt:
30
  conversation.append({"role": "system", "content": system_prompt})
@@ -35,12 +36,12 @@ def generate_and_display(
35
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
36
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
37
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
38
- gr.warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
39
  input_ids = input_ids.to(model.device)
40
 
41
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
42
  generate_kwargs = dict(
43
- input_ids=input_ids,
44
  streamer=streamer,
45
  max_new_tokens=max_new_tokens,
46
  do_sample=True,
@@ -49,71 +50,59 @@ def generate_and_display(
49
  num_beams=1,
50
  repetition_penalty=repetition_penalty,
51
  )
 
 
52
 
53
  outputs = []
54
- model_outputs = model.generate(**generate_kwargs)
55
- for text in streamer.generate_from_iterator(model_outputs):
56
  outputs.append(text)
57
- return "".join(outputs)
58
-
59
- def generate_response():
60
- outputs = generate_and_display(
61
- input_textbox.value,
62
- chat_history=[],
63
- system_prompt=system_prompt_textbox.value,
64
- max_new_tokens=max_new_tokens_slider.value,
65
- temperature=temperature_slider.value,
66
- top_p=top_p_slider.value,
67
- repetition_penalty=repetition_penalty_slider.value,
68
- )
69
- chat_output_textbox.value = outputs
70
 
71
- input_textbox = gr.Textbox(label="User Input", lines=5, placeholder="Enter your message...")
72
- system_prompt_textbox = gr.Textbox(label="System Prompt", lines=5, placeholder="Enter system prompt (optional)...")
73
- max_new_tokens_slider = gr.Slider(
74
- label="Max New Tokens",
75
- minimum=1,
76
- maximum=MAX_MAX_NEW_TOKENS,
77
- step=1,
78
- value=DEFAULT_MAX_NEW_TOKENS,
79
- )
80
- temperature_slider = gr.Slider(
81
- label="Temperature",
82
- minimum=0.01,
83
- maximum=1.0,
84
- step=0.01,
85
- value=0.7,
86
- )
87
- top_p_slider = gr.Slider(
88
- label="Top-p (Nucleus Sampling)",
89
- minimum=0.05,
90
- maximum=1.0,
91
- step=0.01,
92
- value=1.0,
93
- )
94
- repetition_penalty_slider = gr.Slider(
95
- label="Repetition Penalty",
96
- minimum=1.0,
97
- maximum=2.0,
98
- step=0.05,
99
- value=1.1,
100
- )
101
- generate_button = gr.Button(label="Generate Response", command=generate_response)
102
- chat_output_textbox = gr.Textbox(label="Chat Output", lines=10)
103
-
104
- gr.Interface(
105
- generate_and_display,
106
- inputs=[input_textbox, system_prompt_textbox, max_new_tokens_slider, temperature_slider, top_p_slider, repetition_penalty_slider],
107
- outputs=chat_output_textbox,
108
- title="🦣MAmmoTH2",
109
- description="A simple web interactive chat demo based on gradio.",
110
  examples=[
111
  ["Hello there! How are you doing?"],
112
  ["Can you explain briefly to me what is the Python programming language?"],
113
  ["Explain the plot of Cinderella in a sentence."],
114
- ["How many hours does it take a man to eat a helicopter?"],
115
  ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
116
  ],
117
- theme="default",
118
- live=True,
119
- ).launch()
 
 
 
 
 
2
  from threading import Thread
3
  from typing import Iterator
4
 
 
5
  import gradio as gr
6
+ import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
 
16
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
 
19
+ @spaces.GPU
20
+ def generate(
21
  message: str,
22
  chat_history: list[tuple[str, str]],
23
  system_prompt: str,
 
25
  temperature: float = 0.7,
26
  top_p: float = 1.0,
27
  repetition_penalty: float = 1.1,
28
+ ) -> Iterator[str]:
29
  conversation = []
30
  if system_prompt:
31
  conversation.append({"role": "system", "content": system_prompt})
 
36
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
37
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
38
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
39
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
40
  input_ids = input_ids.to(model.device)
41
 
42
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
43
  generate_kwargs = dict(
44
+ {"input_ids": input_ids},
45
  streamer=streamer,
46
  max_new_tokens=max_new_tokens,
47
  do_sample=True,
 
50
  num_beams=1,
51
  repetition_penalty=repetition_penalty,
52
  )
53
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
54
+ t.start()
55
 
56
  outputs = []
57
+ for text in streamer:
 
58
  outputs.append(text)
59
+ yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ chat_interface = gr.ChatInterface(
62
+ fn=generate,
63
+ additional_inputs=[
64
+ gr.Textbox(label="System prompt", lines=6),
65
+ gr.Slider(
66
+ label="Max new tokens",
67
+ minimum=1,
68
+ maximum=MAX_MAX_NEW_TOKENS,
69
+ step=1,
70
+ value=DEFAULT_MAX_NEW_TOKENS,
71
+ ),
72
+ gr.Slider(
73
+ label="Temperature",
74
+ minimum=0.01,
75
+ maximum=1.0,
76
+ step=0.01,
77
+ value=0.7,
78
+ ),
79
+ gr.Slider(
80
+ label="Top-p (nucleus sampling)",
81
+ minimum=0.05,
82
+ maximum=1.0,
83
+ step=0.01,
84
+ value=1.0,
85
+ ),
86
+ gr.Slider(
87
+ label="Repetition penalty",
88
+ minimum=1.0,
89
+ maximum=2.0,
90
+ step=0.05,
91
+ value=1.1,
92
+ ),
93
+ ],
94
+ stop_btn=None,
 
 
 
 
 
95
  examples=[
96
  ["Hello there! How are you doing?"],
97
  ["Can you explain briefly to me what is the Python programming language?"],
98
  ["Explain the plot of Cinderella in a sentence."],
99
+ ["How many hours does it take a man to eat a Helicopter?"],
100
  ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
101
  ],
102
+ )
103
+
104
+ with gr.Blocks(css="style.css") as demo:
105
+ chat_interface.render()
106
+
107
+ if __name__ == "__main__":
108
+ demo.queue(max_size=20).launch()