Athspi commited on
Commit
ce67cd9
·
verified ·
1 Parent(s): f0fbb06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -20
app.py CHANGED
@@ -67,7 +67,7 @@ def initialize_model():
67
  model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
68
  logging.info(model_status)
69
  try:
70
- # FIX: Remove explicit DeviceType. Let the library infer or use string if needed by constructor.
71
  # The simple constructor often works by detecting the installed ORT package.
72
  logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
73
  model = og.Model(model_path) # Simplified model loading
@@ -105,19 +105,18 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
105
  full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
106
 
107
  logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
108
- # logging.debug(f"Full Formatted Prompt:\n{full_prompt}")
109
 
110
  try:
111
  input_tokens = tokenizer.encode(full_prompt)
112
 
 
 
113
  search_options = {
114
  "max_length": max_length,
115
  "temperature": temperature,
116
  "top_p": top_p,
117
  "top_k": top_k,
118
  "do_sample": True,
119
- "eos_token_id": tokenizer.eos_token_id,
120
- "pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
121
  }
122
 
123
  params = og.GeneratorParams(model)
@@ -131,29 +130,23 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
131
 
132
  first_token_time = None
133
  token_count = 0
 
134
  while not generator.is_done():
135
  generator.compute_logits()
136
  generator.generate_next_token()
137
  if first_token_time is None:
138
  first_token_time = time.time() # Record time to first token
139
- next_token = generator.get_next_tokens()[0]
140
 
141
- if next_token == search_options["eos_token_id"]:
142
- logging.info("EOS token encountered.")
143
- break
144
 
145
  decoded_chunk = tokenizer.decode([next_token])
146
  token_count += 1
147
 
148
- # Handle potential decoding issues or special tokens if necessary
149
- if decoded_chunk == "<|end|>": # Example: Stop if assistant outputs end token explicitly
150
- logging.info("Assistant explicitly generated <|end|> token.")
151
- break
152
- if decoded_chunk == tokenizer.eos_token: # Check against tokenizer's eos_token string
153
- logging.info("Assistant generated EOS token string.")
154
  break
155
 
156
-
157
  yield decoded_chunk # Yield just the text chunk
158
 
159
  end_time = time.time()
@@ -169,13 +162,18 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
169
  model_status = f"Error during generation: {e}"
170
  yield f"\n\nSorry, an error occurred during generation: {e}" # Yield error message
171
 
 
172
  # --- Gradio Interface Functions ---
173
 
174
  # 1. Function to add user message to chat history
175
  def add_user_message(user_message, history):
176
  """Adds the user's message to the chat history for display."""
177
  if not user_message:
178
- raise gr.Error("Please enter a message.")
 
 
 
 
179
  history = history + [[user_message, None]] # Append user message, leave bot response None
180
  return "", history # Clear input textbox, return updated history
181
 
@@ -183,7 +181,8 @@ def add_user_message(user_message, history):
183
  def generate_bot_response(history, max_length, temperature, top_p, top_k):
184
  """Generates the bot's response based on the history and streams it."""
185
  if not history or history[-1][1] is not None:
186
- # This shouldn't happen in the normal flow, but good practice
 
187
  return history
188
 
189
  user_prompt = history[-1][0] # Get the latest user prompt
@@ -196,7 +195,7 @@ def generate_bot_response(history, max_length, temperature, top_p, top_k):
196
  )
197
 
198
  # Stream the response chunks back to Gradio
199
- history[-1][1] = "" # Initialize the bot response string
200
  for chunk in response_stream:
201
  history[-1][1] += chunk # Append the chunk to the bot's message in history
202
  yield history # Yield the *entire updated history* back to Chatbot
@@ -207,9 +206,9 @@ def clear_chat():
207
  global model_status # Keep model status indicator updated
208
  # Reset status only if it was showing an error from generation maybe?
209
  # Or just always reset to Ready if model is loaded.
