torVik commited on
Commit
655cd27
·
verified ·
1 Parent(s): ebddce7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -75
app.py CHANGED
@@ -9,21 +9,39 @@ import spaces
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
 
 
 
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
13
 
14
  DESCRIPTION = "# Mistral-7B v0.2"
15
 
16
  if not torch.cuda.is_available():
17
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
 
 
 
18
 
19
  MAX_MAX_NEW_TOKENS = 2048
20
  DEFAULT_MAX_NEW_TOKENS = 1024
21
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
22
 
 
23
  if torch.cuda.is_available():
24
  model_id = "mistralai/Mistral-7B-Instruct-v0.2"
25
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
26
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  @spaces.GPU
@@ -36,36 +54,54 @@ def generate(
36
  top_k: int = 50,
37
  repetition_penalty: float = 1.2,
38
  ) -> Iterator[str]:
 
 
 
39
  conversation = []
40
  for user, assistant in chat_history:
41
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
42
  conversation.append({"role": "user", "content": message})
43
 
44
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
45
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
47
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
- input_ids = input_ids.to(model.device)
49
-
50
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
51
- generate_kwargs = dict(
52
- {"input_ids": input_ids},
53
- streamer=streamer,
54
- max_new_tokens=max_new_tokens,
55
- do_sample=True,
56
- top_p=top_p,
57
- top_k=top_k,
58
- temperature=temperature,
59
- num_beams=1,
60
- repetition_penalty=repetition_penalty,
61
- )
62
- t = Thread(target=model.generate, kwargs=generate_kwargs)
63
- t.start()
64
-
65
- outputs = []
66
- for text in streamer:
67
- outputs.append(text)
68
- yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
  chat_interface = gr.ChatInterface(
@@ -117,6 +153,9 @@ chat_interface = gr.ChatInterface(
117
  ],
118
  )
119
 
 
 
 
120
  with gr.Blocks(css="style.css") as demo:
121
  gr.Markdown(DESCRIPTION)
122
  gr.DuplicateButton(
@@ -126,57 +165,193 @@ with gr.Blocks(css="style.css") as demo:
126
  )
127
  chat_interface.render()
128
 
 
 
 
129
  if __name__ == "__main__":
130
  demo.queue(max_size=20).launch(share=True)
131
 
132
- gr.ChatInterface(
133
- fn=generate,
134
- additional_inputs=[
135
- gr.Slider(
136
- label="Max new tokens",
137
- minimum=1,
138
- maximum=MAX_MAX_NEW_TOKENS,
139
- step=1,
140
- value=DEFAULT_MAX_NEW_TOKENS,
141
- ),
142
- gr.Slider(
143
- label="Temperature",
144
- minimum=0.1,
145
- maximum=4.0,
146
- step=0.1,
147
- value=0.6,
148
- ),
149
- gr.Slider(
150
- label="Top-p (nucleus sampling)",
151
- minimum=0.05,
152
- maximum=1.0,
153
- step=0.05,
154
- value=0.9,
155
- ),
156
- gr.Slider(
157
- label="Top-k",
158
- minimum=1,
159
- maximum=1000,
160
- step=1,
161
- value=50,
162
- ),
163
- gr.Slider(
164
- label="Repetition penalty",
165
- minimum=1.0,
166
- maximum=2.0,
167
- step=0.05,
168
- value=1.2,
169
- ),
170
- ],
171
- stop_btn=None,
172
- examples=[
173
- ["Hello there! How are you doing?"],
174
- ["Can you explain briefly to me what is the Python programming language?"],
175
- ["Explain the plot of Cinderella in a sentence."],
176
- ["How many hours does it take a man to eat a Helicopter?"],
177
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
178
- ],
179
- ).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
 
182
 
 
9
  import torch
10
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
+ # Debugging: Start script
13
+ print("Starting script...")
14
+
15
  HF_TOKEN = os.environ.get("HF_TOKEN")
16
+ if HF_TOKEN is None:
17
+ print("Warning: HF_TOKEN is not set!")
18
 
19
  DESCRIPTION = "# Mistral-7B v0.2"
20
 
21
  if not torch.cuda.is_available():
22
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
23
+ print("Warning: No GPU available. This model cannot run on CPU.")
24
+ else:
25
+ print("GPU is available!")
26
 
27
  MAX_MAX_NEW_TOKENS = 2048
28
  DEFAULT_MAX_NEW_TOKENS = 1024
29
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
 
