Jofthomas commited on
Commit
9847443
·
verified ·
1 Parent(s): 309e0ef

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +75 -31
main.py CHANGED
@@ -25,95 +25,139 @@ step_threshold = 5
25
  questions_for_api: List[Dict[str, Any]] = []
26
  ground_truth_answers: Dict[str, str] = {}
27
  filtered_dataset = None
 
 
 
28
  # --- Define ErrorResponse if not already defined ---
29
  class ErrorResponse(BaseModel):
30
  detail: str
31
 
 
32
  def load_questions():
 
 
 
 
 
33
  global filtered_dataset
34
  global questions_for_api
35
  global ground_truth_answers
36
- global task_file_paths # Declare modification of global
 
37
  tempo_filtered = []
38
- # Clear existing data
39
  questions_for_api.clear()
40
  ground_truth_answers.clear()
41
- task_file_paths.clear() # Clear the mapping too
42
 
43
  logger.info("Starting to load and filter GAIA dataset (validation split)...")
44
  try:
 
45
  dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
46
  logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
47
  except Exception as e:
48
  logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
 
49
  raise RuntimeError("Could not load the primary GAIA dataset.") from e
50
 
51
- # --- Filtering Logic (remains same) ---
52
- # [ ... Same filtering code as before ... ]
53
  for item in dataset:
54
  metadata = item.get('Annotator Metadata')
55
- if metadata: # Check if 'Annotator Metadata' exists
 
56
  num_tools_str = metadata.get('Number of tools')
57
  num_steps_str = metadata.get('Number of steps')
 
58
  if num_tools_str is not None and num_steps_str is not None:
59
  try:
60
  num_tools = int(num_tools_str)
61
  num_steps = int(num_steps_str)
 
62
  if num_tools < tool_threshold and num_steps < step_threshold:
63
- tempo_filtered.append(item)
64
  except ValueError:
65
- logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Could not convert tool/step count.")
66
- # else: # Log missing metadata if needed
67
- # logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
68
-
 
 
69
 
70
- filtered_dataset = tempo_filtered
71
- logger.info(f"Found {len(filtered_dataset)} questions matching the criteria.")
72
 
73
  processed_count = 0
74
- # --- Processing Logic (includes storing file path mapping) ---
75
  for item in filtered_dataset:
 
76
  task_id = item.get('task_id')
77
  original_question_text = item.get('Question')
78
  final_answer = item.get('Final answer')
79
- local_file_path = item.get('file_path') # Get the local path
80
- file_name = item.get('file_name') # Get the filename
81
 
82
- # Validate essential fields
 
83
  if task_id and original_question_text and final_answer is not None:
84
- # Create the dictionary for the API (WITHOUT file_path)
 
 
85
  processed_item = {
86
  "task_id": str(task_id),
87
- "question": str(original_question_text),
 
88
  "Level": item.get("Level"),
89
- "file_name": file_name, # Include filename for info
90
  }
91
- # Clean None values if you prefer not to send nulls for optional fields
92
  processed_item = {k: v for k, v in processed_item.items() if v is not None}
93
 
94
  questions_for_api.append(processed_item)
95
 
96
- # Store ground truth
97
  ground_truth_answers[str(task_id)] = str(final_answer)
98
 
99
- # --- Store the file path mapping ---
100
- if local_file_path and file_name: # Only store if path and name exist
101
- # Basic check if path looks plausible (optional)
102
- if os.path.exists(local_file_path):
103
- task_file_paths[str(task_id)] = local_file_path
104
- logger.debug(f"Stored file path for task_id {task_id}: {local_file_path}")
 
 
 
 
 
 
 
 
105
  else:
106
- logger.warning(f"File path '{local_file_path}' for task_id {task_id} does not exist on server. Mapping skipped.")
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
  processed_count += 1
110
  else:
111
- logger.warning(f"Skipping item due to missing essential fields: task_id={task_id}")
 
112
 
 
113
  logger.info(f"Successfully processed {processed_count} questions for the API.")
114
  logger.info(f"Stored file path mappings for {len(task_file_paths)} tasks.")
 
115
  if not questions_for_api:
116
- logger.error("CRITICAL: No valid questions loaded after filtering/processing.")
 
 
117
 
118
 
119
 
 
25
  questions_for_api: List[Dict[str, Any]] = []
26
  ground_truth_answers: Dict[str, str] = {}
27
  filtered_dataset = None
28
+
29
+ ALLOWED_CACHE_BASE = os.path.abspath("/app/.cache")
30
+
31
  # --- Define ErrorResponse if not already defined ---
32
  class ErrorResponse(BaseModel):
33
  detail: str
34
 
35
+
36
  def load_questions():
37
+ """
38
+ Loads the GAIA dataset, filters questions based on tool/step counts,
39
+ populates 'questions_for_api' with data for the API (excluding sensitive/internal fields),
40
+ stores ground truth answers, and maps task IDs to their local file paths on the server.
41
+ """
42
  global filtered_dataset
43
  global questions_for_api
44
  global ground_truth_answers
45
+ global task_file_paths # Declare modification of global
46
+
47
  tempo_filtered = []
48
+ # Clear existing data from previous runs or restarts
49
  questions_for_api.clear()
50
  ground_truth_answers.clear()
51
+ task_file_paths.clear() # Clear the file path mapping
52
 
53
  logger.info("Starting to load and filter GAIA dataset (validation split)...")
54
  try:
55
+ # Load the specified split
56
  dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
57
  logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