210
- if model and tokenizer:
211
  model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
212
- # Keep the original error if init failed
213
  return None, [], model_status # Clear Textbox, Chatbot history, and update status display
214
 
215
 
 
67
  model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
68
  logging.info(model_status)
69
  try:
70
+ # FIX: Removed explicit DeviceType. Let the library infer or use string if needed by constructor.
71
  # The simple constructor often works by detecting the installed ORT package.
72
  logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
73
  model = og.Model(model_path) # Simplified model loading
 
105
  full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
106
 
107
  logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
 
108
 
109
  try:
110
  input_tokens = tokenizer.encode(full_prompt)
111
 
112
+ # FIX: Removed eos_token_id and pad_token_id as they are not attributes
113
+ # of onnxruntime_genai.Tokenizer and likely handled internally by the generator.
114
  search_options = {
115
  "max_length": max_length,
116
  "temperature": temperature,
117
  "top_p": top_p,
118
  "top_k": top_k,
119
  "do_sample": True,
 
 
120
  }
121
 
122
  params = og.GeneratorParams(model)
 
130
 
131
  first_token_time = None
132
  token_count = 0
133
+ # Rely primarily on generator.is_done()
134
  while not generator.is_done():
135
  generator.compute_logits()
136
  generator.generate_next_token()
137
  if first_token_time is None:
138
  first_token_time = time.time() # Record time to first token
 
139
 
140
+ next_token = generator.get_next_tokens()[0]
 
 
141
 
142
  decoded_chunk = tokenizer.decode([next_token])
143
  token_count += 1
144
 
145
+ # Secondary check: Stop if the model explicitly generates the <|end|> string literal.
146
+ if decoded_chunk == "<|end|>":
147
+ logging.info("Assistant explicitly generated <|end|> token string.")
 
 
 
148
  break
149
 
 
150
  yield decoded_chunk # Yield just the text chunk
151
 
152
  end_time = time.time()
 
162
  model_status = f"Error during generation: {e}"
163
  yield f"\n\nSorry, an error occurred during generation: {e}" # Yield error message
164
 
165
+
166
  # --- Gradio Interface Functions ---
167
 
168
  # 1. Function to add user message to chat history
169
  def add_user_message(user_message, history):
170
  """Adds the user's message to the chat history for display."""
171
  if not user_message:
172
+ # Returning original history prevents adding empty message
173
+ # Use gr.Warning or gr.Info for user feedback? Or raise gr.Error?
174
+ # gr.Warning("Please enter a message.") # Shows warning toast
175
+ return "", history # Clear input, return unchanged history
176
+ # raise gr.Error("Please enter a message.") # Stops execution, shows error
177
  history = history + [[user_message, None]] # Append user message, leave bot response None
178
  return "", history # Clear input textbox, return updated history
179
 
 
181
  def generate_bot_response(history, max_length, temperature, top_p, top_k):
182
  """Generates the bot's response based on the history and streams it."""
183
  if not history or history[-1][1] is not None:
184
+ # This case means user submitted empty message or something went wrong
185
+ # No need to generate if the last turn isn't user's pending turn
186
  return history
187
 
188
  user_prompt = history[-1][0] # Get the latest user prompt
 
195
  )
196
 
197
  # Stream the response chunks back to Gradio
198
+ history[-1][1] = "" # Initialize the bot response string in the history
199
  for chunk in response_stream:
200
  history[-1][1] += chunk # Append the chunk to the bot's message in history
201
  yield history # Yield the *entire updated history* back to Chatbot
 
206
  global model_status # Keep model status indicator updated
207
  # Reset status only if it was showing an error from generation maybe?
208
  # Or just always reset to Ready if model is loaded.
209
+ if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"):
210
  model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
211
+ # Keep the original error if init failed, otherwise show ready status
212
  return None, [], model_status # Clear Textbox, Chatbot history, and update status display
213
 
214