31
+ # Debugging: GPU check passed, loading model
32
  if torch.cuda.is_available():
33
  model_id = "mistralai/Mistral-7B-Instruct-v0.2"
34
+ try:
35
+ print("Loading model...")
36
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
37
+ print("Model loaded successfully!")
38
+
39
+ print("Loading tokenizer...")
40
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
41
+ print("Tokenizer loaded successfully!")
42
+ except Exception as e:
43
+ print(f"Error loading model or tokenizer: {e}")
44
+ raise e # Re-raise the error after logging it
45
 
46
 
47
  @spaces.GPU
 
54
  top_k: int = 50,
55
  repetition_penalty: float = 1.2,
56
  ) -> Iterator[str]:
57
+ print(f"Received message: {message}")
58
+ print(f"Chat history: {chat_history}")
59
+
60
  conversation = []
61
  for user, assistant in chat_history:
62
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
63
  conversation.append({"role": "user", "content": message})
64
 
65
+ try:
66
+ print("Tokenizing input...")
67
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
68
+ print(f"Input tokenized: {input_ids.shape}")
69
+
70
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
71
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
72
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
73
+ print("Trimmed input tokens due to length.")
74
+
75
+ input_ids = input_ids.to(model.device)
76
+ print("Input moved to the model's device.")
77
+
78
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
79
+ generate_kwargs = dict(
80
+ {"input_ids": input_ids},
81
+ streamer=streamer,
82
+ max_new_tokens=max_new_tokens,
83
+ do_sample=True,
84
+ top_p=top_p,
85
+ top_k=top_k,
86
+ temperature=temperature,
87
+ num_beams=1,
88
+ repetition_penalty=repetition_penalty,
89
+ )
90
+
91
+ print("Starting generation...")
92
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
93
+ t.start()
94
+ print("Thread started for model generation.")
95
+
96
+ outputs = []
97
+ for text in streamer:
98
+ outputs.append(text)
99
+ print(f"Generated text so far: {''.join(outputs)}")
100
+ yield "".join(outputs)
101
+
102
+ except Exception as e:
103
+ print(f"Error during generation: {e}")
104
+ raise e # Re-raise the error after logging it
105
 
106
 
107
  chat_interface = gr.ChatInterface(
 
153
  ],
154
  )
155
 
156
+ # Debugging: Interface setup
157
+ print("Setting up interface...")
158
+
159
  with gr.Blocks(css="style.css") as demo:
160
  gr.Markdown(DESCRIPTION)
161
  gr.DuplicateButton(
 
165
  )
166
  chat_interface.render()
167
 
168
+ # Debugging: Starting queue and launching the demo
169
+ print("Launching demo...")
170
+
171
  if __name__ == "__main__":
172
  demo.queue(max_size=20).launch(share=True)
173
 
