jameszokah commited on
Commit
be2a132
·
1 Parent(s): 7e09504

Integrate WhisperX for improved audio transcription and add real-time conversation support: update requirements to include WhisperX, refactor voice cloning to utilize WhisperX, implement WebSocket endpoints for real-time audio processing, and enhance audio transcription capabilities with alignment options.

Browse files
app/api/realtime.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Real-time audio conversation with WebSockets.
2
+
3
+ This module provides WebSocket endpoints for real-time audio conversation
4
+ using the CSM-1B model and WhisperX for transcription.
5
+ """
6
+ import os
7
+ import io
8
+ import base64
9
+ import json
10
+ import time
11
+ import asyncio
12
+ import logging
13
+ import tempfile
14
+ from enum import Enum
15
+ from typing import Dict, List, Optional, Any, Union
16
+ import numpy as np
17
+ import torch
18
+ import torchaudio
19
+ from pydub import AudioSegment
20
+ import whisperx
21
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Request
22
+ from fastapi.responses import JSONResponse
23
+
24
+ # Set up logging
25
+ logger = logging.getLogger(__name__)
26
+ router = APIRouter(prefix="/realtime", tags=["Real-time Conversation"])
27
+
28
+ # Audio processing constants
29
+ SAMPLE_RATE = 16000 # Sample rate for audio processing
30
+ CHUNK_SIZE = 4096 # Chunk size for audio processing
31
+ MAX_AUDIO_DURATION = 10 # Maximum audio duration in seconds
32
+ SILENCE_THRESHOLD = 400 # Threshold for detecting silence (RMS)
33
+ MIN_SILENCE_DURATION = 0.5 # Minimum silence duration to consider a pause
34
+
35
+ # WebSocket message types
36
+ class MessageType(str, Enum):
37
+ AUDIO_CHUNK = "audio_chunk"
38
+ TRANSCRIPT = "transcript"
39
+ RESPONSE = "response"
40
+ START_SPEAKING = "start_speaking"
41
+ STOP_SPEAKING = "stop_speaking"
42
+ ERROR = "error"
43
+ STATUS = "status"
44
+
45
+ # WhisperX model cache for performance
46
+ _whisperx_model = None
47
+ _whisperx_model_lock = asyncio.Lock()
48
+
49
+ # Connection manager for websockets
50
+ class ConnectionManager:
51
+ def __init__(self):
52
+ self.active_connections: Dict[str, WebSocket] = {}
53
+ self.conversation_contexts: Dict[str, List] = {}
54
+ self.voice_preferences: Dict[str, int] = {} # Store voice preferences by client_id
55
+
56
+ async def connect(self, websocket: WebSocket, client_id: str):
57
+ """Connect a client to the WebSocket"""
58
+ await websocket.accept()
59
+ self.active_connections[client_id] = websocket
60
+ self.conversation_contexts[client_id] = []
61
+ self.voice_preferences[client_id] = 1 # Default to echo voice
62
+ logger.info(f"Client {client_id} connected, active connections: {len(self.active_connections)}")
63
+
64
+ def disconnect(self, client_id: str):
65
+ """Disconnect a client from the WebSocket"""
66
+ if client_id in self.active_connections:
67
+ del self.active_connections[client_id]
68
+ if client_id in self.conversation_contexts:
69
+ del self.conversation_contexts[client_id]
70
+ if client_id in self.voice_preferences:
71
+ del self.voice_preferences[client_id]
72
+ logger.info(f"Client {client_id} disconnected, active connections: {len(self.active_connections)}")
73
+
74
+ def set_voice_preference(self, client_id: str, speaker_id: int):
75
+ """Set voice preference for a client"""
76
+ self.voice_preferences[client_id] = speaker_id
77
+
78
+ def get_voice_preference(self, client_id: str) -> int:
79
+ """Get voice preference for a client"""
80
+ return self.voice_preferences.get(client_id, 1) # Default to echo (speaker_id=1)
81
+
82
+ async def send_message(self, client_id: str, message_type: MessageType, data: Any):
83
+ """Send a message to a client"""
84
+ if client_id in self.active_connections:
85
+ message = {
86
+ "type": message_type,
87
+ "data": data,
88
+ "timestamp": time.time()
89
+ }
90
+ await self.active_connections[client_id].send_json(message)
91
+
92
+ def add_to_context(self, client_id: str, speaker: int, text: str, audio: Union[torch.Tensor, bytes]):
93
+ """Add a message to the conversation context"""
94
+ if client_id in self.conversation_contexts:
95
+ # Convert audio tensor to base64 if needed
96
+ if isinstance(audio, torch.Tensor):
97
+ audio_bytes = convert_tensor_to_wav_bytes(audio)
98
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
99
+ elif isinstance(audio, bytes):
100
+ audio_base64 = base64.b64encode(audio).decode('utf-8')
101
+ else:
102
+ raise ValueError(f"Unsupported audio type: {type(audio)}")
103
+
104
+ # Add to context, limiting size to last 5 exchanges
105
+ self.conversation_contexts[client_id].append({
106
+ "speaker": speaker,
107
+ "text": text,
108
+ "audio": audio_base64
109
+ })
110
+
111
+ # Limit context size (keep last 5 exchanges to prevent context growing too large)
112
+ if len(self.conversation_contexts[client_id]) > 5:
113
+ self.conversation_contexts[client_id] = self.conversation_contexts[client_id][-5:]
114
+
115
+ def get_context(self, client_id: str) -> List[Dict]:
116
+ """Get the conversation context for a client"""
117
+ return self.conversation_contexts.get(client_id, [])
118
+
119
+ # Initialize connection manager
120
+ manager = ConnectionManager()
121
+
122
+ async def load_whisperx_model(compute_type="float16"):
123
+ """Load WhisperX model if not already loaded"""
124
+ global _whisperx_model
125
+
126
+ # Use lock to ensure model loading is thread-safe
127
+ async with _whisperx_model_lock:
128
+ # Load WhisperX model if not already loaded
129
+ if _whisperx_model is None:
130
+ logger.info("Loading WhisperX model for real-time transcription")
131
+ device = "cuda" if torch.cuda.is_available() else "cpu"
132
+ # Use small model for lower latency
133
+ _whisperx_model = whisperx.load_model(
134
+ "small", # Small model for faster processing in real-time
135
+ device,
136
+ compute_type=compute_type,
137
+ asr_options={"beam_size": 5, "vad_onset": 0.5, "vad_offset": 0.5}
138
+ )
139
+ logger.info(f"WhisperX model loaded on {device} with compute_type={compute_type}")
140
+
141
+ return _whisperx_model
142
+
143
+ def convert_tensor_to_wav_bytes(audio_tensor: torch.Tensor) -> bytes:
144
+ """Convert audio tensor to WAV bytes"""
145
+ buf = io.BytesIO()
146
+ if len(audio_tensor.shape) == 1:
147
+ audio_tensor = audio_tensor.unsqueeze(0)
148
+ torchaudio.save(buf, audio_tensor.cpu(), SAMPLE_RATE, format="wav")
149
+ buf.seek(0)
150
+ return buf.read()
151
+
152
+ def convert_audio_data(audio_data: bytes) -> torch.Tensor:
153
+ """Convert audio data to tensor"""
154
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp:
155
+ temp.write(audio_data)
156
+ temp.flush()
157
+
158
+ # Load audio
159
+ try:
160
+ # First try with torchaudio
161
+ waveform, sample_rate = torchaudio.load(temp.name)
162
+
163
+ # Convert to mono if needed
164
+ if waveform.shape[0] > 1:
165
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
166
+
167
+ # Resample if needed
168
+ if sample_rate != SAMPLE_RATE:
169
+ waveform = torchaudio.functional.resample(
170
+ waveform, orig_freq=sample_rate, new_freq=SAMPLE_RATE
171
+ )
172
+
173
+ return waveform.squeeze(0)
174
+ except:
175
+ # Fallback to pydub if torchaudio fails
176
+ audio = AudioSegment.from_file(temp.name)
177
+
178
+ # Convert to mono if needed
179
+ if audio.channels > 1:
180
+ audio = audio.set_channels(1)
181
+
182
+ # Resample if needed
183
+ if audio.frame_rate != SAMPLE_RATE:
184
+ audio = audio.set_frame_rate(SAMPLE_RATE)
185
+
186
+ # Convert to numpy array
187
+ samples = np.array(audio.get_array_of_samples(), dtype=np.float32) / 32768.0
188
+
189
+ # Convert to tensor
190
+ waveform = torch.tensor(samples, dtype=torch.float32)
191
+ return waveform
192
+
193
+ async def transcribe_audio(audio_data: bytes, language: Optional[str] = None) -> Dict:
194
+ """Transcribe audio using WhisperX"""
195
+ # Load WhisperX model
196
+ model = await load_whisperx_model()
197
+
198
+ # Save audio to temporary file
199
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp:
200
+ temp.write(audio_data)
201
+ temp.flush()
202
+
203
+ # Transcribe with WhisperX
204
+ device = "cuda" if torch.cuda.is_available() else "cpu"
205
+ result = model.transcribe(
206
+ temp.name,
207
+ language=language,
208
+ batch_size=16 if device == "cuda" else 1
209
+ )
210
+
211
+ return result
212
+
213
+ async def generate_response(app, text: str, speaker_id: int, context: List[Dict]) -> torch.Tensor:
214
+ """Generate response using CSM-1B model"""
215
+ generator = app.state.generator
216
+
217
+ # Validate model availability
218
+ if generator is None:
219
+ raise RuntimeError("TTS model not loaded")
220
+
221
+ # Setup context segments
222
+ segments = []
223
+ for ctx in context:
224
+ if 'speaker' not in ctx or 'text' not in ctx or 'audio' not in ctx:
225
+ continue
226
+
227
+ # Decode base64 audio
228
+ audio_data = base64.b64decode(ctx['audio'])
229
+
230
+ # Convert to tensor
231
+ audio_tensor = convert_audio_data(audio_data)
232
+
233
+ # Create segment
234
+ segments.append({
235
+ "speaker": ctx['speaker'],
236
+ "text": ctx['text'],
237
+ "audio": audio_tensor
238
+ })
239
+
240
+ # Format text for better voice consistency
241
+ from app.prompt_engineering import format_text_for_voice
242
+
243
+ # Determine voice name from speaker_id
244
+ voice_names = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
245
+ voice_name = voice_names[speaker_id] if 0 <= speaker_id < len(voice_names) else "alloy"
246
+
247
+ formatted_text = format_text_for_voice(text, voice_name)
248
+
249
+ # Generate audio with context
250
+ audio = generator.generate(
251
+ text=formatted_text,
252
+ speaker=speaker_id,
253
+ context=segments,
254
+ max_audio_length_ms=10000, # 10 seconds max for low latency
255
+ temperature=0.65, # Lower temperature for more stable output
256
+ topk=40,
257
+ )
258
+
259
+ # Process audio for better quality
260
+ from app.voice_enhancement import process_generated_audio
261
+
262
+ processed_audio = process_generated_audio(
263
+ audio,
264
+ voice_name,
265
+ generator.sample_rate,
266
+ text
267
+ )
268
+
269
+ return processed_audio
270
+
271
+ def is_silence(audio_data: bytes, threshold=SILENCE_THRESHOLD) -> bool:
272
+ """Check if audio is silence"""
273
+ with io.BytesIO(audio_data) as buf:
274
+ try:
275
+ audio = AudioSegment.from_file(buf)
276
+ # Get RMS (root mean square) amplitude
277
+ rms = audio.rms
278
+ return rms < threshold
279
+ except:
280
+ # If can't process, assume not silent
281
+ return False
282
+
283
+ @router.websocket("/conversation/{client_id}")
284
+ async def websocket_conversation(websocket: WebSocket, client_id: str, request: Request):
285
+ """WebSocket endpoint for real-time audio conversation"""
286
+ await manager.connect(websocket, client_id)
287
+
288
+ # Validate model availability
289
+ if not hasattr(request.app.state, "generator") or request.app.state.generator is None:
290
+ await manager.send_message(client_id, MessageType.ERROR,
291
+ {"message": "TTS model not available"})
292
+ manager.disconnect(client_id)
293
+ return
294
+
295
+ # Initialize audio buffer and state
296
+ audio_buffer = io.BytesIO()
297
+ is_speaking = False
298
+ silence_start = None
299
+
300
+ try:
301
+ # Tell client we're ready
302
+ await manager.send_message(client_id, MessageType.STATUS,
303
+ {"status": "ready", "message": "Connection established"})
304
+
305
+ # Process messages
306
+ async for message in websocket.iter_json():
307
+ message_type = message.get("type")
308
+
309
+ if message_type == "audio_chunk":
310
+ # Get audio data
311
+ audio_data = base64.b64decode(message["data"])
312
+
313
+ # Check if silence or speech
314
+ current_is_silence = is_silence(audio_data)
315
+
316
+ # Handle silence detection for end of speech
317
+ if current_is_silence:
318
+ if not silence_start:
319
+ silence_start = time.time()
320
+ elif time.time() - silence_start > MIN_SILENCE_DURATION and is_speaking:
321
+ # End of speech detected
322
+ is_speaking = False
323
+
324
+ # Get audio from buffer
325
+ audio_buffer.seek(0)
326
+ full_audio = audio_buffer.read()
327
+
328
+ # Reset buffer
329
+ audio_buffer = io.BytesIO()
330
+
331
+ # Process the complete audio asynchronously
332
+ asyncio.create_task(process_complete_audio(
333
+ request.app, client_id, full_audio
334
+ ))
335
+
336
+ # Notify client of end of speech
337
+ await manager.send_message(client_id, MessageType.STOP_SPEAKING, {})
338
+ else:
339
+ # Reset silence detection on new speech
340
+ silence_start = None
341
+
342
+ # Start of speech if not already speaking
343
+ if not is_speaking:
344
+ is_speaking = True
345
+ await manager.send_message(client_id, MessageType.START_SPEAKING, {})
346
+
347
+ # Add chunk to buffer if speaking
348
+ if is_speaking:
349
+ audio_buffer.write(audio_data)
350
+
351
+ elif message_type == "end_audio":
352
+ # Explicit end of audio from client
353
+ if audio_buffer.tell() > 0:
354
+ # Get audio from buffer
355
+ audio_buffer.seek(0)
356
+ full_audio = audio_buffer.read()
357
+
358
+ # Reset buffer
359
+ audio_buffer = io.BytesIO()
360
+ is_speaking = False
361
+
362
+ # Process the complete audio asynchronously
363
+ asyncio.create_task(process_complete_audio(
364
+ request.app, client_id, full_audio
365
+ ))
366
+
367
+ elif message_type == "set_voice":
368
+ # Set the voice for the response
369
+ voice = message.get("voice", "alloy")
370
+
371
+ # Map voice string to speaker ID
372
+ voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5}
373
+ speaker_id = voice_to_speaker.get(voice, 0)
374
+
375
+ # Store in client state
376
+ manager.set_voice_preference(client_id, speaker_id)
377
+
378
+ # Send confirmation to client
379
+ await manager.send_message(client_id, MessageType.STATUS,
380
+ {"status": "voice_set", "voice": voice, "speaker_id": speaker_id})
381
+
382
+ elif message_type == "clear_context":
383
+ # Clear the conversation context
384
+ if client_id in manager.conversation_contexts:
385
+ manager.conversation_contexts[client_id] = []
386
+ await manager.send_message(client_id, MessageType.STATUS,
387
+ {"status": "context_cleared"})
388
+
389
+ except WebSocketDisconnect:
390
+ logger.info(f"Client {client_id} disconnected")
391
+ except Exception as e:
392
+ logger.error(f"Error in websocket conversation: {e}", exc_info=True)
393
+ try:
394
+ await manager.send_message(client_id, MessageType.ERROR,
395
+ {"message": str(e)})
396
+ except:
397
+ pass
398
+ finally:
399
+ manager.disconnect(client_id)
400
+
401
+ async def process_complete_audio(app, client_id: str, audio_data: bytes):
402
+ """Process complete audio chunk from WebSocket"""
403
+ try:
404
+ # Transcribe audio
405
+ transcription = await transcribe_audio(audio_data)
406
+
407
+ # Get the text
408
+ text = transcription.get("text", "").strip()
409
+
410
+ # Send transcription to client
411
+ await manager.send_message(client_id, MessageType.TRANSCRIPT,
412
+ {"text": text, "segments": transcription.get("segments", [])})
413
+
414
+ # Skip if empty text
415
+ if not text:
416
+ return
417
+
418
+ # Add user message to context (user is always speaker 0)
419
+ manager.add_to_context(client_id, 0, text, audio_data)
420
+
421
+ # Get current context
422
+ context = manager.get_context(client_id)
423
+
424
+ # Generate response
425
+ voice_id = manager.get_voice_preference(client_id)
426
+ response_audio = await generate_response(app, text, voice_id, context)
427
+
428
+ # Convert to bytes
429
+ response_bytes = convert_tensor_to_wav_bytes(response_audio)
430
+ response_base64 = base64.b64encode(response_bytes).decode('utf-8')
431
+
432
+ # Send response to client
433
+ await manager.send_message(client_id, MessageType.RESPONSE, {
434
+ "audio": response_base64,
435
+ "speaker_id": voice_id
436
+ })
437
+
438
+ # Add assistant response to context
439
+ manager.add_to_context(client_id, voice_id, text, response_audio)
440
+
441
+ except Exception as e:
442
+ logger.error(f"Error processing audio: {e}", exc_info=True)
443
+ await manager.send_message(client_id, MessageType.ERROR,
444
+ {"message": f"Error processing audio: {str(e)}"})
app/api/routes.py CHANGED
@@ -13,8 +13,9 @@ from typing import Dict, List, Optional, Any, Union
13
  import torch
14
  import torchaudio
15
  import numpy as np
16
- from fastapi import APIRouter, Request, HTTPException, BackgroundTasks, Body, Response, Query
17
- from fastapi.responses import StreamingResponse
 
18
  from app.api.schemas import SpeechRequest, ResponseFormat, Voice
19
  from app.model import Segment
20
  from app.api.streaming import AudioChunker
@@ -33,6 +34,10 @@ MIME_TYPES = {
33
  "wav": "audio/wav",
34
  }
35
 
 
 
 
 
36
  def get_speaker_id(app_state, voice):
37
  """Helper function to get speaker ID from voice name or ID"""
38
  if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map:
@@ -1045,4 +1050,127 @@ async def debug_speech(
1045
  "status": "error",
1046
  "message": str(e),
1047
  "traceback": error_trace
1048
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  import torch
14
  import torchaudio
15
  import numpy as np
16
+ import whisperx
17
+ from fastapi import APIRouter, Request, HTTPException, BackgroundTasks, Body, Response, Query, UploadFile, File
18
+ from fastapi.responses import StreamingResponse, JSONResponse
19
  from app.api.schemas import SpeechRequest, ResponseFormat, Voice
20
  from app.model import Segment
21
  from app.api.streaming import AudioChunker
 
34
  "wav": "audio/wav",
35
  }
36
 
37
+ # WhisperX model cache for reuse
38
+ whisperx_model = None
39
+ whisperx_model_lock = asyncio.Lock()
40
+
41
  def get_speaker_id(app_state, voice):
42
  """Helper function to get speaker ID from voice name or ID"""
43
  if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map:
 
1050
  "status": "error",
1051
  "message": str(e),
1052
  "traceback": error_trace
1053
+ }
1054
+
1055
+ @router.post("/audio/transcribe", tags=["Audio"], summary="Transcribe audio to text")
1056
+ async def transcribe_audio(
1057
+ request: Request,
1058
+ audio: UploadFile = File(...),
1059
+ language: Optional[str] = Query(None, description="Language code (e.g., 'en', 'fr', 'de')"),
1060
+ align_text: bool = Query(False, description="Whether to align text with timestamps"),
1061
+ compute_type: str = Query("float16", description="Compute type for model inference (float16, int8, float32)"),
1062
+ ):
1063
+ """
1064
+ Transcribe spoken audio to text using WhisperX (faster and more accurate).
1065
+
1066
+ Upload audio as a file in any common format (mp3, wav, etc.).
1067
+
1068
+ **Parameters:**
1069
+ - `audio`: Audio file to transcribe
1070
+ - `language`: Optional language code (auto-detected if not provided)
1071
+ - `align_text`: Whether to include word-level timestamps
1072
+ - `compute_type`: Compute type for model inference (float16, int8, float32)
1073
+
1074
+ **Response:**
1075
+ ```json
1076
+ {
1077
+ "text": "Transcribed text",
1078
+ "segments": [
1079
+ {
1080
+ "start": 0.0,
1081
+ "end": 2.5,
1082
+ "text": "Segment text"
1083
+ }
1084
+ ],
1085
+ "word_timestamps": [] // If align_text is true
1086
+ }
1087
+ ```
1088
+ """
1089
+ global whisperx_model
1090
+
1091
+ # Create temp directory to store uploaded file
1092
+ with tempfile.TemporaryDirectory() as temp_dir:
1093
+ temp_path = os.path.join(temp_dir, f"audio_upload{os.path.splitext(audio.filename)[1]}")
1094
+
1095
+ # Save uploaded file to temp directory
1096
+ try:
1097
+ content = await audio.read()
1098
+ with open(temp_path, "wb") as f:
1099
+ f.write(content)
1100
+ logger.info(f"Saved uploaded audio to {temp_path}")
1101
+
1102
+ # Use lock to ensure model loading is thread-safe
1103
+ async with whisperx_model_lock:
1104
+ # Load WhisperX model if not already loaded
1105
+ if whisperx_model is None:
1106
+ logger.info("Loading WhisperX model (one-time initialization)")
1107
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1108
+ # Use medium model for better accuracy, but small can be used for faster processing
1109
+ whisperx_model = whisperx.load_model("medium", device, compute_type=compute_type, asr_options={"beam_size": 5})
1110
+ logger.info(f"WhisperX model loaded on {device} with compute_type={compute_type}")
1111
+
1112
+ # Start processing timer
1113
+ start_time = time.time()
1114
+
1115
+ # Specify device for batch processing
1116
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1117
+
1118
+ # Transcribe with WhisperX (much faster than standard whisper)
1119
+ logger.info(f"Transcribing audio with WhisperX on {device}")
1120
+
1121
+ # Process audio file - much faster than standard whisper
1122
+ # and can process batches concurrently on GPU
1123
+ result = whisperx_model.transcribe(
1124
+ temp_path,
1125
+ language=language,
1126
+ batch_size=16 if device == "cuda" else 1 # Larger batch size for GPU
1127
+ )
1128
+
1129
+ # Align word timestamps if requested
1130
+ if align_text and result["segments"]:
1131
+ try:
1132
+ # Load alignment model
1133
+ logger.info("Aligning text with timestamps")
1134
+ alignment_model, metadata = whisperx.load_align_model(
1135
+ language_code=result["language"] if language is None else language,
1136
+ device=device
1137
+ )
1138
+ # Align
1139
+ result = whisperx.align(
1140
+ result["segments"],
1141
+ alignment_model,
1142
+ metadata,
1143
+ temp_path,
1144
+ device,
1145
+ return_char_alignments=False
1146
+ )
1147
+ except Exception as e:
1148
+ logger.warning(f"Word alignment failed: {e}")
1149
+ # Continue without alignment if it fails
1150
+
1151
+ # Calculate processing time
1152
+ processing_time = time.time() - start_time
1153
+
1154
+ # Log results
1155
+ logger.info(f"Successfully transcribed audio in {processing_time:.2f}s: {result['text'][:50]}...")
1156
+
1157
+ # Return results
1158
+ response = {
1159
+ "text": result["text"],
1160
+ "segments": result["segments"],
1161
+ "language": result.get("language", language),
1162
+ "processing_time": processing_time
1163
+ }
1164
+
1165
+ # Add word timestamps if available
1166
+ if align_text and "word_segments" in result:
1167
+ response["word_timestamps"] = result["word_segments"]
1168
+
1169
+ return response
1170
+
1171
+ except Exception as e:
1172
+ logger.error(f"Transcription failed: {e}", exc_info=True)
1173
+ raise HTTPException(
1174
+ status_code=500,
1175
+ detail=f"Failed to transcribe audio: {str(e)}"
1176
+ )
app/main.py CHANGED
@@ -551,6 +551,11 @@ from app.api.audiobook_routes import router as audiobook_router
551
  app.include_router(audiobook_router, prefix="/api/v1")
552
  app.include_router(audiobook_router, prefix="/v1")
553
 
 
 
 
 
 
554
  # Middleware for request timing
555
  @app.middleware("http")
556
  async def add_process_time_header(request: Request, call_next):
 
551
  app.include_router(audiobook_router, prefix="/api/v1")
552
  app.include_router(audiobook_router, prefix="/v1")
553
 
554
+ # Add realtime conversation routes
555
+ from app.api.realtime import router as realtime_router
556
+ app.include_router(realtime_router, prefix="/api/v1")
557
+ app.include_router(realtime_router, prefix="/v1")
558
+
559
  # Middleware for request timing
560
  @app.middleware("http")
561
  async def add_process_time_header(request: Request, call_next):
app/voice_cloning.py CHANGED
@@ -11,7 +11,9 @@ import tempfile
11
  import logging
12
  import asyncio
13
  import yt_dlp
14
- import whisper
 
 
15
  from typing import Dict, List, Optional, Union, Tuple, BinaryIO
16
  from pathlib import Path
17
 
@@ -30,6 +32,10 @@ logger = logging.getLogger(__name__)
30
  CLONED_VOICES_DIR = "/app/cloned_voices"
31
  os.makedirs(CLONED_VOICES_DIR, exist_ok=True)
32
 
 
 
 
 
33
  class ClonedVoice(BaseModel):
34
  """Model representing a cloned voice."""
35
  id: str
@@ -518,7 +524,7 @@ class VoiceCloner:
518
  # Step 1: Download audio from YouTube
519
  audio_path = await self._download_youtube_audio(youtube_url, temp_dir, start_time, duration)
520
 
521
- # Step 2: Generate transcript using Whisper
522
  transcript = await self._generate_transcript(audio_path)
523
 
524
  # Step 3: Clone the voice using the extracted audio and transcript
@@ -589,7 +595,7 @@ class VoiceCloner:
589
 
590
  async def _generate_transcript(self, audio_path: str) -> str:
591
  """
