suayptalha commited on
Commit
b5341f2
·
verified ·
1 Parent(s): a5ec87b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -6
app.py CHANGED
@@ -4,6 +4,9 @@ os.system("pip install git+https://github.com/shumingma/transformers.git")
4
 
5
  import threading
6
  import torch
 
 
 
7
  from transformers import (
8
  AutoModelForCausalLM,
9
  AutoTokenizer,
@@ -32,6 +35,21 @@ def respond(
32
  temperature: float,
33
  top_p: float,
34
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  messages = [{"role": "system", "content": system_message}]
36
  for user_msg, bot_msg in history:
37
  if user_msg:
@@ -40,18 +58,33 @@ def respond(
40
  messages.append({"role": "assistant", "content": bot_msg})
41
  messages.append({"role": "user", "content": message})
42
 
43
- prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
44
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
45
 
46
- outputs = model.generate(
 
 
 
 
47
  **inputs,
 
48
  max_new_tokens=max_tokens,
49
  temperature=temperature,
50
  top_p=top_p,
51
- do_sample=True
52
  )
53
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
54
- yield response
 
 
 
 
 
 
 
55
 
56
  # Initialize Gradio chat interface
57
 
@@ -106,4 +139,4 @@ demo = gr.ChatInterface(
106
  )
107
 
108
  if __name__ == "__main__":
109
- demo.launch()
 
4
 
5
  import threading
6
  import torch
7
+ import torch._dynamo
8
+ torch._dynamo.config.suppress_errors = True
9
+
10
  from transformers import (
11
  AutoModelForCausalLM,
12
  AutoTokenizer,
 
35
  temperature: float,
36
  top_p: float,
37
  ):
38
+ """
39
+ Generate a chat response using streaming with TextIteratorStreamer.
40
+
41
+ Args:
42
+ message: User's current message.
43
+ history: List of (user, assistant) tuples from previous turns.
44
+ system_message: Initial system prompt guiding the assistant.
45
+ max_tokens: Maximum number of tokens to generate.
46
+ temperature: Sampling temperature.
47
+ top_p: Nucleus sampling probability.
48
+
49
+ Yields:
50
+ The growing response text as new tokens are generated.
51
+ """
52
+ # Assemble messages
53
  messages = [{"role": "system", "content": system_message}]
54
  for user_msg, bot_msg in history:
55
  if user_msg:
 
58
  messages.append({"role": "assistant", "content": bot_msg})
59
  messages.append({"role": "user", "content": message})
60
 
61
+ # Prepare prompt and tokenize
62
+ prompt = tokenizer.apply_chat_template(
63
+ messages, tokenize=False, add_generation_prompt=True
64
+ )
65
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
66
 
67
+ # Set up streamer for real-time output
68
+ streamer = TextIteratorStreamer(
69
+ tokenizer, skip_prompt=True, skip_special_tokens=True
70
+ )
71
+ generate_kwargs = dict(
72
  **inputs,
73
+ streamer=streamer,
74
  max_new_tokens=max_tokens,
75
  temperature=temperature,
76
  top_p=top_p,
77
+ do_sample=True,
78
  )
79
+ # Start generation in a separate thread
80
+ thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
81
+ thread.start()
82
+
83
+ # Stream tokens back to user
84
+ response = ""
85
+ for new_text in streamer:
86
+ response += new_text
87
+ yield response
88
 
89
  # Initialize Gradio chat interface
90
 
 
139
  )
140
 
141
  if __name__ == "__main__":
142
+ demo.launch()