Ruurd commited on
Commit
b16f2d9
·
1 Parent(s): 2db3bb3

Try to fix chatbot streaming

Browse files
Files changed (1) hide show
  1. app.py +28 -28
app.py CHANGED
@@ -39,34 +39,6 @@ def format_prompt(messages):
39
  prompt += "Assistant:"
40
  return prompt
41
 
42
- @spaces.GPU
43
- def chat_with_model(messages):
44
- global current_model, current_tokenizer
45
- if current_model is None or current_tokenizer is None:
46
- yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
47
- return
48
-
49
- current_model.to("cuda")
50
-
51
- prompt = format_prompt(messages)
52
- inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
53
-
54
- output_ids = []
55
- messages = messages.copy()
56
- messages.append({"role": "assistant", "content": ""})
57
-
58
- for token_id in current_model.generate(
59
- **inputs,
60
- max_new_tokens=256,
61
- do_sample=False,
62
- return_dict_in_generate=True,
63
- output_scores=False
64
- ).sequences[0][inputs['input_ids'].shape[-1]:]: # skip input tokens
65
- output_ids.append(token_id.item())
66
- decoded = current_tokenizer.decode(output_ids, skip_special_tokens=True)
67
- messages[-1]["content"] = decoded
68
- yield messages
69
-
70
  def add_user_message(user_input, history):
71
  return "", history + [{"role": "user", "content": user_input}]
72
 
@@ -83,6 +55,34 @@ with gr.Blocks() as demo:
83
 
84
  default_model = gr.State("meta-llama/Llama-3.2-3B-Instruct")
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  with gr.Row():
87
  model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
88
  model_status = gr.Textbox(label="Model Status", interactive=False)
 
39
  prompt += "Assistant:"
40
  return prompt
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def add_user_message(user_input, history):
43
  return "", history + [{"role": "user", "content": user_input}]
44
 
 
55
 
56
  default_model = gr.State("meta-llama/Llama-3.2-3B-Instruct")
57
 
58
+ @spaces.GPU
59
+ def chat_with_model(messages):
60
+ global current_model, current_tokenizer
61
+ if current_model is None or current_tokenizer is None:
62
+ yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
63
+ return
64
+
65
+ current_model.to("cuda")
66
+
67
+ prompt = format_prompt(messages)
68
+ inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
69
+
70
+ output_ids = []
71
+ messages = messages.copy()
72
+ messages.append({"role": "assistant", "content": ""})
73
+
74
+ for token_id in current_model.generate(
75
+ **inputs,
76
+ max_new_tokens=256,
77
+ do_sample=False,
78
+ return_dict_in_generate=True,
79
+ output_scores=False
80
+ ).sequences[0][inputs['input_ids'].shape[-1]:]: # skip input tokens
81
+ output_ids.append(token_id.item())
82
+ decoded = current_tokenizer.decode(output_ids, skip_special_tokens=True)
83
+ messages[-1]["content"] = decoded
84
+ yield messages
85
+
86
  with gr.Row():
87
  model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
88
  model_status = gr.Textbox(label="Model Status", interactive=False)