Sofia Casadei
fix: connection
623d479
raw
history blame
6.39 kB
import os
import logging
import json
import torch
import asyncio
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastrtc import (
AdditionalOutputs,
ReplyOnPause,
Stream,
AlgoOptions,
SileroVadOptions,
audio_to_bytes,
get_cloudflare_turn_credentials_async,
)
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
)
from transformers.utils import is_flash_attn_2_available
from utils.logger_config import setup_logging
from utils.device import get_device, get_torch_and_np_dtypes
load_dotenv()
setup_logging()
logger = logging.getLogger(__name__)
UI_MODE = os.getenv("UI_MODE", "fastapi").lower() # gradio | fastapi
UI_TYPE = os.getenv("UI_TYPE", "base").lower() # base | screen
APP_MODE = os.getenv("APP_MODE", "local").lower() # local | deployed
MODEL_ID = os.getenv("MODEL_ID", "openai/whisper-large-v3-turbo")
LANGUAGE = os.getenv("LANGUAGE", "english").lower()
device = get_device(force_cpu=False)
torch_dtype, np_dtype = get_torch_and_np_dtypes(device, use_bfloat16=False)
logger.info(f"Using device: {device}, torch_dtype: {torch_dtype}, np_dtype: {np_dtype}")
attention = "flash_attention_2" if is_flash_attn_2_available() else "sdpa"
logger.info(f"Using attention: {attention}")
logger.info(f"Loading Whisper model: {MODEL_ID}")
logger.info(f"Using language: {LANGUAGE}")
try:
model = AutoModelForSpeechSeq2Seq.from_pretrained(
MODEL_ID,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation=attention
)
model.to(device)
except Exception as e:
logger.error(f"Error loading ASR model: {e}")
logger.error(f"Are you providing a valid model ID? {MODEL_ID}")
raise
processor = AutoProcessor.from_pretrained(MODEL_ID)
transcribe_pipeline = pipeline(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune")
# Warm up the model with empty audio
logger.info("Warming up Whisper model with dummy input")
warmup_audio = np.zeros((16000,), dtype=np_dtype) # 1s of silence
transcribe_pipeline(warmup_audio)
logger.info("Model warmup complete")
async def transcribe(audio: tuple[int, np.ndarray]):
sample_rate, audio_array = audio
logger.info(f"Sample rate: {sample_rate}Hz, Shape: {audio_array.shape}")
outputs = transcribe_pipeline(
audio_to_bytes(audio),
chunk_length_s=3,
batch_size=2,
generate_kwargs={
'task': 'transcribe',
'language': LANGUAGE,
},
#return_timestamps="word"
)
yield AdditionalOutputs(outputs["text"].strip())
logger.info("Initializing FastRTC stream")
stream = Stream(
handler=ReplyOnPause(
transcribe,
algo_options=AlgoOptions(
# Duration in seconds of audio chunks (default 0.6)
audio_chunk_duration=0.6,
# If the chunk has more than started_talking_threshold seconds of speech, the user started talking (default 0.2)
started_talking_threshold=0.1,
# If, after the user started speaking, there is a chunk with less than speech_threshold seconds of speech, the user stopped speaking. (default 0.1)
speech_threshold=0.1,
),
model_options=SileroVadOptions(
# Threshold for what is considered speech (default 0.5)
threshold=0.5,
# Final speech chunks shorter min_speech_duration_ms are thrown out (default 250)
min_speech_duration_ms=250,
# Max duration of speech chunks, longer will be split at the timestamp of the last silence
# that lasts more than 100ms (if any) or just before max_speech_duration_s (default float('inf'))
max_speech_duration_s=3,
# Wait for ms at the end of each speech chunk before separating it (default 2000)
min_silence_duration_ms=100,
# Chunk size for VAD model. Can be 512, 1024, 1536 for 16k s.r. (default 1024)
window_size_samples=512,
# Final speech chunks are padded by speech_pad_ms each side (default 400)
speech_pad_ms=200,
),
),
# send-receive: bidirectional streaming (default)
# send: client to server only
# receive: server to client only
modality="audio",
mode="send",
additional_outputs=[
gr.Textbox(label="Transcript"),
],
additional_outputs_handler=lambda current, new: current + " " + new,
rtc_configuration=get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None,
concurrency_limit=6
)
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
stream.mount(app)
@app.get("/")
async def index():
if UI_TYPE == "base":
html_content = open("static/index.html").read()
elif UI_TYPE == "screen":
html_content = open("static/index-screen.html").read()
rtc_configuration = await get_cloudflare_turn_credentials_async(hf_token=os.getenv("HF_TOKEN")) if APP_MODE == "deployed" else None
logger.info(f"RTC configuration: {rtc_configuration}")
html_content = html_content.replace("__INJECTED_RTC_CONFIG__", json.dumps(rtc_configuration))
return HTMLResponse(content=html_content)
@app.get("/transcript")
def _(webrtc_id: str):
logger.debug(f"New transcript stream request for webrtc_id: {webrtc_id}")
async def output_stream():
try:
async for output in stream.output_stream(webrtc_id):
transcript = output.args[0]
logger.debug(f"Sending transcript for {webrtc_id}: {transcript[:50]}...")
yield f"event: output\ndata: {transcript}\n\n"
except Exception as e:
logger.error(f"Error in transcript stream for {webrtc_id}: {str(e)}")
raise
return StreamingResponse(output_stream(), media_type="text/event-stream")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="localhost", port=7860)