WillHeld commited on
Commit
403c2fe
·
verified ·
1 Parent(s): f014ce9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -11,17 +11,22 @@ model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
11
  def predict(message, history, temperature, top_p):
12
  history.append({"role": "user", "content": message})
13
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
14
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
15
- outputs = model.generate(
 
 
 
 
16
  inputs,
17
  max_new_tokens=1024,
18
  temperature=float(temperature),
19
  top_p=float(top_p),
20
- do_sample=True
 
21
  )
22
- decoded = tokenizer.decode(outputs[0])
23
- response = decoded.split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1]
24
- return response
25
 
26
  with gr.Blocks() as demo:
27
  chatbot = gr.ChatInterface(
 
11
  def predict(message, history, temperature, top_p):
12
  history.append({"role": "user", "content": message})
13
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
14
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
15
+
16
+ streamer = gr.TelegramStreamer() # Use Gradio's built-in streamer
17
+
18
+ # Generate with streaming
19
+ model.generate(
20
  inputs,
21
  max_new_tokens=1024,
22
  temperature=float(temperature),
23
  top_p=float(top_p),
24
+ do_sample=True,
25
+ streamer=streamer
26
  )
27
+
28
+ # The streamer will handle returning the tokens
29
+ return streamer
30
 
31
  with gr.Blocks() as demo:
32
  chatbot = gr.ChatInterface(