aaabiao commited on
Commit
4cbf1bf
·
verified ·
1 Parent(s): 6b3c22e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -16
app.py CHANGED
@@ -5,7 +5,7 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -25,7 +25,6 @@ def generate(
25
  temperature: float = 0.7,
26
  top_p: float = 1.0,
27
  repetition_penalty: float = 1.1,
28
- stop_ids: list[list[int]] = [[2, 6, 7, 8]]
29
  ) -> Iterator[str]:
30
  conversation = []
31
  if system_prompt:
@@ -40,18 +39,21 @@ def generate(
40
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
41
  input_ids = input_ids.to(model.device)
42
 
 
 
 
 
43
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
44
- generate_kwargs = dict(
45
- {"input_ids": input_ids},
46
- streamer=streamer,
47
- max_new_tokens=max_new_tokens,
48
- do_sample=True,
49
- top_p=top_p,
50
- temperature=temperature,
51
- num_beams=1,
52
- repetition_penalty=repetition_penalty,
53
- stop_ids=stop_ids
54
- )
55
  t = Thread(target=model.generate, kwargs=generate_kwargs)
56
  t.start()
57
 
@@ -63,7 +65,7 @@ def generate(
63
  chat_interface = gr.ChatInterface(
64
  fn=generate,
65
  additional_inputs=[
66
- gr.Textbox(label="System prompt", lines=8), # Increase lines for larger box
67
  gr.Slider(
68
  label="Max new tokens",
69
  minimum=1,
@@ -93,7 +95,7 @@ chat_interface = gr.ChatInterface(
93
  value=1.1,
94
  ),
95
  ],
96
- stop_token=None, # Remove stop token
97
  examples=[
98
  ["Hello there! How are you doing?"],
99
  ["Can you explain briefly to me what is the Python programming language?"],
@@ -107,4 +109,4 @@ with gr.Blocks(css="style.css") as demo:
107
  chat_interface.render()
108
 
109
  if __name__ == "__main__":
110
- demo.queue(max_size=20).launch()
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteriaSub
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
 
25
  temperature: float = 0.7,
26
  top_p: float = 1.0,
27
  repetition_penalty: float = 1.1,
 
28
  ) -> Iterator[str]:
29
  conversation = []
30
  if system_prompt:
 
39
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
40
  input_ids = input_ids.to(model.device)
41
 
42
+ stop_words = ["</s>"]
43
+ stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
44
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
45
+
46
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
47
+ generate_kwargs = {
48
+ "input_ids": input_ids,
49
+ "streamer": streamer,
50
+ "max_new_tokens": max_new_tokens,
51
+ "do_sample": True,
52
+ "top_p": top_p,
53
+ "temperature": temperature,
54
+ "stopping_criteria": stopping_criteria,
55
+ "repetition_penalty": repetition_penalty,
56
+ }
 
57
  t = Thread(target=model.generate, kwargs=generate_kwargs)
58
  t.start()
59
 
 
65
  chat_interface = gr.ChatInterface(
66
  fn=generate,
67
  additional_inputs=[
68
+ gr.Textbox(label="System prompt", lines=6),
69
  gr.Slider(
70
  label="Max new tokens",
71
  minimum=1,
 
95
  value=1.1,
96
  ),
97
  ],
98
+ stop_words=stop_words, # Set the stop words
99
  examples=[
100
  ["Hello there! How are you doing?"],
101
  ["Can you explain briefly to me what is the Python programming language?"],
 
109
  chat_interface.render()
110
 
111
  if __name__ == "__main__":
112
+ demo.queue(max_size=20).launch()