Ruurd commited on
Commit
dbf90db
·
1 Parent(s): 719a76f

Use background threading for generation

Browse files
Files changed (1) hide show
  1. app.py +63 -63
app.py CHANGED
@@ -5,41 +5,41 @@ import gradio as gr
5
  import spaces
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
 
8
- # @spaces.GPU
9
- # def chat_with_model(messages):
10
- # global current_model, current_tokenizer
11
- # if current_model is None or current_tokenizer is None:
12
- # yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
13
- # return
14
 
15
- # current_model.to("cuda").half()
16
 
17
- # prompt = format_prompt(messages)
18
- # inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
19
 
20
- # streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=True)
21
- # generation_kwargs = dict(
22
- # **inputs,
23
- # max_new_tokens=256,
24
- # do_sample=True,
25
- # streamer=streamer
26
- # )
27
 
28
- # # Launch generation in a background thread
29
- # thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
30
- # thread.start()
31
 
32
- # output_text = ""
33
- # messages = messages.copy()
34
- # messages.append({"role": "assistant", "content": ""})
35
 
36
- # for new_text in streamer:
37
- # output_text += new_text
38
- # messages[-1]["content"] = output_text
39
- # yield messages
40
 
41
- # current_model.to("cpu")
42
- # torch.cuda.empty_cache()
43
 
44
 
45
  # Globals
@@ -92,41 +92,41 @@ with gr.Blocks() as demo:
92
 
93
  default_model = gr.State("meta-llama/Llama-3.2-3B-Instruct")
94
 
95
- @spaces.GPU
96
- def chat_with_model(messages):
97
- global current_model, current_tokenizer
98
- if current_model is None or current_tokenizer is None:
99
- yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
100
- return
101
-
102
- current_model = current_model.to("cuda").half()
103
-
104
- prompt = format_prompt(messages)
105
- inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
106
-
107
- output_ids = []
108
- messages = messages.copy()
109
- messages.append({"role": "assistant", "content": ""})
110
-
111
- for token_id in current_model.generate(
112
- **inputs,
113
- max_new_tokens=256,
114
- do_sample=True,
115
- return_dict_in_generate=True,
116
- output_scores=False
117
- ).sequences[0][inputs['input_ids'].shape[-1]:]: # skip input tokens
118
- output_ids.append(token_id.item())
119
- decoded = current_tokenizer.decode(output_ids, skip_special_tokens=False)
120
- if output_ids[-1] == current_tokenizer.eos_token_id:
121
- current_model.to("cpu")
122
- torch.cuda.empty_cache()
123
- return
124
- messages[-1]["content"] = decoded
125
- yield messages
126
-
127
- current_model.to("cpu")
128
- torch.cuda.empty_cache()
129
- return
130
 
131
  with gr.Row():
132
  model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
 
5
  import spaces
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
 
8
+ @spaces.GPU
9
+ def chat_with_model(messages):
10
+ global current_model, current_tokenizer
11
+ if current_model is None or current_tokenizer is None:
12
+ yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
13
+ return
14
 
15
+ current_model.to("cuda").half()
16
 
17
+ prompt = format_prompt(messages)
18
+ inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
19
 
20
+ streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=True)
21
+ generation_kwargs = dict(
22
+ **inputs,
23
+ max_new_tokens=256,
24
+ do_sample=True,
25
+ streamer=streamer
26
+ )
27
 
28
+ # Launch generation in a background thread
29
+ thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
30
+ thread.start()
31
 
32
+ output_text = ""
33
+ messages = messages.copy()
34
+ messages.append({"role": "assistant", "content": ""})
35
 
36
+ for new_text in streamer:
37
+ output_text += new_text
38
+ messages[-1]["content"] = output_text
39
+ yield messages
40
 
41
+ current_model.to("cpu")
42
+ torch.cuda.empty_cache()
43
 
44
 
45
  # Globals
 
92
 
93
  default_model = gr.State("meta-llama/Llama-3.2-3B-Instruct")
94
 
95
+ # @spaces.GPU
96
+ # def chat_with_model(messages):
97
+ # global current_model, current_tokenizer
98
+ # if current_model is None or current_tokenizer is None:
99
+ # yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
100
+ # return
101
+
102
+ # current_model = current_model.to("cuda").half()
103
+
104
+ # prompt = format_prompt(messages)
105
+ # inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
106
+
107
+ # output_ids = []
108
+ # messages = messages.copy()
109
+ # messages.append({"role": "assistant", "content": ""})
110
+
111
+ # for token_id in current_model.generate(
112
+ # **inputs,
113
+ # max_new_tokens=256,
114
+ # do_sample=True,
115
+ # return_dict_in_generate=True,
116
+ # output_scores=False
117
+ # ).sequences[0][inputs['input_ids'].shape[-1]:]: # skip input tokens
118
+ # output_ids.append(token_id.item())
119
+ # decoded = current_tokenizer.decode(output_ids, skip_special_tokens=False)
120
+ # if output_ids[-1] == current_tokenizer.eos_token_id:
121
+ # current_model.to("cpu")
122
+ # torch.cuda.empty_cache()
123
+ # return
124
+ # messages[-1]["content"] = decoded
125
+ # yield messages
126
+
127
+ # current_model.to("cpu")
128
+ # torch.cuda.empty_cache()
129
+ # return
130
 
131
  with gr.Row():
132
  model_selector = gr.Dropdown(choices=model_choices, label="Select Model")