Jofthomas commited on
Commit
e1889c4
·
verified ·
1 Parent(s): a3e0aee

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +95 -92
main.py CHANGED
@@ -1,117 +1,124 @@
 
1
  import os
2
  import pandas as pd
3
  from fastapi import FastAPI, HTTPException, Body
4
  from pydantic import BaseModel, Field
5
- from typing import List, Dict, Any
6
  from datasets import load_dataset, Dataset, DatasetDict
7
  from huggingface_hub import HfApi, hf_hub_download
8
  from datetime import datetime, timezone
9
  import logging
10
- import uvicorn # To run the app
11
- import random # <-- Added import for random choice
12
 
 
13
  tool_threshold = 3
14
  step_threshold = 5
15
-
16
- # --- Configuration ---
17
  HF_DATASET_ID = "agents-course/unit4-students-scores"
18
- # Ensure you have write access to this dataset repository on Hugging Face
19
- # and are logged in via `huggingface-cli login` or have HF_TOKEN env var set.
20
- # Prepare data structures for the API
21
- questions_for_api: List[Dict[str, str]] = []
22
  ground_truth_answers: Dict[str, str] = {}
 
23
  # --- Logging Setup ---
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
- filtered_dataset=None
 
 
 
 
 
 
27
  def load_questions():
28
  global filtered_dataset
29
  global questions_for_api
30
  global ground_truth_answers
31
  tempo_filtered=[]
32
- # Clear existing data to prevent duplication if called multiple times
33
  questions_for_api.clear()
34
  ground_truth_answers.clear()
35
 
36
  logger.info("Starting to load and filter GAIA dataset...")
37
  try:
38
- dataset=load_dataset("gaia-benchmark/GAIA","2023_level1",trust_remote_code=True)
39
- logger.info("GAIA dataset loaded.")
 
40
  except Exception as e:
41
- logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
42
- # Decide how to handle this: maybe raise the error or exit
43
  raise RuntimeError("Could not load the primary GAIA dataset.") from e
44
 
45
- for question in dataset['validation']:
46
- metadata = question.get('Annotator Metadata') # Use .get() for safety
 
47
 
48
- if metadata: # Check if 'Annotator Metadata' exists
49
  num_tools_str = metadata.get('Number of tools')
50
  num_steps_str = metadata.get('Number of steps')
51
 
52
- # Check if both numbers exist before trying to convert
53
  if num_tools_str is not None and num_steps_str is not None:
54
  try:
55
- # Convert values to integers for comparison
56
  num_tools = int(num_tools_str)
57
  num_steps = int(num_steps_str)
58
 
59
- # Apply the filter conditions
60
  if num_tools < tool_threshold and num_steps < step_threshold:
61
- # logger.debug(f"MATCH FOUND (Task ID: {question.get('task_id', 'N/A')}) - Tools: {num_tools}, Steps: {num_steps}")
62
- # logger.debug(question) # Print the matching question dictionary
63
- # logger.debug("------------------------------------------------------------------")
64
- tempo_filtered.append(question) # Add to the filtered list
65
- # else: # Optional: Handle items that don't match the filter
66
- # logger.debug(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Tools: {num_tools}, Steps: {num_steps}")
67
  except ValueError:
68
- # Handle cases where 'Number of tools' or 'Number of steps' is not a valid integer
69
- logger.warning(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Could not convert tool/step count to integer: tools='{num_tools_str}', steps='{num_steps_str}'.")
70
- # logger.debug("------------------------------------------------------------------")
71
  else:
72
  logger.warning(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
73
- # logger.debug("------------------------------------------------------------------")
74
 
75
- filtered_dataset=tempo_filtered
 
76
  logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
77
- # print(filtered_dataset) # Keep this commented unless debugging
78
 
 
79
  processed_count = 0
80
  for item in filtered_dataset:
81
  task_id = item.get('task_id')
82
- question_text = item.get('Question')
83
  final_answer = item.get('Final answer')
84
-
85
- # Validate required fields
86
  if task_id and question_text and final_answer is not None:
87
- # Create a copy of the item and remove fields we don't want
88
- processed_item = item.copy()
89
- processed_item.pop('Final answer', None) # Remove Final answer
90
- processed_item.pop('Annotator Annotation', None) # Remove Annotator Annotation
91
-
92
- # Ensure the field name matches what's expected by Pydantic models
93
- if 'Question' in processed_item and 'question' not in processed_item:
94
- processed_item['question'] = processed_item.pop('Question')
95
-
96
- # Store in questions_for_api
97
  questions_for_api.append(processed_item)
98
-
99
- # Still store the ground truth answers separately
100
  ground_truth_answers[str(task_id)] = str(final_answer)
101
  processed_count += 1
102
  else:
103
- logger.warning(f"Skipping item due to missing fields (task_id, Question, or Final answer): {item}")
 
 
 
 
104
 
105
  if not questions_for_api:
106
  logger.error("CRITICAL: No valid questions loaded after filtering. API endpoints needing questions will fail.")
107
- # Depending on requirements, you might want to exit or raise an error here
108
  # raise RuntimeError("Failed to load mandatory question data after filtering.")
109
 
110
- # --- Pydantic Models for Data Validation ---
 
 
111
  class Question(BaseModel):
112
  task_id: str
113
- question: str
114
 
 
 
115
  class AnswerItem(BaseModel):
116
  task_id: str
117
  submitted_answer: str = Field(..., description="The agent's answer for the task_id")
@@ -132,35 +139,30 @@ class ScoreResponse(BaseModel):
132
  class ErrorResponse(BaseModel):
133
  detail: str
134
 
 
135
  # --- FastAPI Application ---
136
  app = FastAPI(
137
  title="Agent Evaluation API",
138
  description="API to fetch questions and submit agent answers for scoring.",
139
  )
140
 
141
- # --- Startup Event Handler ---
142
  @app.on_event("startup")
143
  async def startup_event():
144
- """
145
- Loads the questions when the FastAPI application starts.
146
- """
147
  logger.info("Application startup: Loading questions...")
148
  try:
149
- load_questions() # Call your loading function here
150
  if not questions_for_api:
151
- logger.error("CRITICAL: No questions were loaded during startup. The /questions and /random-question endpoints might fail.")
152
- # Depending on requirements, you might want the app to fail startup
153
- # raise RuntimeError("Failed to load mandatory question data.")
154
  else:
155
  logger.info(f"Successfully loaded {len(questions_for_api)} questions.")
156
  except Exception as e:
157
  logger.error(f"CRITICAL ERROR DURING STARTUP while loading questions: {e}", exc_info=True)
158
- # Decide if the app should exit if loading fails
159
  # import sys
160
- # sys.exit(1)
161
 
162
-
163
- # --- Helper Function to interact with HF Dataset ---
164
  def update_huggingface_dataset(username: str, score: float):
165
  """Loads the dataset, updates the score if higher, and pushes back."""
166
  try:
@@ -242,7 +244,7 @@ def update_huggingface_dataset(username: str, score: float):
242
 
243
  updated_ds = DatasetDict({'train': Dataset.from_pandas(df)})
244
  logger.info(f"Dataset to push: {updated_ds}") # Log the dataset structure
245
- updated_ds.push_to_hub(HF_DATASET_ID) # Token should be picked up from env or login
246
  logger.warning("Dataset push to hub is currently commented out. Uncomment the line above to enable leaderboard updates.") # REMINDER
247
  logger.info("Dataset push simulated/attempted.")
248
  return True
@@ -254,35 +256,35 @@ def update_huggingface_dataset(username: str, score: float):
254
  # Re-raise the exception to be caught by the endpoint handler
255
  raise HTTPException(status_code=500, detail=f"Failed to update Hugging Face dataset: {e}")
256
 
257
-
258
- # --- API Endpoints ---
259
 
260
  @app.get("/questions",
261
- response_model=List[Question],
262
- summary="Get All Filtered Questions",
263
- description="Returns the complete list of questions (task_id and question text only) filtered based on criteria.")
 
264
  async def get_questions():
265
  """
266
- Provides the list of questions that agents should answer.
267
  """
268
- # print(f"Returning {len(questions_for_api)} questions.") # Debug log
269
  if not questions_for_api:
270
  logger.error("GET /questions requested but no questions are loaded.")
271
  raise HTTPException(status_code=404, detail="No questions available.")
 
272
  return questions_for_api
273
 
274
- # --- NEW ENDPOINT ---
275
  @app.get("/random-question",
276
- response_model=Question,
277
- summary="Get One Random Question",
278
- description="Returns a single random question from the available filtered set.",
 
279
  responses={
280
- 200: {"description": "A random question."},
281
  404: {"model": ErrorResponse, "description": "No questions available to choose from."}
282
  })
283
  async def get_random_question():
284
  """
285
- Provides a single, randomly selected question from the loaded list.
286
  """
287
  if not questions_for_api:
288
  logger.warning("GET /random-question requested but no questions are loaded.")
@@ -290,11 +292,11 @@ async def get_random_question():
290
 
291
  # Select and return a random question dictionary
292
  random_question = random.choice(questions_for_api)
293
- logger.info(f"Returning random question with task_id: {random_question['task_id']}")
 
294
  return random_question
295
- # --- END NEW ENDPOINT ---
296
-
297
 
 
298
  @app.post("/submit",
299
  response_model=ScoreResponse,
300
  summary="Submit Agent Answers",
@@ -358,17 +360,22 @@ async def submit_answers(submission: Submission = Body(...)):
358
  logger.debug(f"Incorrect answer for {task_id} from {submission.username}. Submitted: '{submitted}', Expected: '{ground_truth}'")
359
 
360
 
361
- # Calculate score based on valid attempts
362
  if valid_attempted_count == 0:
363
  score = 0.0
364
  message = f"Submission received, but no valid/matching task IDs were found in the {total_attempted_in_payload} answers provided."
365
  logger.warning(f"No valid answers processed for {submission.username} out of {total_attempted_in_payload} submitted.")
 
 
 
 
366
  else:
 
367
  score = round((correct_count / len(ground_truth_answers)) * 100, 2)
368
- message = f"Score calculated successfully: {correct_count}/{valid_attempted_count} correct answers for valid tasks."
369
  if valid_attempted_count < total_attempted_in_payload:
370
  message += f" ({total_attempted_in_payload - valid_attempted_count} submitted answers had invalid or duplicate task IDs)."
371
- logger.info(f"Score for {submission.username}: {score}% ({correct_count}/{valid_attempted_count})")
372
 
373
 
374
  # Update Hugging Face dataset
@@ -401,22 +408,18 @@ async def submit_answers(submission: Submission = Body(...)):
401
  )
402
 
403
  # --- Run the application ---
404
- # This part is mainly for local development without Docker.
405
- # Docker uses the CMD instruction in the Dockerfile.
406
  if __name__ == "__main__":
407
  logger.info("Starting FastAPI server for local development...")
408
- # Explicitly call load_questions here for local run,
409
- # as the @app.on_event("startup") might not trigger reliably
410
- # depending on how uvicorn is invoked directly.
411
  try:
412
- load_questions()
413
  if not questions_for_api:
414
  logger.error("EXITING: Cannot start server without loaded questions.")
 
 
 
415
  else:
416
- # Read port from environment variable for consistency, default to 8000 for local if not set
417
  local_port = int(os.getenv("PORT", "8000"))
418
  logger.info(f"Running Uvicorn locally on http://127.0.0.1:{local_port}")
419
- # Note: host='127.0.0.1' is usually fine for local runs outside docker
420
  uvicorn.run(app, host="127.0.0.1", port=local_port, log_level="info")
421
  except Exception as e:
422
  logger.error(f"Failed to start server: {e}", exc_info=True)
 
1
+ # Import necessary libraries (ensure all required imports are at the top)
2
  import os
3
  import pandas as pd
4
  from fastapi import FastAPI, HTTPException, Body
5
  from pydantic import BaseModel, Field
6
+ from typing import List, Dict, Any #<-- Make sure Any is imported
7
  from datasets import load_dataset, Dataset, DatasetDict
8
  from huggingface_hub import HfApi, hf_hub_download
9
  from datetime import datetime, timezone
10
  import logging
11
+ import uvicorn
12
+ import random
13
 
14
+ # --- Constants and Config ---
15
  tool_threshold = 3
16
  step_threshold = 5
 
 
17
  HF_DATASET_ID = "agents-course/unit4-students-scores"
18
+
19
+ # --- Data Structures ---
20
+ # questions_for_api will now hold richer dictionaries
21
+ questions_for_api: List[Dict[str, Any]] = []
22
  ground_truth_answers: Dict[str, str] = {}
23
+
24
  # --- Logging Setup ---
25
  logging.basicConfig(level=logging.INFO)
26
  logger = logging.getLogger(__name__)
27
+
28
+ # --- Filtered Dataset Placeholder ---
29
+ # Note: Making filtered_dataset global might not be ideal in larger apps,
30
+ # but keeping it as is based on the original code.
31
+ filtered_dataset = None
32
+
33
+ # --- Modified load_questions Function ---
34
  def load_questions():
35
  global filtered_dataset
36
  global questions_for_api
37
  global ground_truth_answers
38
  tempo_filtered=[]
39
+ # Clear existing data
40
  questions_for_api.clear()
41
  ground_truth_answers.clear()
42
 
43
  logger.info("Starting to load and filter GAIA dataset...")
44
  try:
45
+ # Load the 'validation' split specifically if that's intended
46
+ dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split='validation', trust_remote_code=True)
47
+ logger.info("GAIA dataset validation split loaded.")
48
  except Exception as e:
49
+ logger.error(f"Failed to load GAIA dataset validation split: {e}", exc_info=True)
 
50
  raise RuntimeError("Could not load the primary GAIA dataset.") from e
51
 
52
+ # --- Filtering Logic (remains the same) ---
53
+ for question in dataset: # Iterate directly over the loaded split
54
+ metadata = question.get('Annotator Metadata')
55
 
56
+ if metadata:
57
  num_tools_str = metadata.get('Number of tools')
58
  num_steps_str = metadata.get('Number of steps')
59
 
 
60
  if num_tools_str is not None and num_steps_str is not None:
61
  try:
 
62
  num_tools = int(num_tools_str)
63
  num_steps = int(num_steps_str)
64
 
 
65
  if num_tools < tool_threshold and num_steps < step_threshold:
66
+ tempo_filtered.append(question)
 
 
 
 
 
67
  except ValueError:
68
+ logger.warning(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Could not convert tool/step count: tools='{num_tools_str}', steps='{num_steps_str}'.")
69
+ # else: # Optional: Log if numbers are missing
70
+ # logger.debug(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Missing tool/step count in metadata.")
71
  else:
72
  logger.warning(f"Skipping Task ID: {question.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
 
73
 
74
+ # Store the filtered list (optional, could process directly)
75
+ filtered_dataset = tempo_filtered
76
  logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
 
77
 
78
+ # --- Processing Logic (Modified) ---
79
  processed_count = 0
80
  for item in filtered_dataset:
81
  task_id = item.get('task_id')
82
+ question_text = item.get('Question') # Keep original key for now
83
  final_answer = item.get('Final answer')
84
+
85
+ # Validate required fields needed for processing/scoring
86
  if task_id and question_text and final_answer is not None:
87
+ # Create a copy to avoid modifying the original item in filtered_dataset
88
+ processed_item: Dict[str, Any] = item.copy()
89
+
90
+ # Remove the fields we explicitly want to exclude
91
+ processed_item.pop('Final answer', None)
92
+ processed_item.pop('Annotator Annotation', None)
93
+ # You could add more fields to pop here if needed later
94
+ # processed_item.pop('Another field to remove', None)
95
+
96
+ # Store the dictionary containing all remaining fields
97
  questions_for_api.append(processed_item)
98
+
99
+ # Store the ground truth answer separately for scoring
100
  ground_truth_answers[str(task_id)] = str(final_answer)
101
  processed_count += 1
102
  else:
103
+ # Log which required field was missing if possible
104
+ missing = [k for k, v in {'task_id': task_id, 'Question': question_text, 'Final answer': final_answer}.items() if not v and v is not None]
105
+ logger.warning(f"Skipping item due to missing required fields ({', '.join(missing)}): task_id={task_id}")
106
+
107
+ logger.info(f"Successfully processed {processed_count} questions into API format.")
108
 
109
  if not questions_for_api:
110
  logger.error("CRITICAL: No valid questions loaded after filtering. API endpoints needing questions will fail.")
 
111
  # raise RuntimeError("Failed to load mandatory question data after filtering.")
112
 
113
+ # --- Pydantic Models ---
114
+ # Keep Question simple for potential internal use or basic validation,
115
+ # but the API will return Dict[str, Any]
116
  class Question(BaseModel):
117
  task_id: str
118
+ Question: str # Keep original casing if that's what in the data
119
 
120
+ # Keep other models as they are (AnswerItem, Submission, ScoreResponse, ErrorResponse)
121
+ # ... (rest of the Pydantic models remain the same) ...
122
  class AnswerItem(BaseModel):
123
  task_id: str
124
  submitted_answer: str = Field(..., description="The agent's answer for the task_id")
 
139
  class ErrorResponse(BaseModel):
140
  detail: str
141
 
142
+
143
  # --- FastAPI Application ---
144
  app = FastAPI(
145
  title="Agent Evaluation API",
146
  description="API to fetch questions and submit agent answers for scoring.",
147
  )
148
 
149
+ # --- Startup Event ---
150
  @app.on_event("startup")
151
  async def startup_event():
 
 
 
152
  logger.info("Application startup: Loading questions...")
153
  try:
154
+ load_questions()
155
  if not questions_for_api:
156
+ logger.error("CRITICAL: No questions were loaded during startup.")
 
 
157
  else:
158
  logger.info(f"Successfully loaded {len(questions_for_api)} questions.")
159
  except Exception as e:
160
  logger.error(f"CRITICAL ERROR DURING STARTUP while loading questions: {e}", exc_info=True)
 
161
  # import sys
162
+ # sys.exit(1) # Consider exiting if questions are critical
163
 
164
+ # --- Helper Function (update_huggingface_dataset remains the same) ---
165
+ # ... (update_huggingface_dataset function code) ...
166
  def update_huggingface_dataset(username: str, score: float):
167
  """Loads the dataset, updates the score if higher, and pushes back."""
168
  try:
 
244
 
245
  updated_ds = DatasetDict({'train': Dataset.from_pandas(df)})
246
  logger.info(f"Dataset to push: {updated_ds}") # Log the dataset structure
247
+ # updated_ds.push_to_hub(HF_DATASET_ID) # Uncomment this line to enable leaderboard updates
248
  logger.warning("Dataset push to hub is currently commented out. Uncomment the line above to enable leaderboard updates.") # REMINDER
249
  logger.info("Dataset push simulated/attempted.")
250
  return True
 
256
  # Re-raise the exception to be caught by the endpoint handler
257
  raise HTTPException(status_code=500, detail=f"Failed to update Hugging Face dataset: {e}")
258
 
259
+ # --- API Endpoints (Modified response_model) ---
 
260
 
261
  @app.get("/questions",
262
+ # Return a list of dictionaries with arbitrary keys/values
263
+ response_model=List[Dict[str, Any]],
264
+ summary="Get All Filtered Questions (Full Data)",
265
+ description="Returns the complete list of questions with all associated data (excluding answer/annotation) filtered based on criteria.")
266
  async def get_questions():
267
  """
268
+ Provides the list of questions (with extended data) that agents should answer.
269
  """
 
270
  if not questions_for_api:
271
  logger.error("GET /questions requested but no questions are loaded.")
272
  raise HTTPException(status_code=404, detail="No questions available.")
273
+ # questions_for_api now contains the richer dictionaries
274
  return questions_for_api
275
 
 
276
  @app.get("/random-question",
277
+ # Return a single dictionary with arbitrary keys/values
278
+ response_model=Dict[str, Any],
279
+ summary="Get One Random Question (Full Data)",
280
+ description="Returns a single random question with all associated data (excluding answer/annotation) from the available filtered set.",
281
  responses={
282
+ 200: {"description": "A random question with its full data."},
283
  404: {"model": ErrorResponse, "description": "No questions available to choose from."}
284
  })
285
  async def get_random_question():
286
  """
287
+ Provides a single, randomly selected question with its extended data.
288
  """
289
  if not questions_for_api:
290
  logger.warning("GET /random-question requested but no questions are loaded.")
 
292
 
293
  # Select and return a random question dictionary
294
  random_question = random.choice(questions_for_api)
295
+ logger.info(f"Returning random question with task_id: {random_question.get('task_id', 'N/A')}")
296
+ # random_question is already the richer dictionary
297
  return random_question
 
 
298
 
299
+ # --- Submit Endpoint (remains the same, uses ground_truth_answers) ---
300
  @app.post("/submit",
301
  response_model=ScoreResponse,
302
  summary="Submit Agent Answers",
 
360
  logger.debug(f"Incorrect answer for {task_id} from {submission.username}. Submitted: '{submitted}', Expected: '{ground_truth}'")
361
 
362
 
363
+ # Calculate score based on valid attempts AND total number of questions available
364
  if valid_attempted_count == 0:
365
  score = 0.0
366
  message = f"Submission received, but no valid/matching task IDs were found in the {total_attempted_in_payload} answers provided."
367
  logger.warning(f"No valid answers processed for {submission.username} out of {total_attempted_in_payload} submitted.")
368
+ elif not ground_truth_answers: # Prevent division by zero if no questions loaded
369
+ score = 0.0
370
+ message = "Score cannot be calculated because no ground truth answers are loaded."
371
+ logger.error(f"Cannot calculate score for {submission.username}: ground_truth_answers is empty.")
372
  else:
373
+ # Score is based on correct answers divided by the TOTAL number of questions in the filtered set
374
  score = round((correct_count / len(ground_truth_answers)) * 100, 2)
375
+ message = f"Score calculated successfully: {correct_count}/{len(ground_truth_answers)} total questions answered correctly ({valid_attempted_count} valid tasks attempted)."
376
  if valid_attempted_count < total_attempted_in_payload:
377
  message += f" ({total_attempted_in_payload - valid_attempted_count} submitted answers had invalid or duplicate task IDs)."
378
+ logger.info(f"Score for {submission.username}: {score}% ({correct_count}/{len(ground_truth_answers)} correct, based on {valid_attempted_count} valid attempts)")
379
 
380
 
381
  # Update Hugging Face dataset
 
408
  )
409
 
410
  # --- Run the application ---
 
 
411
  if __name__ == "__main__":
412
  logger.info("Starting FastAPI server for local development...")
 
 
 
413
  try:
414
+ load_questions() # Load questions before starting server
415
  if not questions_for_api:
416
  logger.error("EXITING: Cannot start server without loaded questions.")
417
+ # Optional: exit if questions are essential
418
+ # import sys
419
+ # sys.exit(1)
420
  else:
 
421
  local_port = int(os.getenv("PORT", "8000"))
422
  logger.info(f"Running Uvicorn locally on http://127.0.0.1:{local_port}")
 
423
  uvicorn.run(app, host="127.0.0.1", port=local_port, log_level="info")
424
  except Exception as e:
425
  logger.error(f"Failed to start server: {e}", exc_info=True)