import os import logging import json import torch import asyncio import subprocess import gradio as gr import numpy as np from dotenv import load_dotenv from fastapi import FastAPI from fastapi.responses import StreamingResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastrtc import ( AdditionalOutputs, ReplyOnPause, Stream, AlgoOptions, SileroVadOptions, audio_to_bytes, get_cloudflare_turn_credentials_async, ) from transformers import ( AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, ) from transformers.utils import is_flash_attn_2_available from utils.logger_config import setup_logging from utils.device import get_device, get_torch_and_np_dtypes load_dotenv() setup_logging() logger = logging.getLogger(__name__) UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo") LANGUAGE = os.getenv("LANGUAGE", "english").lower() device = get_device(force_cpu=False) torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False) logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}") attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" logger.info(f"Using attention: {attention}") logger.info(f"Loading Whisper model: {MODEL_ID}") logger.info(f"Using language: {LANGUAGE}") try: model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation=attention ) model.to(device) except Exception as e: logger.error(f"Error loading ASR model: {e}") logger.error(f"Are you providing a valid model ID? {MODEL_ID}") raise processor = AutoProcessor.from_pretrained(MODEL_ID) transcribe_pipeline = pipeline( task="automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune") # Warm up the model with empty audio logger.info("Warming up Whisper model with dummy input") warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence transcribe_pipeline(warmup_audio) logger.info("Model warmup complete") async def transcribe(audio: tuple[int, np.ndarray]): sample_rate, audio_array = audio logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}") outputs = transcribe_pipeline( audio_to_bytes(audio), chunk_length_s=3, batch_size=2, generate_kwargs={ 'task': 'transcribe', 'language': LANGUAGE, }, #return_timestamps="word" ) yield AdditionalOutputs(outputs["text"].strip()) logger.info("Initializing FastRTC stream") stream = Stream( handler=ReplyOnPause( transcribe, algo_options=AlgoOptions( # Duration in seconds of audio chunks (default 0.6) audio_chunk_duration=0.6, # If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2) started_talking_threshold=0.1, # If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1) speech_threshold=0.1, ), model_options=SileroVadOptions( # Threshold for what is considered speech (default 0.5) threshold=0.5, # Final speech chunks shorter min_speech_duration_ms are thrown out (default 250) min_speech_duration_ms=250, # Max duration of speech chunks, longer will be split at the timestamp of the last silence # that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf')) max_speech_duration_s=3, # Wait for ms at the end of each speech chunk before separating it (default 2000) min_silence_duration_ms=100, # Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024) window_size_samples=512, # Final speech chunks are padded by speech_pad_ms each side (default 400) speech_pad_ms=200, ), ), # send-receive: bidirectional streaming (default) # send: client to server only # receive: server to client only modality="audio", mode="send", additional_outputs=[ gr.Textbox(label="Transcript"), ], additional_outputs_handler=lambda current, new: current + " " + new, rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None, concurrency_limit=6 ) app = FastAPI() app.mount("/static", StaticFiles(directory="static"), name="static") stream.mount(app) @app.get("/") async def index(): if UI_TYPE == "base": html_content = open("static/index.html").read() elif UI_TYPE == "screen": html_content = open("static/index-screen.html").read() rtc_configuration = await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None logger.info(f"RTC configuration: {rtc_configuration}") html_content = html_content.replace("__INJECTED_RTC_CONFIG__", json.dumps(rtc_configuration)) return HTMLResponse(content=html_content) @app.get("/transcript") def _(webrtc_id: str): logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}") async def output_stream(): try: async for output in stream.output_stream(webrtc_id): transcript = output.args[0] logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...") yield f"event: output\ndata: {transcript}\n\n" except Exception as e: logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}") raise return StreamingResponse(output_stream(), media_type="text/event-stream") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="localhost", port=7860)