aaabiao commited on
Commit
a42898c
·
verified ·
1 Parent(s): 4e73b84

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -48
app.py CHANGED
@@ -16,8 +16,7 @@ if torch.cuda.is_available():
16
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
 
19
- @spaces.GPU
20
- def generate(
21
  message: str,
22
  chat_history: list[tuple[str, str]],
23
  system_prompt: str,
@@ -25,7 +24,7 @@ def generate(
25
  temperature: float = 0.7,
26
  top_p: float = 1.0,
27
  repetition_penalty: float = 1.1,
28
- ) -> Iterator[str]:
29
  conversation = []
30
  if system_prompt:
31
  conversation.append({"role": "system", "content": system_prompt})
@@ -36,7 +35,7 @@ def generate(
36
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
37
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
38
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
39
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
40
  input_ids = input_ids.to(model.device)
41
 
42
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
@@ -52,48 +51,60 @@ def generate(
52
  )
53
 
54
  outputs = []
55
- with torch.no_grad():
56
- model_outputs = model.generate(**generate_kwargs)
57
- for text in streamer.generate_from_iterator(model_outputs):
58
- outputs.append(text)
59
- yield "".join(outputs)
60
 
61
- chat_interface = gr.Interface(
62
- fn=generate,
63
- inputs=[
64
- gr.Textbox(label="User Input", lines=5, placeholder="Enter your message..."),
65
- gr.Textbox(label="System Prompt", lines=5, placeholder="Enter system prompt (optional)..."),
66
- gr.Slider(
67
- label="Max New Tokens",
68
- minimum=1,
69
- maximum=MAX_MAX_NEW_TOKENS,
70
- step=1,
71
- value=DEFAULT_MAX_NEW_TOKENS,
72
- ),
73
- gr.Slider(
74
- label="Temperature",
75
- minimum=0.01,
76
- maximum=1.0,
77
- step=0.01,
78
- value=0.7,
79
- ),
80
- gr.Slider(
81
- label="Top-p (Nucleus Sampling)",
82
- minimum=0.05,
83
- maximum=1.0,
84
- step=0.01,
85
- value=1.0,
86
- ),
87
- gr.Slider(
88
- label="Repetition Penalty",
89
- minimum=1.0,
90
- maximum=2.0,
91
- step=0.05,
92
- value=1.1,
93
- ),
94
- "generate" # This is a placeholder for the button
95
- ],
96
- outputs=gr.Textbox(label="Chat Output", lines=10),
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  title="🦣MAmmoTH2",
98
  description="A simple web interactive chat demo based on gradio.",
99
  examples=[
@@ -105,6 +116,4 @@ chat_interface = gr.Interface(
105
  ],
106
  theme="default",
107
  live=True,
108
- )
109
-
110
- chat_interface.launch()
 
16
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
 
19
+ def generate_and_display(
 
20
  message: str,
21
  chat_history: list[tuple[str, str]],
22
  system_prompt: str,
 
24
  temperature: float = 0.7,
25
  top_p: float = 1.0,
26
  repetition_penalty: float = 1.1,
27
+ ) -> str:
28
  conversation = []
29
  if system_prompt:
30
  conversation.append({"role": "system", "content": system_prompt})
 
35
  input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
36
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
37
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
38
+ gr.warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
39
  input_ids = input_ids.to(model.device)
40
 
41
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
51
  )
52
 
53
  outputs = []
54
+ model_outputs = model.generate(**generate_kwargs)
55
+ for text in streamer.generate_from_iterator(model_outputs):
56
+ outputs.append(text)
57
+ return "".join(outputs)
 
58
 
59
+ def generate_response():
60
+ outputs = generate_and_display(
61
+ input_textbox.value,
62
+ chat_history=[],
63
+ system_prompt=system_prompt_textbox.value,
64
+ max_new_tokens=max_new_tokens_slider.value,
65
+ temperature=temperature_slider.value,
66
+ top_p=top_p_slider.value,
67
+ repetition_penalty=repetition_penalty_slider.value,
68
+ )
69
+ chat_output_textbox.value = outputs
70
+
71
+ input_textbox = gr.Textbox(label="User Input", lines=5, placeholder="Enter your message...")
72
+ system_prompt_textbox = gr.Textbox(label="System Prompt", lines=5, placeholder="Enter system prompt (optional)...")
73
+ max_new_tokens_slider = gr.Slider(
74
+ label="Max New Tokens",
75
+ minimum=1,
76
+ maximum=MAX_MAX_NEW_TOKENS,
77
+ step=1,
78
+ value=DEFAULT_MAX_NEW_TOKENS,
79
+ )
80
+ temperature_slider = gr.Slider(
81
+ label="Temperature",
82
+ minimum=0.01,
83
+ maximum=1.0,
84
+ step=0.01,
85
+ value=0.7,
86
+ )
87
+ top_p_slider = gr.Slider(
88
+ label="Top-p (Nucleus Sampling)",
89
+ minimum=0.05,
90
+ maximum=1.0,
91
+ step=0.01,
92
+ value=1.0,
93
+ )
94
+ repetition_penalty_slider = gr.Slider(
95
+ label="Repetition Penalty",
96
+ minimum=1.0,
97
+ maximum=2.0,
98
+ step=0.05,
99
+ value=1.1,
100
+ )
101
+ generate_button = gr.Button(label="Generate Response", command=generate_response)
102
+ chat_output_textbox = gr.Textbox(label="Chat Output", lines=10)
103
+
104
+ gr.Interface(
105
+ generate_and_display,
106
+ inputs=[input_textbox, system_prompt_textbox, max_new_tokens_slider, temperature_slider, top_p_slider, repetition_penalty_slider],
107
+ outputs=chat_output_textbox,
108
  title="🦣MAmmoTH2",
109
  description="A simple web interactive chat demo based on gradio.",
110
  examples=[
 
116
  ],
117
  theme="default",
118
  live=True,
119
+ ).launch()