lvwerra HF Staff commited on
Commit
c3fdac2
·
1 Parent(s): 5504eb2

add eval script

Browse files
Files changed (1) hide show
  1. eval.py +296 -0
eval.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+ import time
5
+ import argparse
6
+ import subprocess
7
+ import traceback
8
+ import threading
9
+ import concurrent.futures
10
+ from datetime import datetime
11
+ from threading import Timer
12
+ from e2b_desktop import Sandbox
13
+ from huggingface_hub import get_token
14
+
15
+ from smolagents import CodeAgent
16
+ from smolagents.monitoring import LogLevel
17
+ from e2bqwen import QwenVLAPIModel, E2BVisionAgent
18
+
19
+ # Environment variables and constants
20
+ E2B_API_KEY = os.getenv("E2B_API_KEY")
21
+ # Try to get token dynamically, fall back to environment variable
22
+ try:
23
+ HUGGINGFACE_API_KEY = get_token()
24
+ if not HUGGINGFACE_API_KEY:
25
+ HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
26
+ if not HUGGINGFACE_API_KEY:
27
+ raise ValueError("No Hugging Face token found. Please login with `huggingface-cli login` or set HUGGINGFACE_API_KEY environment variable")
28
+ except ImportError:
29
+ # Fall back if huggingface_hub is old version without get_token
30
+ HUGGINGFACE_API_KEY = os.getenv("HUGGINGFACE_API_KEY")
31
+ WIDTH = 1024
32
+ HEIGHT = 768
33
+ SANDBOX_TIMEOUT = 600 # 10 minutes
34
+
35
+ # Thread lock for print statements to avoid garbled output
36
+ print_lock = threading.Lock()
37
+
38
+ def thread_safe_print(*args, **kwargs):
39
+ """Thread-safe print function"""
40
+ with print_lock:
41
+ print(*args, **kwargs)
42
+
43
+ # Get git hash for folder naming
44
+ def get_git_hash():
45
+ try:
46
+ result = subprocess.run(['git', 'rev-parse', '--short', 'HEAD'],
47
+ stdout=subprocess.PIPE,
48
+ stderr=subprocess.PIPE,
49
+ text=True)
50
+ if result.returncode == 0:
51
+ return result.stdout.strip()
52
+ return "nogit"
53
+ except:
54
+ return "nogit"
55
+
56
+ def create_agent(data_dir, desktop):
57
+ """Create an agent with the E2B desktop sandbox"""
58
+ model = QwenVLAPIModel(
59
+ model_id="Qwen/Qwen2.5-VL-72B-Instruct",
60
+ hf_token=HUGGINGFACE_API_KEY,
61
+ )
62
+ return E2BVisionAgent(
63
+ model=model,
64
+ data_dir=data_dir,
65
+ desktop=desktop,
66
+ max_steps=200,
67
+ verbosity_level=2,
68
+ planning_interval=10,
69
+ )
70
+
71
+ def get_agent_summary_erase_images(agent):
72
+ """Get agent summary and erase images to save space"""
73
+ for memory_step in agent.memory.steps:
74
+ if getattr(memory_step, "observations_images", None):
75
+ memory_step.observations_images = None
76
+ return agent.memory.get_succinct_steps()
77
+
78
+ def chat_message_to_json(obj):
79
+ """Custom JSON serializer for ChatMessage and related objects"""
80
+ if hasattr(obj, '__dict__'):
81
+ # Create a copy of the object's __dict__ to avoid modifying the original
82
+ result = obj.__dict__.copy()
83
+
84
+ # Remove the 'raw' field which may contain non-serializable data
85
+ if 'raw' in result:
86
+ del result['raw']
87
+
88
+ # Process the content or tool_calls if they exist
89
+ if 'content' in result and result['content'] is not None:
90
+ if hasattr(result['content'], '__dict__'):
91
+ result['content'] = chat_message_to_json(result['content'])
92
+
93
+ if 'tool_calls' in result and result['tool_calls'] is not None:
94
+ result['tool_calls'] = [chat_message_to_json(tc) for tc in result['tool_calls']]
95
+
96
+ return result
97
+ elif isinstance(obj, (list, tuple)):
98
+ return [chat_message_to_json(item) for item in obj]
99
+ else:
100
+ return obj
101
+
102
+ def save_final_status(folder, status: str, summary, error_message=None) -> None:
103
+ """Save metadata about the run"""
104
+ metadata_path = os.path.join(folder, "metadata.json")
105
+ with open(metadata_path, "w") as output_file:
106
+ output_file.write(json.dumps({
107
+ "status": status,
108
+ "summary": summary,
109
+ "error_message": error_message
110
+ }, default=chat_message_to_json))
111
+
112
+ def run_example_once(example_name, example_text, run_index, example_dir):
113
+ """Run a single example once and return the result"""
114
+ run_dir = os.path.join(example_dir, f"run_{run_index}")
115
+ os.makedirs(run_dir, exist_ok=True)
116
+
117
+ # Save the example text
118
+ with open(os.path.join(run_dir, "task.txt"), "w") as f:
119
+ f.write(example_text)
120
+
121
+ thread_safe_print(f" Starting run {run_index} for example '{example_name}'")
122
+
123
+ # Create a new sandbox for this run
124
+ desktop = None
125
+ try:
126
+ desktop = Sandbox(
127
+ api_key=E2B_API_KEY,
128
+ resolution=(WIDTH, HEIGHT),
129
+ dpi=96,
130
+ timeout=SANDBOX_TIMEOUT
131
+ )
132
+
133
+ # Initialize the desktop environment
134
+ setup_cmd = """sudo mkdir -p /usr/lib/firefox-esr/distribution && echo '{"policies":{"OverrideFirstRunPage":"","OverridePostUpdatePage":"","DisableProfileImport":true,"DontCheckDefaultBrowser":true}}' | sudo tee /usr/lib/firefox-esr/distribution/policies.json > /dev/null"""
135
+ desktop.commands.run(setup_cmd)
136
+
137
+ # Create and run the agent
138
+ agent = create_agent(data_dir=run_dir, desktop=desktop)
139
+ try:
140
+ agent.run(task=example_text)
141
+ summary = get_agent_summary_erase_images(agent)
142
+ save_final_status(run_dir, "completed", summary=summary)
143
+ thread_safe_print(f" ✓ Example '{example_name}' run {run_index} completed successfully")
144
+ result = {"status": "completed", "run_dir": run_dir}
145
+ except Exception as e:
146
+ error_message = f"Error in agent execution: {str(e)}"
147
+ thread_safe_print(f" ✗ Example '{example_name}' run {run_index} failed: {error_message}")
148
+ summary = get_agent_summary_erase_images(agent) if hasattr(agent, 'memory') else None
149
+ save_final_status(run_dir, "failed", summary=summary, error_message=error_message)
150
+ result = {"status": "failed", "run_dir": run_dir, "error": error_message}
151
+ except Exception as e:
152
+ error_message = f"Error setting up sandbox: {str(e)}"
153
+ thread_safe_print(f" ✗ Example '{example_name}' run {run_index} failed: {error_message}")
154
+ save_final_status(run_dir, "failed", summary=None, error_message=error_message)
155
+ result = {"status": "failed", "run_dir": run_dir, "error": error_message}
156
+ finally:
157
+ # Always clean up the sandbox
158
+ if desktop:
159
+ try:
160
+ desktop.kill()
161
+ except:
162
+ pass
163
+
164
+ return result
165
+
166
+ def run_example(example_name, example_text, num_runs, example_dir):
167
+ """Run a single example multiple times using threads for each run"""
168
+ thread_safe_print(f"\nRunning example '{example_name}': '{example_text[:50]}...'")
169
+
170
+ results = []
171
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_runs) as executor:
172
+ # Submit all runs to the executor
173
+ future_to_run = {
174
+ executor.submit(run_example_once, example_name, example_text, j, example_dir): j
175
+ for j in range(num_runs)
176
+ }
177
+
178
+ # Collect results as they complete
179
+ for future in concurrent.futures.as_completed(future_to_run):
180
+ run_index = future_to_run[future]
181
+ try:
182
+ result = future.result()
183
+ results.append(result)
184
+ except Exception as exc:
185
+ thread_safe_print(f" ✗ Run {run_index} for '{example_name}' generated an exception: {exc}")
186
+ results.append({
187
+ "status": "error",
188
+ "run_index": run_index,
189
+ "error": str(exc)
190
+ })
191
+
192
+ return results
193
+
194
+ def run_evaluation(examples, num_runs, output_dir, max_parallel):
195
+ """Run each example n times and save the results"""
196
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
197
+ git_hash = get_git_hash()
198
+ eval_dir = os.path.join(output_dir, f"eval_{timestamp}_{git_hash}")
199
+ os.makedirs(eval_dir, exist_ok=True)
200
+
201
+ thread_safe_print(f"Starting evaluation. Results will be saved to: {eval_dir}")
202
+ thread_safe_print(f"Will run {len(examples)} examples, {num_runs} times each, with {max_parallel} parallel examples")
203
+
204
+ # Save examples to the evaluation directory
205
+ with open(os.path.join(eval_dir, "examples.json"), "w") as f:
206
+ json.dump(examples, f, indent=2)
207
+
208
+ all_results = {}
209
+
210
+ # Run examples in parallel, but limit the number of parallel examples
211
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_parallel) as executor:
212
+ # Prepare the example directories first
213
+ example_dirs = {}
214
+ for example_name in examples:
215
+ example_dir = os.path.join(eval_dir, f"example_{example_name}")
216
+ os.makedirs(example_dir, exist_ok=True)
217
+ example_dirs[example_name] = example_dir
218
+
219
+ # Submit all examples to the executor
220
+ future_to_example = {
221
+ executor.submit(run_example, example_name, example_text, num_runs, example_dirs[example_name]): example_name
222
+ for example_name, example_text in examples.items()
223
+ }
224
+
225
+ # Collect results as they complete
226
+ for future in concurrent.futures.as_completed(future_to_example):
227
+ example_name = future_to_example[future]
228
+ try:
229
+ results = future.result()
230
+ all_results[example_name] = results
231
+
232
+ # Calculate success rate for this example
233
+ success_count = sum(1 for r in results if r["status"] == "completed")
234
+ thread_safe_print(f"Example '{example_name}' complete: {success_count}/{num_runs} successful runs ({success_count/num_runs*100:.1f}%)")
235
+ except Exception as exc:
236
+ thread_safe_print(f"Example '{example_name}' generated an exception: {exc}")
237
+ all_results[example_name] = [{"status": "error", "error": str(exc)}]
238
+
239
+ # Calculate overall results and success rates
240
+ success_counts = {
241
+ example_name: sum(1 for r in results if r["status"] == "completed")
242
+ for example_name, results in all_results.items()
243
+ }
244
+
245
+ total_runs = sum(len(results) for results in all_results.values())
246
+ total_successes = sum(success_counts.values())
247
+
248
+ # Save summary to evaluation directory
249
+ summary = {
250
+ "total_runs": total_runs,
251
+ "total_successes": total_successes,
252
+ "success_rate": total_successes / total_runs if total_runs > 0 else 0,
253
+ "example_success_rates": {
254
+ example_name: success_counts[example_name] / len(all_results[example_name])
255
+ for example_name in examples
256
+ }
257
+ }
258
+
259
+ with open(os.path.join(eval_dir, "summary.json"), "w") as f:
260
+ json.dump(summary, f, indent=2)
261
+
262
+ thread_safe_print(f"\nEvaluation complete. Results saved to: {eval_dir}")
263
+ thread_safe_print(f"Overall success rate: {summary['success_rate']*100:.1f}% ({total_successes}/{total_runs})")
264
+ for example_name in examples:
265
+ success_rate = summary["example_success_rates"][example_name] * 100
266
+ thread_safe_print(f"Example '{example_name}': {success_rate:.1f}% success")
267
+
268
+ return eval_dir
269
+
270
+ def main():
271
+ parser = argparse.ArgumentParser(description="Evaluate computer agent on examples")
272
+ parser.add_argument("--num-runs", type=int, default=3, help="Number of runs per example")
273
+ parser.add_argument("--output-dir", type=str, default="./eval_results", help="Output directory for evaluation results")
274
+ parser.add_argument("--max-parallel", type=int, default=2, help="Maximum number of examples to run in parallel")
275
+ args = parser.parse_args()
276
+
277
+ # Examples from the original code
278
+ examples = {
279
+ "puppies": "Find me pictures of cute puppies",
280
+ "commute": "Check the commuting time between Bern and Zurich on Google maps",
281
+ "hello": "Write 'Hello World' in a text editor",
282
+ "wiki": "When was Temple Grandin introduced to the American Academy of Arts and Sciences, according to Wikipedia?",
283
+ "flight": "Search a flight Rome - Berlin for tomorrow",
284
+ "pond": "What's the name of the pond just south of Château de Fontainebleau in Google maps?",
285
+ "flux": "Go generate a picture of the Golden Gate bridge on a FLUX1.dev space",
286
+ "hf": "Download me a picture of a puppy from Google, then head to Hugging Face, find a Space dedicated to background removal, and use it to remove the puppy picture's background",
287
+ }
288
+
289
+ # Create output directory if it doesn't exist
290
+ os.makedirs(args.output_dir, exist_ok=True)
291
+
292
+ # Run the evaluation
293
+ eval_dir = run_evaluation(examples, args.num_runs, args.output_dir, args.max_parallel)
294
+
295
+ if __name__ == "__main__":
296
+ main()