592
- Generate transcript from audio using Whisper.
593
 
594
  Args:
595
  audio_path: Path to the audio file
@@ -597,13 +603,53 @@ class VoiceCloner:
597
  Returns:
598
  Transcript text
599
  """
600
- # Load Whisper model (use small model for faster processing)
601
- model = whisper.load_model("small")
602
-
603
- # Transcribe the audio
604
- result = model.transcribe(audio_path)
605
 
606
- return result["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
 
608
  def generate_speech(
609
  self,
 
11
  import logging
12
  import asyncio
13
  import yt_dlp
14
+ # Replace standard whisper with WhisperX
15
+ # import whisper
16
+ import whisperx
17
  from typing import Dict, List, Optional, Union, Tuple, BinaryIO
18
  from pathlib import Path
19
 
 
32
  CLONED_VOICES_DIR = "/app/cloned_voices"
33
  os.makedirs(CLONED_VOICES_DIR, exist_ok=True)
34
 
35
+ # WhisperX model cache for performance
36
+ _whisperx_model = None
37
+ _whisperx_model_lock = asyncio.Lock()
38
+
39
  class ClonedVoice(BaseModel):
40
  """Model representing a cloned voice."""
41
  id: str
 
524
  # Step 1: Download audio from YouTube
525
  audio_path = await self._download_youtube_audio(youtube_url, temp_dir, start_time, duration)
526
 
527
+ # Step 2: Generate transcript using WhisperX
528
  transcript = await self._generate_transcript(audio_path)
529
 
530
  # Step 3: Clone the voice using the extracted audio and transcript
 
595
 
596
  async def _generate_transcript(self, audio_path: str) -> str:
597
  """
