Ruurd commited on
Commit
80f8fa5
·
1 Parent(s): 6183592

Implement richtextiteratorstreamer

Browse files
Files changed (1) hide show
  1. app.py +38 -8
app.py CHANGED
@@ -9,6 +9,34 @@ import threading
9
  from transformers import TextIteratorStreamer
10
  import threading
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @spaces.GPU
14
  def chat_with_model(messages):
@@ -31,7 +59,9 @@ def chat_with_model(messages):
31
  inputs = {k: v.to(device) for k, v in inputs.items()}
32
 
33
 
34
- streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
 
 
35
 
36
  generation_kwargs = dict(
37
  **inputs,
@@ -49,16 +79,16 @@ def chat_with_model(messages):
49
  messages = messages.copy()
50
  messages.append({"role": "assistant", "content": ""})
51
 
52
- for new_text in streamer:
53
- output_text += new_text
54
- if "\nUser:" in output_text:
55
- output_text = output_text.split("\nUser:")[0].rstrip()
56
- messages[-1]["content"] = output_text
57
- yield messages
58
- break
59
  messages[-1]["content"] = output_text
60
  yield messages
61
 
 
 
 
62
  current_model.to("cpu")
63
  torch.cuda.empty_cache()
64
 
 
9
  from transformers import TextIteratorStreamer
10
  import threading
11
 
12
+ from transformers import TextIteratorStreamer
13
+ import queue
14
+
15
+ class RichTextStreamer(TextIteratorStreamer):
16
+ def __init__(self, tokenizer, **kwargs):
17
+ super().__init__(tokenizer, **kwargs)
18
+ self.token_queue = queue.Queue()
19
+
20
+ def put(self, value):
21
+ # Instead of just decoding here, we emit full info per token
22
+ token_id = value.item() if hasattr(value, "item") else value
23
+ token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
24
+ is_special = token_id in self.tokenizer.all_special_ids
25
+ self.token_queue.put({
26
+ "token_id": token_id,
27
+ "token": token_str,
28
+ "is_special": is_special
29
+ })
30
+
31
+ def __iter__(self):
32
+ while True:
33
+ try:
34
+ token_info = self.token_queue.get(timeout=self.timeout)
35
+ yield token_info
36
+ except queue.Empty:
37
+ if self.end_of_generation.is_set():
38
+ break
39
+
40
 
41
  @spaces.GPU
42
  def chat_with_model(messages):
 
59
  inputs = {k: v.to(device) for k, v in inputs.items()}
60
 
61
 
62
+ # streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
63
+ streamer = RichTextStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
64
+
65
 
66
  generation_kwargs = dict(
67
  **inputs,
 
79
  messages = messages.copy()
80
  messages.append({"role": "assistant", "content": ""})
81
 
82
+ for token_info in streamer:
83
+ token_str = token_info["token"]
84
+ is_special = token_info["is_special"]
85
+ output_text += token_str
 
 
 
86
  messages[-1]["content"] = output_text
87
  yield messages
88
 
89
+ if is_special and token_info["token_id"] == current_tokenizer.eos_token_id:
90
+ break
91
+
92
  current_model.to("cpu")
93
  torch.cuda.empty_cache()
94