strickvl commited on
Commit
5ea6228
·
unverified ·
1 Parent(s): f025d0e

Update Gradio UI streaming to use agent.run() with improved error handling

Browse files

- Refactored stream_to_gradio function to use agent.run() with streaming
- Added Gradio-specific error message handling
- Improved token tracking and error logging
- Imported gradio explicitly for error message generation

Files changed (1) hide show
  1. Gradio_UI.py +16 -10
Gradio_UI.py CHANGED
@@ -29,6 +29,7 @@ from smolagents.agents import ActionStep, MultiStepAgent
29
  from smolagents.memory import MemoryStep
30
  from smolagents.utils import _is_package_available
31
  import logging
 
32
  logger = logging.getLogger(__name__)
33
 
34
 
@@ -160,26 +161,31 @@ def pull_messages_from_step(
160
 
161
  def stream_to_gradio(agent, task: str, reset_agent_memory: bool = True):
162
  """Stream agent responses to Gradio interface with better error handling"""
 
 
163
  total_input_tokens = 0
164
  total_output_tokens = 0
165
 
166
  try:
167
- for msg in agent.chat(task, reset_memory=reset_agent_memory):
168
- # Safely handle token counting
169
- if hasattr(agent.model, "last_input_token_count"):
170
- input_tokens = agent.model.last_input_token_count or 0
171
- total_input_tokens += input_tokens
 
 
 
172
 
173
- if hasattr(agent.model, "last_output_token_count"):
174
- output_tokens = agent.model.last_output_token_count or 0
175
- total_output_tokens += output_tokens
176
 
177
- yield msg
178
 
179
  except Exception as e:
180
  error_msg = f"Error during chat: {str(e)}"
181
  logger.error(error_msg)
182
- yield f"⚠️ {error_msg}"
183
 
184
 
185
  class GradioUI:
 
29
  from smolagents.memory import MemoryStep
30
  from smolagents.utils import _is_package_available
31
  import logging
32
+
33
  logger = logging.getLogger(__name__)
34
 
35
 
 
161
 
162
  def stream_to_gradio(agent, task: str, reset_agent_memory: bool = True):
163
  """Stream agent responses to Gradio interface with better error handling"""
164
+ import gradio as gr
165
+
166
  total_input_tokens = 0
167
  total_output_tokens = 0
168
 
169
  try:
170
+ # Use agent.run() instead of agent.chat()
171
+ for step_log in agent.run(task, stream=True, reset=reset_agent_memory):
172
+ # Extract messages from the step
173
+ for msg in pull_messages_from_step(step_log):
174
+ # Safely handle token counting
175
+ if hasattr(agent.model, "last_input_token_count"):
176
+ input_tokens = agent.model.last_input_token_count or 0
177
+ total_input_tokens += input_tokens
178
 
179
+ if hasattr(agent.model, "last_output_token_count"):
180
+ output_tokens = agent.model.last_output_token_count or 0
181
+ total_output_tokens += output_tokens
182
 
183
+ yield msg
184
 
185
  except Exception as e:
186
  error_msg = f"Error during chat: {str(e)}"
187
  logger.error(error_msg)
188
+ yield gr.ChatMessage(role="assistant", content=f"⚠️ {error_msg}")
189
 
190
 
191
  class GradioUI: