Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
@@ -25,95 +25,139 @@ step_threshold = 5
|
|
25 |
questions_for_api: List[Dict[str, Any]] = []
|
26 |
ground_truth_answers: Dict[str, str] = {}
|
27 |
filtered_dataset = None
|
|
|
|
|
|
|
28 |
# --- Define ErrorResponse if not already defined ---
|
29 |
class ErrorResponse(BaseModel):
|
30 |
detail: str
|
31 |
|
|
|
32 |
def load_questions():
|
|
|
|
|
|
|
|
|
|
|
33 |
global filtered_dataset
|
34 |
global questions_for_api
|
35 |
global ground_truth_answers
|
36 |
-
global task_file_paths
|
|
|
37 |
tempo_filtered = []
|
38 |
-
# Clear existing data
|
39 |
questions_for_api.clear()
|
40 |
ground_truth_answers.clear()
|
41 |
-
task_file_paths.clear() # Clear the mapping
|
42 |
|
43 |
logger.info("Starting to load and filter GAIA dataset (validation split)...")
|
44 |
try:
|
|
|
45 |
dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
|
46 |
logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
|
47 |
except Exception as e:
|
48 |
logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
|
|
|
49 |
raise RuntimeError("Could not load the primary GAIA dataset.") from e
|
50 |
|
51 |
-
# --- Filtering Logic
|
52 |
-
# [ ... Same filtering code as before ... ]
|
53 |
for item in dataset:
|
54 |
metadata = item.get('Annotator Metadata')
|
55 |
-
|
|
|
56 |
num_tools_str = metadata.get('Number of tools')
|
57 |
num_steps_str = metadata.get('Number of steps')
|
|
|
58 |
if num_tools_str is not None and num_steps_str is not None:
|
59 |
try:
|
60 |
num_tools = int(num_tools_str)
|
61 |
num_steps = int(num_steps_str)
|
|
|
62 |
if num_tools < tool_threshold and num_steps < step_threshold:
|
63 |
-
tempo_filtered.append(item)
|
64 |
except ValueError:
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
|
70 |
-
filtered_dataset = tempo_filtered
|
71 |
-
logger.info(f"Found {len(filtered_dataset)} questions matching the criteria.")
|
72 |
|
73 |
processed_count = 0
|
74 |
-
# ---
|
75 |
for item in filtered_dataset:
|
|
|
76 |
task_id = item.get('task_id')
|
77 |
original_question_text = item.get('Question')
|
78 |
final_answer = item.get('Final answer')
|
79 |
-
local_file_path = item.get('file_path') #
|
80 |
-
file_name = item.get('file_name')
|
81 |
|
82 |
-
# Validate essential fields
|
|
|
83 |
if task_id and original_question_text and final_answer is not None:
|
84 |
-
|
|
|
|
|
85 |
processed_item = {
|
86 |
"task_id": str(task_id),
|
87 |
-
"question": str(original_question_text),
|
|
|
88 |
"Level": item.get("Level"),
|
89 |
-
"file_name": file_name, # Include filename for info
|
90 |
}
|
91 |
-
#
|
92 |
processed_item = {k: v for k, v in processed_item.items() if v is not None}
|
93 |
|
94 |
questions_for_api.append(processed_item)
|
95 |
|
96 |
-
# Store ground truth
|
97 |
ground_truth_answers[str(task_id)] = str(final_answer)
|
98 |
|
99 |
-
#
|
100 |
-
if local_file_path and file_name:
|
101 |
-
#
|
102 |
-
if os.path.
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
else:
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
|
109 |
processed_count += 1
|
110 |
else:
|
111 |
-
|
|
|
112 |
|
|
|
113 |
logger.info(f"Successfully processed {processed_count} questions for the API.")
|
114 |
logger.info(f"Stored file path mappings for {len(task_file_paths)} tasks.")
|
|
|
115 |
if not questions_for_api:
|
116 |
-
logger.error("CRITICAL: No valid questions loaded after filtering
|
|
|
|
|
117 |
|
118 |
|
119 |
|
|
|
25 |
questions_for_api: List[Dict[str, Any]] = []
|
26 |
ground_truth_answers: Dict[str, str] = {}
|
27 |
filtered_dataset = None
|
28 |
+
|
29 |
+
ALLOWED_CACHE_BASE = os.path.abspath("/app/.cache")
|
30 |
+
|
31 |
# --- Define ErrorResponse if not already defined ---
|
32 |
class ErrorResponse(BaseModel):
|
33 |
detail: str
|
34 |
|
35 |
+
|
36 |
def load_questions():
|
37 |
+
"""
|
38 |
+
Loads the GAIA dataset, filters questions based on tool/step counts,
|
39 |
+
populates 'questions_for_api' with data for the API (excluding sensitive/internal fields),
|
40 |
+
stores ground truth answers, and maps task IDs to their local file paths on the server.
|
41 |
+
"""
|
42 |
global filtered_dataset
|
43 |
global questions_for_api
|
44 |
global ground_truth_answers
|
45 |
+
global task_file_paths # Declare modification of global
|
46 |
+
|
47 |
tempo_filtered = []
|
48 |
+
# Clear existing data from previous runs or restarts
|
49 |
questions_for_api.clear()
|
50 |
ground_truth_answers.clear()
|
51 |
+
task_file_paths.clear() # Clear the file path mapping
|
52 |
|
53 |
logger.info("Starting to load and filter GAIA dataset (validation split)...")
|
54 |
try:
|
55 |
+
# Load the specified split
|
56 |
dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
|
57 |
logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
|
58 |
except Exception as e:
|
59 |
logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
|
60 |
+
# Depending on requirements, you might want to exit or raise a more specific error
|
61 |
raise RuntimeError("Could not load the primary GAIA dataset.") from e
|
62 |
|
63 |
+
# --- Filtering Logic based on Annotator Metadata ---
|
|
|
64 |
for item in dataset:
|
65 |
metadata = item.get('Annotator Metadata')
|
66 |
+
|
67 |
+
if metadata:
|
68 |
num_tools_str = metadata.get('Number of tools')
|
69 |
num_steps_str = metadata.get('Number of steps')
|
70 |
+
|
71 |
if num_tools_str is not None and num_steps_str is not None:
|
72 |
try:
|
73 |
num_tools = int(num_tools_str)
|
74 |
num_steps = int(num_steps_str)
|
75 |
+
# Apply filter conditions
|
76 |
if num_tools < tool_threshold and num_steps < step_threshold:
|
77 |
+
tempo_filtered.append(item) # Add the original item if it matches filter
|
78 |
except ValueError:
|
79 |
+
logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Could not convert tool/step count in metadata: tools='{num_tools_str}', steps='{num_steps_str}'.")
|
80 |
+
else:
|
81 |
+
logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - 'Number of tools' or 'Number of steps' missing in Metadata.")
|
82 |
+
else:
|
83 |
+
# If metadata is essential for filtering, you might want to skip items without it
|
84 |
+
logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
|
85 |
|
86 |
+
filtered_dataset = tempo_filtered # Store the list of filtered original dataset items
|
87 |
+
logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
|
88 |
|
89 |
processed_count = 0
|
90 |
+
# --- Process filtered items for API and File Mapping ---
|
91 |
for item in filtered_dataset:
|
92 |
+
# Extract data from the dataset item
|
93 |
task_id = item.get('task_id')
|
94 |
original_question_text = item.get('Question')
|
95 |
final_answer = item.get('Final answer')
|
96 |
+
local_file_path = item.get('file_path') # Server-local path from dataset
|
97 |
+
file_name = item.get('file_name') # Filename from dataset
|
98 |
|
99 |
+
# Validate essential fields needed for processing & ground truth
|
100 |
+
# Note: We proceed even if file path/name are missing, just won't map the file.
|
101 |
if task_id and original_question_text and final_answer is not None:
|
102 |
+
|
103 |
+
# 1. Create the dictionary to be exposed via the API
|
104 |
+
# (Includes 'file_name' for info, but excludes 'file_path')
|
105 |
processed_item = {
|
106 |
"task_id": str(task_id),
|
107 |
+
"question": str(original_question_text), # Rename 'Question' -> 'question'
|
108 |
+
# Include other desired fields, using .get() for safety
|
109 |
"Level": item.get("Level"),
|
110 |
+
"file_name": file_name, # Include filename for client info
|
111 |
}
|
112 |
+
# Optional: Remove keys with None values if you prefer cleaner JSON
|
113 |
processed_item = {k: v for k, v in processed_item.items() if v is not None}
|
114 |
|
115 |
questions_for_api.append(processed_item)
|
116 |
|
117 |
+
# 2. Store the ground truth answer separately
|
118 |
ground_truth_answers[str(task_id)] = str(final_answer)
|
119 |
|
120 |
+
# 3. Store the file path mapping if file details exist and are valid
|
121 |
+
if local_file_path and file_name:
|
122 |
+
# Log if the path from the dataset isn't absolute (might indicate issues)
|
123 |
+
if not os.path.isabs(local_file_path):
|
124 |
+
logger.warning(f"Task {task_id}: Path '{local_file_path}' from dataset is not absolute. This might cause issues finding the file on the server.")
|
125 |
+
# Depending on dataset guarantees, you might try making it absolute:
|
126 |
+
# Assuming WORKDIR is /app as per Dockerfile if paths are relative
|
127 |
+
# local_file_path = os.path.abspath(os.path.join("/app", local_file_path))
|
128 |
+
|
129 |
+
# Check if the file actually exists at the path ON THE SERVER
|
130 |
+
if os.path.exists(local_file_path) and os.path.isfile(local_file_path):
|
131 |
+
# Path exists, store the mapping
|
132 |
+
task_file_paths[str(task_id)] = local_file_path
|
133 |
+
logger.debug(f"Stored file path mapping for task_id {task_id}: {local_file_path}")
|
134 |
else:
|
135 |
+
# Path does *not* exist or is not a file on server filesystem
|
136 |
+
logger.warning(f"File path '{local_file_path}' for task_id {task_id} does NOT exist or is not a file on server. Mapping skipped.")
|
137 |
+
# Log if file info was missing in the first place
|
138 |
+
elif task_id: # Log only if we have a task_id to reference
|
139 |
+
# Check which specific part was missing for better debugging
|
140 |
+
if not local_file_path and not file_name:
|
141 |
+
logger.debug(f"Task {task_id}: No 'file_path' or 'file_name' found in dataset item. No file mapping stored.")
|
142 |
+
elif not local_file_path:
|
143 |
+
logger.debug(f"Task {task_id}: 'file_path' is missing in dataset item (file_name: '{file_name}'). No file mapping stored.")
|
144 |
+
else: # Not file_name
|
145 |
+
logger.debug(f"Task {task_id}: 'file_name' is missing in dataset item (file_path: '{local_file_path}'). No file mapping stored.")
|
146 |
|
147 |
|
148 |
processed_count += 1
|
149 |
else:
|
150 |
+
# Log skipping due to missing core fields (task_id, Question, Final answer)
|
151 |
+
logger.warning(f"Skipping item processing due to missing essential fields: task_id={task_id}, has_question={original_question_text is not None}, has_answer={final_answer is not None}")
|
152 |
|
153 |
+
# Final summary logging
|
154 |
logger.info(f"Successfully processed {processed_count} questions for the API.")
|
155 |
logger.info(f"Stored file path mappings for {len(task_file_paths)} tasks.")
|
156 |
+
|
157 |
if not questions_for_api:
|
158 |
+
logger.error("CRITICAL: No valid questions were loaded after filtering and processing. API endpoints like /questions will fail.")
|
159 |
+
# Consider raising an error if the application cannot function without questions
|
160 |
+
# raise RuntimeError("Failed to load mandatory question data after filtering.")
|
161 |
|
162 |
|
163 |
|