598
+ Generate transcript from audio using WhisperX (faster than standard Whisper).
599
 
600
  Args:
601
  audio_path: Path to the audio file
 
603
  Returns:
604
  Transcript text
605
  """
606
+ global _whisperx_model
 
 
 
 
607
 
608
+ try:
609
+ # Use device with CUDA if available
610
+ device = "cuda" if torch.cuda.is_available() else "cpu"
611
+
612
+ # Use lock to ensure model loading is thread-safe
613
+ async with _whisperx_model_lock:
614
+ # Load WhisperX model if not already loaded
615
+ if _whisperx_model is None:
616
+ logger.info("Loading WhisperX model for transcription (one-time initialization)")
617
+ compute_type = "float16" if device == "cuda" else "float32"
618
+ _whisperx_model = whisperx.load_model(
619
+ "medium", # Can use "small" for faster processing, "medium" for better quality
620
+ device,
621
+ compute_type=compute_type,
622
+ asr_options={"beam_size": 5}
623
+ )
624
+ logger.info(f"WhisperX model loaded on {device}")
625
+
626
+ # Start processing timer
627
+ start_time = time.time()
628
+
629
+ # Process with WhisperX - much faster than standard whisper,
630
+ # especially for longer files
631
+ logger.info(f"Transcribing audio with WhisperX on {device}")
632
+ result = _whisperx_model.transcribe(
633
+ audio_path,
634
+ batch_size=16 if device == "cuda" else 1 # Larger batch size on GPU for faster processing
635
+ )
636
+
637
+ # Calculate and log processing time
638
+ processing_time = time.time() - start_time
639
+ logger.info(f"Transcription completed in {processing_time:.2f}s")
640
+
641
+ return result["text"]
642
+ except Exception as e:
643
+ logger.error(f"WhisperX transcription failed: {e}", exc_info=True)
644
+ # Fallback to conventional approach if WhisperX fails
645
+ logger.warning("Falling back to basic transcription method")
646
+ try:
647
+ model = whisperx.load_model("small", device, compute_type="float32")
648
+ result = model.transcribe(audio_path, batch_size=1)
649
+ return result["text"]
650
+ except Exception as fallback_error:
651
+ logger.error(f"Fallback transcription also failed: {fallback_error}")
652
+ return "Transcription failed. Please provide a transcript manually."
653
 
654
  def generate_speech(
655
  self,
app/watermarking.py CHANGED
@@ -4,19 +4,83 @@ The original CSM code has a watermarking module that adds
4
  an imperceptible watermark to generated audio.
5
  """
