Jofthomas commited on
Commit
34c3c29
·
verified ·
1 Parent(s): 9847443

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +22 -22
main.py CHANGED
@@ -234,28 +234,28 @@ async def startup_event():
234
  # import sys
235
  # sys.exit(1) # Consider exiting if questions are critical
236
 
237
- # --- Add this endpoint definition to your FastAPI app ---
238
-
239
- # Determine a base path for security. This should be the root directory
240
- # where Hugging Face datasets cache is allowed to serve files from.
241
- # IMPORTANT: Adjust this path based on your server's environment or use
242
- # environment variables for configuration.
243
- # Using expanduser handles '~' correctly.
244
- ALLOWED_CACHE_BASE = os.path.abspath(os.path.expanduser("~/.cache/huggingface/datasets"))
245
- logger.info(f"Configured allowed base path for file serving: {ALLOWED_CACHE_BASE}")
246
-
247
- @app.get("/files/{task_id}",
248
- summary="Get Associated File by Task ID",
249
- description="Downloads the file associated with the given task_id, if one exists and is mapped.",
250
- responses={
251
- 200: {
252
- "description": "File content.",
253
- "content": {"*/*": {}} # Indicates response can be any file type
254
- },
255
- 403: {"model": ErrorResponse, "description": "Access denied (e.g., path traversal attempt)."},
256
- 404: {"model": ErrorResponse, "description": "Task ID not found, no file associated, or file missing on server."},
257
- 500: {"model": ErrorResponse, "description": "Server error reading file."}
258
- })
259
  async def get_task_file(task_id: str):
260
  """
261
  Serves the file associated with a specific task ID.
 
234
  # import sys
235
  # sys.exit(1) # Consider exiting if questions are critical
236
 
237
+
238
+ # --- Your Endpoints ---
239
+ @app.get("/files/{task_id}", ...)
240
+ async def get_task_file(task_id: str):
241
+ # ... (endpoint logic) ...
242
+ try:
243
+ # --- Ensure it uses the globally defined variable ---
244
+ abs_base_path = ALLOWED_CACHE_BASE # Uses the variable defined above
245
+ abs_file_path = os.path.abspath(local_file_path)
246
+
247
+ # Add extra debug logging right before the check
248
+ logger.debug(f"Security Check - Comparing: file='{abs_file_path}' against base='{abs_base_path}'")
249
+
250
+ if not abs_file_path.startswith(abs_base_path):
251
+ logger.error(f"SECURITY FAILURE: Path mismatch. File '{abs_file_path}' is NOT within allowed base '{abs_base_path}'.")
252
+ raise HTTPException(status_code=403, detail="File access denied.")
253
+ # ... rest of the endpoint ...
254
+ except Exception as e:
255
+ # ... error handling ...
256
+ # Log the base path again in case of error context
257
+ logger.error(f"Error during file access. Base path check was against: {ALLOWED_CACHE_BASE}")
258
+ raise e # Or handle appropriately
259
  async def get_task_file(task_id: str):
260
  """
261
  Serves the file associated with a specific task ID.