aaabiao commited on
Commit
30dc26e
·
verified ·
1 Parent(s): cc138c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -20
app.py CHANGED
@@ -3,7 +3,6 @@ from threading import Thread
3
  from typing import Iterator
4
 
5
  import gradio as gr
6
- import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
@@ -16,7 +15,6 @@ if torch.cuda.is_available():
16
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
 
19
- @spaces.GPU
20
  def generate(
21
  message: str,
22
  chat_history: list[tuple[str, str]],
@@ -41,7 +39,7 @@ def generate(
41
 
42
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
43
  generate_kwargs = dict(
44
- {"input_ids": input_ids},
45
  streamer=streamer,
46
  max_new_tokens=max_new_tokens,
47
  do_sample=True,
@@ -58,51 +56,53 @@ def generate(
58
  outputs.append(text)
59
  yield "".join(outputs)
60
 
61
- chat_interface = gr.ChatInterface(
62
  fn=generate,
63
- additional_inputs=[
64
- gr.Textbox(label="System prompt", lines=6),
 
65
  gr.Slider(
66
- label="Max new tokens",
67
  minimum=1,
68
  maximum=MAX_MAX_NEW_TOKENS,
69
  step=1,
70
- value=DEFAULT_MAX_NEW_TOKENS,
71
  ),
72
  gr.Slider(
73
  label="Temperature",
74
  minimum=0.01,
75
  maximum=1.0,
76
  step=0.01,
77
- value=0.7,
78
  ),
79
  gr.Slider(
80
- label="Top-p (nucleus sampling)",
81
  minimum=0.05,
82
  maximum=1.0,
83
  step=0.01,
84
- value=1.0,
85
  ),
86
  gr.Slider(
87
- label="Repetition penalty",
88
  minimum=1.0,
89
  maximum=2.0,
90
  step=0.05,
91
- value=1.1,
92
  ),
93
  ],
94
- stop_btn=None,
 
 
95
  examples=[
96
  ["Hello there! How are you doing?"],
97
  ["Can you explain briefly to me what is the Python programming language?"],
98
  ["Explain the plot of Cinderella in a sentence."],
99
- ["How many hours does it take a man to eat a Helicopter?"],
100
  ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
101
  ],
 
 
 
102
  )
103
 
104
- with gr.Blocks(css="style.css") as demo:
105
- chat_interface.render()
106
-
107
- if __name__ == "__main__":
108
- demo.queue(max_size=20).launch()
 
3
  from typing import Iterator
4
 
5
  import gradio as gr
 
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
 
 
15
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
 
 
18
  def generate(
19
  message: str,
20
  chat_history: list[tuple[str, str]],
 
39
 
40
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
41
  generate_kwargs = dict(
42
+ input_ids=input_ids,
43
  streamer=streamer,
44
  max_new_tokens=max_new_tokens,
45
  do_sample=True,
 
56
  outputs.append(text)
57
  yield "".join(outputs)
58
 
59
+ chat_interface = gr.Interface(
60
  fn=generate,
61
+ inputs=[
62
+ gr.Textbox(label="User Input", lines=5, placeholder="Enter your message..."),
63
+ gr.Textbox(label="System Prompt", lines=5, placeholder="Enter system prompt (optional)..."),
64
  gr.Slider(
65
+ label="Max New Tokens",
66
  minimum=1,
67
  maximum=MAX_MAX_NEW_TOKENS,
68
  step=1,
69
+ default=DEFAULT_MAX_NEW_TOKENS,
70
  ),
71
  gr.Slider(
72
  label="Temperature",
73
  minimum=0.01,
74
  maximum=1.0,
75
  step=0.01,
76
+ default=0.7,
77
  ),
78
  gr.Slider(
79
+ label="Top-p (Nucleus Sampling)",
80
  minimum=0.05,
81
  maximum=1.0,
82
  step=0.01,
83
+ default=1.0,
84
  ),
85
  gr.Slider(
86
+ label="Repetition Penalty",
87
  minimum=1.0,
88
  maximum=2.0,
89
  step=0.05,
90
+ default=1.1,
91
  ),
92
  ],
93
+ outputs=gr.Textbox(label="Chat Output", lines=10),
94
+ title="🦣MAmmoTH2",
95
+ description="A simple web interactive chat demo based on gradio.",
96
  examples=[
97
  ["Hello there! How are you doing?"],
98
  ["Can you explain briefly to me what is the Python programming language?"],
99
  ["Explain the plot of Cinderella in a sentence."],
100
+ ["How many hours does it take a man to eat a helicopter?"],
101
  ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
102
  ],
103
+ theme="compact",
104
+ live=True,
105
+ capture_session=True,
106
  )
107
 
108
+ chat_interface.launch()