Jofthomas commited on
Commit
69f1e0d
·
verified ·
1 Parent(s): a4e3b45

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +116 -49
main.py CHANGED
@@ -2,6 +2,7 @@
2
  import os
3
  import pandas as pd
4
  from fastapi import FastAPI, HTTPException, Body
 
5
  from pydantic import BaseModel, Field
6
  from typing import List, Dict, Any, Optional
7
  from datasets import load_dataset, Dataset, DatasetDict
@@ -14,117 +15,183 @@ import random
14
  # --- Constants and Config ---
15
  HF_DATASET_ID = "agents-course/unit4-students-scores"
16
 
17
- # --- Data Structures ---
18
- # questions_for_api will now hold richer dictionaries
19
- questions_for_api: List[Dict[str, Any]] = []
20
- ground_truth_answers: Dict[str, str] = {}
21
 
22
- # --- Logging Setup ---
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
26
-
27
- logger = logging.getLogger(__name__) # Make sure logger is initialized
28
  tool_threshold = 3
29
  step_threshold = 5
30
- questions_for_api: List[Dict[str, Any]] = [] # Use Dict[str, Any] for flexibility before validation
31
  ground_truth_answers: Dict[str, str] = {}
32
- filtered_dataset = None # Or initialize as empty list: []
 
 
 
33
 
34
  def load_questions():
35
  global filtered_dataset
36
  global questions_for_api
37
  global ground_truth_answers
 
38
  tempo_filtered = []
39
  # Clear existing data
40
  questions_for_api.clear()
41
  ground_truth_answers.clear()
 
42
 
43
  logger.info("Starting to load and filter GAIA dataset (validation split)...")
44
  try:
45
- # Load the specified split and features
46
  dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
47
  logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
48
- # You can uncomment below to see the first item's structure if needed
49
- # logger.debug(f"First item structure: {dataset[0]}")
50
  except Exception as e:
51
  logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
52
  raise RuntimeError("Could not load the primary GAIA dataset.") from e
53
 
54
- # --- Filtering Logic (remains the same) ---
 
55
  for item in dataset:
56
  metadata = item.get('Annotator Metadata')
57
-
58
- if metadata:
59
  num_tools_str = metadata.get('Number of tools')
60
  num_steps_str = metadata.get('Number of steps')
61
-
62
  if num_tools_str is not None and num_steps_str is not None:
63
  try:
64
  num_tools = int(num_tools_str)
65
  num_steps = int(num_steps_str)
66
-
67
  if num_tools < tool_threshold and num_steps < step_threshold:
68
- tempo_filtered.append(item) # Add the original item if it matches filter
69
  except ValueError:
70
- logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Could not convert tool/step count: tools='{num_tools_str}', steps='{num_steps_str}'.")
71
- # else: # If needed: log missing numbers in metadata
72
- # logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - 'Number of tools' or 'Number of steps' missing in Metadata.")
73
- else:
74
- logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
75
 
76
- filtered_dataset = tempo_filtered # Store the list of filtered original items
77
- logger.info(f"Found {len(filtered_dataset)} questions matching the criteria (tools < {tool_threshold}, steps < {step_threshold}).")
78
 
79
  processed_count = 0
80
- # --- REVISED Processing Logic to match the new Pydantic model ---
81
  for item in filtered_dataset:
82
  task_id = item.get('task_id')
83
- original_question_text = item.get('Question') # Get original text
84
  final_answer = item.get('Final answer')
 
 
85
 
86
- # Validate essential fields needed for processing & ground truth
87
  if task_id and original_question_text and final_answer is not None:
88
-
89
- # Create the dictionary for the API, selecting only the desired fields
90
  processed_item = {
91
- "task_id": str(task_id), # Ensure string type
92
- "question": str(original_question_text), # Rename and ensure string type
93
- # Include optional fields *if they exist* in the source item
94
- "Level": item.get("Level"), # Use .get() for safety, Pydantic handles None
95
- "file_name": item.get("file_name"),
96
- "file_path": item.get("file_path"),
97
  }
98
- # Optional: Clean up None values if Pydantic model doesn't handle them as desired
99
- # processed_item = {k: v for k, v in processed_item.items() if v is not None}
100
- # However, the Optional[...] fields in Pydantic should handle None correctly.
101
 
102
- # Append the structured dictionary matching the Pydantic model
103
  questions_for_api.append(processed_item)
104
 
105
- # Store the ground truth answer separately (as before)
106
  ground_truth_answers[str(task_id)] = str(final_answer)
 
 
 
 
 
 
 
 
 
 
 
107
  processed_count += 1
108
  else:
109
- logger.warning(f"Skipping item due to missing essential fields (task_id, Question, or Final answer): task_id={task_id}")
110
 
111
- logger.info(f"Successfully processed {processed_count} questions for the API matching the Pydantic model.")
 
112
  if not questions_for_api:
113
- logger.error("CRITICAL: No valid questions loaded after filtering and processing. API endpoints needing questions will fail.")
114
- # raise RuntimeError("Failed to load mandatory question data after filtering.")
115
- # --- END REVISED Processing Logic ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
 
 
117
 
 
118
 
119
- # --- Pydantic Models ---
 
120
 
121
 
122
  class Question(BaseModel):
123
  task_id: str
124
  question: str
125
  Level: Optional[str] = None
126
- file_name: Optional[str] = None
127
- file_path: Optional[str] = None
128
 
129
 
130
  # --- The rest of your Pydantic models remain the same ---
 
2
  import os
3
  import pandas as pd
4
  from fastapi import FastAPI, HTTPException, Body
5
+ from fastapi.responses import FileResponse
6
  from pydantic import BaseModel, Field
7
  from typing import List, Dict, Any, Optional
8
  from datasets import load_dataset, Dataset, DatasetDict
 
15
  # --- Constants and Config ---
16
  HF_DATASET_ID = "agents-course/unit4-students-scores"
17
 
 
 
 
 
18
 
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
+ task_file_paths: Dict[str, str] = {}
 
23
  tool_threshold = 3
24
  step_threshold = 5
25
+ questions_for_api: List[Dict[str, Any]] = []
26
  ground_truth_answers: Dict[str, str] = {}
27
+ filtered_dataset = None
28
+ # --- Define ErrorResponse if not already defined ---
29
+ class ErrorResponse(BaseModel):
30
+ detail: str
31
 
32
  def load_questions():
33
  global filtered_dataset
34
  global questions_for_api
35
  global ground_truth_answers
36
+ global task_file_paths # Declare modification of global
37
  tempo_filtered = []
38
  # Clear existing data
39
  questions_for_api.clear()
40
  ground_truth_answers.clear()
41
+ task_file_paths.clear() # Clear the mapping too
42
 
43
  logger.info("Starting to load and filter GAIA dataset (validation split)...")
44
  try:
 
45
  dataset = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation", trust_remote_code=True)
46
  logger.info(f"GAIA dataset validation split loaded. Features: {dataset.features}")
 
 
47
  except Exception as e:
48
  logger.error(f"Failed to load GAIA dataset: {e}", exc_info=True)
49
  raise RuntimeError("Could not load the primary GAIA dataset.") from e
50
 
51
+ # --- Filtering Logic (remains same) ---
52
+ # [ ... Same filtering code as before ... ]
53
  for item in dataset:
54
  metadata = item.get('Annotator Metadata')
55
+ if metadata: # Check if 'Annotator Metadata' exists
 
56
  num_tools_str = metadata.get('Number of tools')
57
  num_steps_str = metadata.get('Number of steps')
 
58
  if num_tools_str is not None and num_steps_str is not None:
59
  try:
60
  num_tools = int(num_tools_str)
61
  num_steps = int(num_steps_str)
 
62
  if num_tools < tool_threshold and num_steps < step_threshold:
63
+ tempo_filtered.append(item)
64
  except ValueError:
65
+ logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Could not convert tool/step count.")
66
+ # else: # Log missing metadata if needed
67
+ # logger.warning(f"Skipping Task ID: {item.get('task_id', 'N/A')} - Missing 'Annotator Metadata'.")
68
+
 
69
 
70
+ filtered_dataset = tempo_filtered
71
+ logger.info(f"Found {len(filtered_dataset)} questions matching the criteria.")
72
 
73
  processed_count = 0
74
+ # --- Processing Logic (includes storing file path mapping) ---
75
  for item in filtered_dataset:
76
  task_id = item.get('task_id')
77
+ original_question_text = item.get('Question')
78
  final_answer = item.get('Final answer')
79
+ local_file_path = item.get('file_path') # Get the local path
80
+ file_name = item.get('file_name') # Get the filename
81
 
82
+ # Validate essential fields
83
  if task_id and original_question_text and final_answer is not None:
84
+ # Create the dictionary for the API (WITHOUT file_path)
 
85
  processed_item = {
86
+ "task_id": str(task_id),
87
+ "question": str(original_question_text),
88
+ "Level": item.get("Level"),
89
+ "file_name": file_name, # Include filename for info
 
 
90
  }
91
+ # Clean None values if you prefer not to send nulls for optional fields
92
+ processed_item = {k: v for k, v in processed_item.items() if v is not None}
 
93
 
 
94
  questions_for_api.append(processed_item)
95
 
96
+ # Store ground truth
97
  ground_truth_answers[str(task_id)] = str(final_answer)
98
+
99
+ # --- Store the file path mapping ---
100
+ if local_file_path and file_name: # Only store if path and name exist
101
+ # Basic check if path looks plausible (optional)
102
+ if os.path.exists(local_file_path):
103
+ task_file_paths[str(task_id)] = local_file_path
104
+ logger.debug(f"Stored file path for task_id {task_id}: {local_file_path}")
105
+ else:
106
+ logger.warning(f"File path '{local_file_path}' for task_id {task_id} does not exist on server. Mapping skipped.")
107
+
108
+
109
  processed_count += 1
