Jofthomas commited on
Commit
309e0ef
·
verified ·
1 Parent(s): 69f1e0d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +70 -71
main.py CHANGED
@@ -114,76 +114,7 @@ def load_questions():
114
  logger.info(f"Stored file path mappings for {len(task_file_paths)} tasks.")
115
  if not questions_for_api:
116
  logger.error("CRITICAL: No valid questions loaded after filtering/processing.")
117
- # --- Add this endpoint definition to your FastAPI app ---
118
-
119
- # Determine a base path for security. This should be the root directory
120
- # where Hugging Face datasets cache is allowed to serve files from.
121
- # IMPORTANT: Adjust this path based on your server's environment or use
122
- # environment variables for configuration.
123
- # Using expanduser handles '~' correctly.
124
- ALLOWED_CACHE_BASE = os.path.abspath(os.path.expanduser("~/.cache/huggingface/datasets"))
125
- logger.info(f"Configured allowed base path for file serving: {ALLOWED_CACHE_BASE}")
126
-
127
- @app.get("/files/{task_id}",
128
- summary="Get Associated File by Task ID",
129
- description="Downloads the file associated with the given task_id, if one exists and is mapped.",
130
- responses={
131
- 200: {
132
- "description": "File content.",
133
- "content": {"*/*": {}} # Indicates response can be any file type
134
- },
135
- 403: {"model": ErrorResponse, "description": "Access denied (e.g., path traversal attempt)."},
136
- 404: {"model": ErrorResponse, "description": "Task ID not found, no file associated, or file missing on server."},
137
- 500: {"model": ErrorResponse, "description": "Server error reading file."}
138
- })
139
- async def get_task_file(task_id: str):
140
- """
141
- Serves the file associated with a specific task ID.
142
- Includes security checks to prevent accessing arbitrary files.
143
- """
144
- logger.info(f"Request received for file associated with task_id: {task_id}")
145
-
146
- if task_id not in task_file_paths:
147
- logger.warning(f"File request failed: task_id '{task_id}' not found in file path mapping.")
148
- raise HTTPException(status_code=404, detail=f"No file path associated with task_id {task_id}.")
149
-
150
- local_file_path = task_file_paths[task_id]
151
- logger.debug(f"Mapped task_id '{task_id}' to local path: {local_file_path}")
152
 
153
- # --- CRUCIAL SECURITY CHECK ---
154
- try:
155
- # Resolve to absolute paths to prevent '..' tricks
156
- abs_file_path = os.path.abspath(local_file_path)
157
- abs_base_path = ALLOWED_CACHE_BASE # Already absolute
158
-
159
- # Check if the resolved file path starts with the allowed base directory
160
- if not abs_file_path.startswith(abs_base_path):
161
- logger.error(f"SECURITY ALERT: Path traversal attempt denied for task_id '{task_id}'. Path '{local_file_path}' resolves outside base '{abs_base_path}'.")
162
- raise HTTPException(status_code=403, detail="File access denied.")
163
-
164
- # Check if the file exists at the resolved, validated path
165
- if not os.path.exists(abs_file_path) or not os.path.isfile(abs_file_path):
166
- logger.error(f"File not found on server for task_id '{task_id}' at expected path: {abs_file_path}")
167
- raise HTTPException(status_code=404, detail=f"File associated with task_id {task_id} not found on server disk.")
168
-
169
- except HTTPException as http_exc:
170
- raise http_exc # Re-raise our own security/404 exceptions
171
- except Exception as path_err:
172
- logger.error(f"Error resolving or checking path '{local_file_path}' for task_id '{task_id}': {path_err}", exc_info=True)
173
- raise HTTPException(status_code=500, detail="Server error validating file path.")
174
- # --- END SECURITY CHECK ---
175
-
176
- # Determine MIME type for the Content-Type header
177
- mime_type, _ = mimetypes.guess_type(abs_file_path)
178
- media_type = mime_type if mime_type else "application/octet-stream" # Default if unknown
179
-
180
- # Extract filename for the Content-Disposition header (suggests filename to browser/client)
181
- file_name_for_download = os.path.basename(abs_file_path)
182
-
183
- logger.info(f"Serving file '{file_name_for_download}' (type: {media_type}) for task_id '{task_id}' from path: {abs_file_path}")
184
-
185
- # Use FileResponse to efficiently stream the file
186
- return FileResponse(path=abs_file_path, media_type=media_type, filename=file_name_for_download)
187
 
188
 
189
  class Question(BaseModel):
@@ -259,8 +190,76 @@ async def startup_event():
259
  # import sys
260
  # sys.exit(1) # Consider exiting if questions are critical
261
 
262
- # --- Helper Function (update_huggingface_dataset remains the same) ---
263
- # ... (update_huggingface_dataset function code) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def update_huggingface_dataset(username: str, score: float):
265
  """Loads the dataset, updates the score if higher, and pushes back."""
