Spaces:
Running
Running
File size: 3,653 Bytes
9aaf513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import functools
import uuid
import numpy as np
from fastapi import (
File,
UploadFile,
)
import gradio as gr
from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
from typing import List, Dict
from sqlalchemy.orm import Session
from datetime import datetime
from modules.whisper.data_classes import *
from modules.utils.paths import BACKEND_CACHE_DIR
from modules.whisper.faster_whisper_inference import FasterWhisperInference
from backend.common.audio import read_audio
from backend.common.models import QueueResponse
from backend.common.config_loader import load_server_config
from backend.db.task.dao import (
add_task_to_db,
get_db_session,
update_task_status_in_db
)
from backend.db.task.models import TaskStatus, TaskType
transcription_router = APIRouter(prefix="/transcription", tags=["Transcription"])
@functools.lru_cache
def get_pipeline() -> 'FasterWhisperInference':
config = load_server_config()["whisper"]
inferencer = FasterWhisperInference(
output_dir=BACKEND_CACHE_DIR
)
inferencer.update_model(
model_size=config["model_size"],
compute_type=config["compute_type"]
)
return inferencer
def run_transcription(
audio: np.ndarray,
params: TranscriptionPipelineParams,
identifier: str,
) -> List[Segment]:
update_task_status_in_db(
identifier=identifier,
update_data={
"uuid": identifier,
"status": TaskStatus.IN_PROGRESS,
"updated_at": datetime.utcnow()
},
)
segments, elapsed_time = get_pipeline().run(
audio,
gr.Progress(),
"SRT",
False,
*params.to_list()
)
segments = [seg.model_dump() for seg in segments]
update_task_status_in_db(
identifier=identifier,
update_data={
"uuid": identifier,
"status": TaskStatus.COMPLETED,
"result": segments,
"updated_at": datetime.utcnow(),
"duration": elapsed_time
},
)
return segments
@transcription_router.post(
"/",
response_model=QueueResponse,
status_code=status.HTTP_201_CREATED,
summary="Transcribe Audio",
description="Process the provided audio or video file to generate a transcription.",
)
async def transcription(
background_tasks: BackgroundTasks,
file: UploadFile = File(..., description="Audio or video file to transcribe."),
whisper_params: WhisperParams = Depends(),
vad_params: VadParams = Depends(),
bgm_separation_params: BGMSeparationParams = Depends(),
diarization_params: DiarizationParams = Depends(),
) -> QueueResponse:
if not isinstance(file, np.ndarray):
audio, info = await read_audio(file=file)
else:
audio, info = file, None
params = TranscriptionPipelineParams(
whisper=whisper_params,
vad=vad_params,
bgm_separation=bgm_separation_params,
diarization=diarization_params
)
identifier = add_task_to_db(
status=TaskStatus.QUEUED,
file_name=file.filename,
audio_duration=info.duration if info else None,
language=params.whisper.lang,
task_type=TaskType.TRANSCRIPTION,
task_params=params.to_dict(),
)
background_tasks.add_task(
run_transcription,
audio=audio,
params=params,
identifier=identifier,
)
return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="Transcription task has queued")
|