Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,7 @@ from typing import Iterator
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
8 |
-
from transformers import
|
9 |
-
AutoModelForCausalLM,
|
10 |
-
AutoTokenizer,
|
11 |
-
StoppingCriteria,
|
12 |
-
StoppingCriteriaList,
|
13 |
-
TextIteratorStreamer,
|
14 |
-
)
|
15 |
|
16 |
MAX_MAX_NEW_TOKENS = 2048
|
17 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
@@ -46,19 +40,14 @@ def generate(
|
|
46 |
input_ids = input_ids.to(model.device)
|
47 |
|
48 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
49 |
-
|
50 |
-
stop_words = ["</s>"]
|
51 |
-
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
|
52 |
-
stopping_criteria = StoppingCriteriaList([StoppingCriteria(stops=stop_words_ids)])
|
53 |
-
|
54 |
generate_kwargs = dict(
|
55 |
-
input_ids
|
56 |
streamer=streamer,
|
57 |
max_new_tokens=max_new_tokens,
|
58 |
do_sample=True,
|
59 |
top_p=top_p,
|
60 |
temperature=temperature,
|
61 |
-
|
62 |
repetition_penalty=repetition_penalty,
|
63 |
)
|
64 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
@@ -69,11 +58,10 @@ def generate(
|
|
69 |
outputs.append(text)
|
70 |
yield "".join(outputs)
|
71 |
|
72 |
-
stop_button = gr.Button(text="Stop")
|
73 |
chat_interface = gr.ChatInterface(
|
74 |
fn=generate,
|
75 |
additional_inputs=[
|
76 |
-
gr.Textbox(label="System prompt", lines=6),
|
77 |
gr.Slider(
|
78 |
label="Max new tokens",
|
79 |
minimum=1,
|
@@ -103,7 +91,7 @@ chat_interface = gr.ChatInterface(
|
|
103 |
value=1.1,
|
104 |
),
|
105 |
],
|
106 |
-
stop_btn=
|
107 |
examples=[
|
108 |
["Hello there! How are you doing?"],
|
109 |
["Can you explain briefly to me what is the Python programming language?"],
|
@@ -113,7 +101,6 @@ chat_interface = gr.ChatInterface(
|
|
113 |
],
|
114 |
)
|
115 |
|
116 |
-
|
117 |
with gr.Blocks(css="style.css") as demo:
|
118 |
chat_interface.render()
|
119 |
|
|
|
5 |
import gradio as gr
|
6 |
import spaces
|
7 |
import torch
|
8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
MAX_MAX_NEW_TOKENS = 2048
|
11 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
|
|
40 |
input_ids = input_ids.to(model.device)
|
41 |
|
42 |
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
43 |
generate_kwargs = dict(
|
44 |
+
{"input_ids": input_ids},
|
45 |
streamer=streamer,
|
46 |
max_new_tokens=max_new_tokens,
|
47 |
do_sample=True,
|
48 |
top_p=top_p,
|
49 |
temperature=temperature,
|
50 |
+
num_beams=1,
|
51 |
repetition_penalty=repetition_penalty,
|
52 |
)
|
53 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
|
|
58 |
outputs.append(text)
|
59 |
yield "".join(outputs)
|
60 |
|
|
|
61 |
chat_interface = gr.ChatInterface(
|
62 |
fn=generate,
|
63 |
additional_inputs=[
|
64 |
+
gr.Textbox(label="System prompt", lines=6, width=800), # Adjust width here
|
65 |
gr.Slider(
|
66 |
label="Max new tokens",
|
67 |
minimum=1,
|
|
|
91 |
value=1.1,
|
92 |
),
|
93 |
],
|
94 |
+
stop_btn=None,
|
95 |
examples=[
|
96 |
["Hello there! How are you doing?"],
|
97 |
["Can you explain briefly to me what is the Python programming language?"],
|
|
|
101 |
],
|
102 |
)
|
103 |
|
|
|
104 |
with gr.Blocks(css="style.css") as demo:
|
105 |
chat_interface.render()
|
106 |
|