soiz1's picture
Upload 109 files
9aaf513 verified
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
import os
from backend.db.db_instance import get_db_session
from backend.db.task.dao import (
get_task_status_from_db,
get_all_tasks_status_from_db,
delete_task_from_db,
)
from backend.db.task.models import (
TasksResult,
Task,
TaskStatusResponse,
TaskType
)
from backend.common.models import (
Response,
)
from backend.common.compresser import compress_files, find_file_by_hash
from modules.utils.paths import BACKEND_CACHE_DIR
task_router = APIRouter(prefix="/task", tags=["Tasks"])
@task_router.get(
"/{identifier}",
response_model=TaskStatusResponse,
status_code=status.HTTP_200_OK,
summary="Retrieve Task by Identifier",
description="Retrieve the specific task by its identifier.",
)
async def get_task(
identifier: str,
session: Session = Depends(get_db_session),
) -> TaskStatusResponse:
"""
Retrieve the specific task by its identifier.
"""
task = get_task_status_from_db(identifier=identifier, session=session)
if task is not None:
return task.to_response()
else:
raise HTTPException(status_code=404, detail="Identifier not found")
@task_router.get(
"/file/{identifier}",
status_code=status.HTTP_200_OK,
summary="Retrieve FileResponse Task by Identifier",
description="Retrieve the file response task by its identifier. You can use this endpoint if you need to download"
" The file as a response",
)
async def get_file_task(
identifier: str,
session: Session = Depends(get_db_session),
) -> FileResponse:
"""
Retrieve the downloadable file response of a specific task by its identifier.
Compressed by ZIP basically.
"""
task = get_task_status_from_db(identifier=identifier, session=session)
if task is not None:
if task.task_type == TaskType.BGM_SEPARATION:
output_zip_path = os.path.join(BACKEND_CACHE_DIR, f"{identifier}_bgm_separation.zip")
instrumental_path = find_file_by_hash(
os.path.join(BACKEND_CACHE_DIR, "UVR", "instrumental"),
task.result["instrumental_hash"]
)
vocal_path = find_file_by_hash(
os.path.join(BACKEND_CACHE_DIR, "UVR", "vocals"),
task.result["vocal_hash"]
)
output_zip_path = compress_files(
[instrumental_path, vocal_path],
output_zip_path
)
return FileResponse(
path=output_zip_path,
status_code=200,
filename=output_zip_path,
media_type="application/zip"
)
else:
raise HTTPException(status_code=404, detail=f"File download is only supported for bgm separation."
f" The given type is {task.task_type}")
else:
raise HTTPException(status_code=404, detail="Identifier not found")
# Delete method, commented by default because this endpoint is likely to require special permissions
# @task_router.delete(
# "/{identifier}",
# response_model=Response,
# status_code=status.HTTP_200_OK,
# summary="Delete Task by Identifier",
# description="Delete a task from the system using its identifier.",
# )
async def delete_task(
identifier: str,
session: Session = Depends(get_db_session),
) -> Response:
"""
Delete a task by its identifier.
"""
if delete_task_from_db(identifier, session):
return Response(identifier=identifier, message="Task deleted")
else:
raise HTTPException(status_code=404, detail="Task not found")
# Get All method, commented by default because this endpoint is likely to require special permissions
# @task_router.get(
# "/all",
# response_model=TasksResult,
# status_code=status.HTTP_200_OK,
# summary="Retrieve All Task Statuses",
# description="Retrieve the statuses of all tasks available in the system.",
# )
async def get_all_tasks_status(
session: Session = Depends(get_db_session),
) -> TasksResult:
"""
Retrieve all tasks.
"""
return get_all_tasks_status_from_db(session=session)