Ruurd commited on
Commit
205d52f
·
1 Parent(s): 80f8fa5

Implement reasoning blocks and fix eos token showing

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -45,8 +45,6 @@ def chat_with_model(messages):
45
  yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
46
  return
47
 
48
-
49
-
50
  pad_id = current_tokenizer.pad_token_id
51
  if pad_id is None:
52
  pad_id = current_tokenizer.unk_token_id or 0
@@ -58,11 +56,8 @@ def chat_with_model(messages):
58
  inputs = current_tokenizer(prompt, return_tensors="pt")
59
  inputs = {k: v.to(device) for k, v in inputs.items()}
60
 
61
-
62
- # streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
63
  streamer = RichTextStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
64
 
65
-
66
  generation_kwargs = dict(
67
  **inputs,
68
  max_new_tokens=256,
@@ -78,22 +73,42 @@ def chat_with_model(messages):
78
  output_text = ""
79
  messages = messages.copy()
80
  messages.append({"role": "assistant", "content": ""})
 
81
 
82
  for token_info in streamer:
83
  token_str = token_info["token"]
 
84
  is_special = token_info["is_special"]
85
- output_text += token_str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  messages[-1]["content"] = output_text
87
  yield messages
88
 
89
- if is_special and token_info["token_id"] == current_tokenizer.eos_token_id:
90
- break
 
 
91
 
92
  current_model.to("cpu")
93
  torch.cuda.empty_cache()
94
 
95
 
96
-
97
  # Globals
98
  current_model = None
99
  current_tokenizer = None
 
45
  yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}]
46
  return
47
 
 
 
48
  pad_id = current_tokenizer.pad_token_id
49
  if pad_id is None:
50
  pad_id = current_tokenizer.unk_token_id or 0
 
56
  inputs = current_tokenizer(prompt, return_tensors="pt")
57
  inputs = {k: v.to(device) for k, v in inputs.items()}
58
 
 
 
59
  streamer = RichTextStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
60
 
 
61
  generation_kwargs = dict(
62
  **inputs,
63
  max_new_tokens=256,
 
73
  output_text = ""
74
  messages = messages.copy()
75
  messages.append({"role": "assistant", "content": ""})
76
+ in_think = False
77
 
78
  for token_info in streamer:
79
  token_str = token_info["token"]
80
+ token_id = token_info["token_id"]
81
  is_special = token_info["is_special"]
82
+
83
+ # Skip appending the EOS token to output
84
+ if token_id == current_tokenizer.eos_token_id:
85
+ break
86
+
87
+ # Detect reasoning block
88
+ if "<think>" in token_str:
89
+ in_think = True
90
+ token_str = token_str.replace("<think>", "")
91
+ output_text += "*"
92
+
93
+ if "</think>" in token_str:
94
+ in_think = False
95
+ token_str = token_str.replace("</think>", "")
96
+ output_text += token_str + "*"
97
+ else:
98
+ output_text += token_str
99
+
100
  messages[-1]["content"] = output_text
101
  yield messages
102
 
103
+ if in_think:
104
+ output_text += "*"
105
+ messages[-1]["content"] = output_text
106
+ yield messages
107
 
108
  current_model.to("cpu")
109
  torch.cuda.empty_cache()
110
 
111
 
 
112
  # Globals
113
  current_model = None
114
  current_tokenizer = None