110
  else:
111
+ logger.warning(f"Skipping item due to missing essential fields: task_id={task_id}")
112
 
113
+ logger.info(f"Successfully processed {processed_count} questions for the API.")
114
+ logger.info(f"Stored file path mappings for {len(task_file_paths)} tasks.")
115
  if not questions_for_api:
116
+ logger.error("CRITICAL: No valid questions loaded after filtering/processing.")
117
+ # --- Add this endpoint definition to your FastAPI app ---
118
+
119
+ # Determine a base path for security. This should be the root directory
120
+ # where Hugging Face datasets cache is allowed to serve files from.
121
+ # IMPORTANT: Adjust this path based on your server's environment or use
122
+ # environment variables for configuration.
123
+ # Using expanduser handles '~' correctly.
124
+ ALLOWED_CACHE_BASE = os.path.abspath(os.path.expanduser("~/.cache/huggingface/datasets"))
125
+ logger.info(f"Configured allowed base path for file serving: {ALLOWED_CACHE_BASE}")
126
+
127
+ @app.get("/files/{task_id}",
128
+ summary="Get Associated File by Task ID",
129
+ description="Downloads the file associated with the given task_id, if one exists and is mapped.",
130
+ responses={
131
+ 200: {
132
+ "description": "File content.",
133
+ "content": {"*/*": {}} # Indicates response can be any file type
134
+ },
135
+ 403: {"model": ErrorResponse, "description": "Access denied (e.g., path traversal attempt)."},
136
+ 404: {"model": ErrorResponse, "description": "Task ID not found, no file associated, or file missing on server."},
137
+ 500: {"model": ErrorResponse, "description": "Server error reading file."}
138
+ })
139
+ async def get_task_file(task_id: str):
140
+ """
141
+ Serves the file associated with a specific task ID.
142
+ Includes security checks to prevent accessing arbitrary files.
143
+ """
144
+ logger.info(f"Request received for file associated with task_id: {task_id}")
145
+
146
+ if task_id not in task_file_paths:
147
+ logger.warning(f"File request failed: task_id '{task_id}' not found in file path mapping.")
148
+ raise HTTPException(status_code=404, detail=f"No file path associated with task_id {task_id}.")
149
+
150
+ local_file_path = task_file_paths[task_id]
151
+ logger.debug(f"Mapped task_id '{task_id}' to local path: {local_file_path}")
152
+
153
+ # --- CRUCIAL SECURITY CHECK ---
154
+ try:
155
+ # Resolve to absolute paths to prevent '..' tricks
156
+ abs_file_path = os.path.abspath(local_file_path)
157
+ abs_base_path = ALLOWED_CACHE_BASE # Already absolute
158
+
159
+ # Check if the resolved file path starts with the allowed base directory
160
+ if not abs_file_path.startswith(abs_base_path):
161
+ logger.error(f"SECURITY ALERT: Path traversal attempt denied for task_id '{task_id}'. Path '{local_file_path}' resolves outside base '{abs_base_path}'.")
162
+ raise HTTPException(status_code=403, detail="File access denied.")
163
+
164
+ # Check if the file exists at the resolved, validated path
165
+ if not os.path.exists(abs_file_path) or not os.path.isfile(abs_file_path):
166
+ logger.error(f"File not found on server for task_id '{task_id}' at expected path: {abs_file_path}")
167
+ raise HTTPException(status_code=404, detail=f"File associated with task_id {task_id} not found on server disk.")
168
+
169
+ except HTTPException as http_exc:
170
+ raise http_exc # Re-raise our own security/404 exceptions
171
+ except Exception as path_err:
172
+ logger.error(f"Error resolving or checking path '{local_file_path}' for task_id '{task_id}': {path_err}", exc_info=True)
173
+ raise HTTPException(status_code=500, detail="Server error validating file path.")
174
+ # --- END SECURITY CHECK ---
175
+
176
+ # Determine MIME type for the Content-Type header
177
+ mime_type, _ = mimetypes.guess_type(abs_file_path)
178
+ media_type = mime_type if mime_type else "application/octet-stream" # Default if unknown
179
 
180
+ # Extract filename for the Content-Disposition header (suggests filename to browser/client)
181
+ file_name_for_download = os.path.basename(abs_file_path)
182
 
183
+ logger.info(f"Serving file '{file_name_for_download}' (type: {media_type}) for task_id '{task_id}' from path: {abs_file_path}")
184
 
185
+ # Use FileResponse to efficiently stream the file
186
+ return FileResponse(path=abs_file_path, media_type=media_type, filename=file_name_for_download)
187
 
188
 
189
  class Question(BaseModel):
190
  task_id: str
191
  question: str
192
  Level: Optional[str] = None
193
+ file_name: Optional[str] = None # Keep filename for info
194
+ # file_path: Optional[str] = None # REMOVE file_path from the response model
195
 
196
 
197
  # --- The rest of your Pydantic models remain the same ---