Ruurd commited on
Commit
719a76f
·
1 Parent(s): da4880c

Remove one chat_with_model function

Browse files
Files changed (1) hide show
  1. app.py +39 -2
app.py CHANGED
@@ -3,7 +3,44 @@ import torch
3
  import time
4
  import gradio as gr
5
  import spaces
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Globals
9
  current_model = None
@@ -62,7 +99,7 @@ with gr.Blocks() as demo:
62
  yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
63
  return
64
 
65
- current_model = current_model.half().to("cuda")
66
 
67
  prompt = format_prompt(messages)
68
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
 
3
  import time
4
  import gradio as gr
5
  import spaces
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
+
8
+ # @spaces.GPU
9
+ # def chat_with_model(messages):
10
+ # global current_model, current_tokenizer
11
+ # if current_model is None or current_tokenizer is None:
12
+ # yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
13
+ # return
14
+
15
+ # current_model.to("cuda").half()
16
+
17
+ # prompt = format_prompt(messages)
18
+ # inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
19
+
20
+ # streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=True)
21
+ # generation_kwargs = dict(
22
+ # **inputs,
23
+ # max_new_tokens=256,
24
+ # do_sample=True,
25
+ # streamer=streamer
26
+ # )
27
+
28
+ # # Launch generation in a background thread
29
+ # thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
30
+ # thread.start()
31
+
32
+ # output_text = ""
33
+ # messages = messages.copy()
34
+ # messages.append({"role": "assistant", "content": ""})
35
+
36
+ # for new_text in streamer:
37
+ # output_text += new_text
38
+ # messages[-1]["content"] = output_text
39
+ # yield messages
40
+
41
+ # current_model.to("cpu")
42
+ # torch.cuda.empty_cache()
43
+
44
 
45
  # Globals
46
  current_model = None
 
99
  yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
100
  return
101
 
102
+ current_model = current_model.to("cuda").half()
103
 
104
  prompt = format_prompt(messages)
105
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)