File size: 4,451 Bytes
9aaf513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
import os

from backend.db.db_instance import get_db_session
from backend.db.task.dao import (
    get_task_status_from_db,
    get_all_tasks_status_from_db,
    delete_task_from_db,
)
from backend.db.task.models import (
    TasksResult,
    Task,
    TaskStatusResponse,
    TaskType
)
from backend.common.models import (
    Response,
)
from backend.common.compresser import compress_files, find_file_by_hash
from modules.utils.paths import BACKEND_CACHE_DIR

task_router = APIRouter(prefix="/task", tags=["Tasks"])


@task_router.get(

    "/{identifier}",

    response_model=TaskStatusResponse,

    status_code=status.HTTP_200_OK,

    summary="Retrieve Task by Identifier",

    description="Retrieve the specific task by its identifier.",

)
async def get_task(

    identifier: str,

    session: Session = Depends(get_db_session),

) -> TaskStatusResponse:
    """

    Retrieve the specific task by its identifier.

    """
    task = get_task_status_from_db(identifier=identifier, session=session)

    if task is not None:
        return task.to_response()
    else:
        raise HTTPException(status_code=404, detail="Identifier not found")


@task_router.get(

    "/file/{identifier}",

    status_code=status.HTTP_200_OK,

    summary="Retrieve FileResponse Task by Identifier",

    description="Retrieve the file response task by its identifier. You can use this endpoint if you need to download"

                " The file as a response",

)
async def get_file_task(

    identifier: str,

    session: Session = Depends(get_db_session),

) -> FileResponse:
    """

    Retrieve the downloadable file response of a specific task by its identifier.

    Compressed by ZIP basically.

    """
    task = get_task_status_from_db(identifier=identifier, session=session)

    if task is not None:
        if task.task_type == TaskType.BGM_SEPARATION:
            output_zip_path = os.path.join(BACKEND_CACHE_DIR, f"{identifier}_bgm_separation.zip")
            instrumental_path = find_file_by_hash(
                os.path.join(BACKEND_CACHE_DIR, "UVR", "instrumental"),
                task.result["instrumental_hash"]
            )
            vocal_path = find_file_by_hash(
                os.path.join(BACKEND_CACHE_DIR, "UVR", "vocals"),
                task.result["vocal_hash"]
            )

            output_zip_path = compress_files(
                [instrumental_path, vocal_path],
                output_zip_path
            )
            return FileResponse(
                path=output_zip_path,
                status_code=200,
                filename=output_zip_path,
                media_type="application/zip"
            )
        else:
            raise HTTPException(status_code=404, detail=f"File download is only supported for bgm separation."
                                                        f" The given type is {task.task_type}")
    else:
        raise HTTPException(status_code=404, detail="Identifier not found")


# Delete method, commented by default because this endpoint is likely to require special permissions
# @task_router.delete(
#     "/{identifier}",
#     response_model=Response,
#     status_code=status.HTTP_200_OK,
#     summary="Delete Task by Identifier",
#     description="Delete a task from the system using its identifier.",
# )
async def delete_task(

    identifier: str,

    session: Session = Depends(get_db_session),

) -> Response:
    """

    Delete a task by its identifier.

    """
    if delete_task_from_db(identifier, session):
        return Response(identifier=identifier, message="Task deleted")
    else:
        raise HTTPException(status_code=404, detail="Task not found")


# Get All method, commented by default because this endpoint is likely to require special permissions
# @task_router.get(
#     "/all",
#     response_model=TasksResult,
#     status_code=status.HTTP_200_OK,
#     summary="Retrieve All Task Statuses",
#     description="Retrieve the statuses of all tasks available in the system.",
# )
async def get_all_tasks_status(

    session: Session = Depends(get_db_session),

) -> TasksResult:
    """

    Retrieve all tasks.

    """
    return get_all_tasks_status_from_db(session=session)