174
+
175
+
176
+ #!/usr/bin/env python
177
+
178
+ # import os
179
+ # from threading import Thread
180
+ # from typing import Iterator
181
+
182
+ # import gradio as gr
183
+ # import spaces
184
+ # import torch
185
+ # from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
186
+
187
+ # HF_TOKEN = os.environ.get("HF_TOKEN")
188
+
189
+ # DESCRIPTION = "# Mistral-7B v0.2"
190
+
191
+ # if not torch.cuda.is_available():
192
+ # DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
193
+
194
+ # MAX_MAX_NEW_TOKENS = 2048
195
+ # DEFAULT_MAX_NEW_TOKENS = 1024
196
+ # MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
197
+
198
+ # if torch.cuda.is_available():
199
+ # model_id = "mistralai/Mistral-7B-Instruct-v0.2"
200
+ # model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
201
+ # tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
202
+
203
+
204
+ # @spaces.GPU
205
+ # def generate(
206
+ # message: str,
207
+ # chat_history: list[tuple[str, str]],
208
+ # max_new_tokens: int = 1024,
209
+ # temperature: float = 0.6,
210
+ # top_p: float = 0.9,
211
+ # top_k: int = 50,
212
+ # repetition_penalty: float = 1.2,
213
+ # ) -> Iterator[str]:
214
+ # conversation = []
215
+ # for user, assistant in chat_history:
216
+ # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
217
+ # conversation.append({"role": "user", "content": message})
218
+
219
+ # input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
220
+ # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
221
+ # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
222
+ # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
223
+ # input_ids = input_ids.to(model.device)
224
+
225
+ # streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
226
+ # generate_kwargs = dict(
227
+ # {"input_ids": input_ids},
228
+ # streamer=streamer,
229
+ # max_new_tokens=max_new_tokens,
230
+ # do_sample=True,
231
+ # top_p=top_p,
232
+ # top_k=top_k,
233
+ # temperature=temperature,
234
+ # num_beams=1,
235
+ # repetition_penalty=repetition_penalty,
236
+ # )
237
+ # t = Thread(target=model.generate, kwargs=generate_kwargs)
238
+ # t.start()
239
+
240
+ # outputs = []
241
+ # for text in streamer:
242
+ # outputs.append(text)
243
+ # yield "".join(outputs)
244
+
245
+
246
+ # chat_interface = gr.ChatInterface(
247
+ # fn=generate,
248
+ # additional_inputs=[
249
+ # gr.Slider(
250
+ # label="Max new tokens",
251
+ # minimum=1,
252
+ # maximum=MAX_MAX_NEW_TOKENS,
253
+ # step=1,
254
+ # value=DEFAULT_MAX_NEW_TOKENS,
255
+ # ),
256
+ # gr.Slider(
257
+ # label="Temperature",
258
+ # minimum=0.1,
259
+ # maximum=4.0,
260
+ # step=0.1,
261
+ # value=0.6,
262
+ # ),
263
+ # gr.Slider(
264
+ # label="Top-p (nucleus sampling)",
265
+ # minimum=0.05,
266
+ # maximum=1.0,
267
+ # step=0.05,
268
+ # value=0.9,
269
+ # ),
270
+ # gr.Slider(
271
+ # label="Top-k",
272
+ # minimum=1,
273
+ # maximum=1000,
274
+ # step=1,
275
+ # value=50,
276
+ # ),
277
+ # gr.Slider(
278
+ # label="Repetition penalty",
279
+ # minimum=1.0,
280
+ # maximum=2.0,
281
+ # step=0.05,
282
+ # value=1.2,
283
+ # ),
284
+ # ],
285
+ # stop_btn=None,
286
+ # examples=[
287
+ # ["Hello there! How are you doing?"],
288
+ # ["Can you explain briefly to me what is the Python programming language?"],
289
+ # ["Explain the plot of Cinderella in a sentence."],
290
+ # ["How many hours does it take a man to eat a Helicopter?"],
291
+ # ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
292
+ # ],
293
+ # )
294
+
295
+ # with gr.Blocks(css="style.css") as demo:
296
+ # gr.Markdown(DESCRIPTION)
297
+ # gr.DuplicateButton(
298
+ # value="Duplicate Space for private use",
299
+ # elem_id="duplicate-button",
300
+ # visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
301
+ # )
302
+ # chat_interface.render()
303
+
304
+ # if __name__ == "__main__":
305
+ # demo.queue(max_size=20).launch(share=True)
306
+
307
+ # gr.ChatInterface(
308
+ # fn=generate,
309
+ # additional_inputs=[
310
+ # gr.Slider(
311
+ # label="Max new tokens",
312
+ # minimum=1,
313
+ # maximum=MAX_MAX_NEW_TOKENS,
314
+ # step=1,
315
+ # value=DEFAULT_MAX_NEW_TOKENS,
316
+ # ),
317
+ # gr.Slider(
318
+ # label="Temperature",
319
+ # minimum=0.1,
320
+ # maximum=4.0,
321
+ # step=0.1,
322
+ # value=0.6,
323
+ # ),
324
+ # gr.Slider(
325
+ # label="Top-p (nucleus sampling)",
326
+ # minimum=0.05,
327
+ # maximum=1.0,
328
+ # step=0.05,
329
+ # value=0.9,
330
+ # ),
331
+ # gr.Slider(
332
+ # label="Top-k",
333
+ # minimum=1,
334
+ # maximum=1000,
335
+ # step=1,
336
+ # value=50,
337
+ # ),
338
+ # gr.Slider(
339
+ # label="Repetition penalty",
340
+ # minimum=1.0,
341
+ # maximum=2.0,
342
+ # step=0.05,
343
+ # value=1.2,
344
+ # ),
345
+ # ],
346
+ # stop_btn=None,
347
+ # examples=[
348
+ # ["Hello there! How are you doing?"],
349
+ # ["Can you explain briefly to me what is the Python programming language?"],
350
+ # ["Explain the plot of Cinderella in a sentence."],
351
+ # ["How many hours does it take a man to eat a Helicopter?"],
352
+ # ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
353
+ # ],
354
+ # ).launch(share=True)
355
 
356
 
357