266
  try:
 
114
  logger.info(f"Stored file path mappings for {len(task_file_paths)} tasks.")
115
  if not questions_for_api:
116
  logger.error("CRITICAL: No valid questions loaded after filtering/processing.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
 
120
  class Question(BaseModel):
 
190
  # import sys
191
  # sys.exit(1) # Consider exiting if questions are critical
192
 
193
+ # --- Add this endpoint definition to your FastAPI app ---
194
+
195
+ # Determine a base path for security. This should be the root directory
196
+ # where Hugging Face datasets cache is allowed to serve files from.
197
+ # IMPORTANT: Adjust this path based on your server's environment or use
198
+ # environment variables for configuration.
199
+ # Using expanduser handles '~' correctly.
200
+ ALLOWED_CACHE_BASE = os.path.abspath(os.path.expanduser("~/.cache/huggingface/datasets"))
201
+ logger.info(f"Configured allowed base path for file serving: {ALLOWED_CACHE_BASE}")
202
+
203
+ @app.get("/files/{task_id}",
204
+ summary="Get Associated File by Task ID",
205
+ description="Downloads the file associated with the given task_id, if one exists and is mapped.",
206
+ responses={
207
+ 200: {
208
+ "description": "File content.",
209
+ "content": {"*/*": {}} # Indicates response can be any file type
210
+ },
211
+ 403: {"model": ErrorResponse, "description": "Access denied (e.g., path traversal attempt)."},
212
+ 404: {"model": ErrorResponse, "description": "Task ID not found, no file associated, or file missing on server."},
213
+ 500: {"model": ErrorResponse, "description": "Server error reading file."}
214
+ })
215
+ async def get_task_file(task_id: str):
216
+ """
217
+ Serves the file associated with a specific task ID.
218
+ Includes security checks to prevent accessing arbitrary files.
219
+ """
220
+ logger.info(f"Request received for file associated with task_id: {task_id}")
221
+
222
+ if task_id not in task_file_paths:
223
+ logger.warning(f"File request failed: task_id '{task_id}' not found in file path mapping.")
224
+ raise HTTPException(status_code=404, detail=f"No file path associated with task_id {task_id}.")
225
+
226
+ local_file_path = task_file_paths[task_id]
227
+ logger.debug(f"Mapped task_id '{task_id}' to local path: {local_file_path}")
228
+
229
+ # --- CRUCIAL SECURITY CHECK ---
230
+ try:
231
+ # Resolve to absolute paths to prevent '..' tricks
232
+ abs_file_path = os.path.abspath(local_file_path)
233
+ abs_base_path = ALLOWED_CACHE_BASE # Already absolute
234
+
235
+ # Check if the resolved file path starts with the allowed base directory
236
+ if not abs_file_path.startswith(abs_base_path):
237
+ logger.error(f"SECURITY ALERT: Path traversal attempt denied for task_id '{task_id}'. Path '{local_file_path}' resolves outside base '{abs_base_path}'.")
238
+ raise HTTPException(status_code=403, detail="File access denied.")
239
+
240
+ # Check if the file exists at the resolved, validated path
241
+ if not os.path.exists(abs_file_path) or not os.path.isfile(abs_file_path):
242
+ logger.error(f"File not found on server for task_id '{task_id}' at expected path: {abs_file_path}")
243
+ raise HTTPException(status_code=404, detail=f"File associated with task_id {task_id} not found on server disk.")
244
+
245
+ except HTTPException as http_exc:
246
+ raise http_exc # Re-raise our own security/404 exceptions
247
+ except Exception as path_err:
248
+ logger.error(f"Error resolving or checking path '{local_file_path}' for task_id '{task_id}': {path_err}", exc_info=True)
249
+ raise HTTPException(status_code=500, detail="Server error validating file path.")
250
+ # --- END SECURITY CHECK ---
251
+
252
+ # Determine MIME type for the Content-Type header
253
+ mime_type, _ = mimetypes.guess_type(abs_file_path)
254
+ media_type = mime_type if mime_type else "application/octet-stream" # Default if unknown
255
+
256
+ # Extract filename for the Content-Disposition header (suggests filename to browser/client)
257
+ file_name_for_download = os.path.basename(abs_file_path)
258
+
259
+ logger.info(f"Serving file '{file_name_for_download}' (type: {media_type}) for task_id '{task_id}' from path: {abs_file_path}")
260
+
261
+ # Use FileResponse to efficiently stream the file
262
+ return FileResponse(path=abs_file_path, media_type=media_type, filename=file_name_for_download)
263
  def update_huggingface_dataset(username: str, score: float):
264
  """Loads the dataset, updates the score if higher, and pushes back."""
265
  try: