computer-agent / model_replay.py
M-Rique's picture
Add initial images and format
8770774
from smolagents.models import Model, ChatMessage, Tool, MessageRole
from time import sleep
from typing import List, Dict, Optional
from huggingface_hub import hf_hub_download
import json
class FakeModelReplayLog(Model):
"""A model class that returns pre-recorded responses from a log file.
This class is useful for testing and debugging purposes, as it doesn't make
actual API calls but instead returns responses from a pre-recorded log file.
Parameters:
log_url (str, optional):
URL to the log file. Defaults to the smolagents example log.
**kwargs: Additional keyword arguments passed to the Model base class.
"""
def __init__(self, log_folder: str, **kwargs):
super().__init__(**kwargs)
self.dataset_name = "smolagents/computer-agent-logs"
self.log_folder = log_folder
self.call_counter = 0
self.model_outputs = self._load_model_outputs()
def _load_model_outputs(self) -> List[str]:
"""Load model outputs from the log file using HuggingFace datasets library."""
# Download the file from Hugging Face Hub
file_path = hf_hub_download(
repo_id=self.dataset_name,
filename=self.log_folder + "/metadata.json",
repo_type="dataset",
)
# Load and parse the JSON data
with open(file_path, "r") as f:
log_data = json.load(f)
# Extract only the model_output from each step in tool_calls
model_outputs = []
for step in log_data["summary"][1:]:
model_outputs.append(step["model_output_message"]["content"])
print(f"Loaded {len(model_outputs)} model outputs from log file")
return model_outputs
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
tools_to_call_from: Optional[List[Tool]] = None,
**kwargs,
) -> ChatMessage:
"""Return the next pre-recorded response from the log file.
Parameters:
messages: List of input messages (ignored).
stop_sequences: Optional list of stop sequences (ignored).
grammar: Optional grammar specification (ignored).
tools_to_call_from: Optional list of tools (ignored).
**kwargs: Additional keyword arguments (ignored).
Returns:
ChatMessage: The next pre-recorded response.
"""
sleep(1.0)
# Get the next model output
if self.call_counter < len(self.model_outputs):
content = self.model_outputs[self.call_counter]
self.call_counter += 1
else:
content = "No more pre-recorded responses available."
# Token counts are simulated
self.last_input_token_count = len(str(messages)) // 4 # Rough approximation
self.last_output_token_count = len(content) // 4 # Rough approximation
# Create and return a ChatMessage
return ChatMessage(
role=MessageRole.ASSISTANT,
content=content,
tool_calls=None,
raw={"source": "pre-recorded log", "call_number": self.call_counter},
)