aaabiao commited on
Commit
f2951f1
·
verified ·
1 Parent(s): c6210fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -18
app.py CHANGED
@@ -5,13 +5,7 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import (
9
- AutoModelForCausalLM,
10
- AutoTokenizer,
11
- StoppingCriteria,
12
- StoppingCriteriaList,
13
- TextIteratorStreamer,
14
- )
15
 
16
  MAX_MAX_NEW_TOKENS = 2048
17
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -46,19 +40,14 @@ def generate(
46
  input_ids = input_ids.to(model.device)
47
 
48
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
49
-
50
- stop_words = ["</s>"]
51
- stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
52
- stopping_criteria = StoppingCriteriaList([StoppingCriteria(stops=stop_words_ids)])
53
-
54
  generate_kwargs = dict(
55
- input_ids=model_inputs,
56
  streamer=streamer,
57
  max_new_tokens=max_new_tokens,
58
  do_sample=True,
59
  top_p=top_p,
60
  temperature=temperature,
61
- stopping_criteria=stopping_criteria,
62
  repetition_penalty=repetition_penalty,
63
  )
64
  t = Thread(target=model.generate, kwargs=generate_kwargs)
@@ -69,11 +58,10 @@ def generate(
69
  outputs.append(text)
70
  yield "".join(outputs)
71
 
72
- stop_button = gr.Button(text="Stop")
73
  chat_interface = gr.ChatInterface(
74
  fn=generate,
75
  additional_inputs=[
76
- gr.Textbox(label="System prompt", lines=6),
77
  gr.Slider(
78
  label="Max new tokens",
79
  minimum=1,
@@ -103,7 +91,7 @@ chat_interface = gr.ChatInterface(
103
  value=1.1,
104
  ),
105
  ],
106
- stop_btn=stop_button, # Use the created stop button instance
107
  examples=[
108
  ["Hello there! How are you doing?"],
109
  ["Can you explain briefly to me what is the Python programming language?"],
@@ -113,7 +101,6 @@ chat_interface = gr.ChatInterface(
113
  ],
114
  )
115
 
116
-
117
  with gr.Blocks(css="style.css") as demo:
118
  chat_interface.render()
119
 
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
 
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
 
40
  input_ids = input_ids.to(model.device)
41
 
42
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
43
  generate_kwargs = dict(
44
+ {"input_ids": input_ids},
45
  streamer=streamer,
46
  max_new_tokens=max_new_tokens,
47
  do_sample=True,
48
  top_p=top_p,
49
  temperature=temperature,
50
+ num_beams=1,
51
  repetition_penalty=repetition_penalty,
52
  )
53
  t = Thread(target=model.generate, kwargs=generate_kwargs)
 
58
  outputs.append(text)
59
  yield "".join(outputs)
60
 
 
61
  chat_interface = gr.ChatInterface(
62
  fn=generate,
63
  additional_inputs=[
64
+ gr.Textbox(label="System prompt", lines=6, width=800), # Adjust width here
65
  gr.Slider(
66
  label="Max new tokens",
67
  minimum=1,
 
91
  value=1.1,
92
  ),
93
  ],
94
+ stop_btn=None,
95
  examples=[
96
  ["Hello there! How are you doing?"],
97
  ["Can you explain briefly to me what is the Python programming language?"],
 
101
  ],
102
  )
103
 
 
104
  with gr.Blocks(css="style.css") as demo:
105
  chat_interface.render()
106