karumati commited on
Commit
01115c6
·
1 Parent(s): 52513b0
.gitignore ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Environment
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+ ENV/
29
+
30
+ # Model & data files
31
+ models/
32
+ tokenizers/
33
+ voice_memories/
34
+ voice_samples/
35
+ cloned_voices/
36
+ *.pt
37
+ *.pth
38
+ *.bin
39
+ *.ckpt
40
+
41
+ # Audio files
42
+ *.wav
43
+ *.mp3
44
+ *.flac
45
+ *.opus
46
+ *.aac
47
+
48
+ # Logs
49
+ logs/
50
+ *.log
51
+ log.txt
52
+
53
+ # IDE
54
+ .idea/
55
+ .vscode/
56
+ *.swp
57
+ *.swo
58
+ .DS_Store
59
+
60
+ # Temp files
61
+ tmp/
62
+ temp/
63
+ .temp/
64
+ .~*
65
+
66
+ # Project specific
67
+ .cache/
68
+ .neptune/
69
+ MANIFEST.in
70
+ .history/
71
+ .mypy_cache/
72
+ .pytest_cache/
73
+ __pycache__/
74
+ .ipynb_checkpoints/
75
+
76
+ # Ignore Hugging Face credentials
77
+ .huggingface/
Dockerfile ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-stage build for authenticated model downloads
2
+ FROM python:3.10-slim AS model-downloader
3
+ # Install huggingface-cli
4
+ RUN pip install huggingface_hub
5
+ # Set working directory
6
+ WORKDIR /model-downloader
7
+ # Create directory for downloaded models
8
+ RUN mkdir -p /model-downloader/models
9
+ # This will run when building the image
10
+ # You'll need to pass your Hugging Face token at build time
11
+ ARG HF_TOKEN
12
+ ENV HF_TOKEN=${HF_TOKEN}
13
+ # Login and download model
14
+ RUN if [ -n "$HF_TOKEN" ]; then \
15
+ huggingface-cli login --token ${HF_TOKEN}; \
16
+ huggingface-cli download sesame/csm-1b ckpt.pt --local-dir /model-downloader/models; \
17
+ else echo "No HF_TOKEN provided, model download will be skipped"; fi
18
+
19
+ # Now for the main application stage
20
+ FROM nvidia/cuda:12.4.0-base-ubuntu22.04
21
+ # Set environment variables
22
+ ENV PYTHONFAULTHANDLER=1 \
23
+ PYTHONUNBUFFERED=1 \
24
+ PYTHONHASHSEED=random \
25
+ PIP_NO_CACHE_DIR=1 \
26
+ PIP_DISABLE_PIP_VERSION_CHECK=1 \
27
+ PIP_DEFAULT_TIMEOUT=100 \
28
+ NVIDIA_VISIBLE_DEVICES=all \
29
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility \
30
+ TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6" \
31
+ TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
32
+
33
+ # Install system dependencies
34
+ RUN apt-get update && apt-get install -y --no-install-recommends \
35
+ python3 \
36
+ python3-pip \
37
+ python3-dev \
38
+ ffmpeg \
39
+ git \
40
+ build-essential \
41
+ && apt-get clean \
42
+ && rm -rf /var/lib/apt/lists/*
43
+
44
+ # Set working directory
45
+ WORKDIR /app
46
+
47
+ # Copy requirements first for better caching
48
+ COPY requirements.txt .
49
+
50
+ # Create and set up persistent directories with proper permissions
51
+ RUN mkdir -p /app/static /app/models /app/voice_memories /app/voice_references \
52
+ /app/voice_profiles /app/cloned_voices /app/audio_cache /app/tokenizers /app/logs && \
53
+ chmod -R 777 /app/voice_references /app/voice_profiles /app/voice_memories \
54
+ /app/cloned_voices /app/audio_cache /app/static /app/logs /app/tokenizers /app/models
55
+
56
+ # Copy static files
57
+ COPY ./static /app/static
58
+
59
+ # Install Python dependencies
60
+ RUN pip3 install --no-cache-dir --upgrade pip && \
61
+ pip3 install torch torchaudio numpy
62
+
63
+ # Install torchao from source
64
+ RUN pip3 install git+https://github.com/pytorch/ao.git
65
+
66
+ # Install torchtune from source with specific branch for latest features
67
+ RUN git clone https://github.com/pytorch/torchtune.git /tmp/torchtune && \
68
+ cd /tmp/torchtune && \
69
+ # Try to use the main branch, which should have llama3_2
70
+ git checkout main && \
71
+ pip install -e .
72
+
73
+ # Install remaining dependencies
74
+ RUN pip3 install -r requirements.txt
75
+
76
+ # Install additional dependencies for streaming and voice cloning
77
+ RUN pip3 install yt-dlp openai-whisper
78
+
79
+ # Copy application code
80
+ COPY ./app /app/app
81
+
82
+ # Copy downloaded model from the model-downloader stage
83
+ COPY --from=model-downloader /model-downloader/models /app/models
84
+
85
+ # Show available models in torchtune
86
+ RUN python3 -c "import torchtune.models; print('Available models in torchtune:', dir(torchtune.models))"
87
+
88
+ # Expose port
89
+ EXPOSE 8000
90
+
91
+ # Command to run the application
92
+ CMD ["python3", "-m", "app.main"]
README.md CHANGED
@@ -1,10 +1,475 @@
1
  ---
2
- title: Sesame Openai
3
- emoji:
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: docker
 
 
7
  pinned: false
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: CSM-1B TTS Interface # Choose a descriptive title
3
+ emoji: 🔊 # Choose an emoji (e.g., 🎤, 🔊, ✨)
4
+ colorFrom: blue # Optional: Start color for card gradient
5
+ colorTo: indigo # Optional: End color for card gradient
6
  sdk: docker
7
+ sdk_version: "28.0.1" # <-- IMPORTANT: Check your requirements.txt for the exact gradio version you installed! Update this value.
8
+ app_file: app/main.py # <-- IMPORTANT: Use the EXACT filename of your main Gradio script (the one with `demo.launch()`)
9
  pinned: false
10
+ # Optional: Add other configurations like python_version if needed
11
+ # python_version: "3.10"
12
+ # Optional: Specify hardware if needed (e.g., for GPU)
13
+ # hardware: cpu-upgrade # or gpu-small, gpu-a10g-small etc. Check HF pricing/docs
14
+ # Optional: Specify secrets needed (like HF_TOKEN if model download needs it)
15
+ # secrets:
16
+ # - HF_TOKEN
17
  ---
18
 
19
+
20
+
21
+
22
+
23
+ # CSM-1B TTS API
24
+
25
+ An OpenAI-compatible Text-to-Speech API that harnesses the power of Sesame's Conversational Speech Model (CSM-1B). This API allows you to generate high-quality speech from text using a variety of consistent voices, compatible with systems like OpenWebUI, ChatBot UI, and any platform that supports the OpenAI TTS API format.
26
+
27
+ ## Features
28
+
29
+ - **OpenAI API Compatibility**: Drop-in replacement for OpenAI's TTS API
30
+ - **Multiple Voices**: Six distinct voices (alloy, echo, fable, onyx, nova, shimmer)
31
+ - **Voice Consistency**: Maintains consistent voice characteristics across multiple requests
32
+ - **Voice Cloning**: Clone your own voice from audio samples
33
+ - **Conversational Context**: Supports conversational context for improved naturalness
34
+ - **Multiple Audio Formats**: Supports MP3, OPUS, AAC, FLAC, and WAV
35
+ - **Speed Control**: Adjustable speech speed
36
+ - **CUDA Acceleration**: GPU support for faster generation
37
+ - **Web UI**: Simple interface for voice cloning and speech generation
38
+
39
+ ## Getting Started
40
+
41
+ ### Prerequisites
42
+
43
+ - Docker and Docker Compose
44
+ - NVIDIA GPU with CUDA support (recommended)
45
+ - Hugging Face account with access to `sesame/csm-1b` model
46
+
47
+ ### Installation
48
+
49
+ 1. Clone this repository:
50
+ ```bash
51
+ git clone https://github.com/phildougherty/sesame_csm_openai
52
+ cd sesame_csm_openai
53
+ ```
54
+
55
+ 2. Create a `.env` file in the /app folder with your Hugging Face token:
56
+ ```
57
+ HF_TOKEN=your_hugging_face_token_here
58
+ ```
59
+
60
+ 3. Build and start the container:
61
+ ```bash
62
+ docker compose up -d --build
63
+ ```
64
+
65
+ The server will start on port 8000. First startup may take some time as it downloads the model files.
66
+
67
+ ## Hugging Face Configuration (ONLY NEEDED TO ACCEPT TERMS/DOWNLOAD MODEL)
68
+
69
+ This API requires access to the `sesame/csm-1b` model on Hugging Face:
70
+
71
+ 1. Create a Hugging Face account if you don't have one: [https://huggingface.co/join](https://huggingface.co/join)
72
+ 2. Accept the model license at [https://huggingface.co/sesame/csm-1b](https://huggingface.co/sesame/csm-1b)
73
+ 3. Generate an access token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
74
+ 4. Use this token in your `.env` file or pass it directly when building the container:
75
+
76
+ ```bash
77
+ HF_TOKEN=your_token docker compose up -d --build
78
+ ```
79
+
80
+ ### Required Models
81
+
82
+ The API uses the following models which are downloaded automatically:
83
+
84
+ - **CSM-1B**: The main speech generation model from Sesame
85
+ - **Mimi**: Audio codec for high-quality audio generation
86
+ - **Llama Tokenizer**: Uses the unsloth/Llama-3.2-1B tokenizer for text processing
87
+
88
+ ## Multi-GPU Support
89
+
90
+ The CSM-1B model can be distributed across multiple GPUs to handle larger models or improve performance. To enable multi-GPU support, set the `CSM_DEVICE_MAP` environment variable:
91
+
92
+ ```bash
93
+ # Automatic device mapping (recommended)
94
+ CSM_DEVICE_MAP=auto docker compose up -d
95
+
96
+ # Balanced distribution of layers across GPUs
97
+ CSM_DEVICE_MAP=balanced docker compose up -d
98
+
99
+ # Sequential distribution (backbone on first GPUs, decoder on remaining)
100
+ CSM_DEVICE_MAP=sequential docker compose up -d
101
+
102
+ ## Voice Cloning Guide
103
+
104
+ The CSM-1B TTS API comes with powerful voice cloning capabilities that allow you to create custom voices from audio samples. Here's how to use this feature:
105
+
106
+ ### Method 1: Using the Web Interface
107
+
108
+ 1. Access the voice cloning UI by navigating to `http://your-server-ip:8000/voice-cloning` in your browser.
109
+
110
+ 2. **Clone a Voice**:
111
+ - Go to the "Clone Voice" tab
112
+ - Enter a name for your voice
113
+ - Upload an audio sample (2-3 minutes of clear speech works best)
114
+ - Optionally provide a transcript of the audio for better results
115
+ - Click "Clone Voice"
116
+
117
+ 3. **View Your Voices**:
118
+ - Navigate to the "My Voices" tab to see all your cloned voices
119
+ - You can preview or delete voices from this tab
120
+
121
+ 4. **Generate Speech**:
122
+ - Go to the "Generate Speech" tab
123
+ - Select one of your cloned voices
124
+ - Enter the text you want to synthesize
125
+ - Adjust the temperature slider if needed (lower for more consistent results)
126
+ - Click "Generate Speech" and listen to the result
127
+
128
+ ### Method 2: Using the API
129
+
130
+ 1. **Clone a Voice**:
131
+ ```bash
132
+ curl -X POST http://localhost:8000/v1/voice-cloning/clone \
133
+ -F "name=My Voice" \
134
+ -F "audio_file=@path/to/your/voice_sample.mp3" \
135
+ -F "transcript=Optional transcript of the audio sample" \
136
+ -F "description=A description of this voice"
137
+ ```
138
+
139
+ 2. **List Available Cloned Voices**:
140
+ ```bash
141
+ curl -X GET http://localhost:8000/v1/voice-cloning/voices
142
+ ```
143
+
144
+ 3. **Generate Speech with a Cloned Voice**:
145
+ ```bash
146
+ curl -X POST http://localhost:8000/v1/voice-cloning/generate \
147
+ -H "Content-Type: application/json" \
148
+ -d '{
149
+ "voice_id": "1234567890_my_voice",
150
+ "text": "This is my cloned voice speaking.",
151
+ "temperature": 0.7
152
+ }' \
153
+ --output cloned_speech.mp3
154
+ ```
155
+
156
+ 4. **Generate a Voice Preview**:
157
+ ```bash
158
+ curl -X POST http://localhost:8000/v1/voice-cloning/voices/1234567890_my_voice/preview \
159
+ --output voice_preview.mp3
160
+ ```
161
+
162
+ 5. **Delete a Cloned Voice**:
163
+ ```bash
164
+ curl -X DELETE http://localhost:8000/v1/voice-cloning/voices/1234567890_my_voice
165
+ ```
166
+
167
+ ### Voice Cloning Best Practices
168
+
169
+ For the best voice cloning results:
170
+
171
+ 1. **Use High-Quality Audio**: Record in a quiet environment with minimal background noise and echo.
172
+
173
+ 2. **Provide Sufficient Length**: 2-3 minutes of speech provides better results than shorter samples.
174
+
175
+ 3. **Clear, Natural Speech**: Speak naturally at a moderate pace with clear pronunciation.
176
+
177
+ 4. **Include Various Intonations**: Sample should contain different sentence types (statements, questions) for better expressiveness.
178
+
179
+ 5. **Add a Transcript**: While optional, providing an accurate transcript of your recording helps the model better capture your voice characteristics.
180
+
181
+ 6. **Adjust Temperature**: For more consistent results, use lower temperature values (0.6-0.7). For more expressiveness, use higher values (0.7-0.9).
182
+
183
+ 7. **Try Multiple Samples**: If you're not satisfied with the results, try recording a different sample or adjusting the speaking style.
184
+
185
+ ### Using Cloned Voices with the Standard TTS Endpoint
186
+
187
+ Cloned voices are automatically available through the standard OpenAI-compatible endpoint. Simply use the voice ID or name as the `voice` parameter:
188
+
189
+ ```bash
190
+ curl -X POST http://localhost:8000/v1/audio/speech \
191
+ -H "Content-Type: application/json" \
192
+ -d '{
193
+ "model": "csm-1b",
194
+ "input": "This is my cloned voice speaking through the standard endpoint.",
195
+ "voice": "1234567890_my_voice",
196
+ "response_format": "mp3"
197
+ }' \
198
+ --output cloned_speech.mp3
199
+ ```
200
+
201
+ ## YouTube Voice Cloning
202
+
203
+ The CSM-1B TTS API now includes the ability to clone voices directly from YouTube videos. This feature allows you to extract voice characteristics from any YouTube content and create custom TTS voices without needing to download or prepare audio samples yourself.
204
+
205
+ ## How to Clone a Voice from YouTube
206
+
207
+ ### API Endpoint
208
+
209
+ ```
210
+ POST /v1/audio/speech/voice-cloning/youtube
211
+ ```
212
+
213
+ Parameters:
214
+ - `youtube_url`: URL of the YouTube video
215
+ - `voice_name`: Name for the cloned voice
216
+ - `start_time` (optional): Start time in seconds (default: 0)
217
+ - `duration` (optional): Duration to extract in seconds (default: 180)
218
+ - `description` (optional): Description of the voice
219
+
220
+ Example request:
221
+ ```json
222
+ {
223
+ "youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
224
+ "voice_name": "rick_astley",
225
+ "start_time": 30,
226
+ "duration": 60,
227
+ "description": "Never gonna give you up"
228
+ }
229
+ ```
230
+
231
+ Response:
232
+ ```json
233
+ {
234
+ "voice_id": "1710805983_rick_astley",
235
+ "name": "rick_astley",
236
+ "description": "Never gonna give you up",
237
+ "created_at": "2025-03-18T22:53:03Z",
238
+ "audio_duration": 60.0,
239
+ "sample_count": 1440000
240
+ }
241
+ ```
242
+
243
+ ## How It Works
244
+
245
+ 1. The system downloads the audio from the specified YouTube video
246
+ 2. It extracts the specified segment (start time and duration)
247
+ 3. Whisper ASR generates a transcript of the audio for better voice matching
248
+ 4. The audio is processed to remove noise and silence
249
+ 5. The voice is cloned and made available for TTS generation
250
+
251
+ ## Best Practices for YouTube Voice Cloning
252
+
253
+ For optimal results:
254
+
255
+ 1. **Choose Clear Speech Segments**
256
+ - Select portions of the video with clear, uninterrupted speech
257
+ - Avoid segments with background music, sound effects, or multiple speakers
258
+
259
+ 2. **Optimal Duration**
260
+ - 30-60 seconds of clean speech typically provides the best results
261
+ - Longer isn't always better - quality matters more than quantity
262
+
263
+ 3. **Specify Time Ranges Precisely**
264
+ - Use `start_time` and `duration` to target the exact speech segment
265
+ - Preview the segment in YouTube before cloning to ensure it's suitable
266
+
267
+ 4. **Consider Audio Quality**
268
+ - Higher quality videos generally produce better voice clones
269
+ - Interviews, vlogs, and speeches often work better than highly produced content
270
+
271
+ ## Limitations
272
+
273
+ - YouTube videos with heavy background music may result in lower quality voice clones
274
+ - Very noisy or low-quality audio sources will produce less accurate voice clones
275
+ - The system works best with natural speech rather than singing or exaggerated voices
276
+ - Copyright restrictions apply - only clone voices you have permission to use
277
+
278
+ ## Example Use Cases
279
+
280
+ - Create a voice clone of a public figure for educational content
281
+ - Clone your own YouTube voice for consistent TTS across your applications
282
+ - Create voice clones from historical speeches or interviews (public domain)
283
+ - Develop custom voices for creative projects with proper permissions
284
+
285
+ ## Ethical Considerations
286
+
287
+ Please use YouTube voice cloning responsibly:
288
+ - Only clone voices from content you have permission to use
289
+ - Respect copyright and intellectual property rights
290
+ - Clearly disclose when using AI-generated or cloned voices
291
+ - Do not use cloned voices for impersonation, deception, or harmful content
292
+
293
+ ## How the Voices Work
294
+
295
+ Unlike traditional TTS systems with pre-trained voice models, CSM-1B works differently:
296
+
297
+ - The base CSM-1B model is capable of producing a wide variety of voices but doesn't have fixed voice identities
298
+ - This API creates consistent voices by using acoustic "seed" samples for each named voice
299
+ - When you specify a voice (e.g., "alloy"), the API uses a consistent acoustic seed and speaker ID
300
+ - The most recent generated audio becomes the new reference for that voice, maintaining voice consistency
301
+ - Each voice has unique tonal qualities:
302
+ - **alloy**: Balanced mid-tones with natural inflection
303
+ - **echo**: Resonant with slight reverberance
304
+ - **fable**: Brighter with higher pitch
305
+ - **onyx**: Deep and resonant
306
+ - **nova**: Warm and smooth
307
+ - **shimmer**: Light and airy with higher frequencies
308
+
309
+ The voice system can be extended with your own voice samples by using the voice cloning feature.
310
+
311
+ ## API Usage
312
+
313
+ ### Basic Usage
314
+
315
+ Generate speech with a POST request to `/v1/audio/speech`:
316
+
317
+ ```bash
318
+ curl -X POST http://localhost:8000/v1/audio/speech \
319
+ -H "Content-Type: application/json" \
320
+ -d '{
321
+ "model": "csm-1b",
322
+ "input": "Hello, this is a test of the CSM text to speech system.",
323
+ "voice": "alloy",
324
+ "response_format": "mp3"
325
+ }' \
326
+ --output speech.mp3
327
+ ```
328
+
329
+ ### Available Endpoints
330
+
331
+ #### Standard TTS Endpoints
332
+ - `GET /v1/audio/models` - List available models
333
+ - `GET /v1/audio/voices` - List available voices (including cloned voices)
334
+ - `GET /v1/audio/speech/response-formats` - List available response formats
335
+ - `POST /v1/audio/speech` - Generate speech from text
336
+ - `POST /api/v1/audio/conversation` - Advanced endpoint for conversational speech
337
+
338
+ #### Voice Cloning Endpoints
339
+ - `POST /v1/voice-cloning/clone` - Clone a new voice from an audio sample
340
+ - `GET /v1/voice-cloning/voices` - List all cloned voices
341
+ - `POST /v1/voice-cloning/generate` - Generate speech with a cloned voice
342
+ - `POST /v1/voice-cloning/voices/{voice_id}/preview` - Generate a preview of a cloned voice
343
+ - `DELETE /v1/voice-cloning/voices/{voice_id}` - Delete a cloned voice
344
+
345
+ ### Request Parameters
346
+
347
+ #### Standard TTS
348
+ | Parameter | Description | Type | Default |
349
+ |-----------|-------------|------|---------|
350
+ | `model` | Model ID to use | string | "csm-1b" |
351
+ | `input` | The text to convert to speech | string | Required |
352
+ | `voice` | The voice to use (standard or cloned voice ID) | string | "alloy" |
353
+ | `response_format` | Audio format | string | "mp3" |
354
+ | `speed` | Speech speed multiplier | float | 1.0 |
355
+ | `temperature` | Sampling temperature | float | 0.8 |
356
+ | `max_audio_length_ms` | Maximum audio length in ms | integer | 90000 |
357
+
358
+ #### Voice Cloning
359
+ | Parameter | Description | Type | Default |
360
+ |-----------|-------------|------|---------|
361
+ | `name` | Name for the cloned voice | string | Required |
362
+ | `audio_file` | Audio sample file | file | Required |
363
+ | `transcript` | Transcript of the audio | string | Optional |
364
+ | `description` | Description of the voice | string | Optional |
365
+
366
+ ### Available Voices
367
+
368
+ - `alloy` - Balanced and natural
369
+ - `echo` - Resonant
370
+ - `fable` - Bright and higher-pitched
371
+ - `onyx` - Deep and resonant
372
+ - `nova` - Warm and smooth
373
+ - `shimmer` - Light and airy
374
+ - `[cloned voice ID]` - Any voice you've cloned using the voice cloning feature
375
+
376
+ ### Response Formats
377
+
378
+ - `mp3` - MP3 audio format
379
+ - `opus` - Opus audio format
380
+ - `aac` - AAC audio format
381
+ - `flac` - FLAC audio format
382
+ - `wav` - WAV audio format
383
+
384
+ ## Integration with OpenWebUI
385
+
386
+ OpenWebUI is a popular open-source UI for AI models that supports custom TTS endpoints. Here's how to integrate the CSM-1B TTS API:
387
+
388
+ 1. Access your OpenWebUI settings
389
+ 2. Navigate to the TTS settings section
390
+ 3. Select "Custom TTS Endpoint"
391
+ 4. Enter your CSM-1B TTS API URL: `http://your-server-ip:8000/v1/audio/speech`
392
+ 5. Use the API Key field to add any authentication if you've configured it (not required by default)
393
+ 6. Test the connection
394
+ 7. Save your settings
395
+
396
+ Once configured, OpenWebUI will use your CSM-1B TTS API for all text-to-speech conversion, producing high-quality speech with the selected voice.
397
+
398
+ ### Using Cloned Voices with OpenWebUI
399
+
400
+ Your cloned voices will automatically appear in OpenWebUI's voice selector. Simply choose your cloned voice from the dropdown menu in the TTS settings or chat interface.
401
+
402
+ ## Advanced Usage
403
+
404
+ ### Conversational Context
405
+
406
+ For more natural-sounding speech in a conversation, you can use the conversation endpoint:
407
+
408
+ ```bash
409
+ curl -X POST http://localhost:8000/api/v1/audio/conversation \
410
+ -H "Content-Type: application/json" \
411
+ -d '{
412
+ "text": "Nice to meet you too!",
413
+ "speaker_id": 0,
414
+ "context": [
415
+ {
416
+ "speaker": 1,
417
+ "text": "Hello, nice to meet you.",
418
+ "audio": "BASE64_ENCODED_AUDIO"
419
+ }
420
+ ]
421
+ }' \
422
+ --output response.wav
423
+ ```
424
+
425
+ This allows the model to take into account the previous utterances for more contextually appropriate speech.
426
+
427
+ ### Model Parameters
428
+
429
+ For fine-grained control, you can adjust:
430
+
431
+ - `temperature` (0.0-1.0): Higher values produce more variation but may be less stable
432
+ - `topk` (1-100): Controls diversity of generated speech
433
+ - `max_audio_length_ms`: Maximum length of generated audio in milliseconds
434
+ - `voice_consistency` (0.0-1.0): How strongly to maintain voice characteristics across segments
435
+
436
+ ## Troubleshooting
437
+
438
+ ### API Returns 503 Service Unavailable
439
+
440
+ - Verify your Hugging Face token has access to `sesame/csm-1b`
441
+ - Check if the model downloaded successfully in the logs
442
+ - Ensure you have enough GPU memory (at least 8GB recommended)
443
+
444
+ ### Audio Quality Issues
445
+
446
+ - Try different voices - some may work better for your specific text
447
+ - Adjust temperature (lower for more stable output)
448
+ - For longer texts, the API automatically splits into smaller chunks for better quality
449
+ - For cloned voices, try recording a cleaner audio sample
450
+
451
+ ### Voice Cloning Issues
452
+
453
+ - **Poor Voice Quality**: Try recording in a quieter environment with less background noise
454
+ - **Inconsistent Voice**: Provide a longer and more varied audio sample (2-3 minutes)
455
+ - **Accent Issues**: Make sure your sample contains similar words/sounds to what you'll be generating
456
+ - **Low Volume**: The sample is normalized automatically, but ensure it's not too quiet or distorted
457
+
458
+ ### Voice Inconsistency
459
+
460
+ - The API maintains voice consistency across separate requests
461
+ - However, very long pauses between requests may result in voice drift
462
+ - For critical applications, consider using the same seed audio
463
+
464
+ ## License
465
+
466
+ This project is released under the MIT License. The CSM-1B model is subject to its own license terms defined by Sesame.
467
+
468
+ ## Acknowledgments
469
+
470
+ - [Sesame](https://www.sesame.com) for releasing the CSM-1B model
471
+ - This project is not affiliated with or endorsed by Sesame or OpenAI
472
+
473
+ ---
474
+
475
+ Happy speech generating!
app/.dockerignore ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Version control
2
+ .git
3
+ .gitignore
4
+ .github
5
+
6
+ # Python
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+ *.so
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+
28
+ # Environment
29
+ .env
30
+ .venv
31
+ env/
32
+ venv/
33
+ ENV/
34
+
35
+ # Docker related
36
+ Dockerfile*
37
+ docker-compose*
38
+ .docker
39
+ .dockerignore
40
+
41
+ # Documentation
42
+ README.md
43
+ CHANGELOG.md
44
+ LICENSE
45
+ docs/
46
+ *.md
47
+ examples/
48
+ tests/
49
+
50
+ # IDE
51
+ .idea/
52
+ .vscode/
53
+ *.swp
54
+ *.swo
55
+ .DS_Store
56
+
57
+ # Logs and temp files
58
+ logs/
59
+ *.log
60
+ log.txt
61
+ tmp/
62
+ temp/
63
+ .temp/
64
+ .~*
65
+
66
+ # Model files - we'll download these in the container
67
+ models/
68
+ tokenizers/
69
+ voice_memories/
70
+ voice_samples/
71
+ *.pt
72
+ *.pth
73
+ *.bin
74
+ *.ckpt
75
+
76
+ # Audio files
77
+ *.wav
78
+ *.mp3
79
+ *.flac
80
+ *.opus
81
+ *.aac
82
+
83
+ # Project specific
84
+ .cache/
85
+ .neptune/
86
+ MANIFEST.in
87
+ .history/
88
+ .mypy_cache/
89
+ .pytest_cache/
90
+ __pycache__/
91
+ .ipynb_checkpoints/
92
+
93
+ # Ignore Hugging Face credentials except specified files
94
+ .huggingface/
app/api/init.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """API package for CSM-1B."""
app/api/routes.py ADDED
@@ -0,0 +1,1048 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API routes for the CSM-1B TTS API.
3
+ """
4
+ import os
5
+ import io
6
+ import base64
7
+ import time
8
+ import tempfile
9
+ import logging
10
+ import asyncio
11
+ from enum import Enum
12
+ from typing import Dict, List, Optional, Any, Union
13
+ import torch
14
+ import torchaudio
15
+ import numpy as np
16
+ from fastapi import APIRouter, Request, HTTPException, BackgroundTasks, Body, Response, Query
17
+ from fastapi.responses import StreamingResponse
18
+ from app.api.schemas import SpeechRequest, ResponseFormat, Voice
19
+ from app.models import Segment
20
+ from app.api.streaming import AudioChunker
21
+ from app.prompt_engineering import split_into_segments
22
+
23
+ # Set up logging
24
+ logger = logging.getLogger(__name__)
25
+ router = APIRouter()
26
+
27
+ # Mapping of response_format to MIME types
28
+ MIME_TYPES = {
29
+ "mp3": "audio/mpeg",
30
+ "opus": "audio/opus",
31
+ "aac": "audio/aac",
32
+ "flac": "audio/flac",
33
+ "wav": "audio/wav",
34
+ }
35
+
36
+ def get_speaker_id(app_state, voice):
37
+ """Helper function to get speaker ID from voice name or ID"""
38
+ if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map:
39
+ return app_state.voice_speaker_map[voice]
40
+
41
+ # Standard voices mapping
42
+ voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5}
43
+
44
+ if voice in voice_to_speaker:
45
+ return voice_to_speaker[voice]
46
+
47
+ # Try parsing as integer
48
+ try:
49
+ speaker_id = int(voice)
50
+ if 0 <= speaker_id < 6:
51
+ return speaker_id
52
+ except (ValueError, TypeError):
53
+ pass
54
+
55
+ # Check cloned voices if the voice cloner exists
56
+ if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None:
57
+ # Check by ID
58
+ if voice in app_state.voice_cloner.cloned_voices:
59
+ return app_state.voice_cloner.cloned_voices[voice].speaker_id
60
+
61
+ # Check by name
62
+ for v_id, v_info in app_state.voice_cloner.cloned_voices.items():
63
+ if v_info.name.lower() == voice.lower():
64
+ return v_info.speaker_id
65
+
66
+ # Default to alloy
67
+ return 0
68
+
69
+ @router.post("/audio/speech", tags=["Audio"], response_class=Response)
70
+ async def generate_speech(
71
+ request: Request,
72
+ speech_request: SpeechRequest,
73
+ ):
74
+ """
75
+ Generate audio of text being spoken by a realistic voice.
76
+
77
+ This endpoint is compatible with the OpenAI TTS API.
78
+
79
+ For streaming responses, use `/v1/audio/speech/streaming` instead.
80
+ """
81
+ # Check if model is available
82
+ if not hasattr(request.app.state, "generator") or request.app.state.generator is None:
83
+ raise HTTPException(status_code=503, detail="TTS model not available")
84
+
85
+ # Set default values
86
+ model = speech_request.model
87
+ voice = speech_request.voice
88
+ input_text = speech_request.input
89
+ response_format = speech_request.response_format
90
+ speed = speech_request.speed
91
+ temperature = speech_request.temperature
92
+ max_audio_length_ms = speech_request.max_audio_length_ms
93
+
94
+ # Log request details
95
+ logger.info(f"TTS request: text length={len(input_text)}, voice={voice}, format={response_format}")
96
+
97
+ try:
98
+ # Get speaker ID for the voice
99
+ speaker_id = get_speaker_id(request.app.state, voice)
100
+ if speaker_id is None:
101
+ raise HTTPException(status_code=400, detail=f"Voice '{voice}' not found")
102
+
103
+ # Check if this is a cloned voice
104
+ voice_info = None
105
+ cloned_voice_id = None
106
+
107
+ if hasattr(request.app.state, "get_voice_info"):
108
+ voice_info = request.app.state.get_voice_info(voice)
109
+ if voice_info and voice_info["type"] == "cloned":
110
+ cloned_voice_id = voice_info["voice_id"]
111
+
112
+ # Generate audio based on whether it's a standard or cloned voice
113
+ if cloned_voice_id is not None and hasattr(request.app.state, "voice_cloner"):
114
+ # Generate speech with cloned voice
115
+ logger.info(f"Generating speech with cloned voice ID: {cloned_voice_id}")
116
+ try:
117
+ voice_cloner = request.app.state.voice_cloner
118
+ audio = voice_cloner.generate_speech(
119
+ text=input_text,
120
+ voice_id=cloned_voice_id,
121
+ temperature=temperature,
122
+ topk=speech_request.topk or 30,
123
+ max_audio_length_ms=max_audio_length_ms
124
+ )
125
+ sample_rate = request.app.state.sample_rate
126
+ logger.info(f"Generated speech with cloned voice, length: {len(audio)/sample_rate:.2f}s")
127
+ except Exception as e:
128
+ logger.error(f"Error generating speech with cloned voice: {e}", exc_info=True)
129
+ raise HTTPException(
130
+ status_code=500,
131
+ detail=f"Failed to generate speech with cloned voice: {str(e)}"
132
+ )
133
+ else:
134
+ # Generate speech with standard voice
135
+ # Use voice context from memory if enabled
136
+ if hasattr(request.app.state, "voice_memory_enabled") and request.app.state.voice_memory_enabled:
137
+ from app.voice_memory import get_voice_context
138
+ context = get_voice_context(voice, torch.device(request.app.state.device))
139
+ else:
140
+ context = []
141
+
142
+ # Apply optional text enhancement for better voice consistency
143
+ enhanced_text = input_text
144
+ if hasattr(request.app.state, "prompt_templates"):
145
+ from app.prompt_engineering import format_text_for_voice
146
+ enhanced_text = format_text_for_voice(input_text, voice)
147
+
148
+ # Generate audio
149
+ audio = request.app.state.generator.generate(
150
+ text=enhanced_text,
151
+ speaker=speaker_id,
152
+ context=context,
153
+ temperature=temperature,
154
+ topk=speech_request.topk or 50,
155
+ max_audio_length_ms=max_audio_length_ms
156
+ )
157
+ sample_rate = request.app.state.sample_rate
158
+
159
+ # Process audio for better quality
160
+ if hasattr(request.app.state, "voice_enhancement_enabled") and request.app.state.voice_enhancement_enabled:
161
+ from app.voice_enhancement import process_generated_audio
162
+ audio = process_generated_audio(
163
+ audio=audio,
164
+ voice_name=voice,
165
+ sample_rate=sample_rate,
166
+ text=input_text
167
+ )
168
+
169
+ # Update voice memory if enabled
170
+ if hasattr(request.app.state, "voice_memory_enabled") and request.app.state.voice_memory_enabled:
171
+ from app.voice_memory import update_voice_memory
172
+ update_voice_memory(voice, audio, input_text)
173
+
174
+ # Handle speed adjustments if not 1.0
175
+ if speed != 1.0 and speed > 0:
176
+ try:
177
+ # Adjust speed using torchaudio
178
+ effects = [
179
+ ["tempo", str(speed)]
180
+ ]
181
+ audio_cpu = audio.cpu()
182
+ adjusted_audio, _ = torchaudio.sox_effects.apply_effects_tensor(
183
+ audio_cpu.unsqueeze(0),
184
+ sample_rate,
185
+ effects
186
+ )
187
+ audio = adjusted_audio.squeeze(0)
188
+ logger.info(f"Adjusted speech speed to {speed}x")
189
+ except Exception as e:
190
+ logger.warning(f"Failed to adjust speech speed: {e}")
191
+
192
+ # Format the audio according to the requested format
193
+ response_data, content_type = await format_audio(
194
+ audio,
195
+ response_format,
196
+ sample_rate,
197
+ request.app.state
198
+ )
199
+
200
+ # Create and return the response
201
+ return Response(
202
+ content=response_data,
203
+ media_type=content_type,
204
+ headers={"Content-Disposition": f"attachment; filename=speech.{response_format}"}
205
+ )
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error in text_to_speech: {e}", exc_info=True)
209
+ raise HTTPException(status_code=500, detail=str(e))
210
+
211
+ @router.post("/audio/speech/stream", tags=["Audio"])
212
+ async def stream_speech(request: Request, speech_request: SpeechRequest):
213
+ """Stream audio in real-time as it's being generated."""
214
+ # Check if model is loaded
215
+ if not hasattr(request.app.state, "generator") or request.app.state.generator is None:
216
+ raise HTTPException(status_code=503, detail="Model not loaded")
217
+
218
+ # Get request parameters
219
+ input_text = speech_request.input
220
+ voice = speech_request.voice
221
+ response_format = speech_request.response_format
222
+ temperature = speech_request.temperature
223
+
224
+ logger.info(f"Real-time streaming speech from text ({len(input_text)} chars) with voice '{voice}'")
225
+
226
+ # Get speaker ID for the voice
227
+ speaker_id = get_speaker_id(request.app.state, voice)
228
+ if speaker_id is None:
229
+ raise HTTPException(status_code=400, detail=f"Voice '{voice}' not found")
230
+
231
+ # Split text into very small segments for incremental generation
232
+ text_segments = split_into_segments(input_text, max_chars=50) # Smaller segments for faster first response
233
+ logger.info(f"Split text into {len(text_segments)} segments")
234
+
235
+ # Create media type based on format
236
+ media_type = {
237
+ "mp3": "audio/mpeg",
238
+ "opus": "audio/opus",
239
+ "aac": "audio/aac",
240
+ "flac": "audio/flac",
241
+ "wav": "audio/wav",
242
+ }.get(response_format, "audio/mpeg")
243
+
244
+ # For streaming, WAV works best
245
+ streaming_format = "wav"
246
+
247
+ # Set up WAV header for streaming
248
+ sample_rate = request.app.state.sample_rate
249
+
250
+ async def generate_streaming_audio():
251
+ # Get context for the voice
252
+ if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled:
253
+ voice_info = request.app.state.get_voice_info(voice)
254
+ if voice_info and voice_info["type"] == "cloned":
255
+ # Use cloned voice context
256
+ voice_cloner = request.app.state.voice_cloner
257
+ context = voice_cloner.get_voice_context(voice_info["voice_id"])
258
+ else:
259
+ # Standard voice
260
+ from app.voice_enhancement import get_voice_segments
261
+ context = get_voice_segments(voice, request.app.state.device)
262
+ else:
263
+ # Standard voice
264
+ from app.voice_enhancement import get_voice_segments
265
+ context = get_voice_segments(voice, request.app.state.device)
266
+
267
+ # Send WAV header immediately
268
+ if streaming_format == "wav":
269
+ # Create a WAV header for 16-bit mono audio
270
+ header = bytes()
271
+ # RIFF header
272
+ header += b'RIFF'
273
+ header += b'\x00\x00\x00\x00' # Placeholder for file size
274
+ header += b'WAVE'
275
+ # Format chunk
276
+ header += b'fmt '
277
+ header += (16).to_bytes(4, 'little') # Format chunk size
278
+ header += (1).to_bytes(2, 'little') # PCM format
279
+ header += (1).to_bytes(2, 'little') # Mono channel
280
+ header += (sample_rate).to_bytes(4, 'little') # Sample rate
281
+ header += (sample_rate * 2).to_bytes(4, 'little') # Byte rate
282
+ header += (2).to_bytes(2, 'little') # Block align
283
+ header += (16).to_bytes(2, 'little') # Bits per sample
284
+ # Data chunk
285
+ header += b'data'
286
+ header += b'\x00\x00\x00\x00' # Placeholder for data size
287
+ yield header
288
+
289
+ # Process each segment and stream immediately
290
+ for i, segment_text in enumerate(text_segments):
291
+ try:
292
+ logger.info(f"Generating segment {i+1}/{len(text_segments)}")
293
+
294
+ # For cloned voices, use the voice cloner
295
+ if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled:
296
+ voice_info = request.app.state.get_voice_info(voice)
297
+ if voice_info and voice_info["type"] == "cloned":
298
+ # Use cloned voice
299
+ voice_cloner = request.app.state.voice_cloner
300
+ segment_audio = await asyncio.to_thread(
301
+ voice_cloner.generate_speech,
302
+ segment_text,
303
+ voice_info["voice_id"],
304
+ temperature=temperature,
305
+ topk=30,
306
+ max_audio_length_ms=2000 # Keep it very short for fast generation
307
+ )
308
+ else:
309
+ # Use standard voice with generator
310
+ segment_audio = await asyncio.to_thread(
311
+ request.app.state.generator.generate,
312
+ segment_text,
313
+ speaker_id,
314
+ context,
315
+ max_audio_length_ms=2000, # Short for quicker generation
316
+ temperature=temperature
317
+ )
318
+ else:
319
+ # Use standard voice with generator
320
+ segment_audio = await asyncio.to_thread(
321
+ request.app.state.generator.generate,
322
+ segment_text,
323
+ speaker_id,
324
+ context,
325
+ max_audio_length_ms=2000, # Short for quicker generation
326
+ temperature=temperature
327
+ )
328
+
329
+ # Skip empty or problematic audio
330
+ if segment_audio is None or segment_audio.numel() == 0:
331
+ logger.warning(f"Empty audio for segment {i+1}")
332
+ continue
333
+
334
+ # Convert to bytes and stream immediately
335
+ buf = io.BytesIO()
336
+ audio_to_save = segment_audio.unsqueeze(0) if len(segment_audio.shape) == 1 else segment_audio
337
+ torchaudio.save(buf, audio_to_save.cpu(), sample_rate, format=streaming_format)
338
+ buf.seek(0)
339
+
340
+ # For WAV format, skip the header for all segments after the first
341
+ if streaming_format == "wav" and i > 0:
342
+ buf.seek(44) # Skip WAV header
343
+
344
+ segment_bytes = buf.read()
345
+ yield segment_bytes
346
+
347
+ # Update context with this segment for next generation
348
+ context = [
349
+ Segment(
350
+ text=segment_text,
351
+ speaker=speaker_id,
352
+ audio=segment_audio
353
+ )
354
+ ]
355
+
356
+ except Exception as e:
357
+ logger.error(f"Error generating segment {i+1}: {e}")
358
+ # Continue to next segment
359
+
360
+ # Return the streaming response
361
+ return StreamingResponse(
362
+ generate_streaming_audio(),
363
+ media_type=media_type,
364
+ headers={
365
+ "X-Accel-Buffering": "no", # Prevent buffering in nginx
366
+ "Cache-Control": "no-cache, no-store, must-revalidate",
367
+ "Connection": "keep-alive",
368
+ "Transfer-Encoding": "chunked"
369
+ }
370
+ )
371
+
372
+ @router.post("/audio/speech/streaming", tags=["Audio"])
373
+ async def openai_stream_speech(
374
+ request: Request,
375
+ speech_request: SpeechRequest,
376
+ ):
377
+ """
378
+ Stream audio in OpenAI-compatible streaming format.
379
+
380
+ This endpoint is compatible with the OpenAI streaming TTS API.
381
+ """
382
+ # Use the same logic as the stream_speech endpoint but with a different name
383
+ # to maintain the OpenAI API naming convention
384
+ return await stream_speech(request, speech_request)
385
+
386
+ async def format_audio(audio, response_format, sample_rate, app_state):
387
+ """
388
+ Format audio according to requested format.
389
+
390
+ Args:
391
+ audio: Audio tensor from TTS generation
392
+ response_format: Format as string or enum ('mp3', 'opus', 'aac', 'flac', 'wav')
393
+ sample_rate: Sample rate of the audio
394
+ app_state: FastAPI app state with config and cache settings
395
+
396
+ Returns:
397
+ Tuple of (response_data, content_type)
398
+ """
399
+ import io
400
+ import torch
401
+ import torchaudio
402
+ import tempfile
403
+ import os
404
+ import hashlib
405
+ import time
406
+
407
+ # Handle enum or string for response_format
408
+ if hasattr(response_format, 'value'):
409
+ response_format = response_format.value
410
+
411
+ # Normalize response_format to lowercase
412
+ response_format = str(response_format).lower()
413
+
414
+ # Map formats to content types
415
+ format_to_content_type = {
416
+ 'mp3': 'audio/mpeg',
417
+ 'opus': 'audio/opus',
418
+ 'aac': 'audio/aac',
419
+ 'flac': 'audio/flac',
420
+ 'wav': 'audio/wav'
421
+ }
422
+
423
+ # Ensure response format is supported
424
+ if response_format not in format_to_content_type:
425
+ logger.warning(f"Unsupported format: {response_format}, defaulting to mp3")
426
+ response_format = 'mp3'
427
+
428
+ # Generate a cache key based on audio content and format
429
+ cache_enabled = getattr(app_state, "audio_cache_enabled", False)
430
+ cache_key = None
431
+
432
+ if cache_enabled:
433
+ # Generate a hash of the audio tensor for caching
434
+ audio_hash = hashlib.md5(audio.cpu().numpy().tobytes()).hexdigest()
435
+ cache_key = f"{audio_hash}_{response_format}"
436
+ cache_dir = getattr(app_state, "audio_cache_dir", "/app/audio_cache")
437
+ os.makedirs(cache_dir, exist_ok=True)
438
+ cache_path = os.path.join(cache_dir, f"{cache_key}")
439
+
440
+ # Check if we have a cache hit
441
+ if os.path.exists(cache_path):
442
+ try:
443
+ with open(cache_path, "rb") as f:
444
+ cached_data = f.read()
445
+ logger.info(f"Cache hit for {response_format} audio")
446
+ return cached_data, format_to_content_type[response_format]
447
+ except Exception as e:
448
+ logger.warning(f"Error reading from cache: {e}")
449
+
450
+ # Process audio to the required format
451
+ start_time = time.time()
452
+
453
+ # Move audio to CPU before saving
454
+ audio_cpu = audio.cpu()
455
+
456
+ # Use a temporary file for format conversion
457
+ with tempfile.NamedTemporaryFile(suffix=f".{response_format}", delete=False) as temp_file:
458
+ temp_path = temp_file.name
459
+ try:
460
+ if response_format == 'wav':
461
+ # Direct save for WAV
462
+ torchaudio.save(temp_path, audio_cpu.unsqueeze(0), sample_rate)
463
+ else:
464
+ # For other formats, first save as WAV then convert
465
+ wav_path = f"{temp_path}.wav"
466
+ torchaudio.save(wav_path, audio_cpu.unsqueeze(0), sample_rate)
467
+
468
+ # Use ffmpeg via torchaudio for conversion
469
+ if hasattr(torchaudio.backend, 'sox_io_backend'): # New torchaudio structure
470
+ if response_format == 'mp3':
471
+ # For MP3, use higher quality
472
+ sox_effects = torchaudio.sox_effects.SoxEffectsChain()
473
+ sox_effects.set_input_file(wav_path)
474
+ sox_effects.append_effect_to_chain(["rate", f"{sample_rate}"])
475
+ # Higher bitrate for better quality
476
+ sox_effects.append_effect_to_chain(["gain", "-n"]) # Normalize
477
+ out, _ = sox_effects.sox_build_flow_effects()
478
+ torchaudio.save(temp_path, out, sample_rate, format="mp3", compression=128)
479
+ elif response_format == 'opus':
480
+ # Use ffmpeg for opus through a system call
481
+ import subprocess
482
+ subprocess.run([
483
+ "ffmpeg", "-i", wav_path, "-c:a", "libopus",
484
+ "-b:a", "64k", "-vbr", "on", temp_path,
485
+ "-y", "-loglevel", "error"
486
+ ], check=True)
487
+ elif response_format == 'aac':
488
+ # Use ffmpeg for AAC through a system call
489
+ import subprocess
490
+ subprocess.run([
491
+ "ffmpeg", "-i", wav_path, "-c:a", "aac",
492
+ "-b:a", "128k", temp_path,
493
+ "-y", "-loglevel", "error"
494
+ ], check=True)
495
+ elif response_format == 'flac':
496
+ torchaudio.save(temp_path, audio_cpu.unsqueeze(0), sample_rate, format="flac")
497
+ else:
498
+ # Fallback using external command
499
+ import subprocess
500
+ if response_format == 'mp3':
501
+ subprocess.run([
502
+ "ffmpeg", "-i", wav_path, "-codec:a", "libmp3lame",
503
+ "-qscale:a", "2", temp_path,
504
+ "-y", "-loglevel", "error"
505
+ ], check=True)
506
+ elif response_format == 'opus':
507
+ subprocess.run([
508
+ "ffmpeg", "-i", wav_path, "-c:a", "libopus",
509
+ "-b:a", "64k", "-vbr", "on", temp_path,
510
+ "-y", "-loglevel", "error"
511
+ ], check=True)
512
+ elif response_format == 'aac':
513
+ subprocess.run([
514
+ "ffmpeg", "-i", wav_path, "-c:a", "aac",
515
+ "-b:a", "128k", temp_path,
516
+ "-y", "-loglevel", "error"
517
+ ], check=True)
518
+ elif response_format == 'flac':
519
+ subprocess.run([
520
+ "ffmpeg", "-i", wav_path, "-c:a", "flac", temp_path,
521
+ "-y", "-loglevel", "error"
522
+ ], check=True)
523
+
524
+ # Clean up the temporary WAV file
525
+ try:
526
+ os.unlink(wav_path)
527
+ except:
528
+ pass
529
+
530
+ # Read the processed audio file
531
+ with open(temp_path, "rb") as f:
532
+ response_data = f.read()
533
+
534
+ # Store in cache if enabled
535
+ if cache_enabled and cache_key:
536
+ try:
537
+ cache_path = os.path.join(getattr(app_state, "audio_cache_dir", "/app/audio_cache"), f"{cache_key}")
538
+ with open(cache_path, "wb") as f:
539
+ f.write(response_data)
540
+ logger.debug(f"Cached {response_format} audio with key: {cache_key}")
541
+ except Exception as e:
542
+ logger.warning(f"Error writing to cache: {e}")
543
+
544
+ # Log processing time
545
+ processing_time = time.time() - start_time
546
+ logger.info(f"Processed audio to {response_format} in {processing_time:.3f}s")
547
+
548
+ return response_data, format_to_content_type[response_format]
549
+
550
+ except Exception as e:
551
+ logger.error(f"Error converting audio to {response_format}: {e}")
552
+ # Fallback to WAV if conversion fails
553
+ try:
554
+ wav_path = f"{temp_path}.wav"
555
+ torchaudio.save(wav_path, audio_cpu.unsqueeze(0), sample_rate)
556
+ with open(wav_path, "rb") as f:
557
+ response_data = f.read()
558
+ os.unlink(wav_path)
559
+ return response_data, "audio/wav"
560
+ except Exception as fallback_error:
561
+ logger.error(f"Fallback to WAV also failed: {fallback_error}")
562
+ raise RuntimeError(f"Failed to generate audio in any format: {str(e)}")
563
+
564
+ finally:
565
+ # Clean up the temporary file
566
+ try:
567
+ os.unlink(temp_path)
568
+ except:
569
+ pass
570
+
571
+ @router.post("/audio/conversation", tags=["Conversation API"])
572
+ async def conversation_to_speech(
573
+ request: Request,
574
+ text: str = Body(..., description="Text to convert to speech"),
575
+ speaker_id: int = Body(0, description="Speaker ID"),
576
+ context: List[Dict] = Body([], description="Context segments with speaker, text, and audio path"),
577
+ ):
578
+ """
579
+ Custom endpoint for conversational TTS using CSM-1B.
580
+
581
+ This is not part of the OpenAI API but provides the unique conversational
582
+ capability of the CSM model.
583
+ """
584
+ # Get generator from app state
585
+ generator = request.app.state.generator
586
+
587
+ # Validate model availability
588
+ if generator is None:
589
+ raise HTTPException(status_code=503, detail="Model not loaded")
590
+
591
+ try:
592
+ segments = []
593
+
594
+ # Process context if provided
595
+ for ctx in context:
596
+ if 'speaker' not in ctx or 'text' not in ctx or 'audio' not in ctx:
597
+ continue
598
+
599
+ # Audio should be base64-encoded
600
+ audio_data = base64.b64decode(ctx['audio'])
601
+ audio_file = io.BytesIO(audio_data)
602
+
603
+ # Save to temporary file for torchaudio
604
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp:
605
+ temp.write(audio_file.read())
606
+ temp_path = temp.name
607
+
608
+ # Load audio
609
+ audio_tensor, sample_rate = torchaudio.load(temp_path)
610
+ audio_tensor = torchaudio.functional.resample(
611
+ audio_tensor.squeeze(0),
612
+ orig_freq=sample_rate,
613
+ new_freq=generator.sample_rate
614
+ )
615
+
616
+ # Clean up
617
+ os.unlink(temp_path)
618
+
619
+ # Create segment
620
+ segments.append(
621
+ Segment(
622
+ speaker=ctx['speaker'],
623
+ text=ctx['text'],
624
+ audio=audio_tensor
625
+ )
626
+ )
627
+
628
+ logger.info(f"Conversation request: '{text}' with {len(segments)} context segments")
629
+
630
+ # Format the text for better voice consistency
631
+ from app.prompt_engineering import format_text_for_voice
632
+
633
+ # Determine voice name from speaker_id
634
+ voice_names = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
635
+ voice_name = voice_names[speaker_id] if 0 <= speaker_id < len(voice_names) else "alloy"
636
+
637
+ formatted_text = format_text_for_voice(text, voice_name)
638
+
639
+ # Generate audio with context
640
+ audio = generator.generate(
641
+ text=formatted_text,
642
+ speaker=speaker_id,
643
+ context=segments,
644
+ max_audio_length_ms=20000, # 20 seconds
645
+ temperature=0.7, # Lower temperature for more stable output
646
+ topk=40,
647
+ )
648
+
649
+ # Process audio for better quality
650
+ from app.voice_enhancement import process_generated_audio
651
+
652
+ processed_audio = process_generated_audio(
653
+ audio,
654
+ voice_name,
655
+ generator.sample_rate,
656
+ text
657
+ )
658
+
659
+ # Save to temporary file
660
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp:
661
+ temp_path = temp.name
662
+
663
+ # Save audio
664
+ torchaudio.save(temp_path, processed_audio.unsqueeze(0).cpu(), generator.sample_rate)
665
+
666
+ # Return audio file
667
+ def iterfile():
668
+ with open(temp_path, 'rb') as f:
669
+ yield from f
670
+ # Clean up
671
+ if os.path.exists(temp_path):
672
+ os.unlink(temp_path)
673
+
674
+ logger.info(f"Generated conversation response, duration: {processed_audio.shape[0]/generator.sample_rate:.2f}s")
675
+
676
+ return StreamingResponse(
677
+ iterfile(),
678
+ media_type="audio/wav",
679
+ headers={'Content-Disposition': 'attachment; filename="speech.wav"'}
680
+ )
681
+
682
+ except Exception as e:
683
+ import traceback
684
+ error_trace = traceback.format_exc()
685
+ logger.error(f"Conversation speech generation failed: {str(e)}\n{error_trace}")
686
+ raise HTTPException(status_code=500, detail=f"Conversation speech generation failed: {str(e)}")
687
+
688
+ @router.get("/audio/voices", tags=["Audio"])
689
+ async def list_voices(request: Request):
690
+ """
691
+ List available voices in a format compatible with OpenAI and OpenWebUI.
692
+ """
693
+ # Use app state's get_all_voices function if available
694
+ if hasattr(request.app.state, "get_all_voices"):
695
+ voices = request.app.state.get_all_voices()
696
+ logger.info(f"Listing {len(voices)} voices")
697
+ return {"voices": voices}
698
+
699
+ # Fallback to standard voices if necessary
700
+ standard_voices = [
701
+ {"voice_id": "alloy", "name": "Alloy"},
702
+ {"voice_id": "echo", "name": "Echo"},
703
+ {"voice_id": "fable", "name": "Fable"},
704
+ {"voice_id": "onyx", "name": "Onyx"},
705
+ {"voice_id": "nova", "name": "Nova"},
706
+ {"voice_id": "shimmer", "name": "Shimmer"}
707
+ ]
708
+
709
+ # Add cloned voices if available
710
+ if hasattr(request.app.state, "voice_cloner") and request.app.state.voice_cloner is not None:
711
+ cloned_voices = request.app.state.voice_cloner.list_voices()
712
+ for voice in cloned_voices:
713
+ standard_voices.append({
714
+ "voice_id": voice.id, # This has to be specifically voice_id
715
+ "name": voice.name # This has to be specifically name
716
+ })
717
+
718
+ logger.info(f"Listing {len(standard_voices)} voices")
719
+ return {"voices": standard_voices}
720
+
721
+ # Add OpenAI-compatible models list endpoint
722
+ @router.get("/audio/models", tags=["Audio"], summary="List available audio models")
723
+ async def list_models():
724
+ """
725
+ OpenAI compatible endpoint that returns a list of available audio models.
726
+ """
727
+ models = [
728
+ {
729
+ "id": "csm-1b",
730
+ "name": "CSM-1B",
731
+ "description": "Conversational Speech Model 1B from Sesame",
732
+ "created": 1716019200, # March 13, 2025 (from the example)
733
+ "object": "audio",
734
+ "owned_by": "sesame",
735
+ "capabilities": {
736
+ "tts": True,
737
+ "voice_generation": True,
738
+ "voice_cloning": hasattr(router.app, "voice_cloner"),
739
+ "streaming": True
740
+ },
741
+ "max_input_length": 4096,
742
+ "price": {"text-to-speech": 0.00}
743
+ },
744
+ {
745
+ "id": "tts-1",
746
+ "name": "CSM-1B (Compatibility Mode)",
747
+ "description": "CSM-1B with OpenAI TTS-1 compatibility",
748
+ "created": 1716019200,
749
+ "object": "audio",
750
+ "owned_by": "sesame",
751
+ "capabilities": {
752
+ "tts": True,
753
+ "voice_generation": True,
754
+ "streaming": True
755
+ },
756
+ "max_input_length": 4096,
757
+ "price": {"text-to-speech": 0.00}
758
+ },
759
+ {
760
+ "id": "tts-1-hd",
761
+ "name": "CSM-1B (HD Mode)",
762
+ "description": "CSM-1B with higher quality settings",
763
+ "created": 1716019200,
764
+ "object": "audio",
765
+ "owned_by": "sesame",
766
+ "capabilities": {
767
+ "tts": True,
768
+ "voice_generation": True,
769
+ "streaming": True
770
+ },
771
+ "max_input_length": 4096,
772
+ "price": {"text-to-speech": 0.00}
773
+ }
774
+ ]
775
+
776
+ return {"data": models, "object": "list"}
777
+
778
+ # Response format options endpoint
779
+ @router.get("/audio/speech/response-formats", tags=["Audio"], summary="List available response formats")
780
+ async def list_response_formats():
781
+ """List available response formats for speech synthesis."""
782
+ formats = [
783
+ {"name": "mp3", "content_type": "audio/mpeg"},
784
+ {"name": "opus", "content_type": "audio/opus"},
785
+ {"name": "aac", "content_type": "audio/aac"},
786
+ {"name": "flac", "content_type": "audio/flac"},
787
+ {"name": "wav", "content_type": "audio/wav"}
788
+ ]
789
+
790
+ return {"response_formats": formats}
791
+
792
+ # Streaming format options endpoint
793
+ @router.get("/audio/speech/stream-formats", tags=["Audio"], summary="List available streaming formats")
794
+ async def list_stream_formats():
795
+ """List available streaming formats for TTS."""
796
+ return {
797
+ "stream_formats": [
798
+ {
799
+ "format": "mp3",
800
+ "content_type": "audio/mpeg",
801
+ "description": "MP3 audio format (streaming)"
802
+ },
803
+ {
804
+ "format": "opus",
805
+ "content_type": "audio/opus",
806
+ "description": "Opus audio format (streaming)"
807
+ },
808
+ {
809
+ "format": "aac",
810
+ "content_type": "audio/aac",
811
+ "description": "AAC audio format (streaming)"
812
+ },
813
+ {
814
+ "format": "flac",
815
+ "content_type": "audio/flac",
816
+ "description": "FLAC audio format (streaming)"
817
+ },
818
+ {
819
+ "format": "wav",
820
+ "content_type": "audio/wav",
821
+ "description": "WAV audio format (streaming)"
822
+ }
823
+ ]
824
+ }
825
+
826
+ # Simple test endpoint
827
+ @router.get("/test", tags=["Utility"], summary="Test endpoint")
828
+ async def test_endpoint():
829
+ """Simple test endpoint that returns a successful response."""
830
+ return {"status": "ok", "message": "API is working"}
831
+
832
+ # Debug endpoint
833
+ @router.get("/debug", tags=["Utility"], summary="Debug endpoint")
834
+ async def debug_info(request: Request):
835
+ """Get debug information about the API."""
836
+ generator = request.app.state.generator
837
+
838
+ # Basic info
839
+ debug_info = {
840
+ "model_loaded": generator is not None,
841
+ "device": generator.device if generator is not None else None,
842
+ "sample_rate": generator.sample_rate if generator is not None else None,
843
+ }
844
+
845
+ # Add voice enhancement info if available
846
+ try:
847
+ from app.voice_enhancement import VOICE_PROFILES
848
+ voice_info = {}
849
+ for name, profile in VOICE_PROFILES.items():
850
+ voice_info[name] = {
851
+ "pitch_range": f"{profile.pitch_range[0]}-{profile.pitch_range[1]}Hz",
852
+ "timbre": profile.timbre,
853
+ "ref_segments": len(profile.reference_segments),
854
+ }
855
+ debug_info["voice_profiles"] = voice_info
856
+ except ImportError:
857
+ debug_info["voice_profiles"] = "Not available"
858
+
859
+ # Add voice cloning info if available
860
+ if hasattr(request.app.state, "voice_cloner"):
861
+ voice_cloner = request.app.state.voice_cloner
862
+ debug_info["voice_cloning"] = {
863
+ "enabled": True,
864
+ "cloned_voices_count": len(voice_cloner.list_voices()),
865
+ "cloned_voices": [v.name for v in voice_cloner.list_voices()]
866
+ }
867
+ else:
868
+ debug_info["voice_cloning"] = {"enabled": False}
869
+
870
+ # Add streaming info
871
+ debug_info["streaming"] = {"enabled": True}
872
+
873
+ # Add memory usage info for CUDA
874
+ if torch.cuda.is_available():
875
+ debug_info["cuda"] = {
876
+ "allocated_memory_gb": torch.cuda.memory_allocated() / 1e9,
877
+ "reserved_memory_gb": torch.cuda.memory_reserved() / 1e9,
878
+ "max_memory_gb": torch.cuda.get_device_properties(0).total_memory / 1e9,
879
+ }
880
+
881
+ return debug_info
882
+
883
+ @router.get("/voice-management/info", tags=["Voice Management"])
884
+ async def get_voice_storage_info(request: Request):
885
+ """Get information about voice storage usage and status."""
886
+ from app.utils.voice_manager import get_voice_storage_info
887
+ return get_voice_storage_info()
888
+
889
+ @router.post("/voice-management/backup", tags=["Voice Management"])
890
+ async def create_voice_backup(request: Request):
891
+ """Create a backup of all voice data."""
892
+ from app.utils.voice_manager import backup_voice_data
893
+ backup_path = backup_voice_data()
894
+ return {"status": "success", "backup_path": backup_path}
895
+
896
+ @router.post("/voice-management/reset-voices", tags=["Voice Management"])
897
+ async def reset_voices(request: Request):
898
+ """Reset voices to their default state."""
899
+ from app.utils.voice_manager import restore_default_voices
900
+ backup_path = restore_default_voices()
901
+ return {"status": "success", "backup_path": backup_path, "message": "Voices reset to default state"}
902
+
903
+ @router.get("/voice-management/verify-references", tags=["Voice Management"])
904
+ async def verify_references(request: Request):
905
+ """Check if voice references are complete and valid."""
906
+ from app.utils.voice_manager import verify_voice_references
907
+ return verify_voice_references()
908
+
909
+ # Voice diagnostics endpoint
910
+ @router.get("/debug/voices", tags=["Debug"], summary="Voice diagnostics")
911
+ async def voice_diagnostics():
912
+ """Get diagnostic information about voice references."""
913
+ try:
914
+ from app.voice_enhancement import VOICE_PROFILES
915
+
916
+ diagnostics = {}
917
+ for name, profile in VOICE_PROFILES.items():
918
+ ref_info = []
919
+ for i, ref in enumerate(profile.reference_segments):
920
+ if ref is not None:
921
+ duration = ref.shape[0] / 24000 # Assume 24kHz
922
+ ref_info.append({
923
+ "index": i,
924
+ "duration_seconds": f"{duration:.2f}",
925
+ "samples": ref.shape[0],
926
+ "min": float(ref.min()),
927
+ "max": float(ref.max()),
928
+ "rms": float(torch.sqrt(torch.mean(ref ** 2))),
929
+ })
930
+
931
+ diagnostics[name] = {
932
+ "speaker_id": profile.speaker_id,
933
+ "pitch_range": f"{profile.pitch_range[0]}-{profile.pitch_range[1]}Hz",
934
+ "references": ref_info,
935
+ "reference_count": len(ref_info),
936
+ }
937
+
938
+ return {"diagnostics": diagnostics}
939
+ except ImportError:
940
+ return {"error": "Voice enhancement module not available"}
941
+
942
+ # Specialized debugging endpoint for speech generation
943
+ @router.post("/debug/speech", tags=["Debug"], summary="Debug speech generation")
944
+ async def debug_speech(
945
+ request: Request,
946
+ text: str = Body(..., embed=True),
947
+ voice: str = Body("alloy", embed=True),
948
+ use_enhancement: bool = Body(True, embed=True)
949
+ ):
950
+ """Debug endpoint for speech generation with enhancement options."""
951
+ generator = request.app.state.generator
952
+
953
+ if generator is None:
954
+ return {"error": "Model not loaded"}
955
+
956
+ try:
957
+ # Convert voice name to speaker ID
958
+ voice_map = {
959
+ "alloy": 0,
960
+ "echo": 1,
961
+ "fable": 2,
962
+ "onyx": 3,
963
+ "nova": 4,
964
+ "shimmer": 5
965
+ }
966
+ speaker = voice_map.get(voice, 0)
967
+
968
+ # Format text if using enhancement
969
+ if use_enhancement:
970
+ from app.prompt_engineering import format_text_for_voice
971
+ formatted_text = format_text_for_voice(text, voice)
972
+ logger.info(f"Using formatted text: {formatted_text}")
973
+ else:
974
+ formatted_text = text
975
+
976
+ # Get context if using enhancement
977
+ if use_enhancement:
978
+ from app.voice_enhancement import get_voice_segments
979
+ context = get_voice_segments(voice, generator.device)
980
+ logger.info(f"Using {len(context)} context segments")
981
+ else:
982
+ context = []
983
+
984
+ # Generate audio
985
+ start_time = time.time()
986
+ audio = generator.generate(
987
+ text=formatted_text,
988
+ speaker=speaker,
989
+ context=context,
990
+ max_audio_length_ms=10000, # 10 seconds
991
+ temperature=0.7 if use_enhancement else 0.9,
992
+ topk=40 if use_enhancement else 50,
993
+ )
994
+ generation_time = time.time() - start_time
995
+
996
+ # Process audio if using enhancement
997
+ if use_enhancement:
998
+ from app.voice_enhancement import process_generated_audio
999
+ start_time = time.time()
1000
+ processed_audio = process_generated_audio(audio, voice, generator.sample_rate, text)
1001
+ processing_time = time.time() - start_time
1002
+ else:
1003
+ processed_audio = audio
1004
+ processing_time = 0
1005
+
1006
+ # Save to temporary WAV file
1007
+ temp_path = f"/tmp/debug_speech_{voice}_{int(time.time())}.wav"
1008
+ torchaudio.save(temp_path, processed_audio.unsqueeze(0).cpu(), generator.sample_rate)
1009
+
1010
+ # Also save original if enhanced
1011
+ if use_enhancement:
1012
+ orig_path = f"/tmp/debug_speech_{voice}_original_{int(time.time())}.wav"
1013
+ torchaudio.save(orig_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
1014
+ else:
1015
+ orig_path = temp_path
1016
+
1017
+ # Calculate audio metrics
1018
+ duration = processed_audio.shape[0] / generator.sample_rate
1019
+ rms = float(torch.sqrt(torch.mean(processed_audio ** 2)))
1020
+ peak = float(processed_audio.abs().max())
1021
+
1022
+ return {
1023
+ "status": "success",
1024
+ "message": f"Audio generated successfully and saved to {temp_path}",
1025
+ "audio": {
1026
+ "duration_seconds": f"{duration:.2f}",
1027
+ "samples": processed_audio.shape[0],
1028
+ "sample_rate": generator.sample_rate,
1029
+ "rms_level": f"{rms:.3f}",
1030
+ "peak_level": f"{peak:.3f}",
1031
+ },
1032
+ "processing": {
1033
+ "enhancement_used": use_enhancement,
1034
+ "generation_time_seconds": f"{generation_time:.3f}",
1035
+ "processing_time_seconds": f"{processing_time:.3f}",
1036
+ "original_path": orig_path,
1037
+ "processed_path": temp_path,
1038
+ }
1039
+ }
1040
+ except Exception as e:
1041
+ import traceback
1042
+ error_trace = traceback.format_exc()
1043
+ logger.error(f"Debug speech generation failed: {e}\n{error_trace}")
1044
+ return {
1045
+ "status": "error",
1046
+ "message": str(e),
1047
+ "traceback": error_trace
1048
+ }
app/api/schemas.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/api/schemas.py
2
+ from enum import Enum
3
+ from typing import Optional, List, Dict, Any, Union
4
+ from pydantic import BaseModel, Field
5
+
6
+ # Voice options as a non-restrictive string
7
+ class Voice(str):
8
+ """Voice options for CSM model - allowing any string value"""
9
+ pass
10
+
11
+ class ResponseFormat(str, Enum):
12
+ mp3 = "mp3"
13
+ opus = "opus"
14
+ aac = "aac"
15
+ flac = "flac"
16
+ wav = "wav"
17
+
18
+ # Create SpeechRequest for compatibility with our new code
19
+ class SpeechRequest(BaseModel):
20
+ model: Optional[str] = Field("csm-1b", description="The TTS model to use")
21
+ input: str = Field(..., description="The text to generate audio for")
22
+ voice: Optional[str] = Field("alloy", description="The voice to use for generation")
23
+ response_format: Optional[ResponseFormat] = Field(ResponseFormat.mp3, description="The format of the audio response")
24
+ speed: Optional[float] = Field(1.0, description="The speed of the audio", ge=0.25, le=4.0)
25
+ # CSM-specific parameters
26
+ max_audio_length_ms: Optional[float] = Field(90000, description="Maximum audio length in milliseconds")
27
+ temperature: Optional[float] = Field(0.9, description="Sampling temperature", ge=0.0, le=2.0)
28
+ topk: Optional[int] = Field(50, description="Top-k for sampling", ge=1, le=100)
29
+
30
+ class Config:
31
+ populate_by_name = True
32
+ extra = "ignore" # Allow extra fields without error
33
+
34
+ # Maintain TTSRequest for backward compatibility
35
+ class TTSRequest(SpeechRequest):
36
+ """Legacy alias for SpeechRequest for backward compatibility"""
37
+ pass
38
+
39
+ class TTSResponse(BaseModel):
40
+ """Only used for API documentation"""
41
+ pass
app/api/streaming.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streaming support for the TTS API."""
2
+ import asyncio
3
+ import io
4
+ import logging
5
+ import time
6
+ from typing import AsyncGenerator, Optional, List
7
+ import torch
8
+ import torchaudio
9
+ from fastapi import APIRouter, Request, HTTPException
10
+ from fastapi.responses import StreamingResponse
11
+ from app.api.schemas import SpeechRequest, ResponseFormat
12
+ from app.prompt_engineering import split_into_segments
13
+ from app.models import Segment
14
+
15
+ logger = logging.getLogger(__name__)
16
+ router = APIRouter()
17
+
18
+ class AudioChunker:
19
+ """Handle audio chunking for streaming responses."""
20
+ def __init__(self,
21
+ sample_rate: int,
22
+ format: str = "mp3",
23
+ chunk_size_ms: int = 200): # Smaller chunks for better streaming
24
+ """
25
+ Initialize audio chunker.
26
+ Args:
27
+ sample_rate: Audio sample rate in Hz
28
+ format: Output audio format (mp3, opus, etc.)
29
+ chunk_size_ms: Size of each chunk in milliseconds
30
+ """
31
+ self.sample_rate = sample_rate
32
+ self.format = format.lower()
33
+ self.chunk_size_samples = int(sample_rate * (chunk_size_ms / 1000))
34
+ logger.info(f"Audio chunker initialized with {chunk_size_ms}ms chunks ({self.chunk_size_samples} samples)")
35
+
36
+ async def chunk_audio(self,
37
+ audio: torch.Tensor,
38
+ delay_ms: int = 0) -> AsyncGenerator[bytes, None]:
39
+ """
40
+ Convert audio tensor to streaming chunks.
41
+ Args:
42
+ audio: Audio tensor to stream
43
+ delay_ms: Artificial delay between chunks (for testing)
44
+ Yields:
45
+ Audio chunks as bytes
46
+ """
47
+ # Ensure audio is on CPU
48
+ if audio.is_cuda:
49
+ audio = audio.cpu()
50
+ # Calculate number of chunks
51
+ num_samples = audio.shape[0]
52
+ num_chunks = (num_samples + self.chunk_size_samples - 1) // self.chunk_size_samples
53
+ logger.info(f"Streaming {num_samples} samples as {num_chunks} chunks")
54
+ for i in range(num_chunks):
55
+ start_idx = i * self.chunk_size_samples
56
+ end_idx = min(start_idx + self.chunk_size_samples, num_samples)
57
+ # Extract chunk
58
+ chunk = audio[start_idx:end_idx]
59
+ # Convert to bytes in requested format
60
+ chunk_bytes = await self._format_chunk(chunk)
61
+ # Add artificial delay if requested (for testing)
62
+ if delay_ms > 0:
63
+ await asyncio.sleep(delay_ms / 1000)
64
+ yield chunk_bytes
65
+
66
+ async def _format_chunk(self, chunk: torch.Tensor) -> bytes:
67
+ """Convert audio chunk to bytes in the specified format."""
68
+ buf = io.BytesIO()
69
+ # Ensure chunk is 1D and on CPU
70
+ if len(chunk.shape) == 1:
71
+ chunk = chunk.unsqueeze(0) # Add channel dimension
72
+ # Ensure chunk is on CPU
73
+ if chunk.is_cuda:
74
+ chunk = chunk.cpu()
75
+ # Save to buffer in specified format
76
+ if self.format == "mp3":
77
+ torchaudio.save(buf, chunk, self.sample_rate, format="mp3")
78
+ elif self.format == "opus":
79
+ torchaudio.save(buf, chunk, self.sample_rate, format="opus")
80
+ elif self.format == "aac":
81
+ torchaudio.save(buf, chunk, self.sample_rate, format="aac")
82
+ elif self.format == "flac":
83
+ torchaudio.save(buf, chunk, self.sample_rate, format="flac")
84
+ elif self.format == "wav":
85
+ torchaudio.save(buf, chunk, self.sample_rate, format="wav")
86
+ else:
87
+ # Default to mp3
88
+ torchaudio.save(buf, chunk, self.sample_rate, format="mp3")
89
+ # Get bytes from buffer
90
+ buf.seek(0)
91
+ return buf.read()
92
+
93
+ # Helper function to get speaker ID for a voice
94
+ def get_speaker_id(app_state, voice):
95
+ """Helper function to get speaker ID from voice name or ID"""
96
+ if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map:
97
+ return app_state.voice_speaker_map[voice]
98
+ # Standard voices mapping
99
+ voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5}
100
+ if voice in voice_to_speaker:
101
+ return voice_to_speaker[voice]
102
+ # Try parsing as integer
103
+ try:
104
+ speaker_id = int(voice)
105
+ if 0 <= speaker_id < 6:
106
+ return speaker_id
107
+ except (ValueError, TypeError):
108
+ pass
109
+ # Check cloned voices if the voice cloner exists
110
+ if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None:
111
+ # Check by ID
112
+ if voice in app_state.voice_cloner.cloned_voices:
113
+ return app_state.voice_cloner.cloned_voices[voice].speaker_id
114
+ # Check by name
115
+ for v_id, v_info in app_state.voice_cloner.cloned_voices.items():
116
+ if v_info.name.lower() == voice.lower():
117
+ return v_info.speaker_id
118
+ # Default to alloy
119
+ return 0
120
+
121
+ @router.post("/audio/speech/stream", tags=["Audio"])
122
+ async def stream_speech(
123
+ request: Request,
124
+ speech_request: SpeechRequest,
125
+ ):
126
+ """
127
+ Stream audio of text being spoken by a realistic voice.
128
+ This endpoint provides an OpenAI-compatible streaming interface for TTS.
129
+ """
130
+ # Check if model is loaded
131
+ if not hasattr(request.app.state, "generator") or request.app.state.generator is None:
132
+ raise HTTPException(
133
+ status_code=503,
134
+ detail="Model not loaded. Please try again later."
135
+ )
136
+
137
+ # Get request parameters
138
+ model = speech_request.model
139
+ input_text = speech_request.input
140
+ voice = speech_request.voice
141
+ response_format = speech_request.response_format
142
+ speed = speech_request.speed
143
+ temperature = speech_request.temperature
144
+ max_audio_length_ms = speech_request.max_audio_length_ms
145
+
146
+ # Log the request
147
+ logger.info(f"Real-time streaming speech from text ({len(input_text)} chars) with voice '{voice}'")
148
+
149
+ # Check if text is empty
150
+ if not input_text or len(input_text.strip()) == 0:
151
+ raise HTTPException(
152
+ status_code=400,
153
+ detail="Input text cannot be empty"
154
+ )
155
+
156
+ # Get speaker ID for the voice
157
+ speaker_id = get_speaker_id(request.app.state, voice)
158
+ if speaker_id is None:
159
+ raise HTTPException(
160
+ status_code=400,
161
+ detail=f"Voice '{voice}' not found. Available voices: {request.app.state.available_voices}"
162
+ )
163
+
164
+ try:
165
+ # Create media type based on format
166
+ media_type = {
167
+ "mp3": "audio/mpeg",
168
+ "opus": "audio/opus",
169
+ "aac": "audio/aac",
170
+ "flac": "audio/flac",
171
+ "wav": "audio/wav",
172
+ }.get(response_format, "audio/mpeg")
173
+
174
+ # Create the chunker for streaming
175
+ sample_rate = request.app.state.sample_rate
176
+ chunker = AudioChunker(sample_rate, response_format)
177
+
178
+ # Split text into segments using the imported function
179
+ from app.prompt_engineering import split_into_segments
180
+ text_segments = split_into_segments(input_text, max_chars=50) # Smaller segments for faster first response
181
+
182
+ logger.info(f"Split text into {len(text_segments)} segments for incremental streaming")
183
+
184
+ async def generate_streaming_audio():
185
+ # Check for cloned voice
186
+ voice_info = None
187
+ from_cloned_voice = False
188
+
189
+ if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled:
190
+ voice_info = request.app.state.get_voice_info(voice)
191
+ from_cloned_voice = voice_info and voice_info["type"] == "cloned"
192
+ if from_cloned_voice:
193
+ # Use cloned voice context for first segment
194
+ voice_cloner = request.app.state.voice_cloner
195
+ context = voice_cloner.get_voice_context(voice_info["voice_id"])
196
+ else:
197
+ # Use standard voice context
198
+ from app.voice_enhancement import get_voice_segments
199
+ context = get_voice_segments(voice, request.app.state.device)
200
+ else:
201
+ # Use standard voice context
202
+ from app.voice_enhancement import get_voice_segments
203
+ context = get_voice_segments(voice, request.app.state.device)
204
+
205
+ # Send an empty chunk to initialize the connection
206
+ yield b''
207
+
208
+ # Process each text segment incrementally and stream in real time
209
+ for i, segment_text in enumerate(text_segments):
210
+ try:
211
+ logger.info(f"Generating segment {i+1}/{len(text_segments)}")
212
+
213
+ # Generate audio for this segment - use async to avoid blocking
214
+ if from_cloned_voice:
215
+ # Generate with cloned voice
216
+ voice_cloner = request.app.state.voice_cloner
217
+
218
+ # Convert to asynchronous with asyncio.to_thread
219
+ segment_audio = await asyncio.to_thread(
220
+ voice_cloner.generate_speech,
221
+ segment_text,
222
+ voice_info["voice_id"],
223
+ temperature=temperature,
224
+ topk=30,
225
+ max_audio_length_ms=2000 # Keep segments short for streaming
226
+ )
227
+ else:
228
+ # Use standard voice with generator
229
+ segment_audio = await asyncio.to_thread(
230
+ request.app.state.generator.generate,
231
+ segment_text,
232
+ speaker_id,
233
+ context,
234
+ max_audio_length_ms=2000, # Short for quicker generation
235
+ temperature=temperature
236
+ )
237
+
238
+ # Process audio quality for this segment
239
+ if hasattr(request.app.state, "voice_enhancement_enabled") and request.app.state.voice_enhancement_enabled:
240
+ from app.voice_enhancement import process_generated_audio
241
+ segment_audio = process_generated_audio(
242
+ audio=segment_audio,
243
+ voice_name=voice,
244
+ sample_rate=sample_rate,
245
+ text=segment_text
246
+ )
247
+
248
+ # Handle speed adjustment
249
+ if speed != 1.0 and speed > 0:
250
+ try:
251
+ # Adjust speed using torchaudio
252
+ effects = [["tempo", str(speed)]]
253
+ audio_cpu = segment_audio.cpu()
254
+ adjusted_audio, _= torchaudio.sox_effects.apply_effects_tensor(
255
+ audio_cpu.unsqueeze(0),
256
+ sample_rate,
257
+ effects
258
+ )
259
+ segment_audio = adjusted_audio.squeeze(0)
260
+ except Exception as e:
261
+ logger.warning(f"Failed to adjust speech speed: {e}")
262
+
263
+ # Convert this segment to bytes and stream immediately
264
+ buf = io.BytesIO()
265
+ audio_to_save = segment_audio.unsqueeze(0) if len(segment_audio.shape) == 1 else segment_audio
266
+ torchaudio.save(buf, audio_to_save.cpu(), sample_rate, format=response_format)
267
+ buf.seek(0)
268
+ segment_bytes = buf.read()
269
+
270
+ # Stream this segment immediately
271
+ yield segment_bytes
272
+
273
+ # Update context with this segment for next generation
274
+ context = [
275
+ Segment(
276
+ text=segment_text,
277
+ speaker=speaker_id,
278
+ audio=segment_audio
279
+ )
280
+ ]
281
+
282
+ except Exception as e:
283
+ logger.error(f"Error generating segment {i+1}: {e}")
284
+ # Try to continue with next segment
285
+
286
+ # Return streaming response
287
+ return StreamingResponse(
288
+ generate_streaming_audio(),
289
+ media_type=media_type,
290
+ headers={
291
+ "Content-Disposition": f'attachment; filename="speech.{response_format}"',
292
+ "X-Accel-Buffering": "no", # Prevent buffering in nginx
293
+ "Cache-Control": "no-cache, no-store, must-revalidate", # Prevent caching
294
+ "Pragma": "no-cache",
295
+ "Expires": "0",
296
+ "Connection": "keep-alive",
297
+ "Transfer-Encoding": "chunked"
298
+ }
299
+ )
300
+ except Exception as e:
301
+ logger.error(f"Error in stream_speech: {e}")
302
+ raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
303
+
304
+ @router.post("/audio/speech/streaming", tags=["Audio"])
305
+ async def openai_stream_speech(
306
+ request: Request,
307
+ speech_request: SpeechRequest,
308
+ ):
309
+ """
310
+ Stream audio in OpenAI-compatible streaming format.
311
+ This endpoint is compatible with the OpenAI streaming TTS API.
312
+ """
313
+ # Use the same logic as the stream_speech endpoint but with a different name
314
+ # to maintain the OpenAI API naming convention
315
+ return await stream_speech(request, speech_request)
app/api/utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_all_voices(app_state):
2
+ """
3
+ Get all available voices including standard and cloned voices.
4
+ Returns them in a format compatible with OpenAI's API.
5
+ """
6
+ # Standard voices
7
+ voices = [
8
+ {"voice_id": "alloy", "name": "Alloy"},
9
+ {"voice_id": "echo", "name": "Echo"},
10
+ {"voice_id": "fable", "name": "Fable"},
11
+ {"voice_id": "onyx", "name": "Onyx"},
12
+ {"voice_id": "nova", "name": "Nova"},
13
+ {"voice_id": "shimmer", "name": "Shimmer"}
14
+ ]
15
+
16
+ # Add cloned voices if available
17
+ if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None:
18
+ cloned_voices = app_state.voice_cloner.list_voices()
19
+ for voice in cloned_voices:
20
+ voices.append({
21
+ "voice_id": voice.id,
22
+ "name": voice.name
23
+ })
24
+
25
+ return voices
app/api/voice_cloning_routes.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Voice cloning API routes for CSM-1B TTS API."""
2
+ import os
3
+ import io
4
+ import time
5
+ import tempfile
6
+ from typing import Dict, List, Optional, Any
7
+ import torch
8
+ import torchaudio
9
+ from fastapi import APIRouter, Request, Response, HTTPException, UploadFile, File, Form, Body
10
+ from fastapi.responses import StreamingResponse, JSONResponse
11
+ from app.voice_cloning import ClonedVoice
12
+
13
+ # Create router
14
+ router = APIRouter(prefix="/voice-cloning", tags=["Voice Cloning"])
15
+
16
+ @router.post("/clone", summary="Clone a new voice")
17
+ async def clone_voice(
18
+ request: Request,
19
+ audio_file: UploadFile = File(...),
20
+ name: str = Form(...),
21
+ transcript: Optional[str] = Form(None),
22
+ description: Optional[str] = Form(None)
23
+ ):
24
+ """
25
+ Clone a new voice from an audio file.
26
+
27
+ - **audio_file**: Audio file with the voice to clone (MP3, WAV, etc.)
28
+ - **name**: Name for the cloned voice
29
+ - **transcript**: Optional transcript of the audio
30
+ - **description**: Optional description of the voice
31
+ """
32
+ if not hasattr(request.app.state, "voice_cloner"):
33
+ raise HTTPException(status_code=503, detail="Voice cloning service not available")
34
+
35
+ voice_cloner = request.app.state.voice_cloner
36
+
37
+ try:
38
+ voice = await voice_cloner.clone_voice(
39
+ audio_file=audio_file,
40
+ voice_name=name,
41
+ transcript=transcript,
42
+ description=description
43
+ )
44
+
45
+ return {
46
+ "status": "success",
47
+ "message": "Voice cloned successfully",
48
+ "voice": voice
49
+ }
50
+ except Exception as e:
51
+ import traceback
52
+ error_trace = traceback.format_exc()
53
+ request.app.state.logger.error(f"Voice cloning failed: {e}\n{error_trace}")
54
+ raise HTTPException(status_code=500, detail=f"Voice cloning failed: {str(e)}")
55
+
56
+ @router.get("/voices", summary="List cloned voices")
57
+ async def list_voices(request: Request):
58
+ """List all cloned voices available in the system."""
59
+ if not hasattr(request.app.state, "voice_cloner"):
60
+ raise HTTPException(status_code=503, detail="Voice cloning service not available")
61
+
62
+ voice_cloner = request.app.state.voice_cloner
63
+ voices = voice_cloner.list_voices()
64
+
65
+ return {
66
+ "voices": voices
67
+ }
68
+
69
+ @router.delete("/voices/{voice_id}", summary="Delete a cloned voice")
70
+ async def delete_voice(request: Request, voice_id: str):
71
+ """Delete a cloned voice by ID."""
72
+ if not hasattr(request.app.state, "voice_cloner"):
73
+ raise HTTPException(status_code=503, detail="Voice cloning service not available")
74
+
75
+ voice_cloner = request.app.state.voice_cloner
76
+ success = voice_cloner.delete_voice(voice_id)
77
+
78
+ if not success:
79
+ raise HTTPException(status_code=404, detail=f"Voice with ID {voice_id} not found")
80
+
81
+ return {
82
+ "status": "success",
83
+ "message": f"Voice {voice_id} deleted successfully"
84
+ }
85
+
86
+ @router.post("/clone-from-youtube", summary="Clone a voice from YouTube")
87
+ async def clone_voice_from_youtube(
88
+ request: Request,
89
+ data: dict = Body(...) # Use a single body parameter to avoid conflicts
90
+ ):
91
+ """
92
+ Clone a voice from a YouTube video.
93
+
94
+ - **youtube_url**: URL of the YouTube video
95
+ - **voice_name**: Name for the cloned voice
96
+ - **start_time**: Start time in seconds (default: 0)
97
+ - **duration**: Duration to extract in seconds (default: 180)
98
+ - **description**: Optional description of the voice
99
+ """
100
+ if not hasattr(request.app.state, "voice_cloner"):
101
+ raise HTTPException(status_code=503, detail="Voice cloning service not available")
102
+
103
+ voice_cloner = request.app.state.voice_cloner
104
+
105
+ # Extract parameters from the request body
106
+ youtube_url = data.get("youtube_url")
107
+ voice_name = data.get("voice_name")
108
+ start_time = data.get("start_time", 0)
109
+ duration = data.get("duration", 180)
110
+ description = data.get("description")
111
+
112
+ # Validate required parameters
113
+ if not youtube_url or not voice_name:
114
+ raise HTTPException(status_code=400, detail="Missing required parameters: youtube_url and voice_name")
115
+
116
+ try:
117
+ voice = await voice_cloner.clone_voice_from_youtube(
118
+ youtube_url=youtube_url,
119
+ voice_name=voice_name,
120
+ start_time=start_time,
121
+ duration=duration,
122
+ description=description
123
+ )
124
+
125
+ return {
126
+ "status": "success",
127
+ "message": "Voice cloned successfully from YouTube",
128
+ "voice": voice
129
+ }
130
+ except Exception as e:
131
+ import traceback
132
+ error_trace = traceback.format_exc()
133
+ request.app.state.logger.error(f"YouTube voice cloning failed: {e}\n{error_trace}")
134
+ raise HTTPException(status_code=500, detail=f"YouTube voice cloning failed: {str(e)}")
135
+
136
+ @router.post("/generate", summary="Generate speech with cloned voice")
137
+ async def generate_speech(
138
+ request: Request,
139
+ voice_id: str = Body(..., embed=True),
140
+ text: str = Body(..., embed=True),
141
+ temperature: float = Body(0.65, embed=True),
142
+ response_format: str = Body("mp3", embed=True)
143
+ ):
144
+ """
145
+ Generate speech using a cloned voice.
146
+
147
+ - **voice_id**: ID of the cloned voice to use
148
+ - **text**: Text to synthesize
149
+ - **temperature**: Sampling temperature (lower = more stable, higher = more varied)
150
+ - **response_format**: Audio format (mp3, wav, etc.)
151
+ """
152
+ if not hasattr(request.app.state, "voice_cloner"):
153
+ raise HTTPException(status_code=503, detail="Voice cloning service not available")
154
+
155
+ voice_cloner = request.app.state.voice_cloner
156
+
157
+ # Validate voice ID
158
+ if voice_id not in voice_cloner.cloned_voices:
159
+ raise HTTPException(status_code=404, detail=f"Voice with ID {voice_id} not found")
160
+
161
+ # MIME type mapping
162
+ mime_types = {
163
+ "mp3": "audio/mpeg",
164
+ "wav": "audio/wav",
165
+ "ogg": "audio/ogg",
166
+ "flac": "audio/flac",
167
+ "m4a": "audio/mp4",
168
+ }
169
+
170
+ # Set default if format not specified
171
+ if response_format not in mime_types:
172
+ response_format = "mp3"
173
+
174
+ try:
175
+ # Generate speech with the cloned voice - IMPORTANT: This is a synchronous operation
176
+ # Remove the await keyword here
177
+ audio = voice_cloner.generate_speech(
178
+ text=text,
179
+ voice_id=voice_id,
180
+ temperature=temperature
181
+ )
182
+
183
+ # Create temporary file for audio conversion
184
+ with tempfile.NamedTemporaryFile(suffix=f".{response_format}", delete=False) as temp_file:
185
+ temp_path = temp_file.name
186
+
187
+ # Save to WAV first (direct format for torchaudio)
188
+ wav_path = f"{temp_path}.wav"
189
+ torchaudio.save(wav_path, audio.unsqueeze(0).cpu(), voice_cloner.sample_rate)
190
+
191
+ # Convert to requested format
192
+ import ffmpeg
193
+
194
+ if response_format == "mp3":
195
+ (
196
+ ffmpeg.input(wav_path)
197
+ .output(temp_path, format='mp3', audio_bitrate='128k')
198
+ .run(quiet=True, overwrite_output=True)
199
+ )
200
+ elif response_format == "ogg":
201
+ (
202
+ ffmpeg.input(wav_path)
203
+ .output(temp_path, format='ogg')
204
+ .run(quiet=True, overwrite_output=True)
205
+ )
206
+ elif response_format == "flac":
207
+ (
208
+ ffmpeg.input(wav_path)
209
+ .output(temp_path, format='flac')
210
+ .run(quiet=True, overwrite_output=True)
211
+ )
212
+ elif response_format == "m4a":
213
+ (
214
+ ffmpeg.input(wav_path)
215
+ .output(temp_path, format='mp4')
216
+ .run(quiet=True, overwrite_output=True)
217
+ )
218
+ else: # wav
219
+ temp_path = wav_path
220
+
221
+ # Clean up the temporary WAV file if we created a different format
222
+ if temp_path != wav_path and os.path.exists(wav_path):
223
+ os.unlink(wav_path)
224
+
225
+ # Return audio file as response
226
+ def iterfile():
227
+ with open(temp_path, 'rb') as f:
228
+ yield from f
229
+ # Clean up temp file after streaming
230
+ if os.path.exists(temp_path):
231
+ os.unlink(temp_path)
232
+
233
+ return StreamingResponse(
234
+ iterfile(),
235
+ media_type=mime_types.get(response_format, "application/octet-stream"),
236
+ headers={'Content-Disposition': f'attachment; filename="speech.{response_format}"'}
237
+ )
238
+
239
+ except Exception as e:
240
+ import traceback
241
+ error_trace = traceback.format_exc()
242
+ request.app.state.logger.error(f"Speech generation failed: {e}\n{error_trace}")
243
+ raise HTTPException(status_code=500, detail=f"Speech generation failed: {str(e)}")
244
+
245
+ @router.post("/voices/{voice_id}/preview", summary="Generate a preview of a cloned voice")
246
+ async def generate_preview(
247
+ request: Request,
248
+ voice_id: str,
249
+ text: Optional[str] = Body("This is a preview of my cloned voice.", embed=True),
250
+ response_format: str = Body("mp3", embed=True)
251
+ ):
252
+ """
253
+ Generate a preview of a cloned voice with a standard text.
254
+
255
+ - **voice_id**: ID of the cloned voice to use
256
+ - **text**: Optional custom text for the preview
257
+ - **response_format**: Audio format (mp3, wav, etc.)
258
+ """
259
+ # Use the generate_speech endpoint with a standard text
260
+ return await generate_speech(
261
+ request=request,
262
+ voice_id=voice_id,
263
+ text=text,
264
+ temperature=0.7,
265
+ response_format=response_format
266
+ )
267
+
268
+ @router.get("/openai-compatible-voices", summary="List cloned voices in OpenAI format")
269
+ async def list_voices_openai_format(request: Request):
270
+ """List all cloned voices in OpenAI-compatible format."""
271
+ if not hasattr(request.app.state, "voice_cloner"):
272
+ raise HTTPException(status_code=503, detail="Voice cloning service not available")
273
+
274
+ voice_cloner = request.app.state.voice_cloner
275
+ voices = voice_cloner.list_voices()
276
+
277
+ # Format voices in OpenAI-compatible format
278
+ openai_voices = []
279
+ for voice in voices:
280
+ openai_voices.append({
281
+ "voice_id": voice.id,
282
+ "name": voice.name,
283
+ "preview_url": f"/v1/voice-cloning/voices/{voice.id}/preview",
284
+ "description": voice.description or f"Cloned voice: {voice.name}",
285
+ "languages": [{"language_code": "en", "name": "English"}],
286
+ "cloned": True
287
+ })
288
+
289
+ return {"voices": openai_voices}
app/audio_processing.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio processing utilities for CSM-1B TTS API."""
2
+ import logging
3
+ import numpy as np
4
+ import torch
5
+ from scipy import signal
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ def remove_long_silences(
10
+ audio: torch.Tensor,
11
+ sample_rate: int,
12
+ min_speech_energy: float = 0.015,
13
+ max_silence_sec: float = 0.4,
14
+ keep_silence_sec: float = 0.1,
15
+ ) -> torch.Tensor:
16
+ """
17
+ Remove uncomfortably long silences from audio while preserving natural pauses.
18
+
19
+ Args:
20
+ audio: Audio tensor
21
+ sample_rate: Sample rate in Hz
22
+ min_speech_energy: Minimum RMS energy to consider as speech
23
+ max_silence_sec: Maximum silence duration to keep in seconds
24
+ keep_silence_sec: Amount of silence to keep at speech boundaries
25
+
26
+ Returns:
27
+ Audio with long silences removed
28
+ """
29
+ # Convert to numpy for processing
30
+ audio_np = audio.cpu().numpy()
31
+
32
+ # Calculate frame size and hop length
33
+ frame_size = int(0.02 * sample_rate) # 20ms frames
34
+ hop_length = int(0.01 * sample_rate) # 10ms hop
35
+
36
+ # Compute frame energy
37
+ frames = []
38
+ for i in range(0, len(audio_np) - frame_size + 1, hop_length):
39
+ frames.append(audio_np[i:i+frame_size])
40
+
41
+ if len(frames) < 2: # If audio is too short for analysis
42
+ return audio
43
+
44
+ frames = np.array(frames)
45
+ # Root mean square energy
46
+ frame_energy = np.sqrt(np.mean(frames**2, axis=1))
47
+
48
+ # Adaptive threshold based on audio content
49
+ # Uses a percentile to adapt to different audio characteristics
50
+ energy_threshold = max(
51
+ min_speech_energy, # Minimum threshold
52
+ np.percentile(frame_energy, 10) # Adapt to audio
53
+ )
54
+
55
+ # Identify speech frames
56
+ is_speech = frame_energy > energy_threshold
57
+
58
+ # Convert frame indices to sample indices considering overlapping frames
59
+ speech_segments = []
60
+ in_speech = False
61
+ speech_start = 0
62
+
63
+ for i in range(len(is_speech)):
64
+ if is_speech[i] and not in_speech:
65
+ # Start of speech
66
+ in_speech = True
67
+ # Calculate start sample including keep_silence
68
+ speech_start = max(0, i * hop_length - int(keep_silence_sec * sample_rate))
69
+
70
+ elif not is_speech[i] and in_speech:
71
+ # Potential end of speech, look ahead to check if silence continues
72
+ silence_length = 0
73
+ for j in range(i, min(len(is_speech), i + int(max_silence_sec * sample_rate / hop_length))):
74
+ if not is_speech[j]:
75
+ silence_length += 1
76
+ else:
77
+ break
78
+
79
+ if silence_length * hop_length >= max_silence_sec * sample_rate:
80
+ # End of speech, long enough silence detected
81
+ in_speech = False
82
+ # Calculate end sample including keep_silence
83
+ speech_end = min(len(audio_np), i * hop_length + int(keep_silence_sec * sample_rate))
84
+ speech_segments.append((speech_start, speech_end))
85
+
86
+ # Handle the case where audio ends during speech
87
+ if in_speech:
88
+ speech_segments.append((speech_start, len(audio_np)))
89
+
90
+ if not speech_segments:
91
+ logger.warning("No speech segments detected, returning original audio")
92
+ return audio
93
+
94
+ # Combine speech segments with controlled silence durations
95
+ result = []
96
+
97
+ # Add initial silence if the first segment doesn't start at the beginning
98
+ if speech_segments[0][0] > 0:
99
+ # Add a short leading silence (100ms)
100
+ silence_samples = min(int(0.1 * sample_rate), speech_segments[0][0])
101
+ if silence_samples > 0:
102
+ result.append(audio_np[speech_segments[0][0] - silence_samples:speech_segments[0][0]])
103
+
104
+ # Process each speech segment
105
+ for i, (start, end) in enumerate(speech_segments):
106
+ # Add this speech segment
107
+ result.append(audio_np[start:end])
108
+
109
+ # Add a controlled silence between segments
110
+ if i < len(speech_segments) - 1:
111
+ next_start = speech_segments[i+1][0]
112
+ # Calculate available silence duration
113
+ available_silence = next_start - end
114
+
115
+ if available_silence > 0:
116
+ # Use either the actual silence (if shorter than max) or the max allowed
117
+ silence_duration = min(available_silence, int(max_silence_sec * sample_rate))
118
+ # Take the first portion of the silence - usually cleaner
119
+ result.append(audio_np[end:end + silence_duration])
120
+
121
+ # Combine all parts
122
+ processed_audio = np.concatenate(result)
123
+
124
+ # Log the results
125
+ original_duration = len(audio_np) / sample_rate
126
+ processed_duration = len(processed_audio) / sample_rate
127
+ logger.info(f"Silence removal: {original_duration:.2f}s -> {processed_duration:.2f}s ({processed_duration/original_duration*100:.1f}%)")
128
+
129
+ # Return as tensor with original device and dtype
130
+ return torch.tensor(processed_audio, device=audio.device, dtype=audio.dtype)
131
+
132
+ def create_high_shelf_filter(audio, sample_rate, frequency=4000, gain_db=3.0):
133
+ """
134
+ Create a high shelf filter to boost frequencies above the given frequency.
135
+
136
+ Args:
137
+ audio: Audio numpy array
138
+ sample_rate: Sample rate in Hz
139
+ frequency: Shelf frequency in Hz
140
+ gain_db: Gain in dB for frequencies above the shelf
141
+
142
+ Returns:
143
+ Filtered audio
144
+ """
145
+ # Convert gain from dB to linear
146
+ gain = 10 ** (gain_db / 20.0)
147
+
148
+ # Normalized frequency (0 to 1, where 1 is Nyquist frequency)
149
+ normalized_freq = 2.0 * frequency / sample_rate
150
+
151
+ # Design a high-shelf biquad filter
152
+ # This is a standard second-order section (SOS) implementation
153
+ b0 = gain
154
+ b1 = 0
155
+ b2 = 0
156
+ a0 = 1
157
+ a1 = 0
158
+ a2 = 0
159
+
160
+ # Simple first-order high-shelf filter
161
+ alpha = np.sin(np.pi * normalized_freq) / 2 * np.sqrt((gain + 1/gain) * (1/0.5 - 1) + 2)
162
+ cos_w0 = np.cos(np.pi * normalized_freq)
163
+
164
+ b0 = gain * ((gain + 1) + (gain - 1) * cos_w0 + 2 * np.sqrt(gain) * alpha)
165
+ b1 = -2 * gain * ((gain - 1) + (gain + 1) * cos_w0)
166
+ b2 = gain * ((gain + 1) + (gain - 1) * cos_w0 - 2 * np.sqrt(gain) * alpha)
167
+ a0 = (gain + 1) - (gain - 1) * cos_w0 + 2 * np.sqrt(gain) * alpha
168
+ a1 = 2 * ((gain - 1) - (gain + 1) * cos_w0)
169
+ a2 = (gain + 1) - (gain - 1) * cos_w0 - 2 * np.sqrt(gain) * alpha
170
+
171
+ # Normalize coefficients
172
+ b = np.array([b0, b1, b2]) / a0
173
+ a = np.array([1.0, a1/a0, a2/a0])
174
+
175
+ # Apply the filter
176
+ return signal.lfilter(b, a, audio)
177
+
178
+ def enhance_audio_quality(audio: torch.Tensor, sample_rate: int) -> torch.Tensor:
179
+ """
180
+ Enhance audio quality by applying various processing techniques.
181
+
182
+ Args:
183
+ audio: Audio tensor
184
+ sample_rate: Sample rate in Hz
185
+
186
+ Returns:
187
+ Enhanced audio tensor
188
+ """
189
+ try:
190
+ audio_np = audio.cpu().numpy()
191
+
192
+ # Remove DC offset
193
+ audio_np = audio_np - np.mean(audio_np)
194
+
195
+ # Apply light compression to improve perceived loudness
196
+ # Compress by reducing peaks and increasing quieter parts slightly
197
+ threshold = 0.5
198
+ ratio = 1.5
199
+ attack = 0.01
200
+ release = 0.1
201
+
202
+ # Simple compression algorithm
203
+ gain = np.ones_like(audio_np)
204
+ for i in range(1, len(audio_np)):
205
+ level = abs(audio_np[i])
206
+ if level > threshold:
207
+ gain[i] = threshold + (level - threshold) / ratio
208
+ gain[i] = gain[i] / level if level > 0 else 1.0
209
+ else:
210
+ gain[i] = 1.0
211
+
212
+ # Smooth gain changes
213
+ gain[i] = gain[i-1] + (gain[i] - gain[i-1]) * (attack if gain[i] < gain[i-1] else release)
214
+
215
+ audio_np = audio_np * gain
216
+
217
+ # Apply high-shelf filter to enhance speech clarity
218
+ # Boost frequencies above 4000 Hz by 3 dB
219
+ audio_np = create_high_shelf_filter(audio_np, sample_rate, frequency=4000, gain_db=3.0)
220
+
221
+ # Normalize to prevent clipping
222
+ max_val = np.max(np.abs(audio_np))
223
+ if max_val > 0:
224
+ audio_np = audio_np * 0.95 / max_val
225
+
226
+ return torch.tensor(audio_np, device=audio.device, dtype=audio.dtype)
227
+
228
+ except Exception as e:
229
+ logger.warning(f"Audio quality enhancement failed: {e}")
230
+ return audio
app/custom_transformer.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom transformer implementation for fallback."""
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ import logging
6
+
7
+ # Set up logging
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class RMSNorm(nn.Module):
11
+ """Root Mean Square Layer Normalization."""
12
+
13
+ def __init__(self, dim: int, eps: float = 1e-6):
14
+ super().__init__()
15
+ self.eps = eps
16
+ self.weight = nn.Parameter(torch.ones(dim))
17
+
18
+ def forward(self, x):
19
+ # Calculate RMS
20
+ rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
21
+ return self.weight * rms * x
22
+
23
+
24
+ class RotaryEmbedding(nn.Module):
25
+ """Rotary positional embedding."""
26
+
27
+ def __init__(self, dim, max_seq_len=2048, base=10000):
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.max_seq_len = max_seq_len
31
+ self.base = base
32
+
33
+ # Generate frequency tensor
34
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
35
+ self.register_buffer("inv_freq", inv_freq)
36
+
37
+ # Generate cos and sin cache
38
+ self._update_cos_sin_cache(max_seq_len)
39
+
40
+ def _update_cos_sin_cache(self, max_seq_len):
41
+ """Update the cache of cos and sin values."""
42
+ self.max_seq_len = max_seq_len
43
+ t = torch.arange(max_seq_len, device=self.inv_freq.device)
44
+
45
+ # Compute cos and sin at each position
46
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
47
+ cos = freqs.cos()
48
+ sin = freqs.sin()
49
+
50
+ self.register_buffer("cos_cache", cos, persistent=False)
51
+ self.register_buffer("sin_cache", sin, persistent=False)
52
+
53
+ def forward(self, x, seq_len=None, pos=None):
54
+ # Get appropriate parts of the cache
55
+ if pos is not None:
56
+ # Handle arbitrary positions
57
+ cos = self.cos_cache[pos]
58
+ sin = self.sin_cache[pos]
59
+ else:
60
+ # Handle sequential positions
61
+ seq_len = x.shape[1] if seq_len is None else seq_len
62
+ cos = self.cos_cache[:seq_len]
63
+ sin = self.sin_cache[:seq_len]
64
+
65
+ return cos, sin
66
+
67
+
68
+ def rotate_half(x):
69
+ """Rotate half the dimensions of the input."""
70
+ x1, x2 = x.chunk(2, dim=-1)
71
+ return torch.cat((-x2, x1), dim=-1)
72
+
73
+
74
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
75
+ """Apply rotary position embedding to q and k."""
76
+ if position_ids is not None:
77
+ # Handle arbitrary positions
78
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
79
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
80
+ else:
81
+ # Handle sequential positions
82
+ cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
83
+ sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
84
+
85
+ # Apply rotation
86
+ q_embed = (q * cos) + (rotate_half(q) * sin)
87
+ k_embed = (k * cos) + (rotate_half(k) * sin)
88
+
89
+ return q_embed, k_embed
90
+
91
+
92
+ class CustomAttention(nn.Module):
93
+ """Multi-head attention with support for KV caching."""
94
+
95
+ def __init__(self, dim, num_heads, num_kv_heads=None, dropout=0.0):
96
+ super().__init__()
97
+ self.dim = dim
98
+ self.num_heads = num_heads
99
+ self.num_kv_heads = num_kv_heads or num_heads
100
+ self.head_dim = dim // num_heads
101
+ self.scale = self.head_dim ** -0.5
102
+
103
+ # Attention projections
104
+ self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False)
105
+ self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
106
+ self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
107
+ self.o_proj = nn.Linear(num_heads * self.head_dim, dim, bias=False)
108
+
109
+ # Rotary embedding
110
+ self.rope = RotaryEmbedding(self.head_dim)
111
+
112
+ # Dropout
113
+ self.dropout = nn.Dropout(dropout)
114
+
115
+ def _repeat_kv(self, x):
116
+ """Repeat KV heads to match the number of query heads."""
117
+ if self.num_kv_heads == self.num_heads:
118
+ return x
119
+
120
+ b, s, n_kv_head, head_dim = x.shape
121
+
122
+ # Repeat the KV heads to match the number of query heads
123
+ repeats = self.num_heads // self.num_kv_heads
124
+ x = x.repeat_interleave(repeats, dim=2)
125
+
126
+ return x
127
+
128
+ def forward(self, x, mask=None, input_pos=None, kv_cache=None):
129
+ batch_size, seq_len, _ = x.shape
130
+
131
+ # Project to q, k, v
132
+ q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, nh, s, hd]
133
+ k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, nkh, s, hd]
134
+ v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, nkh, s, hd]
135
+
136
+ # Apply rotary embeddings
137
+ cos, sin = self.rope.forward(x, seq_len=seq_len, pos=input_pos)
138
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=input_pos)
139
+
140
+ # Handle KV cache
141
+ if kv_cache is not None:
142
+ k_cache, v_cache = kv_cache
143
+
144
+ if input_pos is not None:
145
+ # Update cache at specific positions
146
+ k_cache.index_copy_(2, input_pos, k)
147
+ v_cache.index_copy_(2, input_pos, v)
148
+
149
+ # Use the entire cache
150
+ k, v = k_cache, v_cache
151
+
152
+ # Repeat KV if needed
153
+ k = self._repeat_kv(k)
154
+ v = self._repeat_kv(v)
155
+
156
+ # Calculate attention scores
157
+ attention_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
158
+
159
+ # Apply mask if provided
160
+ if mask is not None:
161
+ attention_scores = attention_scores.masked_fill(mask == 0, -10000.0)
162
+
163
+ # Apply softmax and dropout
164
+ attention_probs = self.dropout(torch.softmax(attention_scores, dim=-1))
165
+
166
+ # Get context vector
167
+ context = torch.matmul(attention_probs, v)
168
+
169
+ # Reshape and project back
170
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
171
+ output = self.o_proj(context)
172
+
173
+ return output
174
+
175
+
176
+ class FeedForward(nn.Module):
177
+ """Feed-forward network with GELU activation."""
178
+
179
+ def __init__(self, dim, hidden_dim, dropout=0.0):
180
+ super().__init__()
181
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
182
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
183
+ self.dropout = nn.Dropout(dropout)
184
+ self.act = nn.GELU()
185
+
186
+ def forward(self, x):
187
+ x = self.w1(x)
188
+ x = self.act(x)
189
+ x = self.dropout(x)
190
+ x = self.w2(x)
191
+ return x
192
+
193
+
194
+ class TransformerLayer(nn.Module):
195
+ """A single transformer layer."""
196
+
197
+ def __init__(
198
+ self,
199
+ dim,
200
+ num_heads,
201
+ num_kv_heads=None,
202
+ ffn_dim=None,
203
+ dropout=0.0,
204
+ norm_eps=1e-5
205
+ ):
206
+ super().__init__()
207
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
208
+ self.attn = CustomAttention(dim, num_heads, num_kv_heads, dropout)
209
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
210
+ self.ffn = FeedForward(
211
+ dim,
212
+ ffn_dim or 4 * dim,
213
+ dropout
214
+ )
215
+
216
+ def forward(self, x, mask=None, input_pos=None, kv_cache=None):
217
+ # Self-attention with residual
218
+ h = self.norm1(x)
219
+ h = self.attn(h, mask=mask, input_pos=input_pos, kv_cache=kv_cache)
220
+ x = x + h
221
+
222
+ # FFN with residual
223
+ h = self.norm2(x)
224
+ h = self.ffn(h)
225
+ x = x + h
226
+
227
+ return x
228
+
229
+
230
+ class CustomTransformerDecoder(nn.Module):
231
+ """Custom transformer decoder that mimics Llama architecture."""
232
+
233
+ def __init__(
234
+ self,
235
+ vocab_size,
236
+ num_layers,
237
+ num_heads,
238
+ num_kv_heads,
239
+ embed_dim,
240
+ max_seq_len,
241
+ intermediate_dim,
242
+ attn_dropout=0.0,
243
+ norm_eps=1e-5,
244
+ rope_base=10000,
245
+ ):
246
+ super().__init__()
247
+ self.vocab_size = vocab_size
248
+ self.max_seq_len = max_seq_len
249
+ self.embed_dim = embed_dim
250
+
251
+ # Token embeddings
252
+ self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)
253
+
254
+ # Transformer layers
255
+ self.layers = nn.ModuleList([
256
+ TransformerLayer(
257
+ embed_dim,
258
+ num_heads,
259
+ num_kv_heads,
260
+ intermediate_dim,
261
+ attn_dropout,
262
+ norm_eps
263
+ )
264
+ for _ in range(num_layers)
265
+ ])
266
+
267
+ # Final normalization and output projection
268
+ self.norm = RMSNorm(embed_dim, eps=norm_eps)
269
+ self.output = nn.Linear(embed_dim, vocab_size, bias=False)
270
+
271
+ # Initialize the KV cache
272
+ self._kv_cache = None
273
+ self._has_cache = False
274
+
275
+ logger.info(f"Initialized CustomTransformerDecoder with {num_layers} layers, {num_heads} heads, {embed_dim} dim")
276
+
277
+ def setup_caches(self, batch_size, dtype, decoder_max_seq_len=None):
278
+ """Set up KV caches for inference."""
279
+ max_seq_len = decoder_max_seq_len or self.max_seq_len
280
+ device = next(self.parameters()).device
281
+
282
+ self._kv_cache = []
283
+ for i, layer in enumerate(self.layers):
284
+ # Create a KV cache for each layer
285
+ k_cache = torch.zeros(
286
+ batch_size,
287
+ layer.attn.num_kv_heads,
288
+ max_seq_len,
289
+ layer.attn.head_dim,
290
+ device=device,
291
+ dtype=dtype
292
+ )
293
+ v_cache = torch.zeros(
294
+ batch_size,
295
+ layer.attn.num_kv_heads,
296
+ max_seq_len,
297
+ layer.attn.head_dim,
298
+ device=device,
299
+ dtype=dtype
300
+ )
301
+ self._kv_cache.append((k_cache, v_cache))
302
+
303
+ self._has_cache = True
304
+ logger.info(f"KV caches set up for {batch_size} batches, {max_seq_len} seq length")
305
+
306
+ def caches_are_enabled(self):
307
+ """Check if caches are enabled."""
308
+ return self._has_cache
309
+
310
+ def reset_caches(self):
311
+ """Reset the KV cache to zeros."""
312
+ if self._has_cache and self._kv_cache:
313
+ for k_cache, v_cache in self._kv_cache:
314
+ k_cache.zero_()
315
+ v_cache.zero_()
316
+
317
+ def forward(self, x, mask=None, input_pos=None):
318
+ batch_size, seq_len = x.shape[:2]
319
+
320
+ # Apply embedding if input is token IDs
321
+ if x.dim() == 2:
322
+ x = self.tok_embeddings(x)
323
+
324
+ # Apply transformer layers
325
+ for i, layer in enumerate(self.layers):
326
+ layer_cache = self._kv_cache[i] if self._has_cache else None
327
+ x = layer(x, mask=mask, input_pos=input_pos, kv_cache=layer_cache)
328
+
329
+ # Apply final norm
330
+ x = self.norm(x)
331
+
332
+ # Skip output projection if using Identity
333
+ if isinstance(self.output, nn.Identity):
334
+ return x
335
+
336
+ # Apply output projection
337
+ logits = self.output(x)
338
+
339
+ return logits
app/download_model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Download CSM-1B model from Hugging Face."""
2
+
3
+ import os
4
+ import argparse
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ def download_model(output_dir="models"):
8
+ """Download CSM-1B model from Hugging Face."""
9
+ print("Downloading CSM-1B model...")
10
+ os.makedirs(output_dir, exist_ok=True)
11
+
12
+ # Download model
13
+ model_path = hf_hub_download(
14
+ repo_id="sesame/csm-1b",
15
+ filename="ckpt.pt",
16
+ local_dir=output_dir,
17
+ local_dir_use_symlinks=False
18
+ )
19
+
20
+ print(f"Model downloaded to {model_path}")
21
+ return model_path
22
+
23
+ if __name__ == "__main__":
24
+ parser = argparse.ArgumentParser(description="Download CSM-1B model")
25
+ parser.add_argument("--output", type=str, default="models", help="Output directory")
26
+ args = parser.parse_args()
27
+
28
+ download_model(args.output)
app/generator.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Updated generator.py with proper function order
2
+ from dataclasses import dataclass
3
+ from typing import List, Tuple
4
+ import torch
5
+ import torchaudio
6
+ import logging
7
+ import os
8
+ from huggingface_hub import hf_hub_download
9
+ from transformers import AutoTokenizer
10
+ from tokenizers.processors import TemplateProcessing
11
+ from app.models import Segment
12
+ from app.text_normalizer import clean_text_for_tts
13
+ from app.text_normalizer import TextNormalizer
14
+
15
+
16
+ # Set up logging
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Import the CSM watermarking code
20
+ try:
21
+ from app.watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
22
+ except ImportError:
23
+ # Define stubs for watermarking if the module is not available
24
+ CSM_1B_GH_WATERMARK = "CSM1B"
25
+ def load_watermarker(device="cpu"):
26
+ return None
27
+ def watermark(watermarker, audio, sample_rate, key):
28
+ return audio, sample_rate
29
+
30
+ def load_llama3_tokenizer():
31
+ """
32
+ Load tokenizer for Llama 3.2, using unsloth's open version
33
+ instead of the gated meta-llama version.
34
+ """
35
+ try:
36
+ # Use the unsloth version which is not gated
37
+ tokenizer_name = "unsloth/Llama-3.2-1B"
38
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
39
+ bos = tokenizer.bos_token
40
+ eos = tokenizer.eos_token
41
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
42
+ single=f"{bos}:0 $A:0 {eos}:0",
43
+ pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
44
+ special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
45
+ )
46
+ logger.info("Successfully loaded tokenizer from unsloth/Llama-3.2-1B")
47
+ return tokenizer
48
+ except Exception as e:
49
+ logger.error(f"Error loading tokenizer from unsloth: {e}")
50
+ # Fallback to a simpler tokenizer if needed
51
+ try:
52
+ from transformers import GPT2Tokenizer
53
+ logger.warning("Falling back to GPT2Tokenizer")
54
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+ return tokenizer
57
+ except Exception as fallback_e:
58
+ logger.error(f"Fallback tokenizer also failed: {fallback_e}")
59
+ raise RuntimeError("Could not load any suitable tokenizer")
60
+
61
+ class Generator:
62
+ """Generator class for CSM-1B model."""
63
+ def __init__(self, model):
64
+ """Initialize generator with model."""
65
+ self._model = model
66
+ self._model.setup_caches(1)
67
+ self._text_tokenizer = load_llama3_tokenizer()
68
+ device = next(model.parameters()).device
69
+ # Load Mimi codec for audio tokenization
70
+ try:
71
+ logger.info("Loading Mimi audio codec...")
72
+ from huggingface_hub import hf_hub_download
73
+ # First try to import from moshi
74
+ try:
75
+ from moshi.models import loaders
76
+ DEFAULT_REPO = loaders.DEFAULT_REPO
77
+ MIMI_NAME = loaders.MIMI_NAME
78
+ get_mimi = loaders.get_mimi
79
+ except ImportError:
80
+ logger.warning("moshi.models.loaders not found, using fallback")
81
+ # Fallback values if moshi.models.loaders is not available
82
+ DEFAULT_REPO = "kyutai/mimi"
83
+ MIMI_NAME = "mimi-december.pt"
84
+ # Fallback function to load mimi
85
+ def get_mimi(checkpoint_path, device):
86
+ from moshi.models.vqvae_model import MiMiModule
87
+ checkpoint = torch.load(checkpoint_path, map_location=device)
88
+ model = MiMiModule.init_from_checkpoint(checkpoint, device=device)
89
+ return model
90
+ mimi_weight = hf_hub_download(DEFAULT_REPO, MIMI_NAME)
91
+ mimi = get_mimi(mimi_weight, device=device)
92
+ mimi.set_num_codebooks(32)
93
+ self._audio_tokenizer = mimi
94
+ self.sample_rate = mimi.sample_rate
95
+ logger.info(f"Mimi codec loaded successfully with sample rate {self.sample_rate}")
96
+ except Exception as e:
97
+ logger.error(f"Error loading Mimi codec: {e}")
98
+ self._audio_tokenizer = None
99
+ self.sample_rate = 24000 # Default sample rate
100
+ logger.warning(f"Using fallback sample rate: {self.sample_rate}")
101
+ raise RuntimeError(f"Failed to load Mimi codec: {e}")
102
+ try:
103
+ self._watermarker = load_watermarker(device=device)
104
+ logger.info("Watermarker loaded successfully")
105
+ except Exception as e:
106
+ logger.warning(f"Error loading watermarker: {e}. Watermarking will be disabled.")
107
+ self._watermarker = None
108
+
109
+ self.device = device
110
+ # Optimize for CUDA throughput
111
+ if torch.cuda.is_available():
112
+ torch.backends.cudnn.benchmark = True
113
+ torch.cuda.empty_cache()
114
+ logger.info("CUDA optimizations enabled")
115
+
116
+ def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
117
+ """Tokenize a text segment."""
118
+ frame_tokens = []
119
+ frame_masks = []
120
+ # Strip any voice instructions in square brackets to avoid them being read out
121
+ text = self._clean_text_input(text)
122
+ text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
123
+ text_frame = torch.zeros(len(text_tokens), 33).long()
124
+ text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
125
+ text_frame[:, -1] = torch.tensor(text_tokens)
126
+ text_frame_mask[:, -1] = True
127
+ frame_tokens.append(text_frame.to(self.device))
128
+ frame_masks.append(text_frame_mask.to(self.device))
129
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
130
+
131
+ def _clean_text_input(self, text: str) -> str:
132
+ """Clean and normalize text for TTS."""
133
+ return clean_text_for_tts(text)
134
+
135
+ def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
136
+ """Tokenize audio."""
137
+ if self._audio_tokenizer is None:
138
+ raise RuntimeError("Audio tokenizer not initialized")
139
+ frame_tokens = []
140
+ frame_masks = []
141
+ # (K, T)
142
+ audio = audio.to(self.device)
143
+ audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
144
+ # add EOS frame
145
+ eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
146
+ audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
147
+ audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
148
+ audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
149
+ audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
150
+ audio_frame_mask[:, :-1] = True
151
+ frame_tokens.append(audio_frame)
152
+ frame_masks.append(audio_frame_mask)
153
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
154
+
155
+ def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
156
+ """Tokenize a segment of text and audio."""
157
+ text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
158
+ audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
159
+ return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
160
+
161
+ def generate_quick(
162
+ self,
163
+ text: str,
164
+ speaker: int,
165
+ context: List[Segment],
166
+ max_audio_length_ms: float = 2000, # Short for quick generation
167
+ temperature: float = 0.7, # Lower for more predictable output
168
+ topk: int = 20, # Lower for faster beam selection
169
+ ) -> torch.Tensor:
170
+ """Generate audio quickly for real-time streaming."""
171
+ # Similar to generate() but optimized for speed
172
+ self._model.reset_caches()
173
+
174
+ # Convert max_audio_length_ms to frames - limit for faster generation
175
+ max_audio_frames = min(int(max_audio_length_ms / 80), 128) # Smaller limit
176
+
177
+ # Process text
178
+ cleaned_text = clean_text_for_tts(text)
179
+
180
+ # Prepare tokens
181
+ tokens, tokens_mask = [], []
182
+ # Add context segments (limited to 1 for speed)
183
+ if context:
184
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(context[0])
185
+ tokens.append(segment_tokens)
186
+ tokens_mask.append(segment_tokens_mask)
187
+ # Add text tokens
188
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
189
+ tokens.append(gen_segment_tokens)
190
+ tokens_mask.append(gen_segment_tokens_mask)
191
+
192
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
193
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
194
+
195
+ # Generate with larger batch size for fewer iterations
196
+ curr_tokens = prompt_tokens.unsqueeze(0)
197
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
198
+ curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
199
+
200
+ # Use larger batch size
201
+ batch_size = 64 # Generate more frames at once
202
+ all_samples = []
203
+ for start_idx in range(0, max_audio_frames, batch_size):
204
+ end_idx = min(start_idx + batch_size, max_audio_frames)
205
+ batch_frames = end_idx - start_idx
206
+ samples_batch = []
207
+ for i in range(batch_frames):
208
+ sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
209
+ samples_batch.append(sample)
210
+ if torch.all(sample == 0):
211
+ break
212
+ curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
213
+ curr_tokens_mask = torch.cat(
214
+ [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
215
+ ).unsqueeze(1)
216
+ curr_pos = curr_pos[:, -1:] + 1
217
+ all_samples.extend(samples_batch)
218
+ if len(samples_batch) < batch_frames:
219
+ break
220
+
221
+ if not all_samples:
222
+ return torch.zeros(10, device=self.device) # Return short empty audio
223
+
224
+ # Decode audio
225
+ audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
226
+ return audio
227
+
228
+ @torch.inference_mode()
229
+ def generate(
230
+ self,
231
+ text: str,
232
+ speaker: int,
233
+ context: List[Segment],
234
+ max_audio_length_ms: float = 90_000,
235
+ temperature: float = 0.9,
236
+ topk: int = 50,
237
+ ) -> torch.Tensor:
238
+ """Generate audio from text."""
239
+ if self._audio_tokenizer is None:
240
+ raise RuntimeError("Audio tokenizer not initialized")
241
+
242
+ # Start timing
243
+ start_time = torch.cuda.Event(enable_timing=True)
244
+ end_time = torch.cuda.Event(enable_timing=True)
245
+ start_time.record()
246
+
247
+ self._model.reset_caches()
248
+
249
+ # Convert max_audio_length_ms to frames - this controls the maximum generation length
250
+ max_audio_frames = min(int(max_audio_length_ms / 80), 1024) # Limit to reasonable size
251
+ max_seq_len = 2048 - max_audio_frames
252
+
253
+ # Check if text is long and should be split
254
+ if len(text) > 200:
255
+ logger.info(f"Long text detected ({len(text)} chars), processing in segments")
256
+ sentences = TextNormalizer.split_into_sentences(text)
257
+ logger.info(f"Split into {len(sentences)} segments")
258
+
259
+ # Process sentences individually and concatenate the results
260
+ all_audio_segments = []
261
+
262
+ # Use the first sentence to establish voice
263
+ first_sentence = sentences[0]
264
+ cleaned_text = clean_text_for_tts(first_sentence)
265
+
266
+ # Generate the first segment
267
+ tokens, tokens_mask = [], []
268
+
269
+ # Add context segments for the first sentence
270
+ for segment in context:
271
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
272
+ tokens.append(segment_tokens)
273
+ tokens_mask.append(segment_tokens_mask)
274
+
275
+ # Add first sentence tokens
276
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
277
+ tokens.append(gen_segment_tokens)
278
+ tokens_mask.append(gen_segment_tokens_mask)
279
+
280
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
281
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
282
+
283
+ # Check context size and truncate if needed
284
+ if prompt_tokens.size(0) >= max_seq_len:
285
+ logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}")
286
+ prompt_tokens = prompt_tokens[-max_seq_len+50:]
287
+ prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:]
288
+
289
+ # Generate first sentence audio
290
+ curr_tokens = prompt_tokens.unsqueeze(0)
291
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
292
+ curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
293
+
294
+ # Generate first segment
295
+ first_segment_samples = []
296
+ for start_idx in range(0, max_audio_frames, 32):
297
+ end_idx = min(start_idx + 32, max_audio_frames)
298
+ batch_frames = end_idx - start_idx
299
+ samples_batch = []
300
+
301
+ for i in range(batch_frames):
302
+ sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
303
+ samples_batch.append(sample)
304
+
305
+ if torch.all(sample == 0):
306
+ break
307
+
308
+ curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
309
+ curr_tokens_mask = torch.cat(
310
+ [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
311
+ ).unsqueeze(1)
312
+ curr_pos = curr_pos[:, -1:] + 1
313
+
314
+ first_segment_samples.extend(samples_batch)
315
+
316
+ if len(samples_batch) < batch_frames:
317
+ break
318
+
319
+ if not first_segment_samples:
320
+ raise RuntimeError("No audio generated for first segment")
321
+
322
+ # Decode first segment
323
+ first_segment_audio = self._audio_tokenizer.decode(
324
+ torch.stack(first_segment_samples).permute(1, 2, 0)
325
+ ).squeeze(0).squeeze(0)
326
+
327
+ all_audio_segments.append(first_segment_audio)
328
+
329
+ # Now process remaining sentences using the first as context
330
+ for i, sentence in enumerate(sentences[1:], 1):
331
+ logger.info(f"Generating segment {i+1}/{len(sentences)}")
332
+ cleaned_text = clean_text_for_tts(sentence)
333
+
334
+ # Create a context segment from the previous generation
335
+ prev_segment = Segment(
336
+ speaker=speaker,
337
+ text=sentences[i-1],
338
+ audio=all_audio_segments[-1]
339
+ )
340
+
341
+ # Generate with this segment as context
342
+ segment_tokens, segment_tokens_mask = [], []
343
+ segment_tokens.append(self._tokenize_segment(prev_segment)[0])
344
+ segment_tokens_mask.append(self._tokenize_segment(prev_segment)[1])
345
+
346
+ # Add current segment tokens
347
+ current_tokens, current_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
348
+ segment_tokens.append(current_tokens)
349
+ segment_tokens_mask.append(current_tokens_mask)
350
+
351
+ segment_prompt_tokens = torch.cat(segment_tokens, dim=0).long().to(self.device)
352
+ segment_prompt_tokens_mask = torch.cat(segment_tokens_mask, dim=0).bool().to(self.device)
353
+
354
+ # Check length and truncate if needed
355
+ if segment_prompt_tokens.size(0) >= max_seq_len:
356
+ segment_prompt_tokens = segment_prompt_tokens[-max_seq_len+50:]
357
+ segment_prompt_tokens_mask = segment_prompt_tokens_mask[-max_seq_len+50:]
358
+
359
+ # Generate audio for this segment
360
+ curr_tokens = segment_prompt_tokens.unsqueeze(0)
361
+ curr_tokens_mask = segment_prompt_tokens_mask.unsqueeze(0)
362
+ curr_pos = torch.arange(0, segment_prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
363
+
364
+ # Generate segment
365
+ segment_samples = []
366
+ for start_idx in range(0, max_audio_frames, 32):
367
+ end_idx = min(start_idx + 32, max_audio_frames)
368
+ batch_frames = end_idx - start_idx
369
+ samples_batch = []
370
+
371
+ for i in range(batch_frames):
372
+ sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
373
+ samples_batch.append(sample)
374
+
375
+ if torch.all(sample == 0):
376
+ break
377
+
378
+ curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
379
+ curr_tokens_mask = torch.cat(
380
+ [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
381
+ ).unsqueeze(1)
382
+ curr_pos = curr_pos[:, -1:] + 1
383
+
384
+ segment_samples.extend(samples_batch)
385
+
386
+ if len(samples_batch) < batch_frames:
387
+ break
388
+
389
+ if not segment_samples:
390
+ logger.warning(f"No audio generated for segment {i+1}")
391
+ continue
392
+
393
+ # Decode segment
394
+ segment_audio = self._audio_tokenizer.decode(
395
+ torch.stack(segment_samples).permute(1, 2, 0)
396
+ ).squeeze(0).squeeze(0)
397
+
398
+ all_audio_segments.append(segment_audio)
399
+
400
+ # Combine all segments with small pauses
401
+ pause_samples = int(0.3 * self.sample_rate) # 300ms pause
402
+ pause = torch.zeros(pause_samples, device=self.device)
403
+
404
+ audio_parts = []
405
+ for i, segment_audio in enumerate(all_audio_segments):
406
+ audio_parts.append(segment_audio)
407
+ if i < len(all_audio_segments) - 1:
408
+ audio_parts.append(pause)
409
+
410
+ audio = torch.cat(audio_parts)
411
+ logger.info(f"Combined {len(all_audio_segments)} segments into final audio")
412
+
413
+ else:
414
+ # For shorter text, standard processing
415
+ tokens, tokens_mask = [], []
416
+
417
+ # Add context segments
418
+ for segment in context:
419
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
420
+ tokens.append(segment_tokens)
421
+ tokens_mask.append(segment_tokens_mask)
422
+
423
+ # Process text
424
+ cleaned_text = clean_text_for_tts(text)
425
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
426
+ tokens.append(gen_segment_tokens)
427
+ tokens_mask.append(gen_segment_tokens_mask)
428
+
429
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
430
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
431
+
432
+ # Check context size
433
+ if prompt_tokens.size(0) >= max_seq_len:
434
+ logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}")
435
+ prompt_tokens = prompt_tokens[-max_seq_len+50:]
436
+ prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:]
437
+
438
+ # Generate audio - optimized batch generation
439
+ curr_tokens = prompt_tokens.unsqueeze(0)
440
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
441
+ curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
442
+
443
+ # Using optimized batch generation
444
+ batch_size = 32 # Generate this many frames at once
445
+ all_samples = []
446
+
447
+ for start_idx in range(0, max_audio_frames, batch_size):
448
+ end_idx = min(start_idx + batch_size, max_audio_frames)
449
+ batch_frames = end_idx - start_idx
450
+
451
+ samples_batch = []
452
+
453
+ for i in range(batch_frames):
454
+ sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
455
+ samples_batch.append(sample)
456
+
457
+ if torch.all(sample == 0):
458
+ break
459
+
460
+ curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
461
+ curr_tokens_mask = torch.cat(
462
+ [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
463
+ ).unsqueeze(1)
464
+ curr_pos = curr_pos[:, -1:] + 1
465
+
466
+ all_samples.extend(samples_batch)
467
+
468
+ if len(samples_batch) < batch_frames:
469
+ logger.info(f"Early stopping at frame {start_idx + len(samples_batch)}/{max_audio_frames}")
470
+ break
471
+
472
+ if not all_samples:
473
+ raise RuntimeError("No audio generated - model produced empty output")
474
+
475
+ # Decode audio
476
+ audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
477
+
478
+ # Apply watermark
479
+ if self._watermarker is not None:
480
+ try:
481
+ audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
482
+ audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
483
+ except Exception as e:
484
+ logger.warning(f"Error applying watermark: {e}. Continuing without watermark.")
485
+
486
+ # Record execution time
487
+ end_time.record()
488
+ torch.cuda.synchronize()
489
+ execution_ms = start_time.elapsed_time(end_time)
490
+ audio_length_ms = (audio.shape[0] / self.sample_rate) * 1000
491
+
492
+ # Calculate real-time factor (RTF)
493
+ rtf = execution_ms / audio_length_ms
494
+ logger.info(f"Audio generated in {execution_ms:.2f}ms, length: {audio_length_ms:.2f}ms, RTF: {rtf:.2f}x")
495
+
496
+ return audio
497
+
498
+ # Define helper functions for multi-GPU support
499
+ def _manual_device_map(model, state_dict, strategy="balanced"):
500
+ """Apply manual device mapping for multi-GPU setups.
501
+
502
+ Args:
503
+ model: The model to map
504
+ state_dict: Model state dict
505
+ strategy: Mapping strategy ('balanced', 'sequential')
506
+
507
+ Returns:
508
+ Model with weights distributed across GPUs
509
+ """
510
+ num_gpus = torch.cuda.device_count()
511
+ if num_gpus <= 1:
512
+ # No need for mapping with single GPU
513
+ model.load_state_dict(state_dict)
514
+ model = model.to("cuda")
515
+ return model
516
+
517
+ logger.info(f"Applying manual {strategy} device mapping across {num_gpus} GPUs")
518
+
519
+ # Get all layer names from state dict
520
+ layer_names = [name for name in state_dict.keys() if "layers" in name]
521
+ backbone_layers = [name for name in layer_names if "backbone.layers" in name]
522
+ decoder_layers = [name for name in layer_names if "decoder.layers" in name]
523
+
524
+ # Count number of backbone and decoder layers
525
+ backbone_layer_indices = set()
526
+ for name in backbone_layers:
527
+ parts = name.split('.')
528
+ if len(parts) > 2:
529
+ try:
530
+ backbone_layer_indices.add(int(parts[2]))
531
+ except ValueError:
532
+ pass
533
+
534
+ decoder_layer_indices = set()
535
+ for name in decoder_layers:
536
+ parts = name.split('.')
537
+ if len(parts) > 2:
538
+ try:
539
+ decoder_layer_indices.add(int(parts[2]))
540
+ except ValueError:
541
+ pass
542
+
543
+ num_backbone_layers = len(backbone_layer_indices)
544
+ num_decoder_layers = len(decoder_layer_indices)
545
+
546
+ # Create device map
547
+ device_map = {}
548
+
549
+ if strategy == "balanced":
550
+ # Distribute layers evenly across GPUs
551
+ layers_per_gpu = (num_backbone_layers + num_decoder_layers) // num_gpus
552
+ remainder = (num_backbone_layers + num_decoder_layers) % num_gpus
553
+
554
+ # Assign backbone layers
555
+ for i in backbone_layer_indices:
556
+ gpu_idx = min(i // layers_per_gpu, num_gpus - 1)
557
+ device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}"
558
+
559
+ # Assign decoder layers
560
+ for i in decoder_layer_indices:
561
+ gpu_idx = min((i + num_backbone_layers) // layers_per_gpu, num_gpus - 1)
562
+ device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}"
563
+
564
+ elif strategy == "sequential":
565
+ # Fill each GPU sequentially
566
+ # Backbone layers on first GPU(s)
567
+ backbone_per_gpu = max(1, num_backbone_layers // ((num_gpus + 1) // 2))
568
+ for i in backbone_layer_indices:
569
+ gpu_idx = min(i // backbone_per_gpu, (num_gpus + 1) // 2 - 1)
570
+ device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}"
571
+
572
+ # Decoder layers on remaining GPU(s)
573
+ decoder_per_gpu = max(1, num_decoder_layers // (num_gpus - (num_gpus + 1) // 2 + 1))
574
+ for i in decoder_layer_indices:
575
+ gpu_idx = min(i // decoder_per_gpu + (num_gpus + 1) // 2 - 1, num_gpus - 1)
576
+ device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}"
577
+
578
+ # Assign embeddings and other components
579
+ device_map["text_embeddings"] = "cuda:0"
580
+ device_map["audio_embeddings"] = "cuda:0"
581
+ device_map["projection"] = "cuda:0"
582
+ device_map["codebook0_head"] = "cuda:0"
583
+ device_map["audio_head"] = "cuda:0"
584
+
585
+ # Load state dict with device mapping
586
+ model.load_state_dict(state_dict)
587
+
588
+ # Move model parts to assigned devices
589
+ for name, device in device_map.items():
590
+ if "backbone.layers" in name:
591
+ layer_idx = int(name.split('.')[-1])
592
+ if hasattr(model.backbone, 'layers') and layer_idx < len(model.backbone.layers):
593
+ model.backbone.layers[layer_idx] = model.backbone.layers[layer_idx].to(device)
594
+ elif "decoder.layers" in name:
595
+ layer_idx = int(name.split('.')[-1])
596
+ if hasattr(model.decoder, 'layers') and layer_idx < len(model.decoder.layers):
597
+ model.decoder.layers[layer_idx] = model.decoder.layers[layer_idx].to(device)
598
+ elif hasattr(model, name):
599
+ setattr(model, name, getattr(model, name).to(device))
600
+
601
+ logger.info(f"Model distributed across GPUs with {strategy} strategy")
602
+ return model
603
+
604
+ def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda", device_map: str = None) -> Generator:
605
+ """Load CSM-1B model and create generator with performance optimizations.
606
+
607
+ Args:
608
+ ckpt_path: Path to model checkpoint
609
+ device: Device to load model on ('cuda', 'cpu', or specific CUDA device)
610
+ device_map: Optional device mapping strategy ('auto', 'balanced', 'sequential', or None)
611
+
612
+ Returns:
613
+ Generator instance with optimized settings
614
+ """
615
+ try:
616
+ # Import models module for CSM
617
+ from app.torchtune_models import Model, ModelArgs
618
+
619
+ # Create model
620
+ model_args = ModelArgs(
621
+ backbone_flavor="llama-1B",
622
+ decoder_flavor="llama-100M",
623
+ text_vocab_size=128256,
624
+ audio_vocab_size=2051,
625
+ audio_num_codebooks=32,
626
+ )
627
+
628
+ # Load model
629
+ logger.info(f"Loading CSM-1B model from {ckpt_path} with device={device}, device_map={device_map}")
630
+
631
+ # Check for CUDA availability
632
+ cuda_available = device == "cuda" and torch.cuda.is_available()
633
+
634
+ # Set up torch for optimized inference
635
+ if cuda_available:
636
+ # Check if we should enable TF32 (faster but slightly less precise)
637
+ enable_tf32 = os.environ.get("ENABLE_TF32", "true").lower() == "true"
638
+ if enable_tf32:
639
+ logger.info("Enabling TF32 for faster matrix multiplications")
640
+ torch.backends.cuda.matmul.allow_tf32 = True
641
+ torch.backends.cudnn.allow_tf32 = True
642
+
643
+ # Check for available precision modes
644
+ use_bfloat16 = torch.cuda.is_bf16_supported()
645
+ use_float16 = not use_bfloat16 and torch.cuda.is_available() # Fallback to float16
646
+
647
+ if use_bfloat16:
648
+ dtype = torch.bfloat16
649
+ logger.info("Using bfloat16 precision for faster inference")
650
+ elif use_float16:
651
+ dtype = torch.float16
652
+ logger.info("Using float16 precision for faster inference")
653
+ else:
654
+ dtype = torch.float32
655
+ logger.info("Using float32 precision (mixed precision not available)")
656
+
657
+ # Enable Flash Attention if available
658
+ try:
659
+ import flash_attn
660
+ if os.environ.get("ENABLE_FLASH_ATTN", "true").lower() == "true":
661
+ logger.info("Flash Attention detected - enabling for faster attention")
662
+ os.environ["PYTORCH_FLASH_ATTENTION_ENABLED"] = "1"
663
+ except ImportError:
664
+ logger.info("Flash Attention not available (install flash-attn for faster inference)")
665
+ else:
666
+ # CPU-only mode
667
+ dtype = torch.float32
668
+ logger.info("Using CPU mode with float32 precision")
669
+
670
+ # Check for quantization
671
+ enable_quantization = os.environ.get("ENABLE_QUANTIZATION", "false").lower() == "true"
672
+ is_quantized = False
673
+
674
+ # Check for multi-GPU setup
675
+ if device_map and torch.cuda.device_count() > 1:
676
+ logger.info(f"Using device_map={device_map} across {torch.cuda.device_count()} GPUs")
677
+
678
+ # Create model with device map
679
+ model = Model(model_args)
680
+
681
+ # Load state dict
682
+ state_dict = torch.load(ckpt_path, map_location='cpu')
683
+
684
+ # Try quantization before device mapping if enabled
685
+ if enable_quantization and cuda_available:
686
+ try:
687
+ from bitsandbytes.nn import Linear8bitLt
688
+
689
+ def replace_with_8bit(model):
690
+ """Replace linear layers with 8-bit quantized versions"""
691
+ for name, module in model.named_modules():
692
+ if isinstance(module, torch.nn.Linear) and module.out_features > 256:
693
+ parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
694
+ parent = model
695
+ if parent_name:
696
+ for attr in parent_name.split('.'):
697
+ parent = getattr(parent, attr)
698
+ child_name = name.rsplit('.', 1)[1] if '.' in name else name
699
+ setattr(parent, child_name, Linear8bitLt.from_float(module))
700
+ return model
701
+
702
+ logger.info("Applying 8-bit quantization to linear layers")
703
+ model = replace_with_8bit(model)
704
+ is_quantized = True
705
+ except ImportError:
706
+ logger.warning("bitsandbytes not available, skipping quantization")
707
+
708
+ # Apply device mapping
709
+ if device_map == "auto":
710
+ # Use accelerate for automatic device mapping
711
+ try:
712
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
713
+
714
+ # Initialize empty model
715
+ with init_empty_weights():
716
+ empty_model = Model(model_args)
717
+
718
+ # Load and dispatch model across GPUs
719
+ model = load_checkpoint_and_dispatch(
720
+ empty_model,
721
+ ckpt_path,
722
+ device_map="auto",
723
+ no_split_module_classes=["TransformerLayer"],
724
+ # Offload CPU if very large model
725
+ offload_folder="offload" if os.environ.get("OFFLOAD_TO_CPU", "false").lower() == "true" else None
726
+ )
727
+ logger.info("Model loaded with automatic device mapping")
728
+ except ImportError:
729
+ logger.warning("accelerate package not found, falling back to manual device mapping")
730
+ model = _manual_device_map(model, state_dict, "balanced")
731
+ except Exception as mapping_error:
732
+ logger.error(f"Auto device mapping failed: {mapping_error}, falling back to manual")
733
+ model = _manual_device_map(model, state_dict, "balanced")
734
+ else:
735
+ # Manual device mapping
736
+ model = _manual_device_map(model, state_dict, device_map or "balanced")
737
+ else:
738
+ # Single GPU or CPU setup
739
+
740
+ # Try quantization before loading if enabled (GPU only)
741
+ if enable_quantization and cuda_available and not is_quantized:
742
+ try:
743
+ # First load to CPU for quantization
744
+ model = Model(model_args).to("cpu")
745
+ state_dict = torch.load(ckpt_path, map_location="cpu")
746
+ model.load_state_dict(state_dict)
747
+
748
+ from bitsandbytes.nn import Linear8bitLt
749
+
750
+ def replace_with_8bit(model):
751
+ """Replace linear layers with 8-bit quantized versions"""
752
+ for name, module in model.named_modules():
753
+ if isinstance(module, torch.nn.Linear) and module.out_features > 256:
754
+ parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
755
+ parent = model
756
+ if parent_name:
757
+ for attr in parent_name.split('.'):
758
+ parent = getattr(parent, attr)
759
+ child_name = name.rsplit('.', 1)[1] if '.' in name else name
760
+ setattr(parent, child_name, Linear8bitLt.from_float(module))
761
+ return model
762
+
763
+ logger.info("Applying 8-bit quantization to linear layers")
764
+ model = replace_with_8bit(model)
765
+ model = model.to(device=device)
766
+ is_quantized = True
767
+ except ImportError:
768
+ logger.warning("bitsandbytes not available, loading without quantization")
769
+ # Load the standard way
770
+ model = Model(model_args).to(device=device, dtype=dtype)
771
+ state_dict = torch.load(ckpt_path, map_location=device)
772
+ model.load_state_dict(state_dict)
773
+ except Exception as quant_error:
774
+ logger.error(f"Quantization failed: {quant_error}, loading without quantization")
775
+ # Load the standard way
776
+ model = Model(model_args).to(device=device, dtype=dtype)
777
+ state_dict = torch.load(ckpt_path, map_location=device)
778
+ model.load_state_dict(state_dict)
779
+ else:
780
+ # Standard load without quantization
781
+ model = Model(model_args).to(device=device, dtype=dtype)
782
+ state_dict = torch.load(ckpt_path, map_location=device)
783
+ model.load_state_dict(state_dict)
784
+
785
+ # Apply torch.compile if available (PyTorch 2.0+)
786
+ compile_mode = os.environ.get("TORCH_COMPILE_MODE", "none")
787
+ if hasattr(torch, 'compile') and compile_mode != "none" and cuda_available:
788
+ try:
789
+ logger.info(f"Using torch.compile with mode '{compile_mode}' for faster inference")
790
+ if compile_mode == "default":
791
+ model = torch.compile(model)
792
+ else:
793
+ model = torch.compile(model, mode=compile_mode)
794
+ except Exception as compile_error:
795
+ logger.warning(f"Torch compile failed (requires PyTorch 2.0+): {compile_error}")
796
+
797
+ # Try to optimize CUDA graphs for faster inference (advanced)
798
+ use_cuda_graphs = os.environ.get("USE_CUDA_GRAPHS", "false").lower() == "true"
799
+ if use_cuda_graphs and cuda_available and hasattr(torch.cuda, 'CUDAGraph'):
800
+ try:
801
+ logger.info("Setting up CUDA graphs for repeated inference patterns")
802
+ # This requires custom integration inside the model's forward method
803
+ # Just flagging that CUDA graphs should be used
804
+ model.use_cuda_graphs = True
805
+ except Exception as cuda_graph_error:
806
+ logger.warning(f"CUDA graphs setup failed: {cuda_graph_error}")
807
+ model.use_cuda_graphs = False
808
+
809
+ # Set optimal settings for CUDA context
810
+ if cuda_available:
811
+ # Set benchmark mode for hardware-specific optimizations
812
+ torch.backends.cudnn.benchmark = True
813
+ # Clean up CUDA cache before creating generator
814
+ torch.cuda.empty_cache()
815
+ # Ensure all CUDA work is completed to avoid launch delays
816
+ torch.cuda.synchronize()
817
+
818
+ # Create generator
819
+ logger.info("Creating generator with optimized settings")
820
+ generator = Generator(model)
821
+
822
+ # Log memory usage if on CUDA
823
+ if cuda_available:
824
+ memory_allocated = torch.cuda.memory_allocated() / (1024**3)
825
+ memory_reserved = torch.cuda.memory_reserved() / (1024**3)
826
+ logger.info(f"Model loaded, CUDA memory: {memory_allocated:.2f}GB allocated, {memory_reserved:.2f}GB reserved")
827
+
828
+ logger.info(f"Generator created successfully: precision={dtype}, quantized={is_quantized}")
829
+ return generator
830
+ except Exception as e:
831
+ logger.error(f"Failed to load CSM-1B model: {e}")
832
+ import traceback
833
+ logger.error(traceback.format_exc())
834
+ raise
app/main.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CSM-1B TTS API main application.
3
+ Provides an OpenAI-compatible API for the CSM-1B text-to-speech model.
4
+ """
5
+ import os
6
+ import time
7
+ import tempfile
8
+ import logging
9
+ from logging.handlers import RotatingFileHandler
10
+ import traceback
11
+ import asyncio
12
+ import glob
13
+ import torch
14
+ import uvicorn
15
+ from contextlib import asynccontextmanager
16
+ from fastapi import FastAPI, Depends, HTTPException, Request, Response
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from fastapi.responses import RedirectResponse, FileResponse
19
+ from fastapi.staticfiles import StaticFiles
20
+ from app.api.routes import router as api_router
21
+
22
+ # Setup logging
23
+ os.makedirs("logs", exist_ok=True)
24
+ log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
25
+
26
+ # Console handler
27
+ console_handler = logging.StreamHandler()
28
+ console_handler.setFormatter(logging.Formatter(log_format))
29
+
30
+ # File handler
31
+ file_handler = RotatingFileHandler(
32
+ "logs/csm_tts_api.log",
33
+ maxBytes=10*1024*1024, # 10MB
34
+ backupCount=5
35
+ )
36
+ file_handler.setFormatter(logging.Formatter(log_format))
37
+
38
+ # Configure root logger
39
+ logging.basicConfig(
40
+ level=logging.INFO,
41
+ format=log_format,
42
+ handlers=[console_handler, file_handler]
43
+ )
44
+ logger = logging.getLogger(__name__)
45
+ logger.info("Starting CSM-1B TTS API")
46
+
47
+ @asynccontextmanager
48
+ async def lifespan(app: FastAPI):
49
+ """Application lifespan manager for startup and shutdown events."""
50
+ # STARTUP EVENT
51
+ logger.info("Starting application initialization")
52
+ app.state.startup_time = time.time()
53
+ app.state.generator = None # Will be populated later if model loads
54
+ app.state.logger = logger # Make logger available to routes
55
+
56
+ # Create necessary directories - use persistent locations
57
+ os.makedirs("/app/models", exist_ok=True)
58
+ os.makedirs("/app/tokenizers", exist_ok=True)
59
+ os.makedirs("/app/voice_memories", exist_ok=True)
60
+ os.makedirs("/app/voice_references", exist_ok=True)
61
+ os.makedirs("/app/voice_profiles", exist_ok=True)
62
+ os.makedirs("/app/cloned_voices", exist_ok=True)
63
+ os.makedirs("/app/audio_cache", exist_ok=True)
64
+ os.makedirs("/app/static", exist_ok=True)
65
+
66
+ # Set tokenizer cache
67
+ try:
68
+ os.environ["TRANSFORMERS_CACHE"] = "/app/tokenizers"
69
+ logger.info(f"Set tokenizer cache to: {os.environ['TRANSFORMERS_CACHE']}")
70
+ except Exception as e:
71
+ logger.error(f"Error setting tokenizer cache: {e}")
72
+
73
+ # Install additional dependencies if needed
74
+ try:
75
+ import scipy
76
+ import soundfile
77
+ logger.info("Audio processing dependencies available")
78
+ except ImportError as e:
79
+ logger.warning(f"Audio processing dependency missing: {e}. Some audio enhancements may not work.")
80
+ logger.warning("Consider installing: pip install scipy soundfile")
81
+
82
+ # Check CUDA availability
83
+ cuda_available = torch.cuda.is_available()
84
+ if cuda_available:
85
+ device_count = torch.cuda.device_count()
86
+ device_name = torch.cuda.get_device_name(0) if device_count > 0 else "unknown"
87
+ logger.info(f"CUDA is available: {device_count} device(s). Using {device_name}")
88
+ # Report CUDA memory
89
+ if hasattr(torch.cuda, 'get_device_properties'):
90
+ total_memory = torch.cuda.get_device_properties(0).total_memory
91
+ logger.info(f"Total CUDA memory: {total_memory / (1024**3):.2f} GB")
92
+ else:
93
+ logger.warning("CUDA is not available. Using CPU (this will be slow)")
94
+
95
+ # Determine device and device mapping
96
+ device = "cuda" if cuda_available else "cpu"
97
+ device_map = os.environ.get("CSM_DEVICE_MAP", None) # Options: "auto", "balanced", "sequential"
98
+ if device_map and cuda_available:
99
+ if torch.cuda.device_count() > 1:
100
+ logger.info(f"Using device mapping strategy: {device_map} across {torch.cuda.device_count()} GPUs")
101
+ else:
102
+ logger.info("Device mapping requested but only one GPU available, ignoring device_map")
103
+ device_map = None
104
+ else:
105
+ device_map = None
106
+
107
+ logger.info(f"Using device: {device}")
108
+ app.state.device = device
109
+ app.state.device_map = device_map
110
+
111
+ # Check if model file exists
112
+ model_path = os.path.join("/app/models", "ckpt.pt")
113
+ if not os.path.exists(model_path):
114
+ # Try to download at runtime if not present
115
+ logger.info("Model not found. Attempting to download...")
116
+ try:
117
+ from huggingface_hub import hf_hub_download, login
118
+ # Check for token in environment
119
+ hf_token = os.environ.get("HF_TOKEN")
120
+ if hf_token:
121
+ logger.info("Logging in to Hugging Face using provided token")
122
+ login(token=hf_token)
123
+ logger.info("Downloading CSM-1B model from Hugging Face...")
124
+ download_start = time.time()
125
+ model_path = hf_hub_download(
126
+ repo_id="sesame/csm-1b",
127
+ filename="ckpt.pt",
128
+ local_dir="/app/models"
129
+ )
130
+ download_time = time.time() - download_start
131
+ logger.info(f"Model downloaded to {model_path} in {download_time:.2f} seconds")
132
+ except Exception as e:
133
+ error_stack = traceback.format_exc()
134
+ logger.error(f"Error downloading model: {str(e)}\n{error_stack}")
135
+ logger.error("Please build the image with HF_TOKEN to download the model")
136
+ logger.error("Starting without model - API will return 503 Service Unavailable")
137
+ else:
138
+ logger.info(f"Found existing model at {model_path}")
139
+ logger.info(f"Model size: {os.path.getsize(model_path) / (1024 * 1024):.2f} MB")
140
+
141
+ # Load the model
142
+ try:
143
+ logger.info("Loading CSM-1B model...")
144
+ load_start = time.time()
145
+ from app.generator import load_csm_1b
146
+ app.state.generator = load_csm_1b(model_path, device, device_map)
147
+ load_time = time.time() - load_start
148
+ logger.info(f"Model loaded successfully in {load_time:.2f} seconds")
149
+
150
+ # Store sample rate in app state
151
+ app.state.sample_rate = app.state.generator.sample_rate
152
+ logger.info(f"Model sample rate: {app.state.sample_rate} Hz")
153
+
154
+ # Initialize voice enhancement system (this will create proper voice profiles)
155
+ logger.info("Initializing voice enhancement system...")
156
+ try:
157
+ from app.voice_enhancement import initialize_voice_profiles, save_voice_profiles
158
+ initialize_voice_profiles()
159
+ app.state.voice_enhancement_enabled = True
160
+ logger.info("Voice profiles initialized successfully")
161
+ except Exception as e:
162
+ error_stack = traceback.format_exc()
163
+ logger.error(f"Error initializing voice profiles: {str(e)}\n{error_stack}")
164
+ logger.warning("Voice enhancement features will be limited")
165
+ app.state.voice_enhancement_enabled = False
166
+
167
+ # Initialize voice memory system for consistent generation
168
+ logger.info("Initializing voice memory system...")
169
+ try:
170
+ from app.voice_memory import initialize_voices
171
+ initialize_voices(app.state.sample_rate)
172
+ app.state.voice_memory_enabled = True
173
+ logger.info("Voice memory system initialized")
174
+ except Exception as e:
175
+ logger.warning(f"Error initializing voice memory: {e}")
176
+ app.state.voice_memory_enabled = False
177
+
178
+ # Initialize voice cloning system
179
+ try:
180
+ logger.info("Initializing voice cloning system...")
181
+ from app.voice_cloning import VoiceCloner, CLONED_VOICES_DIR
182
+ # Update the cloned voices directory to use the persistent volume
183
+ app.state.cloned_voices_dir = "/app/cloned_voices" # Store path in app state for access
184
+ os.makedirs(app.state.cloned_voices_dir, exist_ok=True)
185
+ CLONED_VOICES_DIR = app.state.cloned_voices_dir # Update the module constant
186
+
187
+ # Initialize the voice cloner with proper device
188
+ app.state.voice_cloner = VoiceCloner(app.state.generator, device=device)
189
+
190
+ # Make sure existing voices are loaded
191
+ app.state.voice_cloner._load_existing_voices()
192
+
193
+ # Log the available voices
194
+ cloned_voices = app.state.voice_cloner.list_voices()
195
+ logger.info(f"Voice cloning system initialized with {len(cloned_voices)} existing voices")
196
+ for voice in cloned_voices:
197
+ logger.info(f" - {voice.name} (ID: {voice.id}, Speaker ID: {voice.speaker_id})")
198
+
199
+ # Flag for voice cloning availability
200
+ app.state.voice_cloning_enabled = True
201
+ except Exception as e:
202
+ error_stack = traceback.format_exc()
203
+ logger.error(f"Error initializing voice cloning: {e}\n{error_stack}")
204
+ logger.warning("Voice cloning features will not be available")
205
+ app.state.voice_cloning_enabled = False
206
+
207
+ # Create prompt templates for consistent generation
208
+ logger.info("Setting up prompt engineering templates...")
209
+ try:
210
+ from app.prompt_engineering import initialize_templates
211
+ app.state.prompt_templates = initialize_templates()
212
+ logger.info("Prompt templates initialized")
213
+ except Exception as e:
214
+ error_stack = traceback.format_exc()
215
+ logger.error(f"Error initializing prompt templates: {e}\n{error_stack}")
216
+ logger.warning("Voice consistency features will be limited")
217
+
218
+ # Generate voice reference samples (runs in background to avoid blocking startup)
219
+ async def generate_samples_async():
220
+ try:
221
+ logger.info("Starting voice reference generation (background task)...")
222
+ from app.voice_enhancement import create_voice_segments
223
+ create_voice_segments(app.state)
224
+ logger.info("Voice reference generation completed")
225
+ except Exception as e:
226
+ error_stack = traceback.format_exc()
227
+ logger.error(f"Error in voice reference generation: {str(e)}\n{error_stack}")
228
+
229
+ # Start as a background task
230
+ asyncio.create_task(generate_samples_async())
231
+
232
+ # Initialize voice cache for all voices (standard + cloned)
233
+ app.state.voice_cache = {}
234
+
235
+ # Add standard voices
236
+ standard_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
237
+ for voice in standard_voices:
238
+ app.state.voice_cache[voice] = []
239
+
240
+ # Add cloned voices to cache if they exist
241
+ if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
242
+ for voice in app.state.voice_cloner.list_voices():
243
+ app.state.voice_cache[voice.id] = []
244
+ # Also add by name for more flexible lookup
245
+ app.state.voice_cache[voice.name] = []
246
+
247
+ # Create mapping from voice name/id to speaker_id for easy lookup
248
+ app.state.voice_speaker_map = {
249
+ "alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5
250
+ }
251
+
252
+ # Add cloned voices to the speaker map
253
+ if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
254
+ for voice in app.state.voice_cloner.list_voices():
255
+ app.state.voice_speaker_map[voice.id] = voice.speaker_id
256
+ app.state.voice_speaker_map[voice.name] = voice.speaker_id
257
+ app.state.voice_speaker_map[str(voice.speaker_id)] = voice.speaker_id
258
+
259
+ # Compile voice information for API
260
+ app.state.available_voices = standard_voices.copy()
261
+ if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
262
+ for voice in app.state.voice_cloner.list_voices():
263
+ app.state.available_voices.append(voice.id)
264
+ app.state.available_voices.append(voice.name)
265
+
266
+ # Store model information for API endpoints
267
+ app.state.model_info = {
268
+ "name": "CSM-1B",
269
+ "device": device,
270
+ "device_map": device_map,
271
+ "sample_rate": app.state.sample_rate,
272
+ "standard_voices": standard_voices,
273
+ "cloned_voices": [v.id for v in app.state.voice_cloner.list_voices()] if app.state.voice_cloning_enabled else [],
274
+ "voice_enhancement_enabled": app.state.voice_enhancement_enabled,
275
+ "voice_memory_enabled": app.state.voice_memory_enabled,
276
+ "voice_cloning_enabled": app.state.voice_cloning_enabled,
277
+ "streaming_enabled": True
278
+ }
279
+
280
+ # Create a function to access all voices in a standardized format
281
+ def get_all_available_voices():
282
+ """Helper function to get all available voices for API endpoints"""
283
+ # Standard voices with fixed descriptions
284
+ all_voices = [
285
+ {"voice_id": "alloy", "name": "Alloy", "description": "Balanced and natural"},
286
+ {"voice_id": "echo", "name": "Echo", "description": "Resonant and deeper"},
287
+ {"voice_id": "fable", "name": "Fable", "description": "Bright and higher-pitched"},
288
+ {"voice_id": "onyx", "name": "Onyx", "description": "Deep and authoritative"},
289
+ {"voice_id": "nova", "name": "Nova", "description": "Warm and smooth"},
290
+ {"voice_id": "shimmer", "name": "Shimmer", "description": "Light and airy"}
291
+ ]
292
+
293
+ # Add cloned voices if available
294
+ if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
295
+ for voice in app.state.voice_cloner.list_voices():
296
+ all_voices.append({
297
+ "voice_id": voice.id,
298
+ "name": voice.name,
299
+ "description": voice.description or f"Cloned voice: {voice.name}"
300
+ })
301
+
302
+ return all_voices
303
+
304
+ app.state.get_all_voices = get_all_available_voices
305
+
306
+ # Add helper function to lookup voice info
307
+ def get_voice_info(voice_identifier):
308
+ """Look up voice information based on name, ID, or speaker_id"""
309
+ # Check standard voices
310
+ if voice_identifier in standard_voices:
311
+ return {
312
+ "type": "standard",
313
+ "voice_id": voice_identifier,
314
+ "name": voice_identifier,
315
+ "speaker_id": standard_voices.index(voice_identifier)
316
+ }
317
+
318
+ # Look for cloned voice
319
+ if not app.state.voice_cloning_enabled or not hasattr(app.state, "voice_cloner"):
320
+ return None
321
+
322
+ # Check by ID
323
+ if voice_identifier in app.state.voice_cloner.cloned_voices:
324
+ voice = app.state.voice_cloner.cloned_voices[voice_identifier]
325
+ return {
326
+ "type": "cloned",
327
+ "voice_id": voice.id,
328
+ "name": voice.name,
329
+ "speaker_id": voice.speaker_id
330
+ }
331
+
332
+ # Check by name
333
+ for v_id, voice in app.state.voice_cloner.cloned_voices.items():
334
+ if voice.name == voice_identifier:
335
+ return {
336
+ "type": "cloned",
337
+ "voice_id": voice.id,
338
+ "name": voice.name,
339
+ "speaker_id": voice.speaker_id
340
+ }
341
+
342
+ # Check by speaker_id (string representation)
343
+ try:
344
+ speaker_id = int(voice_identifier)
345
+ # Check if any cloned voice has this speaker_id
346
+ for v_id, voice in app.state.voice_cloner.cloned_voices.items():
347
+ if voice.speaker_id == speaker_id:
348
+ return {
349
+ "type": "cloned",
350
+ "voice_id": voice.id,
351
+ "name": voice.name,
352
+ "speaker_id": speaker_id
353
+ }
354
+ except (ValueError, TypeError):
355
+ pass
356
+
357
+ # No match found
358
+ return None
359
+
360
+ app.state.get_voice_info = get_voice_info
361
+
362
+ # Set up audio cache
363
+ app.state.audio_cache_enabled = os.environ.get("ENABLE_AUDIO_CACHE", "true").lower() == "true"
364
+ if app.state.audio_cache_enabled:
365
+ app.state.audio_cache_dir = "/app/audio_cache"
366
+ logger.info(f"Audio cache enabled, cache dir: {app.state.audio_cache_dir}")
367
+
368
+ # Log GPU utilization after model loading
369
+ if cuda_available:
370
+ memory_allocated = torch.cuda.memory_allocated() / (1024**3)
371
+ memory_reserved = torch.cuda.memory_reserved() / (1024**3)
372
+ logger.info(f"GPU memory: {memory_allocated:.2f} GB allocated, {memory_reserved:.2f} GB reserved")
373
+
374
+ if torch.cuda.device_count() > 1 and device_map:
375
+ logger.info("Multi-GPU setup active with the following memory usage:")
376
+ for i in range(torch.cuda.device_count()):
377
+ memory_allocated = torch.cuda.memory_allocated(i) / (1024**3)
378
+ memory_reserved = torch.cuda.memory_reserved(i) / (1024**3)
379
+ logger.info(f"GPU {i}: {memory_allocated:.2f} GB allocated, {memory_reserved:.2f} GB reserved")
380
+
381
+ # Set up scheduled tasks
382
+ try:
383
+ # Create a background task for periodic voice profile backup
384
+ async def periodic_voice_profile_backup(interval_hours=6):
385
+ """Periodically save voice profiles to persistent storage."""
386
+ while True:
387
+ try:
388
+ # Wait for the specified interval
389
+ await asyncio.sleep(interval_hours * 3600)
390
+
391
+ # Log the backup
392
+ timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
393
+ logger.info(f"Scheduled voice profile backup started at {timestamp}")
394
+
395
+ # Save voice profiles
396
+ if hasattr(app.state, "voice_enhancement_enabled") and app.state.voice_enhancement_enabled:
397
+ from app.voice_enhancement import save_voice_profiles
398
+ save_voice_profiles()
399
+ logger.info("Voice profiles saved successfully")
400
+
401
+ # Save voice memories
402
+ if hasattr(app.state, "voice_memory_enabled") and app.state.voice_memory_enabled:
403
+ from app.voice_memory import VOICE_MEMORIES
404
+ for voice_name, memory in VOICE_MEMORIES.items():
405
+ memory.save()
406
+ logger.info("Voice memories saved successfully")
407
+
408
+ except Exception as e:
409
+ logger.error(f"Error in periodic voice profile backup: {e}")
410
+
411
+ # Start the scheduled task
412
+ asyncio.create_task(periodic_voice_profile_backup(interval_hours=6))
413
+ logger.info("Started scheduled voice profile backup task")
414
+
415
+ except Exception as e:
416
+ logger.warning(f"Failed to set up scheduled tasks: {e}")
417
+
418
+ logger.info(f"CSM-1B TTS API is ready on {device} with sample rate {app.state.sample_rate}")
419
+ logger.info(f"Standard voices: {standard_voices}")
420
+ cloned_count = len(app.state.voice_cloner.list_voices()) if app.state.voice_cloning_enabled else 0
421
+ logger.info(f"Cloned voices: {cloned_count}")
422
+
423
+ except Exception as e:
424
+ error_stack = traceback.format_exc()
425
+ logger.error(f"Error loading model: {str(e)}\n{error_stack}")
426
+ app.state.generator = None
427
+
428
+ # Calculate total startup time
429
+ startup_time = time.time() - app.state.startup_time
430
+ logger.info(f"Application startup completed in {startup_time:.2f} seconds")
431
+
432
+ yield # This is where the application runs
433
+
434
+ # SHUTDOWN EVENT
435
+ logger.info("Application shutdown initiated")
436
+
437
+ # Clean up model resources
438
+ if hasattr(app.state, "generator") and app.state.generator is not None:
439
+ try:
440
+ # Clean up CUDA memory if available
441
+ if torch.cuda.is_available():
442
+ logger.info("Clearing CUDA cache")
443
+ torch.cuda.empty_cache()
444
+ torch.cuda.synchronize()
445
+ except Exception as e:
446
+ logger.error(f"Error during CUDA cleanup: {e}")
447
+
448
+ # Save voice profiles if they've been updated
449
+ try:
450
+ if hasattr(app.state, "voice_enhancement_enabled") and app.state.voice_enhancement_enabled:
451
+ from app.voice_enhancement import save_voice_profiles
452
+ logger.info("Saving voice profiles...")
453
+ save_voice_profiles()
454
+ logger.info("Voice profiles saved successfully")
455
+ except Exception as e:
456
+ logger.error(f"Error saving voice profiles: {e}")
457
+
458
+ # Save voice memories if they've been updated
459
+ try:
460
+ if hasattr(app.state, "voice_memory_enabled") and app.state.voice_memory_enabled:
461
+ from app.voice_memory import VOICE_MEMORIES
462
+ logger.info("Saving voice memories...")
463
+ for voice_name, memory in VOICE_MEMORIES.items():
464
+ memory.save()
465
+ logger.info("Voice memories saved successfully")
466
+ except Exception as e:
467
+ logger.error(f"Error saving voice memories: {e}")
468
+
469
+ # Clean up any temporary files
470
+ try:
471
+ for temp_file in glob.glob(os.path.join(tempfile.gettempdir(), "csm_tts_*")):
472
+ try:
473
+ os.remove(temp_file)
474
+ logger.info(f"Removed temporary file: {temp_file}")
475
+ except:
476
+ pass
477
+ except Exception as e:
478
+ logger.warning(f"Error cleaning up temporary files: {e}")
479
+
480
+ logger.info("Application shutdown complete")
481
+
482
+ # Initialize FastAPI app
483
+ app = FastAPI(
484
+ title="CSM-1B TTS API",
485
+ description="OpenAI-compatible TTS API using the CSM-1B model from Sesame",
486
+ version="1.0.0",
487
+ lifespan=lifespan
488
+ )
489
+
490
+ # Add CORS middleware
491
+ app.add_middleware(
492
+ CORSMiddleware,
493
+ allow_origins=["*"],
494
+ allow_credentials=True,
495
+ allow_methods=["*"],
496
+ allow_headers=["*"],
497
+ )
498
+
499
+ # Create static and other required directories
500
+ os.makedirs("/app/static", exist_ok=True)
501
+ os.makedirs("/app/cloned_voices", exist_ok=True)
502
+
503
+ # Mount the static files directory
504
+ app.mount("/static", StaticFiles(directory="/app/static"), name="static")
505
+
506
+ # Include routers
507
+ app.include_router(api_router, prefix="/api/v1")
508
+
509
+ # Add OpenAI compatible route
510
+ app.include_router(api_router, prefix="/v1")
511
+
512
+ # Add voice cloning routes
513
+ from app.api.voice_cloning_routes import router as voice_cloning_router
514
+ app.include_router(voice_cloning_router, prefix="/api/v1")
515
+ app.include_router(voice_cloning_router, prefix="/v1")
516
+
517
+ # Add streaming routes
518
+ from app.api.streaming import router as streaming_router
519
+ app.include_router(streaming_router, prefix="/api/v1")
520
+ app.include_router(streaming_router, prefix="/v1")
521
+
522
+ # Middleware for request timing
523
+ @app.middleware("http")
524
+ async def add_process_time_header(request: Request, call_next):
525
+ """Middleware to track request processing time."""
526
+ start_time = time.time()
527
+ response = await call_next(request)
528
+ process_time = time.time() - start_time
529
+ response.headers["X-Process-Time"] = str(process_time)
530
+ logger.debug(f"Request to {request.url.path} processed in {process_time:.3f} seconds")
531
+ return response
532
+
533
+ # Health check endpoint
534
+ @app.get("/health", include_in_schema=False)
535
+ async def health_check(request: Request):
536
+ """Health check endpoint that returns the status of the API."""
537
+ model_status = "healthy" if hasattr(request.app.state, "generator") and request.app.state.generator is not None else "unhealthy"
538
+ uptime = time.time() - getattr(request.app.state, "startup_time", time.time())
539
+
540
+ # Get voice information
541
+ standard_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
542
+ cloned_voices = []
543
+
544
+ if hasattr(request.app.state, "voice_cloner") and request.app.state.voice_cloner is not None:
545
+ cloned_voices = [
546
+ {"id": v.id, "name": v.name, "speaker_id": v.speaker_id}
547
+ for v in request.app.state.voice_cloner.list_voices()
548
+ ]
549
+
550
+ # Get CUDA memory stats if available
551
+ cuda_stats = None
552
+ if torch.cuda.is_available():
553
+ cuda_stats = {
554
+ "allocated_gb": torch.cuda.memory_allocated() / (1024**3),
555
+ "reserved_gb": torch.cuda.memory_reserved() / (1024**3)
556
+ }
557
+
558
+ return {
559
+ "status": model_status,
560
+ "uptime": f"{uptime:.2f} seconds",
561
+ "device": getattr(request.app.state, "device", "unknown"),
562
+ "model": "CSM-1B",
563
+ "standard_voices": standard_voices,
564
+ "cloned_voices": cloned_voices,
565
+ "cloned_voices_count": len(cloned_voices),
566
+ "sample_rate": getattr(request.app.state, "sample_rate", 0),
567
+ "enhancements": "enabled" if hasattr(request.app.state, "model_info") and
568
+ request.app.state.model_info.get("voice_enhancement_enabled", False) else "disabled",
569
+ "streaming": "enabled",
570
+ "cuda": cuda_stats,
571
+ "version": "1.0.0"
572
+ }
573
+
574
+ # Version endpoint
575
+ @app.get("/version", include_in_schema=False)
576
+ async def version():
577
+ """Version endpoint that returns API version information."""
578
+ return {
579
+ "api_version": "1.0.0",
580
+ "model_version": "CSM-1B",
581
+ "compatible_with": "OpenAI TTS v1",
582
+ "enhancements": "voice consistency and audio quality v1.0",
583
+ "voice_cloning": "enabled" if hasattr(app.state, "voice_cloner") else "disabled",
584
+ "streaming": "enabled"
585
+ }
586
+
587
+ # Voice cloning UI endpoint
588
+ @app.get("/voice-cloning", include_in_schema=False)
589
+ async def voice_cloning_ui():
590
+ """Voice cloning UI endpoint."""
591
+ return FileResponse("/app/static/voice-cloning.html")
592
+
593
+ # Streaming demo endpoint
594
+ @app.get("/streaming-demo", include_in_schema=False)
595
+ async def streaming_demo():
596
+ """Streaming TTS demo endpoint."""
597
+ return FileResponse("/app/static/streaming-demo.html")
598
+
599
+ @app.get("/", include_in_schema=False)
600
+ async def root():
601
+ """Root endpoint that redirects to docs."""
602
+ logger.debug("Root endpoint accessed, redirecting to docs")
603
+ return RedirectResponse(url="/docs")
604
+
605
+ if __name__ == "__main__":
606
+ # Get port from environment or use default
607
+ port = int(os.environ.get("PORT", 8000))
608
+
609
+ # Development mode flag
610
+ dev_mode = os.environ.get("DEV_MODE", "false").lower() == "true"
611
+
612
+ # Log level (default to INFO, but can be overridden)
613
+ log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
614
+ logging.getLogger().setLevel(log_level)
615
+
616
+ # Check for audio enhancement and voice cloning flags
617
+ enable_enhancements = os.environ.get("ENABLE_ENHANCEMENTS", "true").lower() == "true"
618
+ enable_voice_cloning = os.environ.get("ENABLE_VOICE_CLONING", "true").lower() == "true"
619
+
620
+ if not enable_enhancements:
621
+ logger.warning("Voice enhancements disabled by environment variable")
622
+ if not enable_voice_cloning:
623
+ logger.warning("Voice cloning disabled by environment variable")
624
+
625
+ logger.info(f"Voice enhancements: {'enabled' if enable_enhancements else 'disabled'}")
626
+ logger.info(f"Voice cloning: {'enabled' if enable_voice_cloning else 'disabled'}")
627
+ logger.info(f"Streaming: enabled")
628
+ logger.info(f"Log level: {log_level}")
629
+
630
+ if dev_mode:
631
+ logger.info(f"Running in development mode with auto-reload enabled on port {port}")
632
+ uvicorn.run(
633
+ "app.main:app",
634
+ host="0.0.0.0",
635
+ port=port,
636
+ reload=True,
637
+ log_level=log_level.lower()
638
+ )
639
+ else:
640
+ logger.info(f"Running in production mode on port {port}")
641
+ uvicorn.run(
642
+ "app.main:app",
643
+ host="0.0.0.0",
644
+ port=port,
645
+ reload=False,
646
+ log_level=log_level.lower()
647
+ )
app/models.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+
6
+
7
+ @dataclass
8
+ class Segment:
9
+ """A segment of speech with text, speaker, and audio."""
10
+ speaker: int
11
+ text: str
12
+ # (num_samples,), sample_rate = 24_000
13
+ audio: torch.Tensor
app/prompt_engineering.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prompt engineering for consistent voice generation."""
2
+ import re
3
+ import random
4
+ from typing import List, Dict, Optional
5
+ import logging
6
+
7
+ # Set up logging
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Voice style descriptors for consistent prompting
11
+ VOICE_STYLES = {
12
+ "alloy": {
13
+ "adjectives": ["balanced", "natural", "clear", "articulate", "neutral", "conversational"],
14
+ "characteristics": ["medium pitch", "even pacing", "neutral tone", "balanced resonance"],
15
+ "speaking_style": "conversational and balanced"
16
+ },
17
+ "echo": {
18
+ "adjectives": ["resonant", "deep", "reverberant", "rich", "sonorous", "full"],
19
+ "characteristics": ["lower pitch", "deliberate pacing", "resonant tone", "deeper timbre"],
20
+ "speaking_style": "rich and resonant"
21
+ },
22
+ "fable": {
23
+ "adjectives": ["bright", "light", "clear", "energetic", "articulate", "animated"],
24
+ "characteristics": ["higher pitch", "lively pacing", "bright tone", "clear articulation"],
25
+ "speaking_style": "bright and energetic"
26
+ },
27
+ "onyx": {
28
+ "adjectives": ["deep", "authoritative", "powerful", "commanding", "strong", "resolute"],
29
+ "characteristics": ["low pitch", "measured pacing", "authoritative tone", "strong projection"],
30
+ "speaking_style": "deep and authoritative"
31
+ },
32
+ "nova": {
33
+ "adjectives": ["warm", "pleasant", "smooth", "harmonious", "gentle", "comforting"],
34
+ "characteristics": ["medium pitch", "smooth pacing", "warm tone", "pleasant timbre"],
35
+ "speaking_style": "warm and smooth"
36
+ },
37
+ "shimmer": {
38
+ "adjectives": ["light", "airy", "bright", "crystalline", "delicate", "expressive"],
39
+ "characteristics": ["higher pitch", "quick pacing", "light tone", "bright timbre"],
40
+ "speaking_style": "light and expressive"
41
+ },
42
+ "custom": {
43
+ "adjectives": ["clear", "distinct", "authentic", "natural", "personalized", "unique"],
44
+ "characteristics": ["natural rhythm", "authentic tone", "personal inflection", "distinctive sound"],
45
+ "speaking_style": "authentic and natural"
46
+ }
47
+ }
48
+
49
+ def initialize_templates():
50
+ """Initialize prompt templates - placeholder for any future setup."""
51
+ logger.info("Prompt templates initialized")
52
+ return VOICE_STYLES
53
+
54
+ def split_into_segments(text: str, max_chars: int = 150) -> List[str]:
55
+ """Split text into optimal segments for better generation.
56
+ Args:
57
+ text: Text to split
58
+ max_chars: Maximum characters per segment
59
+ Returns:
60
+ List of text segments
61
+ """
62
+ # Handle empty or very short text
63
+ if not text or len(text) <= max_chars:
64
+ return [text]
65
+
66
+ # Split by sentences first
67
+ sentences = re.split(r'(?<=[.!?])\s+', text)
68
+
69
+ # Initialize segments
70
+ segments = []
71
+ current_segment = ""
72
+
73
+ for sentence in sentences:
74
+ # If adding this sentence would exceed max_chars
75
+ if len(current_segment) + len(sentence) > max_chars:
76
+ # If current segment is not empty, add it to segments
77
+ if current_segment:
78
+ segments.append(current_segment.strip())
79
+ current_segment = ""
80
+
81
+ # If this sentence alone exceeds max_chars, split it by phrases
82
+ if len(sentence) > max_chars:
83
+ phrases = re.split(r'(?<=[,;:])\s+', sentence)
84
+ for phrase in phrases:
85
+ if len(phrase) > max_chars:
86
+ # Split long phrases into chunks
87
+ words = phrase.split()
88
+ chunk = ""
89
+ for word in words:
90
+ if len(chunk) + len(word) + 1 <= max_chars:
91
+ chunk += " " + word if chunk else word
92
+ else:
93
+ segments.append(chunk.strip())
94
+ chunk = word
95
+ if chunk:
96
+ segments.append(chunk.strip())
97
+ else:
98
+ if len(current_segment) + len(phrase) <= max_chars:
99
+ current_segment += " " + phrase if current_segment else phrase
100
+ else:
101
+ segments.append(current_segment.strip())
102
+ current_segment = phrase
103
+ else:
104
+ current_segment = sentence
105
+ else:
106
+ current_segment += " " + sentence if current_segment else sentence
107
+
108
+ # Add the last segment
109
+ if current_segment:
110
+ segments.append(current_segment.strip())
111
+
112
+ logger.info(f"Split text into {len(segments)} segments")
113
+ return segments
114
+
115
+ def format_text_for_voice(text: str, voice_name: str, segment_index: int = 0, total_segments: int = 1) -> str:
116
+ """Format text with voice characteristics for more consistent generation.
117
+ Args:
118
+ text: Text to format
119
+ voice_name: Name of the voice
120
+ segment_index: Index of this segment (for multi-segment texts)
121
+ total_segments: Total number of segments
122
+ Returns:
123
+ Formatted text optimized for consistent voice generation
124
+ """
125
+ # IMPORTANT: We no longer add voice instructions in brackets since CSM reads them aloud
126
+ # Instead, we're using speaker IDs to control voice identity which is what the model expects
127
+
128
+ # Just return the unmodified text - the Generator class will handle proper formatting
129
+ return text
app/text_normalizer.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text normalization and cleaning utilities for CSM-1B TTS system.
3
+ Handles common issues like contractions, numbers, and special characters.
4
+ """
5
+ import re
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class TextNormalizer:
11
+ """Text normalization utilities for TTS."""
12
+
13
+ # Common English contractions mapping
14
+ CONTRACTIONS = {
15
+ "don't": "dont",
16
+ "won't": "wont",
17
+ "can't": "cant",
18
+ "isn't": "isnt",
19
+ "he's": "hes",
20
+ "she's": "shes",
21
+ "they're": "theyre",
22
+ "we're": "were",
23
+ "you're": "youre",
24
+ "that's": "thats",
25
+ "it's": "its",
26
+ "what's": "whats",
27
+ "let's": "lets",
28
+ "who's": "whos",
29
+ "how's": "hows",
30
+ "where's": "wheres",
31
+ "there's": "theres",
32
+ "wouldn't": "wouldnt",
33
+ "shouldn't": "shouldnt",
34
+ "couldn't": "couldnt",
35
+ "hasn't": "hasnt",
36
+ "haven't": "havent",
37
+ "hadn't": "hadnt",
38
+ "didn't": "didnt",
39
+ "i'm": "im",
40
+ "i've": "ive",
41
+ "i'd": "id",
42
+ "i'll": "ill",
43
+ "you've": "youve",
44
+ "you'll": "youll",
45
+ "you'd": "youd",
46
+ "we've": "weve",
47
+ "we'll": "well",
48
+ "we'd": "wed",
49
+ "they've": "theyve",
50
+ "they'll": "theyll",
51
+ "they'd": "theyd",
52
+ "aren't": "arent",
53
+ "weren't": "werent",
54
+ "wasn't": "wasnt",
55
+ }
56
+
57
+ # Common abbreviations to expand
58
+ ABBREVIATIONS = {
59
+ "Mr.": "Mister",
60
+ "Mrs.": "Misses",
61
+ "Dr.": "Doctor",
62
+ "Prof.": "Professor",
63
+ "St.": "Street",
64
+ "Rd.": "Road",
65
+ "Ave.": "Avenue",
66
+ "vs.": "versus",
67
+ "etc.": "etcetera",
68
+ "e.g.": "for example",
69
+ "i.e.": "that is",
70
+ "approx.": "approximately",
71
+ }
72
+
73
+ # Simple number words for common numbers
74
+ NUMBER_WORDS = {
75
+ "0": "zero",
76
+ "1": "one",
77
+ "2": "two",
78
+ "3": "three",
79
+ "4": "four",
80
+ "5": "five",
81
+ "6": "six",
82
+ "7": "seven",
83
+ "8": "eight",
84
+ "9": "nine",
85
+ "10": "ten",
86
+ "11": "eleven",
87
+ "12": "twelve",
88
+ "13": "thirteen",
89
+ "14": "fourteen",
90
+ "15": "fifteen",
91
+ "16": "sixteen",
92
+ "17": "seventeen",
93
+ "18": "eighteen",
94
+ "19": "nineteen",
95
+ "20": "twenty",
96
+ "30": "thirty",
97
+ "40": "forty",
98
+ "50": "fifty",
99
+ "60": "sixty",
100
+ "70": "seventy",
101
+ "80": "eighty",
102
+ "90": "ninety",
103
+ "100": "one hundred",
104
+ "1000": "one thousand",
105
+ "1000000": "one million",
106
+ "1000000000": "one billion",
107
+ }
108
+
109
+ @classmethod
110
+ def normalize_text(cls, text: str) -> str:
111
+ """
112
+ Normalize text for TTS: handle contractions, punctuation, and special cases.
113
+
114
+ Args:
115
+ text: Input text to normalize
116
+
117
+ Returns:
118
+ Normalized text ready for TTS
119
+ """
120
+ if not text:
121
+ return text
122
+
123
+ # Log original text for debugging
124
+ logger.debug(f"Normalizing text: '{text}'")
125
+
126
+ # Remove voice instructions in square brackets
127
+ text = re.sub(r'\[.*?\]', '', text)
128
+
129
+ # Handle contractions - preserving case sensitivity
130
+ for contraction, replacement in cls.CONTRACTIONS.items():
131
+ # Case insensitive replacement
132
+ text = re.sub(r'\b' + re.escape(contraction) + r'\b', replacement, text, flags=re.IGNORECASE)
133
+
134
+ # Expand common abbreviations
135
+ for abbr, expanded in cls.ABBREVIATIONS.items():
136
+ text = text.replace(abbr, expanded)
137
+
138
+ # Handle numbers - only convert standalone numbers
139
+ def replace_number(match):
140
+ number = match.group(0)
141
+ if number in cls.NUMBER_WORDS:
142
+ return cls.NUMBER_WORDS[number]
143
+ return number
144
+
145
+ text = re.sub(r'\b\d+\b', replace_number, text)
146
+
147
+ # Replace problematic symbols
148
+ text = text.replace("&", " and ")
149
+ text = text.replace("%", " percent ")
150
+ text = text.replace("@", " at ")
151
+ text = text.replace("#", " number ")
152
+ text = text.replace("$", " dollar ")
153
+ text = text.replace("€", " euro ")
154
+ text = text.replace("£", " pound ")
155
+ text = text.replace("¥", " yen ")
156
+
157
+ # Handle dates in MM/DD/YYYY format
158
+ text = re.sub(r'\b(\d{1,2})/(\d{1,2})/(\d{4})\b', r'\1 \2 \3', text)
159
+
160
+ # Fix excessive spaces
161
+ text = re.sub(r'\s+', ' ', text).strip()
162
+
163
+ # Ensure sentence ends with punctuation
164
+ if not text[-1] in ['.', '!', '?', ';', ':', ',']:
165
+ text = text + '.'
166
+
167
+ logger.debug(f"Normalized text: '{text}'")
168
+ return text
169
+
170
+ @classmethod
171
+ def split_into_sentences(cls, text: str) -> list:
172
+ """
173
+ Split text into sentences for better TTS performance.
174
+
175
+ Args:
176
+ text: Input text to split
177
+
178
+ Returns:
179
+ List of sentences
180
+ """
181
+ # Normalize first
182
+ text = cls.normalize_text(text)
183
+
184
+ # Split on sentence boundaries
185
+ sentences = re.split(r'(?<=[.!?])\s+', text)
186
+
187
+ # Remove empty sentences
188
+ sentences = [s for s in sentences if s.strip()]
189
+
190
+ return sentences
191
+
192
+ def clean_text_for_tts(text: str) -> str:
193
+ """Clean and normalize text for TTS processing."""
194
+ return TextNormalizer.normalize_text(text)
app/torchtune_models.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Torchtune models for CSM-1B."""
2
+ import logging
3
+ from dataclasses import dataclass
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ # Set up logging
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # First, try to import llama3_2 from torchtune directly
11
+ try:
12
+ import torchtune
13
+ logger.info(f"Torchtune version: {getattr(torchtune, '__version__', 'unknown')}")
14
+
15
+ # Print available modules in torchtune.models
16
+ try:
17
+ import torchtune.models
18
+ logger.info(f"Available modules in torchtune.models: {dir(torchtune.models)}")
19
+ except Exception as e:
20
+ logger.error(f"Error inspecting torchtune.models: {e}")
21
+
22
+ # Try to import llama3_2 model
23
+ try:
24
+ from torchtune.models.llama3_2 import llama3_2
25
+ logger.info("Successfully imported llama3_2 from torchtune")
26
+ except ImportError as e:
27
+ logger.warning(f"Could not import llama3_2: {e}")
28
+ # Try to import regular llama as fallback
29
+ try:
30
+ from torchtune.models.llama import llama
31
+ logger.info("Using llama from torchtune.models.llama as fallback")
32
+ llama3_2 = llama # Alias llama as llama3_2
33
+ except ImportError:
34
+ logger.error("Could not import llama model either. Will use custom implementation.")
35
+ llama3_2 = None
36
+ except ImportError as e:
37
+ logger.error(f"Torchtune not available: {e}")
38
+ torchtune = None
39
+ llama3_2 = None
40
+
41
+
42
+ # Define our own model implementations as fallbacks
43
+ def llama3_2_1B_custom():
44
+ """Create a Llama 3.2 1B model."""
45
+ from app.custom_transformer import CustomTransformerDecoder
46
+ return CustomTransformerDecoder(
47
+ vocab_size=128_256,
48
+ num_layers=16,
49
+ num_heads=32,
50
+ num_kv_heads=8,
51
+ embed_dim=2048,
52
+ max_seq_len=2048,
53
+ intermediate_dim=8192,
54
+ attn_dropout=0.0,
55
+ norm_eps=1e-5,
56
+ )
57
+
58
+
59
+ def llama3_2_100M_custom():
60
+ """Create a Llama 3.2 100M model."""
61
+ from app.custom_transformer import CustomTransformerDecoder
62
+ return CustomTransformerDecoder(
63
+ vocab_size=128_256,
64
+ num_layers=4,
65
+ num_heads=8,
66
+ num_kv_heads=2,
67
+ embed_dim=1024,
68
+ max_seq_len=2048,
69
+ intermediate_dim=8192,
70
+ attn_dropout=0.0,
71
+ norm_eps=1e-5,
72
+ )
73
+
74
+
75
+ # Setup fallback to our own implementations if needed
76
+ if llama3_2 is None:
77
+ logger.warning("Using custom implementations for Llama models")
78
+ FLAVORS = {
79
+ "llama-1B": llama3_2_1B_custom,
80
+ "llama-100M": llama3_2_100M_custom,
81
+ }
82
+ else:
83
+ logger.info("Using torchtune implementations for Llama models")
84
+ FLAVORS = {
85
+ "llama-1B": lambda: llama3_2(
86
+ vocab_size=128_256,
87
+ num_layers=16,
88
+ num_heads=32,
89
+ num_kv_heads=8,
90
+ embed_dim=2048,
91
+ max_seq_len=2048,
92
+ intermediate_dim=8192,
93
+ attn_dropout=0.0,
94
+ norm_eps=1e-5,
95
+ rope_base=500_000,
96
+ scale_factor=32,
97
+ ),
98
+ "llama-100M": lambda: llama3_2(
99
+ vocab_size=128_256,
100
+ num_layers=4,
101
+ num_heads=8,
102
+ num_kv_heads=2,
103
+ embed_dim=1024,
104
+ max_seq_len=2048,
105
+ intermediate_dim=8192,
106
+ attn_dropout=0.0,
107
+ norm_eps=1e-5,
108
+ rope_base=500_000,
109
+ scale_factor=32,
110
+ ),
111
+ }
112
+
113
+
114
+ def _prepare_transformer(model):
115
+ """Prepare transformer for use."""
116
+ embed_dim = model.tok_embeddings.embedding_dim
117
+ model.tok_embeddings = nn.Identity()
118
+ model.output = nn.Identity()
119
+ return model, embed_dim
120
+
121
+
122
+ def _create_causal_mask(seq_len: int, device: torch.device):
123
+ """Create causal mask."""
124
+ return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
125
+
126
+
127
+ def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
128
+ """Index causal mask.
129
+
130
+ Args:
131
+ mask: (max_seq_len, max_seq_len)
132
+ input_pos: (batch_size, seq_len)
133
+
134
+ Returns:
135
+ (batch_size, seq_len, max_seq_len)
136
+ """
137
+ r = mask[input_pos, :]
138
+ return r
139
+
140
+
141
+ def _multinomial_sample_one_no_sync(probs):
142
+ """Do multinomial sampling without a cuda synchronization."""
143
+ q = torch.empty_like(probs).exponential_(1)
144
+ return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
145
+
146
+
147
+ def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
148
+ """Sample from top-k logits."""
149
+ logits = logits / temperature
150
+ filter_value: float = -float("Inf")
151
+ indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
152
+ scores_processed = logits.masked_fill(indices_to_remove, filter_value)
153
+ scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
154
+ probs = torch.nn.functional.softmax(scores_processed, dim=-1)
155
+ sample_token = _multinomial_sample_one_no_sync(probs)
156
+ return sample_token
157
+
158
+
159
+ @dataclass
160
+ class ModelArgs:
161
+ """Model arguments."""
162
+ backbone_flavor: str
163
+ decoder_flavor: str
164
+ text_vocab_size: int
165
+ audio_vocab_size: int
166
+ audio_num_codebooks: int
167
+
168
+
169
+ class Model(nn.Module):
170
+ """CSM-1B model."""
171
+
172
+ def __init__(self, args: ModelArgs):
173
+ """Initialize model."""
174
+ super().__init__()
175
+ self.args = args
176
+ logger.info(f"Creating model with backbone: {args.backbone_flavor}, decoder: {args.decoder_flavor}")
177
+
178
+ # Load backbone and decoder
179
+ self.backbone, backbone_dim = _prepare_transformer(FLAVORS[args.backbone_flavor]())
180
+ self.decoder, decoder_dim = _prepare_transformer(FLAVORS[args.decoder_flavor]())
181
+
182
+ # Embeddings
183
+ self.text_embeddings = nn.Embedding(args.text_vocab_size, backbone_dim)
184
+ self.audio_embeddings = nn.Embedding(args.audio_vocab_size * args.audio_num_codebooks, backbone_dim)
185
+
186
+ # Projection and heads
187
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
188
+ self.codebook0_head = nn.Linear(backbone_dim, args.audio_vocab_size, bias=False)
189
+ self.audio_head = nn.Parameter(torch.empty(args.audio_num_codebooks - 1, decoder_dim, args.audio_vocab_size))
190
+
191
+ # Initialize audio head
192
+ nn.init.normal_(self.audio_head, mean=0.0, std=0.02)
193
+
194
+ def setup_caches(self, max_batch_size: int) -> torch.Tensor:
195
+ """Setup KV caches and return a causal mask."""
196
+ dtype = next(self.parameters()).dtype
197
+ device = next(self.parameters()).device
198
+
199
+ with device:
200
+ self.backbone.setup_caches(max_batch_size, dtype)
201
+ self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.args.audio_num_codebooks)
202
+
203
+ self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
204
+ self.register_buffer("decoder_causal_mask", _create_causal_mask(self.args.audio_num_codebooks, device))
205
+
206
+ def generate_frame(
207
+ self,
208
+ tokens: torch.Tensor,
209
+ tokens_mask: torch.Tensor,
210
+ input_pos: torch.Tensor,
211
+ temperature: float,
212
+ topk: int,
213
+ ) -> torch.Tensor:
214
+ """Generate a frame of audio tokens.
215
+
216
+ Args:
217
+ tokens: (batch_size, seq_len, audio_num_codebooks+1)
218
+ tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
219
+ input_pos: (batch_size, seq_len) positions for each token
220
+
221
+ Returns:
222
+ (batch_size, audio_num_codebooks) sampled tokens
223
+ """
224
+ dtype = next(self.parameters()).dtype
225
+ b, s = tokens.size()[:2]
226
+
227
+ assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
228
+
229
+ curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
230
+ embeds = self._embed_tokens(tokens)
231
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
232
+ h = masked_embeds.sum(dim=2)
233
+
234
+ h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
235
+
236
+ last_h = h[:, -1, :]
237
+ c0_logits = self.codebook0_head(last_h)
238
+ c0_sample = sample_topk(c0_logits, topk, temperature)
239
+ c0_embed = self._embed_audio(0, c0_sample)
240
+
241
+ curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
242
+ curr_sample = c0_sample.clone()
243
+ curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
244
+
245
+ # Decoder caches must be reset every frame.
246
+ self.decoder.reset_caches()
247
+
248
+ for i in range(1, self.args.audio_num_codebooks):
249
+ curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
250
+ decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
251
+ dtype=dtype
252
+ )
253
+ ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
254
+ ci_sample = sample_topk(ci_logits, topk, temperature)
255
+ ci_embed = self._embed_audio(i, ci_sample)
256
+
257
+ curr_h = ci_embed
258
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
259
+ curr_pos = curr_pos[:, -1:] + 1
260
+
261
+ return curr_sample
262
+
263
+ def reset_caches(self):
264
+ """Reset KV caches."""
265
+ self.backbone.reset_caches()
266
+ self.decoder.reset_caches()
267
+
268
+ def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
269
+ """Embed audio tokens."""
270
+ return self.audio_embeddings(tokens + codebook * self.args.audio_vocab_size)
271
+
272
+ def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
273
+ """Embed tokens."""
274
+ text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
275
+ audio_tokens = tokens[:, :, :-1] + (
276
+ self.args.audio_vocab_size * torch.arange(self.args.audio_num_codebooks, device=tokens.device)
277
+ )
278
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
279
+ tokens.size(0), tokens.size(1), self.args.audio_num_codebooks, -1
280
+ )
281
+
282
+ return torch.cat([audio_embeds, text_embeds], dim=-2)
app/utils/audio_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio utilities for CSM-1B API."""
2
+
3
+ import io
4
+ import tempfile
5
+ from typing import Optional
6
+ import os
7
+
8
+ import torch
9
+ import torchaudio
10
+ import ffmpeg
11
+
12
+
13
+ def convert_audio_format(
14
+ audio_tensor: torch.Tensor,
15
+ sample_rate: int,
16
+ format: str = "mp3",
17
+ bit_rate: Optional[str] = "128k",
18
+ ) -> bytes:
19
+ """Convert audio tensor to specified format.
20
+
21
+ Args:
22
+ audio_tensor: Audio tensor (channels, samples)
23
+ sample_rate: Sample rate
24
+ format: Output format (mp3, opus, aac, flac, wav)
25
+ bit_rate: Bit rate for lossy formats
26
+
27
+ Returns:
28
+ Audio bytes in specified format
29
+ """
30
+ # Create temporary files
31
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
32
+ wav_path = temp_wav.name
33
+
34
+ temp_out = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
35
+ out_path = temp_out.name
36
+ temp_out.close()
37
+
38
+ try:
39
+ # Save as WAV first (native format for torchaudio)
40
+ torchaudio.save(wav_path, audio_tensor.unsqueeze(0) if audio_tensor.dim() == 1 else audio_tensor,
41
+ sample_rate)
42
+
43
+ # Convert to desired format using ffmpeg
44
+ if format == "mp3":
45
+ ffmpeg.input(wav_path).output(out_path, format=format, audio_bitrate=bit_rate).run(quiet=True)
46
+ elif format in ["opus", "aac"]:
47
+ ffmpeg.input(wav_path).output(out_path, format=format).run(quiet=True)
48
+ elif format == "flac":
49
+ ffmpeg.input(wav_path).output(out_path, format=format).run(quiet=True)
50
+ elif format == "wav":
51
+ # Already saved as WAV
52
+ pass
53
+
54
+ # Read the output file
55
+ with open(out_path if format != "wav" else wav_path, "rb") as f:
56
+ audio_bytes = f.read()
57
+
58
+ return audio_bytes
59
+
60
+ finally:
61
+ # Clean up temporary files
62
+ for path in [wav_path, out_path]:
63
+ if os.path.exists(path):
64
+ os.unlink(path)
app/utils/init.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Utilities for CSM-1B API."""
app/utils/scheduled_tasks.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Scheduled tasks for the TTS API."""
2
+ import asyncio
3
+ import logging
4
+ import time
5
+ from datetime import datetime
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ async def periodic_voice_profile_backup(app_state, interval_hours=6):
10
+ """
11
+ Periodically save voice profiles to persistent storage.
12
+
13
+ Args:
14
+ app_state: The application state object
15
+ interval_hours: Backup interval in hours
16
+ """
17
+ while True:
18
+ try:
19
+ # Wait for the specified interval
20
+ await asyncio.sleep(interval_hours * 3600)
21
+
22
+ # Log the backup
23
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
24
+ logger.info(f"Scheduled voice profile backup started at {timestamp}")
25
+
26
+ # Save voice profiles
27
+ if hasattr(app_state, "voice_enhancement_enabled") and app_state.voice_enhancement_enabled:
28
+ from app.voice_enhancement import save_voice_profiles
29
+ save_voice_profiles()
30
+ logger.info("Voice profiles saved successfully")
31
+
32
+ # Save voice memories
33
+ if hasattr(app_state, "voice_memory_enabled") and app_state.voice_memory_enabled:
34
+ for voice_name in app_state.voice_cache:
35
+ if voice_name in ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]:
36
+ from app.voice_memory import VOICE_MEMORIES
37
+ if voice_name in VOICE_MEMORIES:
38
+ VOICE_MEMORIES[voice_name].save()
39
+ logger.info("Voice memories saved successfully")
40
+
41
+ except Exception as e:
42
+ logger.error(f"Error in periodic voice profile backup: {e}")
app/utils/voice_manager.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for managing voice references and profiles."""
2
+ import os
3
+ import logging
4
+ import torch
5
+ import torchaudio
6
+ import shutil
7
+ from typing import Dict, List, Optional, Any
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ # Define persistent paths
12
+ VOICE_REFERENCES_DIR = "/app/voice_references"
13
+ VOICE_PROFILES_DIR = "/app/voice_profiles"
14
+ VOICE_MEMORIES_DIR = "/app/voice_memories"
15
+
16
+ # Ensure directories exist
17
+ os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True)
18
+ os.makedirs(VOICE_PROFILES_DIR, exist_ok=True)
19
+ os.makedirs(VOICE_MEMORIES_DIR, exist_ok=True)
20
+
21
+ def backup_voice_data(backup_dir: str = "/app/voice_backups"):
22
+ """Create a backup of all voice data."""
23
+ os.makedirs(backup_dir, exist_ok=True)
24
+ timestamp = torch.datetime.now().strftime("%Y%m%d_%H%M%S")
25
+ backup_path = os.path.join(backup_dir, f"voice_backup_{timestamp}")
26
+ os.makedirs(backup_path, exist_ok=True)
27
+
28
+ # Backup voice references
29
+ if os.path.exists(VOICE_REFERENCES_DIR):
30
+ refs_backup = os.path.join(backup_path, "voice_references")
31
+ shutil.copytree(VOICE_REFERENCES_DIR, refs_backup)
32
+
33
+ # Backup voice profiles
34
+ if os.path.exists(VOICE_PROFILES_DIR):
35
+ profiles_backup = os.path.join(backup_path, "voice_profiles")
36
+ shutil.copytree(VOICE_PROFILES_DIR, profiles_backup)
37
+
38
+ # Backup voice memories
39
+ if os.path.exists(VOICE_MEMORIES_DIR):
40
+ memories_backup = os.path.join(backup_path, "voice_memories")
41
+ shutil.copytree(VOICE_MEMORIES_DIR, memories_backup)
42
+
43
+ logger.info(f"Voice data backed up to {backup_path}")
44
+ return backup_path
45
+
46
+ def restore_default_voices():
47
+ """Reset voices to their default state by removing existing voice data."""
48
+ for voice_dir in [VOICE_REFERENCES_DIR, VOICE_PROFILES_DIR, VOICE_MEMORIES_DIR]:
49
+ if os.path.exists(voice_dir):
50
+ # Create a backup before deleting
51
+ backup_path = backup_voice_data()
52
+
53
+ # Remove existing data
54
+ for item in os.listdir(voice_dir):
55
+ item_path = os.path.join(voice_dir, item)
56
+ if os.path.isdir(item_path):
57
+ shutil.rmtree(item_path)
58
+ else:
59
+ os.remove(item_path)
60
+
61
+ logger.info(f"Removed existing voice data from {voice_dir}")
62
+
63
+ logger.info(f"Voices reset to default state (backup created at {backup_path})")
64
+ return backup_path
65
+
66
+ def verify_voice_references():
67
+ """Check if voice references are complete and valid."""
68
+ standard_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
69
+ missing_voices = []
70
+
71
+ for voice in standard_voices:
72
+ voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice)
73
+ # Check if directory exists and contains reference files
74
+ if not os.path.exists(voice_dir) or len(os.listdir(voice_dir)) == 0:
75
+ missing_voices.append(voice)
76
+
77
+ return {
78
+ "complete": len(missing_voices) == 0,
79
+ "missing_voices": missing_voices,
80
+ "references_dir": VOICE_REFERENCES_DIR
81
+ }
82
+
83
+ def get_voice_storage_info() -> Dict[str, Any]:
84
+ """Get information about voice storage usage and status."""
85
+ result = {
86
+ "voice_references": {
87
+ "path": VOICE_REFERENCES_DIR,
88
+ "exists": os.path.exists(VOICE_REFERENCES_DIR),
89
+ "voices": [],
90
+ "total_size_mb": 0
91
+ },
92
+ "voice_profiles": {
93
+ "path": VOICE_PROFILES_DIR,
94
+ "exists": os.path.exists(VOICE_PROFILES_DIR),
95
+ "file_count": 0,
96
+ "total_size_mb": 0
97
+ },
98
+ "voice_memories": {
99
+ "path": VOICE_MEMORIES_DIR,
100
+ "exists": os.path.exists(VOICE_MEMORIES_DIR),
101
+ "voices": [],
102
+ "total_size_mb": 0
103
+ }
104
+ }
105
+
106
+ # Get voice references info
107
+ if result["voice_references"]["exists"]:
108
+ for voice in os.listdir(VOICE_REFERENCES_DIR):
109
+ voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice)
110
+ if os.path.isdir(voice_dir):
111
+ file_count = len([f for f in os.listdir(voice_dir) if f.endswith('.wav')])
112
+ dir_size = sum(os.path.getsize(os.path.join(voice_dir, f)) for f in os.listdir(voice_dir) if os.path.isfile(os.path.join(voice_dir, f)))
113
+ result["voice_references"]["voices"].append({
114
+ "name": voice,
115
+ "file_count": file_count,
116
+ "size_mb": dir_size / (1024 * 1024)
117
+ })
118
+ result["voice_references"]["total_size_mb"] += dir_size / (1024 * 1024)
119
+
120
+ # Get voice profiles info
121
+ if result["voice_profiles"]["exists"]:
122
+ files = [f for f in os.listdir(VOICE_PROFILES_DIR) if os.path.isfile(os.path.join(VOICE_PROFILES_DIR, f))]
123
+ result["voice_profiles"]["file_count"] = len(files)
124
+ result["voice_profiles"]["total_size_mb"] = sum(os.path.getsize(os.path.join(VOICE_PROFILES_DIR, f)) for f in files) / (1024 * 1024)
125
+
126
+ # Get voice memories info
127
+ if result["voice_memories"]["exists"]:
128
+ files = [f for f in os.listdir(VOICE_MEMORIES_DIR) if os.path.isfile(os.path.join(VOICE_MEMORIES_DIR, f))]
129
+ result["voice_memories"]["voices"] = [f.replace('.pt', '') for f in files if f.endswith('.pt')]
130
+ result["voice_memories"]["total_size_mb"] = sum(os.path.getsize(os.path.join(VOICE_MEMORIES_DIR, f)) for f in files) / (1024 * 1024)
131
+
132
+ return result
app/voice_cloning.py ADDED
@@ -0,0 +1,705 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Voice cloning module for CSM-1B TTS API.
3
+
4
+ This module provides functionality to clone voices from audio samples,
5
+ with advanced audio preprocessing and voice adaptation techniques.
6
+ """
7
+ import os
8
+ import io
9
+ import time
10
+ import tempfile
11
+ import logging
12
+ import asyncio
13
+ import yt_dlp
14
+ import whisper
15
+ from typing import Dict, List, Optional, Union, Tuple, BinaryIO
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torchaudio
21
+ from pydantic import BaseModel
22
+ from fastapi import UploadFile
23
+
24
+ from app.models import Segment
25
+
26
+ # Set up logging
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Directory for storing cloned voice data
30
+ CLONED_VOICES_DIR = "/app/cloned_voices"
31
+ os.makedirs(CLONED_VOICES_DIR, exist_ok=True)
32
+
33
+ class ClonedVoice(BaseModel):
34
+ """Model representing a cloned voice."""
35
+ id: str
36
+ name: str
37
+ created_at: float
38
+ speaker_id: int
39
+ description: Optional[str] = None
40
+ audio_duration: float
41
+ sample_count: int
42
+
43
+
44
+ class VoiceCloner:
45
+ """Voice cloning utility for CSM-1B model."""
46
+
47
+ def __init__(self, generator, device="cuda"):
48
+ """Initialize the voice cloner with a generator instance."""
49
+ self.generator = generator
50
+ self.device = device
51
+ self.sample_rate = generator.sample_rate
52
+ self.cloned_voices = self._load_existing_voices()
53
+ logger.info(f"Voice cloner initialized with {len(self.cloned_voices)} existing voices")
54
+
55
+ def _load_existing_voices(self) -> Dict[str, ClonedVoice]:
56
+ """Load existing cloned voices from disk."""
57
+ voices = {}
58
+ if not os.path.exists(CLONED_VOICES_DIR):
59
+ return voices
60
+
61
+ for voice_dir in os.listdir(CLONED_VOICES_DIR):
62
+ voice_path = os.path.join(CLONED_VOICES_DIR, voice_dir)
63
+ if not os.path.isdir(voice_path):
64
+ continue
65
+
66
+ info_path = os.path.join(voice_path, "info.json")
67
+ if os.path.exists(info_path):
68
+ try:
69
+ import json
70
+ with open(info_path, "r") as f:
71
+ voice_info = json.load(f)
72
+ voices[voice_dir] = ClonedVoice(**voice_info)
73
+ logger.info(f"Loaded cloned voice: {voice_dir}")
74
+ except Exception as e:
75
+ logger.error(f"Error loading voice {voice_dir}: {e}")
76
+
77
+ return voices
78
+
79
+ async def process_audio_file(
80
+ self,
81
+ file: Union[UploadFile, BinaryIO, str],
82
+ transcript: Optional[str] = None
83
+ ) -> Tuple[torch.Tensor, Optional[str], float]:
84
+ """
85
+ Process an audio file for voice cloning.
86
+
87
+ Args:
88
+ file: The audio file (UploadFile, file-like object, or path)
89
+ transcript: Optional transcript of the audio
90
+
91
+ Returns:
92
+ Tuple of (processed_audio, transcript, duration_seconds)
93
+ """
94
+ temp_path = None
95
+
96
+ try:
97
+ # Handle different input types
98
+ if isinstance(file, str):
99
+ # It's a file path
100
+ audio_path = file
101
+ logger.info(f"Processing audio from file path: {audio_path}")
102
+ else:
103
+ # Create a temporary file
104
+ temp_fd, temp_path = tempfile.mkstemp(suffix=".wav")
105
+ os.close(temp_fd) # Close the file descriptor
106
+
107
+ if isinstance(file, UploadFile):
108
+ # It's a FastAPI UploadFile
109
+ logger.info("Processing audio from UploadFile")
110
+ contents = await file.read()
111
+ with open(temp_path, "wb") as f:
112
+ f.write(contents)
113
+ elif hasattr(file, 'read'):
114
+ # It's a file-like object - check if it's async
115
+ logger.info("Processing audio from file-like object")
116
+ if asyncio.iscoroutinefunction(file.read):
117
+ # It's an async read method
118
+ contents = await file.read()
119
+ else:
120
+ # It's a sync read method
121
+ contents = file.read()
122
+
123
+ with open(temp_path, "wb") as f:
124
+ f.write(contents)
125
+ else:
126
+ raise ValueError(f"Unsupported file type: {type(file)}")
127
+
128
+ audio_path = temp_path
129
+ logger.info(f"Saved uploaded audio to temporary file: {audio_path}")
130
+
131
+ # Load audio
132
+ logger.info(f"Loading audio from {audio_path}")
133
+ audio, sr = torchaudio.load(audio_path)
134
+
135
+ # Convert to mono if stereo
136
+ if audio.shape[0] > 1:
137
+ logger.info(f"Converting {audio.shape[0]} channels to mono")
138
+ audio = torch.mean(audio, dim=0, keepdim=True)
139
+
140
+ # Remove first dimension if it's 1
141
+ if audio.shape[0] == 1:
142
+ audio = audio.squeeze(0)
143
+
144
+ # Resample if necessary
145
+ if sr != self.sample_rate:
146
+ logger.info(f"Resampling from {sr}Hz to {self.sample_rate}Hz")
147
+ audio = torchaudio.functional.resample(
148
+ audio, orig_freq=sr, new_freq=self.sample_rate
149
+ )
150
+
151
+ # Get audio duration
152
+ duration_seconds = len(audio) / self.sample_rate
153
+
154
+ # Process audio for better quality
155
+ logger.info(f"Preprocessing audio for quality enhancement")
156
+ processed_audio = self._preprocess_audio(audio)
157
+ processed_duration = len(processed_audio) / self.sample_rate
158
+
159
+ logger.info(
160
+ f"Processed audio: original duration={duration_seconds:.2f}s, "
161
+ f"processed duration={processed_duration:.2f}s"
162
+ )
163
+
164
+ return processed_audio, transcript, duration_seconds
165
+
166
+ except Exception as e:
167
+ logger.error(f"Error processing audio: {e}", exc_info=True)
168
+ raise RuntimeError(f"Failed to process audio file: {e}")
169
+
170
+ finally:
171
+ # Clean up temp file if we created one
172
+ if temp_path and os.path.exists(temp_path):
173
+ try:
174
+ os.unlink(temp_path)
175
+ logger.debug(f"Deleted temporary file {temp_path}")
176
+ except Exception as e:
177
+ logger.warning(f"Failed to delete temporary file {temp_path}: {e}")
178
+
179
+ def _preprocess_audio(self, audio: torch.Tensor) -> torch.Tensor:
180
+ """
181
+ Preprocess audio for better voice cloning quality.
182
+
183
+ Args:
184
+ audio: Raw audio tensor
185
+
186
+ Returns:
187
+ Processed audio tensor
188
+ """
189
+ # Normalize volume
190
+ if torch.max(torch.abs(audio)) > 0:
191
+ audio = audio / torch.max(torch.abs(audio))
192
+
193
+ # Remove silence with dynamic threshold
194
+ audio = self._remove_silence(audio, threshold=0.02) # Slightly higher threshold to remove more noise
195
+
196
+ # Remove DC offset (very low frequency noise)
197
+ audio = audio - torch.mean(audio)
198
+
199
+ # Apply simple noise reduction
200
+ # This filters out very high frequencies that might contain noise
201
+ try:
202
+ audio_np = audio.cpu().numpy()
203
+ from scipy import signal
204
+
205
+ # Apply a bandpass filter to focus on speech frequencies (80Hz - 8000Hz)
206
+ sos = signal.butter(3, [80, 8000], 'bandpass', fs=self.sample_rate, output='sos')
207
+ filtered = signal.sosfilt(sos, audio_np)
208
+
209
+ # Normalize the filtered audio
210
+ filtered = filtered / (np.max(np.abs(filtered)) + 1e-8)
211
+
212
+ # Convert back to torch tensor
213
+ audio = torch.tensor(filtered, device=audio.device)
214
+ except Exception as e:
215
+ logger.warning(f"Advanced audio filtering failed, using basic processing: {e}")
216
+
217
+ # Ensure audio has correct amplitude
218
+ audio = audio * 0.9 # Slightly reduce volume to prevent clipping
219
+
220
+ return audio
221
+
222
+ def _remove_silence(
223
+ self,
224
+ audio: torch.Tensor,
225
+ threshold: float = 0.015,
226
+ min_silence_duration: float = 0.2
227
+ ) -> torch.Tensor:
228
+ """
229
+ Remove silence from audio while preserving speech rhythm.
230
+
231
+ Args:
232
+ audio: Input audio tensor
233
+ threshold: Energy threshold for silence detection
234
+ min_silence_duration: Minimum silence duration in seconds
235
+
236
+ Returns:
237
+ Audio with silence removed
238
+ """
239
+ # Convert to numpy for easier processing
240
+ audio_np = audio.cpu().numpy()
241
+
242
+ # Calculate energy
243
+ energy = np.abs(audio_np)
244
+
245
+ # Find regions above threshold (speech)
246
+ is_speech = energy > threshold
247
+
248
+ # Convert min_silence_duration to samples
249
+ min_silence_samples = int(min_silence_duration * self.sample_rate)
250
+
251
+ # Find speech segments
252
+ speech_segments = []
253
+ in_speech = False
254
+ speech_start = 0
255
+
256
+ for i in range(len(is_speech)):
257
+ if is_speech[i] and not in_speech:
258
+ # Start of speech segment
259
+ in_speech = True
260
+ speech_start = i
261
+ elif not is_speech[i] and in_speech:
262
+ # Potential end of speech segment
263
+ # Only end if silence is long enough
264
+ silence_count = 0
265
+ for j in range(i, min(len(is_speech), i + min_silence_samples)):
266
+ if not is_speech[j]:
267
+ silence_count += 1
268
+ else:
269
+ break
270
+
271
+ if silence_count >= min_silence_samples:
272
+ # End of speech segment
273
+ in_speech = False
274
+ speech_segments.append((speech_start, i))
275
+
276
+ # Handle case where audio ends during speech
277
+ if in_speech:
278
+ speech_segments.append((speech_start, len(is_speech)))
279
+
280
+ # If no speech segments found, return original audio
281
+ if not speech_segments:
282
+ logger.warning("No speech segments detected, returning original audio")
283
+ return audio
284
+
285
+ # Add small buffer around segments
286
+ buffer_samples = int(0.05 * self.sample_rate) # 50ms buffer
287
+ processed_segments = []
288
+
289
+ for start, end in speech_segments:
290
+ buffered_start = max(0, start - buffer_samples)
291
+ buffered_end = min(len(audio_np), end + buffer_samples)
292
+ processed_segments.append(audio_np[buffered_start:buffered_end])
293
+
294
+ # Concatenate all segments with small pauses between them
295
+ small_pause = np.zeros(int(0.15 * self.sample_rate)) # 150ms pause
296
+ result = processed_segments[0]
297
+
298
+ for segment in processed_segments[1:]:
299
+ result = np.concatenate([result, small_pause, segment])
300
+
301
+ return torch.tensor(result, device=audio.device)
302
+
303
+ def _enhance_speech(self, audio: torch.Tensor) -> torch.Tensor:
304
+ """Enhance speech quality for better cloning results."""
305
+ # This is a placeholder for more advanced speech enhancement
306
+ # In a production implementation, you could add:
307
+ # - Noise reduction
308
+ # - Equalization for speech frequencies
309
+ # - Gentle compression for better dynamics
310
+ return audio
311
+
312
+ async def clone_voice(
313
+ self,
314
+ audio_file: Union[UploadFile, BinaryIO, str],
315
+ voice_name: str,
316
+ transcript: Optional[str] = None,
317
+ description: Optional[str] = None,
318
+ speaker_id: Optional[int] = None # Make this optional
319
+ ) -> ClonedVoice:
320
+ """
321
+ Clone a voice from an audio file.
322
+
323
+ Args:
324
+ audio_file: Audio file with the voice to clone
325
+ voice_name: Name for the cloned voice
326
+ transcript: Transcript of the audio (optional)
327
+ description: Description of the voice (optional)
328
+ speaker_id: Speaker ID to use (default: auto-assigned)
329
+
330
+ Returns:
331
+ ClonedVoice object with voice information
332
+ """
333
+ logger.info(f"Cloning new voice '{voice_name}' from audio file")
334
+
335
+ # Process the audio file
336
+ processed_audio, provided_transcript, duration = await self.process_audio_file(
337
+ audio_file, transcript
338
+ )
339
+
340
+ # Use a better speaker ID assignment - use a small number similar to the built-in voices
341
+ # This prevents issues with the speaker ID being interpreted as speech
342
+ if speaker_id is None:
343
+ # Use a number between 10-20 to avoid conflicts with built-in voices (0-5)
344
+ # but not too large like 999 which might cause issues
345
+ existing_ids = [v.speaker_id for v in self.cloned_voices.values()]
346
+ for potential_id in range(10, 20):
347
+ if potential_id not in existing_ids:
348
+ speaker_id = potential_id
349
+ break
350
+ else:
351
+ # If all IDs in range are taken, use a fallback
352
+ speaker_id = 10
353
+
354
+ # Generate a unique ID for the voice
355
+ voice_id = f"{int(time.time())}_{voice_name.lower().replace(' ', '_')}"
356
+
357
+ # Create directory for the voice
358
+ voice_dir = os.path.join(CLONED_VOICES_DIR, voice_id)
359
+ os.makedirs(voice_dir, exist_ok=True)
360
+
361
+ # Save the processed audio
362
+ audio_path = os.path.join(voice_dir, "reference.wav")
363
+ torchaudio.save(audio_path, processed_audio.unsqueeze(0).cpu(), self.sample_rate)
364
+
365
+ # Save the transcript if provided
366
+ if provided_transcript:
367
+ transcript_path = os.path.join(voice_dir, "transcript.txt")
368
+ with open(transcript_path, "w") as f:
369
+ f.write(provided_transcript)
370
+
371
+ # Create and save voice info
372
+ voice_info = ClonedVoice(
373
+ id=voice_id,
374
+ name=voice_name,
375
+ created_at=time.time(),
376
+ speaker_id=speaker_id,
377
+ description=description,
378
+ audio_duration=duration,
379
+ sample_count=len(processed_audio)
380
+ )
381
+
382
+ # Save voice info as JSON
383
+ import json
384
+ with open(os.path.join(voice_dir, "info.json"), "w") as f:
385
+ f.write(json.dumps(voice_info.dict()))
386
+
387
+ # Add to cloned voices dictionary
388
+ self.cloned_voices[voice_id] = voice_info
389
+
390
+ logger.info(f"Voice '{voice_name}' cloned successfully with ID: {voice_id} and speaker_id: {speaker_id}")
391
+
392
+ return voice_info
393
+
394
+ def get_voice_context(self, voice_id: str) -> List[Segment]:
395
+ """
396
+ Get context segments for a cloned voice.
397
+
398
+ Args:
399
+ voice_id: ID of the cloned voice
400
+
401
+ Returns:
402
+ List of context segments for the voice
403
+ """
404
+ if voice_id not in self.cloned_voices:
405
+ logger.warning(f"Voice ID {voice_id} not found")
406
+ return []
407
+
408
+ voice = self.cloned_voices[voice_id]
409
+ voice_dir = os.path.join(CLONED_VOICES_DIR, voice_id)
410
+ audio_path = os.path.join(voice_dir, "reference.wav")
411
+
412
+ if not os.path.exists(audio_path):
413
+ logger.error(f"Audio file for voice {voice_id} not found at {audio_path}")
414
+ return []
415
+
416
+ try:
417
+ # Load the audio
418
+ audio, sr = torchaudio.load(audio_path)
419
+ audio = audio.squeeze(0)
420
+
421
+ # Resample if necessary
422
+ if sr != self.sample_rate:
423
+ audio = torchaudio.functional.resample(
424
+ audio, orig_freq=sr, new_freq=self.sample_rate
425
+ )
426
+
427
+ # Trim to a maximum of 5 seconds to avoid sequence length issues
428
+ # This is a balance between voice quality and model limitations
429
+ max_samples = 5 * self.sample_rate # 5 seconds
430
+ if audio.shape[0] > max_samples:
431
+ logger.info(f"Trimming voice sample from {audio.shape[0]} to {max_samples} samples")
432
+ # Take from beginning for better voice characteristics
433
+ audio = audio[:max_samples]
434
+
435
+ # Load transcript if available
436
+ transcript_path = os.path.join(voice_dir, "transcript.txt")
437
+ transcript = ""
438
+ if os.path.exists(transcript_path):
439
+ with open(transcript_path, "r") as f:
440
+ full_transcript = f.read()
441
+ # Take a portion of transcript that roughly matches our audio portion
442
+ words = full_transcript.split()
443
+ # Estimate 3 words per second as a rough average
444
+ word_count = min(len(words), int(5 * 3)) # 5 seconds * 3 words/second
445
+ transcript = " ".join(words[:word_count])
446
+ else:
447
+ transcript = f"Voice sample for {voice.name}"
448
+
449
+ # Create context segment
450
+ segment = Segment(
451
+ text=transcript,
452
+ speaker=voice.speaker_id,
453
+ audio=audio.to(self.device)
454
+ )
455
+
456
+ logger.info(f"Created voice context segment with {audio.shape[0]/self.sample_rate:.1f}s audio")
457
+ return [segment]
458
+
459
+ except Exception as e:
460
+ logger.error(f"Error getting voice context for {voice_id}: {e}")
461
+ return []
462
+
463
+ def list_voices(self) -> List[ClonedVoice]:
464
+ """List all available cloned voices."""
465
+ return list(self.cloned_voices.values())
466
+
467
+ def delete_voice(self, voice_id: str) -> bool:
468
+ """
469
+ Delete a cloned voice.
470
+
471
+ Args:
472
+ voice_id: ID of the voice to delete
473
+
474
+ Returns:
475
+ True if successful, False otherwise
476
+ """
477
+ if voice_id not in self.cloned_voices:
478
+ return False
479
+
480
+ voice_dir = os.path.join(CLONED_VOICES_DIR, voice_id)
481
+ if os.path.exists(voice_dir):
482
+ try:
483
+ import shutil
484
+ shutil.rmtree(voice_dir)
485
+ del self.cloned_voices[voice_id]
486
+ return True
487
+ except Exception as e:
488
+ logger.error(f"Error deleting voice {voice_id}: {e}")
489
+ return False
490
+
491
+ return False
492
+
493
+ async def clone_voice_from_youtube(
494
+ self, # Don't forget the self parameter for class methods
495
+ youtube_url: str,
496
+ voice_name: str,
497
+ start_time: int = 0,
498
+ duration: int = 180,
499
+ description: str = None
500
+ ) -> ClonedVoice:
501
+ """
502
+ Clone a voice from a YouTube video.
503
+
504
+ Args:
505
+ youtube_url: URL of the YouTube video
506
+ voice_name: Name for the cloned voice
507
+ start_time: Start time in seconds
508
+ duration: Duration to extract in seconds
509
+ description: Optional description of the voice
510
+
511
+ Returns:
512
+ ClonedVoice object with voice information
513
+ """
514
+ logger.info(f"Cloning voice '{voice_name}' from YouTube: {youtube_url}")
515
+
516
+ # Create temporary directory for processing
517
+ with tempfile.TemporaryDirectory() as temp_dir:
518
+ # Step 1: Download audio from YouTube
519
+ audio_path = await self._download_youtube_audio(youtube_url, temp_dir, start_time, duration)
520
+
521
+ # Step 2: Generate transcript using Whisper
522
+ transcript = await self._generate_transcript(audio_path)
523
+
524
+ # Step 3: Clone the voice using the extracted audio and transcript
525
+ voice = await self.clone_voice(
526
+ audio_file=audio_path,
527
+ voice_name=voice_name,
528
+ transcript=transcript,
529
+ description=description or f"Voice cloned from YouTube: {youtube_url}"
530
+ )
531
+
532
+ return voice
533
+
534
+ async def _download_youtube_audio(
535
+ self, # Don't forget the self parameter
536
+ url: str,
537
+ output_dir: str,
538
+ start_time: int = 0,
539
+ duration: int = 180
540
+ ) -> str:
541
+ """
542
+ Download audio from a YouTube video.
543
+
544
+ Args:
545
+ url: YouTube URL
546
+ output_dir: Directory to save the audio
547
+ start_time: Start time in seconds
548
+ duration: Duration to extract in seconds
549
+
550
+ Returns:
551
+ Path to the downloaded audio file
552
+ """
553
+ output_path = os.path.join(output_dir, "youtube_audio.wav")
554
+
555
+ # Configure yt-dlp options
556
+ ydl_opts = {
557
+ 'format': 'bestaudio/best',
558
+ 'postprocessors': [{
559
+ 'key': 'FFmpegExtractAudio',
560
+ 'preferredcodec': 'wav',
561
+ 'preferredquality': '192',
562
+ }],
563
+ 'outtmpl': output_path.replace(".wav", ""),
564
+ 'quiet': True,
565
+ 'no_warnings': True
566
+ }
567
+
568
+ # Download the video
569
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
570
+ ydl.download([url])
571
+
572
+ # Trim the audio to the specified segment
573
+ if start_time > 0 or duration < float('inf'):
574
+ import ffmpeg
575
+ trimmed_path = os.path.join(output_dir, "trimmed_audio.wav")
576
+
577
+ # Use ffmpeg to trim the audio
578
+ (
579
+ ffmpeg.input(output_path)
580
+ .audio
581
+ .filter('atrim', start=start_time, duration=duration)
582
+ .output(trimmed_path)
583
+ .run(quiet=True, overwrite_output=True)
584
+ )
585
+
586
+ return trimmed_path
587
+
588
+ return output_path
589
+
590
+ async def _generate_transcript(self, audio_path: str) -> str:
591
+ """
592
+ Generate transcript from audio using Whisper.
593
+
594
+ Args:
595
+ audio_path: Path to the audio file
596
+
597
+ Returns:
598
+ Transcript text
599
+ """
600
+ # Load Whisper model (use small model for faster processing)
601
+ model = whisper.load_model("small")
602
+
603
+ # Transcribe the audio
604
+ result = model.transcribe(audio_path)
605
+
606
+ return result["text"]
607
+
608
+ def generate_speech(
609
+ self,
610
+ text: str,
611
+ voice_id: str,
612
+ temperature: float = 0.65,
613
+ topk: int = 30,
614
+ max_audio_length_ms: int = 15000
615
+ ) -> torch.Tensor:
616
+ """
617
+ Generate speech with a cloned voice.
618
+ Args:
619
+ text: Text to synthesize
620
+ voice_id: ID of the cloned voice to use
621
+ temperature: Sampling temperature (lower = more stable, higher = more varied)
622
+ topk: Top-k sampling parameter
623
+ max_audio_length_ms: Maximum audio length in milliseconds
624
+ Returns:
625
+ Generated audio tensor
626
+ """
627
+ # Remove any async/await keywords - this is a synchronous function
628
+ if voice_id not in self.cloned_voices:
629
+ raise ValueError(f"Voice ID {voice_id} not found")
630
+ voice = self.cloned_voices[voice_id]
631
+ context = self.get_voice_context(voice_id)
632
+ if not context:
633
+ raise ValueError(f"Could not get context for voice {voice_id}")
634
+ # Preprocess text for better pronunciation
635
+ processed_text = self._preprocess_text(text)
636
+ logger.info(f"Generating speech with voice '{voice.name}' (ID: {voice_id}, speaker: {voice.speaker_id})")
637
+ try:
638
+ # Check if text is too long and should be split
639
+ if len(processed_text) > 200:
640
+ logger.info(f"Text is long ({len(processed_text)} chars), splitting for better quality")
641
+ from app.prompt_engineering import split_into_segments
642
+ # Split text into manageable segments
643
+ segments = split_into_segments(processed_text, max_chars=150)
644
+ logger.info(f"Split text into {len(segments)} segments")
645
+ all_audio_chunks = []
646
+ # Process each segment
647
+ for i, segment_text in enumerate(segments):
648
+ logger.info(f"Generating segment {i+1}/{len(segments)}")
649
+ # Generate this segment - using plain text without formatting
650
+ segment_audio = self.generator.generate(
651
+ text=segment_text, # Use plain text, no formatting
652
+ speaker=voice.speaker_id,
653
+ context=context,
654
+ max_audio_length_ms=min(max_audio_length_ms, 10000),
655
+ temperature=temperature,
656
+ topk=topk,
657
+ )
658
+ all_audio_chunks.append(segment_audio)
659
+ # Use this segment as context for the next one for consistency
660
+ if i < len(segments) - 1:
661
+ context = [
662
+ Segment(
663
+ text=segment_text,
664
+ speaker=voice.speaker_id,
665
+ audio=segment_audio
666
+ )
667
+ ]
668
+ # Combine chunks with small silence between them
669
+ if len(all_audio_chunks) == 1:
670
+ audio = all_audio_chunks[0]
671
+ else:
672
+ silence_samples = int(0.1 * self.sample_rate) # 100ms silence
673
+ silence = torch.zeros(silence_samples, device=all_audio_chunks[0].device)
674
+ # Join segments with silence
675
+ audio_parts = []
676
+ for i, chunk in enumerate(all_audio_chunks):
677
+ audio_parts.append(chunk)
678
+ if i < len(all_audio_chunks) - 1: # Don't add silence after the last chunk
679
+ audio_parts.append(silence)
680
+ # Concatenate all parts
681
+ audio = torch.cat(audio_parts)
682
+ return audio
683
+ else:
684
+ # For short text, generate directly - using plain text without formatting
685
+ audio = self.generator.generate(
686
+ text=processed_text, # Use plain text, no formatting
687
+ speaker=voice.speaker_id,
688
+ context=context,
689
+ max_audio_length_ms=max_audio_length_ms,
690
+ temperature=temperature,
691
+ topk=topk,
692
+ )
693
+ return audio
694
+ except Exception as e:
695
+ logger.error(f"Error generating speech with voice {voice_id}: {e}")
696
+ raise
697
+
698
+ def _preprocess_text(self, text: str) -> str:
699
+ """Preprocess text for better pronunciation and voice cloning."""
700
+ # Make sure text ends with punctuation for better phrasing
701
+ text = text.strip()
702
+ if not text.endswith(('.', '?', '!', ';')):
703
+ text = text + '.'
704
+
705
+ return text
app/voice_embeddings.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Voice embeddings for consistent voice generation."""
2
+ import os
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
+ from typing import Dict
7
+
8
+ # Path to store voice samples
9
+ VOICE_SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "voice_samples")
10
+ os.makedirs(VOICE_SAMPLES_DIR, exist_ok=True)
11
+
12
+ # Dictionary to store voice embeddings/samples
13
+ VOICE_DICT: Dict[str, torch.Tensor] = {}
14
+
15
+
16
+ def initialize_voices(sample_rate: int = 24000):
17
+ """Initialize voice dictionary with consistent samples."""
18
+ # Generate consistent seed audio for each voice
19
+ for voice_id in range(6):
20
+ voice_name = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"][voice_id]
21
+
22
+ # Create deterministic audio sample for each voice
23
+ np.random.seed(voice_id + 42) # Use a fixed seed based on voice ID
24
+
25
+ # Generate 1 second of "seed" audio with deterministic characteristics
26
+ # This differs per voice but remains consistent across runs
27
+ duration = 1.0 # seconds
28
+ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
29
+
30
+ # Create a distinctive waveform for each voice
31
+ if voice_id == 0: # alloy - rich mid tones
32
+ freq1, freq2 = 220, 440
33
+ audio = 0.5 * np.sin(2 * np.pi * freq1 * t) + 0.3 * np.sin(2 * np.pi * freq2 * t)
34
+ elif voice_id == 1: # echo - reverberant
35
+ freq = 330
36
+ audio = np.sin(2 * np.pi * freq * t) * np.exp(-t * 3)
37
+ elif voice_id == 2: # fable - bright, higher pitch
38
+ freq = 523
39
+ audio = 0.7 * np.sin(2 * np.pi * freq * t)
40
+ elif voice_id == 3: # onyx - deep and resonant
41
+ freq = 165
42
+ audio = 0.8 * np.sin(2 * np.pi * freq * t)
43
+ elif voice_id == 4: # nova - warm and smooth
44
+ freq1, freq2 = 392, 196
45
+ audio = 0.4 * np.sin(2 * np.pi * freq1 * t) + 0.4 * np.sin(2 * np.pi * freq2 * t)
46
+ else: # shimmer - airy and light
47
+ freq1, freq2, freq3 = 587, 880, 1174
48
+ audio = 0.3 * np.sin(2 * np.pi * freq1 * t) + 0.2 * np.sin(2 * np.pi * freq2 * t) + 0.1 * np.sin(2 * np.pi * freq3 * t)
49
+
50
+ # Normalize
51
+ audio = audio / np.max(np.abs(audio))
52
+
53
+ # Convert to tensor
54
+ audio_tensor = torch.tensor(audio, dtype=torch.float32)
55
+
56
+ # Store the audio tensor
57
+ VOICE_DICT[voice_name] = audio_tensor
58
+
59
+ # Save as wav for reference
60
+ save_path = os.path.join(VOICE_SAMPLES_DIR, f"{voice_name}_seed.wav")
61
+ torchaudio.save(save_path, audio_tensor.unsqueeze(0), sample_rate)
62
+
63
+ print(f"Initialized voice seed for {voice_name}")
64
+
65
+
66
+ def get_voice_sample(voice_name: str) -> torch.Tensor:
67
+ """Get the voice sample for a given voice name."""
68
+ if not VOICE_DICT:
69
+ initialize_voices()
70
+
71
+ if voice_name in VOICE_DICT:
72
+ return VOICE_DICT[voice_name]
73
+
74
+ # Default to alloy if voice not found
75
+ print(f"Voice {voice_name} not found, defaulting to alloy")
76
+ return VOICE_DICT["alloy"]
77
+
78
+
79
+ def update_voice_sample(voice_name: str, audio: torch.Tensor):
80
+ """Update the voice sample with recently generated audio."""
81
+ # Only update if we've already initialized
82
+ if VOICE_DICT:
83
+ # Take the last second of audio (or whatever is available)
84
+ sample_length = min(24000, audio.shape[0])
85
+ VOICE_DICT[voice_name] = audio[-sample_length:].detach().cpu()
app/voice_enhancement.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Advanced voice enhancement and consistency system for CSM-1B."""
2
+ import os
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
+ import soundfile as sf
7
+ from typing import Dict, List, Optional, Tuple
8
+ import logging
9
+ from dataclasses import dataclass
10
+ from scipy import signal
11
+
12
+ # Setup logging
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Define persistent paths
16
+ VOICE_REFERENCES_DIR = "/app/voice_references"
17
+ VOICE_PROFILES_DIR = "/app/voice_profiles"
18
+
19
+ # Ensure directories exist
20
+ os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True)
21
+ os.makedirs(VOICE_PROFILES_DIR, exist_ok=True)
22
+
23
+ @dataclass
24
+ class VoiceProfile:
25
+ """Detailed voice profile with acoustic characteristics."""
26
+ name: str
27
+ speaker_id: int
28
+ # Acoustic parameters
29
+ pitch_range: Tuple[float, float] # Min/max pitch in Hz
30
+ intensity_range: Tuple[float, float] # Min/max intensity (volume)
31
+ spectral_tilt: float # Brightness vs. darkness
32
+ prosody_pattern: str # Pattern of intonation and rhythm
33
+ speech_rate: float # Relative speech rate (1.0 = normal)
34
+ formant_shift: float # Formant frequency shift (1.0 = no shift)
35
+ # Reference audio
36
+ reference_segments: List[torch.Tensor]
37
+ # Normalization parameters
38
+ target_rms: float = 0.2
39
+ target_peak: float = 0.95
40
+
41
+ def get_enhancement_params(self) -> Dict:
42
+ """Get parameters for enhancing generated audio."""
43
+ return {
44
+ "target_rms": self.target_rms,
45
+ "target_peak": self.target_peak,
46
+ "pitch_range": self.pitch_range,
47
+ "formant_shift": self.formant_shift,
48
+ "speech_rate": self.speech_rate,
49
+ "spectral_tilt": self.spectral_tilt
50
+ }
51
+
52
+ # Voice profiles with carefully tuned parameters
53
+ VOICE_PROFILES = {
54
+ "alloy": VoiceProfile(
55
+ name="alloy",
56
+ speaker_id=0,
57
+ pitch_range=(85, 180), # Hz - balanced range
58
+ intensity_range=(0.15, 0.3), # moderate intensity
59
+ spectral_tilt=0.0, # neutral tilt
60
+ prosody_pattern="balanced",
61
+ speech_rate=1.0, # normal rate
62
+ formant_shift=1.0, # no shift
63
+ reference_segments=[],
64
+ target_rms=0.2,
65
+ target_peak=0.95
66
+ ),
67
+ "echo": VoiceProfile(
68
+ name="echo",
69
+ speaker_id=1,
70
+ pitch_range=(75, 165), # Hz - lower, resonant
71
+ intensity_range=(0.2, 0.35), # slightly stronger
72
+ spectral_tilt=-0.2, # more low frequencies
73
+ prosody_pattern="deliberate",
74
+ speech_rate=0.95, # slightly slower
75
+ formant_shift=0.95, # slightly lower formants
76
+ reference_segments=[],
77
+ target_rms=0.22, # slightly louder
78
+ target_peak=0.95
79
+ ),
80
+ "fable": VoiceProfile(
81
+ name="fable",
82
+ speaker_id=2,
83
+ pitch_range=(120, 250), # Hz - higher range
84
+ intensity_range=(0.15, 0.28), # moderate intensity
85
+ spectral_tilt=0.2, # more high frequencies
86
+ prosody_pattern="animated",
87
+ speech_rate=1.05, # slightly faster
88
+ formant_shift=1.05, # slightly higher formants
89
+ reference_segments=[],
90
+ target_rms=0.19,
91
+ target_peak=0.95
92
+ ),
93
+ "onyx": VoiceProfile(
94
+ name="onyx",
95
+ speaker_id=3,
96
+ pitch_range=(65, 150), # Hz - deeper range
97
+ intensity_range=(0.18, 0.32), # moderate-strong
98
+ spectral_tilt=-0.3, # more low frequencies
99
+ prosody_pattern="authoritative",
100
+ speech_rate=0.93, # slightly slower
101
+ formant_shift=0.9, # lower formants
102
+ reference_segments=[],
103
+ target_rms=0.23, # stronger
104
+ target_peak=0.95
105
+ ),
106
+ "nova": VoiceProfile(
107
+ name="nova",
108
+ speaker_id=4,
109
+ pitch_range=(90, 200), # Hz - warm midrange
110
+ intensity_range=(0.15, 0.27), # moderate
111
+ spectral_tilt=-0.1, # slightly warm
112
+ prosody_pattern="flowing",
113
+ speech_rate=1.0, # normal rate
114
+ formant_shift=1.0, # no shift
115
+ reference_segments=[],
116
+ target_rms=0.2,
117
+ target_peak=0.95
118
+ ),
119
+ "shimmer": VoiceProfile(
120
+ name="shimmer",
121
+ speaker_id=5,
122
+ pitch_range=(140, 280), # Hz - brighter, higher
123
+ intensity_range=(0.15, 0.25), # moderate-light
124
+ spectral_tilt=0.3, # more high frequencies
125
+ prosody_pattern="light",
126
+ speech_rate=1.07, # slightly faster
127
+ formant_shift=1.1, # higher formants
128
+ reference_segments=[],
129
+ target_rms=0.18, # slightly softer
130
+ target_peak=0.95
131
+ )
132
+ }
133
+
134
+ # Voice-specific prompt templates - crafted to establish voice identity clearly
135
+ VOICE_PROMPTS = {
136
+ "alloy": [
137
+ "Hello, I'm Alloy. I speak with a balanced, natural tone that's easy to understand.",
138
+ "This is Alloy speaking. My voice is designed to be clear and conversational.",
139
+ "Alloy here - I have a neutral, friendly voice with balanced tone qualities."
140
+ ],
141
+ "echo": [
142
+ "Hello, I'm Echo. I speak with a resonant, deeper voice that carries well.",
143
+ "This is Echo speaking. My voice has a rich, resonant quality with depth.",
144
+ "Echo here - My voice is characterized by its warm, resonant tones."
145
+ ],
146
+ "fable": [
147
+ "Hello, I'm Fable. I speak with a bright, higher-pitched voice that's full of energy.",
148
+ "This is Fable speaking. My voice is characterized by its clear, bright quality.",
149
+ "Fable here - My voice is light, articulate, and slightly higher-pitched."
150
+ ],
151
+ "onyx": [
152
+ "Hello, I'm Onyx. I speak with a deep, authoritative voice that commands attention.",
153
+ "This is Onyx speaking. My voice has a powerful, deep quality with gravitas.",
154
+ "Onyx here - My voice is characterized by its depth and commanding presence."
155
+ ],
156
+ "nova": [
157
+ "Hello, I'm Nova. I speak with a warm, pleasant mid-range voice that's easy to listen to.",
158
+ "This is Nova speaking. My voice has a smooth, harmonious quality.",
159
+ "Nova here - My voice is characterized by its warm, friendly mid-tones."
160
+ ],
161
+ "shimmer": [
162
+ "Hello, I'm Shimmer. I speak with a light, bright voice that's expressive and clear.",
163
+ "This is Shimmer speaking. My voice has an airy, higher-pitched quality.",
164
+ "Shimmer here - My voice is characterized by its bright, crystalline tones."
165
+ ]
166
+ }
167
+
168
+ def initialize_voice_profiles():
169
+ """Initialize voice profiles with default settings.
170
+
171
+ This function loads existing voice profiles from disk if available,
172
+ or initializes them with default settings.
173
+ """
174
+ global VOICE_PROFILES
175
+
176
+ # Try to load existing profiles from persistent storage
177
+ profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt")
178
+
179
+ if os.path.exists(profile_path):
180
+ try:
181
+ logger.info(f"Loading voice profiles from {profile_path}")
182
+ saved_profiles = torch.load(profile_path)
183
+
184
+ # Update existing profiles with saved data
185
+ for name, data in saved_profiles.items():
186
+ if name in VOICE_PROFILES:
187
+ VOICE_PROFILES[name].reference_segments = [
188
+ seg.to(torch.device("cpu")) for seg in data.get('reference_segments', [])
189
+ ]
190
+
191
+ logger.info(f"Loaded voice profiles for {len(saved_profiles)} voices")
192
+ except Exception as e:
193
+ logger.error(f"Error loading voice profiles: {e}")
194
+ logger.info("Using default voice profiles")
195
+ else:
196
+ logger.info("No saved voice profiles found, using defaults")
197
+
198
+ # Ensure all voices have at least empty reference segments
199
+ for name, profile in VOICE_PROFILES.items():
200
+ if not hasattr(profile, 'reference_segments'):
201
+ profile.reference_segments = []
202
+
203
+ logger.info(f"Voice profiles initialized for {len(VOICE_PROFILES)} voices")
204
+ return VOICE_PROFILES
205
+
206
+ def normalize_audio(audio: torch.Tensor, target_rms: float = 0.2, target_peak: float = 0.95) -> torch.Tensor:
207
+ """Apply professional-grade normalization to audio.
208
+
209
+ Args:
210
+ audio: Audio tensor
211
+ target_rms: Target RMS level for normalization
212
+ target_peak: Target peak level for limiting
213
+
214
+ Returns:
215
+ Normalized audio tensor
216
+ """
217
+ # Ensure audio is on CPU for processing
218
+ audio_cpu = audio.detach().cpu()
219
+
220
+ # Handle silent audio
221
+ if audio_cpu.abs().max() < 1e-6:
222
+ logger.warning("Audio is nearly silent, returning original")
223
+ return audio
224
+
225
+ # Calculate current RMS
226
+ current_rms = torch.sqrt(torch.mean(audio_cpu ** 2))
227
+
228
+ # Apply RMS normalization
229
+ if current_rms > 0:
230
+ gain = target_rms / current_rms
231
+ normalized = audio_cpu * gain
232
+ else:
233
+ normalized = audio_cpu
234
+
235
+ # Apply peak limiting
236
+ current_peak = normalized.abs().max()
237
+ if current_peak > target_peak:
238
+ normalized = normalized * (target_peak / current_peak)
239
+
240
+ # Return to original device
241
+ return normalized.to(audio.device)
242
+
243
+ def apply_anti_muffling(audio: torch.Tensor, sample_rate: int, clarity_boost: float = 1.2) -> torch.Tensor:
244
+ """Apply anti-muffling to improve clarity.
245
+
246
+ Args:
247
+ audio: Audio tensor
248
+ sample_rate: Audio sample rate
249
+ clarity_boost: Amount of high frequency boost (1.0 = no boost)
250
+
251
+ Returns:
252
+ Processed audio tensor
253
+ """
254
+ # Convert to numpy for filtering
255
+ audio_np = audio.detach().cpu().numpy()
256
+
257
+ try:
258
+ # Design a high shelf filter to boost high frequencies
259
+ # Use a standard high-shelf filter that's supported by scipy.signal
260
+ # We'll use a second-order Butterworth high-pass filter as an alternative
261
+ cutoff = 2000 # Hz
262
+ b, a = signal.butter(2, cutoff/(sample_rate/2), btype='high', analog=False)
263
+
264
+ # Apply the filter with the clarity boost gain
265
+ boosted = signal.filtfilt(b, a, audio_np, axis=0) * clarity_boost
266
+
267
+ # Mix with original to maintain some warmth
268
+ mix_ratio = 0.7 # 70% processed, 30% original
269
+ processed = mix_ratio * boosted + (1-mix_ratio) * audio_np
270
+
271
+ except Exception as e:
272
+ logger.warning(f"Audio enhancement failed, using original: {e}")
273
+ # Return original audio if enhancement fails
274
+ return audio
275
+
276
+ # Convert back to tensor on original device
277
+ return torch.tensor(processed, dtype=audio.dtype, device=audio.device)
278
+
279
+ def enhance_audio(audio: torch.Tensor, sample_rate: int, voice_profile: VoiceProfile) -> torch.Tensor:
280
+ """Apply comprehensive audio enhancement based on voice profile.
281
+
282
+ Args:
283
+ audio: Audio tensor
284
+ sample_rate: Audio sample rate
285
+ voice_profile: Voice profile containing enhancement parameters
286
+
287
+ Returns:
288
+ Enhanced audio tensor
289
+ """
290
+ if audio is None or audio.numel() == 0:
291
+ logger.error("Cannot enhance empty audio")
292
+ return audio
293
+
294
+ try:
295
+ # Step 1: Normalize audio levels
296
+ params = voice_profile.get_enhancement_params()
297
+ normalized = normalize_audio(
298
+ audio,
299
+ target_rms=params["target_rms"],
300
+ target_peak=params["target_peak"]
301
+ )
302
+
303
+ # Step 2: Apply anti-muffling based on spectral tilt
304
+ # Positive tilt means brighter voice so less clarity boost needed
305
+ clarity_boost = 1.0 + max(0, -params["spectral_tilt"]) * 0.5
306
+ clarified = apply_anti_muffling(
307
+ normalized,
308
+ sample_rate,
309
+ clarity_boost=clarity_boost
310
+ )
311
+
312
+ # Log the enhancement
313
+ logger.debug(
314
+ f"Enhanced audio for {voice_profile.name}: "
315
+ f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{clarified.pow(2).mean().sqrt().item():.3f}, "
316
+ f"Peak: {audio.abs().max().item():.3f}->{clarified.abs().max().item():.3f}"
317
+ )
318
+
319
+ return clarified
320
+
321
+ except Exception as e:
322
+ logger.error(f"Error in audio enhancement: {e}")
323
+ return audio # Return original audio if enhancement fails
324
+
325
+ def validate_generated_audio(
326
+ audio: torch.Tensor,
327
+ voice_name: str,
328
+ sample_rate: int,
329
+ min_expected_duration: float = 0.5
330
+ ) -> Tuple[bool, torch.Tensor, str]:
331
+ """Validate and fix generated audio.
332
+
333
+ Args:
334
+ audio: Audio tensor to validate
335
+ voice_name: Name of the voice used
336
+ sample_rate: Audio sample rate
337
+ min_expected_duration: Minimum expected duration in seconds
338
+
339
+ Returns:
340
+ Tuple of (is_valid, fixed_audio, message)
341
+ """
342
+ if audio is None:
343
+ return False, torch.zeros(1), "Audio is None"
344
+
345
+ # Check for NaN values
346
+ if torch.isnan(audio).any():
347
+ logger.warning(f"Audio for {voice_name} contains NaN values, replacing with zeros")
348
+ audio = torch.where(torch.isnan(audio), torch.zeros_like(audio), audio)
349
+
350
+ # Check audio duration
351
+ duration = audio.shape[0] / sample_rate
352
+ if duration < min_expected_duration:
353
+ logger.warning(f"Audio for {voice_name} is too short ({duration:.2f}s < {min_expected_duration}s)")
354
+ return False, audio, f"Audio too short: {duration:.2f}s"
355
+
356
+ # Check for silent sections - this can indicate generation problems
357
+ rms = torch.sqrt(torch.mean(audio ** 2))
358
+ if rms < 0.01: # Very low RMS indicates near silence
359
+ logger.warning(f"Audio for {voice_name} is nearly silent (RMS: {rms:.6f})")
360
+ return False, audio, f"Audio nearly silent: RMS = {rms:.6f}"
361
+
362
+ # Check if audio suddenly cuts off - this detects premature stopping
363
+ # Calculate RMS in the last 100ms
364
+ last_samples = int(0.1 * sample_rate)
365
+ if audio.shape[0] > last_samples:
366
+ end_rms = torch.sqrt(torch.mean(audio[-last_samples:] ** 2))
367
+ if end_rms > 0.1: # High RMS at the end suggests an abrupt cutoff
368
+ logger.warning(f"Audio for {voice_name} may have cut off prematurely (end RMS: {end_rms:.3f})")
369
+ return True, audio, "Audio may have cut off prematurely"
370
+
371
+ return True, audio, "Audio validation passed"
372
+
373
+ def create_voice_segments(app_state, regenerate: bool = False):
374
+ """Create high-quality voice reference segments.
375
+
376
+ Args:
377
+ app_state: Application state containing generator
378
+ regenerate: Whether to regenerate existing references
379
+ """
380
+ generator = app_state.generator
381
+ if not generator:
382
+ logger.error("Cannot create voice segments: generator not available")
383
+ return
384
+
385
+ # Use persistent directory for voice reference segments
386
+ os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True)
387
+
388
+ for voice_name, profile in VOICE_PROFILES.items():
389
+ voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name)
390
+ os.makedirs(voice_dir, exist_ok=True)
391
+
392
+ # Check if we already have references
393
+ if not regenerate and profile.reference_segments:
394
+ logger.info(f"Voice {voice_name} already has {len(profile.reference_segments)} reference segments")
395
+ continue
396
+
397
+ # Get prompts for this voice
398
+ prompts = VOICE_PROMPTS[voice_name]
399
+
400
+ # Generate reference segments
401
+ logger.info(f"Generating reference segments for voice: {voice_name}")
402
+ reference_segments = []
403
+
404
+ for i, prompt in enumerate(prompts):
405
+ ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav")
406
+
407
+ # Skip if file exists and we're not regenerating
408
+ if not regenerate and os.path.exists(ref_path):
409
+ try:
410
+ # Load existing reference
411
+ audio_tensor, sr = torchaudio.load(ref_path)
412
+ if sr != generator.sample_rate:
413
+ audio_tensor = torchaudio.functional.resample(
414
+ audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate
415
+ )
416
+ else:
417
+ audio_tensor = audio_tensor.squeeze(0)
418
+ reference_segments.append(audio_tensor.to(generator.device))
419
+ logger.info(f"Loaded existing reference {i+1}/{len(prompts)} for {voice_name}")
420
+ continue
421
+ except Exception as e:
422
+ logger.warning(f"Failed to load existing reference {i+1} for {voice_name}: {e}")
423
+
424
+ try:
425
+ # Use a lower temperature for more stability in reference samples
426
+ logger.info(f"Generating reference {i+1}/{len(prompts)} for {voice_name}: '{prompt}'")
427
+
428
+ # We want references to be as clean as possible
429
+ audio = generator.generate(
430
+ text=prompt,
431
+ speaker=profile.speaker_id,
432
+ context=[], # No context for initial samples to prevent voice bleed
433
+ max_audio_length_ms=6000, # Shorter for more control
434
+ temperature=0.7, # Lower temperature for more stability
435
+ topk=30, # More focused sampling
436
+ )
437
+
438
+ # Validate and enhance the audio
439
+ is_valid, audio, message = validate_generated_audio(
440
+ audio, voice_name, generator.sample_rate
441
+ )
442
+
443
+ if is_valid:
444
+ # Enhance the audio
445
+ audio = enhance_audio(audio, generator.sample_rate, profile)
446
+
447
+ # Save the reference to persistent storage
448
+ torchaudio.save(ref_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
449
+ reference_segments.append(audio)
450
+ logger.info(f"Generated reference {i+1} for {voice_name}: {message}")
451
+ else:
452
+ logger.warning(f"Invalid reference for {voice_name}: {message}")
453
+ # Try again with different settings if invalid
454
+ if i < len(prompts) - 1:
455
+ logger.info(f"Trying again with next prompt")
456
+ continue
457
+
458
+ except Exception as e:
459
+ logger.error(f"Error generating reference for {voice_name}: {e}")
460
+
461
+ # Update the voice profile with references
462
+ if reference_segments:
463
+ VOICE_PROFILES[voice_name].reference_segments = reference_segments
464
+ logger.info(f"Updated {voice_name} with {len(reference_segments)} reference segments")
465
+
466
+ # Save the updated profiles to persistent storage
467
+ save_voice_profiles()
468
+
469
+ def get_voice_segments(voice_name: str, device: torch.device) -> List:
470
+ """Get context segments for a given voice.
471
+
472
+ Args:
473
+ voice_name: Name of the voice to use
474
+ device: Device to place tensors on
475
+
476
+ Returns:
477
+ List of context segments
478
+ """
479
+ from app.models import Segment
480
+
481
+ if voice_name not in VOICE_PROFILES:
482
+ logger.warning(f"Voice {voice_name} not found, defaulting to alloy")
483
+ voice_name = "alloy"
484
+
485
+ profile = VOICE_PROFILES[voice_name]
486
+
487
+ # If we don't have reference segments yet, create them
488
+ if not profile.reference_segments:
489
+ try:
490
+ # Try to load from disk - use persistent storage
491
+ voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name)
492
+
493
+ if os.path.exists(voice_dir):
494
+ reference_segments = []
495
+ prompts = VOICE_PROMPTS[voice_name]
496
+
497
+ for i, prompt in enumerate(prompts):
498
+ ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav")
499
+ if os.path.exists(ref_path):
500
+ audio_tensor, sr = torchaudio.load(ref_path)
501
+ audio_tensor = audio_tensor.squeeze(0)
502
+ reference_segments.append(audio_tensor)
503
+
504
+ if reference_segments:
505
+ profile.reference_segments = reference_segments
506
+ logger.info(f"Loaded {len(reference_segments)} reference segments for {voice_name}")
507
+ except Exception as e:
508
+ logger.error(f"Error loading reference segments for {voice_name}: {e}")
509
+
510
+ # Create context segments from references
511
+ context = []
512
+ if profile.reference_segments:
513
+ for i, ref_audio in enumerate(profile.reference_segments):
514
+ # Use corresponding prompt if available, otherwise use a generic one
515
+ text = VOICE_PROMPTS[voice_name][i] if i < len(VOICE_PROMPTS[voice_name]) else f"Voice reference for {voice_name}"
516
+
517
+ context.append(
518
+ Segment(
519
+ speaker=profile.speaker_id,
520
+ text=text,
521
+ audio=ref_audio.to(device)
522
+ )
523
+ )
524
+
525
+ logger.info(f"Returning {len(context)} context segments for {voice_name}")
526
+ return context
527
+
528
+ def save_voice_profiles():
529
+ """Save voice profiles to persistent storage."""
530
+ os.makedirs(VOICE_PROFILES_DIR, exist_ok=True)
531
+ profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt")
532
+
533
+ # Create a serializable version of the profiles
534
+ serializable_profiles = {}
535
+ for name, profile in VOICE_PROFILES.items():
536
+ serializable_profiles[name] = {
537
+ 'reference_segments': [seg.cpu() for seg in profile.reference_segments]
538
+ }
539
+
540
+ # Save to persistent storage
541
+ torch.save(serializable_profiles, profile_path)
542
+ logger.info(f"Saved voice profiles to {profile_path}")
543
+
544
+ def process_generated_audio(
545
+ audio: torch.Tensor,
546
+ voice_name: str,
547
+ sample_rate: int,
548
+ text: str
549
+ ) -> torch.Tensor:
550
+ """Process generated audio for consistency and quality.
551
+
552
+ Args:
553
+ audio: Audio tensor
554
+ voice_name: Name of voice used
555
+ sample_rate: Audio sample rate
556
+ text: Text that was spoken
557
+
558
+ Returns:
559
+ Processed audio tensor
560
+ """
561
+ # Validate the audio
562
+ is_valid, audio, message = validate_generated_audio(audio, voice_name, sample_rate)
563
+ if not is_valid:
564
+ logger.warning(f"Generated audio validation issue: {message}")
565
+
566
+ # Get voice profile for enhancement
567
+ profile = VOICE_PROFILES.get(voice_name, VOICE_PROFILES["alloy"])
568
+
569
+ # Enhance the audio based on voice profile
570
+ enhanced = enhance_audio(audio, sample_rate, profile)
571
+
572
+ # Log the enhancement
573
+ original_duration = audio.shape[0] / sample_rate
574
+ enhanced_duration = enhanced.shape[0] / sample_rate
575
+ logger.info(
576
+ f"Processed audio for '{voice_name}': "
577
+ f"Duration: {original_duration:.2f}s->{enhanced_duration:.2f}s, "
578
+ f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{enhanced.pow(2).mean().sqrt().item():.3f}"
579
+ )
580
+
581
+ return enhanced
app/voice_memory.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Advanced voice memory system for consistent voice generation."""
2
+ import os
3
+ import torch
4
+ import torchaudio
5
+ import numpy as np
6
+ import random
7
+ import logging
8
+ from typing import Dict, List, Optional
9
+ from dataclasses import dataclass
10
+ from app.models import Segment
11
+
12
+ # Setup logging
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Path to store voice memories - use persistent location
16
+ VOICE_MEMORIES_DIR = "/app/voice_memories"
17
+ os.makedirs(VOICE_MEMORIES_DIR, exist_ok=True)
18
+
19
+ @dataclass
20
+ class VoiceMemory:
21
+ """Store voice characteristics for consistent generation."""
22
+ name: str # Voice name (alloy, echo, etc.)
23
+ speaker_id: int # Speaker ID (0-5)
24
+ # Store multiple audio segments for context
25
+ audio_segments: List[torch.Tensor]
26
+ # Store text prompts that produced good results
27
+ text_segments: List[str]
28
+ # Base characteristics for this voice
29
+ pitch_base: float # Base pitch characteristic (Hz)
30
+ timbre: str # Voice quality descriptor
31
+
32
+ def get_context_segments(self, device: torch.device, max_segments: int = 2) -> List[Segment]:
33
+ """Get context segments for this voice."""
34
+ if not self.audio_segments:
35
+ return []
36
+
37
+ # Select a limited number of segments to avoid context overflow
38
+ num_segments = min(len(self.audio_segments), max_segments)
39
+ indices = list(range(len(self.audio_segments)))
40
+ random.shuffle(indices)
41
+ selected_indices = indices[:num_segments]
42
+
43
+ segments = []
44
+ for i in selected_indices:
45
+ segments.append(
46
+ Segment(
47
+ speaker=self.speaker_id,
48
+ text=self.text_segments[i] if i < len(self.text_segments) else f"Voice sample {i}",
49
+ audio=self.audio_segments[i].to(device)
50
+ )
51
+ )
52
+
53
+ return segments
54
+
55
+ def update_with_new_audio(self, audio: torch.Tensor, text: str, max_stored: int = 5):
56
+ """Update voice memory with newly generated audio."""
57
+ # Add new audio and text
58
+ self.audio_segments.append(audio.detach().cpu())
59
+ self.text_segments.append(text)
60
+
61
+ # Keep only the most recent segments
62
+ if len(self.audio_segments) > max_stored:
63
+ self.audio_segments = self.audio_segments[-max_stored:]
64
+ self.text_segments = self.text_segments[-max_stored:]
65
+
66
+ def save(self):
67
+ """Save voice memory to persistent storage."""
68
+ data = {
69
+ "name": self.name,
70
+ "speaker_id": self.speaker_id,
71
+ "audio_segments": self.audio_segments,
72
+ "text_segments": self.text_segments,
73
+ "pitch_base": self.pitch_base,
74
+ "timbre": self.timbre
75
+ }
76
+
77
+ # Save to the persistent directory
78
+ save_path = os.path.join(VOICE_MEMORIES_DIR, f"{self.name}.pt")
79
+ try:
80
+ torch.save(data, save_path)
81
+ logger.info(f"Saved voice memory for {self.name} to {save_path}")
82
+ except Exception as e:
83
+ logger.error(f"Error saving voice memory for {self.name}: {e}")
84
+
85
+ @classmethod
86
+ def load(cls, name: str) -> Optional['VoiceMemory']:
87
+ """Load voice memory from persistent storage."""
88
+ path = os.path.join(VOICE_MEMORIES_DIR, f"{name}.pt")
89
+ if not os.path.exists(path):
90
+ logger.info(f"No saved voice memory found for {name} at {path}")
91
+ return None
92
+
93
+ try:
94
+ data = torch.load(path)
95
+ return cls(
96
+ name=data["name"],
97
+ speaker_id=data["speaker_id"],
98
+ audio_segments=data["audio_segments"],
99
+ text_segments=data["text_segments"],
100
+ pitch_base=data["pitch_base"],
101
+ timbre=data["timbre"]
102
+ )
103
+ except Exception as e:
104
+ logger.error(f"Error loading voice memory for {name}: {e}")
105
+ return None
106
+
107
+ # Dictionary of voice memories
108
+ VOICE_MEMORIES: Dict[str, VoiceMemory] = {}
109
+
110
+ # Voice characteristics
111
+ VOICE_CHARACTERISTICS = {
112
+ "alloy": {"pitch": 220.0, "timbre": "balanced", "description": "A balanced, natural voice with medium pitch"},
113
+ "echo": {"pitch": 330.0, "timbre": "resonant", "description": "A resonant voice with a reverberant quality"},
114
+ "fable": {"pitch": 523.0, "timbre": "bright", "description": "A bright, higher-pitched voice with clear articulation"},
115
+ "onyx": {"pitch": 165.0, "timbre": "deep", "description": "A deep, authoritative voice with lower pitch"},
116
+ "nova": {"pitch": 392.0, "timbre": "warm", "description": "A warm, smooth voice with pleasant midrange tone"},
117
+ "shimmer": {"pitch": 587.0, "timbre": "light", "description": "A light, airy voice with higher frequencies"}
118
+ }
119
+
120
+ # Voice intro texts - carefully crafted to capture voice characteristics
121
+ VOICE_INTROS = {
122
+ "alloy": [
123
+ "Hello, I'm Alloy. My voice is designed to be clear and balanced.",
124
+ "This is the Alloy voice. I aim to sound natural and easy to understand.",
125
+ "Welcome, I'm the voice known as Alloy. I have a balanced, medium-range tone."
126
+ ],
127
+ "echo": [
128
+ "Hello, I'm Echo. My voice has a rich, resonant quality.",
129
+ "This is the Echo voice. Notice my distinctive resonance and depth.",
130
+ "Welcome, I'm the voice known as Echo. My tone is designed to resonate clearly."
131
+ ],
132
+ "fable": [
133
+ "Hello, I'm Fable. My voice is bright and articulate.",
134
+ "This is the Fable voice. I have a higher pitch with clear pronunciation.",
135
+ "Welcome, I'm the voice known as Fable. I speak with a bright, energetic tone."
136
+ ],
137
+ "onyx": [
138
+ "Hello, I'm Onyx. My voice is deep and authoritative.",
139
+ "This is the Onyx voice. I speak with a lower pitch and commanding presence.",
140
+ "Welcome, I'm the voice known as Onyx. My tone is deep and resonant."
141
+ ],
142
+ "nova": [
143
+ "Hello, I'm Nova. My voice is warm and harmonious.",
144
+ "This is the Nova voice. I have a smooth, pleasant mid-range quality.",
145
+ "Welcome, I'm the voice known as Nova. I speak with a warm, friendly tone."
146
+ ],
147
+ "shimmer": [
148
+ "Hello, I'm Shimmer. My voice is light and expressive.",
149
+ "This is the Shimmer voice. I have a higher-pitched, airy quality.",
150
+ "Welcome, I'm the voice known as Shimmer. My tone is bright and crisp."
151
+ ]
152
+ }
153
+
154
+ def initialize_voices(sample_rate: int = 24000):
155
+ """Initialize voice memories with consistent base samples."""
156
+ global VOICE_MEMORIES
157
+
158
+ # Check if persistent directory exists, create if needed
159
+ os.makedirs(VOICE_MEMORIES_DIR, exist_ok=True)
160
+ logger.info(f"Using voice memories directory: {VOICE_MEMORIES_DIR}")
161
+
162
+ # First try to load existing memories from persistent storage
163
+ for voice_name in ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]:
164
+ memory = VoiceMemory.load(voice_name)
165
+ if memory:
166
+ VOICE_MEMORIES[voice_name] = memory
167
+ logger.info(f"Loaded existing voice memory for {voice_name} with {len(memory.audio_segments)} segments")
168
+ continue
169
+
170
+ # If no memory exists, create a new one
171
+ speaker_id = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"].index(voice_name)
172
+ characteristics = VOICE_CHARACTERISTICS[voice_name]
173
+
174
+ # Create deterministic seed audio
175
+ np.random.seed(speaker_id + 42)
176
+ duration = 1.0 # seconds
177
+ t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
178
+
179
+ # Create characteristic waveform
180
+ pitch = characteristics["pitch"]
181
+ if voice_name == "alloy":
182
+ audio = 0.5 * np.sin(2 * np.pi * pitch * t) + 0.3 * np.sin(2 * np.pi * pitch * 2 * t)
183
+ elif voice_name == "echo":
184
+ audio = np.sin(2 * np.pi * pitch * t) * np.exp(-t * 3)
185
+ elif voice_name == "fable":
186
+ audio = 0.7 * np.sin(2 * np.pi * pitch * t)
187
+ elif voice_name == "onyx":
188
+ audio = 0.8 * np.sin(2 * np.pi * pitch * t) + 0.1 * np.sin(2 * np.pi * pitch * 0.5 * t)
189
+ elif voice_name == "nova":
190
+ audio = 0.4 * np.sin(2 * np.pi * pitch * t) + 0.4 * np.sin(2 * np.pi * pitch * 0.5 * t)
191
+ else: # shimmer
192
+ audio = 0.3 * np.sin(2 * np.pi * pitch * t) + 0.2 * np.sin(2 * np.pi * pitch * 1.5 * t) + 0.1 * np.sin(2 * np.pi * pitch * 2 * t)
193
+
194
+ # Normalize
195
+ audio = audio / np.max(np.abs(audio))
196
+
197
+ # Convert to tensor
198
+ audio_tensor = torch.tensor(audio, dtype=torch.float32)
199
+
200
+ # Create voice memory
201
+ memory = VoiceMemory(
202
+ name=voice_name,
203
+ speaker_id=speaker_id,
204
+ audio_segments=[audio_tensor],
205
+ text_segments=[f"This is the voice of {voice_name}"],
206
+ pitch_base=characteristics["pitch"],
207
+ timbre=characteristics["timbre"]
208
+ )
209
+
210
+ # Save the voice memory to persistent storage
211
+ memory.save()
212
+
213
+ # Store in dictionary
214
+ VOICE_MEMORIES[voice_name] = memory
215
+
216
+ # Save as wav for reference
217
+ save_path = os.path.join(VOICE_MEMORIES_DIR, f"{voice_name}_seed.wav")
218
+ torchaudio.save(save_path, audio_tensor.unsqueeze(0), sample_rate)
219
+
220
+ logger.info(f"Initialized new voice memory for {voice_name}")
221
+
222
+ def get_voice_context(voice_name: str, device: torch.device, max_segments: int = 2) -> List[Segment]:
223
+ """Get context segments for a given voice."""
224
+ if not VOICE_MEMORIES:
225
+ initialize_voices()
226
+
227
+ if voice_name in VOICE_MEMORIES:
228
+ return VOICE_MEMORIES[voice_name].get_context_segments(device, max_segments=max_segments)
229
+
230
+ # Default to alloy if voice not found
231
+ logger.warning(f"Voice {voice_name} not found, defaulting to alloy")
232
+ return VOICE_MEMORIES["alloy"].get_context_segments(device, max_segments=max_segments)
233
+
234
+ def update_voice_memory(voice_name: str, audio: torch.Tensor, text: str):
235
+ """Update voice memory with newly generated audio and save to persistent storage."""
236
+ if not VOICE_MEMORIES:
237
+ initialize_voices()
238
+
239
+ if voice_name in VOICE_MEMORIES:
240
+ VOICE_MEMORIES[voice_name].update_with_new_audio(audio, text)
241
+ VOICE_MEMORIES[voice_name].save()
242
+ logger.info(f"Updated voice memory for {voice_name}, now has {len(VOICE_MEMORIES[voice_name].audio_segments)} segments")
243
+
244
+ def generate_voice_samples(app_state):
245
+ """Generate high-quality voice samples for each voice.
246
+
247
+ Args:
248
+ app_state: The FastAPI app state containing the generator
249
+ """
250
+ generator = app_state.generator
251
+ if not generator:
252
+ logger.error("Cannot generate voice samples: generator not available")
253
+ return
254
+
255
+ logger.info("Beginning voice sample generation...")
256
+
257
+ # Ensure persistent directory exists
258
+ os.makedirs(VOICE_MEMORIES_DIR, exist_ok=True)
259
+
260
+ for voice_name in ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]:
261
+ speaker_id = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"].index(voice_name)
262
+
263
+ # Get multiple sample texts for this voice
264
+ sample_texts = VOICE_INTROS[voice_name]
265
+
266
+ # Generate a collection of samples for this voice
267
+ logger.info(f"Generating samples for voice: {voice_name}")
268
+ audio_segments = []
269
+ text_segments = []
270
+
271
+ for i, sample_text in enumerate(sample_texts):
272
+ try:
273
+ # Check if we already have a sample
274
+ sample_path = os.path.join(VOICE_MEMORIES_DIR, f"{voice_name}_sample_{i}.wav")
275
+ if os.path.exists(sample_path):
276
+ logger.info(f"Found existing sample {i+1} for {voice_name}, loading from {sample_path}")
277
+ audio_tensor, sr = torchaudio.load(sample_path)
278
+ if sr != generator.sample_rate:
279
+ audio_tensor = torchaudio.functional.resample(
280
+ audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate
281
+ )
282
+ else:
283
+ audio_tensor = audio_tensor.squeeze(0)
284
+ audio_segments.append(audio_tensor)
285
+ text_segments.append(sample_text)
286
+ continue
287
+
288
+ # Generate without context first for seed samples
289
+ logger.info(f"Generating sample {i+1}/{len(sample_texts)} for {voice_name}: '{sample_text}'")
290
+
291
+ # Use a lower temperature for more stable output
292
+ audio = generator.generate(
293
+ text=sample_text,
294
+ speaker=speaker_id,
295
+ context=[], # No context for initial samples
296
+ max_audio_length_ms=10000,
297
+ temperature=0.7, # Lower temperature for more stable output
298
+ topk=30,
299
+ )
300
+
301
+ # Save this segment
302
+ audio_segments.append(audio.detach().cpu())
303
+ text_segments.append(sample_text)
304
+
305
+ # Save as WAV for reference to persistent storage
306
+ torchaudio.save(sample_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
307
+
308
+ logger.info(f"Generated sample {i+1} for {voice_name}, length: {audio.shape[0]/generator.sample_rate:.2f}s")
309
+
310
+ except Exception as e:
311
+ logger.error(f"Error generating sample {i+1} for {voice_name}: {e}")
312
+
313
+ # Use the generated samples to update the voice memory
314
+ if voice_name in VOICE_MEMORIES and audio_segments:
315
+ # Replace existing samples with these high quality ones
316
+ VOICE_MEMORIES[voice_name].audio_segments = audio_segments
317
+ VOICE_MEMORIES[voice_name].text_segments = text_segments
318
+ VOICE_MEMORIES[voice_name].save()
319
+
320
+ logger.info(f"Updated voice memory for {voice_name} with {len(audio_segments)} high-quality samples")
321
+
322
+ # Now generate a second pass with context from these samples
323
+ if len(audio_segments) >= 2:
324
+ try:
325
+ # Check if we already have a character sample
326
+ character_path = os.path.join(VOICE_MEMORIES_DIR, f"{voice_name}_character.wav")
327
+ if os.path.exists(character_path):
328
+ logger.info(f"Found existing character sample for {voice_name}, loading from {character_path}")
329
+ audio_tensor, sr = torchaudio.load(character_path)
330
+ if sr != generator.sample_rate:
331
+ audio_tensor = torchaudio.functional.resample(
332
+ audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate
333
+ )
334
+ else:
335
+ audio_tensor = audio_tensor.squeeze(0)
336
+
337
+ character_sample_text = f"I'm the voice assistant known as {voice_name}. I'm designed to have a distinctive voice that you can easily recognize."
338
+ VOICE_MEMORIES[voice_name].audio_segments.append(audio_tensor)
339
+ VOICE_MEMORIES[voice_name].text_segments.append(character_sample_text)
340
+ VOICE_MEMORIES[voice_name].save()
341
+ continue
342
+
343
+ # Get intro and conclusion prompts that build voice consistency
344
+ context = [
345
+ Segment(
346
+ speaker=speaker_id,
347
+ text=text_segments[0],
348
+ audio=audio_segments[0].to(generator.device)
349
+ )
350
+ ]
351
+
352
+ # Create a longer sample with the voice characteristics now established
353
+ character_sample_text = f"I'm the voice assistant known as {voice_name}. I'm designed to have a distinctive voice that you can easily recognize. My speech patterns and tone should remain consistent throughout our conversation."
354
+
355
+ logger.info(f"Generating character sample for {voice_name} with context")
356
+ character_audio = generator.generate(
357
+ text=character_sample_text,
358
+ speaker=speaker_id,
359
+ context=context,
360
+ max_audio_length_ms=15000,
361
+ temperature=0.7,
362
+ topk=30,
363
+ )
364
+
365
+ # Save this comprehensive character sample to persistent storage
366
+ torchaudio.save(character_path, character_audio.unsqueeze(0).cpu(), generator.sample_rate)
367
+
368
+ # Add this to the memory as well
369
+ VOICE_MEMORIES[voice_name].audio_segments.append(character_audio.detach().cpu())
370
+ VOICE_MEMORIES[voice_name].text_segments.append(character_sample_text)
371
+ VOICE_MEMORIES[voice_name].save()
372
+
373
+ logger.info(f"Generated character sample for {voice_name}, length: {character_audio.shape[0]/generator.sample_rate:.2f}s")
374
+
375
+ except Exception as e:
376
+ logger.error(f"Error generating character sample for {voice_name}: {e}")
377
+
378
+ def create_custom_voice(
379
+ app_state,
380
+ name: str,
381
+ initial_text: str,
382
+ speaker_id: int = 0,
383
+ pitch: Optional[float] = None,
384
+ timbre: str = "custom"
385
+ ) -> Dict:
386
+ """Create a new custom voice.
387
+
388
+ Args:
389
+ app_state: The FastAPI app state containing the generator
390
+ name: Name for the new voice
391
+ initial_text: Text for the initial voice sample
392
+ speaker_id: Base speaker ID (0-5)
393
+ pitch: Base pitch in Hz (optional)
394
+ timbre: Voice quality descriptor
395
+
396
+ Returns:
397
+ Dict with creation status and voice info
398
+ """
399
+ generator = app_state.generator
400
+ if not generator:
401
+ return {"status": "error", "message": "Generator not available"}
402
+
403
+ # Check if voice already exists
404
+ if not VOICE_MEMORIES:
405
+ initialize_voices()
406
+
407
+ if name in VOICE_MEMORIES:
408
+ return {"status": "error", "message": f"Voice '{name}' already exists"}
409
+
410
+ # Generate a voice sample
411
+ try:
412
+ logger.info(f"Creating custom voice '{name}' with text: '{initial_text}'")
413
+
414
+ audio = generator.generate(
415
+ text=initial_text,
416
+ speaker=speaker_id,
417
+ context=[],
418
+ max_audio_length_ms=10000,
419
+ temperature=0.7,
420
+ )
421
+
422
+ # Determine base pitch if not provided
423
+ if pitch is None:
424
+ if speaker_id == 0: # alloy
425
+ pitch = 220.0
426
+ elif speaker_id == 1: # echo
427
+ pitch = 330.0
428
+ elif speaker_id == 2: # fable
429
+ pitch = 523.0
430
+ elif speaker_id == 3: # onyx
431
+ pitch = 165.0
432
+ elif speaker_id == 4: # nova
433
+ pitch = 392.0
434
+ else: # shimmer
435
+ pitch = 587.0
436
+
437
+ # Create a new voice memory
438
+ memory = VoiceMemory(
439
+ name=name,
440
+ speaker_id=speaker_id,
441
+ audio_segments=[audio.detach().cpu()],
442
+ text_segments=[initial_text],
443
+ pitch_base=pitch,
444
+ timbre=timbre
445
+ )
446
+
447
+ # Save the voice memory to persistent storage
448
+ memory.save()
449
+ VOICE_MEMORIES[name] = memory
450
+
451
+ # Save sample as WAV for reference to persistent storage
452
+ sample_path = os.path.join(VOICE_MEMORIES_DIR, f"{name}_sample.wav")
453
+ torchaudio.save(sample_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
454
+
455
+ logger.info(f"Created custom voice '{name}' successfully")
456
+
457
+ return {
458
+ "status": "success",
459
+ "message": f"Voice '{name}' created successfully",
460
+ "voice": {
461
+ "name": name,
462
+ "speaker_id": speaker_id,
463
+ "pitch": pitch,
464
+ "timbre": timbre,
465
+ "sample_length_seconds": audio.shape[0] / generator.sample_rate
466
+ }
467
+ }
468
+
469
+ except Exception as e:
470
+ logger.error(f"Error creating custom voice '{name}': {e}")
471
+ return {
472
+ "status": "error",
473
+ "message": f"Error creating voice: {str(e)}"
474
+ }
app/watermarking.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stub for watermarking module.
2
+
3
+ The original CSM code has a watermarking module that adds
4
+ an imperceptible watermark to generated audio.
5
+ """
6
+
7
+ # Watermark key used by CSM
8
+ CSM_1B_GH_WATERMARK = "CSM1B@GitHub"
9
+
10
+ def load_watermarker(device="cpu"):
11
+ """Stub for watermarker loading.
12
+
13
+ In a real implementation, this would load the actual watermarker.
14
+ """
15
+ return None
16
+
17
+ def watermark(watermarker, audio, sample_rate, key):
18
+ """Stub for watermarking function.
19
+
20
+ In a real implementation, this would add the watermark.
21
+ """
22
+ return audio, sample_rate
docker-compose.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ csm-api:
3
+ build:
4
+ context: .
5
+ dockerfile: Dockerfile
6
+ args:
7
+ - HF_TOKEN=${HF_TOKEN}
8
+ ports:
9
+ - "8000:8000"
10
+ volumes:
11
+ - ./models:/app/models
12
+ - ./cloned_voices:/app/cloned_voices
13
+ - ./voice_references:/app/voice_references
14
+ - ./voice_profiles:/app/voice_profiles
15
+ environment:
16
+ - HF_TOKEN=${HF_TOKEN}
17
+ deploy:
18
+ resources:
19
+ reservations:
20
+ devices:
21
+ - driver: nvidia
22
+ count: all
23
+ capabilities: [gpu]
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.100.0
2
+ uvicorn>=0.24.0
3
+ pydantic>=2.4.0
4
+ python-multipart>=0.0.6
5
+ huggingface_hub>=0.20.0
6
+ torch>=2.0.0
7
+ torchaudio>=2.0.0
8
+ transformers>=4.35.0
9
+ tokenizers>=0.14.0
10
+ sentencepiece>=0.1.99
11
+ triton>=2.1.0; platform_system != "Windows"
12
+ triton-windows>=2.1.0; platform_system == "Windows"
13
+ torchao>=0.1.0
14
+ torchtune>=0.5.0
15
+ numpy>=1.25.0
16
+ ffmpeg-python>=0.2.0
17
+ moshi>=0.1.0
18
+ soundfile>=0.12.1
19
+ scipy>=1.10.0
20
+ librosa>=0.10.0
21
+ yt-dlp>=2023.3.4
22
+ openai-whisper>=20230314
23
+ ffmpeg-python>=0.2.0
24
+ accelerate>=0.20.0
static/voice-cloning.html ADDED
@@ -0,0 +1,1385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en" data-theme="light">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
6
+ <title>Voice Cloning</title>
7
+ <style>
8
+ :root {
9
+ /* Base colors */
10
+ --bg-color: #ffffff;
11
+ --text-color: #333333;
12
+ --card-bg: #ffffff;
13
+ --card-border: #e2e8f0;
14
+ --card-shadow: rgba(0,0,0,0.1);
15
+
16
+ /* Primary colors */
17
+ --primary-color: #4f46e5;
18
+ --primary-hover: #4338ca;
19
+ --primary-light: rgba(79, 70, 229, 0.1);
20
+
21
+ /* Secondary colors */
22
+ --secondary-color: #6b7280;
23
+ --secondary-hover: #4b5563;
24
+
25
+ /* Accent colors */
26
+ --success-color: #10b981;
27
+ --success-hover: #059669;
28
+ --danger-color: #ef4444;
29
+ --danger-hover: #dc2626;
30
+ --warning-color: #f59e0b;
31
+ --warning-hover: #d97706;
32
+
33
+ /* Input elements */
34
+ --input-bg: #ffffff;
35
+ --input-border: #d1d5db;
36
+ --input-focus-border: #4f46e5;
37
+ --input-focus-shadow: rgba(79, 70, 229, 0.2);
38
+
39
+ /* Voice cards */
40
+ --voice-card-bg: #f9fafb;
41
+ --voice-card-border: #e5e7eb;
42
+ --voice-card-shadow: rgba(0,0,0,0.05);
43
+
44
+ /* Tabs */
45
+ --tab-border: #e5e7eb;
46
+ --tab-text: #4b5563;
47
+ --tab-active: #4f46e5;
48
+ --tab-active-bg: rgba(79, 70, 229, 0.1);
49
+
50
+ /* Toggle */
51
+ --toggle-bg: #e5e7eb;
52
+ --toggle-active: #4f46e5;
53
+ --toggle-circle: #ffffff;
54
+
55
+ /* Status indicators */
56
+ --status-success-bg: #dcfce7;
57
+ --status-success-text: #166534;
58
+ --status-error-bg: #fee2e2;
59
+ --status-error-text: #b91c1c;
60
+ --status-warning-bg: #fff7ed;
61
+ --status-warning-text: #c2410c;
62
+ --status-info-bg: #eff6ff;
63
+ --status-info-text: #1e40af;
64
+ }
65
+
66
+ [data-theme="dark"] {
67
+ /* Base colors */
68
+ --bg-color: #111827;
69
+ --text-color: #f3f4f6;
70
+ --card-bg: #1f2937;
71
+ --card-border: #374151;
72
+ --card-shadow: rgba(0,0,0,0.3);
73
+
74
+ /* Primary colors */
75
+ --primary-color: #6366f1;
76
+ --primary-hover: #4f46e5;
77
+ --primary-light: rgba(99, 102, 241, 0.2);
78
+
79
+ /* Secondary colors */
80
+ --secondary-color: #9ca3af;
81
+ --secondary-hover: #6b7280;
82
+
83
+ /* Accent colors - slightly brighter for dark theme */
84
+ --success-color: #34d399;
85
+ --success-hover: #10b981;
86
+ --danger-color: #f87171;
87
+ --danger-hover: #ef4444;
88
+ --warning-color: #fbbf24;
89
+ --warning-hover: #f59e0b;
90
+
91
+ /* Input elements */
92
+ --input-bg: #374151;
93
+ --input-border: #4b5563;
94
+ --input-focus-border: #6366f1;
95
+ --input-focus-shadow: rgba(99, 102, 241, 0.3);
96
+
97
+ /* Voice cards */
98
+ --voice-card-bg: #1f2937;
99
+ --voice-card-border: #374151;
100
+ --voice-card-shadow: rgba(0,0,0,0.2);
101
+
102
+ /* Tabs */
103
+ --tab-border: #374151;
104
+ --tab-text: #9ca3af;
105
+ --tab-active: #6366f1;
106
+ --tab-active-bg: rgba(99, 102, 241, 0.2);
107
+
108
+ /* Toggle */
109
+ --toggle-bg: #4b5563;
110
+ --toggle-active: #6366f1;
111
+ --toggle-circle: #e5e7eb;
112
+
113
+ /* Status indicators */
114
+ --status-success-bg: rgba(16, 185, 129, 0.2);
115
+ --status-success-text: #34d399;
116
+ --status-error-bg: rgba(239, 68, 68, 0.2);
117
+ --status-error-text: #f87171;
118
+ --status-warning-bg: rgba(245, 158, 11, 0.2);
119
+ --status-warning-text: #fbbf24;
120
+ --status-info-bg: rgba(59, 130, 246, 0.2);
121
+ --status-info-text: #60a5fa;
122
+ }
123
+
124
+ * {
125
+ box-sizing: border-box;
126
+ }
127
+
128
+ html, body {
129
+ margin: 0;
130
+ padding: 0;
131
+ overflow-x: hidden;
132
+ width: 100%;
133
+ }
134
+
135
+ body {
136
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
137
+ max-width: 100%;
138
+ padding: 16px;
139
+ line-height: 1.6;
140
+ background-color: var(--bg-color);
141
+ color: var(--text-color);
142
+ transition: background-color 0.3s, color 0.3s;
143
+ }
144
+
145
+ .container {
146
+ width: 100%;
147
+ max-width: 800px;
148
+ margin: 0 auto;
149
+ padding: 0 8px;
150
+ }
151
+
152
+ h1, h2, h3 {
153
+ color: var(--text-color);
154
+ margin-top: 0;
155
+ }
156
+
157
+ h1 {
158
+ font-size: 1.8rem;
159
+ margin-bottom: 1rem;
160
+ text-align: center;
161
+ }
162
+
163
+ h2 {
164
+ font-size: 1.4rem;
165
+ margin-bottom: 1rem;
166
+ }
167
+
168
+ .card {
169
+ border: 1px solid var(--card-border);
170
+ border-radius: 12px;
171
+ padding: 20px;
172
+ margin-bottom: 20px;
173
+ box-shadow: 0 4px 6px var(--card-shadow);
174
+ background-color: var(--card-bg);
175
+ transition: box-shadow 0.3s ease;
176
+ }
177
+
178
+ .card:hover {
179
+ box-shadow: 0 6px 12px var(--card-shadow);
180
+ }
181
+
182
+ .form-group {
183
+ margin-bottom: 20px;
184
+ }
185
+
186
+ label {
187
+ display: block;
188
+ margin-bottom: 6px;
189
+ font-weight: 500;
190
+ font-size: 0.95rem;
191
+ }
192
+
193
+ input, textarea, select {
194
+ width: 100%;
195
+ padding: 10px 12px;
196
+ border: 1px solid var(--input-border);
197
+ border-radius: 8px;
198
+ background-color: var(--input-bg);
199
+ color: var(--text-color);
200
+ font-size: 1rem;
201
+ transition: border-color 0.3s, box-shadow 0.3s;
202
+ }
203
+
204
+ input:focus, textarea:focus, select:focus {
205
+ outline: none;
206
+ border-color: var(--input-focus-border);
207
+ box-shadow: 0 0 0 3px var(--input-focus-shadow);
208
+ }
209
+
210
+ /* File input styling */
211
+ input[type="file"] {
212
+ padding: 8px;
213
+ background-color: var(--input-bg);
214
+ border: 1px dashed var(--input-border);
215
+ border-radius: 8px;
216
+ cursor: pointer;
217
+ }
218
+
219
+ input[type="file"]:hover {
220
+ border-color: var(--primary-color);
221
+ }
222
+
223
+ button {
224
+ background-color: var(--primary-color);
225
+ color: white;
226
+ border: none;
227
+ border-radius: 8px;
228
+ padding: 12px 20px;
229
+ cursor: pointer;
230
+ font-weight: 600;
231
+ font-size: 1rem;
232
+ transition: background-color 0.3s, transform 0.1s;
233
+ width: 100%;
234
+ }
235
+
236
+ button:hover {
237
+ background-color: var(--primary-hover);
238
+ }
239
+
240
+ button:active {
241
+ transform: translateY(1px);
242
+ }
243
+
244
+ button:disabled {
245
+ opacity: 0.7;
246
+ cursor: not-allowed;
247
+ }
248
+
249
+ .btn-row {
250
+ display: flex;
251
+ gap: 10px;
252
+ }
253
+
254
+ .btn-row button {
255
+ flex: 1;
256
+ }
257
+
258
+ .voice-list {
259
+ display: grid;
260
+ grid-template-columns: repeat(auto-fill, minmax(280px, 1fr));
261
+ gap: 16px;
262
+ }
263
+
264
+ .voice-card {
265
+ border: 1px solid var(--voice-card-border);
266
+ border-radius: 10px;
267
+ padding: 16px;
268
+ background-color: var(--voice-card-bg);
269
+ box-shadow: 0 2px 6px var(--voice-card-shadow);
270
+ transition: transform 0.2s ease, box-shadow 0.2s ease;
271
+ }
272
+
273
+ .voice-card:hover {
274
+ transform: translateY(-3px);
275
+ box-shadow: 0 6px 12px var(--voice-card-shadow);
276
+ }
277
+
278
+ .controls {
279
+ display: flex;
280
+ gap: 8px;
281
+ margin-top: 12px;
282
+ }
283
+
284
+ .voice-name {
285
+ font-weight: 600;
286
+ font-size: 18px;
287
+ margin: 0 0 8px 0;
288
+ color: var(--primary-color);
289
+ }
290
+
291
+ .btn-danger {
292
+ background-color: var(--danger-color);
293
+ }
294
+
295
+ .btn-danger:hover {
296
+ background-color: var(--danger-hover);
297
+ }
298
+
299
+ .btn-secondary {
300
+ background-color: var(--secondary-color);
301
+ }
302
+
303
+ .btn-secondary:hover {
304
+ background-color: var(--secondary-hover);
305
+ }
306
+
307
+ #audio-preview {
308
+ margin-top: 20px;
309
+ width: 100%;
310
+ border-radius: 8px;
311
+ background-color: var(--card-bg);
312
+ }
313
+
314
+ /* Status indicators */
315
+ .status-indicator {
316
+ padding: 12px 16px;
317
+ margin: 16px 0;
318
+ border-radius: 8px;
319
+ font-weight: 500;
320
+ display: flex;
321
+ align-items: center;
322
+ opacity: 0;
323
+ transition: opacity 0.3s ease;
324
+ max-height: 0;
325
+ overflow: hidden;
326
+ }
327
+
328
+ .status-indicator.show {
329
+ opacity: 1;
330
+ max-height: 100px;
331
+ }
332
+
333
+ .status-indicator.success {
334
+ background-color: var(--status-success-bg);
335
+ color: var(--status-success-text);
336
+ }
337
+
338
+ .status-indicator.error {
339
+ background-color: var(--status-error-bg);
340
+ color: var(--status-error-text);
341
+ }
342
+
343
+ .status-indicator.warning {
344
+ background-color: var(--status-warning-bg);
345
+ color: var(--status-warning-text);
346
+ }
347
+
348
+ .status-indicator.info {
349
+ background-color: var(--status-info-bg);
350
+ color: var(--status-info-text);
351
+ }
352
+
353
+ .status-indicator svg {
354
+ margin-right: 8px;
355
+ flex-shrink: 0;
356
+ }
357
+
358
+ /* Tabs */
359
+ .tabs {
360
+ display: flex;
361
+ margin-bottom: 20px;
362
+ border-bottom: 1px solid var(--tab-border);
363
+ overflow-x: auto;
364
+ -webkit-overflow-scrolling: touch;
365
+ scrollbar-width: none; /* Hide scrollbar for Firefox */
366
+ }
367
+
368
+ .tabs::-webkit-scrollbar {
369
+ display: none; /* Hide scrollbar for Chrome/Safari */
370
+ }
371
+
372
+ .tabs button {
373
+ background-color: transparent;
374
+ color: var(--tab-text);
375
+ border: none;
376
+ padding: 12px 16px;
377
+ margin-right: 8px;
378
+ cursor: pointer;
379
+ position: relative;
380
+ font-weight: 600;
381
+ border-radius: 8px 8px 0 0;
382
+ white-space: nowrap;
383
+ width: auto;
384
+ flex-shrink: 0;
385
+ }
386
+
387
+ .tabs button.active {
388
+ color: var(--tab-active);
389
+ background-color: var(--tab-active-bg);
390
+ }
391
+
392
+ .tabs button.active::after {
393
+ content: '';
394
+ position: absolute;
395
+ bottom: -1px;
396
+ left: 0;
397
+ right: 0;
398
+ height: 2px;
399
+ background-color: var(--tab-active);
400
+ }
401
+
402
+ .tab-content {
403
+ display: none;
404
+ }
405
+
406
+ .tab-content.active {
407
+ display: block;
408
+ }
409
+
410
+ /* Theme toggle switch */
411
+ .theme-switch-wrapper {
412
+ display: flex;
413
+ align-items: center;
414
+ justify-content: flex-end;
415
+ margin-bottom: 16px;
416
+ }
417
+
418
+ .theme-switch {
419
+ display: inline-block;
420
+ height: 28px;
421
+ position: relative;
422
+ width: 54px;
423
+ }
424
+
425
+ .theme-switch input {
426
+ display: none;
427
+ }
428
+
429
+ .slider {
430
+ background-color: var(--toggle-bg);
431
+ bottom: 0;
432
+ cursor: pointer;
433
+ left: 0;
434
+ position: absolute;
435
+ right: 0;
436
+ top: 0;
437
+ transition: .4s;
438
+ border-radius: 28px;
439
+ }
440
+
441
+ .slider:before {
442
+ background-color: var(--toggle-circle);
443
+ bottom: 4px;
444
+ content: "";
445
+ height: 20px;
446
+ left: 4px;
447
+ position: absolute;
448
+ transition: .4s;
449
+ width: 20px;
450
+ border-radius: 50%;
451
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
452
+ }
453
+
454
+ input:checked + .slider {
455
+ background-color: var(--toggle-active);
456
+ }
457
+
458
+ input:checked + .slider:before {
459
+ transform: translateX(26px);
460
+ }
461
+
462
+ .theme-icon {
463
+ width: 16px;
464
+ height: 16px;
465
+ display: inline-block;
466
+ margin: 0 8px;
467
+ font-size: 16px;
468
+ line-height: 1;
469
+ }
470
+
471
+ /* Progress bar */
472
+ .progress-bar {
473
+ width: 100%;
474
+ height: 12px;
475
+ background-color: var(--input-bg);
476
+ border-radius: 6px;
477
+ margin-top: 12px;
478
+ overflow: hidden;
479
+ box-shadow: inset 0 1px 3px var(--card-shadow);
480
+ }
481
+
482
+ .progress-fill {
483
+ height: 100%;
484
+ background: linear-gradient(to right, var(--primary-color), var(--primary-hover));
485
+ width: 0%;
486
+ transition: width 0.5s ease-in-out;
487
+ border-radius: 6px;
488
+ position: relative;
489
+ }
490
+
491
+ .progress-fill::after {
492
+ content: '';
493
+ position: absolute;
494
+ top: 0;
495
+ left: 0;
496
+ right: 0;
497
+ bottom: 0;
498
+ background: linear-gradient(
499
+ -45deg,
500
+ rgba(255, 255, 255, 0.2) 25%,
501
+ transparent 25%,
502
+ transparent 50%,
503
+ rgba(255, 255, 255, 0.2) 50%,
504
+ rgba(255, 255, 255, 0.2) 75%,
505
+ transparent 75%
506
+ );
507
+ background-size: 16px 16px;
508
+ animation: progress-animation 1s linear infinite;
509
+ border-radius: 6px;
510
+ }
511
+
512
+ @keyframes progress-animation {
513
+ 0% {
514
+ background-position: 0 0;
515
+ }
516
+ 100% {
517
+ background-position: 16px 0;
518
+ }
519
+ }
520
+
521
+ /* Divider */
522
+ .divider {
523
+ margin: 24px 0;
524
+ border-top: 1px solid var(--card-border);
525
+ position: relative;
526
+ }
527
+
528
+ .divider-text {
529
+ position: absolute;
530
+ top: -10px;
531
+ left: 50%;
532
+ transform: translateX(-50%);
533
+ background-color: var(--bg-color);
534
+ padding: 0 12px;
535
+ color: var(--secondary-color);
536
+ font-size: 0.9rem;
537
+ }
538
+
539
+ /* Small text */
540
+ small {
541
+ color: var(--secondary-color);
542
+ display: block;
543
+ margin-top: 6px;
544
+ font-size: 0.85rem;
545
+ }
546
+
547
+ /* Range slider styling */
548
+ input[type="range"] {
549
+ -webkit-appearance: none;
550
+ height: 8px;
551
+ border-radius: 4px;
552
+ background: var(--input-bg);
553
+ outline: none;
554
+ padding: 0;
555
+ margin: 10px 0;
556
+ }
557
+
558
+ input[type="range"]::-webkit-slider-thumb {
559
+ -webkit-appearance: none;
560
+ appearance: none;
561
+ width: 20px;
562
+ height: 20px;
563
+ border-radius: 50%;
564
+ background: var(--primary-color);
565
+ cursor: pointer;
566
+ border: 2px solid white;
567
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
568
+ }
569
+
570
+ input[type="range"]::-moz-range-thumb {
571
+ width: 20px;
572
+ height: 20px;
573
+ border-radius: 50%;
574
+ background: var(--primary-color);
575
+ cursor: pointer;
576
+ border: 2px solid white;
577
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
578
+ }
579
+
580
+ input[type="range"]::-ms-thumb {
581
+ width: 20px;
582
+ height: 20px;
583
+ border-radius: 50%;
584
+ background: var(--primary-color);
585
+ cursor: pointer;
586
+ border: 2px solid white;
587
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
588
+ }
589
+
590
+ #temperature-value {
591
+ display: inline-block;
592
+ width: 40px;
593
+ text-align: center;
594
+ font-weight: 600;
595
+ color: var(--primary-color);
596
+ }
597
+
598
+ /* Toast notifications */
599
+ .toast-container {
600
+ position: fixed;
601
+ bottom: 20px;
602
+ right: 20px;
603
+ z-index: 1000;
604
+ max-width: 100%;
605
+ width: 300px;
606
+ }
607
+
608
+ .toast {
609
+ padding: 12px 16px;
610
+ margin-bottom: 12px;
611
+ border-radius: 8px;
612
+ color: white;
613
+ font-weight: 500;
614
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
615
+ display: flex;
616
+ align-items: center;
617
+ justify-content: space-between;
618
+ animation: toast-in 0.3s ease forwards;
619
+ max-width: 100%;
620
+ }
621
+
622
+ .toast.success {
623
+ background-color: var(--success-color);
624
+ }
625
+
626
+ .toast.error {
627
+ background-color: var(--danger-color);
628
+ }
629
+
630
+ .toast.warning {
631
+ background-color: var(--warning-color);
632
+ }
633
+
634
+ .toast.info {
635
+ background-color: var(--primary-color);
636
+ }
637
+
638
+ .toast-close {
639
+ background: none;
640
+ border: none;
641
+ color: white;
642
+ font-size: 18px;
643
+ cursor: pointer;
644
+ opacity: 0.8;
645
+ width: auto;
646
+ padding: 0 0 0 12px;
647
+ }
648
+
649
+ .toast-close:hover {
650
+ opacity: 1;
651
+ background: none;
652
+ }
653
+
654
+ @keyframes toast-in {
655
+ from {
656
+ transform: translateX(100%);
657
+ opacity: 0;
658
+ }
659
+ to {
660
+ transform: translateX(0);
661
+ opacity: 1;
662
+ }
663
+ }
664
+
665
+ /* Loading spinner */
666
+ .spinner {
667
+ display: inline-block;
668
+ width: 20px;
669
+ height: 20px;
670
+ margin-right: 8px;
671
+ border: 3px solid rgba(255, 255, 255, 0.3);
672
+ border-radius: 50%;
673
+ border-top-color: white;
674
+ animation: spin 1s ease-in-out infinite;
675
+ }
676
+
677
+ @keyframes spin {
678
+ to { transform: rotate(360deg); }
679
+ }
680
+
681
+ /* Mobile responsiveness */
682
+ @media (max-width: 640px) {
683
+ body {
684
+ padding: 12px 8px;
685
+ }
686
+
687
+ h1 {
688
+ font-size: 1.6rem;
689
+ }
690
+
691
+ .card {
692
+ padding: 16px;
693
+ }
694
+
695
+ .voice-list {
696
+ grid-template-columns: 1fr;
697
+ }
698
+
699
+ .controls {
700
+ flex-direction: column;
701
+ }
702
+
703
+ .controls button {
704
+ width: 100%;
705
+ }
706
+
707
+ .tabs button {
708
+ padding: 8px 12px;
709
+ font-size: 0.9rem;
710
+ }
711
+ }
712
+ /* Checkbox styling */
713
+ .checkbox-group {
714
+ margin-bottom: 20px;
715
+ }
716
+ .checkbox-label {
717
+ display: flex;
718
+ align-items: center;
719
+ cursor: pointer;
720
+ font-weight: 500;
721
+ font-size: 0.95rem;
722
+ }
723
+ input[type="checkbox"] {
724
+ width: auto;
725
+ margin-right: 10px;
726
+ height: 18px;
727
+ width: 18px;
728
+ cursor: pointer;
729
+ accent-color: var(--primary-color);
730
+ }
731
+ .checkbox-text {
732
+ position: relative;
733
+ top: 1px;
734
+ }
735
+ </style>
736
+ </head>
737
+ <body>
738
+ <div class="container">
739
+ <div class="theme-switch-wrapper">
740
+ <span class="theme-icon">☀️</span>
741
+ <label class="theme-switch" for="checkbox">
742
+ <input type="checkbox" id="checkbox" />
743
+ <div class="slider"></div>
744
+ </label>
745
+ <span class="theme-icon">🌙</span>
746
+ </div>
747
+
748
+ <h1>Voice Cloning</h1>
749
+
750
+ <div class="tabs">
751
+ <button id="tab-clone" class="active">Clone Voice</button>
752
+ <button id="tab-voices">My Voices</button>
753
+ <button id="tab-generate">Generate Speech</button>
754
+ </div>
755
+
756
+ <!-- Status indicator -->
757
+ <div id="status-message" class="status-indicator">
758
+ <!-- Content will be dynamically inserted -->
759
+ </div>
760
+
761
+ <div id="clone-tab" class="tab-content active">
762
+ <div class="card">
763
+ <h2>Clone a New Voice</h2>
764
+ <form id="clone-form">
765
+ <div class="form-group">
766
+ <label for="voice-name">Voice Name</label>
767
+ <input type="text" id="voice-name" name="name" required placeholder="e.g. My Voice">
768
+ </div>
769
+ <div class="form-group">
770
+ <label for="audio-file">Voice Sample (2-3 minute audio recording)</label>
771
+ <input type="file" id="audio-file" name="audio_file" required accept="audio/*">
772
+ <small>For best results, provide a clear recording with minimal background noise.</small>
773
+ </div>
774
+ <div class="form-group">
775
+ <label for="transcript">Transcript (Optional)</label>
776
+ <textarea id="transcript" name="transcript" rows="4" placeholder="Exact transcript of your audio sample..."></textarea>
777
+ <small>Adding a transcript helps improve voice accuracy.</small>
778
+ </div>
779
+ <div class="form-group">
780
+ <label for="description">Description (Optional)</label>
781
+ <input type="text" id="description" name="description" placeholder="A description of this voice">
782
+ </div>
783
+ <button type="submit">Clone Voice</button>
784
+ </form>
785
+ </div>
786
+
787
+ <div class="divider">
788
+ <span class="divider-text">OR</span>
789
+ </div>
790
+
791
+ <div class="card">
792
+ <h2>Clone Voice from YouTube</h2>
793
+ <form id="youtube-clone-form">
794
+ <div class="form-group">
795
+ <label for="youtube-url">YouTube URL</label>
796
+ <input type="url" id="youtube-url" name="youtube_url" required placeholder="https://www.youtube.com/watch?v=...">
797
+ </div>
798
+ <div class="form-group">
799
+ <label for="youtube-voice-name">Voice Name</label>
800
+ <input type="text" id="youtube-voice-name" name="voice_name" required placeholder="e.g. YouTube Voice">
801
+ </div>
802
+ <div class="form-group">
803
+ <label for="start-time">Start Time (seconds)</label>
804
+ <input type="number" id="start-time" name="start_time" min="0" value="0">
805
+ </div>
806
+ <div class="form-group">
807
+ <label for="duration">Duration (seconds)</label>
808
+ <input type="number" id="duration" name="duration" min="10" max="600" value="180">
809
+ <small>Recommended: 2-3 minutes of clear speech</small>
810
+ </div>
811
+ <div class="form-group">
812
+ <label for="youtube-description">Description (Optional)</label>
813
+ <input type="text" id="youtube-description" name="description" placeholder="A description of this voice">
814
+ </div>
815
+ <button type="submit">Clone from YouTube</button>
816
+ </form>
817
+ <div id="youtube-progress" style="display: none; margin-top: 16px;">
818
+ <p>Processing YouTube video... <span id="progress-status">Downloading</span></p>
819
+ <div class="progress-bar">
820
+ <div class="progress-fill"></div>
821
+ </div>
822
+ </div>
823
+ </div>
824
+ </div>
825
+
826
+ <div id="voices-tab" class="tab-content">
827
+ <h2>My Cloned Voices</h2>
828
+ <div id="voice-list" class="voice-list">
829
+ <!-- Voice cards will be added here -->
830
+ </div>
831
+ </div>
832
+
833
+ <div id="generate-tab" class="tab-content">
834
+ <div class="card">
835
+ <h2>Generate Speech with Cloned Voice</h2>
836
+ <form id="generate-form">
837
+ <div class="form-group">
838
+ <label for="voice-select">Select Voice</label>
839
+ <select id="voice-select" name="voice" required>
840
+ <option value="">Select a voice</option>
841
+ <!-- Voice options will be added here -->
842
+ </select>
843
+ </div>
844
+ <div class="form-group">
845
+ <label for="generate-text">Text to Speak</label>
846
+ <textarea id="generate-text" name="text" rows="4" required placeholder="Enter text to synthesize with the selected voice..."></textarea>
847
+ </div>
848
+ <div class="form-group">
849
+ <label>Temperature: <span id="temperature-value">0.7</span></label>
850
+ <input type="range" id="temperature" name="temperature" min="0.5" max="1.0" step="0.05" value="0.7">
851
+ <small>Lower values (0.5-0.7) produce more consistent speech, higher values (0.8-1.0) produce more varied speech.</small>
852
+ </div>
853
+ <div class="form-group checkbox-group">
854
+ <label class="checkbox-label">
855
+ <input type="checkbox" id="use-streaming" name="use_streaming">
856
+ <span class="checkbox-text">Use streaming mode</span>
857
+ </label>
858
+ <small>Stream audio as it's generated for faster start and lower latency.</small>
859
+ </div>
860
+ <button type="submit">Generate Speech</button>
861
+ </form>
862
+ <audio id="audio-preview" controls style="display: none;"></audio>
863
+ </div>
864
+ </div>
865
+ </div>
866
+
867
+ <!-- Toast notifications container -->
868
+ <div class="toast-container" id="toast-container"></div>
869
+
870
+ <script>
871
+ // Theme toggle functionality
872
+ const toggleSwitch = document.querySelector('#checkbox');
873
+ const html = document.querySelector('html');
874
+
875
+ // Check for saved theme preference or use system preference
876
+ function getThemePreference() {
877
+ const savedTheme = localStorage.getItem('theme');
878
+ if (savedTheme) {
879
+ return savedTheme;
880
+ }
881
+ // Check if system prefers dark mode
882
+ return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
883
+ }
884
+
885
+ // Apply the theme
886
+ function setTheme(theme) {
887
+ html.setAttribute('data-theme', theme);
888
+ localStorage.setItem('theme', theme);
889
+ toggleSwitch.checked = theme === 'dark';
890
+ }
891
+
892
+ // Initialize theme
893
+ setTheme(getThemePreference());
894
+
895
+ // Listen for toggle changes
896
+ toggleSwitch.addEventListener('change', function(e) {
897
+ if (e.target.checked) {
898
+ setTheme('dark');
899
+ } else {
900
+ setTheme('light');
901
+ }
902
+ });
903
+
904
+ // Toast notification system
905
+ function showToast(message, type = 'info', duration = 5000) {
906
+ const container = document.getElementById('toast-container');
907
+ const toast = document.createElement('div');
908
+ toast.className = `toast ${type}`;
909
+ toast.innerHTML = `
910
+ <span>${message}</span>
911
+ <button class="toast-close">&times;</button>
912
+ `;
913
+ container.appendChild(toast);
914
+
915
+ // Auto remove after duration
916
+ const timeout = setTimeout(() => {
917
+ toast.style.opacity = '0';
918
+ setTimeout(() => {
919
+ container.removeChild(toast);
920
+ }, 300);
921
+ }, duration);
922
+
923
+ // Manual close
924
+ toast.querySelector('.toast-close').addEventListener('click', () => {
925
+ clearTimeout(timeout);
926
+ toast.style.opacity = '0';
927
+ setTimeout(() => {
928
+ container.removeChild(toast);
929
+ }, 300);
930
+ });
931
+ }
932
+
933
+ // Status indicator functions
934
+ function showStatus(message, type) {
935
+ const statusElem = document.getElementById('status-message');
936
+ statusElem.className = `status-indicator ${type} show`;
937
+
938
+ let icon = '';
939
+ switch(type) {
940
+ case 'success':
941
+ icon = '<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M9 12l2 2 4-4M21 12a9 9 0 11-18 0 9 9 0 0118 0z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>';
942
+ break;
943
+ case 'error':
944
+ icon = '<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M10 14l2-2m0 0l2-2m-2 2l-2-2m2 2l2 2m7-2a9 9 0 11-18 0 9 9 0 0118 0z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>';
945
+ break;
946
+ case 'warning':
947
+ icon = '<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>';
948
+ break;
949
+ case 'info':
950
+ icon = '<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>';
951
+ break;
952
+ }
953
+
954
+ statusElem.innerHTML = icon + message;
955
+
956
+ // Auto-hide after 5 seconds
957
+ setTimeout(() => {
958
+ statusElem.className = 'status-indicator';
959
+ }, 5000);
960
+ }
961
+
962
+ function hideStatus() {
963
+ const statusElem = document.getElementById('status-message');
964
+ statusElem.className = 'status-indicator';
965
+ }
966
+
967
+ // Tab functionality
968
+ const tabs = document.querySelectorAll('.tabs button');
969
+ const tabContents = document.querySelectorAll('.tab-content');
970
+
971
+ tabs.forEach(tab => {
972
+ tab.addEventListener('click', () => {
973
+ // Remove active class from all tabs
974
+ tabs.forEach(t => t.classList.remove('active'));
975
+ tabContents.forEach(tc => tc.classList.remove('active'));
976
+
977
+ // Add active class to clicked tab
978
+ tab.classList.add('active');
979
+
980
+ // Show corresponding tab content
981
+ const tabId = tab.id.replace('tab-', '');
982
+ document.getElementById(`${tabId}-tab`).classList.add('active');
983
+
984
+ // Hide any status messages when changing tabs
985
+ hideStatus();
986
+ });
987
+ });
988
+
989
+ // Temperature slider
990
+ const temperatureSlider = document.getElementById('temperature');
991
+ const temperatureValue = document.getElementById('temperature-value');
992
+
993
+ temperatureSlider.addEventListener('input', () => {
994
+ temperatureValue.textContent = temperatureSlider.value;
995
+ });
996
+
997
+ // Load voices
998
+ async function loadVoices() {
999
+ try {
1000
+ const response = await fetch('/v1/voice-cloning/voices');
1001
+ const data = await response.json();
1002
+ const voiceList = document.getElementById('voice-list');
1003
+ const voiceSelect = document.getElementById('voice-select');
1004
+
1005
+ // Clear existing content
1006
+ voiceList.innerHTML = '';
1007
+
1008
+ // Clear voice select options but keep the first one
1009
+ while (voiceSelect.options.length > 1) {
1010
+ voiceSelect.remove(1);
1011
+ }
1012
+
1013
+ if (data.voices && data.voices.length > 0) {
1014
+ data.voices.forEach(voice => {
1015
+ // Add to voice list
1016
+ const voiceCard = document.createElement('div');
1017
+ voiceCard.className = 'voice-card';
1018
+ voiceCard.innerHTML = `
1019
+ <h3 class="voice-name">${voice.name}</h3>
1020
+ <p>${voice.description || 'No description'}</p>
1021
+ <p>Created: ${new Date(voice.created_at * 1000).toLocaleString()}</p>
1022
+ <div class="controls">
1023
+ <button class="btn-secondary preview-voice" data-id="${voice.id}">Preview</button>
1024
+ <button class="btn-danger delete-voice" data-id="${voice.id}">Delete</button>
1025
+ </div>
1026
+ `;
1027
+ voiceList.appendChild(voiceCard);
1028
+
1029
+ // Add to voice select
1030
+ const option = document.createElement('option');
1031
+ option.value = voice.id;
1032
+ option.textContent = voice.name;
1033
+ voiceSelect.appendChild(option);
1034
+ });
1035
+
1036
+ // Add event listeners for preview and delete buttons
1037
+ document.querySelectorAll('.preview-voice').forEach(button => {
1038
+ button.addEventListener('click', previewVoice);
1039
+ });
1040
+ document.querySelectorAll('.delete-voice').forEach(button => {
1041
+ button.addEventListener('click', deleteVoice);
1042
+ });
1043
+
1044
+ showStatus(`Loaded ${data.voices.length} voices successfully`, 'success');
1045
+ } else {
1046
+ voiceList.innerHTML = '<p>No cloned voices yet. Create one in the "Clone Voice" tab.</p>';
1047
+ }
1048
+ } catch (error) {
1049
+ console.error('Error loading voices:', error);
1050
+ showStatus('Failed to load voices', 'error');
1051
+ }
1052
+ }
1053
+
1054
+ // Preview voice
1055
+ async function previewVoice(event) {
1056
+ const button = event.target;
1057
+ const originalText = button.textContent;
1058
+ button.disabled = true;
1059
+ button.innerHTML = '<div class="spinner"></div> Loading...';
1060
+
1061
+ const voiceId = button.dataset.id;
1062
+ const audioPreview = document.getElementById('audio-preview');
1063
+
1064
+ try {
1065
+ const response = await fetch(`/v1/voice-cloning/voices/${voiceId}/preview`, {
1066
+ method: 'POST',
1067
+ headers: {
1068
+ 'Content-Type': 'application/json'
1069
+ },
1070
+ body: JSON.stringify({
1071
+ text: "This is a preview of my cloned voice. I hope you like how it sounds!"
1072
+ })
1073
+ });
1074
+
1075
+ if (response.ok) {
1076
+ const blob = await response.blob();
1077
+ const url = URL.createObjectURL(blob);
1078
+ audioPreview.src = url;
1079
+ audioPreview.style.display = 'block';
1080
+
1081
+ // Switch to the generate tab
1082
+ document.getElementById('tab-generate').click();
1083
+
1084
+ // Set the voice in the select
1085
+ document.getElementById('voice-select').value = voiceId;
1086
+ audioPreview.play();
1087
+
1088
+ showToast('Voice preview loaded', 'success');
1089
+ } else {
1090
+ showToast('Failed to preview voice', 'error');
1091
+ }
1092
+ } catch (error) {
1093
+ console.error('Error previewing voice:', error);
1094
+ showToast('Error previewing voice', 'error');
1095
+ } finally {
1096
+ button.disabled = false;
1097
+ button.textContent = originalText;
1098
+ }
1099
+ }
1100
+
1101
+ // Delete voice
1102
+ async function deleteVoice(event) {
1103
+ if (!confirm('Are you sure you want to delete this voice? This cannot be undone.')) {
1104
+ return;
1105
+ }
1106
+
1107
+ const button = event.target;
1108
+ const originalText = button.textContent;
1109
+ button.disabled = true;
1110
+ button.innerHTML = '<div class="spinner"></div> Deleting...';
1111
+
1112
+ const voiceId = button.dataset.id;
1113
+
1114
+ try {
1115
+ const response = await fetch(`/v1/voice-cloning/voices/${voiceId}`, {
1116
+ method: 'DELETE'
1117
+ });
1118
+
1119
+ if (response.ok) {
1120
+ showToast('Voice deleted successfully', 'success');
1121
+ loadVoices();
1122
+ } else {
1123
+ showToast('Failed to delete voice', 'error');
1124
+ }
1125
+ } catch (error) {
1126
+ console.error('Error deleting voice:', error);
1127
+ showToast('Error deleting voice', 'error');
1128
+ } finally {
1129
+ button.disabled = false;
1130
+ button.textContent = originalText;
1131
+ }
1132
+ }
1133
+
1134
+ // Clone voice form submission
1135
+ document.getElementById('clone-form').addEventListener('submit', async (event) => {
1136
+ event.preventDefault();
1137
+ const formData = new FormData(event.target);
1138
+ const submitButton = event.target.querySelector('button[type="submit"]');
1139
+ const originalText = submitButton.textContent;
1140
+
1141
+ submitButton.disabled = true;
1142
+ submitButton.innerHTML = '<div class="spinner"></div> Cloning Voice...';
1143
+ showStatus('Processing your audio sample...', 'info');
1144
+
1145
+ try {
1146
+ const response = await fetch('/v1/voice-cloning/clone', {
1147
+ method: 'POST',
1148
+ body: formData
1149
+ });
1150
+
1151
+ if (response.ok) {
1152
+ const result = await response.json();
1153
+ showStatus('Voice cloned successfully!', 'success');
1154
+ showToast('Voice cloned successfully!', 'success');
1155
+ event.target.reset();
1156
+
1157
+ // Switch to the voices tab
1158
+ document.getElementById('tab-voices').click();
1159
+ loadVoices();
1160
+ } else {
1161
+ const error = await response.json();
1162
+ showStatus(`Failed to clone voice: ${error.detail}`, 'error');
1163
+ showToast('Failed to clone voice', 'error');
1164
+ }
1165
+ } catch (error) {
1166
+ console.error('Error cloning voice:', error);
1167
+ showStatus('Error processing your request', 'error');
1168
+ showToast('Error cloning voice', 'error');
1169
+ } finally {
1170
+ submitButton.disabled = false;
1171
+ submitButton.textContent = originalText;
1172
+ }
1173
+ });
1174
+
1175
+ // YouTube voice cloning form submission
1176
+ document.getElementById('youtube-clone-form').addEventListener('submit', async (event) => {
1177
+ event.preventDefault();
1178
+ const formData = new FormData(event.target);
1179
+ const youtubeUrl = formData.get('youtube_url');
1180
+ const voiceName = formData.get('voice_name');
1181
+ const startTime = parseInt(formData.get('start_time'));
1182
+ const duration = parseInt(formData.get('duration'));
1183
+ const description = formData.get('description');
1184
+ const progressDiv = document.getElementById('youtube-progress');
1185
+ const progressStatus = document.getElementById('progress-status');
1186
+ const progressFill = document.querySelector('.progress-fill');
1187
+ const submitButton = event.target.querySelector('button[type="submit"]');
1188
+ const originalText = submitButton.textContent;
1189
+
1190
+ submitButton.disabled = true;
1191
+ submitButton.innerHTML = '<div class="spinner"></div> Processing...';
1192
+ showStatus('Starting YouTube download...', 'info');
1193
+
1194
+ // Show progress bar
1195
+ progressDiv.style.display = 'block';
1196
+ progressFill.style.width = '10%';
1197
+ progressStatus.textContent = 'Downloading audio...';
1198
+
1199
+ // Simulate progress updates (since we can't get real-time updates easily)
1200
+ let progress = 10;
1201
+ const progressInterval = setInterval(() => {
1202
+ if (progress < 90) {
1203
+ progress += 5;
1204
+ progressFill.style.width = `${progress}%`;
1205
+ if (progress > 30 && progress < 60) {
1206
+ progressStatus.textContent = 'Generating transcript...';
1207
+ } else if (progress >= 60) {
1208
+ progressStatus.textContent = 'Cloning voice...';
1209
+ }
1210
+ }
1211
+ }, 1000);
1212
+
1213
+ try {
1214
+ const response = await fetch('/v1/voice-cloning/clone-from-youtube', {
1215
+ method: 'POST',
1216
+ headers: {
1217
+ 'Content-Type': 'application/json'
1218
+ },
1219
+ body: JSON.stringify({
1220
+ youtube_url: youtubeUrl,
1221
+ voice_name: voiceName,
1222
+ start_time: startTime,
1223
+ duration: duration,
1224
+ description: description
1225
+ })
1226
+ });
1227
+
1228
+ clearInterval(progressInterval);
1229
+
1230
+ if (response.ok) {
1231
+ progressFill.style.width = '100%';
1232
+ progressStatus.textContent = 'Complete!';
1233
+
1234
+ const result = await response.json();
1235
+ showStatus('Voice cloned successfully from YouTube!', 'success');
1236
+ showToast('Voice cloned from YouTube!', 'success');
1237
+
1238
+ event.target.reset();
1239
+
1240
+ // Switch to the voices tab
1241
+ document.getElementById('tab-voices').click();
1242
+ loadVoices();
1243
+ } else {
1244
+ const error = await response.json();
1245
+ showStatus(`Failed to clone voice from YouTube: ${error.detail}`, 'error');
1246
+ showToast('Failed to clone voice from YouTube', 'error');
1247
+ progressDiv.style.display = 'none';
1248
+ }
1249
+ } catch (error) {
1250
+ console.error('Error cloning voice from YouTube:', error);
1251
+ showStatus('Error processing YouTube video', 'error');
1252
+ showToast('Error cloning voice from YouTube', 'error');
1253
+ progressDiv.style.display = 'none';
1254
+ } finally {
1255
+ clearInterval(progressInterval);
1256
+ submitButton.disabled = false;
1257
+ submitButton.textContent = originalText;
1258
+ }
1259
+ });
1260
+
1261
+ // Generate speech form submission
1262
+ document.getElementById('generate-form').addEventListener('submit', async (event) => {
1263
+ event.preventDefault();
1264
+ const formData = new FormData(event.target);
1265
+ const voiceId = formData.get('voice');
1266
+ const text = formData.get('text');
1267
+ const temperature = formData.get('temperature');
1268
+ const useStreaming = formData.get('use_streaming') === 'on';
1269
+
1270
+ if (!voiceId) {
1271
+ showToast('Please select a voice', 'warning');
1272
+ return;
1273
+ }
1274
+
1275
+ const submitButton = event.target.querySelector('button[type="submit"]');
1276
+ const originalText = submitButton.textContent;
1277
+ submitButton.disabled = true;
1278
+ submitButton.innerHTML = '<div class="spinner"></div> Generating...';
1279
+ showStatus(useStreaming ? 'Streaming speech...' : 'Generating speech...', 'info');
1280
+
1281
+ try {
1282
+ const audioPreview = document.getElementById('audio-preview');
1283
+
1284
+ if (useStreaming) {
1285
+ // For streaming, we need to handle the response differently to play audio as it arrives
1286
+ try {
1287
+ // Reset audio element
1288
+ audioPreview.style.display = 'block';
1289
+
1290
+ // Prepare the request
1291
+ const requestOptions = {
1292
+ method: 'POST',
1293
+ headers: {
1294
+ 'Content-Type': 'application/json'
1295
+ },
1296
+ body: JSON.stringify({
1297
+ model: "csm-1b",
1298
+ input: text,
1299
+ voice: voiceId,
1300
+ response_format: "mp3",
1301
+ temperature: parseFloat(temperature),
1302
+ speed: 1.0
1303
+ })
1304
+ };
1305
+
1306
+ // Create a unique URL for this request to avoid caching issues
1307
+ const timestamp = new Date().getTime();
1308
+ const streamingUrl = `/v1/audio/speech/streaming?t=${timestamp}`;
1309
+
1310
+ // Fetch from streaming endpoint
1311
+ const response = await fetch(streamingUrl, requestOptions);
1312
+
1313
+ if (response.ok) {
1314
+ // Create a blob URL for immediate playback
1315
+ const blob = await response.blob();
1316
+ const url = URL.createObjectURL(blob);
1317
+
1318
+ // Set the audio source and play immediately
1319
+ audioPreview.src = url;
1320
+ audioPreview.autoplay = true;
1321
+
1322
+ // Event listeners for success/failure
1323
+ audioPreview.onplay = () => {
1324
+ showStatus('Speech streamed successfully', 'success');
1325
+ showToast('Speech streaming playback started', 'success');
1326
+ };
1327
+
1328
+ audioPreview.onerror = (e) => {
1329
+ console.error('Audio playback error:', e);
1330
+ showStatus('Error playing streamed audio', 'error');
1331
+ showToast('Streaming playback error', 'error');
1332
+ };
1333
+ } else {
1334
+ const error = await response.json();
1335
+ showStatus(`Failed to stream speech: ${error.detail || 'Unknown error'}`, 'error');
1336
+ showToast('Failed to stream speech', 'error');
1337
+ }
1338
+ } catch (error) {
1339
+ console.error('Streaming error:', error);
1340
+ showStatus(`Error streaming speech: ${error.message}`, 'error');
1341
+ showToast('Error streaming speech', 'error');
1342
+ }
1343
+ } else {
1344
+ // Non-streaming uses the original endpoint
1345
+ const response = await fetch('/v1/voice-cloning/generate', {
1346
+ method: 'POST',
1347
+ headers: {
1348
+ 'Content-Type': 'application/json'
1349
+ },
1350
+ body: JSON.stringify({
1351
+ voice_id: voiceId,
1352
+ text: text,
1353
+ temperature: parseFloat(temperature)
1354
+ })
1355
+ });
1356
+
1357
+ if (response.ok) {
1358
+ const blob = await response.blob();
1359
+ const url = URL.createObjectURL(blob);
1360
+ audioPreview.src = url;
1361
+ audioPreview.style.display = 'block';
1362
+ audioPreview.play();
1363
+ showStatus('Speech generated successfully', 'success');
1364
+ showToast('Speech generated successfully', 'success');
1365
+ } else {
1366
+ const error = await response.json();
1367
+ showStatus(`Failed to generate speech: ${error.detail}`, 'error');
1368
+ showToast('Failed to generate speech', 'error');
1369
+ }
1370
+ }
1371
+ } catch (error) {
1372
+ console.error('Error generating speech:', error);
1373
+ showStatus('Error generating speech', 'error');
1374
+ showToast('Error generating speech', 'error');
1375
+ } finally {
1376
+ submitButton.disabled = false;
1377
+ submitButton.textContent = originalText;
1378
+ }
1379
+ });
1380
+
1381
+ // Load voices on page load
1382
+ loadVoices();
1383
+ </script>
1384
+ </body>
1385
+ </html>