Spaces:
Running
Running
File size: 4,451 Bytes
9aaf513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
import os
from backend.db.db_instance import get_db_session
from backend.db.task.dao import (
get_task_status_from_db,
get_all_tasks_status_from_db,
delete_task_from_db,
)
from backend.db.task.models import (
TasksResult,
Task,
TaskStatusResponse,
TaskType
)
from backend.common.models import (
Response,
)
from backend.common.compresser import compress_files, find_file_by_hash
from modules.utils.paths import BACKEND_CACHE_DIR
task_router = APIRouter(prefix="/task", tags=["Tasks"])
@task_router.get(
"/{identifier}",
response_model=TaskStatusResponse,
status_code=status.HTTP_200_OK,
summary="Retrieve Task by Identifier",
description="Retrieve the specific task by its identifier.",
)
async def get_task(
identifier: str,
session: Session = Depends(get_db_session),
) -> TaskStatusResponse:
"""
Retrieve the specific task by its identifier.
"""
task = get_task_status_from_db(identifier=identifier, session=session)
if task is not None:
return task.to_response()
else:
raise HTTPException(status_code=404, detail="Identifier not found")
@task_router.get(
"/file/{identifier}",
status_code=status.HTTP_200_OK,
summary="Retrieve FileResponse Task by Identifier",
description="Retrieve the file response task by its identifier. You can use this endpoint if you need to download"
" The file as a response",
)
async def get_file_task(
identifier: str,
session: Session = Depends(get_db_session),
) -> FileResponse:
"""
Retrieve the downloadable file response of a specific task by its identifier.
Compressed by ZIP basically.
"""
task = get_task_status_from_db(identifier=identifier, session=session)
if task is not None:
if task.task_type == TaskType.BGM_SEPARATION:
output_zip_path = os.path.join(BACKEND_CACHE_DIR, f"{identifier}_bgm_separation.zip")
instrumental_path = find_file_by_hash(
os.path.join(BACKEND_CACHE_DIR, "UVR", "instrumental"),
task.result["instrumental_hash"]
)
vocal_path = find_file_by_hash(
os.path.join(BACKEND_CACHE_DIR, "UVR", "vocals"),
task.result["vocal_hash"]
)
output_zip_path = compress_files(
[instrumental_path, vocal_path],
output_zip_path
)
return FileResponse(
path=output_zip_path,
status_code=200,
filename=output_zip_path,
media_type="application/zip"
)
else:
raise HTTPException(status_code=404, detail=f"File download is only supported for bgm separation."
f" The given type is {task.task_type}")
else:
raise HTTPException(status_code=404, detail="Identifier not found")
# Delete method, commented by default because this endpoint is likely to require special permissions
# @task_router.delete(
# "/{identifier}",
# response_model=Response,
# status_code=status.HTTP_200_OK,
# summary="Delete Task by Identifier",
# description="Delete a task from the system using its identifier.",
# )
async def delete_task(
identifier: str,
session: Session = Depends(get_db_session),
) -> Response:
"""
Delete a task by its identifier.
"""
if delete_task_from_db(identifier, session):
return Response(identifier=identifier, message="Task deleted")
else:
raise HTTPException(status_code=404, detail="Task not found")
# Get All method, commented by default because this endpoint is likely to require special permissions
# @task_router.get(
# "/all",
# response_model=TasksResult,
# status_code=status.HTTP_200_OK,
# summary="Retrieve All Task Statuses",
# description="Retrieve the statuses of all tasks available in the system.",
# )
async def get_all_tasks_status(
session: Session = Depends(get_db_session),
) -> TasksResult:
"""
Retrieve all tasks.
"""
return get_all_tasks_status_from_db(session=session) |