Dixing (Dex) Xu
:zap: update execution timeout logic more aggresively (#35) (#37)
c92f194
unverified
""" | |
Python interpreter for executing code snippets and capturing their output. | |
Supports: | |
- captures stdout and stderr | |
- captures exceptions and stack traces | |
- limits execution time | |
""" | |
import logging | |
import os | |
import queue | |
import signal | |
import sys | |
import time | |
import traceback | |
from dataclasses import dataclass | |
from multiprocessing import Process, Queue | |
from pathlib import Path | |
import humanize | |
from dataclasses_json import DataClassJsonMixin | |
logger = logging.getLogger("aide") | |
class ExecutionResult(DataClassJsonMixin): | |
""" | |
Result of executing a code snippet in the interpreter. | |
Contains the output, execution time, and exception information. | |
""" | |
term_out: list[str] | |
exec_time: float | |
exc_type: str | None | |
exc_info: dict | None = None | |
exc_stack: list[tuple] | None = None | |
def exception_summary(e, working_dir, exec_file_name, format_tb_ipython): | |
"""Generates a string that summarizes an exception and its stack trace (either in standard python repl or in IPython format).""" | |
if format_tb_ipython: | |
import IPython.core.ultratb | |
# tb_offset = 1 to skip parts of the stack trace in weflow code | |
tb = IPython.core.ultratb.VerboseTB(tb_offset=1, color_scheme="NoColor") | |
tb_str = str(tb.text(*sys.exc_info())) | |
else: | |
tb_lines = traceback.format_exception(e) | |
# skip parts of stack trace in weflow code | |
tb_str = "".join( | |
[ | |
line | |
for line in tb_lines | |
if "aide/" not in line and "importlib" not in line | |
] | |
) | |
# tb_str = "".join([l for l in tb_lines]) | |
# replace whole path to file with just filename (to remove agent workspace dir) | |
tb_str = tb_str.replace(str(working_dir / exec_file_name), exec_file_name) | |
exc_info = {} | |
if hasattr(e, "args"): | |
exc_info["args"] = [str(i) for i in e.args] | |
for att in ["name", "msg", "obj"]: | |
if hasattr(e, att): | |
exc_info[att] = str(getattr(e, att)) | |
tb = traceback.extract_tb(e.__traceback__) | |
exc_stack = [(t.filename, t.lineno, t.name, t.line) for t in tb] | |
return tb_str, e.__class__.__name__, exc_info, exc_stack | |
class RedirectQueue: | |
def __init__(self, queue, timeout=5): | |
self.queue = queue | |
self.timeout = timeout | |
def write(self, msg): | |
try: | |
self.queue.put(msg, timeout=self.timeout) | |
except queue.Full: | |
logger.warning("Queue write timed out") | |
def flush(self): | |
pass | |
class Interpreter: | |
def __init__( | |
self, | |
working_dir: Path | str, | |
timeout: int = 3600, | |
format_tb_ipython: bool = False, | |
agent_file_name: str = "runfile.py", | |
): | |
""" | |
Simulates a standalone Python REPL with an execution time limit. | |
Args: | |
working_dir (Path | str): working directory of the agent | |
timeout (int, optional): Timeout for each code execution step. Defaults to 3600. | |
format_tb_ipython (bool, optional): Whether to use IPython or default python REPL formatting for exceptions. Defaults to False. | |
agent_file_name (str, optional): The name for the agent's code file. Defaults to "runfile.py". | |
""" | |
# this really needs to be a path, otherwise causes issues that don't raise exc | |
self.working_dir = Path(working_dir).resolve() | |
assert ( | |
self.working_dir.exists() | |
), f"Working directory {self.working_dir} does not exist" | |
self.timeout = timeout | |
self.format_tb_ipython = format_tb_ipython | |
self.agent_file_name = agent_file_name | |
self.process: Process = None # type: ignore | |
def child_proc_setup(self, result_outq: Queue) -> None: | |
# disable all warnings (before importing anything) | |
import shutup | |
shutup.mute_warnings() | |
os.chdir(str(self.working_dir)) | |
# this seems to only benecessary because we're exec'ing code from a string, | |
# a .py file should be able to import modules from the cwd anyway | |
sys.path.append(str(self.working_dir)) | |
# capture stdout and stderr | |
# trunk-ignore(mypy/assignment) | |
sys.stdout = sys.stderr = RedirectQueue(result_outq) | |
def _run_session( | |
self, code_inq: Queue, result_outq: Queue, event_outq: Queue | |
) -> None: | |
self.child_proc_setup(result_outq) | |
global_scope: dict = {} | |
while True: | |
code = code_inq.get() | |
os.chdir(str(self.working_dir)) | |
with open(self.agent_file_name, "w") as f: | |
f.write(code) | |
event_outq.put(("state:ready",)) | |
try: | |
exec(compile(code, self.agent_file_name, "exec"), global_scope) | |
except BaseException as e: | |
tb_str, e_cls_name, exc_info, exc_stack = exception_summary( | |
e, | |
self.working_dir, | |
self.agent_file_name, | |
self.format_tb_ipython, | |
) | |
result_outq.put(tb_str) | |
if e_cls_name == "KeyboardInterrupt": | |
e_cls_name = "TimeoutError" | |
event_outq.put(("state:finished", e_cls_name, exc_info, exc_stack)) | |
else: | |
event_outq.put(("state:finished", None, None, None)) | |
# remove the file after execution (otherwise it might be included in the data preview) | |
os.remove(self.agent_file_name) | |
# put EOF marker to indicate that we're done | |
result_outq.put("<|EOF|>") | |
def create_process(self) -> None: | |
# we use three queues to communicate with the child process: | |
# - code_inq: send code to child to execute | |
# - result_outq: receive stdout/stderr from child | |
# - event_outq: receive events from child (e.g. state:ready, state:finished) | |
# trunk-ignore(mypy/var-annotated) | |
self.code_inq, self.result_outq, self.event_outq = Queue(), Queue(), Queue() | |
self.process = Process( | |
target=self._run_session, | |
args=(self.code_inq, self.result_outq, self.event_outq), | |
) | |
self.process.start() | |
def cleanup_session(self): | |
if self.process is None: | |
return | |
try: | |
# Reduce grace period from 2 seconds to 0.5 | |
self.process.terminate() | |
self.process.join(timeout=0.5) | |
if self.process.exitcode is None: | |
logger.warning("Process failed to terminate, killing immediately") | |
self.process.kill() | |
self.process.join(timeout=0.5) | |
if self.process.exitcode is None: | |
logger.error("Process refuses to die, using SIGKILL") | |
os.kill(self.process.pid, signal.SIGKILL) | |
except Exception as e: | |
logger.error(f"Error during process cleanup: {e}") | |
finally: | |
if self.process is not None: | |
self.process.close() | |
self.process = None | |
def run(self, code: str, reset_session=True) -> ExecutionResult: | |
""" | |
Execute the provided Python command in a separate process and return its output. | |
Parameters: | |
code (str): Python code to execute. | |
reset_session (bool, optional): Whether to reset the interpreter session before executing the code. Defaults to True. | |
Returns: | |
ExecutionResult: Object containing the output and metadata of the code execution. | |
""" | |
logger.debug(f"REPL is executing code (reset_session={reset_session})") | |
if reset_session: | |
if self.process is not None: | |
# terminate and clean up previous process | |
self.cleanup_session() | |
self.create_process() | |
else: | |
# reset_session needs to be True on first exec | |
assert self.process is not None | |
assert self.process.is_alive() | |
self.code_inq.put(code) | |
# wait for child to actually start execution (we don't want interrupt child setup) | |
try: | |
state = self.event_outq.get(timeout=10) | |
except queue.Empty: | |
msg = "REPL child process failed to start execution" | |
logger.critical(msg) | |
while not self.result_outq.empty(): | |
logger.error(f"REPL output queue dump: {self.result_outq.get()}") | |
raise RuntimeError(msg) from None | |
assert state[0] == "state:ready", state | |
start_time = time.time() | |
# this flag indicates that the child ahs exceeded the time limit and an interrupt was sent | |
# if the child process dies without this flag being set, it's an unexpected termination | |
child_in_overtime = False | |
while True: | |
try: | |
# check if the child is done | |
state = self.event_outq.get(timeout=1) # wait for state:finished | |
assert state[0] == "state:finished", state | |
exec_time = time.time() - start_time | |
break | |
except queue.Empty: | |
# we haven't heard back from the child -> check if it's still alive (assuming overtime interrupt wasn't sent yet) | |
if not child_in_overtime and not self.process.is_alive(): | |
msg = "REPL child process died unexpectedly" | |
logger.critical(msg) | |
while not self.result_outq.empty(): | |
logger.error( | |
f"REPL output queue dump: {self.result_outq.get()}" | |
) | |
raise RuntimeError(msg) from None | |
# child is alive and still executing -> check if we should sigint.. | |
if self.timeout is None: | |
continue | |
running_time = time.time() - start_time | |
if running_time > self.timeout: | |
logger.warning(f"Execution exceeded timeout of {self.timeout}s") | |
os.kill(self.process.pid, signal.SIGINT) | |
child_in_overtime = True | |
# terminate if we're overtime by more than 5 seconds | |
if running_time > self.timeout + 5: | |
logger.warning("Child failed to terminate, killing it..") | |
self.cleanup_session() | |
state = (None, "TimeoutError", {}, []) | |
exec_time = self.timeout | |
break | |
output: list[str] = [] | |
# read all stdout/stderr from child up to the EOF marker | |
# waiting until the queue is empty is not enough since | |
# the feeder thread in child might still be adding to the queue | |
start_collect = time.time() | |
while not self.result_outq.empty() or not output or output[-1] != "<|EOF|>": | |
try: | |
# Add 5-second timeout for output collection | |
if time.time() - start_collect > 5: | |
logger.warning("Output collection timed out") | |
break | |
output.append(self.result_outq.get(timeout=1)) | |
except queue.Empty: | |
continue | |
output.pop() # remove the EOF marker | |
e_cls_name, exc_info, exc_stack = state[1:] | |
if e_cls_name == "TimeoutError": | |
output.append( | |
f"TimeoutError: Execution exceeded the time limit of {humanize.naturaldelta(self.timeout)}" | |
) | |
else: | |
output.append( | |
f"Execution time: {humanize.naturaldelta(exec_time)} seconds (time limit is {humanize.naturaldelta(self.timeout)})." | |
) | |
return ExecutionResult(output, exec_time, e_cls_name, exc_info, exc_stack) | |