Ruurd commited on
Commit
d86f9b0
·
1 Parent(s): d6e3337
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -12,7 +12,6 @@ import threading
12
  from transformers import TextIteratorStreamer
13
  import queue
14
 
15
- @spaces.GPU
16
  class RichTextStreamer(TextIteratorStreamer):
17
  def __init__(self, tokenizer, **kwargs):
18
  super().__init__(tokenizer, **kwargs)
@@ -54,6 +53,7 @@ def chat_with_model(messages):
54
  return
55
 
56
  pad_id = current_tokenizer.pad_token_id
 
57
  if pad_id is None:
58
  pad_id = current_tokenizer.unk_token_id or 0
59
 
@@ -66,35 +66,40 @@ def chat_with_model(messages):
66
 
67
  streamer = RichTextStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
68
 
 
 
 
 
 
69
  generation_kwargs = dict(
70
  **inputs,
71
- max_new_tokens=256,
72
  do_sample=True,
73
  streamer=streamer,
74
- eos_token_id=current_tokenizer.eos_token_id,
75
  pad_token_id=pad_id
76
  )
77
 
78
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
79
  thread.start()
80
 
81
- output_text = ""
82
  messages = messages.copy()
83
  messages.append({"role": "assistant", "content": ""})
84
- in_think = False
85
 
86
  for token_info in streamer:
87
  token_str = token_info["token"]
88
  token_id = token_info["token_id"]
89
  is_special = token_info["is_special"]
90
 
91
- if token_id == current_tokenizer.eos_token_id:
92
- streamer.end_of_generation.set() # signal to stop generation thread
93
  break
94
 
 
95
  if is_special:
96
  continue
97
 
 
98
  if "<think>" in token_str:
99
  in_think = True
100
  token_str = token_str.replace("<think>", "")
@@ -107,8 +112,14 @@ def chat_with_model(messages):
107
  else:
108
  output_text += token_str
109
 
 
110
  if "\nUser:" in output_text:
111
  output_text = output_text.split("\nUser:")[0].rstrip()
 
 
 
 
 
112
 
113
  messages[-1]["content"] = output_text
114
  yield messages
@@ -118,13 +129,11 @@ def chat_with_model(messages):
118
  messages[-1]["content"] = output_text
119
  yield messages
120
 
121
- # Ensure generation thread stops
122
  thread.join(timeout=1.0)
123
-
124
  current_model.to("cpu")
125
  torch.cuda.empty_cache()
126
 
127
- return messages
128
 
129
 
130
  # Globals
 
12
  from transformers import TextIteratorStreamer
13
  import queue
14
 
 
15
  class RichTextStreamer(TextIteratorStreamer):
16
  def __init__(self, tokenizer, **kwargs):
17
  super().__init__(tokenizer, **kwargs)
 
53
  return
54
 
55
  pad_id = current_tokenizer.pad_token_id
56
+ eos_id = current_tokenizer.eos_token_id
57
  if pad_id is None:
58
  pad_id = current_tokenizer.unk_token_id or 0
59
 
 
66
 
67
  streamer = RichTextStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
68
 
69
+ max_new_tokens = 256
70
+ generated_tokens = 0
71
+ output_text = ""
72
+ in_think = False
73
+
74
  generation_kwargs = dict(
75
  **inputs,
76
+ max_new_tokens=max_new_tokens,
77
  do_sample=True,
78
  streamer=streamer,
79
+ eos_token_id=eos_id,
80
  pad_token_id=pad_id
81
  )
82
 
83
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
84
  thread.start()
85
 
 
86
  messages = messages.copy()
87
  messages.append({"role": "assistant", "content": ""})
 
88
 
89
  for token_info in streamer:
90
  token_str = token_info["token"]
91
  token_id = token_info["token_id"]
92
  is_special = token_info["is_special"]
93
 
94
+ # Stop immediately at EOS
95
+ if token_id == eos_id:
96
  break
97
 
98
+ # Optional: skip other special tokens
99
  if is_special:
100
  continue
101
 
102
+ # Detect reasoning block
103
  if "<think>" in token_str:
104
  in_think = True
105
  token_str = token_str.replace("<think>", "")
 
112
  else:
113
  output_text += token_str
114
 
115
+ # Early stopping if user reappears
116
  if "\nUser:" in output_text:
117
  output_text = output_text.split("\nUser:")[0].rstrip()
118
+ break
119
+
120
+ generated_tokens += 1
121
+ if generated_tokens >= max_new_tokens:
122
+ break
123
 
124
  messages[-1]["content"] = output_text
125
  yield messages
 
129
  messages[-1]["content"] = output_text
130
  yield messages
131
 
132
+ # Wait for thread to finish
133
  thread.join(timeout=1.0)
 
134
  current_model.to("cpu")
135
  torch.cuda.empty_cache()
136
 
 
137
 
138
 
139
  # Globals