6
 
7
- # Watermark key used by CSM
8
- CSM_1B_GH_WATERMARK = "CSM1B@GitHub"
9
-
10
- def load_watermarker(device="cpu"):
11
- """Stub for watermarker loading.
12
-
13
- In a real implementation, this would load the actual watermarker.
14
- """
15
- return None
16
-
17
- def watermark(watermarker, audio, sample_rate, key):
18
- """Stub for watermarking function.
19
-
20
- In a real implementation, this would add the watermark.
21
- """
22
- return audio, sample_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  an imperceptible watermark to generated audio.
5
  """
6
 
7
+ import argparse
8
+
9
+ import silentcipher
10
+ import torch
11
+ import torchaudio
12
+
13
+ # This watermark key is public, it is not secure.
14
+ # If using CSM 1B in another application, use a new private key and keep it secret.
15
+ CSM_1B_GH_WATERMARK = [221, 199, 199, 199, 221]
16
+ # [212, 211, 146, 56, 201]
17
+
18
+
19
+ def cli_check_audio() -> None:
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--audio_path", type=str, required=True)
22
+ args = parser.parse_args()
23
+
24
+ check_audio_from_file(args.audio_path)
25
+
26
+
27
+ def load_watermarker(device: str = "cuda") -> silentcipher.server.Model:
28
+ model = silentcipher.get_model(
29
+ model_type="44.1k",
30
+ device=device,
31
+ )
32
+ return model
33
+
34
+
35
+ @torch.inference_mode()
36
+ def watermark(
37
+ watermarker: silentcipher.server.Model,
38
+ audio_array: torch.Tensor,
39
+ sample_rate: int,
40
+ watermark_key: list[int],
41
+ ) -> tuple[torch.Tensor, int]:
42
+ audio_array_44khz = torchaudio.functional.resample(audio_array, orig_freq=sample_rate, new_freq=44100)
43
+ encoded, _ = watermarker.encode_wav(audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36)
44
+
45
+ output_sample_rate = min(44100, sample_rate)
46
+ encoded = torchaudio.functional.resample(encoded, orig_freq=44100, new_freq=output_sample_rate)
47
+ return encoded, output_sample_rate
48
+
49
+
50
+ @torch.inference_mode()
51
+ def verify(
52
+ watermarker: silentcipher.server.Model,
53
+ watermarked_audio: torch.Tensor,
54
+ sample_rate: int,
55
+ watermark_key: list[int],
56
+ ) -> bool:
57
+ watermarked_audio_44khz = torchaudio.functional.resample(watermarked_audio, orig_freq=sample_rate, new_freq=44100)
58
+ result = watermarker.decode_wav(watermarked_audio_44khz, 44100, phase_shift_decoding=True)
59
+
60
+ is_watermarked = result["status"]
61
+ if is_watermarked:
62
+ is_csm_watermarked = result["messages"][0] == watermark_key
63
+ else:
64
+ is_csm_watermarked = False
65
+
66
+ return is_watermarked and is_csm_watermarked
67
+
68
+
69
+ def check_audio_from_file(audio_path: str) -> None:
70
+ watermarker = load_watermarker(device="cuda")
71
+
72
+ audio_array, sample_rate = load_audio(audio_path)
73
+ is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK)
74
+
75
+ outcome = "Watermarked" if is_watermarked else "Not watermarked"
76
+ print(f"{outcome}: {audio_path}")
77
+
78
+
79
+ def load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
80
+ audio_array, sample_rate = torchaudio.load(audio_path)
81
+ audio_array = audio_array.mean(dim=0)
82
+ return audio_array, int(sample_rate)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ cli_check_audio()
requirements.txt CHANGED
@@ -29,4 +29,5 @@ python-dotenv>=1.0.1
29
  sqlalchemy>=2.0.0
30
  alembic>=1.13.0
31
  psycopg2-binary>=2.9.9
32
- certifi>=2024.2.2
 
 
29
  sqlalchemy>=2.0.0
30
  alembic>=1.13.0
31
  psycopg2-binary>=2.9.9
32
+ certifi>=2024.2.2
33
+ whisperx