Spaces:
Paused
Paused
yo
Browse files- .gitignore +77 -0
- Dockerfile +92 -0
- README.md +470 -5
- app/.dockerignore +94 -0
- app/api/init.py +1 -0
- app/api/routes.py +1048 -0
- app/api/schemas.py +41 -0
- app/api/streaming.py +315 -0
- app/api/utils.py +25 -0
- app/api/voice_cloning_routes.py +289 -0
- app/audio_processing.py +230 -0
- app/custom_transformer.py +339 -0
- app/download_model.py +28 -0
- app/generator.py +834 -0
- app/main.py +647 -0
- app/models.py +13 -0
- app/prompt_engineering.py +129 -0
- app/text_normalizer.py +194 -0
- app/torchtune_models.py +282 -0
- app/utils/audio_utils.py +64 -0
- app/utils/init.py +1 -0
- app/utils/scheduled_tasks.py +42 -0
- app/utils/voice_manager.py +132 -0
- app/voice_cloning.py +705 -0
- app/voice_embeddings.py +85 -0
- app/voice_enhancement.py +581 -0
- app/voice_memory.py +474 -0
- app/watermarking.py +22 -0
- docker-compose.yml +23 -0
- requirements.txt +24 -0
- static/voice-cloning.html +1385 -0
.gitignore
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
|
23 |
+
# Environment
|
24 |
+
.env
|
25 |
+
.venv
|
26 |
+
env/
|
27 |
+
venv/
|
28 |
+
ENV/
|
29 |
+
|
30 |
+
# Model & data files
|
31 |
+
models/
|
32 |
+
tokenizers/
|
33 |
+
voice_memories/
|
34 |
+
voice_samples/
|
35 |
+
cloned_voices/
|
36 |
+
*.pt
|
37 |
+
*.pth
|
38 |
+
*.bin
|
39 |
+
*.ckpt
|
40 |
+
|
41 |
+
# Audio files
|
42 |
+
*.wav
|
43 |
+
*.mp3
|
44 |
+
*.flac
|
45 |
+
*.opus
|
46 |
+
*.aac
|
47 |
+
|
48 |
+
# Logs
|
49 |
+
logs/
|
50 |
+
*.log
|
51 |
+
log.txt
|
52 |
+
|
53 |
+
# IDE
|
54 |
+
.idea/
|
55 |
+
.vscode/
|
56 |
+
*.swp
|
57 |
+
*.swo
|
58 |
+
.DS_Store
|
59 |
+
|
60 |
+
# Temp files
|
61 |
+
tmp/
|
62 |
+
temp/
|
63 |
+
.temp/
|
64 |
+
.~*
|
65 |
+
|
66 |
+
# Project specific
|
67 |
+
.cache/
|
68 |
+
.neptune/
|
69 |
+
MANIFEST.in
|
70 |
+
.history/
|
71 |
+
.mypy_cache/
|
72 |
+
.pytest_cache/
|
73 |
+
__pycache__/
|
74 |
+
.ipynb_checkpoints/
|
75 |
+
|
76 |
+
# Ignore Hugging Face credentials
|
77 |
+
.huggingface/
|
Dockerfile
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Multi-stage build for authenticated model downloads
|
2 |
+
FROM python:3.10-slim AS model-downloader
|
3 |
+
# Install huggingface-cli
|
4 |
+
RUN pip install huggingface_hub
|
5 |
+
# Set working directory
|
6 |
+
WORKDIR /model-downloader
|
7 |
+
# Create directory for downloaded models
|
8 |
+
RUN mkdir -p /model-downloader/models
|
9 |
+
# This will run when building the image
|
10 |
+
# You'll need to pass your Hugging Face token at build time
|
11 |
+
ARG HF_TOKEN
|
12 |
+
ENV HF_TOKEN=${HF_TOKEN}
|
13 |
+
# Login and download model
|
14 |
+
RUN if [ -n "$HF_TOKEN" ]; then \
|
15 |
+
huggingface-cli login --token ${HF_TOKEN}; \
|
16 |
+
huggingface-cli download sesame/csm-1b ckpt.pt --local-dir /model-downloader/models; \
|
17 |
+
else echo "No HF_TOKEN provided, model download will be skipped"; fi
|
18 |
+
|
19 |
+
# Now for the main application stage
|
20 |
+
FROM nvidia/cuda:12.4.0-base-ubuntu22.04
|
21 |
+
# Set environment variables
|
22 |
+
ENV PYTHONFAULTHANDLER=1 \
|
23 |
+
PYTHONUNBUFFERED=1 \
|
24 |
+
PYTHONHASHSEED=random \
|
25 |
+
PIP_NO_CACHE_DIR=1 \
|
26 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
27 |
+
PIP_DEFAULT_TIMEOUT=100 \
|
28 |
+
NVIDIA_VISIBLE_DEVICES=all \
|
29 |
+
NVIDIA_DRIVER_CAPABILITIES=compute,utility \
|
30 |
+
TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6" \
|
31 |
+
TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
|
32 |
+
|
33 |
+
# Install system dependencies
|
34 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
35 |
+
python3 \
|
36 |
+
python3-pip \
|
37 |
+
python3-dev \
|
38 |
+
ffmpeg \
|
39 |
+
git \
|
40 |
+
build-essential \
|
41 |
+
&& apt-get clean \
|
42 |
+
&& rm -rf /var/lib/apt/lists/*
|
43 |
+
|
44 |
+
# Set working directory
|
45 |
+
WORKDIR /app
|
46 |
+
|
47 |
+
# Copy requirements first for better caching
|
48 |
+
COPY requirements.txt .
|
49 |
+
|
50 |
+
# Create and set up persistent directories with proper permissions
|
51 |
+
RUN mkdir -p /app/static /app/models /app/voice_memories /app/voice_references \
|
52 |
+
/app/voice_profiles /app/cloned_voices /app/audio_cache /app/tokenizers /app/logs && \
|
53 |
+
chmod -R 777 /app/voice_references /app/voice_profiles /app/voice_memories \
|
54 |
+
/app/cloned_voices /app/audio_cache /app/static /app/logs /app/tokenizers /app/models
|
55 |
+
|
56 |
+
# Copy static files
|
57 |
+
COPY ./static /app/static
|
58 |
+
|
59 |
+
# Install Python dependencies
|
60 |
+
RUN pip3 install --no-cache-dir --upgrade pip && \
|
61 |
+
pip3 install torch torchaudio numpy
|
62 |
+
|
63 |
+
# Install torchao from source
|
64 |
+
RUN pip3 install git+https://github.com/pytorch/ao.git
|
65 |
+
|
66 |
+
# Install torchtune from source with specific branch for latest features
|
67 |
+
RUN git clone https://github.com/pytorch/torchtune.git /tmp/torchtune && \
|
68 |
+
cd /tmp/torchtune && \
|
69 |
+
# Try to use the main branch, which should have llama3_2
|
70 |
+
git checkout main && \
|
71 |
+
pip install -e .
|
72 |
+
|
73 |
+
# Install remaining dependencies
|
74 |
+
RUN pip3 install -r requirements.txt
|
75 |
+
|
76 |
+
# Install additional dependencies for streaming and voice cloning
|
77 |
+
RUN pip3 install yt-dlp openai-whisper
|
78 |
+
|
79 |
+
# Copy application code
|
80 |
+
COPY ./app /app/app
|
81 |
+
|
82 |
+
# Copy downloaded model from the model-downloader stage
|
83 |
+
COPY --from=model-downloader /model-downloader/models /app/models
|
84 |
+
|
85 |
+
# Show available models in torchtune
|
86 |
+
RUN python3 -c "import torchtune.models; print('Available models in torchtune:', dir(torchtune.models))"
|
87 |
+
|
88 |
+
# Expose port
|
89 |
+
EXPOSE 8000
|
90 |
+
|
91 |
+
# Command to run the application
|
92 |
+
CMD ["python3", "-m", "app.main"]
|
README.md
CHANGED
@@ -1,10 +1,475 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
|
|
|
|
7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
---
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: CSM-1B TTS Interface # Choose a descriptive title
|
3 |
+
emoji: 🔊 # Choose an emoji (e.g., 🎤, 🔊, ✨)
|
4 |
+
colorFrom: blue # Optional: Start color for card gradient
|
5 |
+
colorTo: indigo # Optional: End color for card gradient
|
6 |
sdk: docker
|
7 |
+
sdk_version: "28.0.1" # <-- IMPORTANT: Check your requirements.txt for the exact gradio version you installed! Update this value.
|
8 |
+
app_file: app/main.py # <-- IMPORTANT: Use the EXACT filename of your main Gradio script (the one with `demo.launch()`)
|
9 |
pinned: false
|
10 |
+
# Optional: Add other configurations like python_version if needed
|
11 |
+
# python_version: "3.10"
|
12 |
+
# Optional: Specify hardware if needed (e.g., for GPU)
|
13 |
+
# hardware: cpu-upgrade # or gpu-small, gpu-a10g-small etc. Check HF pricing/docs
|
14 |
+
# Optional: Specify secrets needed (like HF_TOKEN if model download needs it)
|
15 |
+
# secrets:
|
16 |
+
# - HF_TOKEN
|
17 |
---
|
18 |
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
# CSM-1B TTS API
|
24 |
+
|
25 |
+
An OpenAI-compatible Text-to-Speech API that harnesses the power of Sesame's Conversational Speech Model (CSM-1B). This API allows you to generate high-quality speech from text using a variety of consistent voices, compatible with systems like OpenWebUI, ChatBot UI, and any platform that supports the OpenAI TTS API format.
|
26 |
+
|
27 |
+
## Features
|
28 |
+
|
29 |
+
- **OpenAI API Compatibility**: Drop-in replacement for OpenAI's TTS API
|
30 |
+
- **Multiple Voices**: Six distinct voices (alloy, echo, fable, onyx, nova, shimmer)
|
31 |
+
- **Voice Consistency**: Maintains consistent voice characteristics across multiple requests
|
32 |
+
- **Voice Cloning**: Clone your own voice from audio samples
|
33 |
+
- **Conversational Context**: Supports conversational context for improved naturalness
|
34 |
+
- **Multiple Audio Formats**: Supports MP3, OPUS, AAC, FLAC, and WAV
|
35 |
+
- **Speed Control**: Adjustable speech speed
|
36 |
+
- **CUDA Acceleration**: GPU support for faster generation
|
37 |
+
- **Web UI**: Simple interface for voice cloning and speech generation
|
38 |
+
|
39 |
+
## Getting Started
|
40 |
+
|
41 |
+
### Prerequisites
|
42 |
+
|
43 |
+
- Docker and Docker Compose
|
44 |
+
- NVIDIA GPU with CUDA support (recommended)
|
45 |
+
- Hugging Face account with access to `sesame/csm-1b` model
|
46 |
+
|
47 |
+
### Installation
|
48 |
+
|
49 |
+
1. Clone this repository:
|
50 |
+
```bash
|
51 |
+
git clone https://github.com/phildougherty/sesame_csm_openai
|
52 |
+
cd sesame_csm_openai
|
53 |
+
```
|
54 |
+
|
55 |
+
2. Create a `.env` file in the /app folder with your Hugging Face token:
|
56 |
+
```
|
57 |
+
HF_TOKEN=your_hugging_face_token_here
|
58 |
+
```
|
59 |
+
|
60 |
+
3. Build and start the container:
|
61 |
+
```bash
|
62 |
+
docker compose up -d --build
|
63 |
+
```
|
64 |
+
|
65 |
+
The server will start on port 8000. First startup may take some time as it downloads the model files.
|
66 |
+
|
67 |
+
## Hugging Face Configuration (ONLY NEEDED TO ACCEPT TERMS/DOWNLOAD MODEL)
|
68 |
+
|
69 |
+
This API requires access to the `sesame/csm-1b` model on Hugging Face:
|
70 |
+
|
71 |
+
1. Create a Hugging Face account if you don't have one: [https://huggingface.co/join](https://huggingface.co/join)
|
72 |
+
2. Accept the model license at [https://huggingface.co/sesame/csm-1b](https://huggingface.co/sesame/csm-1b)
|
73 |
+
3. Generate an access token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
74 |
+
4. Use this token in your `.env` file or pass it directly when building the container:
|
75 |
+
|
76 |
+
```bash
|
77 |
+
HF_TOKEN=your_token docker compose up -d --build
|
78 |
+
```
|
79 |
+
|
80 |
+
### Required Models
|
81 |
+
|
82 |
+
The API uses the following models which are downloaded automatically:
|
83 |
+
|
84 |
+
- **CSM-1B**: The main speech generation model from Sesame
|
85 |
+
- **Mimi**: Audio codec for high-quality audio generation
|
86 |
+
- **Llama Tokenizer**: Uses the unsloth/Llama-3.2-1B tokenizer for text processing
|
87 |
+
|
88 |
+
## Multi-GPU Support
|
89 |
+
|
90 |
+
The CSM-1B model can be distributed across multiple GPUs to handle larger models or improve performance. To enable multi-GPU support, set the `CSM_DEVICE_MAP` environment variable:
|
91 |
+
|
92 |
+
```bash
|
93 |
+
# Automatic device mapping (recommended)
|
94 |
+
CSM_DEVICE_MAP=auto docker compose up -d
|
95 |
+
|
96 |
+
# Balanced distribution of layers across GPUs
|
97 |
+
CSM_DEVICE_MAP=balanced docker compose up -d
|
98 |
+
|
99 |
+
# Sequential distribution (backbone on first GPUs, decoder on remaining)
|
100 |
+
CSM_DEVICE_MAP=sequential docker compose up -d
|
101 |
+
|
102 |
+
## Voice Cloning Guide
|
103 |
+
|
104 |
+
The CSM-1B TTS API comes with powerful voice cloning capabilities that allow you to create custom voices from audio samples. Here's how to use this feature:
|
105 |
+
|
106 |
+
### Method 1: Using the Web Interface
|
107 |
+
|
108 |
+
1. Access the voice cloning UI by navigating to `http://your-server-ip:8000/voice-cloning` in your browser.
|
109 |
+
|
110 |
+
2. **Clone a Voice**:
|
111 |
+
- Go to the "Clone Voice" tab
|
112 |
+
- Enter a name for your voice
|
113 |
+
- Upload an audio sample (2-3 minutes of clear speech works best)
|
114 |
+
- Optionally provide a transcript of the audio for better results
|
115 |
+
- Click "Clone Voice"
|
116 |
+
|
117 |
+
3. **View Your Voices**:
|
118 |
+
- Navigate to the "My Voices" tab to see all your cloned voices
|
119 |
+
- You can preview or delete voices from this tab
|
120 |
+
|
121 |
+
4. **Generate Speech**:
|
122 |
+
- Go to the "Generate Speech" tab
|
123 |
+
- Select one of your cloned voices
|
124 |
+
- Enter the text you want to synthesize
|
125 |
+
- Adjust the temperature slider if needed (lower for more consistent results)
|
126 |
+
- Click "Generate Speech" and listen to the result
|
127 |
+
|
128 |
+
### Method 2: Using the API
|
129 |
+
|
130 |
+
1. **Clone a Voice**:
|
131 |
+
```bash
|
132 |
+
curl -X POST http://localhost:8000/v1/voice-cloning/clone \
|
133 |
+
-F "name=My Voice" \
|
134 |
+
-F "audio_file=@path/to/your/voice_sample.mp3" \
|
135 |
+
-F "transcript=Optional transcript of the audio sample" \
|
136 |
+
-F "description=A description of this voice"
|
137 |
+
```
|
138 |
+
|
139 |
+
2. **List Available Cloned Voices**:
|
140 |
+
```bash
|
141 |
+
curl -X GET http://localhost:8000/v1/voice-cloning/voices
|
142 |
+
```
|
143 |
+
|
144 |
+
3. **Generate Speech with a Cloned Voice**:
|
145 |
+
```bash
|
146 |
+
curl -X POST http://localhost:8000/v1/voice-cloning/generate \
|
147 |
+
-H "Content-Type: application/json" \
|
148 |
+
-d '{
|
149 |
+
"voice_id": "1234567890_my_voice",
|
150 |
+
"text": "This is my cloned voice speaking.",
|
151 |
+
"temperature": 0.7
|
152 |
+
}' \
|
153 |
+
--output cloned_speech.mp3
|
154 |
+
```
|
155 |
+
|
156 |
+
4. **Generate a Voice Preview**:
|
157 |
+
```bash
|
158 |
+
curl -X POST http://localhost:8000/v1/voice-cloning/voices/1234567890_my_voice/preview \
|
159 |
+
--output voice_preview.mp3
|
160 |
+
```
|
161 |
+
|
162 |
+
5. **Delete a Cloned Voice**:
|
163 |
+
```bash
|
164 |
+
curl -X DELETE http://localhost:8000/v1/voice-cloning/voices/1234567890_my_voice
|
165 |
+
```
|
166 |
+
|
167 |
+
### Voice Cloning Best Practices
|
168 |
+
|
169 |
+
For the best voice cloning results:
|
170 |
+
|
171 |
+
1. **Use High-Quality Audio**: Record in a quiet environment with minimal background noise and echo.
|
172 |
+
|
173 |
+
2. **Provide Sufficient Length**: 2-3 minutes of speech provides better results than shorter samples.
|
174 |
+
|
175 |
+
3. **Clear, Natural Speech**: Speak naturally at a moderate pace with clear pronunciation.
|
176 |
+
|
177 |
+
4. **Include Various Intonations**: Sample should contain different sentence types (statements, questions) for better expressiveness.
|
178 |
+
|
179 |
+
5. **Add a Transcript**: While optional, providing an accurate transcript of your recording helps the model better capture your voice characteristics.
|
180 |
+
|
181 |
+
6. **Adjust Temperature**: For more consistent results, use lower temperature values (0.6-0.7). For more expressiveness, use higher values (0.7-0.9).
|
182 |
+
|
183 |
+
7. **Try Multiple Samples**: If you're not satisfied with the results, try recording a different sample or adjusting the speaking style.
|
184 |
+
|
185 |
+
### Using Cloned Voices with the Standard TTS Endpoint
|
186 |
+
|
187 |
+
Cloned voices are automatically available through the standard OpenAI-compatible endpoint. Simply use the voice ID or name as the `voice` parameter:
|
188 |
+
|
189 |
+
```bash
|
190 |
+
curl -X POST http://localhost:8000/v1/audio/speech \
|
191 |
+
-H "Content-Type: application/json" \
|
192 |
+
-d '{
|
193 |
+
"model": "csm-1b",
|
194 |
+
"input": "This is my cloned voice speaking through the standard endpoint.",
|
195 |
+
"voice": "1234567890_my_voice",
|
196 |
+
"response_format": "mp3"
|
197 |
+
}' \
|
198 |
+
--output cloned_speech.mp3
|
199 |
+
```
|
200 |
+
|
201 |
+
## YouTube Voice Cloning
|
202 |
+
|
203 |
+
The CSM-1B TTS API now includes the ability to clone voices directly from YouTube videos. This feature allows you to extract voice characteristics from any YouTube content and create custom TTS voices without needing to download or prepare audio samples yourself.
|
204 |
+
|
205 |
+
## How to Clone a Voice from YouTube
|
206 |
+
|
207 |
+
### API Endpoint
|
208 |
+
|
209 |
+
```
|
210 |
+
POST /v1/audio/speech/voice-cloning/youtube
|
211 |
+
```
|
212 |
+
|
213 |
+
Parameters:
|
214 |
+
- `youtube_url`: URL of the YouTube video
|
215 |
+
- `voice_name`: Name for the cloned voice
|
216 |
+
- `start_time` (optional): Start time in seconds (default: 0)
|
217 |
+
- `duration` (optional): Duration to extract in seconds (default: 180)
|
218 |
+
- `description` (optional): Description of the voice
|
219 |
+
|
220 |
+
Example request:
|
221 |
+
```json
|
222 |
+
{
|
223 |
+
"youtube_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
224 |
+
"voice_name": "rick_astley",
|
225 |
+
"start_time": 30,
|
226 |
+
"duration": 60,
|
227 |
+
"description": "Never gonna give you up"
|
228 |
+
}
|
229 |
+
```
|
230 |
+
|
231 |
+
Response:
|
232 |
+
```json
|
233 |
+
{
|
234 |
+
"voice_id": "1710805983_rick_astley",
|
235 |
+
"name": "rick_astley",
|
236 |
+
"description": "Never gonna give you up",
|
237 |
+
"created_at": "2025-03-18T22:53:03Z",
|
238 |
+
"audio_duration": 60.0,
|
239 |
+
"sample_count": 1440000
|
240 |
+
}
|
241 |
+
```
|
242 |
+
|
243 |
+
## How It Works
|
244 |
+
|
245 |
+
1. The system downloads the audio from the specified YouTube video
|
246 |
+
2. It extracts the specified segment (start time and duration)
|
247 |
+
3. Whisper ASR generates a transcript of the audio for better voice matching
|
248 |
+
4. The audio is processed to remove noise and silence
|
249 |
+
5. The voice is cloned and made available for TTS generation
|
250 |
+
|
251 |
+
## Best Practices for YouTube Voice Cloning
|
252 |
+
|
253 |
+
For optimal results:
|
254 |
+
|
255 |
+
1. **Choose Clear Speech Segments**
|
256 |
+
- Select portions of the video with clear, uninterrupted speech
|
257 |
+
- Avoid segments with background music, sound effects, or multiple speakers
|
258 |
+
|
259 |
+
2. **Optimal Duration**
|
260 |
+
- 30-60 seconds of clean speech typically provides the best results
|
261 |
+
- Longer isn't always better - quality matters more than quantity
|
262 |
+
|
263 |
+
3. **Specify Time Ranges Precisely**
|
264 |
+
- Use `start_time` and `duration` to target the exact speech segment
|
265 |
+
- Preview the segment in YouTube before cloning to ensure it's suitable
|
266 |
+
|
267 |
+
4. **Consider Audio Quality**
|
268 |
+
- Higher quality videos generally produce better voice clones
|
269 |
+
- Interviews, vlogs, and speeches often work better than highly produced content
|
270 |
+
|
271 |
+
## Limitations
|
272 |
+
|
273 |
+
- YouTube videos with heavy background music may result in lower quality voice clones
|
274 |
+
- Very noisy or low-quality audio sources will produce less accurate voice clones
|
275 |
+
- The system works best with natural speech rather than singing or exaggerated voices
|
276 |
+
- Copyright restrictions apply - only clone voices you have permission to use
|
277 |
+
|
278 |
+
## Example Use Cases
|
279 |
+
|
280 |
+
- Create a voice clone of a public figure for educational content
|
281 |
+
- Clone your own YouTube voice for consistent TTS across your applications
|
282 |
+
- Create voice clones from historical speeches or interviews (public domain)
|
283 |
+
- Develop custom voices for creative projects with proper permissions
|
284 |
+
|
285 |
+
## Ethical Considerations
|
286 |
+
|
287 |
+
Please use YouTube voice cloning responsibly:
|
288 |
+
- Only clone voices from content you have permission to use
|
289 |
+
- Respect copyright and intellectual property rights
|
290 |
+
- Clearly disclose when using AI-generated or cloned voices
|
291 |
+
- Do not use cloned voices for impersonation, deception, or harmful content
|
292 |
+
|
293 |
+
## How the Voices Work
|
294 |
+
|
295 |
+
Unlike traditional TTS systems with pre-trained voice models, CSM-1B works differently:
|
296 |
+
|
297 |
+
- The base CSM-1B model is capable of producing a wide variety of voices but doesn't have fixed voice identities
|
298 |
+
- This API creates consistent voices by using acoustic "seed" samples for each named voice
|
299 |
+
- When you specify a voice (e.g., "alloy"), the API uses a consistent acoustic seed and speaker ID
|
300 |
+
- The most recent generated audio becomes the new reference for that voice, maintaining voice consistency
|
301 |
+
- Each voice has unique tonal qualities:
|
302 |
+
- **alloy**: Balanced mid-tones with natural inflection
|
303 |
+
- **echo**: Resonant with slight reverberance
|
304 |
+
- **fable**: Brighter with higher pitch
|
305 |
+
- **onyx**: Deep and resonant
|
306 |
+
- **nova**: Warm and smooth
|
307 |
+
- **shimmer**: Light and airy with higher frequencies
|
308 |
+
|
309 |
+
The voice system can be extended with your own voice samples by using the voice cloning feature.
|
310 |
+
|
311 |
+
## API Usage
|
312 |
+
|
313 |
+
### Basic Usage
|
314 |
+
|
315 |
+
Generate speech with a POST request to `/v1/audio/speech`:
|
316 |
+
|
317 |
+
```bash
|
318 |
+
curl -X POST http://localhost:8000/v1/audio/speech \
|
319 |
+
-H "Content-Type: application/json" \
|
320 |
+
-d '{
|
321 |
+
"model": "csm-1b",
|
322 |
+
"input": "Hello, this is a test of the CSM text to speech system.",
|
323 |
+
"voice": "alloy",
|
324 |
+
"response_format": "mp3"
|
325 |
+
}' \
|
326 |
+
--output speech.mp3
|
327 |
+
```
|
328 |
+
|
329 |
+
### Available Endpoints
|
330 |
+
|
331 |
+
#### Standard TTS Endpoints
|
332 |
+
- `GET /v1/audio/models` - List available models
|
333 |
+
- `GET /v1/audio/voices` - List available voices (including cloned voices)
|
334 |
+
- `GET /v1/audio/speech/response-formats` - List available response formats
|
335 |
+
- `POST /v1/audio/speech` - Generate speech from text
|
336 |
+
- `POST /api/v1/audio/conversation` - Advanced endpoint for conversational speech
|
337 |
+
|
338 |
+
#### Voice Cloning Endpoints
|
339 |
+
- `POST /v1/voice-cloning/clone` - Clone a new voice from an audio sample
|
340 |
+
- `GET /v1/voice-cloning/voices` - List all cloned voices
|
341 |
+
- `POST /v1/voice-cloning/generate` - Generate speech with a cloned voice
|
342 |
+
- `POST /v1/voice-cloning/voices/{voice_id}/preview` - Generate a preview of a cloned voice
|
343 |
+
- `DELETE /v1/voice-cloning/voices/{voice_id}` - Delete a cloned voice
|
344 |
+
|
345 |
+
### Request Parameters
|
346 |
+
|
347 |
+
#### Standard TTS
|
348 |
+
| Parameter | Description | Type | Default |
|
349 |
+
|-----------|-------------|------|---------|
|
350 |
+
| `model` | Model ID to use | string | "csm-1b" |
|
351 |
+
| `input` | The text to convert to speech | string | Required |
|
352 |
+
| `voice` | The voice to use (standard or cloned voice ID) | string | "alloy" |
|
353 |
+
| `response_format` | Audio format | string | "mp3" |
|
354 |
+
| `speed` | Speech speed multiplier | float | 1.0 |
|
355 |
+
| `temperature` | Sampling temperature | float | 0.8 |
|
356 |
+
| `max_audio_length_ms` | Maximum audio length in ms | integer | 90000 |
|
357 |
+
|
358 |
+
#### Voice Cloning
|
359 |
+
| Parameter | Description | Type | Default |
|
360 |
+
|-----------|-------------|------|---------|
|
361 |
+
| `name` | Name for the cloned voice | string | Required |
|
362 |
+
| `audio_file` | Audio sample file | file | Required |
|
363 |
+
| `transcript` | Transcript of the audio | string | Optional |
|
364 |
+
| `description` | Description of the voice | string | Optional |
|
365 |
+
|
366 |
+
### Available Voices
|
367 |
+
|
368 |
+
- `alloy` - Balanced and natural
|
369 |
+
- `echo` - Resonant
|
370 |
+
- `fable` - Bright and higher-pitched
|
371 |
+
- `onyx` - Deep and resonant
|
372 |
+
- `nova` - Warm and smooth
|
373 |
+
- `shimmer` - Light and airy
|
374 |
+
- `[cloned voice ID]` - Any voice you've cloned using the voice cloning feature
|
375 |
+
|
376 |
+
### Response Formats
|
377 |
+
|
378 |
+
- `mp3` - MP3 audio format
|
379 |
+
- `opus` - Opus audio format
|
380 |
+
- `aac` - AAC audio format
|
381 |
+
- `flac` - FLAC audio format
|
382 |
+
- `wav` - WAV audio format
|
383 |
+
|
384 |
+
## Integration with OpenWebUI
|
385 |
+
|
386 |
+
OpenWebUI is a popular open-source UI for AI models that supports custom TTS endpoints. Here's how to integrate the CSM-1B TTS API:
|
387 |
+
|
388 |
+
1. Access your OpenWebUI settings
|
389 |
+
2. Navigate to the TTS settings section
|
390 |
+
3. Select "Custom TTS Endpoint"
|
391 |
+
4. Enter your CSM-1B TTS API URL: `http://your-server-ip:8000/v1/audio/speech`
|
392 |
+
5. Use the API Key field to add any authentication if you've configured it (not required by default)
|
393 |
+
6. Test the connection
|
394 |
+
7. Save your settings
|
395 |
+
|
396 |
+
Once configured, OpenWebUI will use your CSM-1B TTS API for all text-to-speech conversion, producing high-quality speech with the selected voice.
|
397 |
+
|
398 |
+
### Using Cloned Voices with OpenWebUI
|
399 |
+
|
400 |
+
Your cloned voices will automatically appear in OpenWebUI's voice selector. Simply choose your cloned voice from the dropdown menu in the TTS settings or chat interface.
|
401 |
+
|
402 |
+
## Advanced Usage
|
403 |
+
|
404 |
+
### Conversational Context
|
405 |
+
|
406 |
+
For more natural-sounding speech in a conversation, you can use the conversation endpoint:
|
407 |
+
|
408 |
+
```bash
|
409 |
+
curl -X POST http://localhost:8000/api/v1/audio/conversation \
|
410 |
+
-H "Content-Type: application/json" \
|
411 |
+
-d '{
|
412 |
+
"text": "Nice to meet you too!",
|
413 |
+
"speaker_id": 0,
|
414 |
+
"context": [
|
415 |
+
{
|
416 |
+
"speaker": 1,
|
417 |
+
"text": "Hello, nice to meet you.",
|
418 |
+
"audio": "BASE64_ENCODED_AUDIO"
|
419 |
+
}
|
420 |
+
]
|
421 |
+
}' \
|
422 |
+
--output response.wav
|
423 |
+
```
|
424 |
+
|
425 |
+
This allows the model to take into account the previous utterances for more contextually appropriate speech.
|
426 |
+
|
427 |
+
### Model Parameters
|
428 |
+
|
429 |
+
For fine-grained control, you can adjust:
|
430 |
+
|
431 |
+
- `temperature` (0.0-1.0): Higher values produce more variation but may be less stable
|
432 |
+
- `topk` (1-100): Controls diversity of generated speech
|
433 |
+
- `max_audio_length_ms`: Maximum length of generated audio in milliseconds
|
434 |
+
- `voice_consistency` (0.0-1.0): How strongly to maintain voice characteristics across segments
|
435 |
+
|
436 |
+
## Troubleshooting
|
437 |
+
|
438 |
+
### API Returns 503 Service Unavailable
|
439 |
+
|
440 |
+
- Verify your Hugging Face token has access to `sesame/csm-1b`
|
441 |
+
- Check if the model downloaded successfully in the logs
|
442 |
+
- Ensure you have enough GPU memory (at least 8GB recommended)
|
443 |
+
|
444 |
+
### Audio Quality Issues
|
445 |
+
|
446 |
+
- Try different voices - some may work better for your specific text
|
447 |
+
- Adjust temperature (lower for more stable output)
|
448 |
+
- For longer texts, the API automatically splits into smaller chunks for better quality
|
449 |
+
- For cloned voices, try recording a cleaner audio sample
|
450 |
+
|
451 |
+
### Voice Cloning Issues
|
452 |
+
|
453 |
+
- **Poor Voice Quality**: Try recording in a quieter environment with less background noise
|
454 |
+
- **Inconsistent Voice**: Provide a longer and more varied audio sample (2-3 minutes)
|
455 |
+
- **Accent Issues**: Make sure your sample contains similar words/sounds to what you'll be generating
|
456 |
+
- **Low Volume**: The sample is normalized automatically, but ensure it's not too quiet or distorted
|
457 |
+
|
458 |
+
### Voice Inconsistency
|
459 |
+
|
460 |
+
- The API maintains voice consistency across separate requests
|
461 |
+
- However, very long pauses between requests may result in voice drift
|
462 |
+
- For critical applications, consider using the same seed audio
|
463 |
+
|
464 |
+
## License
|
465 |
+
|
466 |
+
This project is released under the MIT License. The CSM-1B model is subject to its own license terms defined by Sesame.
|
467 |
+
|
468 |
+
## Acknowledgments
|
469 |
+
|
470 |
+
- [Sesame](https://www.sesame.com) for releasing the CSM-1B model
|
471 |
+
- This project is not affiliated with or endorsed by Sesame or OpenAI
|
472 |
+
|
473 |
+
---
|
474 |
+
|
475 |
+
Happy speech generating!
|
app/.dockerignore
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Version control
|
2 |
+
.git
|
3 |
+
.gitignore
|
4 |
+
.github
|
5 |
+
|
6 |
+
# Python
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
*.so
|
11 |
+
.Python
|
12 |
+
build/
|
13 |
+
develop-eggs/
|
14 |
+
dist/
|
15 |
+
downloads/
|
16 |
+
eggs/
|
17 |
+
.eggs/
|
18 |
+
lib/
|
19 |
+
lib64/
|
20 |
+
parts/
|
21 |
+
sdist/
|
22 |
+
var/
|
23 |
+
wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
|
28 |
+
# Environment
|
29 |
+
.env
|
30 |
+
.venv
|
31 |
+
env/
|
32 |
+
venv/
|
33 |
+
ENV/
|
34 |
+
|
35 |
+
# Docker related
|
36 |
+
Dockerfile*
|
37 |
+
docker-compose*
|
38 |
+
.docker
|
39 |
+
.dockerignore
|
40 |
+
|
41 |
+
# Documentation
|
42 |
+
README.md
|
43 |
+
CHANGELOG.md
|
44 |
+
LICENSE
|
45 |
+
docs/
|
46 |
+
*.md
|
47 |
+
examples/
|
48 |
+
tests/
|
49 |
+
|
50 |
+
# IDE
|
51 |
+
.idea/
|
52 |
+
.vscode/
|
53 |
+
*.swp
|
54 |
+
*.swo
|
55 |
+
.DS_Store
|
56 |
+
|
57 |
+
# Logs and temp files
|
58 |
+
logs/
|
59 |
+
*.log
|
60 |
+
log.txt
|
61 |
+
tmp/
|
62 |
+
temp/
|
63 |
+
.temp/
|
64 |
+
.~*
|
65 |
+
|
66 |
+
# Model files - we'll download these in the container
|
67 |
+
models/
|
68 |
+
tokenizers/
|
69 |
+
voice_memories/
|
70 |
+
voice_samples/
|
71 |
+
*.pt
|
72 |
+
*.pth
|
73 |
+
*.bin
|
74 |
+
*.ckpt
|
75 |
+
|
76 |
+
# Audio files
|
77 |
+
*.wav
|
78 |
+
*.mp3
|
79 |
+
*.flac
|
80 |
+
*.opus
|
81 |
+
*.aac
|
82 |
+
|
83 |
+
# Project specific
|
84 |
+
.cache/
|
85 |
+
.neptune/
|
86 |
+
MANIFEST.in
|
87 |
+
.history/
|
88 |
+
.mypy_cache/
|
89 |
+
.pytest_cache/
|
90 |
+
__pycache__/
|
91 |
+
.ipynb_checkpoints/
|
92 |
+
|
93 |
+
# Ignore Hugging Face credentials except specified files
|
94 |
+
.huggingface/
|
app/api/init.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""API package for CSM-1B."""
|
app/api/routes.py
ADDED
@@ -0,0 +1,1048 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
API routes for the CSM-1B TTS API.
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import io
|
6 |
+
import base64
|
7 |
+
import time
|
8 |
+
import tempfile
|
9 |
+
import logging
|
10 |
+
import asyncio
|
11 |
+
from enum import Enum
|
12 |
+
from typing import Dict, List, Optional, Any, Union
|
13 |
+
import torch
|
14 |
+
import torchaudio
|
15 |
+
import numpy as np
|
16 |
+
from fastapi import APIRouter, Request, HTTPException, BackgroundTasks, Body, Response, Query
|
17 |
+
from fastapi.responses import StreamingResponse
|
18 |
+
from app.api.schemas import SpeechRequest, ResponseFormat, Voice
|
19 |
+
from app.models import Segment
|
20 |
+
from app.api.streaming import AudioChunker
|
21 |
+
from app.prompt_engineering import split_into_segments
|
22 |
+
|
23 |
+
# Set up logging
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
router = APIRouter()
|
26 |
+
|
27 |
+
# Mapping of response_format to MIME types
|
28 |
+
MIME_TYPES = {
|
29 |
+
"mp3": "audio/mpeg",
|
30 |
+
"opus": "audio/opus",
|
31 |
+
"aac": "audio/aac",
|
32 |
+
"flac": "audio/flac",
|
33 |
+
"wav": "audio/wav",
|
34 |
+
}
|
35 |
+
|
36 |
+
def get_speaker_id(app_state, voice):
|
37 |
+
"""Helper function to get speaker ID from voice name or ID"""
|
38 |
+
if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map:
|
39 |
+
return app_state.voice_speaker_map[voice]
|
40 |
+
|
41 |
+
# Standard voices mapping
|
42 |
+
voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5}
|
43 |
+
|
44 |
+
if voice in voice_to_speaker:
|
45 |
+
return voice_to_speaker[voice]
|
46 |
+
|
47 |
+
# Try parsing as integer
|
48 |
+
try:
|
49 |
+
speaker_id = int(voice)
|
50 |
+
if 0 <= speaker_id < 6:
|
51 |
+
return speaker_id
|
52 |
+
except (ValueError, TypeError):
|
53 |
+
pass
|
54 |
+
|
55 |
+
# Check cloned voices if the voice cloner exists
|
56 |
+
if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None:
|
57 |
+
# Check by ID
|
58 |
+
if voice in app_state.voice_cloner.cloned_voices:
|
59 |
+
return app_state.voice_cloner.cloned_voices[voice].speaker_id
|
60 |
+
|
61 |
+
# Check by name
|
62 |
+
for v_id, v_info in app_state.voice_cloner.cloned_voices.items():
|
63 |
+
if v_info.name.lower() == voice.lower():
|
64 |
+
return v_info.speaker_id
|
65 |
+
|
66 |
+
# Default to alloy
|
67 |
+
return 0
|
68 |
+
|
69 |
+
@router.post("/audio/speech", tags=["Audio"], response_class=Response)
|
70 |
+
async def generate_speech(
|
71 |
+
request: Request,
|
72 |
+
speech_request: SpeechRequest,
|
73 |
+
):
|
74 |
+
"""
|
75 |
+
Generate audio of text being spoken by a realistic voice.
|
76 |
+
|
77 |
+
This endpoint is compatible with the OpenAI TTS API.
|
78 |
+
|
79 |
+
For streaming responses, use `/v1/audio/speech/streaming` instead.
|
80 |
+
"""
|
81 |
+
# Check if model is available
|
82 |
+
if not hasattr(request.app.state, "generator") or request.app.state.generator is None:
|
83 |
+
raise HTTPException(status_code=503, detail="TTS model not available")
|
84 |
+
|
85 |
+
# Set default values
|
86 |
+
model = speech_request.model
|
87 |
+
voice = speech_request.voice
|
88 |
+
input_text = speech_request.input
|
89 |
+
response_format = speech_request.response_format
|
90 |
+
speed = speech_request.speed
|
91 |
+
temperature = speech_request.temperature
|
92 |
+
max_audio_length_ms = speech_request.max_audio_length_ms
|
93 |
+
|
94 |
+
# Log request details
|
95 |
+
logger.info(f"TTS request: text length={len(input_text)}, voice={voice}, format={response_format}")
|
96 |
+
|
97 |
+
try:
|
98 |
+
# Get speaker ID for the voice
|
99 |
+
speaker_id = get_speaker_id(request.app.state, voice)
|
100 |
+
if speaker_id is None:
|
101 |
+
raise HTTPException(status_code=400, detail=f"Voice '{voice}' not found")
|
102 |
+
|
103 |
+
# Check if this is a cloned voice
|
104 |
+
voice_info = None
|
105 |
+
cloned_voice_id = None
|
106 |
+
|
107 |
+
if hasattr(request.app.state, "get_voice_info"):
|
108 |
+
voice_info = request.app.state.get_voice_info(voice)
|
109 |
+
if voice_info and voice_info["type"] == "cloned":
|
110 |
+
cloned_voice_id = voice_info["voice_id"]
|
111 |
+
|
112 |
+
# Generate audio based on whether it's a standard or cloned voice
|
113 |
+
if cloned_voice_id is not None and hasattr(request.app.state, "voice_cloner"):
|
114 |
+
# Generate speech with cloned voice
|
115 |
+
logger.info(f"Generating speech with cloned voice ID: {cloned_voice_id}")
|
116 |
+
try:
|
117 |
+
voice_cloner = request.app.state.voice_cloner
|
118 |
+
audio = voice_cloner.generate_speech(
|
119 |
+
text=input_text,
|
120 |
+
voice_id=cloned_voice_id,
|
121 |
+
temperature=temperature,
|
122 |
+
topk=speech_request.topk or 30,
|
123 |
+
max_audio_length_ms=max_audio_length_ms
|
124 |
+
)
|
125 |
+
sample_rate = request.app.state.sample_rate
|
126 |
+
logger.info(f"Generated speech with cloned voice, length: {len(audio)/sample_rate:.2f}s")
|
127 |
+
except Exception as e:
|
128 |
+
logger.error(f"Error generating speech with cloned voice: {e}", exc_info=True)
|
129 |
+
raise HTTPException(
|
130 |
+
status_code=500,
|
131 |
+
detail=f"Failed to generate speech with cloned voice: {str(e)}"
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
# Generate speech with standard voice
|
135 |
+
# Use voice context from memory if enabled
|
136 |
+
if hasattr(request.app.state, "voice_memory_enabled") and request.app.state.voice_memory_enabled:
|
137 |
+
from app.voice_memory import get_voice_context
|
138 |
+
context = get_voice_context(voice, torch.device(request.app.state.device))
|
139 |
+
else:
|
140 |
+
context = []
|
141 |
+
|
142 |
+
# Apply optional text enhancement for better voice consistency
|
143 |
+
enhanced_text = input_text
|
144 |
+
if hasattr(request.app.state, "prompt_templates"):
|
145 |
+
from app.prompt_engineering import format_text_for_voice
|
146 |
+
enhanced_text = format_text_for_voice(input_text, voice)
|
147 |
+
|
148 |
+
# Generate audio
|
149 |
+
audio = request.app.state.generator.generate(
|
150 |
+
text=enhanced_text,
|
151 |
+
speaker=speaker_id,
|
152 |
+
context=context,
|
153 |
+
temperature=temperature,
|
154 |
+
topk=speech_request.topk or 50,
|
155 |
+
max_audio_length_ms=max_audio_length_ms
|
156 |
+
)
|
157 |
+
sample_rate = request.app.state.sample_rate
|
158 |
+
|
159 |
+
# Process audio for better quality
|
160 |
+
if hasattr(request.app.state, "voice_enhancement_enabled") and request.app.state.voice_enhancement_enabled:
|
161 |
+
from app.voice_enhancement import process_generated_audio
|
162 |
+
audio = process_generated_audio(
|
163 |
+
audio=audio,
|
164 |
+
voice_name=voice,
|
165 |
+
sample_rate=sample_rate,
|
166 |
+
text=input_text
|
167 |
+
)
|
168 |
+
|
169 |
+
# Update voice memory if enabled
|
170 |
+
if hasattr(request.app.state, "voice_memory_enabled") and request.app.state.voice_memory_enabled:
|
171 |
+
from app.voice_memory import update_voice_memory
|
172 |
+
update_voice_memory(voice, audio, input_text)
|
173 |
+
|
174 |
+
# Handle speed adjustments if not 1.0
|
175 |
+
if speed != 1.0 and speed > 0:
|
176 |
+
try:
|
177 |
+
# Adjust speed using torchaudio
|
178 |
+
effects = [
|
179 |
+
["tempo", str(speed)]
|
180 |
+
]
|
181 |
+
audio_cpu = audio.cpu()
|
182 |
+
adjusted_audio, _ = torchaudio.sox_effects.apply_effects_tensor(
|
183 |
+
audio_cpu.unsqueeze(0),
|
184 |
+
sample_rate,
|
185 |
+
effects
|
186 |
+
)
|
187 |
+
audio = adjusted_audio.squeeze(0)
|
188 |
+
logger.info(f"Adjusted speech speed to {speed}x")
|
189 |
+
except Exception as e:
|
190 |
+
logger.warning(f"Failed to adjust speech speed: {e}")
|
191 |
+
|
192 |
+
# Format the audio according to the requested format
|
193 |
+
response_data, content_type = await format_audio(
|
194 |
+
audio,
|
195 |
+
response_format,
|
196 |
+
sample_rate,
|
197 |
+
request.app.state
|
198 |
+
)
|
199 |
+
|
200 |
+
# Create and return the response
|
201 |
+
return Response(
|
202 |
+
content=response_data,
|
203 |
+
media_type=content_type,
|
204 |
+
headers={"Content-Disposition": f"attachment; filename=speech.{response_format}"}
|
205 |
+
)
|
206 |
+
|
207 |
+
except Exception as e:
|
208 |
+
logger.error(f"Error in text_to_speech: {e}", exc_info=True)
|
209 |
+
raise HTTPException(status_code=500, detail=str(e))
|
210 |
+
|
211 |
+
@router.post("/audio/speech/stream", tags=["Audio"])
|
212 |
+
async def stream_speech(request: Request, speech_request: SpeechRequest):
|
213 |
+
"""Stream audio in real-time as it's being generated."""
|
214 |
+
# Check if model is loaded
|
215 |
+
if not hasattr(request.app.state, "generator") or request.app.state.generator is None:
|
216 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
217 |
+
|
218 |
+
# Get request parameters
|
219 |
+
input_text = speech_request.input
|
220 |
+
voice = speech_request.voice
|
221 |
+
response_format = speech_request.response_format
|
222 |
+
temperature = speech_request.temperature
|
223 |
+
|
224 |
+
logger.info(f"Real-time streaming speech from text ({len(input_text)} chars) with voice '{voice}'")
|
225 |
+
|
226 |
+
# Get speaker ID for the voice
|
227 |
+
speaker_id = get_speaker_id(request.app.state, voice)
|
228 |
+
if speaker_id is None:
|
229 |
+
raise HTTPException(status_code=400, detail=f"Voice '{voice}' not found")
|
230 |
+
|
231 |
+
# Split text into very small segments for incremental generation
|
232 |
+
text_segments = split_into_segments(input_text, max_chars=50) # Smaller segments for faster first response
|
233 |
+
logger.info(f"Split text into {len(text_segments)} segments")
|
234 |
+
|
235 |
+
# Create media type based on format
|
236 |
+
media_type = {
|
237 |
+
"mp3": "audio/mpeg",
|
238 |
+
"opus": "audio/opus",
|
239 |
+
"aac": "audio/aac",
|
240 |
+
"flac": "audio/flac",
|
241 |
+
"wav": "audio/wav",
|
242 |
+
}.get(response_format, "audio/mpeg")
|
243 |
+
|
244 |
+
# For streaming, WAV works best
|
245 |
+
streaming_format = "wav"
|
246 |
+
|
247 |
+
# Set up WAV header for streaming
|
248 |
+
sample_rate = request.app.state.sample_rate
|
249 |
+
|
250 |
+
async def generate_streaming_audio():
|
251 |
+
# Get context for the voice
|
252 |
+
if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled:
|
253 |
+
voice_info = request.app.state.get_voice_info(voice)
|
254 |
+
if voice_info and voice_info["type"] == "cloned":
|
255 |
+
# Use cloned voice context
|
256 |
+
voice_cloner = request.app.state.voice_cloner
|
257 |
+
context = voice_cloner.get_voice_context(voice_info["voice_id"])
|
258 |
+
else:
|
259 |
+
# Standard voice
|
260 |
+
from app.voice_enhancement import get_voice_segments
|
261 |
+
context = get_voice_segments(voice, request.app.state.device)
|
262 |
+
else:
|
263 |
+
# Standard voice
|
264 |
+
from app.voice_enhancement import get_voice_segments
|
265 |
+
context = get_voice_segments(voice, request.app.state.device)
|
266 |
+
|
267 |
+
# Send WAV header immediately
|
268 |
+
if streaming_format == "wav":
|
269 |
+
# Create a WAV header for 16-bit mono audio
|
270 |
+
header = bytes()
|
271 |
+
# RIFF header
|
272 |
+
header += b'RIFF'
|
273 |
+
header += b'\x00\x00\x00\x00' # Placeholder for file size
|
274 |
+
header += b'WAVE'
|
275 |
+
# Format chunk
|
276 |
+
header += b'fmt '
|
277 |
+
header += (16).to_bytes(4, 'little') # Format chunk size
|
278 |
+
header += (1).to_bytes(2, 'little') # PCM format
|
279 |
+
header += (1).to_bytes(2, 'little') # Mono channel
|
280 |
+
header += (sample_rate).to_bytes(4, 'little') # Sample rate
|
281 |
+
header += (sample_rate * 2).to_bytes(4, 'little') # Byte rate
|
282 |
+
header += (2).to_bytes(2, 'little') # Block align
|
283 |
+
header += (16).to_bytes(2, 'little') # Bits per sample
|
284 |
+
# Data chunk
|
285 |
+
header += b'data'
|
286 |
+
header += b'\x00\x00\x00\x00' # Placeholder for data size
|
287 |
+
yield header
|
288 |
+
|
289 |
+
# Process each segment and stream immediately
|
290 |
+
for i, segment_text in enumerate(text_segments):
|
291 |
+
try:
|
292 |
+
logger.info(f"Generating segment {i+1}/{len(text_segments)}")
|
293 |
+
|
294 |
+
# For cloned voices, use the voice cloner
|
295 |
+
if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled:
|
296 |
+
voice_info = request.app.state.get_voice_info(voice)
|
297 |
+
if voice_info and voice_info["type"] == "cloned":
|
298 |
+
# Use cloned voice
|
299 |
+
voice_cloner = request.app.state.voice_cloner
|
300 |
+
segment_audio = await asyncio.to_thread(
|
301 |
+
voice_cloner.generate_speech,
|
302 |
+
segment_text,
|
303 |
+
voice_info["voice_id"],
|
304 |
+
temperature=temperature,
|
305 |
+
topk=30,
|
306 |
+
max_audio_length_ms=2000 # Keep it very short for fast generation
|
307 |
+
)
|
308 |
+
else:
|
309 |
+
# Use standard voice with generator
|
310 |
+
segment_audio = await asyncio.to_thread(
|
311 |
+
request.app.state.generator.generate,
|
312 |
+
segment_text,
|
313 |
+
speaker_id,
|
314 |
+
context,
|
315 |
+
max_audio_length_ms=2000, # Short for quicker generation
|
316 |
+
temperature=temperature
|
317 |
+
)
|
318 |
+
else:
|
319 |
+
# Use standard voice with generator
|
320 |
+
segment_audio = await asyncio.to_thread(
|
321 |
+
request.app.state.generator.generate,
|
322 |
+
segment_text,
|
323 |
+
speaker_id,
|
324 |
+
context,
|
325 |
+
max_audio_length_ms=2000, # Short for quicker generation
|
326 |
+
temperature=temperature
|
327 |
+
)
|
328 |
+
|
329 |
+
# Skip empty or problematic audio
|
330 |
+
if segment_audio is None or segment_audio.numel() == 0:
|
331 |
+
logger.warning(f"Empty audio for segment {i+1}")
|
332 |
+
continue
|
333 |
+
|
334 |
+
# Convert to bytes and stream immediately
|
335 |
+
buf = io.BytesIO()
|
336 |
+
audio_to_save = segment_audio.unsqueeze(0) if len(segment_audio.shape) == 1 else segment_audio
|
337 |
+
torchaudio.save(buf, audio_to_save.cpu(), sample_rate, format=streaming_format)
|
338 |
+
buf.seek(0)
|
339 |
+
|
340 |
+
# For WAV format, skip the header for all segments after the first
|
341 |
+
if streaming_format == "wav" and i > 0:
|
342 |
+
buf.seek(44) # Skip WAV header
|
343 |
+
|
344 |
+
segment_bytes = buf.read()
|
345 |
+
yield segment_bytes
|
346 |
+
|
347 |
+
# Update context with this segment for next generation
|
348 |
+
context = [
|
349 |
+
Segment(
|
350 |
+
text=segment_text,
|
351 |
+
speaker=speaker_id,
|
352 |
+
audio=segment_audio
|
353 |
+
)
|
354 |
+
]
|
355 |
+
|
356 |
+
except Exception as e:
|
357 |
+
logger.error(f"Error generating segment {i+1}: {e}")
|
358 |
+
# Continue to next segment
|
359 |
+
|
360 |
+
# Return the streaming response
|
361 |
+
return StreamingResponse(
|
362 |
+
generate_streaming_audio(),
|
363 |
+
media_type=media_type,
|
364 |
+
headers={
|
365 |
+
"X-Accel-Buffering": "no", # Prevent buffering in nginx
|
366 |
+
"Cache-Control": "no-cache, no-store, must-revalidate",
|
367 |
+
"Connection": "keep-alive",
|
368 |
+
"Transfer-Encoding": "chunked"
|
369 |
+
}
|
370 |
+
)
|
371 |
+
|
372 |
+
@router.post("/audio/speech/streaming", tags=["Audio"])
|
373 |
+
async def openai_stream_speech(
|
374 |
+
request: Request,
|
375 |
+
speech_request: SpeechRequest,
|
376 |
+
):
|
377 |
+
"""
|
378 |
+
Stream audio in OpenAI-compatible streaming format.
|
379 |
+
|
380 |
+
This endpoint is compatible with the OpenAI streaming TTS API.
|
381 |
+
"""
|
382 |
+
# Use the same logic as the stream_speech endpoint but with a different name
|
383 |
+
# to maintain the OpenAI API naming convention
|
384 |
+
return await stream_speech(request, speech_request)
|
385 |
+
|
386 |
+
async def format_audio(audio, response_format, sample_rate, app_state):
|
387 |
+
"""
|
388 |
+
Format audio according to requested format.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
audio: Audio tensor from TTS generation
|
392 |
+
response_format: Format as string or enum ('mp3', 'opus', 'aac', 'flac', 'wav')
|
393 |
+
sample_rate: Sample rate of the audio
|
394 |
+
app_state: FastAPI app state with config and cache settings
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
Tuple of (response_data, content_type)
|
398 |
+
"""
|
399 |
+
import io
|
400 |
+
import torch
|
401 |
+
import torchaudio
|
402 |
+
import tempfile
|
403 |
+
import os
|
404 |
+
import hashlib
|
405 |
+
import time
|
406 |
+
|
407 |
+
# Handle enum or string for response_format
|
408 |
+
if hasattr(response_format, 'value'):
|
409 |
+
response_format = response_format.value
|
410 |
+
|
411 |
+
# Normalize response_format to lowercase
|
412 |
+
response_format = str(response_format).lower()
|
413 |
+
|
414 |
+
# Map formats to content types
|
415 |
+
format_to_content_type = {
|
416 |
+
'mp3': 'audio/mpeg',
|
417 |
+
'opus': 'audio/opus',
|
418 |
+
'aac': 'audio/aac',
|
419 |
+
'flac': 'audio/flac',
|
420 |
+
'wav': 'audio/wav'
|
421 |
+
}
|
422 |
+
|
423 |
+
# Ensure response format is supported
|
424 |
+
if response_format not in format_to_content_type:
|
425 |
+
logger.warning(f"Unsupported format: {response_format}, defaulting to mp3")
|
426 |
+
response_format = 'mp3'
|
427 |
+
|
428 |
+
# Generate a cache key based on audio content and format
|
429 |
+
cache_enabled = getattr(app_state, "audio_cache_enabled", False)
|
430 |
+
cache_key = None
|
431 |
+
|
432 |
+
if cache_enabled:
|
433 |
+
# Generate a hash of the audio tensor for caching
|
434 |
+
audio_hash = hashlib.md5(audio.cpu().numpy().tobytes()).hexdigest()
|
435 |
+
cache_key = f"{audio_hash}_{response_format}"
|
436 |
+
cache_dir = getattr(app_state, "audio_cache_dir", "/app/audio_cache")
|
437 |
+
os.makedirs(cache_dir, exist_ok=True)
|
438 |
+
cache_path = os.path.join(cache_dir, f"{cache_key}")
|
439 |
+
|
440 |
+
# Check if we have a cache hit
|
441 |
+
if os.path.exists(cache_path):
|
442 |
+
try:
|
443 |
+
with open(cache_path, "rb") as f:
|
444 |
+
cached_data = f.read()
|
445 |
+
logger.info(f"Cache hit for {response_format} audio")
|
446 |
+
return cached_data, format_to_content_type[response_format]
|
447 |
+
except Exception as e:
|
448 |
+
logger.warning(f"Error reading from cache: {e}")
|
449 |
+
|
450 |
+
# Process audio to the required format
|
451 |
+
start_time = time.time()
|
452 |
+
|
453 |
+
# Move audio to CPU before saving
|
454 |
+
audio_cpu = audio.cpu()
|
455 |
+
|
456 |
+
# Use a temporary file for format conversion
|
457 |
+
with tempfile.NamedTemporaryFile(suffix=f".{response_format}", delete=False) as temp_file:
|
458 |
+
temp_path = temp_file.name
|
459 |
+
try:
|
460 |
+
if response_format == 'wav':
|
461 |
+
# Direct save for WAV
|
462 |
+
torchaudio.save(temp_path, audio_cpu.unsqueeze(0), sample_rate)
|
463 |
+
else:
|
464 |
+
# For other formats, first save as WAV then convert
|
465 |
+
wav_path = f"{temp_path}.wav"
|
466 |
+
torchaudio.save(wav_path, audio_cpu.unsqueeze(0), sample_rate)
|
467 |
+
|
468 |
+
# Use ffmpeg via torchaudio for conversion
|
469 |
+
if hasattr(torchaudio.backend, 'sox_io_backend'): # New torchaudio structure
|
470 |
+
if response_format == 'mp3':
|
471 |
+
# For MP3, use higher quality
|
472 |
+
sox_effects = torchaudio.sox_effects.SoxEffectsChain()
|
473 |
+
sox_effects.set_input_file(wav_path)
|
474 |
+
sox_effects.append_effect_to_chain(["rate", f"{sample_rate}"])
|
475 |
+
# Higher bitrate for better quality
|
476 |
+
sox_effects.append_effect_to_chain(["gain", "-n"]) # Normalize
|
477 |
+
out, _ = sox_effects.sox_build_flow_effects()
|
478 |
+
torchaudio.save(temp_path, out, sample_rate, format="mp3", compression=128)
|
479 |
+
elif response_format == 'opus':
|
480 |
+
# Use ffmpeg for opus through a system call
|
481 |
+
import subprocess
|
482 |
+
subprocess.run([
|
483 |
+
"ffmpeg", "-i", wav_path, "-c:a", "libopus",
|
484 |
+
"-b:a", "64k", "-vbr", "on", temp_path,
|
485 |
+
"-y", "-loglevel", "error"
|
486 |
+
], check=True)
|
487 |
+
elif response_format == 'aac':
|
488 |
+
# Use ffmpeg for AAC through a system call
|
489 |
+
import subprocess
|
490 |
+
subprocess.run([
|
491 |
+
"ffmpeg", "-i", wav_path, "-c:a", "aac",
|
492 |
+
"-b:a", "128k", temp_path,
|
493 |
+
"-y", "-loglevel", "error"
|
494 |
+
], check=True)
|
495 |
+
elif response_format == 'flac':
|
496 |
+
torchaudio.save(temp_path, audio_cpu.unsqueeze(0), sample_rate, format="flac")
|
497 |
+
else:
|
498 |
+
# Fallback using external command
|
499 |
+
import subprocess
|
500 |
+
if response_format == 'mp3':
|
501 |
+
subprocess.run([
|
502 |
+
"ffmpeg", "-i", wav_path, "-codec:a", "libmp3lame",
|
503 |
+
"-qscale:a", "2", temp_path,
|
504 |
+
"-y", "-loglevel", "error"
|
505 |
+
], check=True)
|
506 |
+
elif response_format == 'opus':
|
507 |
+
subprocess.run([
|
508 |
+
"ffmpeg", "-i", wav_path, "-c:a", "libopus",
|
509 |
+
"-b:a", "64k", "-vbr", "on", temp_path,
|
510 |
+
"-y", "-loglevel", "error"
|
511 |
+
], check=True)
|
512 |
+
elif response_format == 'aac':
|
513 |
+
subprocess.run([
|
514 |
+
"ffmpeg", "-i", wav_path, "-c:a", "aac",
|
515 |
+
"-b:a", "128k", temp_path,
|
516 |
+
"-y", "-loglevel", "error"
|
517 |
+
], check=True)
|
518 |
+
elif response_format == 'flac':
|
519 |
+
subprocess.run([
|
520 |
+
"ffmpeg", "-i", wav_path, "-c:a", "flac", temp_path,
|
521 |
+
"-y", "-loglevel", "error"
|
522 |
+
], check=True)
|
523 |
+
|
524 |
+
# Clean up the temporary WAV file
|
525 |
+
try:
|
526 |
+
os.unlink(wav_path)
|
527 |
+
except:
|
528 |
+
pass
|
529 |
+
|
530 |
+
# Read the processed audio file
|
531 |
+
with open(temp_path, "rb") as f:
|
532 |
+
response_data = f.read()
|
533 |
+
|
534 |
+
# Store in cache if enabled
|
535 |
+
if cache_enabled and cache_key:
|
536 |
+
try:
|
537 |
+
cache_path = os.path.join(getattr(app_state, "audio_cache_dir", "/app/audio_cache"), f"{cache_key}")
|
538 |
+
with open(cache_path, "wb") as f:
|
539 |
+
f.write(response_data)
|
540 |
+
logger.debug(f"Cached {response_format} audio with key: {cache_key}")
|
541 |
+
except Exception as e:
|
542 |
+
logger.warning(f"Error writing to cache: {e}")
|
543 |
+
|
544 |
+
# Log processing time
|
545 |
+
processing_time = time.time() - start_time
|
546 |
+
logger.info(f"Processed audio to {response_format} in {processing_time:.3f}s")
|
547 |
+
|
548 |
+
return response_data, format_to_content_type[response_format]
|
549 |
+
|
550 |
+
except Exception as e:
|
551 |
+
logger.error(f"Error converting audio to {response_format}: {e}")
|
552 |
+
# Fallback to WAV if conversion fails
|
553 |
+
try:
|
554 |
+
wav_path = f"{temp_path}.wav"
|
555 |
+
torchaudio.save(wav_path, audio_cpu.unsqueeze(0), sample_rate)
|
556 |
+
with open(wav_path, "rb") as f:
|
557 |
+
response_data = f.read()
|
558 |
+
os.unlink(wav_path)
|
559 |
+
return response_data, "audio/wav"
|
560 |
+
except Exception as fallback_error:
|
561 |
+
logger.error(f"Fallback to WAV also failed: {fallback_error}")
|
562 |
+
raise RuntimeError(f"Failed to generate audio in any format: {str(e)}")
|
563 |
+
|
564 |
+
finally:
|
565 |
+
# Clean up the temporary file
|
566 |
+
try:
|
567 |
+
os.unlink(temp_path)
|
568 |
+
except:
|
569 |
+
pass
|
570 |
+
|
571 |
+
@router.post("/audio/conversation", tags=["Conversation API"])
|
572 |
+
async def conversation_to_speech(
|
573 |
+
request: Request,
|
574 |
+
text: str = Body(..., description="Text to convert to speech"),
|
575 |
+
speaker_id: int = Body(0, description="Speaker ID"),
|
576 |
+
context: List[Dict] = Body([], description="Context segments with speaker, text, and audio path"),
|
577 |
+
):
|
578 |
+
"""
|
579 |
+
Custom endpoint for conversational TTS using CSM-1B.
|
580 |
+
|
581 |
+
This is not part of the OpenAI API but provides the unique conversational
|
582 |
+
capability of the CSM model.
|
583 |
+
"""
|
584 |
+
# Get generator from app state
|
585 |
+
generator = request.app.state.generator
|
586 |
+
|
587 |
+
# Validate model availability
|
588 |
+
if generator is None:
|
589 |
+
raise HTTPException(status_code=503, detail="Model not loaded")
|
590 |
+
|
591 |
+
try:
|
592 |
+
segments = []
|
593 |
+
|
594 |
+
# Process context if provided
|
595 |
+
for ctx in context:
|
596 |
+
if 'speaker' not in ctx or 'text' not in ctx or 'audio' not in ctx:
|
597 |
+
continue
|
598 |
+
|
599 |
+
# Audio should be base64-encoded
|
600 |
+
audio_data = base64.b64decode(ctx['audio'])
|
601 |
+
audio_file = io.BytesIO(audio_data)
|
602 |
+
|
603 |
+
# Save to temporary file for torchaudio
|
604 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp:
|
605 |
+
temp.write(audio_file.read())
|
606 |
+
temp_path = temp.name
|
607 |
+
|
608 |
+
# Load audio
|
609 |
+
audio_tensor, sample_rate = torchaudio.load(temp_path)
|
610 |
+
audio_tensor = torchaudio.functional.resample(
|
611 |
+
audio_tensor.squeeze(0),
|
612 |
+
orig_freq=sample_rate,
|
613 |
+
new_freq=generator.sample_rate
|
614 |
+
)
|
615 |
+
|
616 |
+
# Clean up
|
617 |
+
os.unlink(temp_path)
|
618 |
+
|
619 |
+
# Create segment
|
620 |
+
segments.append(
|
621 |
+
Segment(
|
622 |
+
speaker=ctx['speaker'],
|
623 |
+
text=ctx['text'],
|
624 |
+
audio=audio_tensor
|
625 |
+
)
|
626 |
+
)
|
627 |
+
|
628 |
+
logger.info(f"Conversation request: '{text}' with {len(segments)} context segments")
|
629 |
+
|
630 |
+
# Format the text for better voice consistency
|
631 |
+
from app.prompt_engineering import format_text_for_voice
|
632 |
+
|
633 |
+
# Determine voice name from speaker_id
|
634 |
+
voice_names = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
635 |
+
voice_name = voice_names[speaker_id] if 0 <= speaker_id < len(voice_names) else "alloy"
|
636 |
+
|
637 |
+
formatted_text = format_text_for_voice(text, voice_name)
|
638 |
+
|
639 |
+
# Generate audio with context
|
640 |
+
audio = generator.generate(
|
641 |
+
text=formatted_text,
|
642 |
+
speaker=speaker_id,
|
643 |
+
context=segments,
|
644 |
+
max_audio_length_ms=20000, # 20 seconds
|
645 |
+
temperature=0.7, # Lower temperature for more stable output
|
646 |
+
topk=40,
|
647 |
+
)
|
648 |
+
|
649 |
+
# Process audio for better quality
|
650 |
+
from app.voice_enhancement import process_generated_audio
|
651 |
+
|
652 |
+
processed_audio = process_generated_audio(
|
653 |
+
audio,
|
654 |
+
voice_name,
|
655 |
+
generator.sample_rate,
|
656 |
+
text
|
657 |
+
)
|
658 |
+
|
659 |
+
# Save to temporary file
|
660 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp:
|
661 |
+
temp_path = temp.name
|
662 |
+
|
663 |
+
# Save audio
|
664 |
+
torchaudio.save(temp_path, processed_audio.unsqueeze(0).cpu(), generator.sample_rate)
|
665 |
+
|
666 |
+
# Return audio file
|
667 |
+
def iterfile():
|
668 |
+
with open(temp_path, 'rb') as f:
|
669 |
+
yield from f
|
670 |
+
# Clean up
|
671 |
+
if os.path.exists(temp_path):
|
672 |
+
os.unlink(temp_path)
|
673 |
+
|
674 |
+
logger.info(f"Generated conversation response, duration: {processed_audio.shape[0]/generator.sample_rate:.2f}s")
|
675 |
+
|
676 |
+
return StreamingResponse(
|
677 |
+
iterfile(),
|
678 |
+
media_type="audio/wav",
|
679 |
+
headers={'Content-Disposition': 'attachment; filename="speech.wav"'}
|
680 |
+
)
|
681 |
+
|
682 |
+
except Exception as e:
|
683 |
+
import traceback
|
684 |
+
error_trace = traceback.format_exc()
|
685 |
+
logger.error(f"Conversation speech generation failed: {str(e)}\n{error_trace}")
|
686 |
+
raise HTTPException(status_code=500, detail=f"Conversation speech generation failed: {str(e)}")
|
687 |
+
|
688 |
+
@router.get("/audio/voices", tags=["Audio"])
|
689 |
+
async def list_voices(request: Request):
|
690 |
+
"""
|
691 |
+
List available voices in a format compatible with OpenAI and OpenWebUI.
|
692 |
+
"""
|
693 |
+
# Use app state's get_all_voices function if available
|
694 |
+
if hasattr(request.app.state, "get_all_voices"):
|
695 |
+
voices = request.app.state.get_all_voices()
|
696 |
+
logger.info(f"Listing {len(voices)} voices")
|
697 |
+
return {"voices": voices}
|
698 |
+
|
699 |
+
# Fallback to standard voices if necessary
|
700 |
+
standard_voices = [
|
701 |
+
{"voice_id": "alloy", "name": "Alloy"},
|
702 |
+
{"voice_id": "echo", "name": "Echo"},
|
703 |
+
{"voice_id": "fable", "name": "Fable"},
|
704 |
+
{"voice_id": "onyx", "name": "Onyx"},
|
705 |
+
{"voice_id": "nova", "name": "Nova"},
|
706 |
+
{"voice_id": "shimmer", "name": "Shimmer"}
|
707 |
+
]
|
708 |
+
|
709 |
+
# Add cloned voices if available
|
710 |
+
if hasattr(request.app.state, "voice_cloner") and request.app.state.voice_cloner is not None:
|
711 |
+
cloned_voices = request.app.state.voice_cloner.list_voices()
|
712 |
+
for voice in cloned_voices:
|
713 |
+
standard_voices.append({
|
714 |
+
"voice_id": voice.id, # This has to be specifically voice_id
|
715 |
+
"name": voice.name # This has to be specifically name
|
716 |
+
})
|
717 |
+
|
718 |
+
logger.info(f"Listing {len(standard_voices)} voices")
|
719 |
+
return {"voices": standard_voices}
|
720 |
+
|
721 |
+
# Add OpenAI-compatible models list endpoint
|
722 |
+
@router.get("/audio/models", tags=["Audio"], summary="List available audio models")
|
723 |
+
async def list_models():
|
724 |
+
"""
|
725 |
+
OpenAI compatible endpoint that returns a list of available audio models.
|
726 |
+
"""
|
727 |
+
models = [
|
728 |
+
{
|
729 |
+
"id": "csm-1b",
|
730 |
+
"name": "CSM-1B",
|
731 |
+
"description": "Conversational Speech Model 1B from Sesame",
|
732 |
+
"created": 1716019200, # March 13, 2025 (from the example)
|
733 |
+
"object": "audio",
|
734 |
+
"owned_by": "sesame",
|
735 |
+
"capabilities": {
|
736 |
+
"tts": True,
|
737 |
+
"voice_generation": True,
|
738 |
+
"voice_cloning": hasattr(router.app, "voice_cloner"),
|
739 |
+
"streaming": True
|
740 |
+
},
|
741 |
+
"max_input_length": 4096,
|
742 |
+
"price": {"text-to-speech": 0.00}
|
743 |
+
},
|
744 |
+
{
|
745 |
+
"id": "tts-1",
|
746 |
+
"name": "CSM-1B (Compatibility Mode)",
|
747 |
+
"description": "CSM-1B with OpenAI TTS-1 compatibility",
|
748 |
+
"created": 1716019200,
|
749 |
+
"object": "audio",
|
750 |
+
"owned_by": "sesame",
|
751 |
+
"capabilities": {
|
752 |
+
"tts": True,
|
753 |
+
"voice_generation": True,
|
754 |
+
"streaming": True
|
755 |
+
},
|
756 |
+
"max_input_length": 4096,
|
757 |
+
"price": {"text-to-speech": 0.00}
|
758 |
+
},
|
759 |
+
{
|
760 |
+
"id": "tts-1-hd",
|
761 |
+
"name": "CSM-1B (HD Mode)",
|
762 |
+
"description": "CSM-1B with higher quality settings",
|
763 |
+
"created": 1716019200,
|
764 |
+
"object": "audio",
|
765 |
+
"owned_by": "sesame",
|
766 |
+
"capabilities": {
|
767 |
+
"tts": True,
|
768 |
+
"voice_generation": True,
|
769 |
+
"streaming": True
|
770 |
+
},
|
771 |
+
"max_input_length": 4096,
|
772 |
+
"price": {"text-to-speech": 0.00}
|
773 |
+
}
|
774 |
+
]
|
775 |
+
|
776 |
+
return {"data": models, "object": "list"}
|
777 |
+
|
778 |
+
# Response format options endpoint
|
779 |
+
@router.get("/audio/speech/response-formats", tags=["Audio"], summary="List available response formats")
|
780 |
+
async def list_response_formats():
|
781 |
+
"""List available response formats for speech synthesis."""
|
782 |
+
formats = [
|
783 |
+
{"name": "mp3", "content_type": "audio/mpeg"},
|
784 |
+
{"name": "opus", "content_type": "audio/opus"},
|
785 |
+
{"name": "aac", "content_type": "audio/aac"},
|
786 |
+
{"name": "flac", "content_type": "audio/flac"},
|
787 |
+
{"name": "wav", "content_type": "audio/wav"}
|
788 |
+
]
|
789 |
+
|
790 |
+
return {"response_formats": formats}
|
791 |
+
|
792 |
+
# Streaming format options endpoint
|
793 |
+
@router.get("/audio/speech/stream-formats", tags=["Audio"], summary="List available streaming formats")
|
794 |
+
async def list_stream_formats():
|
795 |
+
"""List available streaming formats for TTS."""
|
796 |
+
return {
|
797 |
+
"stream_formats": [
|
798 |
+
{
|
799 |
+
"format": "mp3",
|
800 |
+
"content_type": "audio/mpeg",
|
801 |
+
"description": "MP3 audio format (streaming)"
|
802 |
+
},
|
803 |
+
{
|
804 |
+
"format": "opus",
|
805 |
+
"content_type": "audio/opus",
|
806 |
+
"description": "Opus audio format (streaming)"
|
807 |
+
},
|
808 |
+
{
|
809 |
+
"format": "aac",
|
810 |
+
"content_type": "audio/aac",
|
811 |
+
"description": "AAC audio format (streaming)"
|
812 |
+
},
|
813 |
+
{
|
814 |
+
"format": "flac",
|
815 |
+
"content_type": "audio/flac",
|
816 |
+
"description": "FLAC audio format (streaming)"
|
817 |
+
},
|
818 |
+
{
|
819 |
+
"format": "wav",
|
820 |
+
"content_type": "audio/wav",
|
821 |
+
"description": "WAV audio format (streaming)"
|
822 |
+
}
|
823 |
+
]
|
824 |
+
}
|
825 |
+
|
826 |
+
# Simple test endpoint
|
827 |
+
@router.get("/test", tags=["Utility"], summary="Test endpoint")
|
828 |
+
async def test_endpoint():
|
829 |
+
"""Simple test endpoint that returns a successful response."""
|
830 |
+
return {"status": "ok", "message": "API is working"}
|
831 |
+
|
832 |
+
# Debug endpoint
|
833 |
+
@router.get("/debug", tags=["Utility"], summary="Debug endpoint")
|
834 |
+
async def debug_info(request: Request):
|
835 |
+
"""Get debug information about the API."""
|
836 |
+
generator = request.app.state.generator
|
837 |
+
|
838 |
+
# Basic info
|
839 |
+
debug_info = {
|
840 |
+
"model_loaded": generator is not None,
|
841 |
+
"device": generator.device if generator is not None else None,
|
842 |
+
"sample_rate": generator.sample_rate if generator is not None else None,
|
843 |
+
}
|
844 |
+
|
845 |
+
# Add voice enhancement info if available
|
846 |
+
try:
|
847 |
+
from app.voice_enhancement import VOICE_PROFILES
|
848 |
+
voice_info = {}
|
849 |
+
for name, profile in VOICE_PROFILES.items():
|
850 |
+
voice_info[name] = {
|
851 |
+
"pitch_range": f"{profile.pitch_range[0]}-{profile.pitch_range[1]}Hz",
|
852 |
+
"timbre": profile.timbre,
|
853 |
+
"ref_segments": len(profile.reference_segments),
|
854 |
+
}
|
855 |
+
debug_info["voice_profiles"] = voice_info
|
856 |
+
except ImportError:
|
857 |
+
debug_info["voice_profiles"] = "Not available"
|
858 |
+
|
859 |
+
# Add voice cloning info if available
|
860 |
+
if hasattr(request.app.state, "voice_cloner"):
|
861 |
+
voice_cloner = request.app.state.voice_cloner
|
862 |
+
debug_info["voice_cloning"] = {
|
863 |
+
"enabled": True,
|
864 |
+
"cloned_voices_count": len(voice_cloner.list_voices()),
|
865 |
+
"cloned_voices": [v.name for v in voice_cloner.list_voices()]
|
866 |
+
}
|
867 |
+
else:
|
868 |
+
debug_info["voice_cloning"] = {"enabled": False}
|
869 |
+
|
870 |
+
# Add streaming info
|
871 |
+
debug_info["streaming"] = {"enabled": True}
|
872 |
+
|
873 |
+
# Add memory usage info for CUDA
|
874 |
+
if torch.cuda.is_available():
|
875 |
+
debug_info["cuda"] = {
|
876 |
+
"allocated_memory_gb": torch.cuda.memory_allocated() / 1e9,
|
877 |
+
"reserved_memory_gb": torch.cuda.memory_reserved() / 1e9,
|
878 |
+
"max_memory_gb": torch.cuda.get_device_properties(0).total_memory / 1e9,
|
879 |
+
}
|
880 |
+
|
881 |
+
return debug_info
|
882 |
+
|
883 |
+
@router.get("/voice-management/info", tags=["Voice Management"])
|
884 |
+
async def get_voice_storage_info(request: Request):
|
885 |
+
"""Get information about voice storage usage and status."""
|
886 |
+
from app.utils.voice_manager import get_voice_storage_info
|
887 |
+
return get_voice_storage_info()
|
888 |
+
|
889 |
+
@router.post("/voice-management/backup", tags=["Voice Management"])
|
890 |
+
async def create_voice_backup(request: Request):
|
891 |
+
"""Create a backup of all voice data."""
|
892 |
+
from app.utils.voice_manager import backup_voice_data
|
893 |
+
backup_path = backup_voice_data()
|
894 |
+
return {"status": "success", "backup_path": backup_path}
|
895 |
+
|
896 |
+
@router.post("/voice-management/reset-voices", tags=["Voice Management"])
|
897 |
+
async def reset_voices(request: Request):
|
898 |
+
"""Reset voices to their default state."""
|
899 |
+
from app.utils.voice_manager import restore_default_voices
|
900 |
+
backup_path = restore_default_voices()
|
901 |
+
return {"status": "success", "backup_path": backup_path, "message": "Voices reset to default state"}
|
902 |
+
|
903 |
+
@router.get("/voice-management/verify-references", tags=["Voice Management"])
|
904 |
+
async def verify_references(request: Request):
|
905 |
+
"""Check if voice references are complete and valid."""
|
906 |
+
from app.utils.voice_manager import verify_voice_references
|
907 |
+
return verify_voice_references()
|
908 |
+
|
909 |
+
# Voice diagnostics endpoint
|
910 |
+
@router.get("/debug/voices", tags=["Debug"], summary="Voice diagnostics")
|
911 |
+
async def voice_diagnostics():
|
912 |
+
"""Get diagnostic information about voice references."""
|
913 |
+
try:
|
914 |
+
from app.voice_enhancement import VOICE_PROFILES
|
915 |
+
|
916 |
+
diagnostics = {}
|
917 |
+
for name, profile in VOICE_PROFILES.items():
|
918 |
+
ref_info = []
|
919 |
+
for i, ref in enumerate(profile.reference_segments):
|
920 |
+
if ref is not None:
|
921 |
+
duration = ref.shape[0] / 24000 # Assume 24kHz
|
922 |
+
ref_info.append({
|
923 |
+
"index": i,
|
924 |
+
"duration_seconds": f"{duration:.2f}",
|
925 |
+
"samples": ref.shape[0],
|
926 |
+
"min": float(ref.min()),
|
927 |
+
"max": float(ref.max()),
|
928 |
+
"rms": float(torch.sqrt(torch.mean(ref ** 2))),
|
929 |
+
})
|
930 |
+
|
931 |
+
diagnostics[name] = {
|
932 |
+
"speaker_id": profile.speaker_id,
|
933 |
+
"pitch_range": f"{profile.pitch_range[0]}-{profile.pitch_range[1]}Hz",
|
934 |
+
"references": ref_info,
|
935 |
+
"reference_count": len(ref_info),
|
936 |
+
}
|
937 |
+
|
938 |
+
return {"diagnostics": diagnostics}
|
939 |
+
except ImportError:
|
940 |
+
return {"error": "Voice enhancement module not available"}
|
941 |
+
|
942 |
+
# Specialized debugging endpoint for speech generation
|
943 |
+
@router.post("/debug/speech", tags=["Debug"], summary="Debug speech generation")
|
944 |
+
async def debug_speech(
|
945 |
+
request: Request,
|
946 |
+
text: str = Body(..., embed=True),
|
947 |
+
voice: str = Body("alloy", embed=True),
|
948 |
+
use_enhancement: bool = Body(True, embed=True)
|
949 |
+
):
|
950 |
+
"""Debug endpoint for speech generation with enhancement options."""
|
951 |
+
generator = request.app.state.generator
|
952 |
+
|
953 |
+
if generator is None:
|
954 |
+
return {"error": "Model not loaded"}
|
955 |
+
|
956 |
+
try:
|
957 |
+
# Convert voice name to speaker ID
|
958 |
+
voice_map = {
|
959 |
+
"alloy": 0,
|
960 |
+
"echo": 1,
|
961 |
+
"fable": 2,
|
962 |
+
"onyx": 3,
|
963 |
+
"nova": 4,
|
964 |
+
"shimmer": 5
|
965 |
+
}
|
966 |
+
speaker = voice_map.get(voice, 0)
|
967 |
+
|
968 |
+
# Format text if using enhancement
|
969 |
+
if use_enhancement:
|
970 |
+
from app.prompt_engineering import format_text_for_voice
|
971 |
+
formatted_text = format_text_for_voice(text, voice)
|
972 |
+
logger.info(f"Using formatted text: {formatted_text}")
|
973 |
+
else:
|
974 |
+
formatted_text = text
|
975 |
+
|
976 |
+
# Get context if using enhancement
|
977 |
+
if use_enhancement:
|
978 |
+
from app.voice_enhancement import get_voice_segments
|
979 |
+
context = get_voice_segments(voice, generator.device)
|
980 |
+
logger.info(f"Using {len(context)} context segments")
|
981 |
+
else:
|
982 |
+
context = []
|
983 |
+
|
984 |
+
# Generate audio
|
985 |
+
start_time = time.time()
|
986 |
+
audio = generator.generate(
|
987 |
+
text=formatted_text,
|
988 |
+
speaker=speaker,
|
989 |
+
context=context,
|
990 |
+
max_audio_length_ms=10000, # 10 seconds
|
991 |
+
temperature=0.7 if use_enhancement else 0.9,
|
992 |
+
topk=40 if use_enhancement else 50,
|
993 |
+
)
|
994 |
+
generation_time = time.time() - start_time
|
995 |
+
|
996 |
+
# Process audio if using enhancement
|
997 |
+
if use_enhancement:
|
998 |
+
from app.voice_enhancement import process_generated_audio
|
999 |
+
start_time = time.time()
|
1000 |
+
processed_audio = process_generated_audio(audio, voice, generator.sample_rate, text)
|
1001 |
+
processing_time = time.time() - start_time
|
1002 |
+
else:
|
1003 |
+
processed_audio = audio
|
1004 |
+
processing_time = 0
|
1005 |
+
|
1006 |
+
# Save to temporary WAV file
|
1007 |
+
temp_path = f"/tmp/debug_speech_{voice}_{int(time.time())}.wav"
|
1008 |
+
torchaudio.save(temp_path, processed_audio.unsqueeze(0).cpu(), generator.sample_rate)
|
1009 |
+
|
1010 |
+
# Also save original if enhanced
|
1011 |
+
if use_enhancement:
|
1012 |
+
orig_path = f"/tmp/debug_speech_{voice}_original_{int(time.time())}.wav"
|
1013 |
+
torchaudio.save(orig_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
|
1014 |
+
else:
|
1015 |
+
orig_path = temp_path
|
1016 |
+
|
1017 |
+
# Calculate audio metrics
|
1018 |
+
duration = processed_audio.shape[0] / generator.sample_rate
|
1019 |
+
rms = float(torch.sqrt(torch.mean(processed_audio ** 2)))
|
1020 |
+
peak = float(processed_audio.abs().max())
|
1021 |
+
|
1022 |
+
return {
|
1023 |
+
"status": "success",
|
1024 |
+
"message": f"Audio generated successfully and saved to {temp_path}",
|
1025 |
+
"audio": {
|
1026 |
+
"duration_seconds": f"{duration:.2f}",
|
1027 |
+
"samples": processed_audio.shape[0],
|
1028 |
+
"sample_rate": generator.sample_rate,
|
1029 |
+
"rms_level": f"{rms:.3f}",
|
1030 |
+
"peak_level": f"{peak:.3f}",
|
1031 |
+
},
|
1032 |
+
"processing": {
|
1033 |
+
"enhancement_used": use_enhancement,
|
1034 |
+
"generation_time_seconds": f"{generation_time:.3f}",
|
1035 |
+
"processing_time_seconds": f"{processing_time:.3f}",
|
1036 |
+
"original_path": orig_path,
|
1037 |
+
"processed_path": temp_path,
|
1038 |
+
}
|
1039 |
+
}
|
1040 |
+
except Exception as e:
|
1041 |
+
import traceback
|
1042 |
+
error_trace = traceback.format_exc()
|
1043 |
+
logger.error(f"Debug speech generation failed: {e}\n{error_trace}")
|
1044 |
+
return {
|
1045 |
+
"status": "error",
|
1046 |
+
"message": str(e),
|
1047 |
+
"traceback": error_trace
|
1048 |
+
}
|
app/api/schemas.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app/api/schemas.py
|
2 |
+
from enum import Enum
|
3 |
+
from typing import Optional, List, Dict, Any, Union
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
# Voice options as a non-restrictive string
|
7 |
+
class Voice(str):
|
8 |
+
"""Voice options for CSM model - allowing any string value"""
|
9 |
+
pass
|
10 |
+
|
11 |
+
class ResponseFormat(str, Enum):
|
12 |
+
mp3 = "mp3"
|
13 |
+
opus = "opus"
|
14 |
+
aac = "aac"
|
15 |
+
flac = "flac"
|
16 |
+
wav = "wav"
|
17 |
+
|
18 |
+
# Create SpeechRequest for compatibility with our new code
|
19 |
+
class SpeechRequest(BaseModel):
|
20 |
+
model: Optional[str] = Field("csm-1b", description="The TTS model to use")
|
21 |
+
input: str = Field(..., description="The text to generate audio for")
|
22 |
+
voice: Optional[str] = Field("alloy", description="The voice to use for generation")
|
23 |
+
response_format: Optional[ResponseFormat] = Field(ResponseFormat.mp3, description="The format of the audio response")
|
24 |
+
speed: Optional[float] = Field(1.0, description="The speed of the audio", ge=0.25, le=4.0)
|
25 |
+
# CSM-specific parameters
|
26 |
+
max_audio_length_ms: Optional[float] = Field(90000, description="Maximum audio length in milliseconds")
|
27 |
+
temperature: Optional[float] = Field(0.9, description="Sampling temperature", ge=0.0, le=2.0)
|
28 |
+
topk: Optional[int] = Field(50, description="Top-k for sampling", ge=1, le=100)
|
29 |
+
|
30 |
+
class Config:
|
31 |
+
populate_by_name = True
|
32 |
+
extra = "ignore" # Allow extra fields without error
|
33 |
+
|
34 |
+
# Maintain TTSRequest for backward compatibility
|
35 |
+
class TTSRequest(SpeechRequest):
|
36 |
+
"""Legacy alias for SpeechRequest for backward compatibility"""
|
37 |
+
pass
|
38 |
+
|
39 |
+
class TTSResponse(BaseModel):
|
40 |
+
"""Only used for API documentation"""
|
41 |
+
pass
|
app/api/streaming.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Streaming support for the TTS API."""
|
2 |
+
import asyncio
|
3 |
+
import io
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
from typing import AsyncGenerator, Optional, List
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from fastapi import APIRouter, Request, HTTPException
|
10 |
+
from fastapi.responses import StreamingResponse
|
11 |
+
from app.api.schemas import SpeechRequest, ResponseFormat
|
12 |
+
from app.prompt_engineering import split_into_segments
|
13 |
+
from app.models import Segment
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
router = APIRouter()
|
17 |
+
|
18 |
+
class AudioChunker:
|
19 |
+
"""Handle audio chunking for streaming responses."""
|
20 |
+
def __init__(self,
|
21 |
+
sample_rate: int,
|
22 |
+
format: str = "mp3",
|
23 |
+
chunk_size_ms: int = 200): # Smaller chunks for better streaming
|
24 |
+
"""
|
25 |
+
Initialize audio chunker.
|
26 |
+
Args:
|
27 |
+
sample_rate: Audio sample rate in Hz
|
28 |
+
format: Output audio format (mp3, opus, etc.)
|
29 |
+
chunk_size_ms: Size of each chunk in milliseconds
|
30 |
+
"""
|
31 |
+
self.sample_rate = sample_rate
|
32 |
+
self.format = format.lower()
|
33 |
+
self.chunk_size_samples = int(sample_rate * (chunk_size_ms / 1000))
|
34 |
+
logger.info(f"Audio chunker initialized with {chunk_size_ms}ms chunks ({self.chunk_size_samples} samples)")
|
35 |
+
|
36 |
+
async def chunk_audio(self,
|
37 |
+
audio: torch.Tensor,
|
38 |
+
delay_ms: int = 0) -> AsyncGenerator[bytes, None]:
|
39 |
+
"""
|
40 |
+
Convert audio tensor to streaming chunks.
|
41 |
+
Args:
|
42 |
+
audio: Audio tensor to stream
|
43 |
+
delay_ms: Artificial delay between chunks (for testing)
|
44 |
+
Yields:
|
45 |
+
Audio chunks as bytes
|
46 |
+
"""
|
47 |
+
# Ensure audio is on CPU
|
48 |
+
if audio.is_cuda:
|
49 |
+
audio = audio.cpu()
|
50 |
+
# Calculate number of chunks
|
51 |
+
num_samples = audio.shape[0]
|
52 |
+
num_chunks = (num_samples + self.chunk_size_samples - 1) // self.chunk_size_samples
|
53 |
+
logger.info(f"Streaming {num_samples} samples as {num_chunks} chunks")
|
54 |
+
for i in range(num_chunks):
|
55 |
+
start_idx = i * self.chunk_size_samples
|
56 |
+
end_idx = min(start_idx + self.chunk_size_samples, num_samples)
|
57 |
+
# Extract chunk
|
58 |
+
chunk = audio[start_idx:end_idx]
|
59 |
+
# Convert to bytes in requested format
|
60 |
+
chunk_bytes = await self._format_chunk(chunk)
|
61 |
+
# Add artificial delay if requested (for testing)
|
62 |
+
if delay_ms > 0:
|
63 |
+
await asyncio.sleep(delay_ms / 1000)
|
64 |
+
yield chunk_bytes
|
65 |
+
|
66 |
+
async def _format_chunk(self, chunk: torch.Tensor) -> bytes:
|
67 |
+
"""Convert audio chunk to bytes in the specified format."""
|
68 |
+
buf = io.BytesIO()
|
69 |
+
# Ensure chunk is 1D and on CPU
|
70 |
+
if len(chunk.shape) == 1:
|
71 |
+
chunk = chunk.unsqueeze(0) # Add channel dimension
|
72 |
+
# Ensure chunk is on CPU
|
73 |
+
if chunk.is_cuda:
|
74 |
+
chunk = chunk.cpu()
|
75 |
+
# Save to buffer in specified format
|
76 |
+
if self.format == "mp3":
|
77 |
+
torchaudio.save(buf, chunk, self.sample_rate, format="mp3")
|
78 |
+
elif self.format == "opus":
|
79 |
+
torchaudio.save(buf, chunk, self.sample_rate, format="opus")
|
80 |
+
elif self.format == "aac":
|
81 |
+
torchaudio.save(buf, chunk, self.sample_rate, format="aac")
|
82 |
+
elif self.format == "flac":
|
83 |
+
torchaudio.save(buf, chunk, self.sample_rate, format="flac")
|
84 |
+
elif self.format == "wav":
|
85 |
+
torchaudio.save(buf, chunk, self.sample_rate, format="wav")
|
86 |
+
else:
|
87 |
+
# Default to mp3
|
88 |
+
torchaudio.save(buf, chunk, self.sample_rate, format="mp3")
|
89 |
+
# Get bytes from buffer
|
90 |
+
buf.seek(0)
|
91 |
+
return buf.read()
|
92 |
+
|
93 |
+
# Helper function to get speaker ID for a voice
|
94 |
+
def get_speaker_id(app_state, voice):
|
95 |
+
"""Helper function to get speaker ID from voice name or ID"""
|
96 |
+
if hasattr(app_state, "voice_speaker_map") and voice in app_state.voice_speaker_map:
|
97 |
+
return app_state.voice_speaker_map[voice]
|
98 |
+
# Standard voices mapping
|
99 |
+
voice_to_speaker = {"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5}
|
100 |
+
if voice in voice_to_speaker:
|
101 |
+
return voice_to_speaker[voice]
|
102 |
+
# Try parsing as integer
|
103 |
+
try:
|
104 |
+
speaker_id = int(voice)
|
105 |
+
if 0 <= speaker_id < 6:
|
106 |
+
return speaker_id
|
107 |
+
except (ValueError, TypeError):
|
108 |
+
pass
|
109 |
+
# Check cloned voices if the voice cloner exists
|
110 |
+
if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None:
|
111 |
+
# Check by ID
|
112 |
+
if voice in app_state.voice_cloner.cloned_voices:
|
113 |
+
return app_state.voice_cloner.cloned_voices[voice].speaker_id
|
114 |
+
# Check by name
|
115 |
+
for v_id, v_info in app_state.voice_cloner.cloned_voices.items():
|
116 |
+
if v_info.name.lower() == voice.lower():
|
117 |
+
return v_info.speaker_id
|
118 |
+
# Default to alloy
|
119 |
+
return 0
|
120 |
+
|
121 |
+
@router.post("/audio/speech/stream", tags=["Audio"])
|
122 |
+
async def stream_speech(
|
123 |
+
request: Request,
|
124 |
+
speech_request: SpeechRequest,
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
Stream audio of text being spoken by a realistic voice.
|
128 |
+
This endpoint provides an OpenAI-compatible streaming interface for TTS.
|
129 |
+
"""
|
130 |
+
# Check if model is loaded
|
131 |
+
if not hasattr(request.app.state, "generator") or request.app.state.generator is None:
|
132 |
+
raise HTTPException(
|
133 |
+
status_code=503,
|
134 |
+
detail="Model not loaded. Please try again later."
|
135 |
+
)
|
136 |
+
|
137 |
+
# Get request parameters
|
138 |
+
model = speech_request.model
|
139 |
+
input_text = speech_request.input
|
140 |
+
voice = speech_request.voice
|
141 |
+
response_format = speech_request.response_format
|
142 |
+
speed = speech_request.speed
|
143 |
+
temperature = speech_request.temperature
|
144 |
+
max_audio_length_ms = speech_request.max_audio_length_ms
|
145 |
+
|
146 |
+
# Log the request
|
147 |
+
logger.info(f"Real-time streaming speech from text ({len(input_text)} chars) with voice '{voice}'")
|
148 |
+
|
149 |
+
# Check if text is empty
|
150 |
+
if not input_text or len(input_text.strip()) == 0:
|
151 |
+
raise HTTPException(
|
152 |
+
status_code=400,
|
153 |
+
detail="Input text cannot be empty"
|
154 |
+
)
|
155 |
+
|
156 |
+
# Get speaker ID for the voice
|
157 |
+
speaker_id = get_speaker_id(request.app.state, voice)
|
158 |
+
if speaker_id is None:
|
159 |
+
raise HTTPException(
|
160 |
+
status_code=400,
|
161 |
+
detail=f"Voice '{voice}' not found. Available voices: {request.app.state.available_voices}"
|
162 |
+
)
|
163 |
+
|
164 |
+
try:
|
165 |
+
# Create media type based on format
|
166 |
+
media_type = {
|
167 |
+
"mp3": "audio/mpeg",
|
168 |
+
"opus": "audio/opus",
|
169 |
+
"aac": "audio/aac",
|
170 |
+
"flac": "audio/flac",
|
171 |
+
"wav": "audio/wav",
|
172 |
+
}.get(response_format, "audio/mpeg")
|
173 |
+
|
174 |
+
# Create the chunker for streaming
|
175 |
+
sample_rate = request.app.state.sample_rate
|
176 |
+
chunker = AudioChunker(sample_rate, response_format)
|
177 |
+
|
178 |
+
# Split text into segments using the imported function
|
179 |
+
from app.prompt_engineering import split_into_segments
|
180 |
+
text_segments = split_into_segments(input_text, max_chars=50) # Smaller segments for faster first response
|
181 |
+
|
182 |
+
logger.info(f"Split text into {len(text_segments)} segments for incremental streaming")
|
183 |
+
|
184 |
+
async def generate_streaming_audio():
|
185 |
+
# Check for cloned voice
|
186 |
+
voice_info = None
|
187 |
+
from_cloned_voice = False
|
188 |
+
|
189 |
+
if hasattr(request.app.state, "voice_cloning_enabled") and request.app.state.voice_cloning_enabled:
|
190 |
+
voice_info = request.app.state.get_voice_info(voice)
|
191 |
+
from_cloned_voice = voice_info and voice_info["type"] == "cloned"
|
192 |
+
if from_cloned_voice:
|
193 |
+
# Use cloned voice context for first segment
|
194 |
+
voice_cloner = request.app.state.voice_cloner
|
195 |
+
context = voice_cloner.get_voice_context(voice_info["voice_id"])
|
196 |
+
else:
|
197 |
+
# Use standard voice context
|
198 |
+
from app.voice_enhancement import get_voice_segments
|
199 |
+
context = get_voice_segments(voice, request.app.state.device)
|
200 |
+
else:
|
201 |
+
# Use standard voice context
|
202 |
+
from app.voice_enhancement import get_voice_segments
|
203 |
+
context = get_voice_segments(voice, request.app.state.device)
|
204 |
+
|
205 |
+
# Send an empty chunk to initialize the connection
|
206 |
+
yield b''
|
207 |
+
|
208 |
+
# Process each text segment incrementally and stream in real time
|
209 |
+
for i, segment_text in enumerate(text_segments):
|
210 |
+
try:
|
211 |
+
logger.info(f"Generating segment {i+1}/{len(text_segments)}")
|
212 |
+
|
213 |
+
# Generate audio for this segment - use async to avoid blocking
|
214 |
+
if from_cloned_voice:
|
215 |
+
# Generate with cloned voice
|
216 |
+
voice_cloner = request.app.state.voice_cloner
|
217 |
+
|
218 |
+
# Convert to asynchronous with asyncio.to_thread
|
219 |
+
segment_audio = await asyncio.to_thread(
|
220 |
+
voice_cloner.generate_speech,
|
221 |
+
segment_text,
|
222 |
+
voice_info["voice_id"],
|
223 |
+
temperature=temperature,
|
224 |
+
topk=30,
|
225 |
+
max_audio_length_ms=2000 # Keep segments short for streaming
|
226 |
+
)
|
227 |
+
else:
|
228 |
+
# Use standard voice with generator
|
229 |
+
segment_audio = await asyncio.to_thread(
|
230 |
+
request.app.state.generator.generate,
|
231 |
+
segment_text,
|
232 |
+
speaker_id,
|
233 |
+
context,
|
234 |
+
max_audio_length_ms=2000, # Short for quicker generation
|
235 |
+
temperature=temperature
|
236 |
+
)
|
237 |
+
|
238 |
+
# Process audio quality for this segment
|
239 |
+
if hasattr(request.app.state, "voice_enhancement_enabled") and request.app.state.voice_enhancement_enabled:
|
240 |
+
from app.voice_enhancement import process_generated_audio
|
241 |
+
segment_audio = process_generated_audio(
|
242 |
+
audio=segment_audio,
|
243 |
+
voice_name=voice,
|
244 |
+
sample_rate=sample_rate,
|
245 |
+
text=segment_text
|
246 |
+
)
|
247 |
+
|
248 |
+
# Handle speed adjustment
|
249 |
+
if speed != 1.0 and speed > 0:
|
250 |
+
try:
|
251 |
+
# Adjust speed using torchaudio
|
252 |
+
effects = [["tempo", str(speed)]]
|
253 |
+
audio_cpu = segment_audio.cpu()
|
254 |
+
adjusted_audio, _= torchaudio.sox_effects.apply_effects_tensor(
|
255 |
+
audio_cpu.unsqueeze(0),
|
256 |
+
sample_rate,
|
257 |
+
effects
|
258 |
+
)
|
259 |
+
segment_audio = adjusted_audio.squeeze(0)
|
260 |
+
except Exception as e:
|
261 |
+
logger.warning(f"Failed to adjust speech speed: {e}")
|
262 |
+
|
263 |
+
# Convert this segment to bytes and stream immediately
|
264 |
+
buf = io.BytesIO()
|
265 |
+
audio_to_save = segment_audio.unsqueeze(0) if len(segment_audio.shape) == 1 else segment_audio
|
266 |
+
torchaudio.save(buf, audio_to_save.cpu(), sample_rate, format=response_format)
|
267 |
+
buf.seek(0)
|
268 |
+
segment_bytes = buf.read()
|
269 |
+
|
270 |
+
# Stream this segment immediately
|
271 |
+
yield segment_bytes
|
272 |
+
|
273 |
+
# Update context with this segment for next generation
|
274 |
+
context = [
|
275 |
+
Segment(
|
276 |
+
text=segment_text,
|
277 |
+
speaker=speaker_id,
|
278 |
+
audio=segment_audio
|
279 |
+
)
|
280 |
+
]
|
281 |
+
|
282 |
+
except Exception as e:
|
283 |
+
logger.error(f"Error generating segment {i+1}: {e}")
|
284 |
+
# Try to continue with next segment
|
285 |
+
|
286 |
+
# Return streaming response
|
287 |
+
return StreamingResponse(
|
288 |
+
generate_streaming_audio(),
|
289 |
+
media_type=media_type,
|
290 |
+
headers={
|
291 |
+
"Content-Disposition": f'attachment; filename="speech.{response_format}"',
|
292 |
+
"X-Accel-Buffering": "no", # Prevent buffering in nginx
|
293 |
+
"Cache-Control": "no-cache, no-store, must-revalidate", # Prevent caching
|
294 |
+
"Pragma": "no-cache",
|
295 |
+
"Expires": "0",
|
296 |
+
"Connection": "keep-alive",
|
297 |
+
"Transfer-Encoding": "chunked"
|
298 |
+
}
|
299 |
+
)
|
300 |
+
except Exception as e:
|
301 |
+
logger.error(f"Error in stream_speech: {e}")
|
302 |
+
raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
|
303 |
+
|
304 |
+
@router.post("/audio/speech/streaming", tags=["Audio"])
|
305 |
+
async def openai_stream_speech(
|
306 |
+
request: Request,
|
307 |
+
speech_request: SpeechRequest,
|
308 |
+
):
|
309 |
+
"""
|
310 |
+
Stream audio in OpenAI-compatible streaming format.
|
311 |
+
This endpoint is compatible with the OpenAI streaming TTS API.
|
312 |
+
"""
|
313 |
+
# Use the same logic as the stream_speech endpoint but with a different name
|
314 |
+
# to maintain the OpenAI API naming convention
|
315 |
+
return await stream_speech(request, speech_request)
|
app/api/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_all_voices(app_state):
|
2 |
+
"""
|
3 |
+
Get all available voices including standard and cloned voices.
|
4 |
+
Returns them in a format compatible with OpenAI's API.
|
5 |
+
"""
|
6 |
+
# Standard voices
|
7 |
+
voices = [
|
8 |
+
{"voice_id": "alloy", "name": "Alloy"},
|
9 |
+
{"voice_id": "echo", "name": "Echo"},
|
10 |
+
{"voice_id": "fable", "name": "Fable"},
|
11 |
+
{"voice_id": "onyx", "name": "Onyx"},
|
12 |
+
{"voice_id": "nova", "name": "Nova"},
|
13 |
+
{"voice_id": "shimmer", "name": "Shimmer"}
|
14 |
+
]
|
15 |
+
|
16 |
+
# Add cloned voices if available
|
17 |
+
if hasattr(app_state, "voice_cloner") and app_state.voice_cloner is not None:
|
18 |
+
cloned_voices = app_state.voice_cloner.list_voices()
|
19 |
+
for voice in cloned_voices:
|
20 |
+
voices.append({
|
21 |
+
"voice_id": voice.id,
|
22 |
+
"name": voice.name
|
23 |
+
})
|
24 |
+
|
25 |
+
return voices
|
app/api/voice_cloning_routes.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Voice cloning API routes for CSM-1B TTS API."""
|
2 |
+
import os
|
3 |
+
import io
|
4 |
+
import time
|
5 |
+
import tempfile
|
6 |
+
from typing import Dict, List, Optional, Any
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from fastapi import APIRouter, Request, Response, HTTPException, UploadFile, File, Form, Body
|
10 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
11 |
+
from app.voice_cloning import ClonedVoice
|
12 |
+
|
13 |
+
# Create router
|
14 |
+
router = APIRouter(prefix="/voice-cloning", tags=["Voice Cloning"])
|
15 |
+
|
16 |
+
@router.post("/clone", summary="Clone a new voice")
|
17 |
+
async def clone_voice(
|
18 |
+
request: Request,
|
19 |
+
audio_file: UploadFile = File(...),
|
20 |
+
name: str = Form(...),
|
21 |
+
transcript: Optional[str] = Form(None),
|
22 |
+
description: Optional[str] = Form(None)
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Clone a new voice from an audio file.
|
26 |
+
|
27 |
+
- **audio_file**: Audio file with the voice to clone (MP3, WAV, etc.)
|
28 |
+
- **name**: Name for the cloned voice
|
29 |
+
- **transcript**: Optional transcript of the audio
|
30 |
+
- **description**: Optional description of the voice
|
31 |
+
"""
|
32 |
+
if not hasattr(request.app.state, "voice_cloner"):
|
33 |
+
raise HTTPException(status_code=503, detail="Voice cloning service not available")
|
34 |
+
|
35 |
+
voice_cloner = request.app.state.voice_cloner
|
36 |
+
|
37 |
+
try:
|
38 |
+
voice = await voice_cloner.clone_voice(
|
39 |
+
audio_file=audio_file,
|
40 |
+
voice_name=name,
|
41 |
+
transcript=transcript,
|
42 |
+
description=description
|
43 |
+
)
|
44 |
+
|
45 |
+
return {
|
46 |
+
"status": "success",
|
47 |
+
"message": "Voice cloned successfully",
|
48 |
+
"voice": voice
|
49 |
+
}
|
50 |
+
except Exception as e:
|
51 |
+
import traceback
|
52 |
+
error_trace = traceback.format_exc()
|
53 |
+
request.app.state.logger.error(f"Voice cloning failed: {e}\n{error_trace}")
|
54 |
+
raise HTTPException(status_code=500, detail=f"Voice cloning failed: {str(e)}")
|
55 |
+
|
56 |
+
@router.get("/voices", summary="List cloned voices")
|
57 |
+
async def list_voices(request: Request):
|
58 |
+
"""List all cloned voices available in the system."""
|
59 |
+
if not hasattr(request.app.state, "voice_cloner"):
|
60 |
+
raise HTTPException(status_code=503, detail="Voice cloning service not available")
|
61 |
+
|
62 |
+
voice_cloner = request.app.state.voice_cloner
|
63 |
+
voices = voice_cloner.list_voices()
|
64 |
+
|
65 |
+
return {
|
66 |
+
"voices": voices
|
67 |
+
}
|
68 |
+
|
69 |
+
@router.delete("/voices/{voice_id}", summary="Delete a cloned voice")
|
70 |
+
async def delete_voice(request: Request, voice_id: str):
|
71 |
+
"""Delete a cloned voice by ID."""
|
72 |
+
if not hasattr(request.app.state, "voice_cloner"):
|
73 |
+
raise HTTPException(status_code=503, detail="Voice cloning service not available")
|
74 |
+
|
75 |
+
voice_cloner = request.app.state.voice_cloner
|
76 |
+
success = voice_cloner.delete_voice(voice_id)
|
77 |
+
|
78 |
+
if not success:
|
79 |
+
raise HTTPException(status_code=404, detail=f"Voice with ID {voice_id} not found")
|
80 |
+
|
81 |
+
return {
|
82 |
+
"status": "success",
|
83 |
+
"message": f"Voice {voice_id} deleted successfully"
|
84 |
+
}
|
85 |
+
|
86 |
+
@router.post("/clone-from-youtube", summary="Clone a voice from YouTube")
|
87 |
+
async def clone_voice_from_youtube(
|
88 |
+
request: Request,
|
89 |
+
data: dict = Body(...) # Use a single body parameter to avoid conflicts
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Clone a voice from a YouTube video.
|
93 |
+
|
94 |
+
- **youtube_url**: URL of the YouTube video
|
95 |
+
- **voice_name**: Name for the cloned voice
|
96 |
+
- **start_time**: Start time in seconds (default: 0)
|
97 |
+
- **duration**: Duration to extract in seconds (default: 180)
|
98 |
+
- **description**: Optional description of the voice
|
99 |
+
"""
|
100 |
+
if not hasattr(request.app.state, "voice_cloner"):
|
101 |
+
raise HTTPException(status_code=503, detail="Voice cloning service not available")
|
102 |
+
|
103 |
+
voice_cloner = request.app.state.voice_cloner
|
104 |
+
|
105 |
+
# Extract parameters from the request body
|
106 |
+
youtube_url = data.get("youtube_url")
|
107 |
+
voice_name = data.get("voice_name")
|
108 |
+
start_time = data.get("start_time", 0)
|
109 |
+
duration = data.get("duration", 180)
|
110 |
+
description = data.get("description")
|
111 |
+
|
112 |
+
# Validate required parameters
|
113 |
+
if not youtube_url or not voice_name:
|
114 |
+
raise HTTPException(status_code=400, detail="Missing required parameters: youtube_url and voice_name")
|
115 |
+
|
116 |
+
try:
|
117 |
+
voice = await voice_cloner.clone_voice_from_youtube(
|
118 |
+
youtube_url=youtube_url,
|
119 |
+
voice_name=voice_name,
|
120 |
+
start_time=start_time,
|
121 |
+
duration=duration,
|
122 |
+
description=description
|
123 |
+
)
|
124 |
+
|
125 |
+
return {
|
126 |
+
"status": "success",
|
127 |
+
"message": "Voice cloned successfully from YouTube",
|
128 |
+
"voice": voice
|
129 |
+
}
|
130 |
+
except Exception as e:
|
131 |
+
import traceback
|
132 |
+
error_trace = traceback.format_exc()
|
133 |
+
request.app.state.logger.error(f"YouTube voice cloning failed: {e}\n{error_trace}")
|
134 |
+
raise HTTPException(status_code=500, detail=f"YouTube voice cloning failed: {str(e)}")
|
135 |
+
|
136 |
+
@router.post("/generate", summary="Generate speech with cloned voice")
|
137 |
+
async def generate_speech(
|
138 |
+
request: Request,
|
139 |
+
voice_id: str = Body(..., embed=True),
|
140 |
+
text: str = Body(..., embed=True),
|
141 |
+
temperature: float = Body(0.65, embed=True),
|
142 |
+
response_format: str = Body("mp3", embed=True)
|
143 |
+
):
|
144 |
+
"""
|
145 |
+
Generate speech using a cloned voice.
|
146 |
+
|
147 |
+
- **voice_id**: ID of the cloned voice to use
|
148 |
+
- **text**: Text to synthesize
|
149 |
+
- **temperature**: Sampling temperature (lower = more stable, higher = more varied)
|
150 |
+
- **response_format**: Audio format (mp3, wav, etc.)
|
151 |
+
"""
|
152 |
+
if not hasattr(request.app.state, "voice_cloner"):
|
153 |
+
raise HTTPException(status_code=503, detail="Voice cloning service not available")
|
154 |
+
|
155 |
+
voice_cloner = request.app.state.voice_cloner
|
156 |
+
|
157 |
+
# Validate voice ID
|
158 |
+
if voice_id not in voice_cloner.cloned_voices:
|
159 |
+
raise HTTPException(status_code=404, detail=f"Voice with ID {voice_id} not found")
|
160 |
+
|
161 |
+
# MIME type mapping
|
162 |
+
mime_types = {
|
163 |
+
"mp3": "audio/mpeg",
|
164 |
+
"wav": "audio/wav",
|
165 |
+
"ogg": "audio/ogg",
|
166 |
+
"flac": "audio/flac",
|
167 |
+
"m4a": "audio/mp4",
|
168 |
+
}
|
169 |
+
|
170 |
+
# Set default if format not specified
|
171 |
+
if response_format not in mime_types:
|
172 |
+
response_format = "mp3"
|
173 |
+
|
174 |
+
try:
|
175 |
+
# Generate speech with the cloned voice - IMPORTANT: This is a synchronous operation
|
176 |
+
# Remove the await keyword here
|
177 |
+
audio = voice_cloner.generate_speech(
|
178 |
+
text=text,
|
179 |
+
voice_id=voice_id,
|
180 |
+
temperature=temperature
|
181 |
+
)
|
182 |
+
|
183 |
+
# Create temporary file for audio conversion
|
184 |
+
with tempfile.NamedTemporaryFile(suffix=f".{response_format}", delete=False) as temp_file:
|
185 |
+
temp_path = temp_file.name
|
186 |
+
|
187 |
+
# Save to WAV first (direct format for torchaudio)
|
188 |
+
wav_path = f"{temp_path}.wav"
|
189 |
+
torchaudio.save(wav_path, audio.unsqueeze(0).cpu(), voice_cloner.sample_rate)
|
190 |
+
|
191 |
+
# Convert to requested format
|
192 |
+
import ffmpeg
|
193 |
+
|
194 |
+
if response_format == "mp3":
|
195 |
+
(
|
196 |
+
ffmpeg.input(wav_path)
|
197 |
+
.output(temp_path, format='mp3', audio_bitrate='128k')
|
198 |
+
.run(quiet=True, overwrite_output=True)
|
199 |
+
)
|
200 |
+
elif response_format == "ogg":
|
201 |
+
(
|
202 |
+
ffmpeg.input(wav_path)
|
203 |
+
.output(temp_path, format='ogg')
|
204 |
+
.run(quiet=True, overwrite_output=True)
|
205 |
+
)
|
206 |
+
elif response_format == "flac":
|
207 |
+
(
|
208 |
+
ffmpeg.input(wav_path)
|
209 |
+
.output(temp_path, format='flac')
|
210 |
+
.run(quiet=True, overwrite_output=True)
|
211 |
+
)
|
212 |
+
elif response_format == "m4a":
|
213 |
+
(
|
214 |
+
ffmpeg.input(wav_path)
|
215 |
+
.output(temp_path, format='mp4')
|
216 |
+
.run(quiet=True, overwrite_output=True)
|
217 |
+
)
|
218 |
+
else: # wav
|
219 |
+
temp_path = wav_path
|
220 |
+
|
221 |
+
# Clean up the temporary WAV file if we created a different format
|
222 |
+
if temp_path != wav_path and os.path.exists(wav_path):
|
223 |
+
os.unlink(wav_path)
|
224 |
+
|
225 |
+
# Return audio file as response
|
226 |
+
def iterfile():
|
227 |
+
with open(temp_path, 'rb') as f:
|
228 |
+
yield from f
|
229 |
+
# Clean up temp file after streaming
|
230 |
+
if os.path.exists(temp_path):
|
231 |
+
os.unlink(temp_path)
|
232 |
+
|
233 |
+
return StreamingResponse(
|
234 |
+
iterfile(),
|
235 |
+
media_type=mime_types.get(response_format, "application/octet-stream"),
|
236 |
+
headers={'Content-Disposition': f'attachment; filename="speech.{response_format}"'}
|
237 |
+
)
|
238 |
+
|
239 |
+
except Exception as e:
|
240 |
+
import traceback
|
241 |
+
error_trace = traceback.format_exc()
|
242 |
+
request.app.state.logger.error(f"Speech generation failed: {e}\n{error_trace}")
|
243 |
+
raise HTTPException(status_code=500, detail=f"Speech generation failed: {str(e)}")
|
244 |
+
|
245 |
+
@router.post("/voices/{voice_id}/preview", summary="Generate a preview of a cloned voice")
|
246 |
+
async def generate_preview(
|
247 |
+
request: Request,
|
248 |
+
voice_id: str,
|
249 |
+
text: Optional[str] = Body("This is a preview of my cloned voice.", embed=True),
|
250 |
+
response_format: str = Body("mp3", embed=True)
|
251 |
+
):
|
252 |
+
"""
|
253 |
+
Generate a preview of a cloned voice with a standard text.
|
254 |
+
|
255 |
+
- **voice_id**: ID of the cloned voice to use
|
256 |
+
- **text**: Optional custom text for the preview
|
257 |
+
- **response_format**: Audio format (mp3, wav, etc.)
|
258 |
+
"""
|
259 |
+
# Use the generate_speech endpoint with a standard text
|
260 |
+
return await generate_speech(
|
261 |
+
request=request,
|
262 |
+
voice_id=voice_id,
|
263 |
+
text=text,
|
264 |
+
temperature=0.7,
|
265 |
+
response_format=response_format
|
266 |
+
)
|
267 |
+
|
268 |
+
@router.get("/openai-compatible-voices", summary="List cloned voices in OpenAI format")
|
269 |
+
async def list_voices_openai_format(request: Request):
|
270 |
+
"""List all cloned voices in OpenAI-compatible format."""
|
271 |
+
if not hasattr(request.app.state, "voice_cloner"):
|
272 |
+
raise HTTPException(status_code=503, detail="Voice cloning service not available")
|
273 |
+
|
274 |
+
voice_cloner = request.app.state.voice_cloner
|
275 |
+
voices = voice_cloner.list_voices()
|
276 |
+
|
277 |
+
# Format voices in OpenAI-compatible format
|
278 |
+
openai_voices = []
|
279 |
+
for voice in voices:
|
280 |
+
openai_voices.append({
|
281 |
+
"voice_id": voice.id,
|
282 |
+
"name": voice.name,
|
283 |
+
"preview_url": f"/v1/voice-cloning/voices/{voice.id}/preview",
|
284 |
+
"description": voice.description or f"Cloned voice: {voice.name}",
|
285 |
+
"languages": [{"language_code": "en", "name": "English"}],
|
286 |
+
"cloned": True
|
287 |
+
})
|
288 |
+
|
289 |
+
return {"voices": openai_voices}
|
app/audio_processing.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Audio processing utilities for CSM-1B TTS API."""
|
2 |
+
import logging
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from scipy import signal
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
def remove_long_silences(
|
10 |
+
audio: torch.Tensor,
|
11 |
+
sample_rate: int,
|
12 |
+
min_speech_energy: float = 0.015,
|
13 |
+
max_silence_sec: float = 0.4,
|
14 |
+
keep_silence_sec: float = 0.1,
|
15 |
+
) -> torch.Tensor:
|
16 |
+
"""
|
17 |
+
Remove uncomfortably long silences from audio while preserving natural pauses.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
audio: Audio tensor
|
21 |
+
sample_rate: Sample rate in Hz
|
22 |
+
min_speech_energy: Minimum RMS energy to consider as speech
|
23 |
+
max_silence_sec: Maximum silence duration to keep in seconds
|
24 |
+
keep_silence_sec: Amount of silence to keep at speech boundaries
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Audio with long silences removed
|
28 |
+
"""
|
29 |
+
# Convert to numpy for processing
|
30 |
+
audio_np = audio.cpu().numpy()
|
31 |
+
|
32 |
+
# Calculate frame size and hop length
|
33 |
+
frame_size = int(0.02 * sample_rate) # 20ms frames
|
34 |
+
hop_length = int(0.01 * sample_rate) # 10ms hop
|
35 |
+
|
36 |
+
# Compute frame energy
|
37 |
+
frames = []
|
38 |
+
for i in range(0, len(audio_np) - frame_size + 1, hop_length):
|
39 |
+
frames.append(audio_np[i:i+frame_size])
|
40 |
+
|
41 |
+
if len(frames) < 2: # If audio is too short for analysis
|
42 |
+
return audio
|
43 |
+
|
44 |
+
frames = np.array(frames)
|
45 |
+
# Root mean square energy
|
46 |
+
frame_energy = np.sqrt(np.mean(frames**2, axis=1))
|
47 |
+
|
48 |
+
# Adaptive threshold based on audio content
|
49 |
+
# Uses a percentile to adapt to different audio characteristics
|
50 |
+
energy_threshold = max(
|
51 |
+
min_speech_energy, # Minimum threshold
|
52 |
+
np.percentile(frame_energy, 10) # Adapt to audio
|
53 |
+
)
|
54 |
+
|
55 |
+
# Identify speech frames
|
56 |
+
is_speech = frame_energy > energy_threshold
|
57 |
+
|
58 |
+
# Convert frame indices to sample indices considering overlapping frames
|
59 |
+
speech_segments = []
|
60 |
+
in_speech = False
|
61 |
+
speech_start = 0
|
62 |
+
|
63 |
+
for i in range(len(is_speech)):
|
64 |
+
if is_speech[i] and not in_speech:
|
65 |
+
# Start of speech
|
66 |
+
in_speech = True
|
67 |
+
# Calculate start sample including keep_silence
|
68 |
+
speech_start = max(0, i * hop_length - int(keep_silence_sec * sample_rate))
|
69 |
+
|
70 |
+
elif not is_speech[i] and in_speech:
|
71 |
+
# Potential end of speech, look ahead to check if silence continues
|
72 |
+
silence_length = 0
|
73 |
+
for j in range(i, min(len(is_speech), i + int(max_silence_sec * sample_rate / hop_length))):
|
74 |
+
if not is_speech[j]:
|
75 |
+
silence_length += 1
|
76 |
+
else:
|
77 |
+
break
|
78 |
+
|
79 |
+
if silence_length * hop_length >= max_silence_sec * sample_rate:
|
80 |
+
# End of speech, long enough silence detected
|
81 |
+
in_speech = False
|
82 |
+
# Calculate end sample including keep_silence
|
83 |
+
speech_end = min(len(audio_np), i * hop_length + int(keep_silence_sec * sample_rate))
|
84 |
+
speech_segments.append((speech_start, speech_end))
|
85 |
+
|
86 |
+
# Handle the case where audio ends during speech
|
87 |
+
if in_speech:
|
88 |
+
speech_segments.append((speech_start, len(audio_np)))
|
89 |
+
|
90 |
+
if not speech_segments:
|
91 |
+
logger.warning("No speech segments detected, returning original audio")
|
92 |
+
return audio
|
93 |
+
|
94 |
+
# Combine speech segments with controlled silence durations
|
95 |
+
result = []
|
96 |
+
|
97 |
+
# Add initial silence if the first segment doesn't start at the beginning
|
98 |
+
if speech_segments[0][0] > 0:
|
99 |
+
# Add a short leading silence (100ms)
|
100 |
+
silence_samples = min(int(0.1 * sample_rate), speech_segments[0][0])
|
101 |
+
if silence_samples > 0:
|
102 |
+
result.append(audio_np[speech_segments[0][0] - silence_samples:speech_segments[0][0]])
|
103 |
+
|
104 |
+
# Process each speech segment
|
105 |
+
for i, (start, end) in enumerate(speech_segments):
|
106 |
+
# Add this speech segment
|
107 |
+
result.append(audio_np[start:end])
|
108 |
+
|
109 |
+
# Add a controlled silence between segments
|
110 |
+
if i < len(speech_segments) - 1:
|
111 |
+
next_start = speech_segments[i+1][0]
|
112 |
+
# Calculate available silence duration
|
113 |
+
available_silence = next_start - end
|
114 |
+
|
115 |
+
if available_silence > 0:
|
116 |
+
# Use either the actual silence (if shorter than max) or the max allowed
|
117 |
+
silence_duration = min(available_silence, int(max_silence_sec * sample_rate))
|
118 |
+
# Take the first portion of the silence - usually cleaner
|
119 |
+
result.append(audio_np[end:end + silence_duration])
|
120 |
+
|
121 |
+
# Combine all parts
|
122 |
+
processed_audio = np.concatenate(result)
|
123 |
+
|
124 |
+
# Log the results
|
125 |
+
original_duration = len(audio_np) / sample_rate
|
126 |
+
processed_duration = len(processed_audio) / sample_rate
|
127 |
+
logger.info(f"Silence removal: {original_duration:.2f}s -> {processed_duration:.2f}s ({processed_duration/original_duration*100:.1f}%)")
|
128 |
+
|
129 |
+
# Return as tensor with original device and dtype
|
130 |
+
return torch.tensor(processed_audio, device=audio.device, dtype=audio.dtype)
|
131 |
+
|
132 |
+
def create_high_shelf_filter(audio, sample_rate, frequency=4000, gain_db=3.0):
|
133 |
+
"""
|
134 |
+
Create a high shelf filter to boost frequencies above the given frequency.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
audio: Audio numpy array
|
138 |
+
sample_rate: Sample rate in Hz
|
139 |
+
frequency: Shelf frequency in Hz
|
140 |
+
gain_db: Gain in dB for frequencies above the shelf
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Filtered audio
|
144 |
+
"""
|
145 |
+
# Convert gain from dB to linear
|
146 |
+
gain = 10 ** (gain_db / 20.0)
|
147 |
+
|
148 |
+
# Normalized frequency (0 to 1, where 1 is Nyquist frequency)
|
149 |
+
normalized_freq = 2.0 * frequency / sample_rate
|
150 |
+
|
151 |
+
# Design a high-shelf biquad filter
|
152 |
+
# This is a standard second-order section (SOS) implementation
|
153 |
+
b0 = gain
|
154 |
+
b1 = 0
|
155 |
+
b2 = 0
|
156 |
+
a0 = 1
|
157 |
+
a1 = 0
|
158 |
+
a2 = 0
|
159 |
+
|
160 |
+
# Simple first-order high-shelf filter
|
161 |
+
alpha = np.sin(np.pi * normalized_freq) / 2 * np.sqrt((gain + 1/gain) * (1/0.5 - 1) + 2)
|
162 |
+
cos_w0 = np.cos(np.pi * normalized_freq)
|
163 |
+
|
164 |
+
b0 = gain * ((gain + 1) + (gain - 1) * cos_w0 + 2 * np.sqrt(gain) * alpha)
|
165 |
+
b1 = -2 * gain * ((gain - 1) + (gain + 1) * cos_w0)
|
166 |
+
b2 = gain * ((gain + 1) + (gain - 1) * cos_w0 - 2 * np.sqrt(gain) * alpha)
|
167 |
+
a0 = (gain + 1) - (gain - 1) * cos_w0 + 2 * np.sqrt(gain) * alpha
|
168 |
+
a1 = 2 * ((gain - 1) - (gain + 1) * cos_w0)
|
169 |
+
a2 = (gain + 1) - (gain - 1) * cos_w0 - 2 * np.sqrt(gain) * alpha
|
170 |
+
|
171 |
+
# Normalize coefficients
|
172 |
+
b = np.array([b0, b1, b2]) / a0
|
173 |
+
a = np.array([1.0, a1/a0, a2/a0])
|
174 |
+
|
175 |
+
# Apply the filter
|
176 |
+
return signal.lfilter(b, a, audio)
|
177 |
+
|
178 |
+
def enhance_audio_quality(audio: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
179 |
+
"""
|
180 |
+
Enhance audio quality by applying various processing techniques.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
audio: Audio tensor
|
184 |
+
sample_rate: Sample rate in Hz
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
Enhanced audio tensor
|
188 |
+
"""
|
189 |
+
try:
|
190 |
+
audio_np = audio.cpu().numpy()
|
191 |
+
|
192 |
+
# Remove DC offset
|
193 |
+
audio_np = audio_np - np.mean(audio_np)
|
194 |
+
|
195 |
+
# Apply light compression to improve perceived loudness
|
196 |
+
# Compress by reducing peaks and increasing quieter parts slightly
|
197 |
+
threshold = 0.5
|
198 |
+
ratio = 1.5
|
199 |
+
attack = 0.01
|
200 |
+
release = 0.1
|
201 |
+
|
202 |
+
# Simple compression algorithm
|
203 |
+
gain = np.ones_like(audio_np)
|
204 |
+
for i in range(1, len(audio_np)):
|
205 |
+
level = abs(audio_np[i])
|
206 |
+
if level > threshold:
|
207 |
+
gain[i] = threshold + (level - threshold) / ratio
|
208 |
+
gain[i] = gain[i] / level if level > 0 else 1.0
|
209 |
+
else:
|
210 |
+
gain[i] = 1.0
|
211 |
+
|
212 |
+
# Smooth gain changes
|
213 |
+
gain[i] = gain[i-1] + (gain[i] - gain[i-1]) * (attack if gain[i] < gain[i-1] else release)
|
214 |
+
|
215 |
+
audio_np = audio_np * gain
|
216 |
+
|
217 |
+
# Apply high-shelf filter to enhance speech clarity
|
218 |
+
# Boost frequencies above 4000 Hz by 3 dB
|
219 |
+
audio_np = create_high_shelf_filter(audio_np, sample_rate, frequency=4000, gain_db=3.0)
|
220 |
+
|
221 |
+
# Normalize to prevent clipping
|
222 |
+
max_val = np.max(np.abs(audio_np))
|
223 |
+
if max_val > 0:
|
224 |
+
audio_np = audio_np * 0.95 / max_val
|
225 |
+
|
226 |
+
return torch.tensor(audio_np, device=audio.device, dtype=audio.dtype)
|
227 |
+
|
228 |
+
except Exception as e:
|
229 |
+
logger.warning(f"Audio quality enhancement failed: {e}")
|
230 |
+
return audio
|
app/custom_transformer.py
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Custom transformer implementation for fallback."""
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import math
|
5 |
+
import logging
|
6 |
+
|
7 |
+
# Set up logging
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class RMSNorm(nn.Module):
|
11 |
+
"""Root Mean Square Layer Normalization."""
|
12 |
+
|
13 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
14 |
+
super().__init__()
|
15 |
+
self.eps = eps
|
16 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
# Calculate RMS
|
20 |
+
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
21 |
+
return self.weight * rms * x
|
22 |
+
|
23 |
+
|
24 |
+
class RotaryEmbedding(nn.Module):
|
25 |
+
"""Rotary positional embedding."""
|
26 |
+
|
27 |
+
def __init__(self, dim, max_seq_len=2048, base=10000):
|
28 |
+
super().__init__()
|
29 |
+
self.dim = dim
|
30 |
+
self.max_seq_len = max_seq_len
|
31 |
+
self.base = base
|
32 |
+
|
33 |
+
# Generate frequency tensor
|
34 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
35 |
+
self.register_buffer("inv_freq", inv_freq)
|
36 |
+
|
37 |
+
# Generate cos and sin cache
|
38 |
+
self._update_cos_sin_cache(max_seq_len)
|
39 |
+
|
40 |
+
def _update_cos_sin_cache(self, max_seq_len):
|
41 |
+
"""Update the cache of cos and sin values."""
|
42 |
+
self.max_seq_len = max_seq_len
|
43 |
+
t = torch.arange(max_seq_len, device=self.inv_freq.device)
|
44 |
+
|
45 |
+
# Compute cos and sin at each position
|
46 |
+
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
|
47 |
+
cos = freqs.cos()
|
48 |
+
sin = freqs.sin()
|
49 |
+
|
50 |
+
self.register_buffer("cos_cache", cos, persistent=False)
|
51 |
+
self.register_buffer("sin_cache", sin, persistent=False)
|
52 |
+
|
53 |
+
def forward(self, x, seq_len=None, pos=None):
|
54 |
+
# Get appropriate parts of the cache
|
55 |
+
if pos is not None:
|
56 |
+
# Handle arbitrary positions
|
57 |
+
cos = self.cos_cache[pos]
|
58 |
+
sin = self.sin_cache[pos]
|
59 |
+
else:
|
60 |
+
# Handle sequential positions
|
61 |
+
seq_len = x.shape[1] if seq_len is None else seq_len
|
62 |
+
cos = self.cos_cache[:seq_len]
|
63 |
+
sin = self.sin_cache[:seq_len]
|
64 |
+
|
65 |
+
return cos, sin
|
66 |
+
|
67 |
+
|
68 |
+
def rotate_half(x):
|
69 |
+
"""Rotate half the dimensions of the input."""
|
70 |
+
x1, x2 = x.chunk(2, dim=-1)
|
71 |
+
return torch.cat((-x2, x1), dim=-1)
|
72 |
+
|
73 |
+
|
74 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
|
75 |
+
"""Apply rotary position embedding to q and k."""
|
76 |
+
if position_ids is not None:
|
77 |
+
# Handle arbitrary positions
|
78 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
79 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
80 |
+
else:
|
81 |
+
# Handle sequential positions
|
82 |
+
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
|
83 |
+
sin = sin.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, dim]
|
84 |
+
|
85 |
+
# Apply rotation
|
86 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
87 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
88 |
+
|
89 |
+
return q_embed, k_embed
|
90 |
+
|
91 |
+
|
92 |
+
class CustomAttention(nn.Module):
|
93 |
+
"""Multi-head attention with support for KV caching."""
|
94 |
+
|
95 |
+
def __init__(self, dim, num_heads, num_kv_heads=None, dropout=0.0):
|
96 |
+
super().__init__()
|
97 |
+
self.dim = dim
|
98 |
+
self.num_heads = num_heads
|
99 |
+
self.num_kv_heads = num_kv_heads or num_heads
|
100 |
+
self.head_dim = dim // num_heads
|
101 |
+
self.scale = self.head_dim ** -0.5
|
102 |
+
|
103 |
+
# Attention projections
|
104 |
+
self.q_proj = nn.Linear(dim, num_heads * self.head_dim, bias=False)
|
105 |
+
self.k_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
|
106 |
+
self.v_proj = nn.Linear(dim, self.num_kv_heads * self.head_dim, bias=False)
|
107 |
+
self.o_proj = nn.Linear(num_heads * self.head_dim, dim, bias=False)
|
108 |
+
|
109 |
+
# Rotary embedding
|
110 |
+
self.rope = RotaryEmbedding(self.head_dim)
|
111 |
+
|
112 |
+
# Dropout
|
113 |
+
self.dropout = nn.Dropout(dropout)
|
114 |
+
|
115 |
+
def _repeat_kv(self, x):
|
116 |
+
"""Repeat KV heads to match the number of query heads."""
|
117 |
+
if self.num_kv_heads == self.num_heads:
|
118 |
+
return x
|
119 |
+
|
120 |
+
b, s, n_kv_head, head_dim = x.shape
|
121 |
+
|
122 |
+
# Repeat the KV heads to match the number of query heads
|
123 |
+
repeats = self.num_heads // self.num_kv_heads
|
124 |
+
x = x.repeat_interleave(repeats, dim=2)
|
125 |
+
|
126 |
+
return x
|
127 |
+
|
128 |
+
def forward(self, x, mask=None, input_pos=None, kv_cache=None):
|
129 |
+
batch_size, seq_len, _ = x.shape
|
130 |
+
|
131 |
+
# Project to q, k, v
|
132 |
+
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, nh, s, hd]
|
133 |
+
k = self.k_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, nkh, s, hd]
|
134 |
+
v = self.v_proj(x).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, nkh, s, hd]
|
135 |
+
|
136 |
+
# Apply rotary embeddings
|
137 |
+
cos, sin = self.rope.forward(x, seq_len=seq_len, pos=input_pos)
|
138 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=input_pos)
|
139 |
+
|
140 |
+
# Handle KV cache
|
141 |
+
if kv_cache is not None:
|
142 |
+
k_cache, v_cache = kv_cache
|
143 |
+
|
144 |
+
if input_pos is not None:
|
145 |
+
# Update cache at specific positions
|
146 |
+
k_cache.index_copy_(2, input_pos, k)
|
147 |
+
v_cache.index_copy_(2, input_pos, v)
|
148 |
+
|
149 |
+
# Use the entire cache
|
150 |
+
k, v = k_cache, v_cache
|
151 |
+
|
152 |
+
# Repeat KV if needed
|
153 |
+
k = self._repeat_kv(k)
|
154 |
+
v = self._repeat_kv(v)
|
155 |
+
|
156 |
+
# Calculate attention scores
|
157 |
+
attention_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
158 |
+
|
159 |
+
# Apply mask if provided
|
160 |
+
if mask is not None:
|
161 |
+
attention_scores = attention_scores.masked_fill(mask == 0, -10000.0)
|
162 |
+
|
163 |
+
# Apply softmax and dropout
|
164 |
+
attention_probs = self.dropout(torch.softmax(attention_scores, dim=-1))
|
165 |
+
|
166 |
+
# Get context vector
|
167 |
+
context = torch.matmul(attention_probs, v)
|
168 |
+
|
169 |
+
# Reshape and project back
|
170 |
+
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
171 |
+
output = self.o_proj(context)
|
172 |
+
|
173 |
+
return output
|
174 |
+
|
175 |
+
|
176 |
+
class FeedForward(nn.Module):
|
177 |
+
"""Feed-forward network with GELU activation."""
|
178 |
+
|
179 |
+
def __init__(self, dim, hidden_dim, dropout=0.0):
|
180 |
+
super().__init__()
|
181 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
182 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
183 |
+
self.dropout = nn.Dropout(dropout)
|
184 |
+
self.act = nn.GELU()
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
x = self.w1(x)
|
188 |
+
x = self.act(x)
|
189 |
+
x = self.dropout(x)
|
190 |
+
x = self.w2(x)
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
class TransformerLayer(nn.Module):
|
195 |
+
"""A single transformer layer."""
|
196 |
+
|
197 |
+
def __init__(
|
198 |
+
self,
|
199 |
+
dim,
|
200 |
+
num_heads,
|
201 |
+
num_kv_heads=None,
|
202 |
+
ffn_dim=None,
|
203 |
+
dropout=0.0,
|
204 |
+
norm_eps=1e-5
|
205 |
+
):
|
206 |
+
super().__init__()
|
207 |
+
self.norm1 = RMSNorm(dim, eps=norm_eps)
|
208 |
+
self.attn = CustomAttention(dim, num_heads, num_kv_heads, dropout)
|
209 |
+
self.norm2 = RMSNorm(dim, eps=norm_eps)
|
210 |
+
self.ffn = FeedForward(
|
211 |
+
dim,
|
212 |
+
ffn_dim or 4 * dim,
|
213 |
+
dropout
|
214 |
+
)
|
215 |
+
|
216 |
+
def forward(self, x, mask=None, input_pos=None, kv_cache=None):
|
217 |
+
# Self-attention with residual
|
218 |
+
h = self.norm1(x)
|
219 |
+
h = self.attn(h, mask=mask, input_pos=input_pos, kv_cache=kv_cache)
|
220 |
+
x = x + h
|
221 |
+
|
222 |
+
# FFN with residual
|
223 |
+
h = self.norm2(x)
|
224 |
+
h = self.ffn(h)
|
225 |
+
x = x + h
|
226 |
+
|
227 |
+
return x
|
228 |
+
|
229 |
+
|
230 |
+
class CustomTransformerDecoder(nn.Module):
|
231 |
+
"""Custom transformer decoder that mimics Llama architecture."""
|
232 |
+
|
233 |
+
def __init__(
|
234 |
+
self,
|
235 |
+
vocab_size,
|
236 |
+
num_layers,
|
237 |
+
num_heads,
|
238 |
+
num_kv_heads,
|
239 |
+
embed_dim,
|
240 |
+
max_seq_len,
|
241 |
+
intermediate_dim,
|
242 |
+
attn_dropout=0.0,
|
243 |
+
norm_eps=1e-5,
|
244 |
+
rope_base=10000,
|
245 |
+
):
|
246 |
+
super().__init__()
|
247 |
+
self.vocab_size = vocab_size
|
248 |
+
self.max_seq_len = max_seq_len
|
249 |
+
self.embed_dim = embed_dim
|
250 |
+
|
251 |
+
# Token embeddings
|
252 |
+
self.tok_embeddings = nn.Embedding(vocab_size, embed_dim)
|
253 |
+
|
254 |
+
# Transformer layers
|
255 |
+
self.layers = nn.ModuleList([
|
256 |
+
TransformerLayer(
|
257 |
+
embed_dim,
|
258 |
+
num_heads,
|
259 |
+
num_kv_heads,
|
260 |
+
intermediate_dim,
|
261 |
+
attn_dropout,
|
262 |
+
norm_eps
|
263 |
+
)
|
264 |
+
for _ in range(num_layers)
|
265 |
+
])
|
266 |
+
|
267 |
+
# Final normalization and output projection
|
268 |
+
self.norm = RMSNorm(embed_dim, eps=norm_eps)
|
269 |
+
self.output = nn.Linear(embed_dim, vocab_size, bias=False)
|
270 |
+
|
271 |
+
# Initialize the KV cache
|
272 |
+
self._kv_cache = None
|
273 |
+
self._has_cache = False
|
274 |
+
|
275 |
+
logger.info(f"Initialized CustomTransformerDecoder with {num_layers} layers, {num_heads} heads, {embed_dim} dim")
|
276 |
+
|
277 |
+
def setup_caches(self, batch_size, dtype, decoder_max_seq_len=None):
|
278 |
+
"""Set up KV caches for inference."""
|
279 |
+
max_seq_len = decoder_max_seq_len or self.max_seq_len
|
280 |
+
device = next(self.parameters()).device
|
281 |
+
|
282 |
+
self._kv_cache = []
|
283 |
+
for i, layer in enumerate(self.layers):
|
284 |
+
# Create a KV cache for each layer
|
285 |
+
k_cache = torch.zeros(
|
286 |
+
batch_size,
|
287 |
+
layer.attn.num_kv_heads,
|
288 |
+
max_seq_len,
|
289 |
+
layer.attn.head_dim,
|
290 |
+
device=device,
|
291 |
+
dtype=dtype
|
292 |
+
)
|
293 |
+
v_cache = torch.zeros(
|
294 |
+
batch_size,
|
295 |
+
layer.attn.num_kv_heads,
|
296 |
+
max_seq_len,
|
297 |
+
layer.attn.head_dim,
|
298 |
+
device=device,
|
299 |
+
dtype=dtype
|
300 |
+
)
|
301 |
+
self._kv_cache.append((k_cache, v_cache))
|
302 |
+
|
303 |
+
self._has_cache = True
|
304 |
+
logger.info(f"KV caches set up for {batch_size} batches, {max_seq_len} seq length")
|
305 |
+
|
306 |
+
def caches_are_enabled(self):
|
307 |
+
"""Check if caches are enabled."""
|
308 |
+
return self._has_cache
|
309 |
+
|
310 |
+
def reset_caches(self):
|
311 |
+
"""Reset the KV cache to zeros."""
|
312 |
+
if self._has_cache and self._kv_cache:
|
313 |
+
for k_cache, v_cache in self._kv_cache:
|
314 |
+
k_cache.zero_()
|
315 |
+
v_cache.zero_()
|
316 |
+
|
317 |
+
def forward(self, x, mask=None, input_pos=None):
|
318 |
+
batch_size, seq_len = x.shape[:2]
|
319 |
+
|
320 |
+
# Apply embedding if input is token IDs
|
321 |
+
if x.dim() == 2:
|
322 |
+
x = self.tok_embeddings(x)
|
323 |
+
|
324 |
+
# Apply transformer layers
|
325 |
+
for i, layer in enumerate(self.layers):
|
326 |
+
layer_cache = self._kv_cache[i] if self._has_cache else None
|
327 |
+
x = layer(x, mask=mask, input_pos=input_pos, kv_cache=layer_cache)
|
328 |
+
|
329 |
+
# Apply final norm
|
330 |
+
x = self.norm(x)
|
331 |
+
|
332 |
+
# Skip output projection if using Identity
|
333 |
+
if isinstance(self.output, nn.Identity):
|
334 |
+
return x
|
335 |
+
|
336 |
+
# Apply output projection
|
337 |
+
logits = self.output(x)
|
338 |
+
|
339 |
+
return logits
|
app/download_model.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Download CSM-1B model from Hugging Face."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
+
|
7 |
+
def download_model(output_dir="models"):
|
8 |
+
"""Download CSM-1B model from Hugging Face."""
|
9 |
+
print("Downloading CSM-1B model...")
|
10 |
+
os.makedirs(output_dir, exist_ok=True)
|
11 |
+
|
12 |
+
# Download model
|
13 |
+
model_path = hf_hub_download(
|
14 |
+
repo_id="sesame/csm-1b",
|
15 |
+
filename="ckpt.pt",
|
16 |
+
local_dir=output_dir,
|
17 |
+
local_dir_use_symlinks=False
|
18 |
+
)
|
19 |
+
|
20 |
+
print(f"Model downloaded to {model_path}")
|
21 |
+
return model_path
|
22 |
+
|
23 |
+
if __name__ == "__main__":
|
24 |
+
parser = argparse.ArgumentParser(description="Download CSM-1B model")
|
25 |
+
parser.add_argument("--output", type=str, default="models", help="Output directory")
|
26 |
+
args = parser.parse_args()
|
27 |
+
|
28 |
+
download_model(args.output)
|
app/generator.py
ADDED
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Updated generator.py with proper function order
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Tuple
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
+
from transformers import AutoTokenizer
|
10 |
+
from tokenizers.processors import TemplateProcessing
|
11 |
+
from app.models import Segment
|
12 |
+
from app.text_normalizer import clean_text_for_tts
|
13 |
+
from app.text_normalizer import TextNormalizer
|
14 |
+
|
15 |
+
|
16 |
+
# Set up logging
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
# Import the CSM watermarking code
|
20 |
+
try:
|
21 |
+
from app.watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
|
22 |
+
except ImportError:
|
23 |
+
# Define stubs for watermarking if the module is not available
|
24 |
+
CSM_1B_GH_WATERMARK = "CSM1B"
|
25 |
+
def load_watermarker(device="cpu"):
|
26 |
+
return None
|
27 |
+
def watermark(watermarker, audio, sample_rate, key):
|
28 |
+
return audio, sample_rate
|
29 |
+
|
30 |
+
def load_llama3_tokenizer():
|
31 |
+
"""
|
32 |
+
Load tokenizer for Llama 3.2, using unsloth's open version
|
33 |
+
instead of the gated meta-llama version.
|
34 |
+
"""
|
35 |
+
try:
|
36 |
+
# Use the unsloth version which is not gated
|
37 |
+
tokenizer_name = "unsloth/Llama-3.2-1B"
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
39 |
+
bos = tokenizer.bos_token
|
40 |
+
eos = tokenizer.eos_token
|
41 |
+
tokenizer._tokenizer.post_processor = TemplateProcessing(
|
42 |
+
single=f"{bos}:0 $A:0 {eos}:0",
|
43 |
+
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
|
44 |
+
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
|
45 |
+
)
|
46 |
+
logger.info("Successfully loaded tokenizer from unsloth/Llama-3.2-1B")
|
47 |
+
return tokenizer
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Error loading tokenizer from unsloth: {e}")
|
50 |
+
# Fallback to a simpler tokenizer if needed
|
51 |
+
try:
|
52 |
+
from transformers import GPT2Tokenizer
|
53 |
+
logger.warning("Falling back to GPT2Tokenizer")
|
54 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
55 |
+
tokenizer.pad_token = tokenizer.eos_token
|
56 |
+
return tokenizer
|
57 |
+
except Exception as fallback_e:
|
58 |
+
logger.error(f"Fallback tokenizer also failed: {fallback_e}")
|
59 |
+
raise RuntimeError("Could not load any suitable tokenizer")
|
60 |
+
|
61 |
+
class Generator:
|
62 |
+
"""Generator class for CSM-1B model."""
|
63 |
+
def __init__(self, model):
|
64 |
+
"""Initialize generator with model."""
|
65 |
+
self._model = model
|
66 |
+
self._model.setup_caches(1)
|
67 |
+
self._text_tokenizer = load_llama3_tokenizer()
|
68 |
+
device = next(model.parameters()).device
|
69 |
+
# Load Mimi codec for audio tokenization
|
70 |
+
try:
|
71 |
+
logger.info("Loading Mimi audio codec...")
|
72 |
+
from huggingface_hub import hf_hub_download
|
73 |
+
# First try to import from moshi
|
74 |
+
try:
|
75 |
+
from moshi.models import loaders
|
76 |
+
DEFAULT_REPO = loaders.DEFAULT_REPO
|
77 |
+
MIMI_NAME = loaders.MIMI_NAME
|
78 |
+
get_mimi = loaders.get_mimi
|
79 |
+
except ImportError:
|
80 |
+
logger.warning("moshi.models.loaders not found, using fallback")
|
81 |
+
# Fallback values if moshi.models.loaders is not available
|
82 |
+
DEFAULT_REPO = "kyutai/mimi"
|
83 |
+
MIMI_NAME = "mimi-december.pt"
|
84 |
+
# Fallback function to load mimi
|
85 |
+
def get_mimi(checkpoint_path, device):
|
86 |
+
from moshi.models.vqvae_model import MiMiModule
|
87 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
88 |
+
model = MiMiModule.init_from_checkpoint(checkpoint, device=device)
|
89 |
+
return model
|
90 |
+
mimi_weight = hf_hub_download(DEFAULT_REPO, MIMI_NAME)
|
91 |
+
mimi = get_mimi(mimi_weight, device=device)
|
92 |
+
mimi.set_num_codebooks(32)
|
93 |
+
self._audio_tokenizer = mimi
|
94 |
+
self.sample_rate = mimi.sample_rate
|
95 |
+
logger.info(f"Mimi codec loaded successfully with sample rate {self.sample_rate}")
|
96 |
+
except Exception as e:
|
97 |
+
logger.error(f"Error loading Mimi codec: {e}")
|
98 |
+
self._audio_tokenizer = None
|
99 |
+
self.sample_rate = 24000 # Default sample rate
|
100 |
+
logger.warning(f"Using fallback sample rate: {self.sample_rate}")
|
101 |
+
raise RuntimeError(f"Failed to load Mimi codec: {e}")
|
102 |
+
try:
|
103 |
+
self._watermarker = load_watermarker(device=device)
|
104 |
+
logger.info("Watermarker loaded successfully")
|
105 |
+
except Exception as e:
|
106 |
+
logger.warning(f"Error loading watermarker: {e}. Watermarking will be disabled.")
|
107 |
+
self._watermarker = None
|
108 |
+
|
109 |
+
self.device = device
|
110 |
+
# Optimize for CUDA throughput
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
torch.backends.cudnn.benchmark = True
|
113 |
+
torch.cuda.empty_cache()
|
114 |
+
logger.info("CUDA optimizations enabled")
|
115 |
+
|
116 |
+
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
117 |
+
"""Tokenize a text segment."""
|
118 |
+
frame_tokens = []
|
119 |
+
frame_masks = []
|
120 |
+
# Strip any voice instructions in square brackets to avoid them being read out
|
121 |
+
text = self._clean_text_input(text)
|
122 |
+
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
|
123 |
+
text_frame = torch.zeros(len(text_tokens), 33).long()
|
124 |
+
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
|
125 |
+
text_frame[:, -1] = torch.tensor(text_tokens)
|
126 |
+
text_frame_mask[:, -1] = True
|
127 |
+
frame_tokens.append(text_frame.to(self.device))
|
128 |
+
frame_masks.append(text_frame_mask.to(self.device))
|
129 |
+
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
130 |
+
|
131 |
+
def _clean_text_input(self, text: str) -> str:
|
132 |
+
"""Clean and normalize text for TTS."""
|
133 |
+
return clean_text_for_tts(text)
|
134 |
+
|
135 |
+
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
136 |
+
"""Tokenize audio."""
|
137 |
+
if self._audio_tokenizer is None:
|
138 |
+
raise RuntimeError("Audio tokenizer not initialized")
|
139 |
+
frame_tokens = []
|
140 |
+
frame_masks = []
|
141 |
+
# (K, T)
|
142 |
+
audio = audio.to(self.device)
|
143 |
+
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
|
144 |
+
# add EOS frame
|
145 |
+
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
|
146 |
+
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
|
147 |
+
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
|
148 |
+
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
|
149 |
+
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
|
150 |
+
audio_frame_mask[:, :-1] = True
|
151 |
+
frame_tokens.append(audio_frame)
|
152 |
+
frame_masks.append(audio_frame_mask)
|
153 |
+
return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
|
154 |
+
|
155 |
+
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
|
156 |
+
"""Tokenize a segment of text and audio."""
|
157 |
+
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
|
158 |
+
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
|
159 |
+
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
|
160 |
+
|
161 |
+
def generate_quick(
|
162 |
+
self,
|
163 |
+
text: str,
|
164 |
+
speaker: int,
|
165 |
+
context: List[Segment],
|
166 |
+
max_audio_length_ms: float = 2000, # Short for quick generation
|
167 |
+
temperature: float = 0.7, # Lower for more predictable output
|
168 |
+
topk: int = 20, # Lower for faster beam selection
|
169 |
+
) -> torch.Tensor:
|
170 |
+
"""Generate audio quickly for real-time streaming."""
|
171 |
+
# Similar to generate() but optimized for speed
|
172 |
+
self._model.reset_caches()
|
173 |
+
|
174 |
+
# Convert max_audio_length_ms to frames - limit for faster generation
|
175 |
+
max_audio_frames = min(int(max_audio_length_ms / 80), 128) # Smaller limit
|
176 |
+
|
177 |
+
# Process text
|
178 |
+
cleaned_text = clean_text_for_tts(text)
|
179 |
+
|
180 |
+
# Prepare tokens
|
181 |
+
tokens, tokens_mask = [], []
|
182 |
+
# Add context segments (limited to 1 for speed)
|
183 |
+
if context:
|
184 |
+
segment_tokens, segment_tokens_mask = self._tokenize_segment(context[0])
|
185 |
+
tokens.append(segment_tokens)
|
186 |
+
tokens_mask.append(segment_tokens_mask)
|
187 |
+
# Add text tokens
|
188 |
+
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
|
189 |
+
tokens.append(gen_segment_tokens)
|
190 |
+
tokens_mask.append(gen_segment_tokens_mask)
|
191 |
+
|
192 |
+
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
193 |
+
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
194 |
+
|
195 |
+
# Generate with larger batch size for fewer iterations
|
196 |
+
curr_tokens = prompt_tokens.unsqueeze(0)
|
197 |
+
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
198 |
+
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
199 |
+
|
200 |
+
# Use larger batch size
|
201 |
+
batch_size = 64 # Generate more frames at once
|
202 |
+
all_samples = []
|
203 |
+
for start_idx in range(0, max_audio_frames, batch_size):
|
204 |
+
end_idx = min(start_idx + batch_size, max_audio_frames)
|
205 |
+
batch_frames = end_idx - start_idx
|
206 |
+
samples_batch = []
|
207 |
+
for i in range(batch_frames):
|
208 |
+
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
209 |
+
samples_batch.append(sample)
|
210 |
+
if torch.all(sample == 0):
|
211 |
+
break
|
212 |
+
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
|
213 |
+
curr_tokens_mask = torch.cat(
|
214 |
+
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
|
215 |
+
).unsqueeze(1)
|
216 |
+
curr_pos = curr_pos[:, -1:] + 1
|
217 |
+
all_samples.extend(samples_batch)
|
218 |
+
if len(samples_batch) < batch_frames:
|
219 |
+
break
|
220 |
+
|
221 |
+
if not all_samples:
|
222 |
+
return torch.zeros(10, device=self.device) # Return short empty audio
|
223 |
+
|
224 |
+
# Decode audio
|
225 |
+
audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
|
226 |
+
return audio
|
227 |
+
|
228 |
+
@torch.inference_mode()
|
229 |
+
def generate(
|
230 |
+
self,
|
231 |
+
text: str,
|
232 |
+
speaker: int,
|
233 |
+
context: List[Segment],
|
234 |
+
max_audio_length_ms: float = 90_000,
|
235 |
+
temperature: float = 0.9,
|
236 |
+
topk: int = 50,
|
237 |
+
) -> torch.Tensor:
|
238 |
+
"""Generate audio from text."""
|
239 |
+
if self._audio_tokenizer is None:
|
240 |
+
raise RuntimeError("Audio tokenizer not initialized")
|
241 |
+
|
242 |
+
# Start timing
|
243 |
+
start_time = torch.cuda.Event(enable_timing=True)
|
244 |
+
end_time = torch.cuda.Event(enable_timing=True)
|
245 |
+
start_time.record()
|
246 |
+
|
247 |
+
self._model.reset_caches()
|
248 |
+
|
249 |
+
# Convert max_audio_length_ms to frames - this controls the maximum generation length
|
250 |
+
max_audio_frames = min(int(max_audio_length_ms / 80), 1024) # Limit to reasonable size
|
251 |
+
max_seq_len = 2048 - max_audio_frames
|
252 |
+
|
253 |
+
# Check if text is long and should be split
|
254 |
+
if len(text) > 200:
|
255 |
+
logger.info(f"Long text detected ({len(text)} chars), processing in segments")
|
256 |
+
sentences = TextNormalizer.split_into_sentences(text)
|
257 |
+
logger.info(f"Split into {len(sentences)} segments")
|
258 |
+
|
259 |
+
# Process sentences individually and concatenate the results
|
260 |
+
all_audio_segments = []
|
261 |
+
|
262 |
+
# Use the first sentence to establish voice
|
263 |
+
first_sentence = sentences[0]
|
264 |
+
cleaned_text = clean_text_for_tts(first_sentence)
|
265 |
+
|
266 |
+
# Generate the first segment
|
267 |
+
tokens, tokens_mask = [], []
|
268 |
+
|
269 |
+
# Add context segments for the first sentence
|
270 |
+
for segment in context:
|
271 |
+
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
272 |
+
tokens.append(segment_tokens)
|
273 |
+
tokens_mask.append(segment_tokens_mask)
|
274 |
+
|
275 |
+
# Add first sentence tokens
|
276 |
+
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
|
277 |
+
tokens.append(gen_segment_tokens)
|
278 |
+
tokens_mask.append(gen_segment_tokens_mask)
|
279 |
+
|
280 |
+
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
281 |
+
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
282 |
+
|
283 |
+
# Check context size and truncate if needed
|
284 |
+
if prompt_tokens.size(0) >= max_seq_len:
|
285 |
+
logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}")
|
286 |
+
prompt_tokens = prompt_tokens[-max_seq_len+50:]
|
287 |
+
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:]
|
288 |
+
|
289 |
+
# Generate first sentence audio
|
290 |
+
curr_tokens = prompt_tokens.unsqueeze(0)
|
291 |
+
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
292 |
+
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
293 |
+
|
294 |
+
# Generate first segment
|
295 |
+
first_segment_samples = []
|
296 |
+
for start_idx in range(0, max_audio_frames, 32):
|
297 |
+
end_idx = min(start_idx + 32, max_audio_frames)
|
298 |
+
batch_frames = end_idx - start_idx
|
299 |
+
samples_batch = []
|
300 |
+
|
301 |
+
for i in range(batch_frames):
|
302 |
+
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
303 |
+
samples_batch.append(sample)
|
304 |
+
|
305 |
+
if torch.all(sample == 0):
|
306 |
+
break
|
307 |
+
|
308 |
+
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
|
309 |
+
curr_tokens_mask = torch.cat(
|
310 |
+
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
|
311 |
+
).unsqueeze(1)
|
312 |
+
curr_pos = curr_pos[:, -1:] + 1
|
313 |
+
|
314 |
+
first_segment_samples.extend(samples_batch)
|
315 |
+
|
316 |
+
if len(samples_batch) < batch_frames:
|
317 |
+
break
|
318 |
+
|
319 |
+
if not first_segment_samples:
|
320 |
+
raise RuntimeError("No audio generated for first segment")
|
321 |
+
|
322 |
+
# Decode first segment
|
323 |
+
first_segment_audio = self._audio_tokenizer.decode(
|
324 |
+
torch.stack(first_segment_samples).permute(1, 2, 0)
|
325 |
+
).squeeze(0).squeeze(0)
|
326 |
+
|
327 |
+
all_audio_segments.append(first_segment_audio)
|
328 |
+
|
329 |
+
# Now process remaining sentences using the first as context
|
330 |
+
for i, sentence in enumerate(sentences[1:], 1):
|
331 |
+
logger.info(f"Generating segment {i+1}/{len(sentences)}")
|
332 |
+
cleaned_text = clean_text_for_tts(sentence)
|
333 |
+
|
334 |
+
# Create a context segment from the previous generation
|
335 |
+
prev_segment = Segment(
|
336 |
+
speaker=speaker,
|
337 |
+
text=sentences[i-1],
|
338 |
+
audio=all_audio_segments[-1]
|
339 |
+
)
|
340 |
+
|
341 |
+
# Generate with this segment as context
|
342 |
+
segment_tokens, segment_tokens_mask = [], []
|
343 |
+
segment_tokens.append(self._tokenize_segment(prev_segment)[0])
|
344 |
+
segment_tokens_mask.append(self._tokenize_segment(prev_segment)[1])
|
345 |
+
|
346 |
+
# Add current segment tokens
|
347 |
+
current_tokens, current_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
|
348 |
+
segment_tokens.append(current_tokens)
|
349 |
+
segment_tokens_mask.append(current_tokens_mask)
|
350 |
+
|
351 |
+
segment_prompt_tokens = torch.cat(segment_tokens, dim=0).long().to(self.device)
|
352 |
+
segment_prompt_tokens_mask = torch.cat(segment_tokens_mask, dim=0).bool().to(self.device)
|
353 |
+
|
354 |
+
# Check length and truncate if needed
|
355 |
+
if segment_prompt_tokens.size(0) >= max_seq_len:
|
356 |
+
segment_prompt_tokens = segment_prompt_tokens[-max_seq_len+50:]
|
357 |
+
segment_prompt_tokens_mask = segment_prompt_tokens_mask[-max_seq_len+50:]
|
358 |
+
|
359 |
+
# Generate audio for this segment
|
360 |
+
curr_tokens = segment_prompt_tokens.unsqueeze(0)
|
361 |
+
curr_tokens_mask = segment_prompt_tokens_mask.unsqueeze(0)
|
362 |
+
curr_pos = torch.arange(0, segment_prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
363 |
+
|
364 |
+
# Generate segment
|
365 |
+
segment_samples = []
|
366 |
+
for start_idx in range(0, max_audio_frames, 32):
|
367 |
+
end_idx = min(start_idx + 32, max_audio_frames)
|
368 |
+
batch_frames = end_idx - start_idx
|
369 |
+
samples_batch = []
|
370 |
+
|
371 |
+
for i in range(batch_frames):
|
372 |
+
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
373 |
+
samples_batch.append(sample)
|
374 |
+
|
375 |
+
if torch.all(sample == 0):
|
376 |
+
break
|
377 |
+
|
378 |
+
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
|
379 |
+
curr_tokens_mask = torch.cat(
|
380 |
+
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
|
381 |
+
).unsqueeze(1)
|
382 |
+
curr_pos = curr_pos[:, -1:] + 1
|
383 |
+
|
384 |
+
segment_samples.extend(samples_batch)
|
385 |
+
|
386 |
+
if len(samples_batch) < batch_frames:
|
387 |
+
break
|
388 |
+
|
389 |
+
if not segment_samples:
|
390 |
+
logger.warning(f"No audio generated for segment {i+1}")
|
391 |
+
continue
|
392 |
+
|
393 |
+
# Decode segment
|
394 |
+
segment_audio = self._audio_tokenizer.decode(
|
395 |
+
torch.stack(segment_samples).permute(1, 2, 0)
|
396 |
+
).squeeze(0).squeeze(0)
|
397 |
+
|
398 |
+
all_audio_segments.append(segment_audio)
|
399 |
+
|
400 |
+
# Combine all segments with small pauses
|
401 |
+
pause_samples = int(0.3 * self.sample_rate) # 300ms pause
|
402 |
+
pause = torch.zeros(pause_samples, device=self.device)
|
403 |
+
|
404 |
+
audio_parts = []
|
405 |
+
for i, segment_audio in enumerate(all_audio_segments):
|
406 |
+
audio_parts.append(segment_audio)
|
407 |
+
if i < len(all_audio_segments) - 1:
|
408 |
+
audio_parts.append(pause)
|
409 |
+
|
410 |
+
audio = torch.cat(audio_parts)
|
411 |
+
logger.info(f"Combined {len(all_audio_segments)} segments into final audio")
|
412 |
+
|
413 |
+
else:
|
414 |
+
# For shorter text, standard processing
|
415 |
+
tokens, tokens_mask = [], []
|
416 |
+
|
417 |
+
# Add context segments
|
418 |
+
for segment in context:
|
419 |
+
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
|
420 |
+
tokens.append(segment_tokens)
|
421 |
+
tokens_mask.append(segment_tokens_mask)
|
422 |
+
|
423 |
+
# Process text
|
424 |
+
cleaned_text = clean_text_for_tts(text)
|
425 |
+
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(cleaned_text, speaker)
|
426 |
+
tokens.append(gen_segment_tokens)
|
427 |
+
tokens_mask.append(gen_segment_tokens_mask)
|
428 |
+
|
429 |
+
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
|
430 |
+
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
|
431 |
+
|
432 |
+
# Check context size
|
433 |
+
if prompt_tokens.size(0) >= max_seq_len:
|
434 |
+
logger.warning(f"Inputs too long ({prompt_tokens.size(0)} tokens), truncating to {max_seq_len - 50}")
|
435 |
+
prompt_tokens = prompt_tokens[-max_seq_len+50:]
|
436 |
+
prompt_tokens_mask = prompt_tokens_mask[-max_seq_len+50:]
|
437 |
+
|
438 |
+
# Generate audio - optimized batch generation
|
439 |
+
curr_tokens = prompt_tokens.unsqueeze(0)
|
440 |
+
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
|
441 |
+
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
|
442 |
+
|
443 |
+
# Using optimized batch generation
|
444 |
+
batch_size = 32 # Generate this many frames at once
|
445 |
+
all_samples = []
|
446 |
+
|
447 |
+
for start_idx in range(0, max_audio_frames, batch_size):
|
448 |
+
end_idx = min(start_idx + batch_size, max_audio_frames)
|
449 |
+
batch_frames = end_idx - start_idx
|
450 |
+
|
451 |
+
samples_batch = []
|
452 |
+
|
453 |
+
for i in range(batch_frames):
|
454 |
+
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
|
455 |
+
samples_batch.append(sample)
|
456 |
+
|
457 |
+
if torch.all(sample == 0):
|
458 |
+
break
|
459 |
+
|
460 |
+
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
|
461 |
+
curr_tokens_mask = torch.cat(
|
462 |
+
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
|
463 |
+
).unsqueeze(1)
|
464 |
+
curr_pos = curr_pos[:, -1:] + 1
|
465 |
+
|
466 |
+
all_samples.extend(samples_batch)
|
467 |
+
|
468 |
+
if len(samples_batch) < batch_frames:
|
469 |
+
logger.info(f"Early stopping at frame {start_idx + len(samples_batch)}/{max_audio_frames}")
|
470 |
+
break
|
471 |
+
|
472 |
+
if not all_samples:
|
473 |
+
raise RuntimeError("No audio generated - model produced empty output")
|
474 |
+
|
475 |
+
# Decode audio
|
476 |
+
audio = self._audio_tokenizer.decode(torch.stack(all_samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
|
477 |
+
|
478 |
+
# Apply watermark
|
479 |
+
if self._watermarker is not None:
|
480 |
+
try:
|
481 |
+
audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
|
482 |
+
audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
|
483 |
+
except Exception as e:
|
484 |
+
logger.warning(f"Error applying watermark: {e}. Continuing without watermark.")
|
485 |
+
|
486 |
+
# Record execution time
|
487 |
+
end_time.record()
|
488 |
+
torch.cuda.synchronize()
|
489 |
+
execution_ms = start_time.elapsed_time(end_time)
|
490 |
+
audio_length_ms = (audio.shape[0] / self.sample_rate) * 1000
|
491 |
+
|
492 |
+
# Calculate real-time factor (RTF)
|
493 |
+
rtf = execution_ms / audio_length_ms
|
494 |
+
logger.info(f"Audio generated in {execution_ms:.2f}ms, length: {audio_length_ms:.2f}ms, RTF: {rtf:.2f}x")
|
495 |
+
|
496 |
+
return audio
|
497 |
+
|
498 |
+
# Define helper functions for multi-GPU support
|
499 |
+
def _manual_device_map(model, state_dict, strategy="balanced"):
|
500 |
+
"""Apply manual device mapping for multi-GPU setups.
|
501 |
+
|
502 |
+
Args:
|
503 |
+
model: The model to map
|
504 |
+
state_dict: Model state dict
|
505 |
+
strategy: Mapping strategy ('balanced', 'sequential')
|
506 |
+
|
507 |
+
Returns:
|
508 |
+
Model with weights distributed across GPUs
|
509 |
+
"""
|
510 |
+
num_gpus = torch.cuda.device_count()
|
511 |
+
if num_gpus <= 1:
|
512 |
+
# No need for mapping with single GPU
|
513 |
+
model.load_state_dict(state_dict)
|
514 |
+
model = model.to("cuda")
|
515 |
+
return model
|
516 |
+
|
517 |
+
logger.info(f"Applying manual {strategy} device mapping across {num_gpus} GPUs")
|
518 |
+
|
519 |
+
# Get all layer names from state dict
|
520 |
+
layer_names = [name for name in state_dict.keys() if "layers" in name]
|
521 |
+
backbone_layers = [name for name in layer_names if "backbone.layers" in name]
|
522 |
+
decoder_layers = [name for name in layer_names if "decoder.layers" in name]
|
523 |
+
|
524 |
+
# Count number of backbone and decoder layers
|
525 |
+
backbone_layer_indices = set()
|
526 |
+
for name in backbone_layers:
|
527 |
+
parts = name.split('.')
|
528 |
+
if len(parts) > 2:
|
529 |
+
try:
|
530 |
+
backbone_layer_indices.add(int(parts[2]))
|
531 |
+
except ValueError:
|
532 |
+
pass
|
533 |
+
|
534 |
+
decoder_layer_indices = set()
|
535 |
+
for name in decoder_layers:
|
536 |
+
parts = name.split('.')
|
537 |
+
if len(parts) > 2:
|
538 |
+
try:
|
539 |
+
decoder_layer_indices.add(int(parts[2]))
|
540 |
+
except ValueError:
|
541 |
+
pass
|
542 |
+
|
543 |
+
num_backbone_layers = len(backbone_layer_indices)
|
544 |
+
num_decoder_layers = len(decoder_layer_indices)
|
545 |
+
|
546 |
+
# Create device map
|
547 |
+
device_map = {}
|
548 |
+
|
549 |
+
if strategy == "balanced":
|
550 |
+
# Distribute layers evenly across GPUs
|
551 |
+
layers_per_gpu = (num_backbone_layers + num_decoder_layers) // num_gpus
|
552 |
+
remainder = (num_backbone_layers + num_decoder_layers) % num_gpus
|
553 |
+
|
554 |
+
# Assign backbone layers
|
555 |
+
for i in backbone_layer_indices:
|
556 |
+
gpu_idx = min(i // layers_per_gpu, num_gpus - 1)
|
557 |
+
device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}"
|
558 |
+
|
559 |
+
# Assign decoder layers
|
560 |
+
for i in decoder_layer_indices:
|
561 |
+
gpu_idx = min((i + num_backbone_layers) // layers_per_gpu, num_gpus - 1)
|
562 |
+
device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}"
|
563 |
+
|
564 |
+
elif strategy == "sequential":
|
565 |
+
# Fill each GPU sequentially
|
566 |
+
# Backbone layers on first GPU(s)
|
567 |
+
backbone_per_gpu = max(1, num_backbone_layers // ((num_gpus + 1) // 2))
|
568 |
+
for i in backbone_layer_indices:
|
569 |
+
gpu_idx = min(i // backbone_per_gpu, (num_gpus + 1) // 2 - 1)
|
570 |
+
device_map[f"backbone.layers.{i}"] = f"cuda:{gpu_idx}"
|
571 |
+
|
572 |
+
# Decoder layers on remaining GPU(s)
|
573 |
+
decoder_per_gpu = max(1, num_decoder_layers // (num_gpus - (num_gpus + 1) // 2 + 1))
|
574 |
+
for i in decoder_layer_indices:
|
575 |
+
gpu_idx = min(i // decoder_per_gpu + (num_gpus + 1) // 2 - 1, num_gpus - 1)
|
576 |
+
device_map[f"decoder.layers.{i}"] = f"cuda:{gpu_idx}"
|
577 |
+
|
578 |
+
# Assign embeddings and other components
|
579 |
+
device_map["text_embeddings"] = "cuda:0"
|
580 |
+
device_map["audio_embeddings"] = "cuda:0"
|
581 |
+
device_map["projection"] = "cuda:0"
|
582 |
+
device_map["codebook0_head"] = "cuda:0"
|
583 |
+
device_map["audio_head"] = "cuda:0"
|
584 |
+
|
585 |
+
# Load state dict with device mapping
|
586 |
+
model.load_state_dict(state_dict)
|
587 |
+
|
588 |
+
# Move model parts to assigned devices
|
589 |
+
for name, device in device_map.items():
|
590 |
+
if "backbone.layers" in name:
|
591 |
+
layer_idx = int(name.split('.')[-1])
|
592 |
+
if hasattr(model.backbone, 'layers') and layer_idx < len(model.backbone.layers):
|
593 |
+
model.backbone.layers[layer_idx] = model.backbone.layers[layer_idx].to(device)
|
594 |
+
elif "decoder.layers" in name:
|
595 |
+
layer_idx = int(name.split('.')[-1])
|
596 |
+
if hasattr(model.decoder, 'layers') and layer_idx < len(model.decoder.layers):
|
597 |
+
model.decoder.layers[layer_idx] = model.decoder.layers[layer_idx].to(device)
|
598 |
+
elif hasattr(model, name):
|
599 |
+
setattr(model, name, getattr(model, name).to(device))
|
600 |
+
|
601 |
+
logger.info(f"Model distributed across GPUs with {strategy} strategy")
|
602 |
+
return model
|
603 |
+
|
604 |
+
def load_csm_1b(ckpt_path: str = "ckpt.pt", device: str = "cuda", device_map: str = None) -> Generator:
|
605 |
+
"""Load CSM-1B model and create generator with performance optimizations.
|
606 |
+
|
607 |
+
Args:
|
608 |
+
ckpt_path: Path to model checkpoint
|
609 |
+
device: Device to load model on ('cuda', 'cpu', or specific CUDA device)
|
610 |
+
device_map: Optional device mapping strategy ('auto', 'balanced', 'sequential', or None)
|
611 |
+
|
612 |
+
Returns:
|
613 |
+
Generator instance with optimized settings
|
614 |
+
"""
|
615 |
+
try:
|
616 |
+
# Import models module for CSM
|
617 |
+
from app.torchtune_models import Model, ModelArgs
|
618 |
+
|
619 |
+
# Create model
|
620 |
+
model_args = ModelArgs(
|
621 |
+
backbone_flavor="llama-1B",
|
622 |
+
decoder_flavor="llama-100M",
|
623 |
+
text_vocab_size=128256,
|
624 |
+
audio_vocab_size=2051,
|
625 |
+
audio_num_codebooks=32,
|
626 |
+
)
|
627 |
+
|
628 |
+
# Load model
|
629 |
+
logger.info(f"Loading CSM-1B model from {ckpt_path} with device={device}, device_map={device_map}")
|
630 |
+
|
631 |
+
# Check for CUDA availability
|
632 |
+
cuda_available = device == "cuda" and torch.cuda.is_available()
|
633 |
+
|
634 |
+
# Set up torch for optimized inference
|
635 |
+
if cuda_available:
|
636 |
+
# Check if we should enable TF32 (faster but slightly less precise)
|
637 |
+
enable_tf32 = os.environ.get("ENABLE_TF32", "true").lower() == "true"
|
638 |
+
if enable_tf32:
|
639 |
+
logger.info("Enabling TF32 for faster matrix multiplications")
|
640 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
641 |
+
torch.backends.cudnn.allow_tf32 = True
|
642 |
+
|
643 |
+
# Check for available precision modes
|
644 |
+
use_bfloat16 = torch.cuda.is_bf16_supported()
|
645 |
+
use_float16 = not use_bfloat16 and torch.cuda.is_available() # Fallback to float16
|
646 |
+
|
647 |
+
if use_bfloat16:
|
648 |
+
dtype = torch.bfloat16
|
649 |
+
logger.info("Using bfloat16 precision for faster inference")
|
650 |
+
elif use_float16:
|
651 |
+
dtype = torch.float16
|
652 |
+
logger.info("Using float16 precision for faster inference")
|
653 |
+
else:
|
654 |
+
dtype = torch.float32
|
655 |
+
logger.info("Using float32 precision (mixed precision not available)")
|
656 |
+
|
657 |
+
# Enable Flash Attention if available
|
658 |
+
try:
|
659 |
+
import flash_attn
|
660 |
+
if os.environ.get("ENABLE_FLASH_ATTN", "true").lower() == "true":
|
661 |
+
logger.info("Flash Attention detected - enabling for faster attention")
|
662 |
+
os.environ["PYTORCH_FLASH_ATTENTION_ENABLED"] = "1"
|
663 |
+
except ImportError:
|
664 |
+
logger.info("Flash Attention not available (install flash-attn for faster inference)")
|
665 |
+
else:
|
666 |
+
# CPU-only mode
|
667 |
+
dtype = torch.float32
|
668 |
+
logger.info("Using CPU mode with float32 precision")
|
669 |
+
|
670 |
+
# Check for quantization
|
671 |
+
enable_quantization = os.environ.get("ENABLE_QUANTIZATION", "false").lower() == "true"
|
672 |
+
is_quantized = False
|
673 |
+
|
674 |
+
# Check for multi-GPU setup
|
675 |
+
if device_map and torch.cuda.device_count() > 1:
|
676 |
+
logger.info(f"Using device_map={device_map} across {torch.cuda.device_count()} GPUs")
|
677 |
+
|
678 |
+
# Create model with device map
|
679 |
+
model = Model(model_args)
|
680 |
+
|
681 |
+
# Load state dict
|
682 |
+
state_dict = torch.load(ckpt_path, map_location='cpu')
|
683 |
+
|
684 |
+
# Try quantization before device mapping if enabled
|
685 |
+
if enable_quantization and cuda_available:
|
686 |
+
try:
|
687 |
+
from bitsandbytes.nn import Linear8bitLt
|
688 |
+
|
689 |
+
def replace_with_8bit(model):
|
690 |
+
"""Replace linear layers with 8-bit quantized versions"""
|
691 |
+
for name, module in model.named_modules():
|
692 |
+
if isinstance(module, torch.nn.Linear) and module.out_features > 256:
|
693 |
+
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
|
694 |
+
parent = model
|
695 |
+
if parent_name:
|
696 |
+
for attr in parent_name.split('.'):
|
697 |
+
parent = getattr(parent, attr)
|
698 |
+
child_name = name.rsplit('.', 1)[1] if '.' in name else name
|
699 |
+
setattr(parent, child_name, Linear8bitLt.from_float(module))
|
700 |
+
return model
|
701 |
+
|
702 |
+
logger.info("Applying 8-bit quantization to linear layers")
|
703 |
+
model = replace_with_8bit(model)
|
704 |
+
is_quantized = True
|
705 |
+
except ImportError:
|
706 |
+
logger.warning("bitsandbytes not available, skipping quantization")
|
707 |
+
|
708 |
+
# Apply device mapping
|
709 |
+
if device_map == "auto":
|
710 |
+
# Use accelerate for automatic device mapping
|
711 |
+
try:
|
712 |
+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
713 |
+
|
714 |
+
# Initialize empty model
|
715 |
+
with init_empty_weights():
|
716 |
+
empty_model = Model(model_args)
|
717 |
+
|
718 |
+
# Load and dispatch model across GPUs
|
719 |
+
model = load_checkpoint_and_dispatch(
|
720 |
+
empty_model,
|
721 |
+
ckpt_path,
|
722 |
+
device_map="auto",
|
723 |
+
no_split_module_classes=["TransformerLayer"],
|
724 |
+
# Offload CPU if very large model
|
725 |
+
offload_folder="offload" if os.environ.get("OFFLOAD_TO_CPU", "false").lower() == "true" else None
|
726 |
+
)
|
727 |
+
logger.info("Model loaded with automatic device mapping")
|
728 |
+
except ImportError:
|
729 |
+
logger.warning("accelerate package not found, falling back to manual device mapping")
|
730 |
+
model = _manual_device_map(model, state_dict, "balanced")
|
731 |
+
except Exception as mapping_error:
|
732 |
+
logger.error(f"Auto device mapping failed: {mapping_error}, falling back to manual")
|
733 |
+
model = _manual_device_map(model, state_dict, "balanced")
|
734 |
+
else:
|
735 |
+
# Manual device mapping
|
736 |
+
model = _manual_device_map(model, state_dict, device_map or "balanced")
|
737 |
+
else:
|
738 |
+
# Single GPU or CPU setup
|
739 |
+
|
740 |
+
# Try quantization before loading if enabled (GPU only)
|
741 |
+
if enable_quantization and cuda_available and not is_quantized:
|
742 |
+
try:
|
743 |
+
# First load to CPU for quantization
|
744 |
+
model = Model(model_args).to("cpu")
|
745 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
746 |
+
model.load_state_dict(state_dict)
|
747 |
+
|
748 |
+
from bitsandbytes.nn import Linear8bitLt
|
749 |
+
|
750 |
+
def replace_with_8bit(model):
|
751 |
+
"""Replace linear layers with 8-bit quantized versions"""
|
752 |
+
for name, module in model.named_modules():
|
753 |
+
if isinstance(module, torch.nn.Linear) and module.out_features > 256:
|
754 |
+
parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
|
755 |
+
parent = model
|
756 |
+
if parent_name:
|
757 |
+
for attr in parent_name.split('.'):
|
758 |
+
parent = getattr(parent, attr)
|
759 |
+
child_name = name.rsplit('.', 1)[1] if '.' in name else name
|
760 |
+
setattr(parent, child_name, Linear8bitLt.from_float(module))
|
761 |
+
return model
|
762 |
+
|
763 |
+
logger.info("Applying 8-bit quantization to linear layers")
|
764 |
+
model = replace_with_8bit(model)
|
765 |
+
model = model.to(device=device)
|
766 |
+
is_quantized = True
|
767 |
+
except ImportError:
|
768 |
+
logger.warning("bitsandbytes not available, loading without quantization")
|
769 |
+
# Load the standard way
|
770 |
+
model = Model(model_args).to(device=device, dtype=dtype)
|
771 |
+
state_dict = torch.load(ckpt_path, map_location=device)
|
772 |
+
model.load_state_dict(state_dict)
|
773 |
+
except Exception as quant_error:
|
774 |
+
logger.error(f"Quantization failed: {quant_error}, loading without quantization")
|
775 |
+
# Load the standard way
|
776 |
+
model = Model(model_args).to(device=device, dtype=dtype)
|
777 |
+
state_dict = torch.load(ckpt_path, map_location=device)
|
778 |
+
model.load_state_dict(state_dict)
|
779 |
+
else:
|
780 |
+
# Standard load without quantization
|
781 |
+
model = Model(model_args).to(device=device, dtype=dtype)
|
782 |
+
state_dict = torch.load(ckpt_path, map_location=device)
|
783 |
+
model.load_state_dict(state_dict)
|
784 |
+
|
785 |
+
# Apply torch.compile if available (PyTorch 2.0+)
|
786 |
+
compile_mode = os.environ.get("TORCH_COMPILE_MODE", "none")
|
787 |
+
if hasattr(torch, 'compile') and compile_mode != "none" and cuda_available:
|
788 |
+
try:
|
789 |
+
logger.info(f"Using torch.compile with mode '{compile_mode}' for faster inference")
|
790 |
+
if compile_mode == "default":
|
791 |
+
model = torch.compile(model)
|
792 |
+
else:
|
793 |
+
model = torch.compile(model, mode=compile_mode)
|
794 |
+
except Exception as compile_error:
|
795 |
+
logger.warning(f"Torch compile failed (requires PyTorch 2.0+): {compile_error}")
|
796 |
+
|
797 |
+
# Try to optimize CUDA graphs for faster inference (advanced)
|
798 |
+
use_cuda_graphs = os.environ.get("USE_CUDA_GRAPHS", "false").lower() == "true"
|
799 |
+
if use_cuda_graphs and cuda_available and hasattr(torch.cuda, 'CUDAGraph'):
|
800 |
+
try:
|
801 |
+
logger.info("Setting up CUDA graphs for repeated inference patterns")
|
802 |
+
# This requires custom integration inside the model's forward method
|
803 |
+
# Just flagging that CUDA graphs should be used
|
804 |
+
model.use_cuda_graphs = True
|
805 |
+
except Exception as cuda_graph_error:
|
806 |
+
logger.warning(f"CUDA graphs setup failed: {cuda_graph_error}")
|
807 |
+
model.use_cuda_graphs = False
|
808 |
+
|
809 |
+
# Set optimal settings for CUDA context
|
810 |
+
if cuda_available:
|
811 |
+
# Set benchmark mode for hardware-specific optimizations
|
812 |
+
torch.backends.cudnn.benchmark = True
|
813 |
+
# Clean up CUDA cache before creating generator
|
814 |
+
torch.cuda.empty_cache()
|
815 |
+
# Ensure all CUDA work is completed to avoid launch delays
|
816 |
+
torch.cuda.synchronize()
|
817 |
+
|
818 |
+
# Create generator
|
819 |
+
logger.info("Creating generator with optimized settings")
|
820 |
+
generator = Generator(model)
|
821 |
+
|
822 |
+
# Log memory usage if on CUDA
|
823 |
+
if cuda_available:
|
824 |
+
memory_allocated = torch.cuda.memory_allocated() / (1024**3)
|
825 |
+
memory_reserved = torch.cuda.memory_reserved() / (1024**3)
|
826 |
+
logger.info(f"Model loaded, CUDA memory: {memory_allocated:.2f}GB allocated, {memory_reserved:.2f}GB reserved")
|
827 |
+
|
828 |
+
logger.info(f"Generator created successfully: precision={dtype}, quantized={is_quantized}")
|
829 |
+
return generator
|
830 |
+
except Exception as e:
|
831 |
+
logger.error(f"Failed to load CSM-1B model: {e}")
|
832 |
+
import traceback
|
833 |
+
logger.error(traceback.format_exc())
|
834 |
+
raise
|
app/main.py
ADDED
@@ -0,0 +1,647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
CSM-1B TTS API main application.
|
3 |
+
Provides an OpenAI-compatible API for the CSM-1B text-to-speech model.
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
import tempfile
|
8 |
+
import logging
|
9 |
+
from logging.handlers import RotatingFileHandler
|
10 |
+
import traceback
|
11 |
+
import asyncio
|
12 |
+
import glob
|
13 |
+
import torch
|
14 |
+
import uvicorn
|
15 |
+
from contextlib import asynccontextmanager
|
16 |
+
from fastapi import FastAPI, Depends, HTTPException, Request, Response
|
17 |
+
from fastapi.middleware.cors import CORSMiddleware
|
18 |
+
from fastapi.responses import RedirectResponse, FileResponse
|
19 |
+
from fastapi.staticfiles import StaticFiles
|
20 |
+
from app.api.routes import router as api_router
|
21 |
+
|
22 |
+
# Setup logging
|
23 |
+
os.makedirs("logs", exist_ok=True)
|
24 |
+
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
25 |
+
|
26 |
+
# Console handler
|
27 |
+
console_handler = logging.StreamHandler()
|
28 |
+
console_handler.setFormatter(logging.Formatter(log_format))
|
29 |
+
|
30 |
+
# File handler
|
31 |
+
file_handler = RotatingFileHandler(
|
32 |
+
"logs/csm_tts_api.log",
|
33 |
+
maxBytes=10*1024*1024, # 10MB
|
34 |
+
backupCount=5
|
35 |
+
)
|
36 |
+
file_handler.setFormatter(logging.Formatter(log_format))
|
37 |
+
|
38 |
+
# Configure root logger
|
39 |
+
logging.basicConfig(
|
40 |
+
level=logging.INFO,
|
41 |
+
format=log_format,
|
42 |
+
handlers=[console_handler, file_handler]
|
43 |
+
)
|
44 |
+
logger = logging.getLogger(__name__)
|
45 |
+
logger.info("Starting CSM-1B TTS API")
|
46 |
+
|
47 |
+
@asynccontextmanager
|
48 |
+
async def lifespan(app: FastAPI):
|
49 |
+
"""Application lifespan manager for startup and shutdown events."""
|
50 |
+
# STARTUP EVENT
|
51 |
+
logger.info("Starting application initialization")
|
52 |
+
app.state.startup_time = time.time()
|
53 |
+
app.state.generator = None # Will be populated later if model loads
|
54 |
+
app.state.logger = logger # Make logger available to routes
|
55 |
+
|
56 |
+
# Create necessary directories - use persistent locations
|
57 |
+
os.makedirs("/app/models", exist_ok=True)
|
58 |
+
os.makedirs("/app/tokenizers", exist_ok=True)
|
59 |
+
os.makedirs("/app/voice_memories", exist_ok=True)
|
60 |
+
os.makedirs("/app/voice_references", exist_ok=True)
|
61 |
+
os.makedirs("/app/voice_profiles", exist_ok=True)
|
62 |
+
os.makedirs("/app/cloned_voices", exist_ok=True)
|
63 |
+
os.makedirs("/app/audio_cache", exist_ok=True)
|
64 |
+
os.makedirs("/app/static", exist_ok=True)
|
65 |
+
|
66 |
+
# Set tokenizer cache
|
67 |
+
try:
|
68 |
+
os.environ["TRANSFORMERS_CACHE"] = "/app/tokenizers"
|
69 |
+
logger.info(f"Set tokenizer cache to: {os.environ['TRANSFORMERS_CACHE']}")
|
70 |
+
except Exception as e:
|
71 |
+
logger.error(f"Error setting tokenizer cache: {e}")
|
72 |
+
|
73 |
+
# Install additional dependencies if needed
|
74 |
+
try:
|
75 |
+
import scipy
|
76 |
+
import soundfile
|
77 |
+
logger.info("Audio processing dependencies available")
|
78 |
+
except ImportError as e:
|
79 |
+
logger.warning(f"Audio processing dependency missing: {e}. Some audio enhancements may not work.")
|
80 |
+
logger.warning("Consider installing: pip install scipy soundfile")
|
81 |
+
|
82 |
+
# Check CUDA availability
|
83 |
+
cuda_available = torch.cuda.is_available()
|
84 |
+
if cuda_available:
|
85 |
+
device_count = torch.cuda.device_count()
|
86 |
+
device_name = torch.cuda.get_device_name(0) if device_count > 0 else "unknown"
|
87 |
+
logger.info(f"CUDA is available: {device_count} device(s). Using {device_name}")
|
88 |
+
# Report CUDA memory
|
89 |
+
if hasattr(torch.cuda, 'get_device_properties'):
|
90 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory
|
91 |
+
logger.info(f"Total CUDA memory: {total_memory / (1024**3):.2f} GB")
|
92 |
+
else:
|
93 |
+
logger.warning("CUDA is not available. Using CPU (this will be slow)")
|
94 |
+
|
95 |
+
# Determine device and device mapping
|
96 |
+
device = "cuda" if cuda_available else "cpu"
|
97 |
+
device_map = os.environ.get("CSM_DEVICE_MAP", None) # Options: "auto", "balanced", "sequential"
|
98 |
+
if device_map and cuda_available:
|
99 |
+
if torch.cuda.device_count() > 1:
|
100 |
+
logger.info(f"Using device mapping strategy: {device_map} across {torch.cuda.device_count()} GPUs")
|
101 |
+
else:
|
102 |
+
logger.info("Device mapping requested but only one GPU available, ignoring device_map")
|
103 |
+
device_map = None
|
104 |
+
else:
|
105 |
+
device_map = None
|
106 |
+
|
107 |
+
logger.info(f"Using device: {device}")
|
108 |
+
app.state.device = device
|
109 |
+
app.state.device_map = device_map
|
110 |
+
|
111 |
+
# Check if model file exists
|
112 |
+
model_path = os.path.join("/app/models", "ckpt.pt")
|
113 |
+
if not os.path.exists(model_path):
|
114 |
+
# Try to download at runtime if not present
|
115 |
+
logger.info("Model not found. Attempting to download...")
|
116 |
+
try:
|
117 |
+
from huggingface_hub import hf_hub_download, login
|
118 |
+
# Check for token in environment
|
119 |
+
hf_token = os.environ.get("HF_TOKEN")
|
120 |
+
if hf_token:
|
121 |
+
logger.info("Logging in to Hugging Face using provided token")
|
122 |
+
login(token=hf_token)
|
123 |
+
logger.info("Downloading CSM-1B model from Hugging Face...")
|
124 |
+
download_start = time.time()
|
125 |
+
model_path = hf_hub_download(
|
126 |
+
repo_id="sesame/csm-1b",
|
127 |
+
filename="ckpt.pt",
|
128 |
+
local_dir="/app/models"
|
129 |
+
)
|
130 |
+
download_time = time.time() - download_start
|
131 |
+
logger.info(f"Model downloaded to {model_path} in {download_time:.2f} seconds")
|
132 |
+
except Exception as e:
|
133 |
+
error_stack = traceback.format_exc()
|
134 |
+
logger.error(f"Error downloading model: {str(e)}\n{error_stack}")
|
135 |
+
logger.error("Please build the image with HF_TOKEN to download the model")
|
136 |
+
logger.error("Starting without model - API will return 503 Service Unavailable")
|
137 |
+
else:
|
138 |
+
logger.info(f"Found existing model at {model_path}")
|
139 |
+
logger.info(f"Model size: {os.path.getsize(model_path) / (1024 * 1024):.2f} MB")
|
140 |
+
|
141 |
+
# Load the model
|
142 |
+
try:
|
143 |
+
logger.info("Loading CSM-1B model...")
|
144 |
+
load_start = time.time()
|
145 |
+
from app.generator import load_csm_1b
|
146 |
+
app.state.generator = load_csm_1b(model_path, device, device_map)
|
147 |
+
load_time = time.time() - load_start
|
148 |
+
logger.info(f"Model loaded successfully in {load_time:.2f} seconds")
|
149 |
+
|
150 |
+
# Store sample rate in app state
|
151 |
+
app.state.sample_rate = app.state.generator.sample_rate
|
152 |
+
logger.info(f"Model sample rate: {app.state.sample_rate} Hz")
|
153 |
+
|
154 |
+
# Initialize voice enhancement system (this will create proper voice profiles)
|
155 |
+
logger.info("Initializing voice enhancement system...")
|
156 |
+
try:
|
157 |
+
from app.voice_enhancement import initialize_voice_profiles, save_voice_profiles
|
158 |
+
initialize_voice_profiles()
|
159 |
+
app.state.voice_enhancement_enabled = True
|
160 |
+
logger.info("Voice profiles initialized successfully")
|
161 |
+
except Exception as e:
|
162 |
+
error_stack = traceback.format_exc()
|
163 |
+
logger.error(f"Error initializing voice profiles: {str(e)}\n{error_stack}")
|
164 |
+
logger.warning("Voice enhancement features will be limited")
|
165 |
+
app.state.voice_enhancement_enabled = False
|
166 |
+
|
167 |
+
# Initialize voice memory system for consistent generation
|
168 |
+
logger.info("Initializing voice memory system...")
|
169 |
+
try:
|
170 |
+
from app.voice_memory import initialize_voices
|
171 |
+
initialize_voices(app.state.sample_rate)
|
172 |
+
app.state.voice_memory_enabled = True
|
173 |
+
logger.info("Voice memory system initialized")
|
174 |
+
except Exception as e:
|
175 |
+
logger.warning(f"Error initializing voice memory: {e}")
|
176 |
+
app.state.voice_memory_enabled = False
|
177 |
+
|
178 |
+
# Initialize voice cloning system
|
179 |
+
try:
|
180 |
+
logger.info("Initializing voice cloning system...")
|
181 |
+
from app.voice_cloning import VoiceCloner, CLONED_VOICES_DIR
|
182 |
+
# Update the cloned voices directory to use the persistent volume
|
183 |
+
app.state.cloned_voices_dir = "/app/cloned_voices" # Store path in app state for access
|
184 |
+
os.makedirs(app.state.cloned_voices_dir, exist_ok=True)
|
185 |
+
CLONED_VOICES_DIR = app.state.cloned_voices_dir # Update the module constant
|
186 |
+
|
187 |
+
# Initialize the voice cloner with proper device
|
188 |
+
app.state.voice_cloner = VoiceCloner(app.state.generator, device=device)
|
189 |
+
|
190 |
+
# Make sure existing voices are loaded
|
191 |
+
app.state.voice_cloner._load_existing_voices()
|
192 |
+
|
193 |
+
# Log the available voices
|
194 |
+
cloned_voices = app.state.voice_cloner.list_voices()
|
195 |
+
logger.info(f"Voice cloning system initialized with {len(cloned_voices)} existing voices")
|
196 |
+
for voice in cloned_voices:
|
197 |
+
logger.info(f" - {voice.name} (ID: {voice.id}, Speaker ID: {voice.speaker_id})")
|
198 |
+
|
199 |
+
# Flag for voice cloning availability
|
200 |
+
app.state.voice_cloning_enabled = True
|
201 |
+
except Exception as e:
|
202 |
+
error_stack = traceback.format_exc()
|
203 |
+
logger.error(f"Error initializing voice cloning: {e}\n{error_stack}")
|
204 |
+
logger.warning("Voice cloning features will not be available")
|
205 |
+
app.state.voice_cloning_enabled = False
|
206 |
+
|
207 |
+
# Create prompt templates for consistent generation
|
208 |
+
logger.info("Setting up prompt engineering templates...")
|
209 |
+
try:
|
210 |
+
from app.prompt_engineering import initialize_templates
|
211 |
+
app.state.prompt_templates = initialize_templates()
|
212 |
+
logger.info("Prompt templates initialized")
|
213 |
+
except Exception as e:
|
214 |
+
error_stack = traceback.format_exc()
|
215 |
+
logger.error(f"Error initializing prompt templates: {e}\n{error_stack}")
|
216 |
+
logger.warning("Voice consistency features will be limited")
|
217 |
+
|
218 |
+
# Generate voice reference samples (runs in background to avoid blocking startup)
|
219 |
+
async def generate_samples_async():
|
220 |
+
try:
|
221 |
+
logger.info("Starting voice reference generation (background task)...")
|
222 |
+
from app.voice_enhancement import create_voice_segments
|
223 |
+
create_voice_segments(app.state)
|
224 |
+
logger.info("Voice reference generation completed")
|
225 |
+
except Exception as e:
|
226 |
+
error_stack = traceback.format_exc()
|
227 |
+
logger.error(f"Error in voice reference generation: {str(e)}\n{error_stack}")
|
228 |
+
|
229 |
+
# Start as a background task
|
230 |
+
asyncio.create_task(generate_samples_async())
|
231 |
+
|
232 |
+
# Initialize voice cache for all voices (standard + cloned)
|
233 |
+
app.state.voice_cache = {}
|
234 |
+
|
235 |
+
# Add standard voices
|
236 |
+
standard_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
237 |
+
for voice in standard_voices:
|
238 |
+
app.state.voice_cache[voice] = []
|
239 |
+
|
240 |
+
# Add cloned voices to cache if they exist
|
241 |
+
if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
|
242 |
+
for voice in app.state.voice_cloner.list_voices():
|
243 |
+
app.state.voice_cache[voice.id] = []
|
244 |
+
# Also add by name for more flexible lookup
|
245 |
+
app.state.voice_cache[voice.name] = []
|
246 |
+
|
247 |
+
# Create mapping from voice name/id to speaker_id for easy lookup
|
248 |
+
app.state.voice_speaker_map = {
|
249 |
+
"alloy": 0, "echo": 1, "fable": 2, "onyx": 3, "nova": 4, "shimmer": 5
|
250 |
+
}
|
251 |
+
|
252 |
+
# Add cloned voices to the speaker map
|
253 |
+
if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
|
254 |
+
for voice in app.state.voice_cloner.list_voices():
|
255 |
+
app.state.voice_speaker_map[voice.id] = voice.speaker_id
|
256 |
+
app.state.voice_speaker_map[voice.name] = voice.speaker_id
|
257 |
+
app.state.voice_speaker_map[str(voice.speaker_id)] = voice.speaker_id
|
258 |
+
|
259 |
+
# Compile voice information for API
|
260 |
+
app.state.available_voices = standard_voices.copy()
|
261 |
+
if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
|
262 |
+
for voice in app.state.voice_cloner.list_voices():
|
263 |
+
app.state.available_voices.append(voice.id)
|
264 |
+
app.state.available_voices.append(voice.name)
|
265 |
+
|
266 |
+
# Store model information for API endpoints
|
267 |
+
app.state.model_info = {
|
268 |
+
"name": "CSM-1B",
|
269 |
+
"device": device,
|
270 |
+
"device_map": device_map,
|
271 |
+
"sample_rate": app.state.sample_rate,
|
272 |
+
"standard_voices": standard_voices,
|
273 |
+
"cloned_voices": [v.id for v in app.state.voice_cloner.list_voices()] if app.state.voice_cloning_enabled else [],
|
274 |
+
"voice_enhancement_enabled": app.state.voice_enhancement_enabled,
|
275 |
+
"voice_memory_enabled": app.state.voice_memory_enabled,
|
276 |
+
"voice_cloning_enabled": app.state.voice_cloning_enabled,
|
277 |
+
"streaming_enabled": True
|
278 |
+
}
|
279 |
+
|
280 |
+
# Create a function to access all voices in a standardized format
|
281 |
+
def get_all_available_voices():
|
282 |
+
"""Helper function to get all available voices for API endpoints"""
|
283 |
+
# Standard voices with fixed descriptions
|
284 |
+
all_voices = [
|
285 |
+
{"voice_id": "alloy", "name": "Alloy", "description": "Balanced and natural"},
|
286 |
+
{"voice_id": "echo", "name": "Echo", "description": "Resonant and deeper"},
|
287 |
+
{"voice_id": "fable", "name": "Fable", "description": "Bright and higher-pitched"},
|
288 |
+
{"voice_id": "onyx", "name": "Onyx", "description": "Deep and authoritative"},
|
289 |
+
{"voice_id": "nova", "name": "Nova", "description": "Warm and smooth"},
|
290 |
+
{"voice_id": "shimmer", "name": "Shimmer", "description": "Light and airy"}
|
291 |
+
]
|
292 |
+
|
293 |
+
# Add cloned voices if available
|
294 |
+
if app.state.voice_cloning_enabled and hasattr(app.state, "voice_cloner"):
|
295 |
+
for voice in app.state.voice_cloner.list_voices():
|
296 |
+
all_voices.append({
|
297 |
+
"voice_id": voice.id,
|
298 |
+
"name": voice.name,
|
299 |
+
"description": voice.description or f"Cloned voice: {voice.name}"
|
300 |
+
})
|
301 |
+
|
302 |
+
return all_voices
|
303 |
+
|
304 |
+
app.state.get_all_voices = get_all_available_voices
|
305 |
+
|
306 |
+
# Add helper function to lookup voice info
|
307 |
+
def get_voice_info(voice_identifier):
|
308 |
+
"""Look up voice information based on name, ID, or speaker_id"""
|
309 |
+
# Check standard voices
|
310 |
+
if voice_identifier in standard_voices:
|
311 |
+
return {
|
312 |
+
"type": "standard",
|
313 |
+
"voice_id": voice_identifier,
|
314 |
+
"name": voice_identifier,
|
315 |
+
"speaker_id": standard_voices.index(voice_identifier)
|
316 |
+
}
|
317 |
+
|
318 |
+
# Look for cloned voice
|
319 |
+
if not app.state.voice_cloning_enabled or not hasattr(app.state, "voice_cloner"):
|
320 |
+
return None
|
321 |
+
|
322 |
+
# Check by ID
|
323 |
+
if voice_identifier in app.state.voice_cloner.cloned_voices:
|
324 |
+
voice = app.state.voice_cloner.cloned_voices[voice_identifier]
|
325 |
+
return {
|
326 |
+
"type": "cloned",
|
327 |
+
"voice_id": voice.id,
|
328 |
+
"name": voice.name,
|
329 |
+
"speaker_id": voice.speaker_id
|
330 |
+
}
|
331 |
+
|
332 |
+
# Check by name
|
333 |
+
for v_id, voice in app.state.voice_cloner.cloned_voices.items():
|
334 |
+
if voice.name == voice_identifier:
|
335 |
+
return {
|
336 |
+
"type": "cloned",
|
337 |
+
"voice_id": voice.id,
|
338 |
+
"name": voice.name,
|
339 |
+
"speaker_id": voice.speaker_id
|
340 |
+
}
|
341 |
+
|
342 |
+
# Check by speaker_id (string representation)
|
343 |
+
try:
|
344 |
+
speaker_id = int(voice_identifier)
|
345 |
+
# Check if any cloned voice has this speaker_id
|
346 |
+
for v_id, voice in app.state.voice_cloner.cloned_voices.items():
|
347 |
+
if voice.speaker_id == speaker_id:
|
348 |
+
return {
|
349 |
+
"type": "cloned",
|
350 |
+
"voice_id": voice.id,
|
351 |
+
"name": voice.name,
|
352 |
+
"speaker_id": speaker_id
|
353 |
+
}
|
354 |
+
except (ValueError, TypeError):
|
355 |
+
pass
|
356 |
+
|
357 |
+
# No match found
|
358 |
+
return None
|
359 |
+
|
360 |
+
app.state.get_voice_info = get_voice_info
|
361 |
+
|
362 |
+
# Set up audio cache
|
363 |
+
app.state.audio_cache_enabled = os.environ.get("ENABLE_AUDIO_CACHE", "true").lower() == "true"
|
364 |
+
if app.state.audio_cache_enabled:
|
365 |
+
app.state.audio_cache_dir = "/app/audio_cache"
|
366 |
+
logger.info(f"Audio cache enabled, cache dir: {app.state.audio_cache_dir}")
|
367 |
+
|
368 |
+
# Log GPU utilization after model loading
|
369 |
+
if cuda_available:
|
370 |
+
memory_allocated = torch.cuda.memory_allocated() / (1024**3)
|
371 |
+
memory_reserved = torch.cuda.memory_reserved() / (1024**3)
|
372 |
+
logger.info(f"GPU memory: {memory_allocated:.2f} GB allocated, {memory_reserved:.2f} GB reserved")
|
373 |
+
|
374 |
+
if torch.cuda.device_count() > 1 and device_map:
|
375 |
+
logger.info("Multi-GPU setup active with the following memory usage:")
|
376 |
+
for i in range(torch.cuda.device_count()):
|
377 |
+
memory_allocated = torch.cuda.memory_allocated(i) / (1024**3)
|
378 |
+
memory_reserved = torch.cuda.memory_reserved(i) / (1024**3)
|
379 |
+
logger.info(f"GPU {i}: {memory_allocated:.2f} GB allocated, {memory_reserved:.2f} GB reserved")
|
380 |
+
|
381 |
+
# Set up scheduled tasks
|
382 |
+
try:
|
383 |
+
# Create a background task for periodic voice profile backup
|
384 |
+
async def periodic_voice_profile_backup(interval_hours=6):
|
385 |
+
"""Periodically save voice profiles to persistent storage."""
|
386 |
+
while True:
|
387 |
+
try:
|
388 |
+
# Wait for the specified interval
|
389 |
+
await asyncio.sleep(interval_hours * 3600)
|
390 |
+
|
391 |
+
# Log the backup
|
392 |
+
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
393 |
+
logger.info(f"Scheduled voice profile backup started at {timestamp}")
|
394 |
+
|
395 |
+
# Save voice profiles
|
396 |
+
if hasattr(app.state, "voice_enhancement_enabled") and app.state.voice_enhancement_enabled:
|
397 |
+
from app.voice_enhancement import save_voice_profiles
|
398 |
+
save_voice_profiles()
|
399 |
+
logger.info("Voice profiles saved successfully")
|
400 |
+
|
401 |
+
# Save voice memories
|
402 |
+
if hasattr(app.state, "voice_memory_enabled") and app.state.voice_memory_enabled:
|
403 |
+
from app.voice_memory import VOICE_MEMORIES
|
404 |
+
for voice_name, memory in VOICE_MEMORIES.items():
|
405 |
+
memory.save()
|
406 |
+
logger.info("Voice memories saved successfully")
|
407 |
+
|
408 |
+
except Exception as e:
|
409 |
+
logger.error(f"Error in periodic voice profile backup: {e}")
|
410 |
+
|
411 |
+
# Start the scheduled task
|
412 |
+
asyncio.create_task(periodic_voice_profile_backup(interval_hours=6))
|
413 |
+
logger.info("Started scheduled voice profile backup task")
|
414 |
+
|
415 |
+
except Exception as e:
|
416 |
+
logger.warning(f"Failed to set up scheduled tasks: {e}")
|
417 |
+
|
418 |
+
logger.info(f"CSM-1B TTS API is ready on {device} with sample rate {app.state.sample_rate}")
|
419 |
+
logger.info(f"Standard voices: {standard_voices}")
|
420 |
+
cloned_count = len(app.state.voice_cloner.list_voices()) if app.state.voice_cloning_enabled else 0
|
421 |
+
logger.info(f"Cloned voices: {cloned_count}")
|
422 |
+
|
423 |
+
except Exception as e:
|
424 |
+
error_stack = traceback.format_exc()
|
425 |
+
logger.error(f"Error loading model: {str(e)}\n{error_stack}")
|
426 |
+
app.state.generator = None
|
427 |
+
|
428 |
+
# Calculate total startup time
|
429 |
+
startup_time = time.time() - app.state.startup_time
|
430 |
+
logger.info(f"Application startup completed in {startup_time:.2f} seconds")
|
431 |
+
|
432 |
+
yield # This is where the application runs
|
433 |
+
|
434 |
+
# SHUTDOWN EVENT
|
435 |
+
logger.info("Application shutdown initiated")
|
436 |
+
|
437 |
+
# Clean up model resources
|
438 |
+
if hasattr(app.state, "generator") and app.state.generator is not None:
|
439 |
+
try:
|
440 |
+
# Clean up CUDA memory if available
|
441 |
+
if torch.cuda.is_available():
|
442 |
+
logger.info("Clearing CUDA cache")
|
443 |
+
torch.cuda.empty_cache()
|
444 |
+
torch.cuda.synchronize()
|
445 |
+
except Exception as e:
|
446 |
+
logger.error(f"Error during CUDA cleanup: {e}")
|
447 |
+
|
448 |
+
# Save voice profiles if they've been updated
|
449 |
+
try:
|
450 |
+
if hasattr(app.state, "voice_enhancement_enabled") and app.state.voice_enhancement_enabled:
|
451 |
+
from app.voice_enhancement import save_voice_profiles
|
452 |
+
logger.info("Saving voice profiles...")
|
453 |
+
save_voice_profiles()
|
454 |
+
logger.info("Voice profiles saved successfully")
|
455 |
+
except Exception as e:
|
456 |
+
logger.error(f"Error saving voice profiles: {e}")
|
457 |
+
|
458 |
+
# Save voice memories if they've been updated
|
459 |
+
try:
|
460 |
+
if hasattr(app.state, "voice_memory_enabled") and app.state.voice_memory_enabled:
|
461 |
+
from app.voice_memory import VOICE_MEMORIES
|
462 |
+
logger.info("Saving voice memories...")
|
463 |
+
for voice_name, memory in VOICE_MEMORIES.items():
|
464 |
+
memory.save()
|
465 |
+
logger.info("Voice memories saved successfully")
|
466 |
+
except Exception as e:
|
467 |
+
logger.error(f"Error saving voice memories: {e}")
|
468 |
+
|
469 |
+
# Clean up any temporary files
|
470 |
+
try:
|
471 |
+
for temp_file in glob.glob(os.path.join(tempfile.gettempdir(), "csm_tts_*")):
|
472 |
+
try:
|
473 |
+
os.remove(temp_file)
|
474 |
+
logger.info(f"Removed temporary file: {temp_file}")
|
475 |
+
except:
|
476 |
+
pass
|
477 |
+
except Exception as e:
|
478 |
+
logger.warning(f"Error cleaning up temporary files: {e}")
|
479 |
+
|
480 |
+
logger.info("Application shutdown complete")
|
481 |
+
|
482 |
+
# Initialize FastAPI app
|
483 |
+
app = FastAPI(
|
484 |
+
title="CSM-1B TTS API",
|
485 |
+
description="OpenAI-compatible TTS API using the CSM-1B model from Sesame",
|
486 |
+
version="1.0.0",
|
487 |
+
lifespan=lifespan
|
488 |
+
)
|
489 |
+
|
490 |
+
# Add CORS middleware
|
491 |
+
app.add_middleware(
|
492 |
+
CORSMiddleware,
|
493 |
+
allow_origins=["*"],
|
494 |
+
allow_credentials=True,
|
495 |
+
allow_methods=["*"],
|
496 |
+
allow_headers=["*"],
|
497 |
+
)
|
498 |
+
|
499 |
+
# Create static and other required directories
|
500 |
+
os.makedirs("/app/static", exist_ok=True)
|
501 |
+
os.makedirs("/app/cloned_voices", exist_ok=True)
|
502 |
+
|
503 |
+
# Mount the static files directory
|
504 |
+
app.mount("/static", StaticFiles(directory="/app/static"), name="static")
|
505 |
+
|
506 |
+
# Include routers
|
507 |
+
app.include_router(api_router, prefix="/api/v1")
|
508 |
+
|
509 |
+
# Add OpenAI compatible route
|
510 |
+
app.include_router(api_router, prefix="/v1")
|
511 |
+
|
512 |
+
# Add voice cloning routes
|
513 |
+
from app.api.voice_cloning_routes import router as voice_cloning_router
|
514 |
+
app.include_router(voice_cloning_router, prefix="/api/v1")
|
515 |
+
app.include_router(voice_cloning_router, prefix="/v1")
|
516 |
+
|
517 |
+
# Add streaming routes
|
518 |
+
from app.api.streaming import router as streaming_router
|
519 |
+
app.include_router(streaming_router, prefix="/api/v1")
|
520 |
+
app.include_router(streaming_router, prefix="/v1")
|
521 |
+
|
522 |
+
# Middleware for request timing
|
523 |
+
@app.middleware("http")
|
524 |
+
async def add_process_time_header(request: Request, call_next):
|
525 |
+
"""Middleware to track request processing time."""
|
526 |
+
start_time = time.time()
|
527 |
+
response = await call_next(request)
|
528 |
+
process_time = time.time() - start_time
|
529 |
+
response.headers["X-Process-Time"] = str(process_time)
|
530 |
+
logger.debug(f"Request to {request.url.path} processed in {process_time:.3f} seconds")
|
531 |
+
return response
|
532 |
+
|
533 |
+
# Health check endpoint
|
534 |
+
@app.get("/health", include_in_schema=False)
|
535 |
+
async def health_check(request: Request):
|
536 |
+
"""Health check endpoint that returns the status of the API."""
|
537 |
+
model_status = "healthy" if hasattr(request.app.state, "generator") and request.app.state.generator is not None else "unhealthy"
|
538 |
+
uptime = time.time() - getattr(request.app.state, "startup_time", time.time())
|
539 |
+
|
540 |
+
# Get voice information
|
541 |
+
standard_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
542 |
+
cloned_voices = []
|
543 |
+
|
544 |
+
if hasattr(request.app.state, "voice_cloner") and request.app.state.voice_cloner is not None:
|
545 |
+
cloned_voices = [
|
546 |
+
{"id": v.id, "name": v.name, "speaker_id": v.speaker_id}
|
547 |
+
for v in request.app.state.voice_cloner.list_voices()
|
548 |
+
]
|
549 |
+
|
550 |
+
# Get CUDA memory stats if available
|
551 |
+
cuda_stats = None
|
552 |
+
if torch.cuda.is_available():
|
553 |
+
cuda_stats = {
|
554 |
+
"allocated_gb": torch.cuda.memory_allocated() / (1024**3),
|
555 |
+
"reserved_gb": torch.cuda.memory_reserved() / (1024**3)
|
556 |
+
}
|
557 |
+
|
558 |
+
return {
|
559 |
+
"status": model_status,
|
560 |
+
"uptime": f"{uptime:.2f} seconds",
|
561 |
+
"device": getattr(request.app.state, "device", "unknown"),
|
562 |
+
"model": "CSM-1B",
|
563 |
+
"standard_voices": standard_voices,
|
564 |
+
"cloned_voices": cloned_voices,
|
565 |
+
"cloned_voices_count": len(cloned_voices),
|
566 |
+
"sample_rate": getattr(request.app.state, "sample_rate", 0),
|
567 |
+
"enhancements": "enabled" if hasattr(request.app.state, "model_info") and
|
568 |
+
request.app.state.model_info.get("voice_enhancement_enabled", False) else "disabled",
|
569 |
+
"streaming": "enabled",
|
570 |
+
"cuda": cuda_stats,
|
571 |
+
"version": "1.0.0"
|
572 |
+
}
|
573 |
+
|
574 |
+
# Version endpoint
|
575 |
+
@app.get("/version", include_in_schema=False)
|
576 |
+
async def version():
|
577 |
+
"""Version endpoint that returns API version information."""
|
578 |
+
return {
|
579 |
+
"api_version": "1.0.0",
|
580 |
+
"model_version": "CSM-1B",
|
581 |
+
"compatible_with": "OpenAI TTS v1",
|
582 |
+
"enhancements": "voice consistency and audio quality v1.0",
|
583 |
+
"voice_cloning": "enabled" if hasattr(app.state, "voice_cloner") else "disabled",
|
584 |
+
"streaming": "enabled"
|
585 |
+
}
|
586 |
+
|
587 |
+
# Voice cloning UI endpoint
|
588 |
+
@app.get("/voice-cloning", include_in_schema=False)
|
589 |
+
async def voice_cloning_ui():
|
590 |
+
"""Voice cloning UI endpoint."""
|
591 |
+
return FileResponse("/app/static/voice-cloning.html")
|
592 |
+
|
593 |
+
# Streaming demo endpoint
|
594 |
+
@app.get("/streaming-demo", include_in_schema=False)
|
595 |
+
async def streaming_demo():
|
596 |
+
"""Streaming TTS demo endpoint."""
|
597 |
+
return FileResponse("/app/static/streaming-demo.html")
|
598 |
+
|
599 |
+
@app.get("/", include_in_schema=False)
|
600 |
+
async def root():
|
601 |
+
"""Root endpoint that redirects to docs."""
|
602 |
+
logger.debug("Root endpoint accessed, redirecting to docs")
|
603 |
+
return RedirectResponse(url="/docs")
|
604 |
+
|
605 |
+
if __name__ == "__main__":
|
606 |
+
# Get port from environment or use default
|
607 |
+
port = int(os.environ.get("PORT", 8000))
|
608 |
+
|
609 |
+
# Development mode flag
|
610 |
+
dev_mode = os.environ.get("DEV_MODE", "false").lower() == "true"
|
611 |
+
|
612 |
+
# Log level (default to INFO, but can be overridden)
|
613 |
+
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
|
614 |
+
logging.getLogger().setLevel(log_level)
|
615 |
+
|
616 |
+
# Check for audio enhancement and voice cloning flags
|
617 |
+
enable_enhancements = os.environ.get("ENABLE_ENHANCEMENTS", "true").lower() == "true"
|
618 |
+
enable_voice_cloning = os.environ.get("ENABLE_VOICE_CLONING", "true").lower() == "true"
|
619 |
+
|
620 |
+
if not enable_enhancements:
|
621 |
+
logger.warning("Voice enhancements disabled by environment variable")
|
622 |
+
if not enable_voice_cloning:
|
623 |
+
logger.warning("Voice cloning disabled by environment variable")
|
624 |
+
|
625 |
+
logger.info(f"Voice enhancements: {'enabled' if enable_enhancements else 'disabled'}")
|
626 |
+
logger.info(f"Voice cloning: {'enabled' if enable_voice_cloning else 'disabled'}")
|
627 |
+
logger.info(f"Streaming: enabled")
|
628 |
+
logger.info(f"Log level: {log_level}")
|
629 |
+
|
630 |
+
if dev_mode:
|
631 |
+
logger.info(f"Running in development mode with auto-reload enabled on port {port}")
|
632 |
+
uvicorn.run(
|
633 |
+
"app.main:app",
|
634 |
+
host="0.0.0.0",
|
635 |
+
port=port,
|
636 |
+
reload=True,
|
637 |
+
log_level=log_level.lower()
|
638 |
+
)
|
639 |
+
else:
|
640 |
+
logger.info(f"Running in production mode on port {port}")
|
641 |
+
uvicorn.run(
|
642 |
+
"app.main:app",
|
643 |
+
host="0.0.0.0",
|
644 |
+
port=port,
|
645 |
+
reload=False,
|
646 |
+
log_level=log_level.lower()
|
647 |
+
)
|
app/models.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class Segment:
|
9 |
+
"""A segment of speech with text, speaker, and audio."""
|
10 |
+
speaker: int
|
11 |
+
text: str
|
12 |
+
# (num_samples,), sample_rate = 24_000
|
13 |
+
audio: torch.Tensor
|
app/prompt_engineering.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Prompt engineering for consistent voice generation."""
|
2 |
+
import re
|
3 |
+
import random
|
4 |
+
from typing import List, Dict, Optional
|
5 |
+
import logging
|
6 |
+
|
7 |
+
# Set up logging
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
# Voice style descriptors for consistent prompting
|
11 |
+
VOICE_STYLES = {
|
12 |
+
"alloy": {
|
13 |
+
"adjectives": ["balanced", "natural", "clear", "articulate", "neutral", "conversational"],
|
14 |
+
"characteristics": ["medium pitch", "even pacing", "neutral tone", "balanced resonance"],
|
15 |
+
"speaking_style": "conversational and balanced"
|
16 |
+
},
|
17 |
+
"echo": {
|
18 |
+
"adjectives": ["resonant", "deep", "reverberant", "rich", "sonorous", "full"],
|
19 |
+
"characteristics": ["lower pitch", "deliberate pacing", "resonant tone", "deeper timbre"],
|
20 |
+
"speaking_style": "rich and resonant"
|
21 |
+
},
|
22 |
+
"fable": {
|
23 |
+
"adjectives": ["bright", "light", "clear", "energetic", "articulate", "animated"],
|
24 |
+
"characteristics": ["higher pitch", "lively pacing", "bright tone", "clear articulation"],
|
25 |
+
"speaking_style": "bright and energetic"
|
26 |
+
},
|
27 |
+
"onyx": {
|
28 |
+
"adjectives": ["deep", "authoritative", "powerful", "commanding", "strong", "resolute"],
|
29 |
+
"characteristics": ["low pitch", "measured pacing", "authoritative tone", "strong projection"],
|
30 |
+
"speaking_style": "deep and authoritative"
|
31 |
+
},
|
32 |
+
"nova": {
|
33 |
+
"adjectives": ["warm", "pleasant", "smooth", "harmonious", "gentle", "comforting"],
|
34 |
+
"characteristics": ["medium pitch", "smooth pacing", "warm tone", "pleasant timbre"],
|
35 |
+
"speaking_style": "warm and smooth"
|
36 |
+
},
|
37 |
+
"shimmer": {
|
38 |
+
"adjectives": ["light", "airy", "bright", "crystalline", "delicate", "expressive"],
|
39 |
+
"characteristics": ["higher pitch", "quick pacing", "light tone", "bright timbre"],
|
40 |
+
"speaking_style": "light and expressive"
|
41 |
+
},
|
42 |
+
"custom": {
|
43 |
+
"adjectives": ["clear", "distinct", "authentic", "natural", "personalized", "unique"],
|
44 |
+
"characteristics": ["natural rhythm", "authentic tone", "personal inflection", "distinctive sound"],
|
45 |
+
"speaking_style": "authentic and natural"
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
def initialize_templates():
|
50 |
+
"""Initialize prompt templates - placeholder for any future setup."""
|
51 |
+
logger.info("Prompt templates initialized")
|
52 |
+
return VOICE_STYLES
|
53 |
+
|
54 |
+
def split_into_segments(text: str, max_chars: int = 150) -> List[str]:
|
55 |
+
"""Split text into optimal segments for better generation.
|
56 |
+
Args:
|
57 |
+
text: Text to split
|
58 |
+
max_chars: Maximum characters per segment
|
59 |
+
Returns:
|
60 |
+
List of text segments
|
61 |
+
"""
|
62 |
+
# Handle empty or very short text
|
63 |
+
if not text or len(text) <= max_chars:
|
64 |
+
return [text]
|
65 |
+
|
66 |
+
# Split by sentences first
|
67 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
68 |
+
|
69 |
+
# Initialize segments
|
70 |
+
segments = []
|
71 |
+
current_segment = ""
|
72 |
+
|
73 |
+
for sentence in sentences:
|
74 |
+
# If adding this sentence would exceed max_chars
|
75 |
+
if len(current_segment) + len(sentence) > max_chars:
|
76 |
+
# If current segment is not empty, add it to segments
|
77 |
+
if current_segment:
|
78 |
+
segments.append(current_segment.strip())
|
79 |
+
current_segment = ""
|
80 |
+
|
81 |
+
# If this sentence alone exceeds max_chars, split it by phrases
|
82 |
+
if len(sentence) > max_chars:
|
83 |
+
phrases = re.split(r'(?<=[,;:])\s+', sentence)
|
84 |
+
for phrase in phrases:
|
85 |
+
if len(phrase) > max_chars:
|
86 |
+
# Split long phrases into chunks
|
87 |
+
words = phrase.split()
|
88 |
+
chunk = ""
|
89 |
+
for word in words:
|
90 |
+
if len(chunk) + len(word) + 1 <= max_chars:
|
91 |
+
chunk += " " + word if chunk else word
|
92 |
+
else:
|
93 |
+
segments.append(chunk.strip())
|
94 |
+
chunk = word
|
95 |
+
if chunk:
|
96 |
+
segments.append(chunk.strip())
|
97 |
+
else:
|
98 |
+
if len(current_segment) + len(phrase) <= max_chars:
|
99 |
+
current_segment += " " + phrase if current_segment else phrase
|
100 |
+
else:
|
101 |
+
segments.append(current_segment.strip())
|
102 |
+
current_segment = phrase
|
103 |
+
else:
|
104 |
+
current_segment = sentence
|
105 |
+
else:
|
106 |
+
current_segment += " " + sentence if current_segment else sentence
|
107 |
+
|
108 |
+
# Add the last segment
|
109 |
+
if current_segment:
|
110 |
+
segments.append(current_segment.strip())
|
111 |
+
|
112 |
+
logger.info(f"Split text into {len(segments)} segments")
|
113 |
+
return segments
|
114 |
+
|
115 |
+
def format_text_for_voice(text: str, voice_name: str, segment_index: int = 0, total_segments: int = 1) -> str:
|
116 |
+
"""Format text with voice characteristics for more consistent generation.
|
117 |
+
Args:
|
118 |
+
text: Text to format
|
119 |
+
voice_name: Name of the voice
|
120 |
+
segment_index: Index of this segment (for multi-segment texts)
|
121 |
+
total_segments: Total number of segments
|
122 |
+
Returns:
|
123 |
+
Formatted text optimized for consistent voice generation
|
124 |
+
"""
|
125 |
+
# IMPORTANT: We no longer add voice instructions in brackets since CSM reads them aloud
|
126 |
+
# Instead, we're using speaker IDs to control voice identity which is what the model expects
|
127 |
+
|
128 |
+
# Just return the unmodified text - the Generator class will handle proper formatting
|
129 |
+
return text
|
app/text_normalizer.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Text normalization and cleaning utilities for CSM-1B TTS system.
|
3 |
+
Handles common issues like contractions, numbers, and special characters.
|
4 |
+
"""
|
5 |
+
import re
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class TextNormalizer:
|
11 |
+
"""Text normalization utilities for TTS."""
|
12 |
+
|
13 |
+
# Common English contractions mapping
|
14 |
+
CONTRACTIONS = {
|
15 |
+
"don't": "dont",
|
16 |
+
"won't": "wont",
|
17 |
+
"can't": "cant",
|
18 |
+
"isn't": "isnt",
|
19 |
+
"he's": "hes",
|
20 |
+
"she's": "shes",
|
21 |
+
"they're": "theyre",
|
22 |
+
"we're": "were",
|
23 |
+
"you're": "youre",
|
24 |
+
"that's": "thats",
|
25 |
+
"it's": "its",
|
26 |
+
"what's": "whats",
|
27 |
+
"let's": "lets",
|
28 |
+
"who's": "whos",
|
29 |
+
"how's": "hows",
|
30 |
+
"where's": "wheres",
|
31 |
+
"there's": "theres",
|
32 |
+
"wouldn't": "wouldnt",
|
33 |
+
"shouldn't": "shouldnt",
|
34 |
+
"couldn't": "couldnt",
|
35 |
+
"hasn't": "hasnt",
|
36 |
+
"haven't": "havent",
|
37 |
+
"hadn't": "hadnt",
|
38 |
+
"didn't": "didnt",
|
39 |
+
"i'm": "im",
|
40 |
+
"i've": "ive",
|
41 |
+
"i'd": "id",
|
42 |
+
"i'll": "ill",
|
43 |
+
"you've": "youve",
|
44 |
+
"you'll": "youll",
|
45 |
+
"you'd": "youd",
|
46 |
+
"we've": "weve",
|
47 |
+
"we'll": "well",
|
48 |
+
"we'd": "wed",
|
49 |
+
"they've": "theyve",
|
50 |
+
"they'll": "theyll",
|
51 |
+
"they'd": "theyd",
|
52 |
+
"aren't": "arent",
|
53 |
+
"weren't": "werent",
|
54 |
+
"wasn't": "wasnt",
|
55 |
+
}
|
56 |
+
|
57 |
+
# Common abbreviations to expand
|
58 |
+
ABBREVIATIONS = {
|
59 |
+
"Mr.": "Mister",
|
60 |
+
"Mrs.": "Misses",
|
61 |
+
"Dr.": "Doctor",
|
62 |
+
"Prof.": "Professor",
|
63 |
+
"St.": "Street",
|
64 |
+
"Rd.": "Road",
|
65 |
+
"Ave.": "Avenue",
|
66 |
+
"vs.": "versus",
|
67 |
+
"etc.": "etcetera",
|
68 |
+
"e.g.": "for example",
|
69 |
+
"i.e.": "that is",
|
70 |
+
"approx.": "approximately",
|
71 |
+
}
|
72 |
+
|
73 |
+
# Simple number words for common numbers
|
74 |
+
NUMBER_WORDS = {
|
75 |
+
"0": "zero",
|
76 |
+
"1": "one",
|
77 |
+
"2": "two",
|
78 |
+
"3": "three",
|
79 |
+
"4": "four",
|
80 |
+
"5": "five",
|
81 |
+
"6": "six",
|
82 |
+
"7": "seven",
|
83 |
+
"8": "eight",
|
84 |
+
"9": "nine",
|
85 |
+
"10": "ten",
|
86 |
+
"11": "eleven",
|
87 |
+
"12": "twelve",
|
88 |
+
"13": "thirteen",
|
89 |
+
"14": "fourteen",
|
90 |
+
"15": "fifteen",
|
91 |
+
"16": "sixteen",
|
92 |
+
"17": "seventeen",
|
93 |
+
"18": "eighteen",
|
94 |
+
"19": "nineteen",
|
95 |
+
"20": "twenty",
|
96 |
+
"30": "thirty",
|
97 |
+
"40": "forty",
|
98 |
+
"50": "fifty",
|
99 |
+
"60": "sixty",
|
100 |
+
"70": "seventy",
|
101 |
+
"80": "eighty",
|
102 |
+
"90": "ninety",
|
103 |
+
"100": "one hundred",
|
104 |
+
"1000": "one thousand",
|
105 |
+
"1000000": "one million",
|
106 |
+
"1000000000": "one billion",
|
107 |
+
}
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def normalize_text(cls, text: str) -> str:
|
111 |
+
"""
|
112 |
+
Normalize text for TTS: handle contractions, punctuation, and special cases.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
text: Input text to normalize
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Normalized text ready for TTS
|
119 |
+
"""
|
120 |
+
if not text:
|
121 |
+
return text
|
122 |
+
|
123 |
+
# Log original text for debugging
|
124 |
+
logger.debug(f"Normalizing text: '{text}'")
|
125 |
+
|
126 |
+
# Remove voice instructions in square brackets
|
127 |
+
text = re.sub(r'\[.*?\]', '', text)
|
128 |
+
|
129 |
+
# Handle contractions - preserving case sensitivity
|
130 |
+
for contraction, replacement in cls.CONTRACTIONS.items():
|
131 |
+
# Case insensitive replacement
|
132 |
+
text = re.sub(r'\b' + re.escape(contraction) + r'\b', replacement, text, flags=re.IGNORECASE)
|
133 |
+
|
134 |
+
# Expand common abbreviations
|
135 |
+
for abbr, expanded in cls.ABBREVIATIONS.items():
|
136 |
+
text = text.replace(abbr, expanded)
|
137 |
+
|
138 |
+
# Handle numbers - only convert standalone numbers
|
139 |
+
def replace_number(match):
|
140 |
+
number = match.group(0)
|
141 |
+
if number in cls.NUMBER_WORDS:
|
142 |
+
return cls.NUMBER_WORDS[number]
|
143 |
+
return number
|
144 |
+
|
145 |
+
text = re.sub(r'\b\d+\b', replace_number, text)
|
146 |
+
|
147 |
+
# Replace problematic symbols
|
148 |
+
text = text.replace("&", " and ")
|
149 |
+
text = text.replace("%", " percent ")
|
150 |
+
text = text.replace("@", " at ")
|
151 |
+
text = text.replace("#", " number ")
|
152 |
+
text = text.replace("$", " dollar ")
|
153 |
+
text = text.replace("€", " euro ")
|
154 |
+
text = text.replace("£", " pound ")
|
155 |
+
text = text.replace("¥", " yen ")
|
156 |
+
|
157 |
+
# Handle dates in MM/DD/YYYY format
|
158 |
+
text = re.sub(r'\b(\d{1,2})/(\d{1,2})/(\d{4})\b', r'\1 \2 \3', text)
|
159 |
+
|
160 |
+
# Fix excessive spaces
|
161 |
+
text = re.sub(r'\s+', ' ', text).strip()
|
162 |
+
|
163 |
+
# Ensure sentence ends with punctuation
|
164 |
+
if not text[-1] in ['.', '!', '?', ';', ':', ',']:
|
165 |
+
text = text + '.'
|
166 |
+
|
167 |
+
logger.debug(f"Normalized text: '{text}'")
|
168 |
+
return text
|
169 |
+
|
170 |
+
@classmethod
|
171 |
+
def split_into_sentences(cls, text: str) -> list:
|
172 |
+
"""
|
173 |
+
Split text into sentences for better TTS performance.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
text: Input text to split
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
List of sentences
|
180 |
+
"""
|
181 |
+
# Normalize first
|
182 |
+
text = cls.normalize_text(text)
|
183 |
+
|
184 |
+
# Split on sentence boundaries
|
185 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
186 |
+
|
187 |
+
# Remove empty sentences
|
188 |
+
sentences = [s for s in sentences if s.strip()]
|
189 |
+
|
190 |
+
return sentences
|
191 |
+
|
192 |
+
def clean_text_for_tts(text: str) -> str:
|
193 |
+
"""Clean and normalize text for TTS processing."""
|
194 |
+
return TextNormalizer.normalize_text(text)
|
app/torchtune_models.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Torchtune models for CSM-1B."""
|
2 |
+
import logging
|
3 |
+
from dataclasses import dataclass
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
# Set up logging
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
# First, try to import llama3_2 from torchtune directly
|
11 |
+
try:
|
12 |
+
import torchtune
|
13 |
+
logger.info(f"Torchtune version: {getattr(torchtune, '__version__', 'unknown')}")
|
14 |
+
|
15 |
+
# Print available modules in torchtune.models
|
16 |
+
try:
|
17 |
+
import torchtune.models
|
18 |
+
logger.info(f"Available modules in torchtune.models: {dir(torchtune.models)}")
|
19 |
+
except Exception as e:
|
20 |
+
logger.error(f"Error inspecting torchtune.models: {e}")
|
21 |
+
|
22 |
+
# Try to import llama3_2 model
|
23 |
+
try:
|
24 |
+
from torchtune.models.llama3_2 import llama3_2
|
25 |
+
logger.info("Successfully imported llama3_2 from torchtune")
|
26 |
+
except ImportError as e:
|
27 |
+
logger.warning(f"Could not import llama3_2: {e}")
|
28 |
+
# Try to import regular llama as fallback
|
29 |
+
try:
|
30 |
+
from torchtune.models.llama import llama
|
31 |
+
logger.info("Using llama from torchtune.models.llama as fallback")
|
32 |
+
llama3_2 = llama # Alias llama as llama3_2
|
33 |
+
except ImportError:
|
34 |
+
logger.error("Could not import llama model either. Will use custom implementation.")
|
35 |
+
llama3_2 = None
|
36 |
+
except ImportError as e:
|
37 |
+
logger.error(f"Torchtune not available: {e}")
|
38 |
+
torchtune = None
|
39 |
+
llama3_2 = None
|
40 |
+
|
41 |
+
|
42 |
+
# Define our own model implementations as fallbacks
|
43 |
+
def llama3_2_1B_custom():
|
44 |
+
"""Create a Llama 3.2 1B model."""
|
45 |
+
from app.custom_transformer import CustomTransformerDecoder
|
46 |
+
return CustomTransformerDecoder(
|
47 |
+
vocab_size=128_256,
|
48 |
+
num_layers=16,
|
49 |
+
num_heads=32,
|
50 |
+
num_kv_heads=8,
|
51 |
+
embed_dim=2048,
|
52 |
+
max_seq_len=2048,
|
53 |
+
intermediate_dim=8192,
|
54 |
+
attn_dropout=0.0,
|
55 |
+
norm_eps=1e-5,
|
56 |
+
)
|
57 |
+
|
58 |
+
|
59 |
+
def llama3_2_100M_custom():
|
60 |
+
"""Create a Llama 3.2 100M model."""
|
61 |
+
from app.custom_transformer import CustomTransformerDecoder
|
62 |
+
return CustomTransformerDecoder(
|
63 |
+
vocab_size=128_256,
|
64 |
+
num_layers=4,
|
65 |
+
num_heads=8,
|
66 |
+
num_kv_heads=2,
|
67 |
+
embed_dim=1024,
|
68 |
+
max_seq_len=2048,
|
69 |
+
intermediate_dim=8192,
|
70 |
+
attn_dropout=0.0,
|
71 |
+
norm_eps=1e-5,
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
+
# Setup fallback to our own implementations if needed
|
76 |
+
if llama3_2 is None:
|
77 |
+
logger.warning("Using custom implementations for Llama models")
|
78 |
+
FLAVORS = {
|
79 |
+
"llama-1B": llama3_2_1B_custom,
|
80 |
+
"llama-100M": llama3_2_100M_custom,
|
81 |
+
}
|
82 |
+
else:
|
83 |
+
logger.info("Using torchtune implementations for Llama models")
|
84 |
+
FLAVORS = {
|
85 |
+
"llama-1B": lambda: llama3_2(
|
86 |
+
vocab_size=128_256,
|
87 |
+
num_layers=16,
|
88 |
+
num_heads=32,
|
89 |
+
num_kv_heads=8,
|
90 |
+
embed_dim=2048,
|
91 |
+
max_seq_len=2048,
|
92 |
+
intermediate_dim=8192,
|
93 |
+
attn_dropout=0.0,
|
94 |
+
norm_eps=1e-5,
|
95 |
+
rope_base=500_000,
|
96 |
+
scale_factor=32,
|
97 |
+
),
|
98 |
+
"llama-100M": lambda: llama3_2(
|
99 |
+
vocab_size=128_256,
|
100 |
+
num_layers=4,
|
101 |
+
num_heads=8,
|
102 |
+
num_kv_heads=2,
|
103 |
+
embed_dim=1024,
|
104 |
+
max_seq_len=2048,
|
105 |
+
intermediate_dim=8192,
|
106 |
+
attn_dropout=0.0,
|
107 |
+
norm_eps=1e-5,
|
108 |
+
rope_base=500_000,
|
109 |
+
scale_factor=32,
|
110 |
+
),
|
111 |
+
}
|
112 |
+
|
113 |
+
|
114 |
+
def _prepare_transformer(model):
|
115 |
+
"""Prepare transformer for use."""
|
116 |
+
embed_dim = model.tok_embeddings.embedding_dim
|
117 |
+
model.tok_embeddings = nn.Identity()
|
118 |
+
model.output = nn.Identity()
|
119 |
+
return model, embed_dim
|
120 |
+
|
121 |
+
|
122 |
+
def _create_causal_mask(seq_len: int, device: torch.device):
|
123 |
+
"""Create causal mask."""
|
124 |
+
return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
125 |
+
|
126 |
+
|
127 |
+
def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
|
128 |
+
"""Index causal mask.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
mask: (max_seq_len, max_seq_len)
|
132 |
+
input_pos: (batch_size, seq_len)
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
(batch_size, seq_len, max_seq_len)
|
136 |
+
"""
|
137 |
+
r = mask[input_pos, :]
|
138 |
+
return r
|
139 |
+
|
140 |
+
|
141 |
+
def _multinomial_sample_one_no_sync(probs):
|
142 |
+
"""Do multinomial sampling without a cuda synchronization."""
|
143 |
+
q = torch.empty_like(probs).exponential_(1)
|
144 |
+
return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
145 |
+
|
146 |
+
|
147 |
+
def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
|
148 |
+
"""Sample from top-k logits."""
|
149 |
+
logits = logits / temperature
|
150 |
+
filter_value: float = -float("Inf")
|
151 |
+
indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
|
152 |
+
scores_processed = logits.masked_fill(indices_to_remove, filter_value)
|
153 |
+
scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
|
154 |
+
probs = torch.nn.functional.softmax(scores_processed, dim=-1)
|
155 |
+
sample_token = _multinomial_sample_one_no_sync(probs)
|
156 |
+
return sample_token
|
157 |
+
|
158 |
+
|
159 |
+
@dataclass
|
160 |
+
class ModelArgs:
|
161 |
+
"""Model arguments."""
|
162 |
+
backbone_flavor: str
|
163 |
+
decoder_flavor: str
|
164 |
+
text_vocab_size: int
|
165 |
+
audio_vocab_size: int
|
166 |
+
audio_num_codebooks: int
|
167 |
+
|
168 |
+
|
169 |
+
class Model(nn.Module):
|
170 |
+
"""CSM-1B model."""
|
171 |
+
|
172 |
+
def __init__(self, args: ModelArgs):
|
173 |
+
"""Initialize model."""
|
174 |
+
super().__init__()
|
175 |
+
self.args = args
|
176 |
+
logger.info(f"Creating model with backbone: {args.backbone_flavor}, decoder: {args.decoder_flavor}")
|
177 |
+
|
178 |
+
# Load backbone and decoder
|
179 |
+
self.backbone, backbone_dim = _prepare_transformer(FLAVORS[args.backbone_flavor]())
|
180 |
+
self.decoder, decoder_dim = _prepare_transformer(FLAVORS[args.decoder_flavor]())
|
181 |
+
|
182 |
+
# Embeddings
|
183 |
+
self.text_embeddings = nn.Embedding(args.text_vocab_size, backbone_dim)
|
184 |
+
self.audio_embeddings = nn.Embedding(args.audio_vocab_size * args.audio_num_codebooks, backbone_dim)
|
185 |
+
|
186 |
+
# Projection and heads
|
187 |
+
self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
|
188 |
+
self.codebook0_head = nn.Linear(backbone_dim, args.audio_vocab_size, bias=False)
|
189 |
+
self.audio_head = nn.Parameter(torch.empty(args.audio_num_codebooks - 1, decoder_dim, args.audio_vocab_size))
|
190 |
+
|
191 |
+
# Initialize audio head
|
192 |
+
nn.init.normal_(self.audio_head, mean=0.0, std=0.02)
|
193 |
+
|
194 |
+
def setup_caches(self, max_batch_size: int) -> torch.Tensor:
|
195 |
+
"""Setup KV caches and return a causal mask."""
|
196 |
+
dtype = next(self.parameters()).dtype
|
197 |
+
device = next(self.parameters()).device
|
198 |
+
|
199 |
+
with device:
|
200 |
+
self.backbone.setup_caches(max_batch_size, dtype)
|
201 |
+
self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.args.audio_num_codebooks)
|
202 |
+
|
203 |
+
self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
|
204 |
+
self.register_buffer("decoder_causal_mask", _create_causal_mask(self.args.audio_num_codebooks, device))
|
205 |
+
|
206 |
+
def generate_frame(
|
207 |
+
self,
|
208 |
+
tokens: torch.Tensor,
|
209 |
+
tokens_mask: torch.Tensor,
|
210 |
+
input_pos: torch.Tensor,
|
211 |
+
temperature: float,
|
212 |
+
topk: int,
|
213 |
+
) -> torch.Tensor:
|
214 |
+
"""Generate a frame of audio tokens.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
tokens: (batch_size, seq_len, audio_num_codebooks+1)
|
218 |
+
tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
|
219 |
+
input_pos: (batch_size, seq_len) positions for each token
|
220 |
+
|
221 |
+
Returns:
|
222 |
+
(batch_size, audio_num_codebooks) sampled tokens
|
223 |
+
"""
|
224 |
+
dtype = next(self.parameters()).dtype
|
225 |
+
b, s = tokens.size()[:2]
|
226 |
+
|
227 |
+
assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
|
228 |
+
|
229 |
+
curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
|
230 |
+
embeds = self._embed_tokens(tokens)
|
231 |
+
masked_embeds = embeds * tokens_mask.unsqueeze(-1)
|
232 |
+
h = masked_embeds.sum(dim=2)
|
233 |
+
|
234 |
+
h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
|
235 |
+
|
236 |
+
last_h = h[:, -1, :]
|
237 |
+
c0_logits = self.codebook0_head(last_h)
|
238 |
+
c0_sample = sample_topk(c0_logits, topk, temperature)
|
239 |
+
c0_embed = self._embed_audio(0, c0_sample)
|
240 |
+
|
241 |
+
curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
|
242 |
+
curr_sample = c0_sample.clone()
|
243 |
+
curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
|
244 |
+
|
245 |
+
# Decoder caches must be reset every frame.
|
246 |
+
self.decoder.reset_caches()
|
247 |
+
|
248 |
+
for i in range(1, self.args.audio_num_codebooks):
|
249 |
+
curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
|
250 |
+
decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
|
251 |
+
dtype=dtype
|
252 |
+
)
|
253 |
+
ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
|
254 |
+
ci_sample = sample_topk(ci_logits, topk, temperature)
|
255 |
+
ci_embed = self._embed_audio(i, ci_sample)
|
256 |
+
|
257 |
+
curr_h = ci_embed
|
258 |
+
curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
|
259 |
+
curr_pos = curr_pos[:, -1:] + 1
|
260 |
+
|
261 |
+
return curr_sample
|
262 |
+
|
263 |
+
def reset_caches(self):
|
264 |
+
"""Reset KV caches."""
|
265 |
+
self.backbone.reset_caches()
|
266 |
+
self.decoder.reset_caches()
|
267 |
+
|
268 |
+
def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
|
269 |
+
"""Embed audio tokens."""
|
270 |
+
return self.audio_embeddings(tokens + codebook * self.args.audio_vocab_size)
|
271 |
+
|
272 |
+
def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
273 |
+
"""Embed tokens."""
|
274 |
+
text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
|
275 |
+
audio_tokens = tokens[:, :, :-1] + (
|
276 |
+
self.args.audio_vocab_size * torch.arange(self.args.audio_num_codebooks, device=tokens.device)
|
277 |
+
)
|
278 |
+
audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
|
279 |
+
tokens.size(0), tokens.size(1), self.args.audio_num_codebooks, -1
|
280 |
+
)
|
281 |
+
|
282 |
+
return torch.cat([audio_embeds, text_embeds], dim=-2)
|
app/utils/audio_utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Audio utilities for CSM-1B API."""
|
2 |
+
|
3 |
+
import io
|
4 |
+
import tempfile
|
5 |
+
from typing import Optional
|
6 |
+
import os
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
import ffmpeg
|
11 |
+
|
12 |
+
|
13 |
+
def convert_audio_format(
|
14 |
+
audio_tensor: torch.Tensor,
|
15 |
+
sample_rate: int,
|
16 |
+
format: str = "mp3",
|
17 |
+
bit_rate: Optional[str] = "128k",
|
18 |
+
) -> bytes:
|
19 |
+
"""Convert audio tensor to specified format.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
audio_tensor: Audio tensor (channels, samples)
|
23 |
+
sample_rate: Sample rate
|
24 |
+
format: Output format (mp3, opus, aac, flac, wav)
|
25 |
+
bit_rate: Bit rate for lossy formats
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Audio bytes in specified format
|
29 |
+
"""
|
30 |
+
# Create temporary files
|
31 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
|
32 |
+
wav_path = temp_wav.name
|
33 |
+
|
34 |
+
temp_out = tempfile.NamedTemporaryFile(suffix=f".{format}", delete=False)
|
35 |
+
out_path = temp_out.name
|
36 |
+
temp_out.close()
|
37 |
+
|
38 |
+
try:
|
39 |
+
# Save as WAV first (native format for torchaudio)
|
40 |
+
torchaudio.save(wav_path, audio_tensor.unsqueeze(0) if audio_tensor.dim() == 1 else audio_tensor,
|
41 |
+
sample_rate)
|
42 |
+
|
43 |
+
# Convert to desired format using ffmpeg
|
44 |
+
if format == "mp3":
|
45 |
+
ffmpeg.input(wav_path).output(out_path, format=format, audio_bitrate=bit_rate).run(quiet=True)
|
46 |
+
elif format in ["opus", "aac"]:
|
47 |
+
ffmpeg.input(wav_path).output(out_path, format=format).run(quiet=True)
|
48 |
+
elif format == "flac":
|
49 |
+
ffmpeg.input(wav_path).output(out_path, format=format).run(quiet=True)
|
50 |
+
elif format == "wav":
|
51 |
+
# Already saved as WAV
|
52 |
+
pass
|
53 |
+
|
54 |
+
# Read the output file
|
55 |
+
with open(out_path if format != "wav" else wav_path, "rb") as f:
|
56 |
+
audio_bytes = f.read()
|
57 |
+
|
58 |
+
return audio_bytes
|
59 |
+
|
60 |
+
finally:
|
61 |
+
# Clean up temporary files
|
62 |
+
for path in [wav_path, out_path]:
|
63 |
+
if os.path.exists(path):
|
64 |
+
os.unlink(path)
|
app/utils/init.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
"""Utilities for CSM-1B API."""
|
app/utils/scheduled_tasks.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Scheduled tasks for the TTS API."""
|
2 |
+
import asyncio
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
async def periodic_voice_profile_backup(app_state, interval_hours=6):
|
10 |
+
"""
|
11 |
+
Periodically save voice profiles to persistent storage.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
app_state: The application state object
|
15 |
+
interval_hours: Backup interval in hours
|
16 |
+
"""
|
17 |
+
while True:
|
18 |
+
try:
|
19 |
+
# Wait for the specified interval
|
20 |
+
await asyncio.sleep(interval_hours * 3600)
|
21 |
+
|
22 |
+
# Log the backup
|
23 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
24 |
+
logger.info(f"Scheduled voice profile backup started at {timestamp}")
|
25 |
+
|
26 |
+
# Save voice profiles
|
27 |
+
if hasattr(app_state, "voice_enhancement_enabled") and app_state.voice_enhancement_enabled:
|
28 |
+
from app.voice_enhancement import save_voice_profiles
|
29 |
+
save_voice_profiles()
|
30 |
+
logger.info("Voice profiles saved successfully")
|
31 |
+
|
32 |
+
# Save voice memories
|
33 |
+
if hasattr(app_state, "voice_memory_enabled") and app_state.voice_memory_enabled:
|
34 |
+
for voice_name in app_state.voice_cache:
|
35 |
+
if voice_name in ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]:
|
36 |
+
from app.voice_memory import VOICE_MEMORIES
|
37 |
+
if voice_name in VOICE_MEMORIES:
|
38 |
+
VOICE_MEMORIES[voice_name].save()
|
39 |
+
logger.info("Voice memories saved successfully")
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
logger.error(f"Error in periodic voice profile backup: {e}")
|
app/utils/voice_manager.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions for managing voice references and profiles."""
|
2 |
+
import os
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import shutil
|
7 |
+
from typing import Dict, List, Optional, Any
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
# Define persistent paths
|
12 |
+
VOICE_REFERENCES_DIR = "/app/voice_references"
|
13 |
+
VOICE_PROFILES_DIR = "/app/voice_profiles"
|
14 |
+
VOICE_MEMORIES_DIR = "/app/voice_memories"
|
15 |
+
|
16 |
+
# Ensure directories exist
|
17 |
+
os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True)
|
18 |
+
os.makedirs(VOICE_PROFILES_DIR, exist_ok=True)
|
19 |
+
os.makedirs(VOICE_MEMORIES_DIR, exist_ok=True)
|
20 |
+
|
21 |
+
def backup_voice_data(backup_dir: str = "/app/voice_backups"):
|
22 |
+
"""Create a backup of all voice data."""
|
23 |
+
os.makedirs(backup_dir, exist_ok=True)
|
24 |
+
timestamp = torch.datetime.now().strftime("%Y%m%d_%H%M%S")
|
25 |
+
backup_path = os.path.join(backup_dir, f"voice_backup_{timestamp}")
|
26 |
+
os.makedirs(backup_path, exist_ok=True)
|
27 |
+
|
28 |
+
# Backup voice references
|
29 |
+
if os.path.exists(VOICE_REFERENCES_DIR):
|
30 |
+
refs_backup = os.path.join(backup_path, "voice_references")
|
31 |
+
shutil.copytree(VOICE_REFERENCES_DIR, refs_backup)
|
32 |
+
|
33 |
+
# Backup voice profiles
|
34 |
+
if os.path.exists(VOICE_PROFILES_DIR):
|
35 |
+
profiles_backup = os.path.join(backup_path, "voice_profiles")
|
36 |
+
shutil.copytree(VOICE_PROFILES_DIR, profiles_backup)
|
37 |
+
|
38 |
+
# Backup voice memories
|
39 |
+
if os.path.exists(VOICE_MEMORIES_DIR):
|
40 |
+
memories_backup = os.path.join(backup_path, "voice_memories")
|
41 |
+
shutil.copytree(VOICE_MEMORIES_DIR, memories_backup)
|
42 |
+
|
43 |
+
logger.info(f"Voice data backed up to {backup_path}")
|
44 |
+
return backup_path
|
45 |
+
|
46 |
+
def restore_default_voices():
|
47 |
+
"""Reset voices to their default state by removing existing voice data."""
|
48 |
+
for voice_dir in [VOICE_REFERENCES_DIR, VOICE_PROFILES_DIR, VOICE_MEMORIES_DIR]:
|
49 |
+
if os.path.exists(voice_dir):
|
50 |
+
# Create a backup before deleting
|
51 |
+
backup_path = backup_voice_data()
|
52 |
+
|
53 |
+
# Remove existing data
|
54 |
+
for item in os.listdir(voice_dir):
|
55 |
+
item_path = os.path.join(voice_dir, item)
|
56 |
+
if os.path.isdir(item_path):
|
57 |
+
shutil.rmtree(item_path)
|
58 |
+
else:
|
59 |
+
os.remove(item_path)
|
60 |
+
|
61 |
+
logger.info(f"Removed existing voice data from {voice_dir}")
|
62 |
+
|
63 |
+
logger.info(f"Voices reset to default state (backup created at {backup_path})")
|
64 |
+
return backup_path
|
65 |
+
|
66 |
+
def verify_voice_references():
|
67 |
+
"""Check if voice references are complete and valid."""
|
68 |
+
standard_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
|
69 |
+
missing_voices = []
|
70 |
+
|
71 |
+
for voice in standard_voices:
|
72 |
+
voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice)
|
73 |
+
# Check if directory exists and contains reference files
|
74 |
+
if not os.path.exists(voice_dir) or len(os.listdir(voice_dir)) == 0:
|
75 |
+
missing_voices.append(voice)
|
76 |
+
|
77 |
+
return {
|
78 |
+
"complete": len(missing_voices) == 0,
|
79 |
+
"missing_voices": missing_voices,
|
80 |
+
"references_dir": VOICE_REFERENCES_DIR
|
81 |
+
}
|
82 |
+
|
83 |
+
def get_voice_storage_info() -> Dict[str, Any]:
|
84 |
+
"""Get information about voice storage usage and status."""
|
85 |
+
result = {
|
86 |
+
"voice_references": {
|
87 |
+
"path": VOICE_REFERENCES_DIR,
|
88 |
+
"exists": os.path.exists(VOICE_REFERENCES_DIR),
|
89 |
+
"voices": [],
|
90 |
+
"total_size_mb": 0
|
91 |
+
},
|
92 |
+
"voice_profiles": {
|
93 |
+
"path": VOICE_PROFILES_DIR,
|
94 |
+
"exists": os.path.exists(VOICE_PROFILES_DIR),
|
95 |
+
"file_count": 0,
|
96 |
+
"total_size_mb": 0
|
97 |
+
},
|
98 |
+
"voice_memories": {
|
99 |
+
"path": VOICE_MEMORIES_DIR,
|
100 |
+
"exists": os.path.exists(VOICE_MEMORIES_DIR),
|
101 |
+
"voices": [],
|
102 |
+
"total_size_mb": 0
|
103 |
+
}
|
104 |
+
}
|
105 |
+
|
106 |
+
# Get voice references info
|
107 |
+
if result["voice_references"]["exists"]:
|
108 |
+
for voice in os.listdir(VOICE_REFERENCES_DIR):
|
109 |
+
voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice)
|
110 |
+
if os.path.isdir(voice_dir):
|
111 |
+
file_count = len([f for f in os.listdir(voice_dir) if f.endswith('.wav')])
|
112 |
+
dir_size = sum(os.path.getsize(os.path.join(voice_dir, f)) for f in os.listdir(voice_dir) if os.path.isfile(os.path.join(voice_dir, f)))
|
113 |
+
result["voice_references"]["voices"].append({
|
114 |
+
"name": voice,
|
115 |
+
"file_count": file_count,
|
116 |
+
"size_mb": dir_size / (1024 * 1024)
|
117 |
+
})
|
118 |
+
result["voice_references"]["total_size_mb"] += dir_size / (1024 * 1024)
|
119 |
+
|
120 |
+
# Get voice profiles info
|
121 |
+
if result["voice_profiles"]["exists"]:
|
122 |
+
files = [f for f in os.listdir(VOICE_PROFILES_DIR) if os.path.isfile(os.path.join(VOICE_PROFILES_DIR, f))]
|
123 |
+
result["voice_profiles"]["file_count"] = len(files)
|
124 |
+
result["voice_profiles"]["total_size_mb"] = sum(os.path.getsize(os.path.join(VOICE_PROFILES_DIR, f)) for f in files) / (1024 * 1024)
|
125 |
+
|
126 |
+
# Get voice memories info
|
127 |
+
if result["voice_memories"]["exists"]:
|
128 |
+
files = [f for f in os.listdir(VOICE_MEMORIES_DIR) if os.path.isfile(os.path.join(VOICE_MEMORIES_DIR, f))]
|
129 |
+
result["voice_memories"]["voices"] = [f.replace('.pt', '') for f in files if f.endswith('.pt')]
|
130 |
+
result["voice_memories"]["total_size_mb"] = sum(os.path.getsize(os.path.join(VOICE_MEMORIES_DIR, f)) for f in files) / (1024 * 1024)
|
131 |
+
|
132 |
+
return result
|
app/voice_cloning.py
ADDED
@@ -0,0 +1,705 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Voice cloning module for CSM-1B TTS API.
|
3 |
+
|
4 |
+
This module provides functionality to clone voices from audio samples,
|
5 |
+
with advanced audio preprocessing and voice adaptation techniques.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import io
|
9 |
+
import time
|
10 |
+
import tempfile
|
11 |
+
import logging
|
12 |
+
import asyncio
|
13 |
+
import yt_dlp
|
14 |
+
import whisper
|
15 |
+
from typing import Dict, List, Optional, Union, Tuple, BinaryIO
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import torchaudio
|
21 |
+
from pydantic import BaseModel
|
22 |
+
from fastapi import UploadFile
|
23 |
+
|
24 |
+
from app.models import Segment
|
25 |
+
|
26 |
+
# Set up logging
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
# Directory for storing cloned voice data
|
30 |
+
CLONED_VOICES_DIR = "/app/cloned_voices"
|
31 |
+
os.makedirs(CLONED_VOICES_DIR, exist_ok=True)
|
32 |
+
|
33 |
+
class ClonedVoice(BaseModel):
|
34 |
+
"""Model representing a cloned voice."""
|
35 |
+
id: str
|
36 |
+
name: str
|
37 |
+
created_at: float
|
38 |
+
speaker_id: int
|
39 |
+
description: Optional[str] = None
|
40 |
+
audio_duration: float
|
41 |
+
sample_count: int
|
42 |
+
|
43 |
+
|
44 |
+
class VoiceCloner:
|
45 |
+
"""Voice cloning utility for CSM-1B model."""
|
46 |
+
|
47 |
+
def __init__(self, generator, device="cuda"):
|
48 |
+
"""Initialize the voice cloner with a generator instance."""
|
49 |
+
self.generator = generator
|
50 |
+
self.device = device
|
51 |
+
self.sample_rate = generator.sample_rate
|
52 |
+
self.cloned_voices = self._load_existing_voices()
|
53 |
+
logger.info(f"Voice cloner initialized with {len(self.cloned_voices)} existing voices")
|
54 |
+
|
55 |
+
def _load_existing_voices(self) -> Dict[str, ClonedVoice]:
|
56 |
+
"""Load existing cloned voices from disk."""
|
57 |
+
voices = {}
|
58 |
+
if not os.path.exists(CLONED_VOICES_DIR):
|
59 |
+
return voices
|
60 |
+
|
61 |
+
for voice_dir in os.listdir(CLONED_VOICES_DIR):
|
62 |
+
voice_path = os.path.join(CLONED_VOICES_DIR, voice_dir)
|
63 |
+
if not os.path.isdir(voice_path):
|
64 |
+
continue
|
65 |
+
|
66 |
+
info_path = os.path.join(voice_path, "info.json")
|
67 |
+
if os.path.exists(info_path):
|
68 |
+
try:
|
69 |
+
import json
|
70 |
+
with open(info_path, "r") as f:
|
71 |
+
voice_info = json.load(f)
|
72 |
+
voices[voice_dir] = ClonedVoice(**voice_info)
|
73 |
+
logger.info(f"Loaded cloned voice: {voice_dir}")
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Error loading voice {voice_dir}: {e}")
|
76 |
+
|
77 |
+
return voices
|
78 |
+
|
79 |
+
async def process_audio_file(
|
80 |
+
self,
|
81 |
+
file: Union[UploadFile, BinaryIO, str],
|
82 |
+
transcript: Optional[str] = None
|
83 |
+
) -> Tuple[torch.Tensor, Optional[str], float]:
|
84 |
+
"""
|
85 |
+
Process an audio file for voice cloning.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
file: The audio file (UploadFile, file-like object, or path)
|
89 |
+
transcript: Optional transcript of the audio
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
Tuple of (processed_audio, transcript, duration_seconds)
|
93 |
+
"""
|
94 |
+
temp_path = None
|
95 |
+
|
96 |
+
try:
|
97 |
+
# Handle different input types
|
98 |
+
if isinstance(file, str):
|
99 |
+
# It's a file path
|
100 |
+
audio_path = file
|
101 |
+
logger.info(f"Processing audio from file path: {audio_path}")
|
102 |
+
else:
|
103 |
+
# Create a temporary file
|
104 |
+
temp_fd, temp_path = tempfile.mkstemp(suffix=".wav")
|
105 |
+
os.close(temp_fd) # Close the file descriptor
|
106 |
+
|
107 |
+
if isinstance(file, UploadFile):
|
108 |
+
# It's a FastAPI UploadFile
|
109 |
+
logger.info("Processing audio from UploadFile")
|
110 |
+
contents = await file.read()
|
111 |
+
with open(temp_path, "wb") as f:
|
112 |
+
f.write(contents)
|
113 |
+
elif hasattr(file, 'read'):
|
114 |
+
# It's a file-like object - check if it's async
|
115 |
+
logger.info("Processing audio from file-like object")
|
116 |
+
if asyncio.iscoroutinefunction(file.read):
|
117 |
+
# It's an async read method
|
118 |
+
contents = await file.read()
|
119 |
+
else:
|
120 |
+
# It's a sync read method
|
121 |
+
contents = file.read()
|
122 |
+
|
123 |
+
with open(temp_path, "wb") as f:
|
124 |
+
f.write(contents)
|
125 |
+
else:
|
126 |
+
raise ValueError(f"Unsupported file type: {type(file)}")
|
127 |
+
|
128 |
+
audio_path = temp_path
|
129 |
+
logger.info(f"Saved uploaded audio to temporary file: {audio_path}")
|
130 |
+
|
131 |
+
# Load audio
|
132 |
+
logger.info(f"Loading audio from {audio_path}")
|
133 |
+
audio, sr = torchaudio.load(audio_path)
|
134 |
+
|
135 |
+
# Convert to mono if stereo
|
136 |
+
if audio.shape[0] > 1:
|
137 |
+
logger.info(f"Converting {audio.shape[0]} channels to mono")
|
138 |
+
audio = torch.mean(audio, dim=0, keepdim=True)
|
139 |
+
|
140 |
+
# Remove first dimension if it's 1
|
141 |
+
if audio.shape[0] == 1:
|
142 |
+
audio = audio.squeeze(0)
|
143 |
+
|
144 |
+
# Resample if necessary
|
145 |
+
if sr != self.sample_rate:
|
146 |
+
logger.info(f"Resampling from {sr}Hz to {self.sample_rate}Hz")
|
147 |
+
audio = torchaudio.functional.resample(
|
148 |
+
audio, orig_freq=sr, new_freq=self.sample_rate
|
149 |
+
)
|
150 |
+
|
151 |
+
# Get audio duration
|
152 |
+
duration_seconds = len(audio) / self.sample_rate
|
153 |
+
|
154 |
+
# Process audio for better quality
|
155 |
+
logger.info(f"Preprocessing audio for quality enhancement")
|
156 |
+
processed_audio = self._preprocess_audio(audio)
|
157 |
+
processed_duration = len(processed_audio) / self.sample_rate
|
158 |
+
|
159 |
+
logger.info(
|
160 |
+
f"Processed audio: original duration={duration_seconds:.2f}s, "
|
161 |
+
f"processed duration={processed_duration:.2f}s"
|
162 |
+
)
|
163 |
+
|
164 |
+
return processed_audio, transcript, duration_seconds
|
165 |
+
|
166 |
+
except Exception as e:
|
167 |
+
logger.error(f"Error processing audio: {e}", exc_info=True)
|
168 |
+
raise RuntimeError(f"Failed to process audio file: {e}")
|
169 |
+
|
170 |
+
finally:
|
171 |
+
# Clean up temp file if we created one
|
172 |
+
if temp_path and os.path.exists(temp_path):
|
173 |
+
try:
|
174 |
+
os.unlink(temp_path)
|
175 |
+
logger.debug(f"Deleted temporary file {temp_path}")
|
176 |
+
except Exception as e:
|
177 |
+
logger.warning(f"Failed to delete temporary file {temp_path}: {e}")
|
178 |
+
|
179 |
+
def _preprocess_audio(self, audio: torch.Tensor) -> torch.Tensor:
|
180 |
+
"""
|
181 |
+
Preprocess audio for better voice cloning quality.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
audio: Raw audio tensor
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
Processed audio tensor
|
188 |
+
"""
|
189 |
+
# Normalize volume
|
190 |
+
if torch.max(torch.abs(audio)) > 0:
|
191 |
+
audio = audio / torch.max(torch.abs(audio))
|
192 |
+
|
193 |
+
# Remove silence with dynamic threshold
|
194 |
+
audio = self._remove_silence(audio, threshold=0.02) # Slightly higher threshold to remove more noise
|
195 |
+
|
196 |
+
# Remove DC offset (very low frequency noise)
|
197 |
+
audio = audio - torch.mean(audio)
|
198 |
+
|
199 |
+
# Apply simple noise reduction
|
200 |
+
# This filters out very high frequencies that might contain noise
|
201 |
+
try:
|
202 |
+
audio_np = audio.cpu().numpy()
|
203 |
+
from scipy import signal
|
204 |
+
|
205 |
+
# Apply a bandpass filter to focus on speech frequencies (80Hz - 8000Hz)
|
206 |
+
sos = signal.butter(3, [80, 8000], 'bandpass', fs=self.sample_rate, output='sos')
|
207 |
+
filtered = signal.sosfilt(sos, audio_np)
|
208 |
+
|
209 |
+
# Normalize the filtered audio
|
210 |
+
filtered = filtered / (np.max(np.abs(filtered)) + 1e-8)
|
211 |
+
|
212 |
+
# Convert back to torch tensor
|
213 |
+
audio = torch.tensor(filtered, device=audio.device)
|
214 |
+
except Exception as e:
|
215 |
+
logger.warning(f"Advanced audio filtering failed, using basic processing: {e}")
|
216 |
+
|
217 |
+
# Ensure audio has correct amplitude
|
218 |
+
audio = audio * 0.9 # Slightly reduce volume to prevent clipping
|
219 |
+
|
220 |
+
return audio
|
221 |
+
|
222 |
+
def _remove_silence(
|
223 |
+
self,
|
224 |
+
audio: torch.Tensor,
|
225 |
+
threshold: float = 0.015,
|
226 |
+
min_silence_duration: float = 0.2
|
227 |
+
) -> torch.Tensor:
|
228 |
+
"""
|
229 |
+
Remove silence from audio while preserving speech rhythm.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
audio: Input audio tensor
|
233 |
+
threshold: Energy threshold for silence detection
|
234 |
+
min_silence_duration: Minimum silence duration in seconds
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
Audio with silence removed
|
238 |
+
"""
|
239 |
+
# Convert to numpy for easier processing
|
240 |
+
audio_np = audio.cpu().numpy()
|
241 |
+
|
242 |
+
# Calculate energy
|
243 |
+
energy = np.abs(audio_np)
|
244 |
+
|
245 |
+
# Find regions above threshold (speech)
|
246 |
+
is_speech = energy > threshold
|
247 |
+
|
248 |
+
# Convert min_silence_duration to samples
|
249 |
+
min_silence_samples = int(min_silence_duration * self.sample_rate)
|
250 |
+
|
251 |
+
# Find speech segments
|
252 |
+
speech_segments = []
|
253 |
+
in_speech = False
|
254 |
+
speech_start = 0
|
255 |
+
|
256 |
+
for i in range(len(is_speech)):
|
257 |
+
if is_speech[i] and not in_speech:
|
258 |
+
# Start of speech segment
|
259 |
+
in_speech = True
|
260 |
+
speech_start = i
|
261 |
+
elif not is_speech[i] and in_speech:
|
262 |
+
# Potential end of speech segment
|
263 |
+
# Only end if silence is long enough
|
264 |
+
silence_count = 0
|
265 |
+
for j in range(i, min(len(is_speech), i + min_silence_samples)):
|
266 |
+
if not is_speech[j]:
|
267 |
+
silence_count += 1
|
268 |
+
else:
|
269 |
+
break
|
270 |
+
|
271 |
+
if silence_count >= min_silence_samples:
|
272 |
+
# End of speech segment
|
273 |
+
in_speech = False
|
274 |
+
speech_segments.append((speech_start, i))
|
275 |
+
|
276 |
+
# Handle case where audio ends during speech
|
277 |
+
if in_speech:
|
278 |
+
speech_segments.append((speech_start, len(is_speech)))
|
279 |
+
|
280 |
+
# If no speech segments found, return original audio
|
281 |
+
if not speech_segments:
|
282 |
+
logger.warning("No speech segments detected, returning original audio")
|
283 |
+
return audio
|
284 |
+
|
285 |
+
# Add small buffer around segments
|
286 |
+
buffer_samples = int(0.05 * self.sample_rate) # 50ms buffer
|
287 |
+
processed_segments = []
|
288 |
+
|
289 |
+
for start, end in speech_segments:
|
290 |
+
buffered_start = max(0, start - buffer_samples)
|
291 |
+
buffered_end = min(len(audio_np), end + buffer_samples)
|
292 |
+
processed_segments.append(audio_np[buffered_start:buffered_end])
|
293 |
+
|
294 |
+
# Concatenate all segments with small pauses between them
|
295 |
+
small_pause = np.zeros(int(0.15 * self.sample_rate)) # 150ms pause
|
296 |
+
result = processed_segments[0]
|
297 |
+
|
298 |
+
for segment in processed_segments[1:]:
|
299 |
+
result = np.concatenate([result, small_pause, segment])
|
300 |
+
|
301 |
+
return torch.tensor(result, device=audio.device)
|
302 |
+
|
303 |
+
def _enhance_speech(self, audio: torch.Tensor) -> torch.Tensor:
|
304 |
+
"""Enhance speech quality for better cloning results."""
|
305 |
+
# This is a placeholder for more advanced speech enhancement
|
306 |
+
# In a production implementation, you could add:
|
307 |
+
# - Noise reduction
|
308 |
+
# - Equalization for speech frequencies
|
309 |
+
# - Gentle compression for better dynamics
|
310 |
+
return audio
|
311 |
+
|
312 |
+
async def clone_voice(
|
313 |
+
self,
|
314 |
+
audio_file: Union[UploadFile, BinaryIO, str],
|
315 |
+
voice_name: str,
|
316 |
+
transcript: Optional[str] = None,
|
317 |
+
description: Optional[str] = None,
|
318 |
+
speaker_id: Optional[int] = None # Make this optional
|
319 |
+
) -> ClonedVoice:
|
320 |
+
"""
|
321 |
+
Clone a voice from an audio file.
|
322 |
+
|
323 |
+
Args:
|
324 |
+
audio_file: Audio file with the voice to clone
|
325 |
+
voice_name: Name for the cloned voice
|
326 |
+
transcript: Transcript of the audio (optional)
|
327 |
+
description: Description of the voice (optional)
|
328 |
+
speaker_id: Speaker ID to use (default: auto-assigned)
|
329 |
+
|
330 |
+
Returns:
|
331 |
+
ClonedVoice object with voice information
|
332 |
+
"""
|
333 |
+
logger.info(f"Cloning new voice '{voice_name}' from audio file")
|
334 |
+
|
335 |
+
# Process the audio file
|
336 |
+
processed_audio, provided_transcript, duration = await self.process_audio_file(
|
337 |
+
audio_file, transcript
|
338 |
+
)
|
339 |
+
|
340 |
+
# Use a better speaker ID assignment - use a small number similar to the built-in voices
|
341 |
+
# This prevents issues with the speaker ID being interpreted as speech
|
342 |
+
if speaker_id is None:
|
343 |
+
# Use a number between 10-20 to avoid conflicts with built-in voices (0-5)
|
344 |
+
# but not too large like 999 which might cause issues
|
345 |
+
existing_ids = [v.speaker_id for v in self.cloned_voices.values()]
|
346 |
+
for potential_id in range(10, 20):
|
347 |
+
if potential_id not in existing_ids:
|
348 |
+
speaker_id = potential_id
|
349 |
+
break
|
350 |
+
else:
|
351 |
+
# If all IDs in range are taken, use a fallback
|
352 |
+
speaker_id = 10
|
353 |
+
|
354 |
+
# Generate a unique ID for the voice
|
355 |
+
voice_id = f"{int(time.time())}_{voice_name.lower().replace(' ', '_')}"
|
356 |
+
|
357 |
+
# Create directory for the voice
|
358 |
+
voice_dir = os.path.join(CLONED_VOICES_DIR, voice_id)
|
359 |
+
os.makedirs(voice_dir, exist_ok=True)
|
360 |
+
|
361 |
+
# Save the processed audio
|
362 |
+
audio_path = os.path.join(voice_dir, "reference.wav")
|
363 |
+
torchaudio.save(audio_path, processed_audio.unsqueeze(0).cpu(), self.sample_rate)
|
364 |
+
|
365 |
+
# Save the transcript if provided
|
366 |
+
if provided_transcript:
|
367 |
+
transcript_path = os.path.join(voice_dir, "transcript.txt")
|
368 |
+
with open(transcript_path, "w") as f:
|
369 |
+
f.write(provided_transcript)
|
370 |
+
|
371 |
+
# Create and save voice info
|
372 |
+
voice_info = ClonedVoice(
|
373 |
+
id=voice_id,
|
374 |
+
name=voice_name,
|
375 |
+
created_at=time.time(),
|
376 |
+
speaker_id=speaker_id,
|
377 |
+
description=description,
|
378 |
+
audio_duration=duration,
|
379 |
+
sample_count=len(processed_audio)
|
380 |
+
)
|
381 |
+
|
382 |
+
# Save voice info as JSON
|
383 |
+
import json
|
384 |
+
with open(os.path.join(voice_dir, "info.json"), "w") as f:
|
385 |
+
f.write(json.dumps(voice_info.dict()))
|
386 |
+
|
387 |
+
# Add to cloned voices dictionary
|
388 |
+
self.cloned_voices[voice_id] = voice_info
|
389 |
+
|
390 |
+
logger.info(f"Voice '{voice_name}' cloned successfully with ID: {voice_id} and speaker_id: {speaker_id}")
|
391 |
+
|
392 |
+
return voice_info
|
393 |
+
|
394 |
+
def get_voice_context(self, voice_id: str) -> List[Segment]:
|
395 |
+
"""
|
396 |
+
Get context segments for a cloned voice.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
voice_id: ID of the cloned voice
|
400 |
+
|
401 |
+
Returns:
|
402 |
+
List of context segments for the voice
|
403 |
+
"""
|
404 |
+
if voice_id not in self.cloned_voices:
|
405 |
+
logger.warning(f"Voice ID {voice_id} not found")
|
406 |
+
return []
|
407 |
+
|
408 |
+
voice = self.cloned_voices[voice_id]
|
409 |
+
voice_dir = os.path.join(CLONED_VOICES_DIR, voice_id)
|
410 |
+
audio_path = os.path.join(voice_dir, "reference.wav")
|
411 |
+
|
412 |
+
if not os.path.exists(audio_path):
|
413 |
+
logger.error(f"Audio file for voice {voice_id} not found at {audio_path}")
|
414 |
+
return []
|
415 |
+
|
416 |
+
try:
|
417 |
+
# Load the audio
|
418 |
+
audio, sr = torchaudio.load(audio_path)
|
419 |
+
audio = audio.squeeze(0)
|
420 |
+
|
421 |
+
# Resample if necessary
|
422 |
+
if sr != self.sample_rate:
|
423 |
+
audio = torchaudio.functional.resample(
|
424 |
+
audio, orig_freq=sr, new_freq=self.sample_rate
|
425 |
+
)
|
426 |
+
|
427 |
+
# Trim to a maximum of 5 seconds to avoid sequence length issues
|
428 |
+
# This is a balance between voice quality and model limitations
|
429 |
+
max_samples = 5 * self.sample_rate # 5 seconds
|
430 |
+
if audio.shape[0] > max_samples:
|
431 |
+
logger.info(f"Trimming voice sample from {audio.shape[0]} to {max_samples} samples")
|
432 |
+
# Take from beginning for better voice characteristics
|
433 |
+
audio = audio[:max_samples]
|
434 |
+
|
435 |
+
# Load transcript if available
|
436 |
+
transcript_path = os.path.join(voice_dir, "transcript.txt")
|
437 |
+
transcript = ""
|
438 |
+
if os.path.exists(transcript_path):
|
439 |
+
with open(transcript_path, "r") as f:
|
440 |
+
full_transcript = f.read()
|
441 |
+
# Take a portion of transcript that roughly matches our audio portion
|
442 |
+
words = full_transcript.split()
|
443 |
+
# Estimate 3 words per second as a rough average
|
444 |
+
word_count = min(len(words), int(5 * 3)) # 5 seconds * 3 words/second
|
445 |
+
transcript = " ".join(words[:word_count])
|
446 |
+
else:
|
447 |
+
transcript = f"Voice sample for {voice.name}"
|
448 |
+
|
449 |
+
# Create context segment
|
450 |
+
segment = Segment(
|
451 |
+
text=transcript,
|
452 |
+
speaker=voice.speaker_id,
|
453 |
+
audio=audio.to(self.device)
|
454 |
+
)
|
455 |
+
|
456 |
+
logger.info(f"Created voice context segment with {audio.shape[0]/self.sample_rate:.1f}s audio")
|
457 |
+
return [segment]
|
458 |
+
|
459 |
+
except Exception as e:
|
460 |
+
logger.error(f"Error getting voice context for {voice_id}: {e}")
|
461 |
+
return []
|
462 |
+
|
463 |
+
def list_voices(self) -> List[ClonedVoice]:
|
464 |
+
"""List all available cloned voices."""
|
465 |
+
return list(self.cloned_voices.values())
|
466 |
+
|
467 |
+
def delete_voice(self, voice_id: str) -> bool:
|
468 |
+
"""
|
469 |
+
Delete a cloned voice.
|
470 |
+
|
471 |
+
Args:
|
472 |
+
voice_id: ID of the voice to delete
|
473 |
+
|
474 |
+
Returns:
|
475 |
+
True if successful, False otherwise
|
476 |
+
"""
|
477 |
+
if voice_id not in self.cloned_voices:
|
478 |
+
return False
|
479 |
+
|
480 |
+
voice_dir = os.path.join(CLONED_VOICES_DIR, voice_id)
|
481 |
+
if os.path.exists(voice_dir):
|
482 |
+
try:
|
483 |
+
import shutil
|
484 |
+
shutil.rmtree(voice_dir)
|
485 |
+
del self.cloned_voices[voice_id]
|
486 |
+
return True
|
487 |
+
except Exception as e:
|
488 |
+
logger.error(f"Error deleting voice {voice_id}: {e}")
|
489 |
+
return False
|
490 |
+
|
491 |
+
return False
|
492 |
+
|
493 |
+
async def clone_voice_from_youtube(
|
494 |
+
self, # Don't forget the self parameter for class methods
|
495 |
+
youtube_url: str,
|
496 |
+
voice_name: str,
|
497 |
+
start_time: int = 0,
|
498 |
+
duration: int = 180,
|
499 |
+
description: str = None
|
500 |
+
) -> ClonedVoice:
|
501 |
+
"""
|
502 |
+
Clone a voice from a YouTube video.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
youtube_url: URL of the YouTube video
|
506 |
+
voice_name: Name for the cloned voice
|
507 |
+
start_time: Start time in seconds
|
508 |
+
duration: Duration to extract in seconds
|
509 |
+
description: Optional description of the voice
|
510 |
+
|
511 |
+
Returns:
|
512 |
+
ClonedVoice object with voice information
|
513 |
+
"""
|
514 |
+
logger.info(f"Cloning voice '{voice_name}' from YouTube: {youtube_url}")
|
515 |
+
|
516 |
+
# Create temporary directory for processing
|
517 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
518 |
+
# Step 1: Download audio from YouTube
|
519 |
+
audio_path = await self._download_youtube_audio(youtube_url, temp_dir, start_time, duration)
|
520 |
+
|
521 |
+
# Step 2: Generate transcript using Whisper
|
522 |
+
transcript = await self._generate_transcript(audio_path)
|
523 |
+
|
524 |
+
# Step 3: Clone the voice using the extracted audio and transcript
|
525 |
+
voice = await self.clone_voice(
|
526 |
+
audio_file=audio_path,
|
527 |
+
voice_name=voice_name,
|
528 |
+
transcript=transcript,
|
529 |
+
description=description or f"Voice cloned from YouTube: {youtube_url}"
|
530 |
+
)
|
531 |
+
|
532 |
+
return voice
|
533 |
+
|
534 |
+
async def _download_youtube_audio(
|
535 |
+
self, # Don't forget the self parameter
|
536 |
+
url: str,
|
537 |
+
output_dir: str,
|
538 |
+
start_time: int = 0,
|
539 |
+
duration: int = 180
|
540 |
+
) -> str:
|
541 |
+
"""
|
542 |
+
Download audio from a YouTube video.
|
543 |
+
|
544 |
+
Args:
|
545 |
+
url: YouTube URL
|
546 |
+
output_dir: Directory to save the audio
|
547 |
+
start_time: Start time in seconds
|
548 |
+
duration: Duration to extract in seconds
|
549 |
+
|
550 |
+
Returns:
|
551 |
+
Path to the downloaded audio file
|
552 |
+
"""
|
553 |
+
output_path = os.path.join(output_dir, "youtube_audio.wav")
|
554 |
+
|
555 |
+
# Configure yt-dlp options
|
556 |
+
ydl_opts = {
|
557 |
+
'format': 'bestaudio/best',
|
558 |
+
'postprocessors': [{
|
559 |
+
'key': 'FFmpegExtractAudio',
|
560 |
+
'preferredcodec': 'wav',
|
561 |
+
'preferredquality': '192',
|
562 |
+
}],
|
563 |
+
'outtmpl': output_path.replace(".wav", ""),
|
564 |
+
'quiet': True,
|
565 |
+
'no_warnings': True
|
566 |
+
}
|
567 |
+
|
568 |
+
# Download the video
|
569 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
570 |
+
ydl.download([url])
|
571 |
+
|
572 |
+
# Trim the audio to the specified segment
|
573 |
+
if start_time > 0 or duration < float('inf'):
|
574 |
+
import ffmpeg
|
575 |
+
trimmed_path = os.path.join(output_dir, "trimmed_audio.wav")
|
576 |
+
|
577 |
+
# Use ffmpeg to trim the audio
|
578 |
+
(
|
579 |
+
ffmpeg.input(output_path)
|
580 |
+
.audio
|
581 |
+
.filter('atrim', start=start_time, duration=duration)
|
582 |
+
.output(trimmed_path)
|
583 |
+
.run(quiet=True, overwrite_output=True)
|
584 |
+
)
|
585 |
+
|
586 |
+
return trimmed_path
|
587 |
+
|
588 |
+
return output_path
|
589 |
+
|
590 |
+
async def _generate_transcript(self, audio_path: str) -> str:
|
591 |
+
"""
|
592 |
+
Generate transcript from audio using Whisper.
|
593 |
+
|
594 |
+
Args:
|
595 |
+
audio_path: Path to the audio file
|
596 |
+
|
597 |
+
Returns:
|
598 |
+
Transcript text
|
599 |
+
"""
|
600 |
+
# Load Whisper model (use small model for faster processing)
|
601 |
+
model = whisper.load_model("small")
|
602 |
+
|
603 |
+
# Transcribe the audio
|
604 |
+
result = model.transcribe(audio_path)
|
605 |
+
|
606 |
+
return result["text"]
|
607 |
+
|
608 |
+
def generate_speech(
|
609 |
+
self,
|
610 |
+
text: str,
|
611 |
+
voice_id: str,
|
612 |
+
temperature: float = 0.65,
|
613 |
+
topk: int = 30,
|
614 |
+
max_audio_length_ms: int = 15000
|
615 |
+
) -> torch.Tensor:
|
616 |
+
"""
|
617 |
+
Generate speech with a cloned voice.
|
618 |
+
Args:
|
619 |
+
text: Text to synthesize
|
620 |
+
voice_id: ID of the cloned voice to use
|
621 |
+
temperature: Sampling temperature (lower = more stable, higher = more varied)
|
622 |
+
topk: Top-k sampling parameter
|
623 |
+
max_audio_length_ms: Maximum audio length in milliseconds
|
624 |
+
Returns:
|
625 |
+
Generated audio tensor
|
626 |
+
"""
|
627 |
+
# Remove any async/await keywords - this is a synchronous function
|
628 |
+
if voice_id not in self.cloned_voices:
|
629 |
+
raise ValueError(f"Voice ID {voice_id} not found")
|
630 |
+
voice = self.cloned_voices[voice_id]
|
631 |
+
context = self.get_voice_context(voice_id)
|
632 |
+
if not context:
|
633 |
+
raise ValueError(f"Could not get context for voice {voice_id}")
|
634 |
+
# Preprocess text for better pronunciation
|
635 |
+
processed_text = self._preprocess_text(text)
|
636 |
+
logger.info(f"Generating speech with voice '{voice.name}' (ID: {voice_id}, speaker: {voice.speaker_id})")
|
637 |
+
try:
|
638 |
+
# Check if text is too long and should be split
|
639 |
+
if len(processed_text) > 200:
|
640 |
+
logger.info(f"Text is long ({len(processed_text)} chars), splitting for better quality")
|
641 |
+
from app.prompt_engineering import split_into_segments
|
642 |
+
# Split text into manageable segments
|
643 |
+
segments = split_into_segments(processed_text, max_chars=150)
|
644 |
+
logger.info(f"Split text into {len(segments)} segments")
|
645 |
+
all_audio_chunks = []
|
646 |
+
# Process each segment
|
647 |
+
for i, segment_text in enumerate(segments):
|
648 |
+
logger.info(f"Generating segment {i+1}/{len(segments)}")
|
649 |
+
# Generate this segment - using plain text without formatting
|
650 |
+
segment_audio = self.generator.generate(
|
651 |
+
text=segment_text, # Use plain text, no formatting
|
652 |
+
speaker=voice.speaker_id,
|
653 |
+
context=context,
|
654 |
+
max_audio_length_ms=min(max_audio_length_ms, 10000),
|
655 |
+
temperature=temperature,
|
656 |
+
topk=topk,
|
657 |
+
)
|
658 |
+
all_audio_chunks.append(segment_audio)
|
659 |
+
# Use this segment as context for the next one for consistency
|
660 |
+
if i < len(segments) - 1:
|
661 |
+
context = [
|
662 |
+
Segment(
|
663 |
+
text=segment_text,
|
664 |
+
speaker=voice.speaker_id,
|
665 |
+
audio=segment_audio
|
666 |
+
)
|
667 |
+
]
|
668 |
+
# Combine chunks with small silence between them
|
669 |
+
if len(all_audio_chunks) == 1:
|
670 |
+
audio = all_audio_chunks[0]
|
671 |
+
else:
|
672 |
+
silence_samples = int(0.1 * self.sample_rate) # 100ms silence
|
673 |
+
silence = torch.zeros(silence_samples, device=all_audio_chunks[0].device)
|
674 |
+
# Join segments with silence
|
675 |
+
audio_parts = []
|
676 |
+
for i, chunk in enumerate(all_audio_chunks):
|
677 |
+
audio_parts.append(chunk)
|
678 |
+
if i < len(all_audio_chunks) - 1: # Don't add silence after the last chunk
|
679 |
+
audio_parts.append(silence)
|
680 |
+
# Concatenate all parts
|
681 |
+
audio = torch.cat(audio_parts)
|
682 |
+
return audio
|
683 |
+
else:
|
684 |
+
# For short text, generate directly - using plain text without formatting
|
685 |
+
audio = self.generator.generate(
|
686 |
+
text=processed_text, # Use plain text, no formatting
|
687 |
+
speaker=voice.speaker_id,
|
688 |
+
context=context,
|
689 |
+
max_audio_length_ms=max_audio_length_ms,
|
690 |
+
temperature=temperature,
|
691 |
+
topk=topk,
|
692 |
+
)
|
693 |
+
return audio
|
694 |
+
except Exception as e:
|
695 |
+
logger.error(f"Error generating speech with voice {voice_id}: {e}")
|
696 |
+
raise
|
697 |
+
|
698 |
+
def _preprocess_text(self, text: str) -> str:
|
699 |
+
"""Preprocess text for better pronunciation and voice cloning."""
|
700 |
+
# Make sure text ends with punctuation for better phrasing
|
701 |
+
text = text.strip()
|
702 |
+
if not text.endswith(('.', '?', '!', ';')):
|
703 |
+
text = text + '.'
|
704 |
+
|
705 |
+
return text
|
app/voice_embeddings.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Voice embeddings for consistent voice generation."""
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import numpy as np
|
6 |
+
from typing import Dict
|
7 |
+
|
8 |
+
# Path to store voice samples
|
9 |
+
VOICE_SAMPLES_DIR = os.path.join(os.path.dirname(__file__), "voice_samples")
|
10 |
+
os.makedirs(VOICE_SAMPLES_DIR, exist_ok=True)
|
11 |
+
|
12 |
+
# Dictionary to store voice embeddings/samples
|
13 |
+
VOICE_DICT: Dict[str, torch.Tensor] = {}
|
14 |
+
|
15 |
+
|
16 |
+
def initialize_voices(sample_rate: int = 24000):
|
17 |
+
"""Initialize voice dictionary with consistent samples."""
|
18 |
+
# Generate consistent seed audio for each voice
|
19 |
+
for voice_id in range(6):
|
20 |
+
voice_name = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"][voice_id]
|
21 |
+
|
22 |
+
# Create deterministic audio sample for each voice
|
23 |
+
np.random.seed(voice_id + 42) # Use a fixed seed based on voice ID
|
24 |
+
|
25 |
+
# Generate 1 second of "seed" audio with deterministic characteristics
|
26 |
+
# This differs per voice but remains consistent across runs
|
27 |
+
duration = 1.0 # seconds
|
28 |
+
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
29 |
+
|
30 |
+
# Create a distinctive waveform for each voice
|
31 |
+
if voice_id == 0: # alloy - rich mid tones
|
32 |
+
freq1, freq2 = 220, 440
|
33 |
+
audio = 0.5 * np.sin(2 * np.pi * freq1 * t) + 0.3 * np.sin(2 * np.pi * freq2 * t)
|
34 |
+
elif voice_id == 1: # echo - reverberant
|
35 |
+
freq = 330
|
36 |
+
audio = np.sin(2 * np.pi * freq * t) * np.exp(-t * 3)
|
37 |
+
elif voice_id == 2: # fable - bright, higher pitch
|
38 |
+
freq = 523
|
39 |
+
audio = 0.7 * np.sin(2 * np.pi * freq * t)
|
40 |
+
elif voice_id == 3: # onyx - deep and resonant
|
41 |
+
freq = 165
|
42 |
+
audio = 0.8 * np.sin(2 * np.pi * freq * t)
|
43 |
+
elif voice_id == 4: # nova - warm and smooth
|
44 |
+
freq1, freq2 = 392, 196
|
45 |
+
audio = 0.4 * np.sin(2 * np.pi * freq1 * t) + 0.4 * np.sin(2 * np.pi * freq2 * t)
|
46 |
+
else: # shimmer - airy and light
|
47 |
+
freq1, freq2, freq3 = 587, 880, 1174
|
48 |
+
audio = 0.3 * np.sin(2 * np.pi * freq1 * t) + 0.2 * np.sin(2 * np.pi * freq2 * t) + 0.1 * np.sin(2 * np.pi * freq3 * t)
|
49 |
+
|
50 |
+
# Normalize
|
51 |
+
audio = audio / np.max(np.abs(audio))
|
52 |
+
|
53 |
+
# Convert to tensor
|
54 |
+
audio_tensor = torch.tensor(audio, dtype=torch.float32)
|
55 |
+
|
56 |
+
# Store the audio tensor
|
57 |
+
VOICE_DICT[voice_name] = audio_tensor
|
58 |
+
|
59 |
+
# Save as wav for reference
|
60 |
+
save_path = os.path.join(VOICE_SAMPLES_DIR, f"{voice_name}_seed.wav")
|
61 |
+
torchaudio.save(save_path, audio_tensor.unsqueeze(0), sample_rate)
|
62 |
+
|
63 |
+
print(f"Initialized voice seed for {voice_name}")
|
64 |
+
|
65 |
+
|
66 |
+
def get_voice_sample(voice_name: str) -> torch.Tensor:
|
67 |
+
"""Get the voice sample for a given voice name."""
|
68 |
+
if not VOICE_DICT:
|
69 |
+
initialize_voices()
|
70 |
+
|
71 |
+
if voice_name in VOICE_DICT:
|
72 |
+
return VOICE_DICT[voice_name]
|
73 |
+
|
74 |
+
# Default to alloy if voice not found
|
75 |
+
print(f"Voice {voice_name} not found, defaulting to alloy")
|
76 |
+
return VOICE_DICT["alloy"]
|
77 |
+
|
78 |
+
|
79 |
+
def update_voice_sample(voice_name: str, audio: torch.Tensor):
|
80 |
+
"""Update the voice sample with recently generated audio."""
|
81 |
+
# Only update if we've already initialized
|
82 |
+
if VOICE_DICT:
|
83 |
+
# Take the last second of audio (or whatever is available)
|
84 |
+
sample_length = min(24000, audio.shape[0])
|
85 |
+
VOICE_DICT[voice_name] = audio[-sample_length:].detach().cpu()
|
app/voice_enhancement.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Advanced voice enhancement and consistency system for CSM-1B."""
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import numpy as np
|
6 |
+
import soundfile as sf
|
7 |
+
from typing import Dict, List, Optional, Tuple
|
8 |
+
import logging
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from scipy import signal
|
11 |
+
|
12 |
+
# Setup logging
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
# Define persistent paths
|
16 |
+
VOICE_REFERENCES_DIR = "/app/voice_references"
|
17 |
+
VOICE_PROFILES_DIR = "/app/voice_profiles"
|
18 |
+
|
19 |
+
# Ensure directories exist
|
20 |
+
os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True)
|
21 |
+
os.makedirs(VOICE_PROFILES_DIR, exist_ok=True)
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class VoiceProfile:
|
25 |
+
"""Detailed voice profile with acoustic characteristics."""
|
26 |
+
name: str
|
27 |
+
speaker_id: int
|
28 |
+
# Acoustic parameters
|
29 |
+
pitch_range: Tuple[float, float] # Min/max pitch in Hz
|
30 |
+
intensity_range: Tuple[float, float] # Min/max intensity (volume)
|
31 |
+
spectral_tilt: float # Brightness vs. darkness
|
32 |
+
prosody_pattern: str # Pattern of intonation and rhythm
|
33 |
+
speech_rate: float # Relative speech rate (1.0 = normal)
|
34 |
+
formant_shift: float # Formant frequency shift (1.0 = no shift)
|
35 |
+
# Reference audio
|
36 |
+
reference_segments: List[torch.Tensor]
|
37 |
+
# Normalization parameters
|
38 |
+
target_rms: float = 0.2
|
39 |
+
target_peak: float = 0.95
|
40 |
+
|
41 |
+
def get_enhancement_params(self) -> Dict:
|
42 |
+
"""Get parameters for enhancing generated audio."""
|
43 |
+
return {
|
44 |
+
"target_rms": self.target_rms,
|
45 |
+
"target_peak": self.target_peak,
|
46 |
+
"pitch_range": self.pitch_range,
|
47 |
+
"formant_shift": self.formant_shift,
|
48 |
+
"speech_rate": self.speech_rate,
|
49 |
+
"spectral_tilt": self.spectral_tilt
|
50 |
+
}
|
51 |
+
|
52 |
+
# Voice profiles with carefully tuned parameters
|
53 |
+
VOICE_PROFILES = {
|
54 |
+
"alloy": VoiceProfile(
|
55 |
+
name="alloy",
|
56 |
+
speaker_id=0,
|
57 |
+
pitch_range=(85, 180), # Hz - balanced range
|
58 |
+
intensity_range=(0.15, 0.3), # moderate intensity
|
59 |
+
spectral_tilt=0.0, # neutral tilt
|
60 |
+
prosody_pattern="balanced",
|
61 |
+
speech_rate=1.0, # normal rate
|
62 |
+
formant_shift=1.0, # no shift
|
63 |
+
reference_segments=[],
|
64 |
+
target_rms=0.2,
|
65 |
+
target_peak=0.95
|
66 |
+
),
|
67 |
+
"echo": VoiceProfile(
|
68 |
+
name="echo",
|
69 |
+
speaker_id=1,
|
70 |
+
pitch_range=(75, 165), # Hz - lower, resonant
|
71 |
+
intensity_range=(0.2, 0.35), # slightly stronger
|
72 |
+
spectral_tilt=-0.2, # more low frequencies
|
73 |
+
prosody_pattern="deliberate",
|
74 |
+
speech_rate=0.95, # slightly slower
|
75 |
+
formant_shift=0.95, # slightly lower formants
|
76 |
+
reference_segments=[],
|
77 |
+
target_rms=0.22, # slightly louder
|
78 |
+
target_peak=0.95
|
79 |
+
),
|
80 |
+
"fable": VoiceProfile(
|
81 |
+
name="fable",
|
82 |
+
speaker_id=2,
|
83 |
+
pitch_range=(120, 250), # Hz - higher range
|
84 |
+
intensity_range=(0.15, 0.28), # moderate intensity
|
85 |
+
spectral_tilt=0.2, # more high frequencies
|
86 |
+
prosody_pattern="animated",
|
87 |
+
speech_rate=1.05, # slightly faster
|
88 |
+
formant_shift=1.05, # slightly higher formants
|
89 |
+
reference_segments=[],
|
90 |
+
target_rms=0.19,
|
91 |
+
target_peak=0.95
|
92 |
+
),
|
93 |
+
"onyx": VoiceProfile(
|
94 |
+
name="onyx",
|
95 |
+
speaker_id=3,
|
96 |
+
pitch_range=(65, 150), # Hz - deeper range
|
97 |
+
intensity_range=(0.18, 0.32), # moderate-strong
|
98 |
+
spectral_tilt=-0.3, # more low frequencies
|
99 |
+
prosody_pattern="authoritative",
|
100 |
+
speech_rate=0.93, # slightly slower
|
101 |
+
formant_shift=0.9, # lower formants
|
102 |
+
reference_segments=[],
|
103 |
+
target_rms=0.23, # stronger
|
104 |
+
target_peak=0.95
|
105 |
+
),
|
106 |
+
"nova": VoiceProfile(
|
107 |
+
name="nova",
|
108 |
+
speaker_id=4,
|
109 |
+
pitch_range=(90, 200), # Hz - warm midrange
|
110 |
+
intensity_range=(0.15, 0.27), # moderate
|
111 |
+
spectral_tilt=-0.1, # slightly warm
|
112 |
+
prosody_pattern="flowing",
|
113 |
+
speech_rate=1.0, # normal rate
|
114 |
+
formant_shift=1.0, # no shift
|
115 |
+
reference_segments=[],
|
116 |
+
target_rms=0.2,
|
117 |
+
target_peak=0.95
|
118 |
+
),
|
119 |
+
"shimmer": VoiceProfile(
|
120 |
+
name="shimmer",
|
121 |
+
speaker_id=5,
|
122 |
+
pitch_range=(140, 280), # Hz - brighter, higher
|
123 |
+
intensity_range=(0.15, 0.25), # moderate-light
|
124 |
+
spectral_tilt=0.3, # more high frequencies
|
125 |
+
prosody_pattern="light",
|
126 |
+
speech_rate=1.07, # slightly faster
|
127 |
+
formant_shift=1.1, # higher formants
|
128 |
+
reference_segments=[],
|
129 |
+
target_rms=0.18, # slightly softer
|
130 |
+
target_peak=0.95
|
131 |
+
)
|
132 |
+
}
|
133 |
+
|
134 |
+
# Voice-specific prompt templates - crafted to establish voice identity clearly
|
135 |
+
VOICE_PROMPTS = {
|
136 |
+
"alloy": [
|
137 |
+
"Hello, I'm Alloy. I speak with a balanced, natural tone that's easy to understand.",
|
138 |
+
"This is Alloy speaking. My voice is designed to be clear and conversational.",
|
139 |
+
"Alloy here - I have a neutral, friendly voice with balanced tone qualities."
|
140 |
+
],
|
141 |
+
"echo": [
|
142 |
+
"Hello, I'm Echo. I speak with a resonant, deeper voice that carries well.",
|
143 |
+
"This is Echo speaking. My voice has a rich, resonant quality with depth.",
|
144 |
+
"Echo here - My voice is characterized by its warm, resonant tones."
|
145 |
+
],
|
146 |
+
"fable": [
|
147 |
+
"Hello, I'm Fable. I speak with a bright, higher-pitched voice that's full of energy.",
|
148 |
+
"This is Fable speaking. My voice is characterized by its clear, bright quality.",
|
149 |
+
"Fable here - My voice is light, articulate, and slightly higher-pitched."
|
150 |
+
],
|
151 |
+
"onyx": [
|
152 |
+
"Hello, I'm Onyx. I speak with a deep, authoritative voice that commands attention.",
|
153 |
+
"This is Onyx speaking. My voice has a powerful, deep quality with gravitas.",
|
154 |
+
"Onyx here - My voice is characterized by its depth and commanding presence."
|
155 |
+
],
|
156 |
+
"nova": [
|
157 |
+
"Hello, I'm Nova. I speak with a warm, pleasant mid-range voice that's easy to listen to.",
|
158 |
+
"This is Nova speaking. My voice has a smooth, harmonious quality.",
|
159 |
+
"Nova here - My voice is characterized by its warm, friendly mid-tones."
|
160 |
+
],
|
161 |
+
"shimmer": [
|
162 |
+
"Hello, I'm Shimmer. I speak with a light, bright voice that's expressive and clear.",
|
163 |
+
"This is Shimmer speaking. My voice has an airy, higher-pitched quality.",
|
164 |
+
"Shimmer here - My voice is characterized by its bright, crystalline tones."
|
165 |
+
]
|
166 |
+
}
|
167 |
+
|
168 |
+
def initialize_voice_profiles():
|
169 |
+
"""Initialize voice profiles with default settings.
|
170 |
+
|
171 |
+
This function loads existing voice profiles from disk if available,
|
172 |
+
or initializes them with default settings.
|
173 |
+
"""
|
174 |
+
global VOICE_PROFILES
|
175 |
+
|
176 |
+
# Try to load existing profiles from persistent storage
|
177 |
+
profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt")
|
178 |
+
|
179 |
+
if os.path.exists(profile_path):
|
180 |
+
try:
|
181 |
+
logger.info(f"Loading voice profiles from {profile_path}")
|
182 |
+
saved_profiles = torch.load(profile_path)
|
183 |
+
|
184 |
+
# Update existing profiles with saved data
|
185 |
+
for name, data in saved_profiles.items():
|
186 |
+
if name in VOICE_PROFILES:
|
187 |
+
VOICE_PROFILES[name].reference_segments = [
|
188 |
+
seg.to(torch.device("cpu")) for seg in data.get('reference_segments', [])
|
189 |
+
]
|
190 |
+
|
191 |
+
logger.info(f"Loaded voice profiles for {len(saved_profiles)} voices")
|
192 |
+
except Exception as e:
|
193 |
+
logger.error(f"Error loading voice profiles: {e}")
|
194 |
+
logger.info("Using default voice profiles")
|
195 |
+
else:
|
196 |
+
logger.info("No saved voice profiles found, using defaults")
|
197 |
+
|
198 |
+
# Ensure all voices have at least empty reference segments
|
199 |
+
for name, profile in VOICE_PROFILES.items():
|
200 |
+
if not hasattr(profile, 'reference_segments'):
|
201 |
+
profile.reference_segments = []
|
202 |
+
|
203 |
+
logger.info(f"Voice profiles initialized for {len(VOICE_PROFILES)} voices")
|
204 |
+
return VOICE_PROFILES
|
205 |
+
|
206 |
+
def normalize_audio(audio: torch.Tensor, target_rms: float = 0.2, target_peak: float = 0.95) -> torch.Tensor:
|
207 |
+
"""Apply professional-grade normalization to audio.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
audio: Audio tensor
|
211 |
+
target_rms: Target RMS level for normalization
|
212 |
+
target_peak: Target peak level for limiting
|
213 |
+
|
214 |
+
Returns:
|
215 |
+
Normalized audio tensor
|
216 |
+
"""
|
217 |
+
# Ensure audio is on CPU for processing
|
218 |
+
audio_cpu = audio.detach().cpu()
|
219 |
+
|
220 |
+
# Handle silent audio
|
221 |
+
if audio_cpu.abs().max() < 1e-6:
|
222 |
+
logger.warning("Audio is nearly silent, returning original")
|
223 |
+
return audio
|
224 |
+
|
225 |
+
# Calculate current RMS
|
226 |
+
current_rms = torch.sqrt(torch.mean(audio_cpu ** 2))
|
227 |
+
|
228 |
+
# Apply RMS normalization
|
229 |
+
if current_rms > 0:
|
230 |
+
gain = target_rms / current_rms
|
231 |
+
normalized = audio_cpu * gain
|
232 |
+
else:
|
233 |
+
normalized = audio_cpu
|
234 |
+
|
235 |
+
# Apply peak limiting
|
236 |
+
current_peak = normalized.abs().max()
|
237 |
+
if current_peak > target_peak:
|
238 |
+
normalized = normalized * (target_peak / current_peak)
|
239 |
+
|
240 |
+
# Return to original device
|
241 |
+
return normalized.to(audio.device)
|
242 |
+
|
243 |
+
def apply_anti_muffling(audio: torch.Tensor, sample_rate: int, clarity_boost: float = 1.2) -> torch.Tensor:
|
244 |
+
"""Apply anti-muffling to improve clarity.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
audio: Audio tensor
|
248 |
+
sample_rate: Audio sample rate
|
249 |
+
clarity_boost: Amount of high frequency boost (1.0 = no boost)
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Processed audio tensor
|
253 |
+
"""
|
254 |
+
# Convert to numpy for filtering
|
255 |
+
audio_np = audio.detach().cpu().numpy()
|
256 |
+
|
257 |
+
try:
|
258 |
+
# Design a high shelf filter to boost high frequencies
|
259 |
+
# Use a standard high-shelf filter that's supported by scipy.signal
|
260 |
+
# We'll use a second-order Butterworth high-pass filter as an alternative
|
261 |
+
cutoff = 2000 # Hz
|
262 |
+
b, a = signal.butter(2, cutoff/(sample_rate/2), btype='high', analog=False)
|
263 |
+
|
264 |
+
# Apply the filter with the clarity boost gain
|
265 |
+
boosted = signal.filtfilt(b, a, audio_np, axis=0) * clarity_boost
|
266 |
+
|
267 |
+
# Mix with original to maintain some warmth
|
268 |
+
mix_ratio = 0.7 # 70% processed, 30% original
|
269 |
+
processed = mix_ratio * boosted + (1-mix_ratio) * audio_np
|
270 |
+
|
271 |
+
except Exception as e:
|
272 |
+
logger.warning(f"Audio enhancement failed, using original: {e}")
|
273 |
+
# Return original audio if enhancement fails
|
274 |
+
return audio
|
275 |
+
|
276 |
+
# Convert back to tensor on original device
|
277 |
+
return torch.tensor(processed, dtype=audio.dtype, device=audio.device)
|
278 |
+
|
279 |
+
def enhance_audio(audio: torch.Tensor, sample_rate: int, voice_profile: VoiceProfile) -> torch.Tensor:
|
280 |
+
"""Apply comprehensive audio enhancement based on voice profile.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
audio: Audio tensor
|
284 |
+
sample_rate: Audio sample rate
|
285 |
+
voice_profile: Voice profile containing enhancement parameters
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
Enhanced audio tensor
|
289 |
+
"""
|
290 |
+
if audio is None or audio.numel() == 0:
|
291 |
+
logger.error("Cannot enhance empty audio")
|
292 |
+
return audio
|
293 |
+
|
294 |
+
try:
|
295 |
+
# Step 1: Normalize audio levels
|
296 |
+
params = voice_profile.get_enhancement_params()
|
297 |
+
normalized = normalize_audio(
|
298 |
+
audio,
|
299 |
+
target_rms=params["target_rms"],
|
300 |
+
target_peak=params["target_peak"]
|
301 |
+
)
|
302 |
+
|
303 |
+
# Step 2: Apply anti-muffling based on spectral tilt
|
304 |
+
# Positive tilt means brighter voice so less clarity boost needed
|
305 |
+
clarity_boost = 1.0 + max(0, -params["spectral_tilt"]) * 0.5
|
306 |
+
clarified = apply_anti_muffling(
|
307 |
+
normalized,
|
308 |
+
sample_rate,
|
309 |
+
clarity_boost=clarity_boost
|
310 |
+
)
|
311 |
+
|
312 |
+
# Log the enhancement
|
313 |
+
logger.debug(
|
314 |
+
f"Enhanced audio for {voice_profile.name}: "
|
315 |
+
f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{clarified.pow(2).mean().sqrt().item():.3f}, "
|
316 |
+
f"Peak: {audio.abs().max().item():.3f}->{clarified.abs().max().item():.3f}"
|
317 |
+
)
|
318 |
+
|
319 |
+
return clarified
|
320 |
+
|
321 |
+
except Exception as e:
|
322 |
+
logger.error(f"Error in audio enhancement: {e}")
|
323 |
+
return audio # Return original audio if enhancement fails
|
324 |
+
|
325 |
+
def validate_generated_audio(
|
326 |
+
audio: torch.Tensor,
|
327 |
+
voice_name: str,
|
328 |
+
sample_rate: int,
|
329 |
+
min_expected_duration: float = 0.5
|
330 |
+
) -> Tuple[bool, torch.Tensor, str]:
|
331 |
+
"""Validate and fix generated audio.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
audio: Audio tensor to validate
|
335 |
+
voice_name: Name of the voice used
|
336 |
+
sample_rate: Audio sample rate
|
337 |
+
min_expected_duration: Minimum expected duration in seconds
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
Tuple of (is_valid, fixed_audio, message)
|
341 |
+
"""
|
342 |
+
if audio is None:
|
343 |
+
return False, torch.zeros(1), "Audio is None"
|
344 |
+
|
345 |
+
# Check for NaN values
|
346 |
+
if torch.isnan(audio).any():
|
347 |
+
logger.warning(f"Audio for {voice_name} contains NaN values, replacing with zeros")
|
348 |
+
audio = torch.where(torch.isnan(audio), torch.zeros_like(audio), audio)
|
349 |
+
|
350 |
+
# Check audio duration
|
351 |
+
duration = audio.shape[0] / sample_rate
|
352 |
+
if duration < min_expected_duration:
|
353 |
+
logger.warning(f"Audio for {voice_name} is too short ({duration:.2f}s < {min_expected_duration}s)")
|
354 |
+
return False, audio, f"Audio too short: {duration:.2f}s"
|
355 |
+
|
356 |
+
# Check for silent sections - this can indicate generation problems
|
357 |
+
rms = torch.sqrt(torch.mean(audio ** 2))
|
358 |
+
if rms < 0.01: # Very low RMS indicates near silence
|
359 |
+
logger.warning(f"Audio for {voice_name} is nearly silent (RMS: {rms:.6f})")
|
360 |
+
return False, audio, f"Audio nearly silent: RMS = {rms:.6f}"
|
361 |
+
|
362 |
+
# Check if audio suddenly cuts off - this detects premature stopping
|
363 |
+
# Calculate RMS in the last 100ms
|
364 |
+
last_samples = int(0.1 * sample_rate)
|
365 |
+
if audio.shape[0] > last_samples:
|
366 |
+
end_rms = torch.sqrt(torch.mean(audio[-last_samples:] ** 2))
|
367 |
+
if end_rms > 0.1: # High RMS at the end suggests an abrupt cutoff
|
368 |
+
logger.warning(f"Audio for {voice_name} may have cut off prematurely (end RMS: {end_rms:.3f})")
|
369 |
+
return True, audio, "Audio may have cut off prematurely"
|
370 |
+
|
371 |
+
return True, audio, "Audio validation passed"
|
372 |
+
|
373 |
+
def create_voice_segments(app_state, regenerate: bool = False):
|
374 |
+
"""Create high-quality voice reference segments.
|
375 |
+
|
376 |
+
Args:
|
377 |
+
app_state: Application state containing generator
|
378 |
+
regenerate: Whether to regenerate existing references
|
379 |
+
"""
|
380 |
+
generator = app_state.generator
|
381 |
+
if not generator:
|
382 |
+
logger.error("Cannot create voice segments: generator not available")
|
383 |
+
return
|
384 |
+
|
385 |
+
# Use persistent directory for voice reference segments
|
386 |
+
os.makedirs(VOICE_REFERENCES_DIR, exist_ok=True)
|
387 |
+
|
388 |
+
for voice_name, profile in VOICE_PROFILES.items():
|
389 |
+
voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name)
|
390 |
+
os.makedirs(voice_dir, exist_ok=True)
|
391 |
+
|
392 |
+
# Check if we already have references
|
393 |
+
if not regenerate and profile.reference_segments:
|
394 |
+
logger.info(f"Voice {voice_name} already has {len(profile.reference_segments)} reference segments")
|
395 |
+
continue
|
396 |
+
|
397 |
+
# Get prompts for this voice
|
398 |
+
prompts = VOICE_PROMPTS[voice_name]
|
399 |
+
|
400 |
+
# Generate reference segments
|
401 |
+
logger.info(f"Generating reference segments for voice: {voice_name}")
|
402 |
+
reference_segments = []
|
403 |
+
|
404 |
+
for i, prompt in enumerate(prompts):
|
405 |
+
ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav")
|
406 |
+
|
407 |
+
# Skip if file exists and we're not regenerating
|
408 |
+
if not regenerate and os.path.exists(ref_path):
|
409 |
+
try:
|
410 |
+
# Load existing reference
|
411 |
+
audio_tensor, sr = torchaudio.load(ref_path)
|
412 |
+
if sr != generator.sample_rate:
|
413 |
+
audio_tensor = torchaudio.functional.resample(
|
414 |
+
audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate
|
415 |
+
)
|
416 |
+
else:
|
417 |
+
audio_tensor = audio_tensor.squeeze(0)
|
418 |
+
reference_segments.append(audio_tensor.to(generator.device))
|
419 |
+
logger.info(f"Loaded existing reference {i+1}/{len(prompts)} for {voice_name}")
|
420 |
+
continue
|
421 |
+
except Exception as e:
|
422 |
+
logger.warning(f"Failed to load existing reference {i+1} for {voice_name}: {e}")
|
423 |
+
|
424 |
+
try:
|
425 |
+
# Use a lower temperature for more stability in reference samples
|
426 |
+
logger.info(f"Generating reference {i+1}/{len(prompts)} for {voice_name}: '{prompt}'")
|
427 |
+
|
428 |
+
# We want references to be as clean as possible
|
429 |
+
audio = generator.generate(
|
430 |
+
text=prompt,
|
431 |
+
speaker=profile.speaker_id,
|
432 |
+
context=[], # No context for initial samples to prevent voice bleed
|
433 |
+
max_audio_length_ms=6000, # Shorter for more control
|
434 |
+
temperature=0.7, # Lower temperature for more stability
|
435 |
+
topk=30, # More focused sampling
|
436 |
+
)
|
437 |
+
|
438 |
+
# Validate and enhance the audio
|
439 |
+
is_valid, audio, message = validate_generated_audio(
|
440 |
+
audio, voice_name, generator.sample_rate
|
441 |
+
)
|
442 |
+
|
443 |
+
if is_valid:
|
444 |
+
# Enhance the audio
|
445 |
+
audio = enhance_audio(audio, generator.sample_rate, profile)
|
446 |
+
|
447 |
+
# Save the reference to persistent storage
|
448 |
+
torchaudio.save(ref_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
|
449 |
+
reference_segments.append(audio)
|
450 |
+
logger.info(f"Generated reference {i+1} for {voice_name}: {message}")
|
451 |
+
else:
|
452 |
+
logger.warning(f"Invalid reference for {voice_name}: {message}")
|
453 |
+
# Try again with different settings if invalid
|
454 |
+
if i < len(prompts) - 1:
|
455 |
+
logger.info(f"Trying again with next prompt")
|
456 |
+
continue
|
457 |
+
|
458 |
+
except Exception as e:
|
459 |
+
logger.error(f"Error generating reference for {voice_name}: {e}")
|
460 |
+
|
461 |
+
# Update the voice profile with references
|
462 |
+
if reference_segments:
|
463 |
+
VOICE_PROFILES[voice_name].reference_segments = reference_segments
|
464 |
+
logger.info(f"Updated {voice_name} with {len(reference_segments)} reference segments")
|
465 |
+
|
466 |
+
# Save the updated profiles to persistent storage
|
467 |
+
save_voice_profiles()
|
468 |
+
|
469 |
+
def get_voice_segments(voice_name: str, device: torch.device) -> List:
|
470 |
+
"""Get context segments for a given voice.
|
471 |
+
|
472 |
+
Args:
|
473 |
+
voice_name: Name of the voice to use
|
474 |
+
device: Device to place tensors on
|
475 |
+
|
476 |
+
Returns:
|
477 |
+
List of context segments
|
478 |
+
"""
|
479 |
+
from app.models import Segment
|
480 |
+
|
481 |
+
if voice_name not in VOICE_PROFILES:
|
482 |
+
logger.warning(f"Voice {voice_name} not found, defaulting to alloy")
|
483 |
+
voice_name = "alloy"
|
484 |
+
|
485 |
+
profile = VOICE_PROFILES[voice_name]
|
486 |
+
|
487 |
+
# If we don't have reference segments yet, create them
|
488 |
+
if not profile.reference_segments:
|
489 |
+
try:
|
490 |
+
# Try to load from disk - use persistent storage
|
491 |
+
voice_dir = os.path.join(VOICE_REFERENCES_DIR, voice_name)
|
492 |
+
|
493 |
+
if os.path.exists(voice_dir):
|
494 |
+
reference_segments = []
|
495 |
+
prompts = VOICE_PROMPTS[voice_name]
|
496 |
+
|
497 |
+
for i, prompt in enumerate(prompts):
|
498 |
+
ref_path = os.path.join(voice_dir, f"{voice_name}_ref_{i}.wav")
|
499 |
+
if os.path.exists(ref_path):
|
500 |
+
audio_tensor, sr = torchaudio.load(ref_path)
|
501 |
+
audio_tensor = audio_tensor.squeeze(0)
|
502 |
+
reference_segments.append(audio_tensor)
|
503 |
+
|
504 |
+
if reference_segments:
|
505 |
+
profile.reference_segments = reference_segments
|
506 |
+
logger.info(f"Loaded {len(reference_segments)} reference segments for {voice_name}")
|
507 |
+
except Exception as e:
|
508 |
+
logger.error(f"Error loading reference segments for {voice_name}: {e}")
|
509 |
+
|
510 |
+
# Create context segments from references
|
511 |
+
context = []
|
512 |
+
if profile.reference_segments:
|
513 |
+
for i, ref_audio in enumerate(profile.reference_segments):
|
514 |
+
# Use corresponding prompt if available, otherwise use a generic one
|
515 |
+
text = VOICE_PROMPTS[voice_name][i] if i < len(VOICE_PROMPTS[voice_name]) else f"Voice reference for {voice_name}"
|
516 |
+
|
517 |
+
context.append(
|
518 |
+
Segment(
|
519 |
+
speaker=profile.speaker_id,
|
520 |
+
text=text,
|
521 |
+
audio=ref_audio.to(device)
|
522 |
+
)
|
523 |
+
)
|
524 |
+
|
525 |
+
logger.info(f"Returning {len(context)} context segments for {voice_name}")
|
526 |
+
return context
|
527 |
+
|
528 |
+
def save_voice_profiles():
|
529 |
+
"""Save voice profiles to persistent storage."""
|
530 |
+
os.makedirs(VOICE_PROFILES_DIR, exist_ok=True)
|
531 |
+
profile_path = os.path.join(VOICE_PROFILES_DIR, "voice_profiles.pt")
|
532 |
+
|
533 |
+
# Create a serializable version of the profiles
|
534 |
+
serializable_profiles = {}
|
535 |
+
for name, profile in VOICE_PROFILES.items():
|
536 |
+
serializable_profiles[name] = {
|
537 |
+
'reference_segments': [seg.cpu() for seg in profile.reference_segments]
|
538 |
+
}
|
539 |
+
|
540 |
+
# Save to persistent storage
|
541 |
+
torch.save(serializable_profiles, profile_path)
|
542 |
+
logger.info(f"Saved voice profiles to {profile_path}")
|
543 |
+
|
544 |
+
def process_generated_audio(
|
545 |
+
audio: torch.Tensor,
|
546 |
+
voice_name: str,
|
547 |
+
sample_rate: int,
|
548 |
+
text: str
|
549 |
+
) -> torch.Tensor:
|
550 |
+
"""Process generated audio for consistency and quality.
|
551 |
+
|
552 |
+
Args:
|
553 |
+
audio: Audio tensor
|
554 |
+
voice_name: Name of voice used
|
555 |
+
sample_rate: Audio sample rate
|
556 |
+
text: Text that was spoken
|
557 |
+
|
558 |
+
Returns:
|
559 |
+
Processed audio tensor
|
560 |
+
"""
|
561 |
+
# Validate the audio
|
562 |
+
is_valid, audio, message = validate_generated_audio(audio, voice_name, sample_rate)
|
563 |
+
if not is_valid:
|
564 |
+
logger.warning(f"Generated audio validation issue: {message}")
|
565 |
+
|
566 |
+
# Get voice profile for enhancement
|
567 |
+
profile = VOICE_PROFILES.get(voice_name, VOICE_PROFILES["alloy"])
|
568 |
+
|
569 |
+
# Enhance the audio based on voice profile
|
570 |
+
enhanced = enhance_audio(audio, sample_rate, profile)
|
571 |
+
|
572 |
+
# Log the enhancement
|
573 |
+
original_duration = audio.shape[0] / sample_rate
|
574 |
+
enhanced_duration = enhanced.shape[0] / sample_rate
|
575 |
+
logger.info(
|
576 |
+
f"Processed audio for '{voice_name}': "
|
577 |
+
f"Duration: {original_duration:.2f}s->{enhanced_duration:.2f}s, "
|
578 |
+
f"RMS: {audio.pow(2).mean().sqrt().item():.3f}->{enhanced.pow(2).mean().sqrt().item():.3f}"
|
579 |
+
)
|
580 |
+
|
581 |
+
return enhanced
|
app/voice_memory.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Advanced voice memory system for consistent voice generation."""
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import numpy as np
|
6 |
+
import random
|
7 |
+
import logging
|
8 |
+
from typing import Dict, List, Optional
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from app.models import Segment
|
11 |
+
|
12 |
+
# Setup logging
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
# Path to store voice memories - use persistent location
|
16 |
+
VOICE_MEMORIES_DIR = "/app/voice_memories"
|
17 |
+
os.makedirs(VOICE_MEMORIES_DIR, exist_ok=True)
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class VoiceMemory:
|
21 |
+
"""Store voice characteristics for consistent generation."""
|
22 |
+
name: str # Voice name (alloy, echo, etc.)
|
23 |
+
speaker_id: int # Speaker ID (0-5)
|
24 |
+
# Store multiple audio segments for context
|
25 |
+
audio_segments: List[torch.Tensor]
|
26 |
+
# Store text prompts that produced good results
|
27 |
+
text_segments: List[str]
|
28 |
+
# Base characteristics for this voice
|
29 |
+
pitch_base: float # Base pitch characteristic (Hz)
|
30 |
+
timbre: str # Voice quality descriptor
|
31 |
+
|
32 |
+
def get_context_segments(self, device: torch.device, max_segments: int = 2) -> List[Segment]:
|
33 |
+
"""Get context segments for this voice."""
|
34 |
+
if not self.audio_segments:
|
35 |
+
return []
|
36 |
+
|
37 |
+
# Select a limited number of segments to avoid context overflow
|
38 |
+
num_segments = min(len(self.audio_segments), max_segments)
|
39 |
+
indices = list(range(len(self.audio_segments)))
|
40 |
+
random.shuffle(indices)
|
41 |
+
selected_indices = indices[:num_segments]
|
42 |
+
|
43 |
+
segments = []
|
44 |
+
for i in selected_indices:
|
45 |
+
segments.append(
|
46 |
+
Segment(
|
47 |
+
speaker=self.speaker_id,
|
48 |
+
text=self.text_segments[i] if i < len(self.text_segments) else f"Voice sample {i}",
|
49 |
+
audio=self.audio_segments[i].to(device)
|
50 |
+
)
|
51 |
+
)
|
52 |
+
|
53 |
+
return segments
|
54 |
+
|
55 |
+
def update_with_new_audio(self, audio: torch.Tensor, text: str, max_stored: int = 5):
|
56 |
+
"""Update voice memory with newly generated audio."""
|
57 |
+
# Add new audio and text
|
58 |
+
self.audio_segments.append(audio.detach().cpu())
|
59 |
+
self.text_segments.append(text)
|
60 |
+
|
61 |
+
# Keep only the most recent segments
|
62 |
+
if len(self.audio_segments) > max_stored:
|
63 |
+
self.audio_segments = self.audio_segments[-max_stored:]
|
64 |
+
self.text_segments = self.text_segments[-max_stored:]
|
65 |
+
|
66 |
+
def save(self):
|
67 |
+
"""Save voice memory to persistent storage."""
|
68 |
+
data = {
|
69 |
+
"name": self.name,
|
70 |
+
"speaker_id": self.speaker_id,
|
71 |
+
"audio_segments": self.audio_segments,
|
72 |
+
"text_segments": self.text_segments,
|
73 |
+
"pitch_base": self.pitch_base,
|
74 |
+
"timbre": self.timbre
|
75 |
+
}
|
76 |
+
|
77 |
+
# Save to the persistent directory
|
78 |
+
save_path = os.path.join(VOICE_MEMORIES_DIR, f"{self.name}.pt")
|
79 |
+
try:
|
80 |
+
torch.save(data, save_path)
|
81 |
+
logger.info(f"Saved voice memory for {self.name} to {save_path}")
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Error saving voice memory for {self.name}: {e}")
|
84 |
+
|
85 |
+
@classmethod
|
86 |
+
def load(cls, name: str) -> Optional['VoiceMemory']:
|
87 |
+
"""Load voice memory from persistent storage."""
|
88 |
+
path = os.path.join(VOICE_MEMORIES_DIR, f"{name}.pt")
|
89 |
+
if not os.path.exists(path):
|
90 |
+
logger.info(f"No saved voice memory found for {name} at {path}")
|
91 |
+
return None
|
92 |
+
|
93 |
+
try:
|
94 |
+
data = torch.load(path)
|
95 |
+
return cls(
|
96 |
+
name=data["name"],
|
97 |
+
speaker_id=data["speaker_id"],
|
98 |
+
audio_segments=data["audio_segments"],
|
99 |
+
text_segments=data["text_segments"],
|
100 |
+
pitch_base=data["pitch_base"],
|
101 |
+
timbre=data["timbre"]
|
102 |
+
)
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Error loading voice memory for {name}: {e}")
|
105 |
+
return None
|
106 |
+
|
107 |
+
# Dictionary of voice memories
|
108 |
+
VOICE_MEMORIES: Dict[str, VoiceMemory] = {}
|
109 |
+
|
110 |
+
# Voice characteristics
|
111 |
+
VOICE_CHARACTERISTICS = {
|
112 |
+
"alloy": {"pitch": 220.0, "timbre": "balanced", "description": "A balanced, natural voice with medium pitch"},
|
113 |
+
"echo": {"pitch": 330.0, "timbre": "resonant", "description": "A resonant voice with a reverberant quality"},
|
114 |
+
"fable": {"pitch": 523.0, "timbre": "bright", "description": "A bright, higher-pitched voice with clear articulation"},
|
115 |
+
"onyx": {"pitch": 165.0, "timbre": "deep", "description": "A deep, authoritative voice with lower pitch"},
|
116 |
+
"nova": {"pitch": 392.0, "timbre": "warm", "description": "A warm, smooth voice with pleasant midrange tone"},
|
117 |
+
"shimmer": {"pitch": 587.0, "timbre": "light", "description": "A light, airy voice with higher frequencies"}
|
118 |
+
}
|
119 |
+
|
120 |
+
# Voice intro texts - carefully crafted to capture voice characteristics
|
121 |
+
VOICE_INTROS = {
|
122 |
+
"alloy": [
|
123 |
+
"Hello, I'm Alloy. My voice is designed to be clear and balanced.",
|
124 |
+
"This is the Alloy voice. I aim to sound natural and easy to understand.",
|
125 |
+
"Welcome, I'm the voice known as Alloy. I have a balanced, medium-range tone."
|
126 |
+
],
|
127 |
+
"echo": [
|
128 |
+
"Hello, I'm Echo. My voice has a rich, resonant quality.",
|
129 |
+
"This is the Echo voice. Notice my distinctive resonance and depth.",
|
130 |
+
"Welcome, I'm the voice known as Echo. My tone is designed to resonate clearly."
|
131 |
+
],
|
132 |
+
"fable": [
|
133 |
+
"Hello, I'm Fable. My voice is bright and articulate.",
|
134 |
+
"This is the Fable voice. I have a higher pitch with clear pronunciation.",
|
135 |
+
"Welcome, I'm the voice known as Fable. I speak with a bright, energetic tone."
|
136 |
+
],
|
137 |
+
"onyx": [
|
138 |
+
"Hello, I'm Onyx. My voice is deep and authoritative.",
|
139 |
+
"This is the Onyx voice. I speak with a lower pitch and commanding presence.",
|
140 |
+
"Welcome, I'm the voice known as Onyx. My tone is deep and resonant."
|
141 |
+
],
|
142 |
+
"nova": [
|
143 |
+
"Hello, I'm Nova. My voice is warm and harmonious.",
|
144 |
+
"This is the Nova voice. I have a smooth, pleasant mid-range quality.",
|
145 |
+
"Welcome, I'm the voice known as Nova. I speak with a warm, friendly tone."
|
146 |
+
],
|
147 |
+
"shimmer": [
|
148 |
+
"Hello, I'm Shimmer. My voice is light and expressive.",
|
149 |
+
"This is the Shimmer voice. I have a higher-pitched, airy quality.",
|
150 |
+
"Welcome, I'm the voice known as Shimmer. My tone is bright and crisp."
|
151 |
+
]
|
152 |
+
}
|
153 |
+
|
154 |
+
def initialize_voices(sample_rate: int = 24000):
|
155 |
+
"""Initialize voice memories with consistent base samples."""
|
156 |
+
global VOICE_MEMORIES
|
157 |
+
|
158 |
+
# Check if persistent directory exists, create if needed
|
159 |
+
os.makedirs(VOICE_MEMORIES_DIR, exist_ok=True)
|
160 |
+
logger.info(f"Using voice memories directory: {VOICE_MEMORIES_DIR}")
|
161 |
+
|
162 |
+
# First try to load existing memories from persistent storage
|
163 |
+
for voice_name in ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]:
|
164 |
+
memory = VoiceMemory.load(voice_name)
|
165 |
+
if memory:
|
166 |
+
VOICE_MEMORIES[voice_name] = memory
|
167 |
+
logger.info(f"Loaded existing voice memory for {voice_name} with {len(memory.audio_segments)} segments")
|
168 |
+
continue
|
169 |
+
|
170 |
+
# If no memory exists, create a new one
|
171 |
+
speaker_id = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"].index(voice_name)
|
172 |
+
characteristics = VOICE_CHARACTERISTICS[voice_name]
|
173 |
+
|
174 |
+
# Create deterministic seed audio
|
175 |
+
np.random.seed(speaker_id + 42)
|
176 |
+
duration = 1.0 # seconds
|
177 |
+
t = np.linspace(0, duration, int(sample_rate * duration), endpoint=False)
|
178 |
+
|
179 |
+
# Create characteristic waveform
|
180 |
+
pitch = characteristics["pitch"]
|
181 |
+
if voice_name == "alloy":
|
182 |
+
audio = 0.5 * np.sin(2 * np.pi * pitch * t) + 0.3 * np.sin(2 * np.pi * pitch * 2 * t)
|
183 |
+
elif voice_name == "echo":
|
184 |
+
audio = np.sin(2 * np.pi * pitch * t) * np.exp(-t * 3)
|
185 |
+
elif voice_name == "fable":
|
186 |
+
audio = 0.7 * np.sin(2 * np.pi * pitch * t)
|
187 |
+
elif voice_name == "onyx":
|
188 |
+
audio = 0.8 * np.sin(2 * np.pi * pitch * t) + 0.1 * np.sin(2 * np.pi * pitch * 0.5 * t)
|
189 |
+
elif voice_name == "nova":
|
190 |
+
audio = 0.4 * np.sin(2 * np.pi * pitch * t) + 0.4 * np.sin(2 * np.pi * pitch * 0.5 * t)
|
191 |
+
else: # shimmer
|
192 |
+
audio = 0.3 * np.sin(2 * np.pi * pitch * t) + 0.2 * np.sin(2 * np.pi * pitch * 1.5 * t) + 0.1 * np.sin(2 * np.pi * pitch * 2 * t)
|
193 |
+
|
194 |
+
# Normalize
|
195 |
+
audio = audio / np.max(np.abs(audio))
|
196 |
+
|
197 |
+
# Convert to tensor
|
198 |
+
audio_tensor = torch.tensor(audio, dtype=torch.float32)
|
199 |
+
|
200 |
+
# Create voice memory
|
201 |
+
memory = VoiceMemory(
|
202 |
+
name=voice_name,
|
203 |
+
speaker_id=speaker_id,
|
204 |
+
audio_segments=[audio_tensor],
|
205 |
+
text_segments=[f"This is the voice of {voice_name}"],
|
206 |
+
pitch_base=characteristics["pitch"],
|
207 |
+
timbre=characteristics["timbre"]
|
208 |
+
)
|
209 |
+
|
210 |
+
# Save the voice memory to persistent storage
|
211 |
+
memory.save()
|
212 |
+
|
213 |
+
# Store in dictionary
|
214 |
+
VOICE_MEMORIES[voice_name] = memory
|
215 |
+
|
216 |
+
# Save as wav for reference
|
217 |
+
save_path = os.path.join(VOICE_MEMORIES_DIR, f"{voice_name}_seed.wav")
|
218 |
+
torchaudio.save(save_path, audio_tensor.unsqueeze(0), sample_rate)
|
219 |
+
|
220 |
+
logger.info(f"Initialized new voice memory for {voice_name}")
|
221 |
+
|
222 |
+
def get_voice_context(voice_name: str, device: torch.device, max_segments: int = 2) -> List[Segment]:
|
223 |
+
"""Get context segments for a given voice."""
|
224 |
+
if not VOICE_MEMORIES:
|
225 |
+
initialize_voices()
|
226 |
+
|
227 |
+
if voice_name in VOICE_MEMORIES:
|
228 |
+
return VOICE_MEMORIES[voice_name].get_context_segments(device, max_segments=max_segments)
|
229 |
+
|
230 |
+
# Default to alloy if voice not found
|
231 |
+
logger.warning(f"Voice {voice_name} not found, defaulting to alloy")
|
232 |
+
return VOICE_MEMORIES["alloy"].get_context_segments(device, max_segments=max_segments)
|
233 |
+
|
234 |
+
def update_voice_memory(voice_name: str, audio: torch.Tensor, text: str):
|
235 |
+
"""Update voice memory with newly generated audio and save to persistent storage."""
|
236 |
+
if not VOICE_MEMORIES:
|
237 |
+
initialize_voices()
|
238 |
+
|
239 |
+
if voice_name in VOICE_MEMORIES:
|
240 |
+
VOICE_MEMORIES[voice_name].update_with_new_audio(audio, text)
|
241 |
+
VOICE_MEMORIES[voice_name].save()
|
242 |
+
logger.info(f"Updated voice memory for {voice_name}, now has {len(VOICE_MEMORIES[voice_name].audio_segments)} segments")
|
243 |
+
|
244 |
+
def generate_voice_samples(app_state):
|
245 |
+
"""Generate high-quality voice samples for each voice.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
app_state: The FastAPI app state containing the generator
|
249 |
+
"""
|
250 |
+
generator = app_state.generator
|
251 |
+
if not generator:
|
252 |
+
logger.error("Cannot generate voice samples: generator not available")
|
253 |
+
return
|
254 |
+
|
255 |
+
logger.info("Beginning voice sample generation...")
|
256 |
+
|
257 |
+
# Ensure persistent directory exists
|
258 |
+
os.makedirs(VOICE_MEMORIES_DIR, exist_ok=True)
|
259 |
+
|
260 |
+
for voice_name in ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]:
|
261 |
+
speaker_id = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"].index(voice_name)
|
262 |
+
|
263 |
+
# Get multiple sample texts for this voice
|
264 |
+
sample_texts = VOICE_INTROS[voice_name]
|
265 |
+
|
266 |
+
# Generate a collection of samples for this voice
|
267 |
+
logger.info(f"Generating samples for voice: {voice_name}")
|
268 |
+
audio_segments = []
|
269 |
+
text_segments = []
|
270 |
+
|
271 |
+
for i, sample_text in enumerate(sample_texts):
|
272 |
+
try:
|
273 |
+
# Check if we already have a sample
|
274 |
+
sample_path = os.path.join(VOICE_MEMORIES_DIR, f"{voice_name}_sample_{i}.wav")
|
275 |
+
if os.path.exists(sample_path):
|
276 |
+
logger.info(f"Found existing sample {i+1} for {voice_name}, loading from {sample_path}")
|
277 |
+
audio_tensor, sr = torchaudio.load(sample_path)
|
278 |
+
if sr != generator.sample_rate:
|
279 |
+
audio_tensor = torchaudio.functional.resample(
|
280 |
+
audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate
|
281 |
+
)
|
282 |
+
else:
|
283 |
+
audio_tensor = audio_tensor.squeeze(0)
|
284 |
+
audio_segments.append(audio_tensor)
|
285 |
+
text_segments.append(sample_text)
|
286 |
+
continue
|
287 |
+
|
288 |
+
# Generate without context first for seed samples
|
289 |
+
logger.info(f"Generating sample {i+1}/{len(sample_texts)} for {voice_name}: '{sample_text}'")
|
290 |
+
|
291 |
+
# Use a lower temperature for more stable output
|
292 |
+
audio = generator.generate(
|
293 |
+
text=sample_text,
|
294 |
+
speaker=speaker_id,
|
295 |
+
context=[], # No context for initial samples
|
296 |
+
max_audio_length_ms=10000,
|
297 |
+
temperature=0.7, # Lower temperature for more stable output
|
298 |
+
topk=30,
|
299 |
+
)
|
300 |
+
|
301 |
+
# Save this segment
|
302 |
+
audio_segments.append(audio.detach().cpu())
|
303 |
+
text_segments.append(sample_text)
|
304 |
+
|
305 |
+
# Save as WAV for reference to persistent storage
|
306 |
+
torchaudio.save(sample_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
|
307 |
+
|
308 |
+
logger.info(f"Generated sample {i+1} for {voice_name}, length: {audio.shape[0]/generator.sample_rate:.2f}s")
|
309 |
+
|
310 |
+
except Exception as e:
|
311 |
+
logger.error(f"Error generating sample {i+1} for {voice_name}: {e}")
|
312 |
+
|
313 |
+
# Use the generated samples to update the voice memory
|
314 |
+
if voice_name in VOICE_MEMORIES and audio_segments:
|
315 |
+
# Replace existing samples with these high quality ones
|
316 |
+
VOICE_MEMORIES[voice_name].audio_segments = audio_segments
|
317 |
+
VOICE_MEMORIES[voice_name].text_segments = text_segments
|
318 |
+
VOICE_MEMORIES[voice_name].save()
|
319 |
+
|
320 |
+
logger.info(f"Updated voice memory for {voice_name} with {len(audio_segments)} high-quality samples")
|
321 |
+
|
322 |
+
# Now generate a second pass with context from these samples
|
323 |
+
if len(audio_segments) >= 2:
|
324 |
+
try:
|
325 |
+
# Check if we already have a character sample
|
326 |
+
character_path = os.path.join(VOICE_MEMORIES_DIR, f"{voice_name}_character.wav")
|
327 |
+
if os.path.exists(character_path):
|
328 |
+
logger.info(f"Found existing character sample for {voice_name}, loading from {character_path}")
|
329 |
+
audio_tensor, sr = torchaudio.load(character_path)
|
330 |
+
if sr != generator.sample_rate:
|
331 |
+
audio_tensor = torchaudio.functional.resample(
|
332 |
+
audio_tensor.squeeze(0), orig_freq=sr, new_freq=generator.sample_rate
|
333 |
+
)
|
334 |
+
else:
|
335 |
+
audio_tensor = audio_tensor.squeeze(0)
|
336 |
+
|
337 |
+
character_sample_text = f"I'm the voice assistant known as {voice_name}. I'm designed to have a distinctive voice that you can easily recognize."
|
338 |
+
VOICE_MEMORIES[voice_name].audio_segments.append(audio_tensor)
|
339 |
+
VOICE_MEMORIES[voice_name].text_segments.append(character_sample_text)
|
340 |
+
VOICE_MEMORIES[voice_name].save()
|
341 |
+
continue
|
342 |
+
|
343 |
+
# Get intro and conclusion prompts that build voice consistency
|
344 |
+
context = [
|
345 |
+
Segment(
|
346 |
+
speaker=speaker_id,
|
347 |
+
text=text_segments[0],
|
348 |
+
audio=audio_segments[0].to(generator.device)
|
349 |
+
)
|
350 |
+
]
|
351 |
+
|
352 |
+
# Create a longer sample with the voice characteristics now established
|
353 |
+
character_sample_text = f"I'm the voice assistant known as {voice_name}. I'm designed to have a distinctive voice that you can easily recognize. My speech patterns and tone should remain consistent throughout our conversation."
|
354 |
+
|
355 |
+
logger.info(f"Generating character sample for {voice_name} with context")
|
356 |
+
character_audio = generator.generate(
|
357 |
+
text=character_sample_text,
|
358 |
+
speaker=speaker_id,
|
359 |
+
context=context,
|
360 |
+
max_audio_length_ms=15000,
|
361 |
+
temperature=0.7,
|
362 |
+
topk=30,
|
363 |
+
)
|
364 |
+
|
365 |
+
# Save this comprehensive character sample to persistent storage
|
366 |
+
torchaudio.save(character_path, character_audio.unsqueeze(0).cpu(), generator.sample_rate)
|
367 |
+
|
368 |
+
# Add this to the memory as well
|
369 |
+
VOICE_MEMORIES[voice_name].audio_segments.append(character_audio.detach().cpu())
|
370 |
+
VOICE_MEMORIES[voice_name].text_segments.append(character_sample_text)
|
371 |
+
VOICE_MEMORIES[voice_name].save()
|
372 |
+
|
373 |
+
logger.info(f"Generated character sample for {voice_name}, length: {character_audio.shape[0]/generator.sample_rate:.2f}s")
|
374 |
+
|
375 |
+
except Exception as e:
|
376 |
+
logger.error(f"Error generating character sample for {voice_name}: {e}")
|
377 |
+
|
378 |
+
def create_custom_voice(
|
379 |
+
app_state,
|
380 |
+
name: str,
|
381 |
+
initial_text: str,
|
382 |
+
speaker_id: int = 0,
|
383 |
+
pitch: Optional[float] = None,
|
384 |
+
timbre: str = "custom"
|
385 |
+
) -> Dict:
|
386 |
+
"""Create a new custom voice.
|
387 |
+
|
388 |
+
Args:
|
389 |
+
app_state: The FastAPI app state containing the generator
|
390 |
+
name: Name for the new voice
|
391 |
+
initial_text: Text for the initial voice sample
|
392 |
+
speaker_id: Base speaker ID (0-5)
|
393 |
+
pitch: Base pitch in Hz (optional)
|
394 |
+
timbre: Voice quality descriptor
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
Dict with creation status and voice info
|
398 |
+
"""
|
399 |
+
generator = app_state.generator
|
400 |
+
if not generator:
|
401 |
+
return {"status": "error", "message": "Generator not available"}
|
402 |
+
|
403 |
+
# Check if voice already exists
|
404 |
+
if not VOICE_MEMORIES:
|
405 |
+
initialize_voices()
|
406 |
+
|
407 |
+
if name in VOICE_MEMORIES:
|
408 |
+
return {"status": "error", "message": f"Voice '{name}' already exists"}
|
409 |
+
|
410 |
+
# Generate a voice sample
|
411 |
+
try:
|
412 |
+
logger.info(f"Creating custom voice '{name}' with text: '{initial_text}'")
|
413 |
+
|
414 |
+
audio = generator.generate(
|
415 |
+
text=initial_text,
|
416 |
+
speaker=speaker_id,
|
417 |
+
context=[],
|
418 |
+
max_audio_length_ms=10000,
|
419 |
+
temperature=0.7,
|
420 |
+
)
|
421 |
+
|
422 |
+
# Determine base pitch if not provided
|
423 |
+
if pitch is None:
|
424 |
+
if speaker_id == 0: # alloy
|
425 |
+
pitch = 220.0
|
426 |
+
elif speaker_id == 1: # echo
|
427 |
+
pitch = 330.0
|
428 |
+
elif speaker_id == 2: # fable
|
429 |
+
pitch = 523.0
|
430 |
+
elif speaker_id == 3: # onyx
|
431 |
+
pitch = 165.0
|
432 |
+
elif speaker_id == 4: # nova
|
433 |
+
pitch = 392.0
|
434 |
+
else: # shimmer
|
435 |
+
pitch = 587.0
|
436 |
+
|
437 |
+
# Create a new voice memory
|
438 |
+
memory = VoiceMemory(
|
439 |
+
name=name,
|
440 |
+
speaker_id=speaker_id,
|
441 |
+
audio_segments=[audio.detach().cpu()],
|
442 |
+
text_segments=[initial_text],
|
443 |
+
pitch_base=pitch,
|
444 |
+
timbre=timbre
|
445 |
+
)
|
446 |
+
|
447 |
+
# Save the voice memory to persistent storage
|
448 |
+
memory.save()
|
449 |
+
VOICE_MEMORIES[name] = memory
|
450 |
+
|
451 |
+
# Save sample as WAV for reference to persistent storage
|
452 |
+
sample_path = os.path.join(VOICE_MEMORIES_DIR, f"{name}_sample.wav")
|
453 |
+
torchaudio.save(sample_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
|
454 |
+
|
455 |
+
logger.info(f"Created custom voice '{name}' successfully")
|
456 |
+
|
457 |
+
return {
|
458 |
+
"status": "success",
|
459 |
+
"message": f"Voice '{name}' created successfully",
|
460 |
+
"voice": {
|
461 |
+
"name": name,
|
462 |
+
"speaker_id": speaker_id,
|
463 |
+
"pitch": pitch,
|
464 |
+
"timbre": timbre,
|
465 |
+
"sample_length_seconds": audio.shape[0] / generator.sample_rate
|
466 |
+
}
|
467 |
+
}
|
468 |
+
|
469 |
+
except Exception as e:
|
470 |
+
logger.error(f"Error creating custom voice '{name}': {e}")
|
471 |
+
return {
|
472 |
+
"status": "error",
|
473 |
+
"message": f"Error creating voice: {str(e)}"
|
474 |
+
}
|
app/watermarking.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stub for watermarking module.
|
2 |
+
|
3 |
+
The original CSM code has a watermarking module that adds
|
4 |
+
an imperceptible watermark to generated audio.
|
5 |
+
"""
|
6 |
+
|
7 |
+
# Watermark key used by CSM
|
8 |
+
CSM_1B_GH_WATERMARK = "CSM1B@GitHub"
|
9 |
+
|
10 |
+
def load_watermarker(device="cpu"):
|
11 |
+
"""Stub for watermarker loading.
|
12 |
+
|
13 |
+
In a real implementation, this would load the actual watermarker.
|
14 |
+
"""
|
15 |
+
return None
|
16 |
+
|
17 |
+
def watermark(watermarker, audio, sample_rate, key):
|
18 |
+
"""Stub for watermarking function.
|
19 |
+
|
20 |
+
In a real implementation, this would add the watermark.
|
21 |
+
"""
|
22 |
+
return audio, sample_rate
|
docker-compose.yml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
csm-api:
|
3 |
+
build:
|
4 |
+
context: .
|
5 |
+
dockerfile: Dockerfile
|
6 |
+
args:
|
7 |
+
- HF_TOKEN=${HF_TOKEN}
|
8 |
+
ports:
|
9 |
+
- "8000:8000"
|
10 |
+
volumes:
|
11 |
+
- ./models:/app/models
|
12 |
+
- ./cloned_voices:/app/cloned_voices
|
13 |
+
- ./voice_references:/app/voice_references
|
14 |
+
- ./voice_profiles:/app/voice_profiles
|
15 |
+
environment:
|
16 |
+
- HF_TOKEN=${HF_TOKEN}
|
17 |
+
deploy:
|
18 |
+
resources:
|
19 |
+
reservations:
|
20 |
+
devices:
|
21 |
+
- driver: nvidia
|
22 |
+
count: all
|
23 |
+
capabilities: [gpu]
|
requirements.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi>=0.100.0
|
2 |
+
uvicorn>=0.24.0
|
3 |
+
pydantic>=2.4.0
|
4 |
+
python-multipart>=0.0.6
|
5 |
+
huggingface_hub>=0.20.0
|
6 |
+
torch>=2.0.0
|
7 |
+
torchaudio>=2.0.0
|
8 |
+
transformers>=4.35.0
|
9 |
+
tokenizers>=0.14.0
|
10 |
+
sentencepiece>=0.1.99
|
11 |
+
triton>=2.1.0; platform_system != "Windows"
|
12 |
+
triton-windows>=2.1.0; platform_system == "Windows"
|
13 |
+
torchao>=0.1.0
|
14 |
+
torchtune>=0.5.0
|
15 |
+
numpy>=1.25.0
|
16 |
+
ffmpeg-python>=0.2.0
|
17 |
+
moshi>=0.1.0
|
18 |
+
soundfile>=0.12.1
|
19 |
+
scipy>=1.10.0
|
20 |
+
librosa>=0.10.0
|
21 |
+
yt-dlp>=2023.3.4
|
22 |
+
openai-whisper>=20230314
|
23 |
+
ffmpeg-python>=0.2.0
|
24 |
+
accelerate>=0.20.0
|
static/voice-cloning.html
ADDED
@@ -0,0 +1,1385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en" data-theme="light">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no">
|
6 |
+
<title>Voice Cloning</title>
|
7 |
+
<style>
|
8 |
+
:root {
|
9 |
+
/* Base colors */
|
10 |
+
--bg-color: #ffffff;
|
11 |
+
--text-color: #333333;
|
12 |
+
--card-bg: #ffffff;
|
13 |
+
--card-border: #e2e8f0;
|
14 |
+
--card-shadow: rgba(0,0,0,0.1);
|
15 |
+
|
16 |
+
/* Primary colors */
|
17 |
+
--primary-color: #4f46e5;
|
18 |
+
--primary-hover: #4338ca;
|
19 |
+
--primary-light: rgba(79, 70, 229, 0.1);
|
20 |
+
|
21 |
+
/* Secondary colors */
|
22 |
+
--secondary-color: #6b7280;
|
23 |
+
--secondary-hover: #4b5563;
|
24 |
+
|
25 |
+
/* Accent colors */
|
26 |
+
--success-color: #10b981;
|
27 |
+
--success-hover: #059669;
|
28 |
+
--danger-color: #ef4444;
|
29 |
+
--danger-hover: #dc2626;
|
30 |
+
--warning-color: #f59e0b;
|
31 |
+
--warning-hover: #d97706;
|
32 |
+
|
33 |
+
/* Input elements */
|
34 |
+
--input-bg: #ffffff;
|
35 |
+
--input-border: #d1d5db;
|
36 |
+
--input-focus-border: #4f46e5;
|
37 |
+
--input-focus-shadow: rgba(79, 70, 229, 0.2);
|
38 |
+
|
39 |
+
/* Voice cards */
|
40 |
+
--voice-card-bg: #f9fafb;
|
41 |
+
--voice-card-border: #e5e7eb;
|
42 |
+
--voice-card-shadow: rgba(0,0,0,0.05);
|
43 |
+
|
44 |
+
/* Tabs */
|
45 |
+
--tab-border: #e5e7eb;
|
46 |
+
--tab-text: #4b5563;
|
47 |
+
--tab-active: #4f46e5;
|
48 |
+
--tab-active-bg: rgba(79, 70, 229, 0.1);
|
49 |
+
|
50 |
+
/* Toggle */
|
51 |
+
--toggle-bg: #e5e7eb;
|
52 |
+
--toggle-active: #4f46e5;
|
53 |
+
--toggle-circle: #ffffff;
|
54 |
+
|
55 |
+
/* Status indicators */
|
56 |
+
--status-success-bg: #dcfce7;
|
57 |
+
--status-success-text: #166534;
|
58 |
+
--status-error-bg: #fee2e2;
|
59 |
+
--status-error-text: #b91c1c;
|
60 |
+
--status-warning-bg: #fff7ed;
|
61 |
+
--status-warning-text: #c2410c;
|
62 |
+
--status-info-bg: #eff6ff;
|
63 |
+
--status-info-text: #1e40af;
|
64 |
+
}
|
65 |
+
|
66 |
+
[data-theme="dark"] {
|
67 |
+
/* Base colors */
|
68 |
+
--bg-color: #111827;
|
69 |
+
--text-color: #f3f4f6;
|
70 |
+
--card-bg: #1f2937;
|
71 |
+
--card-border: #374151;
|
72 |
+
--card-shadow: rgba(0,0,0,0.3);
|
73 |
+
|
74 |
+
/* Primary colors */
|
75 |
+
--primary-color: #6366f1;
|
76 |
+
--primary-hover: #4f46e5;
|
77 |
+
--primary-light: rgba(99, 102, 241, 0.2);
|
78 |
+
|
79 |
+
/* Secondary colors */
|
80 |
+
--secondary-color: #9ca3af;
|
81 |
+
--secondary-hover: #6b7280;
|
82 |
+
|
83 |
+
/* Accent colors - slightly brighter for dark theme */
|
84 |
+
--success-color: #34d399;
|
85 |
+
--success-hover: #10b981;
|
86 |
+
--danger-color: #f87171;
|
87 |
+
--danger-hover: #ef4444;
|
88 |
+
--warning-color: #fbbf24;
|
89 |
+
--warning-hover: #f59e0b;
|
90 |
+
|
91 |
+
/* Input elements */
|
92 |
+
--input-bg: #374151;
|
93 |
+
--input-border: #4b5563;
|
94 |
+
--input-focus-border: #6366f1;
|
95 |
+
--input-focus-shadow: rgba(99, 102, 241, 0.3);
|
96 |
+
|
97 |
+
/* Voice cards */
|
98 |
+
--voice-card-bg: #1f2937;
|
99 |
+
--voice-card-border: #374151;
|
100 |
+
--voice-card-shadow: rgba(0,0,0,0.2);
|
101 |
+
|
102 |
+
/* Tabs */
|
103 |
+
--tab-border: #374151;
|
104 |
+
--tab-text: #9ca3af;
|
105 |
+
--tab-active: #6366f1;
|
106 |
+
--tab-active-bg: rgba(99, 102, 241, 0.2);
|
107 |
+
|
108 |
+
/* Toggle */
|
109 |
+
--toggle-bg: #4b5563;
|
110 |
+
--toggle-active: #6366f1;
|
111 |
+
--toggle-circle: #e5e7eb;
|
112 |
+
|
113 |
+
/* Status indicators */
|
114 |
+
--status-success-bg: rgba(16, 185, 129, 0.2);
|
115 |
+
--status-success-text: #34d399;
|
116 |
+
--status-error-bg: rgba(239, 68, 68, 0.2);
|
117 |
+
--status-error-text: #f87171;
|
118 |
+
--status-warning-bg: rgba(245, 158, 11, 0.2);
|
119 |
+
--status-warning-text: #fbbf24;
|
120 |
+
--status-info-bg: rgba(59, 130, 246, 0.2);
|
121 |
+
--status-info-text: #60a5fa;
|
122 |
+
}
|
123 |
+
|
124 |
+
* {
|
125 |
+
box-sizing: border-box;
|
126 |
+
}
|
127 |
+
|
128 |
+
html, body {
|
129 |
+
margin: 0;
|
130 |
+
padding: 0;
|
131 |
+
overflow-x: hidden;
|
132 |
+
width: 100%;
|
133 |
+
}
|
134 |
+
|
135 |
+
body {
|
136 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif;
|
137 |
+
max-width: 100%;
|
138 |
+
padding: 16px;
|
139 |
+
line-height: 1.6;
|
140 |
+
background-color: var(--bg-color);
|
141 |
+
color: var(--text-color);
|
142 |
+
transition: background-color 0.3s, color 0.3s;
|
143 |
+
}
|
144 |
+
|
145 |
+
.container {
|
146 |
+
width: 100%;
|
147 |
+
max-width: 800px;
|
148 |
+
margin: 0 auto;
|
149 |
+
padding: 0 8px;
|
150 |
+
}
|
151 |
+
|
152 |
+
h1, h2, h3 {
|
153 |
+
color: var(--text-color);
|
154 |
+
margin-top: 0;
|
155 |
+
}
|
156 |
+
|
157 |
+
h1 {
|
158 |
+
font-size: 1.8rem;
|
159 |
+
margin-bottom: 1rem;
|
160 |
+
text-align: center;
|
161 |
+
}
|
162 |
+
|
163 |
+
h2 {
|
164 |
+
font-size: 1.4rem;
|
165 |
+
margin-bottom: 1rem;
|
166 |
+
}
|
167 |
+
|
168 |
+
.card {
|
169 |
+
border: 1px solid var(--card-border);
|
170 |
+
border-radius: 12px;
|
171 |
+
padding: 20px;
|
172 |
+
margin-bottom: 20px;
|
173 |
+
box-shadow: 0 4px 6px var(--card-shadow);
|
174 |
+
background-color: var(--card-bg);
|
175 |
+
transition: box-shadow 0.3s ease;
|
176 |
+
}
|
177 |
+
|
178 |
+
.card:hover {
|
179 |
+
box-shadow: 0 6px 12px var(--card-shadow);
|
180 |
+
}
|
181 |
+
|
182 |
+
.form-group {
|
183 |
+
margin-bottom: 20px;
|
184 |
+
}
|
185 |
+
|
186 |
+
label {
|
187 |
+
display: block;
|
188 |
+
margin-bottom: 6px;
|
189 |
+
font-weight: 500;
|
190 |
+
font-size: 0.95rem;
|
191 |
+
}
|
192 |
+
|
193 |
+
input, textarea, select {
|
194 |
+
width: 100%;
|
195 |
+
padding: 10px 12px;
|
196 |
+
border: 1px solid var(--input-border);
|
197 |
+
border-radius: 8px;
|
198 |
+
background-color: var(--input-bg);
|
199 |
+
color: var(--text-color);
|
200 |
+
font-size: 1rem;
|
201 |
+
transition: border-color 0.3s, box-shadow 0.3s;
|
202 |
+
}
|
203 |
+
|
204 |
+
input:focus, textarea:focus, select:focus {
|
205 |
+
outline: none;
|
206 |
+
border-color: var(--input-focus-border);
|
207 |
+
box-shadow: 0 0 0 3px var(--input-focus-shadow);
|
208 |
+
}
|
209 |
+
|
210 |
+
/* File input styling */
|
211 |
+
input[type="file"] {
|
212 |
+
padding: 8px;
|
213 |
+
background-color: var(--input-bg);
|
214 |
+
border: 1px dashed var(--input-border);
|
215 |
+
border-radius: 8px;
|
216 |
+
cursor: pointer;
|
217 |
+
}
|
218 |
+
|
219 |
+
input[type="file"]:hover {
|
220 |
+
border-color: var(--primary-color);
|
221 |
+
}
|
222 |
+
|
223 |
+
button {
|
224 |
+
background-color: var(--primary-color);
|
225 |
+
color: white;
|
226 |
+
border: none;
|
227 |
+
border-radius: 8px;
|
228 |
+
padding: 12px 20px;
|
229 |
+
cursor: pointer;
|
230 |
+
font-weight: 600;
|
231 |
+
font-size: 1rem;
|
232 |
+
transition: background-color 0.3s, transform 0.1s;
|
233 |
+
width: 100%;
|
234 |
+
}
|
235 |
+
|
236 |
+
button:hover {
|
237 |
+
background-color: var(--primary-hover);
|
238 |
+
}
|
239 |
+
|
240 |
+
button:active {
|
241 |
+
transform: translateY(1px);
|
242 |
+
}
|
243 |
+
|
244 |
+
button:disabled {
|
245 |
+
opacity: 0.7;
|
246 |
+
cursor: not-allowed;
|
247 |
+
}
|
248 |
+
|
249 |
+
.btn-row {
|
250 |
+
display: flex;
|
251 |
+
gap: 10px;
|
252 |
+
}
|
253 |
+
|
254 |
+
.btn-row button {
|
255 |
+
flex: 1;
|
256 |
+
}
|
257 |
+
|
258 |
+
.voice-list {
|
259 |
+
display: grid;
|
260 |
+
grid-template-columns: repeat(auto-fill, minmax(280px, 1fr));
|
261 |
+
gap: 16px;
|
262 |
+
}
|
263 |
+
|
264 |
+
.voice-card {
|
265 |
+
border: 1px solid var(--voice-card-border);
|
266 |
+
border-radius: 10px;
|
267 |
+
padding: 16px;
|
268 |
+
background-color: var(--voice-card-bg);
|
269 |
+
box-shadow: 0 2px 6px var(--voice-card-shadow);
|
270 |
+
transition: transform 0.2s ease, box-shadow 0.2s ease;
|
271 |
+
}
|
272 |
+
|
273 |
+
.voice-card:hover {
|
274 |
+
transform: translateY(-3px);
|
275 |
+
box-shadow: 0 6px 12px var(--voice-card-shadow);
|
276 |
+
}
|
277 |
+
|
278 |
+
.controls {
|
279 |
+
display: flex;
|
280 |
+
gap: 8px;
|
281 |
+
margin-top: 12px;
|
282 |
+
}
|
283 |
+
|
284 |
+
.voice-name {
|
285 |
+
font-weight: 600;
|
286 |
+
font-size: 18px;
|
287 |
+
margin: 0 0 8px 0;
|
288 |
+
color: var(--primary-color);
|
289 |
+
}
|
290 |
+
|
291 |
+
.btn-danger {
|
292 |
+
background-color: var(--danger-color);
|
293 |
+
}
|
294 |
+
|
295 |
+
.btn-danger:hover {
|
296 |
+
background-color: var(--danger-hover);
|
297 |
+
}
|
298 |
+
|
299 |
+
.btn-secondary {
|
300 |
+
background-color: var(--secondary-color);
|
301 |
+
}
|
302 |
+
|
303 |
+
.btn-secondary:hover {
|
304 |
+
background-color: var(--secondary-hover);
|
305 |
+
}
|
306 |
+
|
307 |
+
#audio-preview {
|
308 |
+
margin-top: 20px;
|
309 |
+
width: 100%;
|
310 |
+
border-radius: 8px;
|
311 |
+
background-color: var(--card-bg);
|
312 |
+
}
|
313 |
+
|
314 |
+
/* Status indicators */
|
315 |
+
.status-indicator {
|
316 |
+
padding: 12px 16px;
|
317 |
+
margin: 16px 0;
|
318 |
+
border-radius: 8px;
|
319 |
+
font-weight: 500;
|
320 |
+
display: flex;
|
321 |
+
align-items: center;
|
322 |
+
opacity: 0;
|
323 |
+
transition: opacity 0.3s ease;
|
324 |
+
max-height: 0;
|
325 |
+
overflow: hidden;
|
326 |
+
}
|
327 |
+
|
328 |
+
.status-indicator.show {
|
329 |
+
opacity: 1;
|
330 |
+
max-height: 100px;
|
331 |
+
}
|
332 |
+
|
333 |
+
.status-indicator.success {
|
334 |
+
background-color: var(--status-success-bg);
|
335 |
+
color: var(--status-success-text);
|
336 |
+
}
|
337 |
+
|
338 |
+
.status-indicator.error {
|
339 |
+
background-color: var(--status-error-bg);
|
340 |
+
color: var(--status-error-text);
|
341 |
+
}
|
342 |
+
|
343 |
+
.status-indicator.warning {
|
344 |
+
background-color: var(--status-warning-bg);
|
345 |
+
color: var(--status-warning-text);
|
346 |
+
}
|
347 |
+
|
348 |
+
.status-indicator.info {
|
349 |
+
background-color: var(--status-info-bg);
|
350 |
+
color: var(--status-info-text);
|
351 |
+
}
|
352 |
+
|
353 |
+
.status-indicator svg {
|
354 |
+
margin-right: 8px;
|
355 |
+
flex-shrink: 0;
|
356 |
+
}
|
357 |
+
|
358 |
+
/* Tabs */
|
359 |
+
.tabs {
|
360 |
+
display: flex;
|
361 |
+
margin-bottom: 20px;
|
362 |
+
border-bottom: 1px solid var(--tab-border);
|
363 |
+
overflow-x: auto;
|
364 |
+
-webkit-overflow-scrolling: touch;
|
365 |
+
scrollbar-width: none; /* Hide scrollbar for Firefox */
|
366 |
+
}
|
367 |
+
|
368 |
+
.tabs::-webkit-scrollbar {
|
369 |
+
display: none; /* Hide scrollbar for Chrome/Safari */
|
370 |
+
}
|
371 |
+
|
372 |
+
.tabs button {
|
373 |
+
background-color: transparent;
|
374 |
+
color: var(--tab-text);
|
375 |
+
border: none;
|
376 |
+
padding: 12px 16px;
|
377 |
+
margin-right: 8px;
|
378 |
+
cursor: pointer;
|
379 |
+
position: relative;
|
380 |
+
font-weight: 600;
|
381 |
+
border-radius: 8px 8px 0 0;
|
382 |
+
white-space: nowrap;
|
383 |
+
width: auto;
|
384 |
+
flex-shrink: 0;
|
385 |
+
}
|
386 |
+
|
387 |
+
.tabs button.active {
|
388 |
+
color: var(--tab-active);
|
389 |
+
background-color: var(--tab-active-bg);
|
390 |
+
}
|
391 |
+
|
392 |
+
.tabs button.active::after {
|
393 |
+
content: '';
|
394 |
+
position: absolute;
|
395 |
+
bottom: -1px;
|
396 |
+
left: 0;
|
397 |
+
right: 0;
|
398 |
+
height: 2px;
|
399 |
+
background-color: var(--tab-active);
|
400 |
+
}
|
401 |
+
|
402 |
+
.tab-content {
|
403 |
+
display: none;
|
404 |
+
}
|
405 |
+
|
406 |
+
.tab-content.active {
|
407 |
+
display: block;
|
408 |
+
}
|
409 |
+
|
410 |
+
/* Theme toggle switch */
|
411 |
+
.theme-switch-wrapper {
|
412 |
+
display: flex;
|
413 |
+
align-items: center;
|
414 |
+
justify-content: flex-end;
|
415 |
+
margin-bottom: 16px;
|
416 |
+
}
|
417 |
+
|
418 |
+
.theme-switch {
|
419 |
+
display: inline-block;
|
420 |
+
height: 28px;
|
421 |
+
position: relative;
|
422 |
+
width: 54px;
|
423 |
+
}
|
424 |
+
|
425 |
+
.theme-switch input {
|
426 |
+
display: none;
|
427 |
+
}
|
428 |
+
|
429 |
+
.slider {
|
430 |
+
background-color: var(--toggle-bg);
|
431 |
+
bottom: 0;
|
432 |
+
cursor: pointer;
|
433 |
+
left: 0;
|
434 |
+
position: absolute;
|
435 |
+
right: 0;
|
436 |
+
top: 0;
|
437 |
+
transition: .4s;
|
438 |
+
border-radius: 28px;
|
439 |
+
}
|
440 |
+
|
441 |
+
.slider:before {
|
442 |
+
background-color: var(--toggle-circle);
|
443 |
+
bottom: 4px;
|
444 |
+
content: "";
|
445 |
+
height: 20px;
|
446 |
+
left: 4px;
|
447 |
+
position: absolute;
|
448 |
+
transition: .4s;
|
449 |
+
width: 20px;
|
450 |
+
border-radius: 50%;
|
451 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
452 |
+
}
|
453 |
+
|
454 |
+
input:checked + .slider {
|
455 |
+
background-color: var(--toggle-active);
|
456 |
+
}
|
457 |
+
|
458 |
+
input:checked + .slider:before {
|
459 |
+
transform: translateX(26px);
|
460 |
+
}
|
461 |
+
|
462 |
+
.theme-icon {
|
463 |
+
width: 16px;
|
464 |
+
height: 16px;
|
465 |
+
display: inline-block;
|
466 |
+
margin: 0 8px;
|
467 |
+
font-size: 16px;
|
468 |
+
line-height: 1;
|
469 |
+
}
|
470 |
+
|
471 |
+
/* Progress bar */
|
472 |
+
.progress-bar {
|
473 |
+
width: 100%;
|
474 |
+
height: 12px;
|
475 |
+
background-color: var(--input-bg);
|
476 |
+
border-radius: 6px;
|
477 |
+
margin-top: 12px;
|
478 |
+
overflow: hidden;
|
479 |
+
box-shadow: inset 0 1px 3px var(--card-shadow);
|
480 |
+
}
|
481 |
+
|
482 |
+
.progress-fill {
|
483 |
+
height: 100%;
|
484 |
+
background: linear-gradient(to right, var(--primary-color), var(--primary-hover));
|
485 |
+
width: 0%;
|
486 |
+
transition: width 0.5s ease-in-out;
|
487 |
+
border-radius: 6px;
|
488 |
+
position: relative;
|
489 |
+
}
|
490 |
+
|
491 |
+
.progress-fill::after {
|
492 |
+
content: '';
|
493 |
+
position: absolute;
|
494 |
+
top: 0;
|
495 |
+
left: 0;
|
496 |
+
right: 0;
|
497 |
+
bottom: 0;
|
498 |
+
background: linear-gradient(
|
499 |
+
-45deg,
|
500 |
+
rgba(255, 255, 255, 0.2) 25%,
|
501 |
+
transparent 25%,
|
502 |
+
transparent 50%,
|
503 |
+
rgba(255, 255, 255, 0.2) 50%,
|
504 |
+
rgba(255, 255, 255, 0.2) 75%,
|
505 |
+
transparent 75%
|
506 |
+
);
|
507 |
+
background-size: 16px 16px;
|
508 |
+
animation: progress-animation 1s linear infinite;
|
509 |
+
border-radius: 6px;
|
510 |
+
}
|
511 |
+
|
512 |
+
@keyframes progress-animation {
|
513 |
+
0% {
|
514 |
+
background-position: 0 0;
|
515 |
+
}
|
516 |
+
100% {
|
517 |
+
background-position: 16px 0;
|
518 |
+
}
|
519 |
+
}
|
520 |
+
|
521 |
+
/* Divider */
|
522 |
+
.divider {
|
523 |
+
margin: 24px 0;
|
524 |
+
border-top: 1px solid var(--card-border);
|
525 |
+
position: relative;
|
526 |
+
}
|
527 |
+
|
528 |
+
.divider-text {
|
529 |
+
position: absolute;
|
530 |
+
top: -10px;
|
531 |
+
left: 50%;
|
532 |
+
transform: translateX(-50%);
|
533 |
+
background-color: var(--bg-color);
|
534 |
+
padding: 0 12px;
|
535 |
+
color: var(--secondary-color);
|
536 |
+
font-size: 0.9rem;
|
537 |
+
}
|
538 |
+
|
539 |
+
/* Small text */
|
540 |
+
small {
|
541 |
+
color: var(--secondary-color);
|
542 |
+
display: block;
|
543 |
+
margin-top: 6px;
|
544 |
+
font-size: 0.85rem;
|
545 |
+
}
|
546 |
+
|
547 |
+
/* Range slider styling */
|
548 |
+
input[type="range"] {
|
549 |
+
-webkit-appearance: none;
|
550 |
+
height: 8px;
|
551 |
+
border-radius: 4px;
|
552 |
+
background: var(--input-bg);
|
553 |
+
outline: none;
|
554 |
+
padding: 0;
|
555 |
+
margin: 10px 0;
|
556 |
+
}
|
557 |
+
|
558 |
+
input[type="range"]::-webkit-slider-thumb {
|
559 |
+
-webkit-appearance: none;
|
560 |
+
appearance: none;
|
561 |
+
width: 20px;
|
562 |
+
height: 20px;
|
563 |
+
border-radius: 50%;
|
564 |
+
background: var(--primary-color);
|
565 |
+
cursor: pointer;
|
566 |
+
border: 2px solid white;
|
567 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
568 |
+
}
|
569 |
+
|
570 |
+
input[type="range"]::-moz-range-thumb {
|
571 |
+
width: 20px;
|
572 |
+
height: 20px;
|
573 |
+
border-radius: 50%;
|
574 |
+
background: var(--primary-color);
|
575 |
+
cursor: pointer;
|
576 |
+
border: 2px solid white;
|
577 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
578 |
+
}
|
579 |
+
|
580 |
+
input[type="range"]::-ms-thumb {
|
581 |
+
width: 20px;
|
582 |
+
height: 20px;
|
583 |
+
border-radius: 50%;
|
584 |
+
background: var(--primary-color);
|
585 |
+
cursor: pointer;
|
586 |
+
border: 2px solid white;
|
587 |
+
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
588 |
+
}
|
589 |
+
|
590 |
+
#temperature-value {
|
591 |
+
display: inline-block;
|
592 |
+
width: 40px;
|
593 |
+
text-align: center;
|
594 |
+
font-weight: 600;
|
595 |
+
color: var(--primary-color);
|
596 |
+
}
|
597 |
+
|
598 |
+
/* Toast notifications */
|
599 |
+
.toast-container {
|
600 |
+
position: fixed;
|
601 |
+
bottom: 20px;
|
602 |
+
right: 20px;
|
603 |
+
z-index: 1000;
|
604 |
+
max-width: 100%;
|
605 |
+
width: 300px;
|
606 |
+
}
|
607 |
+
|
608 |
+
.toast {
|
609 |
+
padding: 12px 16px;
|
610 |
+
margin-bottom: 12px;
|
611 |
+
border-radius: 8px;
|
612 |
+
color: white;
|
613 |
+
font-weight: 500;
|
614 |
+
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
|
615 |
+
display: flex;
|
616 |
+
align-items: center;
|
617 |
+
justify-content: space-between;
|
618 |
+
animation: toast-in 0.3s ease forwards;
|
619 |
+
max-width: 100%;
|
620 |
+
}
|
621 |
+
|
622 |
+
.toast.success {
|
623 |
+
background-color: var(--success-color);
|
624 |
+
}
|
625 |
+
|
626 |
+
.toast.error {
|
627 |
+
background-color: var(--danger-color);
|
628 |
+
}
|
629 |
+
|
630 |
+
.toast.warning {
|
631 |
+
background-color: var(--warning-color);
|
632 |
+
}
|
633 |
+
|
634 |
+
.toast.info {
|
635 |
+
background-color: var(--primary-color);
|
636 |
+
}
|
637 |
+
|
638 |
+
.toast-close {
|
639 |
+
background: none;
|
640 |
+
border: none;
|
641 |
+
color: white;
|
642 |
+
font-size: 18px;
|
643 |
+
cursor: pointer;
|
644 |
+
opacity: 0.8;
|
645 |
+
width: auto;
|
646 |
+
padding: 0 0 0 12px;
|
647 |
+
}
|
648 |
+
|
649 |
+
.toast-close:hover {
|
650 |
+
opacity: 1;
|
651 |
+
background: none;
|
652 |
+
}
|
653 |
+
|
654 |
+
@keyframes toast-in {
|
655 |
+
from {
|
656 |
+
transform: translateX(100%);
|
657 |
+
opacity: 0;
|
658 |
+
}
|
659 |
+
to {
|
660 |
+
transform: translateX(0);
|
661 |
+
opacity: 1;
|
662 |
+
}
|
663 |
+
}
|
664 |
+
|
665 |
+
/* Loading spinner */
|
666 |
+
.spinner {
|
667 |
+
display: inline-block;
|
668 |
+
width: 20px;
|
669 |
+
height: 20px;
|
670 |
+
margin-right: 8px;
|
671 |
+
border: 3px solid rgba(255, 255, 255, 0.3);
|
672 |
+
border-radius: 50%;
|
673 |
+
border-top-color: white;
|
674 |
+
animation: spin 1s ease-in-out infinite;
|
675 |
+
}
|
676 |
+
|
677 |
+
@keyframes spin {
|
678 |
+
to { transform: rotate(360deg); }
|
679 |
+
}
|
680 |
+
|
681 |
+
/* Mobile responsiveness */
|
682 |
+
@media (max-width: 640px) {
|
683 |
+
body {
|
684 |
+
padding: 12px 8px;
|
685 |
+
}
|
686 |
+
|
687 |
+
h1 {
|
688 |
+
font-size: 1.6rem;
|
689 |
+
}
|
690 |
+
|
691 |
+
.card {
|
692 |
+
padding: 16px;
|
693 |
+
}
|
694 |
+
|
695 |
+
.voice-list {
|
696 |
+
grid-template-columns: 1fr;
|
697 |
+
}
|
698 |
+
|
699 |
+
.controls {
|
700 |
+
flex-direction: column;
|
701 |
+
}
|
702 |
+
|
703 |
+
.controls button {
|
704 |
+
width: 100%;
|
705 |
+
}
|
706 |
+
|
707 |
+
.tabs button {
|
708 |
+
padding: 8px 12px;
|
709 |
+
font-size: 0.9rem;
|
710 |
+
}
|
711 |
+
}
|
712 |
+
/* Checkbox styling */
|
713 |
+
.checkbox-group {
|
714 |
+
margin-bottom: 20px;
|
715 |
+
}
|
716 |
+
.checkbox-label {
|
717 |
+
display: flex;
|
718 |
+
align-items: center;
|
719 |
+
cursor: pointer;
|
720 |
+
font-weight: 500;
|
721 |
+
font-size: 0.95rem;
|
722 |
+
}
|
723 |
+
input[type="checkbox"] {
|
724 |
+
width: auto;
|
725 |
+
margin-right: 10px;
|
726 |
+
height: 18px;
|
727 |
+
width: 18px;
|
728 |
+
cursor: pointer;
|
729 |
+
accent-color: var(--primary-color);
|
730 |
+
}
|
731 |
+
.checkbox-text {
|
732 |
+
position: relative;
|
733 |
+
top: 1px;
|
734 |
+
}
|
735 |
+
</style>
|
736 |
+
</head>
|
737 |
+
<body>
|
738 |
+
<div class="container">
|
739 |
+
<div class="theme-switch-wrapper">
|
740 |
+
<span class="theme-icon">☀️</span>
|
741 |
+
<label class="theme-switch" for="checkbox">
|
742 |
+
<input type="checkbox" id="checkbox" />
|
743 |
+
<div class="slider"></div>
|
744 |
+
</label>
|
745 |
+
<span class="theme-icon">🌙</span>
|
746 |
+
</div>
|
747 |
+
|
748 |
+
<h1>Voice Cloning</h1>
|
749 |
+
|
750 |
+
<div class="tabs">
|
751 |
+
<button id="tab-clone" class="active">Clone Voice</button>
|
752 |
+
<button id="tab-voices">My Voices</button>
|
753 |
+
<button id="tab-generate">Generate Speech</button>
|
754 |
+
</div>
|
755 |
+
|
756 |
+
<!-- Status indicator -->
|
757 |
+
<div id="status-message" class="status-indicator">
|
758 |
+
<!-- Content will be dynamically inserted -->
|
759 |
+
</div>
|
760 |
+
|
761 |
+
<div id="clone-tab" class="tab-content active">
|
762 |
+
<div class="card">
|
763 |
+
<h2>Clone a New Voice</h2>
|
764 |
+
<form id="clone-form">
|
765 |
+
<div class="form-group">
|
766 |
+
<label for="voice-name">Voice Name</label>
|
767 |
+
<input type="text" id="voice-name" name="name" required placeholder="e.g. My Voice">
|
768 |
+
</div>
|
769 |
+
<div class="form-group">
|
770 |
+
<label for="audio-file">Voice Sample (2-3 minute audio recording)</label>
|
771 |
+
<input type="file" id="audio-file" name="audio_file" required accept="audio/*">
|
772 |
+
<small>For best results, provide a clear recording with minimal background noise.</small>
|
773 |
+
</div>
|
774 |
+
<div class="form-group">
|
775 |
+
<label for="transcript">Transcript (Optional)</label>
|
776 |
+
<textarea id="transcript" name="transcript" rows="4" placeholder="Exact transcript of your audio sample..."></textarea>
|
777 |
+
<small>Adding a transcript helps improve voice accuracy.</small>
|
778 |
+
</div>
|
779 |
+
<div class="form-group">
|
780 |
+
<label for="description">Description (Optional)</label>
|
781 |
+
<input type="text" id="description" name="description" placeholder="A description of this voice">
|
782 |
+
</div>
|
783 |
+
<button type="submit">Clone Voice</button>
|
784 |
+
</form>
|
785 |
+
</div>
|
786 |
+
|
787 |
+
<div class="divider">
|
788 |
+
<span class="divider-text">OR</span>
|
789 |
+
</div>
|
790 |
+
|
791 |
+
<div class="card">
|
792 |
+
<h2>Clone Voice from YouTube</h2>
|
793 |
+
<form id="youtube-clone-form">
|
794 |
+
<div class="form-group">
|
795 |
+
<label for="youtube-url">YouTube URL</label>
|
796 |
+
<input type="url" id="youtube-url" name="youtube_url" required placeholder="https://www.youtube.com/watch?v=...">
|
797 |
+
</div>
|
798 |
+
<div class="form-group">
|
799 |
+
<label for="youtube-voice-name">Voice Name</label>
|
800 |
+
<input type="text" id="youtube-voice-name" name="voice_name" required placeholder="e.g. YouTube Voice">
|
801 |
+
</div>
|
802 |
+
<div class="form-group">
|
803 |
+
<label for="start-time">Start Time (seconds)</label>
|
804 |
+
<input type="number" id="start-time" name="start_time" min="0" value="0">
|
805 |
+
</div>
|
806 |
+
<div class="form-group">
|
807 |
+
<label for="duration">Duration (seconds)</label>
|
808 |
+
<input type="number" id="duration" name="duration" min="10" max="600" value="180">
|
809 |
+
<small>Recommended: 2-3 minutes of clear speech</small>
|
810 |
+
</div>
|
811 |
+
<div class="form-group">
|
812 |
+
<label for="youtube-description">Description (Optional)</label>
|
813 |
+
<input type="text" id="youtube-description" name="description" placeholder="A description of this voice">
|
814 |
+
</div>
|
815 |
+
<button type="submit">Clone from YouTube</button>
|
816 |
+
</form>
|
817 |
+
<div id="youtube-progress" style="display: none; margin-top: 16px;">
|
818 |
+
<p>Processing YouTube video... <span id="progress-status">Downloading</span></p>
|
819 |
+
<div class="progress-bar">
|
820 |
+
<div class="progress-fill"></div>
|
821 |
+
</div>
|
822 |
+
</div>
|
823 |
+
</div>
|
824 |
+
</div>
|
825 |
+
|
826 |
+
<div id="voices-tab" class="tab-content">
|
827 |
+
<h2>My Cloned Voices</h2>
|
828 |
+
<div id="voice-list" class="voice-list">
|
829 |
+
<!-- Voice cards will be added here -->
|
830 |
+
</div>
|
831 |
+
</div>
|
832 |
+
|
833 |
+
<div id="generate-tab" class="tab-content">
|
834 |
+
<div class="card">
|
835 |
+
<h2>Generate Speech with Cloned Voice</h2>
|
836 |
+
<form id="generate-form">
|
837 |
+
<div class="form-group">
|
838 |
+
<label for="voice-select">Select Voice</label>
|
839 |
+
<select id="voice-select" name="voice" required>
|
840 |
+
<option value="">Select a voice</option>
|
841 |
+
<!-- Voice options will be added here -->
|
842 |
+
</select>
|
843 |
+
</div>
|
844 |
+
<div class="form-group">
|
845 |
+
<label for="generate-text">Text to Speak</label>
|
846 |
+
<textarea id="generate-text" name="text" rows="4" required placeholder="Enter text to synthesize with the selected voice..."></textarea>
|
847 |
+
</div>
|
848 |
+
<div class="form-group">
|
849 |
+
<label>Temperature: <span id="temperature-value">0.7</span></label>
|
850 |
+
<input type="range" id="temperature" name="temperature" min="0.5" max="1.0" step="0.05" value="0.7">
|
851 |
+
<small>Lower values (0.5-0.7) produce more consistent speech, higher values (0.8-1.0) produce more varied speech.</small>
|
852 |
+
</div>
|
853 |
+
<div class="form-group checkbox-group">
|
854 |
+
<label class="checkbox-label">
|
855 |
+
<input type="checkbox" id="use-streaming" name="use_streaming">
|
856 |
+
<span class="checkbox-text">Use streaming mode</span>
|
857 |
+
</label>
|
858 |
+
<small>Stream audio as it's generated for faster start and lower latency.</small>
|
859 |
+
</div>
|
860 |
+
<button type="submit">Generate Speech</button>
|
861 |
+
</form>
|
862 |
+
<audio id="audio-preview" controls style="display: none;"></audio>
|
863 |
+
</div>
|
864 |
+
</div>
|
865 |
+
</div>
|
866 |
+
|
867 |
+
<!-- Toast notifications container -->
|
868 |
+
<div class="toast-container" id="toast-container"></div>
|
869 |
+
|
870 |
+
<script>
|
871 |
+
// Theme toggle functionality
|
872 |
+
const toggleSwitch = document.querySelector('#checkbox');
|
873 |
+
const html = document.querySelector('html');
|
874 |
+
|
875 |
+
// Check for saved theme preference or use system preference
|
876 |
+
function getThemePreference() {
|
877 |
+
const savedTheme = localStorage.getItem('theme');
|
878 |
+
if (savedTheme) {
|
879 |
+
return savedTheme;
|
880 |
+
}
|
881 |
+
// Check if system prefers dark mode
|
882 |
+
return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
|
883 |
+
}
|
884 |
+
|
885 |
+
// Apply the theme
|
886 |
+
function setTheme(theme) {
|
887 |
+
html.setAttribute('data-theme', theme);
|
888 |
+
localStorage.setItem('theme', theme);
|
889 |
+
toggleSwitch.checked = theme === 'dark';
|
890 |
+
}
|
891 |
+
|
892 |
+
// Initialize theme
|
893 |
+
setTheme(getThemePreference());
|
894 |
+
|
895 |
+
// Listen for toggle changes
|
896 |
+
toggleSwitch.addEventListener('change', function(e) {
|
897 |
+
if (e.target.checked) {
|
898 |
+
setTheme('dark');
|
899 |
+
} else {
|
900 |
+
setTheme('light');
|
901 |
+
}
|
902 |
+
});
|
903 |
+
|
904 |
+
// Toast notification system
|
905 |
+
function showToast(message, type = 'info', duration = 5000) {
|
906 |
+
const container = document.getElementById('toast-container');
|
907 |
+
const toast = document.createElement('div');
|
908 |
+
toast.className = `toast ${type}`;
|
909 |
+
toast.innerHTML = `
|
910 |
+
<span>${message}</span>
|
911 |
+
<button class="toast-close">×</button>
|
912 |
+
`;
|
913 |
+
container.appendChild(toast);
|
914 |
+
|
915 |
+
// Auto remove after duration
|
916 |
+
const timeout = setTimeout(() => {
|
917 |
+
toast.style.opacity = '0';
|
918 |
+
setTimeout(() => {
|
919 |
+
container.removeChild(toast);
|
920 |
+
}, 300);
|
921 |
+
}, duration);
|
922 |
+
|
923 |
+
// Manual close
|
924 |
+
toast.querySelector('.toast-close').addEventListener('click', () => {
|
925 |
+
clearTimeout(timeout);
|
926 |
+
toast.style.opacity = '0';
|
927 |
+
setTimeout(() => {
|
928 |
+
container.removeChild(toast);
|
929 |
+
}, 300);
|
930 |
+
});
|
931 |
+
}
|
932 |
+
|
933 |
+
// Status indicator functions
|
934 |
+
function showStatus(message, type) {
|
935 |
+
const statusElem = document.getElementById('status-message');
|
936 |
+
statusElem.className = `status-indicator ${type} show`;
|
937 |
+
|
938 |
+
let icon = '';
|
939 |
+
switch(type) {
|
940 |
+
case 'success':
|
941 |
+
icon = '<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M9 12l2 2 4-4M21 12a9 9 0 11-18 0 9 9 0 0118 0z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>';
|
942 |
+
break;
|
943 |
+
case 'error':
|
944 |
+
icon = '<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M10 14l2-2m0 0l2-2m-2 2l-2-2m2 2l2 2m7-2a9 9 0 11-18 0 9 9 0 0118 0z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>';
|
945 |
+
break;
|
946 |
+
case 'warning':
|
947 |
+
icon = '<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>';
|
948 |
+
break;
|
949 |
+
case 'info':
|
950 |
+
icon = '<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>';
|
951 |
+
break;
|
952 |
+
}
|
953 |
+
|
954 |
+
statusElem.innerHTML = icon + message;
|
955 |
+
|
956 |
+
// Auto-hide after 5 seconds
|
957 |
+
setTimeout(() => {
|
958 |
+
statusElem.className = 'status-indicator';
|
959 |
+
}, 5000);
|
960 |
+
}
|
961 |
+
|
962 |
+
function hideStatus() {
|
963 |
+
const statusElem = document.getElementById('status-message');
|
964 |
+
statusElem.className = 'status-indicator';
|
965 |
+
}
|
966 |
+
|
967 |
+
// Tab functionality
|
968 |
+
const tabs = document.querySelectorAll('.tabs button');
|
969 |
+
const tabContents = document.querySelectorAll('.tab-content');
|
970 |
+
|
971 |
+
tabs.forEach(tab => {
|
972 |
+
tab.addEventListener('click', () => {
|
973 |
+
// Remove active class from all tabs
|
974 |
+
tabs.forEach(t => t.classList.remove('active'));
|
975 |
+
tabContents.forEach(tc => tc.classList.remove('active'));
|
976 |
+
|
977 |
+
// Add active class to clicked tab
|
978 |
+
tab.classList.add('active');
|
979 |
+
|
980 |
+
// Show corresponding tab content
|
981 |
+
const tabId = tab.id.replace('tab-', '');
|
982 |
+
document.getElementById(`${tabId}-tab`).classList.add('active');
|
983 |
+
|
984 |
+
// Hide any status messages when changing tabs
|
985 |
+
hideStatus();
|
986 |
+
});
|
987 |
+
});
|
988 |
+
|
989 |
+
// Temperature slider
|
990 |
+
const temperatureSlider = document.getElementById('temperature');
|
991 |
+
const temperatureValue = document.getElementById('temperature-value');
|
992 |
+
|
993 |
+
temperatureSlider.addEventListener('input', () => {
|
994 |
+
temperatureValue.textContent = temperatureSlider.value;
|
995 |
+
});
|
996 |
+
|
997 |
+
// Load voices
|
998 |
+
async function loadVoices() {
|
999 |
+
try {
|
1000 |
+
const response = await fetch('/v1/voice-cloning/voices');
|
1001 |
+
const data = await response.json();
|
1002 |
+
const voiceList = document.getElementById('voice-list');
|
1003 |
+
const voiceSelect = document.getElementById('voice-select');
|
1004 |
+
|
1005 |
+
// Clear existing content
|
1006 |
+
voiceList.innerHTML = '';
|
1007 |
+
|
1008 |
+
// Clear voice select options but keep the first one
|
1009 |
+
while (voiceSelect.options.length > 1) {
|
1010 |
+
voiceSelect.remove(1);
|
1011 |
+
}
|
1012 |
+
|
1013 |
+
if (data.voices && data.voices.length > 0) {
|
1014 |
+
data.voices.forEach(voice => {
|
1015 |
+
// Add to voice list
|
1016 |
+
const voiceCard = document.createElement('div');
|
1017 |
+
voiceCard.className = 'voice-card';
|
1018 |
+
voiceCard.innerHTML = `
|
1019 |
+
<h3 class="voice-name">${voice.name}</h3>
|
1020 |
+
<p>${voice.description || 'No description'}</p>
|
1021 |
+
<p>Created: ${new Date(voice.created_at * 1000).toLocaleString()}</p>
|
1022 |
+
<div class="controls">
|
1023 |
+
<button class="btn-secondary preview-voice" data-id="${voice.id}">Preview</button>
|
1024 |
+
<button class="btn-danger delete-voice" data-id="${voice.id}">Delete</button>
|
1025 |
+
</div>
|
1026 |
+
`;
|
1027 |
+
voiceList.appendChild(voiceCard);
|
1028 |
+
|
1029 |
+
// Add to voice select
|
1030 |
+
const option = document.createElement('option');
|
1031 |
+
option.value = voice.id;
|
1032 |
+
option.textContent = voice.name;
|
1033 |
+
voiceSelect.appendChild(option);
|
1034 |
+
});
|
1035 |
+
|
1036 |
+
// Add event listeners for preview and delete buttons
|
1037 |
+
document.querySelectorAll('.preview-voice').forEach(button => {
|
1038 |
+
button.addEventListener('click', previewVoice);
|
1039 |
+
});
|
1040 |
+
document.querySelectorAll('.delete-voice').forEach(button => {
|
1041 |
+
button.addEventListener('click', deleteVoice);
|
1042 |
+
});
|
1043 |
+
|
1044 |
+
showStatus(`Loaded ${data.voices.length} voices successfully`, 'success');
|
1045 |
+
} else {
|
1046 |
+
voiceList.innerHTML = '<p>No cloned voices yet. Create one in the "Clone Voice" tab.</p>';
|
1047 |
+
}
|
1048 |
+
} catch (error) {
|
1049 |
+
console.error('Error loading voices:', error);
|
1050 |
+
showStatus('Failed to load voices', 'error');
|
1051 |
+
}
|
1052 |
+
}
|
1053 |
+
|
1054 |
+
// Preview voice
|
1055 |
+
async function previewVoice(event) {
|
1056 |
+
const button = event.target;
|
1057 |
+
const originalText = button.textContent;
|
1058 |
+
button.disabled = true;
|
1059 |
+
button.innerHTML = '<div class="spinner"></div> Loading...';
|
1060 |
+
|
1061 |
+
const voiceId = button.dataset.id;
|
1062 |
+
const audioPreview = document.getElementById('audio-preview');
|
1063 |
+
|
1064 |
+
try {
|
1065 |
+
const response = await fetch(`/v1/voice-cloning/voices/${voiceId}/preview`, {
|
1066 |
+
method: 'POST',
|
1067 |
+
headers: {
|
1068 |
+
'Content-Type': 'application/json'
|
1069 |
+
},
|
1070 |
+
body: JSON.stringify({
|
1071 |
+
text: "This is a preview of my cloned voice. I hope you like how it sounds!"
|
1072 |
+
})
|
1073 |
+
});
|
1074 |
+
|
1075 |
+
if (response.ok) {
|
1076 |
+
const blob = await response.blob();
|
1077 |
+
const url = URL.createObjectURL(blob);
|
1078 |
+
audioPreview.src = url;
|
1079 |
+
audioPreview.style.display = 'block';
|
1080 |
+
|
1081 |
+
// Switch to the generate tab
|
1082 |
+
document.getElementById('tab-generate').click();
|
1083 |
+
|
1084 |
+
// Set the voice in the select
|
1085 |
+
document.getElementById('voice-select').value = voiceId;
|
1086 |
+
audioPreview.play();
|
1087 |
+
|
1088 |
+
showToast('Voice preview loaded', 'success');
|
1089 |
+
} else {
|
1090 |
+
showToast('Failed to preview voice', 'error');
|
1091 |
+
}
|
1092 |
+
} catch (error) {
|
1093 |
+
console.error('Error previewing voice:', error);
|
1094 |
+
showToast('Error previewing voice', 'error');
|
1095 |
+
} finally {
|
1096 |
+
button.disabled = false;
|
1097 |
+
button.textContent = originalText;
|
1098 |
+
}
|
1099 |
+
}
|
1100 |
+
|
1101 |
+
// Delete voice
|
1102 |
+
async function deleteVoice(event) {
|
1103 |
+
if (!confirm('Are you sure you want to delete this voice? This cannot be undone.')) {
|
1104 |
+
return;
|
1105 |
+
}
|
1106 |
+
|
1107 |
+
const button = event.target;
|
1108 |
+
const originalText = button.textContent;
|
1109 |
+
button.disabled = true;
|
1110 |
+
button.innerHTML = '<div class="spinner"></div> Deleting...';
|
1111 |
+
|
1112 |
+
const voiceId = button.dataset.id;
|
1113 |
+
|
1114 |
+
try {
|
1115 |
+
const response = await fetch(`/v1/voice-cloning/voices/${voiceId}`, {
|
1116 |
+
method: 'DELETE'
|
1117 |
+
});
|
1118 |
+
|
1119 |
+
if (response.ok) {
|
1120 |
+
showToast('Voice deleted successfully', 'success');
|
1121 |
+
loadVoices();
|
1122 |
+
} else {
|
1123 |
+
showToast('Failed to delete voice', 'error');
|
1124 |
+
}
|
1125 |
+
} catch (error) {
|
1126 |
+
console.error('Error deleting voice:', error);
|
1127 |
+
showToast('Error deleting voice', 'error');
|
1128 |
+
} finally {
|
1129 |
+
button.disabled = false;
|
1130 |
+
button.textContent = originalText;
|
1131 |
+
}
|
1132 |
+
}
|
1133 |
+
|
1134 |
+
// Clone voice form submission
|
1135 |
+
document.getElementById('clone-form').addEventListener('submit', async (event) => {
|
1136 |
+
event.preventDefault();
|
1137 |
+
const formData = new FormData(event.target);
|
1138 |
+
const submitButton = event.target.querySelector('button[type="submit"]');
|
1139 |
+
const originalText = submitButton.textContent;
|
1140 |
+
|
1141 |
+
submitButton.disabled = true;
|
1142 |
+
submitButton.innerHTML = '<div class="spinner"></div> Cloning Voice...';
|
1143 |
+
showStatus('Processing your audio sample...', 'info');
|
1144 |
+
|
1145 |
+
try {
|
1146 |
+
const response = await fetch('/v1/voice-cloning/clone', {
|
1147 |
+
method: 'POST',
|
1148 |
+
body: formData
|
1149 |
+
});
|
1150 |
+
|
1151 |
+
if (response.ok) {
|
1152 |
+
const result = await response.json();
|
1153 |
+
showStatus('Voice cloned successfully!', 'success');
|
1154 |
+
showToast('Voice cloned successfully!', 'success');
|
1155 |
+
event.target.reset();
|
1156 |
+
|
1157 |
+
// Switch to the voices tab
|
1158 |
+
document.getElementById('tab-voices').click();
|
1159 |
+
loadVoices();
|
1160 |
+
} else {
|
1161 |
+
const error = await response.json();
|
1162 |
+
showStatus(`Failed to clone voice: ${error.detail}`, 'error');
|
1163 |
+
showToast('Failed to clone voice', 'error');
|
1164 |
+
}
|
1165 |
+
} catch (error) {
|
1166 |
+
console.error('Error cloning voice:', error);
|
1167 |
+
showStatus('Error processing your request', 'error');
|
1168 |
+
showToast('Error cloning voice', 'error');
|
1169 |
+
} finally {
|
1170 |
+
submitButton.disabled = false;
|
1171 |
+
submitButton.textContent = originalText;
|
1172 |
+
}
|
1173 |
+
});
|
1174 |
+
|
1175 |
+
// YouTube voice cloning form submission
|
1176 |
+
document.getElementById('youtube-clone-form').addEventListener('submit', async (event) => {
|
1177 |
+
event.preventDefault();
|
1178 |
+
const formData = new FormData(event.target);
|
1179 |
+
const youtubeUrl = formData.get('youtube_url');
|
1180 |
+
const voiceName = formData.get('voice_name');
|
1181 |
+
const startTime = parseInt(formData.get('start_time'));
|
1182 |
+
const duration = parseInt(formData.get('duration'));
|
1183 |
+
const description = formData.get('description');
|
1184 |
+
const progressDiv = document.getElementById('youtube-progress');
|
1185 |
+
const progressStatus = document.getElementById('progress-status');
|
1186 |
+
const progressFill = document.querySelector('.progress-fill');
|
1187 |
+
const submitButton = event.target.querySelector('button[type="submit"]');
|
1188 |
+
const originalText = submitButton.textContent;
|
1189 |
+
|
1190 |
+
submitButton.disabled = true;
|
1191 |
+
submitButton.innerHTML = '<div class="spinner"></div> Processing...';
|
1192 |
+
showStatus('Starting YouTube download...', 'info');
|
1193 |
+
|
1194 |
+
// Show progress bar
|
1195 |
+
progressDiv.style.display = 'block';
|
1196 |
+
progressFill.style.width = '10%';
|
1197 |
+
progressStatus.textContent = 'Downloading audio...';
|
1198 |
+
|
1199 |
+
// Simulate progress updates (since we can't get real-time updates easily)
|
1200 |
+
let progress = 10;
|
1201 |
+
const progressInterval = setInterval(() => {
|
1202 |
+
if (progress < 90) {
|
1203 |
+
progress += 5;
|
1204 |
+
progressFill.style.width = `${progress}%`;
|
1205 |
+
if (progress > 30 && progress < 60) {
|
1206 |
+
progressStatus.textContent = 'Generating transcript...';
|
1207 |
+
} else if (progress >= 60) {
|
1208 |
+
progressStatus.textContent = 'Cloning voice...';
|
1209 |
+
}
|
1210 |
+
}
|
1211 |
+
}, 1000);
|
1212 |
+
|
1213 |
+
try {
|
1214 |
+
const response = await fetch('/v1/voice-cloning/clone-from-youtube', {
|
1215 |
+
method: 'POST',
|
1216 |
+
headers: {
|
1217 |
+
'Content-Type': 'application/json'
|
1218 |
+
},
|
1219 |
+
body: JSON.stringify({
|
1220 |
+
youtube_url: youtubeUrl,
|
1221 |
+
voice_name: voiceName,
|
1222 |
+
start_time: startTime,
|
1223 |
+
duration: duration,
|
1224 |
+
description: description
|
1225 |
+
})
|
1226 |
+
});
|
1227 |
+
|
1228 |
+
clearInterval(progressInterval);
|
1229 |
+
|
1230 |
+
if (response.ok) {
|
1231 |
+
progressFill.style.width = '100%';
|
1232 |
+
progressStatus.textContent = 'Complete!';
|
1233 |
+
|
1234 |
+
const result = await response.json();
|
1235 |
+
showStatus('Voice cloned successfully from YouTube!', 'success');
|
1236 |
+
showToast('Voice cloned from YouTube!', 'success');
|
1237 |
+
|
1238 |
+
event.target.reset();
|
1239 |
+
|
1240 |
+
// Switch to the voices tab
|
1241 |
+
document.getElementById('tab-voices').click();
|
1242 |
+
loadVoices();
|
1243 |
+
} else {
|
1244 |
+
const error = await response.json();
|
1245 |
+
showStatus(`Failed to clone voice from YouTube: ${error.detail}`, 'error');
|
1246 |
+
showToast('Failed to clone voice from YouTube', 'error');
|
1247 |
+
progressDiv.style.display = 'none';
|
1248 |
+
}
|
1249 |
+
} catch (error) {
|
1250 |
+
console.error('Error cloning voice from YouTube:', error);
|
1251 |
+
showStatus('Error processing YouTube video', 'error');
|
1252 |
+
showToast('Error cloning voice from YouTube', 'error');
|
1253 |
+
progressDiv.style.display = 'none';
|
1254 |
+
} finally {
|
1255 |
+
clearInterval(progressInterval);
|
1256 |
+
submitButton.disabled = false;
|
1257 |
+
submitButton.textContent = originalText;
|
1258 |
+
}
|
1259 |
+
});
|
1260 |
+
|
1261 |
+
// Generate speech form submission
|
1262 |
+
document.getElementById('generate-form').addEventListener('submit', async (event) => {
|
1263 |
+
event.preventDefault();
|
1264 |
+
const formData = new FormData(event.target);
|
1265 |
+
const voiceId = formData.get('voice');
|
1266 |
+
const text = formData.get('text');
|
1267 |
+
const temperature = formData.get('temperature');
|
1268 |
+
const useStreaming = formData.get('use_streaming') === 'on';
|
1269 |
+
|
1270 |
+
if (!voiceId) {
|
1271 |
+
showToast('Please select a voice', 'warning');
|
1272 |
+
return;
|
1273 |
+
}
|
1274 |
+
|
1275 |
+
const submitButton = event.target.querySelector('button[type="submit"]');
|
1276 |
+
const originalText = submitButton.textContent;
|
1277 |
+
submitButton.disabled = true;
|
1278 |
+
submitButton.innerHTML = '<div class="spinner"></div> Generating...';
|
1279 |
+
showStatus(useStreaming ? 'Streaming speech...' : 'Generating speech...', 'info');
|
1280 |
+
|
1281 |
+
try {
|
1282 |
+
const audioPreview = document.getElementById('audio-preview');
|
1283 |
+
|
1284 |
+
if (useStreaming) {
|
1285 |
+
// For streaming, we need to handle the response differently to play audio as it arrives
|
1286 |
+
try {
|
1287 |
+
// Reset audio element
|
1288 |
+
audioPreview.style.display = 'block';
|
1289 |
+
|
1290 |
+
// Prepare the request
|
1291 |
+
const requestOptions = {
|
1292 |
+
method: 'POST',
|
1293 |
+
headers: {
|
1294 |
+
'Content-Type': 'application/json'
|
1295 |
+
},
|
1296 |
+
body: JSON.stringify({
|
1297 |
+
model: "csm-1b",
|
1298 |
+
input: text,
|
1299 |
+
voice: voiceId,
|
1300 |
+
response_format: "mp3",
|
1301 |
+
temperature: parseFloat(temperature),
|
1302 |
+
speed: 1.0
|
1303 |
+
})
|
1304 |
+
};
|
1305 |
+
|
1306 |
+
// Create a unique URL for this request to avoid caching issues
|
1307 |
+
const timestamp = new Date().getTime();
|
1308 |
+
const streamingUrl = `/v1/audio/speech/streaming?t=${timestamp}`;
|
1309 |
+
|
1310 |
+
// Fetch from streaming endpoint
|
1311 |
+
const response = await fetch(streamingUrl, requestOptions);
|
1312 |
+
|
1313 |
+
if (response.ok) {
|
1314 |
+
// Create a blob URL for immediate playback
|
1315 |
+
const blob = await response.blob();
|
1316 |
+
const url = URL.createObjectURL(blob);
|
1317 |
+
|
1318 |
+
// Set the audio source and play immediately
|
1319 |
+
audioPreview.src = url;
|
1320 |
+
audioPreview.autoplay = true;
|
1321 |
+
|
1322 |
+
// Event listeners for success/failure
|
1323 |
+
audioPreview.onplay = () => {
|
1324 |
+
showStatus('Speech streamed successfully', 'success');
|
1325 |
+
showToast('Speech streaming playback started', 'success');
|
1326 |
+
};
|
1327 |
+
|
1328 |
+
audioPreview.onerror = (e) => {
|
1329 |
+
console.error('Audio playback error:', e);
|
1330 |
+
showStatus('Error playing streamed audio', 'error');
|
1331 |
+
showToast('Streaming playback error', 'error');
|
1332 |
+
};
|
1333 |
+
} else {
|
1334 |
+
const error = await response.json();
|
1335 |
+
showStatus(`Failed to stream speech: ${error.detail || 'Unknown error'}`, 'error');
|
1336 |
+
showToast('Failed to stream speech', 'error');
|
1337 |
+
}
|
1338 |
+
} catch (error) {
|
1339 |
+
console.error('Streaming error:', error);
|
1340 |
+
showStatus(`Error streaming speech: ${error.message}`, 'error');
|
1341 |
+
showToast('Error streaming speech', 'error');
|
1342 |
+
}
|
1343 |
+
} else {
|
1344 |
+
// Non-streaming uses the original endpoint
|
1345 |
+
const response = await fetch('/v1/voice-cloning/generate', {
|
1346 |
+
method: 'POST',
|
1347 |
+
headers: {
|
1348 |
+
'Content-Type': 'application/json'
|
1349 |
+
},
|
1350 |
+
body: JSON.stringify({
|
1351 |
+
voice_id: voiceId,
|
1352 |
+
text: text,
|
1353 |
+
temperature: parseFloat(temperature)
|
1354 |
+
})
|
1355 |
+
});
|
1356 |
+
|
1357 |
+
if (response.ok) {
|
1358 |
+
const blob = await response.blob();
|
1359 |
+
const url = URL.createObjectURL(blob);
|
1360 |
+
audioPreview.src = url;
|
1361 |
+
audioPreview.style.display = 'block';
|
1362 |
+
audioPreview.play();
|
1363 |
+
showStatus('Speech generated successfully', 'success');
|
1364 |
+
showToast('Speech generated successfully', 'success');
|
1365 |
+
} else {
|
1366 |
+
const error = await response.json();
|
1367 |
+
showStatus(`Failed to generate speech: ${error.detail}`, 'error');
|
1368 |
+
showToast('Failed to generate speech', 'error');
|
1369 |
+
}
|
1370 |
+
}
|
1371 |
+
} catch (error) {
|
1372 |
+
console.error('Error generating speech:', error);
|
1373 |
+
showStatus('Error generating speech', 'error');
|
1374 |
+
showToast('Error generating speech', 'error');
|
1375 |
+
} finally {
|
1376 |
+
submitButton.disabled = false;
|
1377 |
+
submitButton.textContent = originalText;
|
1378 |
+
}
|
1379 |
+
});
|
1380 |
+
|
1381 |
+
// Load voices on page load
|
1382 |
+
loadVoices();
|
1383 |
+
</script>
|
1384 |
+
</body>
|
1385 |
+
</html>
|