Adrien Dor commited on
Commit
f3304f4
Β·
1 Parent(s): a60081f

new space test

Browse files
Files changed (2) hide show
  1. app.py +1 -273
  2. model.py +0 -76
app.py CHANGED
@@ -1,275 +1,3 @@
1
- from typing import Iterator
2
-
3
  import gradio as gr
4
- import torch
5
-
6
- from model import get_input_token_length, run
7
-
8
- DEFAULT_SYSTEM_PROMPT = """\
9
- You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\
10
- """
11
- MAX_MAX_NEW_TOKENS = 2048
12
- DEFAULT_MAX_NEW_TOKENS = 1024
13
- MAX_INPUT_TOKEN_LENGTH = 4000
14
-
15
- DESCRIPTION = """
16
- # Llama-2 13B Chat
17
- This Space demonstrates model [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta, a Llama 2 model with 13B parameters fine-tuned for chat instructions. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
18
- πŸ”Ž For more details about the Llama 2 family of models and how to use them with `transformers`, take a look [at our blog post](https://huggingface.co/blog/llama2).
19
- πŸ”¨ Looking for an even more powerful model? Check out the large [**70B** model demo](https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI).
20
- πŸ‡ For a smaller model that you can run on many GPUs, check our [7B model demo](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat).
21
- """
22
-
23
- LICENSE = """
24
- <p/>
25
- ---
26
- As a derivate work of [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta,
27
- this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/USE_POLICY.md).
28
- """
29
-
30
- if not torch.cuda.is_available():
31
- DESCRIPTION += '\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>'
32
-
33
-
34
- def clear_and_save_textbox(message: str) -> tuple[str, str]:
35
- return '', message
36
-
37
-
38
- def display_input(message: str,
39
- history: list[tuple[str, str]]) -> list[tuple[str, str]]:
40
- history.append((message, ''))
41
- return history
42
-
43
-
44
- def delete_prev_fn(
45
- history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
46
- try:
47
- message, _ = history.pop()
48
- except IndexError:
49
- message = ''
50
- return history, message or ''
51
-
52
-
53
- def generate(
54
- message: str,
55
- history_with_input: list[tuple[str, str]],
56
- system_prompt: str,
57
- max_new_tokens: int,
58
- temperature: float,
59
- top_p: float,
60
- top_k: int,
61
- ) -> Iterator[list[tuple[str, str]]]:
62
- if max_new_tokens > MAX_MAX_NEW_TOKENS:
63
- raise ValueError
64
-
65
- history = history_with_input[:-1]
66
- generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
67
- try:
68
- first_response = next(generator)
69
- yield history + [(message, first_response)]
70
- except StopIteration:
71
- yield history + [(message, '')]
72
- for response in generator:
73
- yield history + [(message, response)]
74
-
75
-
76
- def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
77
- generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
78
- for x in generator:
79
- pass
80
- return '', x
81
-
82
-
83
- def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
84
- input_token_length = get_input_token_length(message, chat_history, system_prompt)
85
- if input_token_length > MAX_INPUT_TOKEN_LENGTH:
86
- raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
87
-
88
-
89
- with gr.Blocks(css='style.css') as demo:
90
- gr.Markdown(DESCRIPTION)
91
- gr.DuplicateButton(value='Duplicate Space for private use',
92
- elem_id='duplicate-button')
93
-
94
- with gr.Group():
95
- chatbot = gr.Chatbot(label='Chatbot')
96
- with gr.Row():
97
- textbox = gr.Textbox(
98
- container=False,
99
- show_label=False,
100
- placeholder='Type a message...',
101
- scale=10,
102
- )
103
- submit_button = gr.Button('Submit',
104
- variant='primary',
105
- scale=1,
106
- min_width=0)
107
- with gr.Row():
108
- retry_button = gr.Button('πŸ”„ Retry', variant='secondary')
109
- undo_button = gr.Button('↩️ Undo', variant='secondary')
110
- clear_button = gr.Button('πŸ—‘οΈ Clear', variant='secondary')
111
-
112
- saved_input = gr.State()
113
-
114
- with gr.Accordion(label='Advanced options', open=False):
115
- system_prompt = gr.Textbox(label='System prompt',
116
- value=DEFAULT_SYSTEM_PROMPT,
117
- lines=6)
118
- max_new_tokens = gr.Slider(
119
- label='Max new tokens',
120
- minimum=1,
121
- maximum=MAX_MAX_NEW_TOKENS,
122
- step=1,
123
- value=DEFAULT_MAX_NEW_TOKENS,
124
- )
125
- temperature = gr.Slider(
126
- label='Temperature',
127
- minimum=0.1,
128
- maximum=4.0,
129
- step=0.1,
130
- value=1.0,
131
- )
132
- top_p = gr.Slider(
133
- label='Top-p (nucleus sampling)',
134
- minimum=0.05,
135
- maximum=1.0,
136
- step=0.05,
137
- value=0.95,
138
- )
139
- top_k = gr.Slider(
140
- label='Top-k',
141
- minimum=1,
142
- maximum=1000,
143
- step=1,
144
- value=50,
145
- )
146
-
147
- gr.Examples(
148
- examples=[
149
- 'Hello there! How are you doing?',
150
- 'Can you explain briefly to me what is the Python programming language?',
151
- 'Explain the plot of Cinderella in a sentence.',
152
- 'How many hours does it take a man to eat a Helicopter?',
153
- "Write a 100-word article on 'Benefits of Open-Source in AI research'",
154
- ],
155
- inputs=textbox,
156
- outputs=[textbox, chatbot],
157
- fn=process_example,
158
- cache_examples=True,
159
- )
160
-
161
- gr.Markdown(LICENSE)
162
-
163
- textbox.submit(
164
- fn=clear_and_save_textbox,
165
- inputs=textbox,
166
- outputs=[textbox, saved_input],
167
- api_name=False,
168
- queue=False,
169
- ).then(
170
- fn=display_input,
171
- inputs=[saved_input, chatbot],
172
- outputs=chatbot,
173
- api_name=False,
174
- queue=False,
175
- ).then(
176
- fn=check_input_token_length,
177
- inputs=[saved_input, chatbot, system_prompt],
178
- api_name=False,
179
- queue=False,
180
- ).success(
181
- fn=generate,
182
- inputs=[
183
- saved_input,
184
- chatbot,
185
- system_prompt,
186
- max_new_tokens,
187
- temperature,
188
- top_p,
189
- top_k,
190
- ],
191
- outputs=chatbot,
192
- api_name=False,
193
- )
194
-
195
- button_event_preprocess = submit_button.click(
196
- fn=clear_and_save_textbox,
197
- inputs=textbox,
198
- outputs=[textbox, saved_input],
199
- api_name=False,
200
- queue=False,
201
- ).then(
202
- fn=display_input,
203
- inputs=[saved_input, chatbot],
204
- outputs=chatbot,
205
- api_name=False,
206
- queue=False,
207
- ).then(
208
- fn=check_input_token_length,
209
- inputs=[saved_input, chatbot, system_prompt],
210
- api_name=False,
211
- queue=False,
212
- ).success(
213
- fn=generate,
214
- inputs=[
215
- saved_input,
216
- chatbot,
217
- system_prompt,
218
- max_new_tokens,
219
- temperature,
220
- top_p,
221
- top_k,
222
- ],
223
- outputs=chatbot,
224
- api_name=False,
225
- )
226
-
227
- retry_button.click(
228
- fn=delete_prev_fn,
229
- inputs=chatbot,
230
- outputs=[chatbot, saved_input],
231
- api_name=False,
232
- queue=False,
233
- ).then(
234
- fn=display_input,
235
- inputs=[saved_input, chatbot],
236
- outputs=chatbot,
237
- api_name=False,
238
- queue=False,
239
- ).then(
240
- fn=generate,
241
- inputs=[
242
- saved_input,
243
- chatbot,
244
- system_prompt,
245
- max_new_tokens,
246
- temperature,
247
- top_p,
248
- top_k,
249
- ],
250
- outputs=chatbot,
251
- api_name=False,
252
- )
253
-
254
- undo_button.click(
255
- fn=delete_prev_fn,
256
- inputs=chatbot,
257
- outputs=[chatbot, saved_input],
258
- api_name=False,
259
- queue=False,
260
- ).then(
261
- fn=lambda x: x,
262
- inputs=[saved_input],
263
- outputs=textbox,
264
- api_name=False,
265
- queue=False,
266
- )
267
-
268
- clear_button.click(
269
- fn=lambda: ([], ''),
270
- outputs=[chatbot, saved_input],
271
- queue=False,
272
- api_name=False,
273
- )
274
 
275
- demo.queue(max_size=20).launch()
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ gr.Interface.load("models/meta-llama/Llama-2-70b-chat-hf").launch()
model.py CHANGED
@@ -1,76 +0,0 @@
1
- from threading import Thread
2
- from typing import Iterator
3
-
4
- import torch
5
- from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
-
7
- access_token='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'
8
- model_id = 'meta-llama/Llama-2-70b-chat-hf'
9
-
10
- if torch.cuda.is_available():
11
- config = AutoConfig.from_pretrained(model_id)
12
- config.pretraining_tp = 1
13
- model = AutoModelForCausalLM.from_pretrained(
14
- model_id,
15
- config=config,
16
- torch_dtype=torch.float16,
17
- load_in_4bit=True,
18
- device_map='auto',
19
- use_auth_token=access_token
20
- )
21
- else:
22
- model = None
23
- tokenizer = AutoTokenizer.from_pretrained(model_id)
24
- #test
25
-
26
- def get_prompt(message: str, chat_history: list[tuple[str, str]],
27
- system_prompt: str) -> str:
28
- texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
29
- # The first user input is _not_ stripped
30
- do_strip = False
31
- for user_input, response in chat_history:
32
- user_input = user_input.strip() if do_strip else user_input
33
- do_strip = True
34
- texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
35
- message = message.strip() if do_strip else message
36
- texts.append(f'{message} [/INST]')
37
- return ''.join(texts)
38
-
39
-
40
- def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
41
- prompt = get_prompt(message, chat_history, system_prompt)
42
- input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
43
- return input_ids.shape[-1]
44
-
45
-
46
- def run(message: str,
47
- chat_history: list[tuple[str, str]],
48
- system_prompt: str,
49
- max_new_tokens: int = 1024,
50
- temperature: float = 0.8,
51
- top_p: float = 0.95,
52
- top_k: int = 50) -> Iterator[str]:
53
- prompt = get_prompt(message, chat_history, system_prompt)
54
- inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to('cuda')
55
-
56
- streamer = TextIteratorStreamer(tokenizer,
57
- timeout=10.,
58
- skip_prompt=True,
59
- skip_special_tokens=True)
60
- generate_kwargs = dict(
61
- inputs,
62
- streamer=streamer,
63
- max_new_tokens=max_new_tokens,
64
- do_sample=True,
65
- top_p=top_p,
66
- top_k=top_k,
67
- temperature=temperature,
68
- num_beams=1,
69
- )
70
- t = Thread(target=model.generate, kwargs=generate_kwargs)
71
- t.start()
72
-
73
- outputs = []
74
- for text in streamer:
75
- outputs.append(text)
76
- yield ''.join(outputs)