Implement richtextiteratorstreamer
Browse files
app.py
CHANGED
@@ -9,6 +9,34 @@ import threading
|
|
9 |
from transformers import TextIteratorStreamer
|
10 |
import threading
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
@spaces.GPU
|
14 |
def chat_with_model(messages):
|
@@ -31,7 +59,9 @@ def chat_with_model(messages):
|
|
31 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
32 |
|
33 |
|
34 |
-
streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
|
|
|
|
|
35 |
|
36 |
generation_kwargs = dict(
|
37 |
**inputs,
|
@@ -49,16 +79,16 @@ def chat_with_model(messages):
|
|
49 |
messages = messages.copy()
|
50 |
messages.append({"role": "assistant", "content": ""})
|
51 |
|
52 |
-
for
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
messages[-1]["content"] = output_text
|
57 |
-
yield messages
|
58 |
-
break
|
59 |
messages[-1]["content"] = output_text
|
60 |
yield messages
|
61 |
|
|
|
|
|
|
|
62 |
current_model.to("cpu")
|
63 |
torch.cuda.empty_cache()
|
64 |
|
|
|
9 |
from transformers import TextIteratorStreamer
|
10 |
import threading
|
11 |
|
12 |
+
from transformers import TextIteratorStreamer
|
13 |
+
import queue
|
14 |
+
|
15 |
+
class RichTextStreamer(TextIteratorStreamer):
|
16 |
+
def __init__(self, tokenizer, **kwargs):
|
17 |
+
super().__init__(tokenizer, **kwargs)
|
18 |
+
self.token_queue = queue.Queue()
|
19 |
+
|
20 |
+
def put(self, value):
|
21 |
+
# Instead of just decoding here, we emit full info per token
|
22 |
+
token_id = value.item() if hasattr(value, "item") else value
|
23 |
+
token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
|
24 |
+
is_special = token_id in self.tokenizer.all_special_ids
|
25 |
+
self.token_queue.put({
|
26 |
+
"token_id": token_id,
|
27 |
+
"token": token_str,
|
28 |
+
"is_special": is_special
|
29 |
+
})
|
30 |
+
|
31 |
+
def __iter__(self):
|
32 |
+
while True:
|
33 |
+
try:
|
34 |
+
token_info = self.token_queue.get(timeout=self.timeout)
|
35 |
+
yield token_info
|
36 |
+
except queue.Empty:
|
37 |
+
if self.end_of_generation.is_set():
|
38 |
+
break
|
39 |
+
|
40 |
|
41 |
@spaces.GPU
|
42 |
def chat_with_model(messages):
|
|
|
59 |
inputs = {k: v.to(device) for k, v in inputs.items()}
|
60 |
|
61 |
|
62 |
+
# streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
|
63 |
+
streamer = RichTextStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
|
64 |
+
|
65 |
|
66 |
generation_kwargs = dict(
|
67 |
**inputs,
|
|
|
79 |
messages = messages.copy()
|
80 |
messages.append({"role": "assistant", "content": ""})
|
81 |
|
82 |
+
for token_info in streamer:
|
83 |
+
token_str = token_info["token"]
|
84 |
+
is_special = token_info["is_special"]
|
85 |
+
output_text += token_str
|
|
|
|
|
|
|
86 |
messages[-1]["content"] = output_text
|
87 |
yield messages
|
88 |
|
89 |
+
if is_special and token_info["token_id"] == current_tokenizer.eos_token_id:
|
90 |
+
break
|
91 |
+
|
92 |
current_model.to("cpu")
|
93 |
torch.cuda.empty_cache()
|
94 |
|