Update app.py
Browse files
app.py
CHANGED
@@ -67,7 +67,7 @@ def initialize_model():
|
|
67 |
model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
|
68 |
logging.info(model_status)
|
69 |
try:
|
70 |
-
# FIX:
|
71 |
# The simple constructor often works by detecting the installed ORT package.
|
72 |
logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
|
73 |
model = og.Model(model_path) # Simplified model loading
|
@@ -105,19 +105,18 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
|
|
105 |
full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
106 |
|
107 |
logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
|
108 |
-
# logging.debug(f"Full Formatted Prompt:\n{full_prompt}")
|
109 |
|
110 |
try:
|
111 |
input_tokens = tokenizer.encode(full_prompt)
|
112 |
|
|
|
|
|
113 |
search_options = {
|
114 |
"max_length": max_length,
|
115 |
"temperature": temperature,
|
116 |
"top_p": top_p,
|
117 |
"top_k": top_k,
|
118 |
"do_sample": True,
|
119 |
-
"eos_token_id": tokenizer.eos_token_id,
|
120 |
-
"pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
|
121 |
}
|
122 |
|
123 |
params = og.GeneratorParams(model)
|
@@ -131,29 +130,23 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
|
|
131 |
|
132 |
first_token_time = None
|
133 |
token_count = 0
|
|
|
134 |
while not generator.is_done():
|
135 |
generator.compute_logits()
|
136 |
generator.generate_next_token()
|
137 |
if first_token_time is None:
|
138 |
first_token_time = time.time() # Record time to first token
|
139 |
-
next_token = generator.get_next_tokens()[0]
|
140 |
|
141 |
-
|
142 |
-
logging.info("EOS token encountered.")
|
143 |
-
break
|
144 |
|
145 |
decoded_chunk = tokenizer.decode([next_token])
|
146 |
token_count += 1
|
147 |
|
148 |
-
#
|
149 |
-
if decoded_chunk == "<|end|>":
|
150 |
-
logging.info("Assistant explicitly generated <|end|> token.")
|
151 |
-
break
|
152 |
-
if decoded_chunk == tokenizer.eos_token: # Check against tokenizer's eos_token string
|
153 |
-
logging.info("Assistant generated EOS token string.")
|
154 |
break
|
155 |
|
156 |
-
|
157 |
yield decoded_chunk # Yield just the text chunk
|
158 |
|
159 |
end_time = time.time()
|
@@ -169,13 +162,18 @@ def generate_response_stream(prompt, history, max_length, temperature, top_p, to
|
|
169 |
model_status = f"Error during generation: {e}"
|
170 |
yield f"\n\nSorry, an error occurred during generation: {e}" # Yield error message
|
171 |
|
|
|
172 |
# --- Gradio Interface Functions ---
|
173 |
|
174 |
# 1. Function to add user message to chat history
|
175 |
def add_user_message(user_message, history):
|
176 |
"""Adds the user's message to the chat history for display."""
|
177 |
if not user_message:
|
178 |
-
|
|
|
|
|
|
|
|
|
179 |
history = history + [[user_message, None]] # Append user message, leave bot response None
|
180 |
return "", history # Clear input textbox, return updated history
|
181 |
|
@@ -183,7 +181,8 @@ def add_user_message(user_message, history):
|
|
183 |
def generate_bot_response(history, max_length, temperature, top_p, top_k):
|
184 |
"""Generates the bot's response based on the history and streams it."""
|
185 |
if not history or history[-1][1] is not None:
|
186 |
-
# This
|
|
|
187 |
return history
|
188 |
|
189 |
user_prompt = history[-1][0] # Get the latest user prompt
|
@@ -196,7 +195,7 @@ def generate_bot_response(history, max_length, temperature, top_p, top_k):
|
|
196 |
)
|
197 |
|
198 |
# Stream the response chunks back to Gradio
|
199 |
-
history[-1][1] = "" # Initialize the bot response string
|
200 |
for chunk in response_stream:
|
201 |
history[-1][1] += chunk # Append the chunk to the bot's message in history
|
202 |
yield history # Yield the *entire updated history* back to Chatbot
|
@@ -207,9 +206,9 @@ def clear_chat():
|
|
207 |
global model_status # Keep model status indicator updated
|
208 |
# Reset status only if it was showing an error from generation maybe?
|
209 |
# Or just always reset to Ready if model is loaded.
|
210 |
-
if model and tokenizer:
|
211 |
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
|
212 |
-
# Keep the original error if init failed
|
213 |
return None, [], model_status # Clear Textbox, Chatbot history, and update status display
|
214 |
|
215 |
|
|
|
67 |
model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
|
68 |
logging.info(model_status)
|
69 |
try:
|
70 |
+
# FIX: Removed explicit DeviceType. Let the library infer or use string if needed by constructor.
|
71 |
# The simple constructor often works by detecting the installed ORT package.
|
72 |
logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
|
73 |
model = og.Model(model_path) # Simplified model loading
|
|
|
105 |
full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
106 |
|
107 |
logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
|
|
|
108 |
|
109 |
try:
|
110 |
input_tokens = tokenizer.encode(full_prompt)
|
111 |
|
112 |
+
# FIX: Removed eos_token_id and pad_token_id as they are not attributes
|
113 |
+
# of onnxruntime_genai.Tokenizer and likely handled internally by the generator.
|
114 |
search_options = {
|
115 |
"max_length": max_length,
|
116 |
"temperature": temperature,
|
117 |
"top_p": top_p,
|
118 |
"top_k": top_k,
|
119 |
"do_sample": True,
|
|
|
|
|
120 |
}
|
121 |
|
122 |
params = og.GeneratorParams(model)
|
|
|
130 |
|
131 |
first_token_time = None
|
132 |
token_count = 0
|
133 |
+
# Rely primarily on generator.is_done()
|
134 |
while not generator.is_done():
|
135 |
generator.compute_logits()
|
136 |
generator.generate_next_token()
|
137 |
if first_token_time is None:
|
138 |
first_token_time = time.time() # Record time to first token
|
|
|
139 |
|
140 |
+
next_token = generator.get_next_tokens()[0]
|
|
|
|
|
141 |
|
142 |
decoded_chunk = tokenizer.decode([next_token])
|
143 |
token_count += 1
|
144 |
|
145 |
+
# Secondary check: Stop if the model explicitly generates the <|end|> string literal.
|
146 |
+
if decoded_chunk == "<|end|>":
|
147 |
+
logging.info("Assistant explicitly generated <|end|> token string.")
|
|
|
|
|
|
|
148 |
break
|
149 |
|
|
|
150 |
yield decoded_chunk # Yield just the text chunk
|
151 |
|
152 |
end_time = time.time()
|
|
|
162 |
model_status = f"Error during generation: {e}"
|
163 |
yield f"\n\nSorry, an error occurred during generation: {e}" # Yield error message
|
164 |
|
165 |
+
|
166 |
# --- Gradio Interface Functions ---
|
167 |
|
168 |
# 1. Function to add user message to chat history
|
169 |
def add_user_message(user_message, history):
|
170 |
"""Adds the user's message to the chat history for display."""
|
171 |
if not user_message:
|
172 |
+
# Returning original history prevents adding empty message
|
173 |
+
# Use gr.Warning or gr.Info for user feedback? Or raise gr.Error?
|
174 |
+
# gr.Warning("Please enter a message.") # Shows warning toast
|
175 |
+
return "", history # Clear input, return unchanged history
|
176 |
+
# raise gr.Error("Please enter a message.") # Stops execution, shows error
|
177 |
history = history + [[user_message, None]] # Append user message, leave bot response None
|
178 |
return "", history # Clear input textbox, return updated history
|
179 |
|
|
|
181 |
def generate_bot_response(history, max_length, temperature, top_p, top_k):
|
182 |
"""Generates the bot's response based on the history and streams it."""
|
183 |
if not history or history[-1][1] is not None:
|
184 |
+
# This case means user submitted empty message or something went wrong
|
185 |
+
# No need to generate if the last turn isn't user's pending turn
|
186 |
return history
|
187 |
|
188 |
user_prompt = history[-1][0] # Get the latest user prompt
|
|
|
195 |
)
|
196 |
|
197 |
# Stream the response chunks back to Gradio
|
198 |
+
history[-1][1] = "" # Initialize the bot response string in the history
|
199 |
for chunk in response_stream:
|
200 |
history[-1][1] += chunk # Append the chunk to the bot's message in history
|
201 |
yield history # Yield the *entire updated history* back to Chatbot
|
|
|
206 |
global model_status # Keep model status indicator updated
|
207 |
# Reset status only if it was showing an error from generation maybe?
|
208 |
# Or just always reset to Ready if model is loaded.
|
209 |
+
if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"):
|
210 |
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
|
211 |
+
# Keep the original error if init failed, otherwise show ready status
|
212 |
return None, [], model_status # Clear Textbox, Chatbot history, and update status display
|
213 |
|
214 |
|