58
  except Exception as e:
59
  logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
60
+ # Depending on requirements, you might want to exit or raise a more specific error
61
  raise RuntimeError("Could not load the primary GAIA dataset.") from e
62
 
63
+ # --- Filtering Logic based on Annotator Metadata ---
 
64
  for item in dataset:
65
  metadata = item.get('Annotator Metadata')
66
+
67
+ if metadata:
68
  num_tools_str = metadata.get('Number of tools')
69
  num_steps_str = metadata.get('Number of steps')
70
+
71
  if num_tools_str is not None and num_steps_str is not None:
72
  try:
73
  num_tools = int(num_tools_str)
74
  num_steps = int(num_steps_str)
75
+ # Apply filter conditions
76
  if num_tools < tool_threshold and num_steps < step_threshold:
77
+ tempo_filtered.append(item) # Add the original item if it matches filter
78
  except ValueError:
79
+ logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Could not convert tool/step count in metadata: tools='{num_tools_str}', steps='{num_steps_str}'.")
80
+ else:
81
+ logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - 'Number of tools' or 'Number of steps' missing in Metadata.")
82
+ else:
83
+ # If metadata is essential for filtering, you might want to skip items without it
84
+ logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
85
 
86
+ filtered_dataset = tempo_filtered # Store the list of filtered original dataset items
87
+ logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
88
 
89
  processed_count = 0
90
+ # --- Process filtered items for API and File Mapping ---
91
  for item in filtered_dataset:
92
+ # Extract data from the dataset item
93
  task_id = item.get('task_id')
94
  original_question_text = item.get('Question')
95
  final_answer = item.get('Final answer')
96
+ local_file_path = item.get('file_path') # Server-local path from dataset
97
+ file_name = item.get('file_name') # Filename from dataset
98
 
99
+ # Validate essential fields needed for processing & ground truth
100
+ # Note: We proceed even if file path/name are missing, just won't map the file.
101
  if task_id and original_question_text and final_answer is not None:
102
+
103
+ # 1. Create the dictionary to be exposed via the API
104
+ # (Includes 'file_name' for info, but excludes 'file_path')
105
  processed_item = {
106
  "task_id": str(task_id),
107
+ "question": str(original_question_text), # Rename 'Question' -> 'question'
108
+ # Include other desired fields, using .get() for safety
109
  "Level": item.get("Level"),
110
+ "file_name": file_name, # Include filename for client info
111
  }
112
+ # Optional: Remove keys with None values if you prefer cleaner JSON
113
  processed_item = {k: v for k, v in processed_item.items() if v is not None}
114
 
115
  questions_for_api.append(processed_item)
116
 
117
+ # 2. Store the ground truth answer separately
118
  ground_truth_answers[str(task_id)] = str(final_answer)
119
 
120
+ # 3. Store the file path mapping if file details exist and are valid
121
+ if local_file_path and file_name:
122
+ # Log if the path from the dataset isn't absolute (might indicate issues)
123
+ if not os.path.isabs(local_file_path):
124
+ logger.warning(f"Task {task_id}: Path '{local_file_path}' from dataset is not absolute. This might cause issues finding the file on the server.")
125
+ # Depending on dataset guarantees, you might try making it absolute:
126
+ # Assuming WORKDIR is /app as per Dockerfile if paths are relative
127
+ # local_file_path = os.path.abspath(os.path.join("/app", local_file_path))
128
+
129
+ # Check if the file actually exists at the path ON THE SERVER
130
+ if os.path.exists(local_file_path) and os.path.isfile(local_file_path):
131
+ # Path exists, store the mapping
132
+ task_file_paths[str(task_id)] = local_file_path
133
+ logger.debug(f"Stored file path mapping for task_id {task_id}: {local_file_path}")
134
  else:
135
+ # Path does *not* exist or is not a file on server filesystem
136
+ logger.warning(f"File path '{local_file_path}' for task_id {task_id} does NOT exist or is not a file on server. Mapping skipped.")
137
+ # Log if file info was missing in the first place
138
+ elif task_id: # Log only if we have a task_id to reference
139
+ # Check which specific part was missing for better debugging
140
+ if not local_file_path and not file_name:
141
+ logger.debug(f"Task {task_id}: No 'file_path' or 'file_name' found in dataset item. No file mapping stored.")
142
+ elif not local_file_path:
143
+ logger.debug(f"Task {task_id}: 'file_path' is missing in dataset item (file_name: '{file_name}'). No file mapping stored.")
144
+ else: # Not file_name
145
+ logger.debug(f"Task {task_id}: 'file_name' is missing in dataset item (file_path: '{local_file_path}'). No file mapping stored.")
146
 
147
 
148
  processed_count += 1
149
  else:
150
+ # Log skipping due to missing core fields (task_id, Question, Final answer)
151
+ logger.warning(f"Skipping item processing due to missing essential fields: task_id={task_id}, has_question={original_question_text is not None}, has_answer={final_answer is not None}")
152
 
153
+ # Final summary logging
154
  logger.info(f"Successfully processed {processed_count} questions for the API.")
155
  logger.info(f"Stored file path mappings for {len(task_file_paths)} tasks.")
156
+
157
  if not questions_for_api:
158
+ logger.error("CRITICAL: No valid questions were loaded after filtering and processing. API endpoints like /questions will fail.")
159
+ # Consider raising an error if the application cannot function without questions
160
+ # raise RuntimeError("Failed to load mandatory question data after filtering.")
161
 
162
 
163