Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import json | |
import torch | |
import asyncio | |
import gradio as gr | |
import numpy as np | |
from dotenv import load_dotenv | |
from fastapi import FastAPI | |
from fastapi.responses import StreamingResponse, HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastrtc import ( | |
AdditionalOutputs, | |
ReplyOnPause, | |
Stream, | |
AlgoOptions, | |
SileroVadOptions, | |
audio_to_bytes, | |
get_cloudflare_turn_credentials_async, | |
) | |
from transformers import ( | |
AutoModelForSpeechSeq2Seq, | |
AutoProcessor, | |
pipeline, | |
) | |
from transformers.utils import is_flash_attn_2_available | |
from utils.logger_config import setup_logging | |
from utils.device import get_device, get_torch_and_np_dtypes | |
load_dotenv() | |
setup_logging() | |
logger = logging.getLogger(__name__) | |
UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi | |
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen | |
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed | |
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo") | |
LANGUAGE = os.getenv("LANGUAGE", "english").lower() | |
device = get_device(force_cpu=False) | |
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False) | |
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}") | |
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" | |
logger.info(f"Using attention: {attention}") | |
logger.info(f"Loading Whisper model: {MODEL_ID}") | |
logger.info(f"Using language: {LANGUAGE}") | |
try: | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True, | |
attn_implementation=attention | |
) | |
model.to(device) | |
except Exception as e: | |
logger.error(f"Error loading ASR model: {e}") | |
logger.error(f"Are you providing a valid model ID? {MODEL_ID}") | |
raise | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
transcribe_pipeline = pipeline( | |
task="automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune") | |
# Warm up the model with empty audio | |
logger.info("Warming up Whisper model with dummy input") | |
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence | |
transcribe_pipeline(warmup_audio) | |
logger.info("Model warmup complete") | |
async def transcribe(audio: tuple[int, np.ndarray]): | |
sample_rate, audio_array = audio | |
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}") | |
outputs = transcribe_pipeline( | |
audio_to_bytes(audio), | |
chunk_length_s=3, | |
batch_size=2, | |
generate_kwargs={ | |
'task': 'transcribe', | |
'language': LANGUAGE, | |
}, | |
#return_timestamps="word" | |
) | |
yield AdditionalOutputs(outputs["text"].strip()) | |
logger.info("Initializing FastRTC stream") | |
stream = Stream( | |
handler=ReplyOnPause( | |
transcribe, | |
algo_options=AlgoOptions( | |
# Duration in seconds of audio chunks (default 0.6) | |
audio_chunk_duration=0.6, | |
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2) | |
started_talking_threshold=0.1, | |
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1) | |
speech_threshold=0.1, | |
), | |
model_options=SileroVadOptions( | |
# Threshold for what is considered speech (default 0.5) | |
threshold=0.5, | |
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250) | |
min_speech_duration_ms=250, | |
# Max duration of speech chunks, longer will be split at the timestamp of the last silence | |
# that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf')) | |
max_speech_duration_s=3, | |
# Wait for ms at the end of each speech chunk before separating it (default 2000) | |
min_silence_duration_ms=100, | |
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024) | |
window_size_samples=512, | |
# Final speech chunks are padded by speech_pad_ms each side (default 400) | |
speech_pad_ms=200, | |
), | |
), | |
# send-receive: bidirectional streaming (default) | |
# send: client to server only | |
# receive: server to client only | |
modality="audio", | |
mode="send", | |
additional_outputs=[ | |
gr.Textbox(label="Transcript"), | |
], | |
additional_outputs_handler=lambda current, new: current + " " + new, | |
rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None, | |
concurrency_limit=6 | |
) | |
app = FastAPI() | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
stream.mount(app) | |
async def index(): | |
if UI_TYPE == "base": | |
html_content = open("static/index.html").read() | |
elif UI_TYPE == "screen": | |
html_content = open("static/index-screen.html").read() | |
rtc_configuration = await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None | |
logger.info(f"RTC configuration: {rtc_configuration}") | |
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_configuration)) | |
return HTMLResponse(content=html_content) | |
def _(webrtc_id: str): | |
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}") | |
async def output_stream(): | |
try: | |
async for output in stream.output_stream(webrtc_id): | |
transcript = output.args[0] | |
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...") | |
yield f"event: output\ndata: {transcript}\n\n" | |
except Exception as e: | |
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}") | |
raise | |
return StreamingResponse(output_stream(), media_type="text/event-stream") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="localhost", port=7860) |