Spaces:
Running
Running
init commit
Browse files- Dockerfile +35 -0
- README.md +2 -2
- app.py +302 -0
- commit +3 -0
- kokoro.py +165 -0
- requirements.txt +17 -0
- tts_processor.py +142 -0
Dockerfile
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.11-slim
|
2 |
+
|
3 |
+
# Install system dependencies
|
4 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
5 |
+
libsndfile1 \
|
6 |
+
espeak-ng \
|
7 |
+
ffmpeg \
|
8 |
+
git \
|
9 |
+
wget \
|
10 |
+
&& rm -rf /var/lib/apt/lists/*
|
11 |
+
RUN useradd -m -u 1000 user
|
12 |
+
|
13 |
+
# Switch to the "user" user
|
14 |
+
USER user
|
15 |
+
|
16 |
+
# Set home to the user's home directory
|
17 |
+
ENV HOME=/home/user \
|
18 |
+
PATH=/home/user/.local/bin:$PATH
|
19 |
+
|
20 |
+
# Set the working directory to the user's home directory
|
21 |
+
WORKDIR $HOME/app
|
22 |
+
|
23 |
+
# Copy and install Python dependencies
|
24 |
+
COPY requirements.txt $HOME/app/
|
25 |
+
RUN pip install --no-cache-dir -r requirements.txt && pip install --upgrade pip
|
26 |
+
|
27 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
28 |
+
COPY --chown=user . $HOME/app
|
29 |
+
|
30 |
+
|
31 |
+
# Expose port
|
32 |
+
EXPOSE 7860
|
33 |
+
|
34 |
+
# Run the application
|
35 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
6 |
sdk: docker
|
|
|
1 |
---
|
2 |
+
title: PrivateTest
|
3 |
+
emoji: 🐨
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
6 |
sdk: docker
|
app.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify, send_from_directory, abort
|
2 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
3 |
+
import librosa
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from onnxruntime import InferenceSession
|
7 |
+
import soundfile as sf
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import uuid
|
11 |
+
import logging
|
12 |
+
from flask_cors import CORS
|
13 |
+
import threading
|
14 |
+
import tempfile
|
15 |
+
from huggingface_hub import snapshot_download
|
16 |
+
from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
|
17 |
+
import time
|
18 |
+
from tts_processor import preprocess_all
|
19 |
+
import hashlib
|
20 |
+
|
21 |
+
# Configure logging
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
app = Flask(__name__)
|
26 |
+
CORS(app, resources={r"/*": {"origins": "*"}})
|
27 |
+
|
28 |
+
# Global lock to ensure one method runs at a time
|
29 |
+
global_lock = threading.Lock()
|
30 |
+
|
31 |
+
# Repository ID and paths
|
32 |
+
kokoro_model_id = 'onnx-community/Kokoro-82M-v1.0-ONNX'
|
33 |
+
model_path = 'kokoro_model'
|
34 |
+
voice_name = 'am_adam' # Example voice: af (adjust as needed)
|
35 |
+
|
36 |
+
# Directory to serve files from
|
37 |
+
SERVE_DIR = os.environ.get("SERVE_DIR", "./files") # Default to './files' if not provided
|
38 |
+
|
39 |
+
os.makedirs(SERVE_DIR, exist_ok=True)
|
40 |
+
def validate_audio_file(file):
|
41 |
+
if file.content_type not in ["audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3"]:
|
42 |
+
raise ValueError("Unsupported file type")
|
43 |
+
file.seek(0, os.SEEK_END)
|
44 |
+
file_size = file.tell()
|
45 |
+
file.seek(0) # Reset file pointer
|
46 |
+
if file_size > 10 * 1024 * 1024: # 10 MB limit
|
47 |
+
raise ValueError("File is too large (max 10 MB)")
|
48 |
+
|
49 |
+
def validate_text_input(text):
|
50 |
+
if not isinstance(text, str):
|
51 |
+
raise ValueError("Text input must be a string")
|
52 |
+
if len(text.strip()) == 0:
|
53 |
+
raise ValueError("Text input cannot be empty")
|
54 |
+
if len(text) > 1024: # Limit to 1024 characters
|
55 |
+
raise ValueError("Text input is too long (max 1024 characters)")
|
56 |
+
|
57 |
+
file_cache = {}
|
58 |
+
|
59 |
+
def is_cached(cached_file_path):
|
60 |
+
"""
|
61 |
+
Check if a file exists in the cache.
|
62 |
+
If the file is not in the cache, perform a disk check and update the cache.
|
63 |
+
"""
|
64 |
+
if cached_file_path in file_cache:
|
65 |
+
return file_cache[cached_file_path] # Return cached result
|
66 |
+
exists = os.path.exists(cached_file_path) # Perform disk check
|
67 |
+
file_cache[cached_file_path] = exists # Update the cache
|
68 |
+
return exists
|
69 |
+
import time
|
70 |
+
from huggingface_hub import snapshot_download
|
71 |
+
from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
|
72 |
+
|
73 |
+
def initialize_models():
|
74 |
+
global sess, voice_style, processor, whisper_model
|
75 |
+
|
76 |
+
max_retries = 5 # Maximum number of retries
|
77 |
+
retry_delay = 2 # Initial delay in seconds (will double after each retry)
|
78 |
+
|
79 |
+
for attempt in range(max_retries):
|
80 |
+
try:
|
81 |
+
# Download the ONNX model if not already downloaded
|
82 |
+
if not os.path.exists(model_path):
|
83 |
+
logger.info(f"Attempt {attempt + 1} to download and load Kokoro model...")
|
84 |
+
kokoro_dir = snapshot_download(kokoro_model_id, cache_dir=model_path)
|
85 |
+
logger.info(f"Kokoro model directory: {kokoro_dir}")
|
86 |
+
else:
|
87 |
+
kokoro_dir = model_path
|
88 |
+
logger.info(f"Using cached Kokoro model directory: {kokoro_dir}")
|
89 |
+
|
90 |
+
# Validate ONNX file path
|
91 |
+
onnx_path = None
|
92 |
+
for root, _, files in os.walk(kokoro_dir):
|
93 |
+
if 'model.onnx' in files:
|
94 |
+
onnx_path = os.path.join(root, 'model.onnx')
|
95 |
+
break
|
96 |
+
|
97 |
+
if not onnx_path or not os.path.exists(onnx_path):
|
98 |
+
raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
|
99 |
+
|
100 |
+
logger.info("Loading ONNX session...")
|
101 |
+
sess = InferenceSession(onnx_path)
|
102 |
+
logger.info(f"ONNX session loaded successfully from {onnx_path}")
|
103 |
+
|
104 |
+
# Load the voice style vector
|
105 |
+
voice_style_path = None
|
106 |
+
for root, _, files in os.walk(kokoro_dir):
|
107 |
+
if f'{voice_name}.bin' in files:
|
108 |
+
voice_style_path = os.path.join(root, f'{voice_name}.bin')
|
109 |
+
break
|
110 |
+
|
111 |
+
if not voice_style_path or not os.path.exists(voice_style_path):
|
112 |
+
raise FileNotFoundError(f"Voice style file not found at {voice_style_path}")
|
113 |
+
|
114 |
+
logger.info("Loading voice style vector...")
|
115 |
+
voice_style = np.fromfile(voice_style_path, dtype=np.float32).reshape(-1, 1, 256)
|
116 |
+
logger.info(f"Voice style vector loaded successfully from {voice_style_path}")
|
117 |
+
|
118 |
+
# Initialize Whisper model for S2T
|
119 |
+
logger.info("Downloading and loading Whisper model...")
|
120 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
|
121 |
+
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
|
122 |
+
whisper_model.config.forced_decoder_ids = None
|
123 |
+
logger.info("Whisper model loaded successfully")
|
124 |
+
|
125 |
+
# If everything succeeds, break out of the retry loop
|
126 |
+
break
|
127 |
+
|
128 |
+
except (RepositoryNotFoundError, HfHubHTTPError, FileNotFoundError) as e:
|
129 |
+
logger.error(f"Attempt {attempt + 1} failed: {str(e)}")
|
130 |
+
if attempt == max_retries - 1:
|
131 |
+
logger.error("Max retries reached. Failed to initialize models.")
|
132 |
+
raise # Re-raise the exception if max retries are reached
|
133 |
+
time.sleep(retry_delay)
|
134 |
+
retry_delay *= 2 # Exponential backoff
|
135 |
+
|
136 |
+
# Initialize models
|
137 |
+
initialize_models()
|
138 |
+
|
139 |
+
# Health check endpoint
|
140 |
+
@app.route('/health', methods=['GET'])
|
141 |
+
def health_check():
|
142 |
+
try:
|
143 |
+
return jsonify({"status": "healthy"}), 200
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Health check failed: {str(e)}")
|
146 |
+
return jsonify({"status": "unhealthy"}), 500
|
147 |
+
|
148 |
+
# Text-to-Speech (T2S) Endpoint
|
149 |
+
@app.route('/generate_audio', methods=['POST'])
|
150 |
+
def generate_audio():
|
151 |
+
"""Text-to-Speech (T2S) Endpoint"""
|
152 |
+
with global_lock: # Acquire global lock to ensure only one instance runs
|
153 |
+
try:
|
154 |
+
logger.debug("Received request to /generate_audio")
|
155 |
+
data = request.json
|
156 |
+
text = data['text']
|
157 |
+
output_dir = data.get('output_dir')
|
158 |
+
|
159 |
+
validate_text_input(text)
|
160 |
+
logger.debug(f"Text: {text}")
|
161 |
+
if not output_dir:
|
162 |
+
raise ValueError("Output directory is required but not provided")
|
163 |
+
|
164 |
+
# Ensure output_dir is an absolute path and valid
|
165 |
+
if not os.path.isabs(output_dir):
|
166 |
+
raise ValueError("Output directory must be an absolute path")
|
167 |
+
if not os.path.exists(output_dir):
|
168 |
+
raise ValueError(f"Output directory does not exist: {output_dir}")
|
169 |
+
|
170 |
+
# Generate a unique hash for the text
|
171 |
+
text = preprocess_all(text)
|
172 |
+
logger.debug(f"Processed Text {text}")
|
173 |
+
text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
|
174 |
+
hashed_file_name = f"{text_hash}.wav"
|
175 |
+
cached_file_path = os.path.join(output_dir, hashed_file_name)
|
176 |
+
logger.debug(f"Generated hash for processed text: {text_hash}")
|
177 |
+
logger.debug(f"Output directory: {output_dir}")
|
178 |
+
logger.debug(f"Cached file path: {cached_file_path}")
|
179 |
+
|
180 |
+
# Check if cached file exists
|
181 |
+
if is_cached(cached_file_path):
|
182 |
+
logger.info(f"Returning cached audio for text: {text}")
|
183 |
+
return jsonify({"status": "success", "output_path": cached_file_path})
|
184 |
+
|
185 |
+
# Tokenize text
|
186 |
+
logger.debug("Tokenizing text...")
|
187 |
+
from kokoro import phonemize, tokenize # Import dynamically
|
188 |
+
tokens = tokenize(phonemize(text, 'a'))
|
189 |
+
logger.debug(f"Initial tokens: {tokens}")
|
190 |
+
if len(tokens) > 510:
|
191 |
+
logger.warning("Text too long; truncating to 510 tokens.")
|
192 |
+
tokens = tokens[:510]
|
193 |
+
tokens = [[0, *tokens, 0]] # Add pad tokens
|
194 |
+
logger.debug(f"Final tokens: {tokens}")
|
195 |
+
|
196 |
+
# Get style vector based on token length
|
197 |
+
logger.debug("Fetching style vector...")
|
198 |
+
ref_s = voice_style[len(tokens[0]) - 2] # Shape: (1, 256)
|
199 |
+
logger.debug(f"Style vector shape: {ref_s.shape}")
|
200 |
+
|
201 |
+
# Run ONNX inference
|
202 |
+
logger.debug("Running ONNX inference...")
|
203 |
+
audio = sess.run(None, dict(
|
204 |
+
input_ids=np.array(tokens, dtype=np.int64),
|
205 |
+
style=ref_s,
|
206 |
+
speed=np.ones(1, dtype=np.float32),
|
207 |
+
))[0]
|
208 |
+
logger.debug(f"Audio generated with shape: {audio.shape}")
|
209 |
+
|
210 |
+
# Fix audio data for saving
|
211 |
+
audio = np.squeeze(audio) # Remove extra dimension
|
212 |
+
audio = audio.astype(np.float32) # Ensure correct data type
|
213 |
+
|
214 |
+
# Save audio
|
215 |
+
logger.debug(f"Saving audio to {cached_file_path}...")
|
216 |
+
sf.write(cached_file_path, audio, 24000) # Save with 24 kHz sample rate
|
217 |
+
logger.info(f"Audio saved successfully to {cached_file_path}")
|
218 |
+
return jsonify({"status": "success", "output_path": cached_file_path})
|
219 |
+
except Exception as e:
|
220 |
+
logger.error(f"Error generating audio: {str(e)}")
|
221 |
+
return jsonify({"status": "error", "message": str(e)}), 500
|
222 |
+
|
223 |
+
# Speech-to-Text (S2T) Endpoint
|
224 |
+
@app.route('/transcribe_audio', methods=['POST'])
|
225 |
+
def transcribe_audio():
|
226 |
+
"""Speech-to-Text (S2T) Endpoint"""
|
227 |
+
with global_lock: # Acquire global lock to ensure only one instance runs
|
228 |
+
audio_path = None
|
229 |
+
try:
|
230 |
+
logger.debug("Received request to /transcribe_audio")
|
231 |
+
file = request.files['file']
|
232 |
+
validate_audio_file(file)
|
233 |
+
# Generate a unique filename using uuid
|
234 |
+
unique_filename = f"{uuid.uuid4().hex}_{file.filename}"
|
235 |
+
audio_path = os.path.join("/tmp", unique_filename)
|
236 |
+
file.save(audio_path)
|
237 |
+
logger.debug(f"Audio file saved to {audio_path}")
|
238 |
+
|
239 |
+
# Load and preprocess audio
|
240 |
+
logger.debug("Processing audio for transcription...")
|
241 |
+
audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
|
242 |
+
|
243 |
+
input_features = processor(
|
244 |
+
audio_array,
|
245 |
+
sampling_rate=sampling_rate,
|
246 |
+
return_tensors="pt"
|
247 |
+
).input_features
|
248 |
+
|
249 |
+
# Generate transcription
|
250 |
+
logger.debug("Generating transcription...")
|
251 |
+
predicted_ids = whisper_model.generate(input_features)
|
252 |
+
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
253 |
+
logger.info(f"Transcription: {transcription}")
|
254 |
+
|
255 |
+
return jsonify({"status": "success", "transcription": transcription})
|
256 |
+
except Exception as e:
|
257 |
+
logger.error(f"Error transcribing audio: {str(e)}")
|
258 |
+
return jsonify({"status": "error", "message": str(e)}), 500
|
259 |
+
finally:
|
260 |
+
# Ensure temporary file is removed
|
261 |
+
if audio_path and os.path.exists(audio_path):
|
262 |
+
os.remove(audio_path)
|
263 |
+
logger.debug(f"Temporary file {audio_path} removed")
|
264 |
+
|
265 |
+
@app.route('/files/<filename>', methods=['GET'])
|
266 |
+
def serve_wav_file(filename):
|
267 |
+
"""
|
268 |
+
Serve a .wav file from the configured directory.
|
269 |
+
Only serves files ending with '.wav'.
|
270 |
+
"""
|
271 |
+
# Ensure only .wav files are allowed
|
272 |
+
if not filename.lower().endswith('.wav'):
|
273 |
+
abort(400, "Only .wav files are allowed.")
|
274 |
+
|
275 |
+
# Check if the file exists in the directory
|
276 |
+
file_path = os.path.join(SERVE_DIR, filename)
|
277 |
+
logger.debug(f"Looking for file at: {file_path}")
|
278 |
+
if not os.path.isfile(file_path):
|
279 |
+
logger.error(f"File not found: {file_path}")
|
280 |
+
abort(404, "File not found.")
|
281 |
+
|
282 |
+
# Serve the file
|
283 |
+
return send_from_directory(SERVE_DIR, filename)
|
284 |
+
|
285 |
+
# Error handlers
|
286 |
+
@app.errorhandler(400)
|
287 |
+
def bad_request(error):
|
288 |
+
"""Handle 400 errors."""
|
289 |
+
return {"error": "Bad Request", "message": str(error)}, 400
|
290 |
+
|
291 |
+
@app.errorhandler(404)
|
292 |
+
def not_found(error):
|
293 |
+
"""Handle 404 errors."""
|
294 |
+
return {"error": "Not Found", "message": str(error)}, 404
|
295 |
+
|
296 |
+
@app.errorhandler(500)
|
297 |
+
def internal_error(error):
|
298 |
+
"""Handle unexpected errors."""
|
299 |
+
return {"error": "Internal Server Error", "message": "An unexpected error occurred."}, 500
|
300 |
+
|
301 |
+
if __name__ == "__main__":
|
302 |
+
app.run(host="0.0.0.0", port=7860)
|
commit
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
git add .
|
2 |
+
git commit -m "$*"
|
3 |
+
git push
|
kokoro.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import phonemizer
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
def split_num(num):
|
7 |
+
num = num.group()
|
8 |
+
if '.' in num:
|
9 |
+
return num
|
10 |
+
elif ':' in num:
|
11 |
+
h, m = [int(n) for n in num.split(':')]
|
12 |
+
if m == 0:
|
13 |
+
return f"{h} o'clock"
|
14 |
+
elif m < 10:
|
15 |
+
return f'{h} oh {m}'
|
16 |
+
return f'{h} {m}'
|
17 |
+
year = int(num[:4])
|
18 |
+
if year < 1100 or year % 1000 < 10:
|
19 |
+
return num
|
20 |
+
left, right = num[:2], int(num[2:4])
|
21 |
+
s = 's' if num.endswith('s') else ''
|
22 |
+
if 100 <= year % 1000 <= 999:
|
23 |
+
if right == 0:
|
24 |
+
return f'{left} hundred{s}'
|
25 |
+
elif right < 10:
|
26 |
+
return f'{left} oh {right}{s}'
|
27 |
+
return f'{left} {right}{s}'
|
28 |
+
|
29 |
+
def flip_money(m):
|
30 |
+
m = m.group()
|
31 |
+
bill = 'dollar' if m[0] == '$' else 'pound'
|
32 |
+
if m[-1].isalpha():
|
33 |
+
return f'{m[1:]} {bill}s'
|
34 |
+
elif '.' not in m:
|
35 |
+
s = '' if m[1:] == '1' else 's'
|
36 |
+
return f'{m[1:]} {bill}{s}'
|
37 |
+
b, c = m[1:].split('.')
|
38 |
+
s = '' if b == '1' else 's'
|
39 |
+
c = int(c.ljust(2, '0'))
|
40 |
+
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
|
41 |
+
return f'{b} {bill}{s} and {c} {coins}'
|
42 |
+
|
43 |
+
def point_num(num):
|
44 |
+
a, b = num.group().split('.')
|
45 |
+
return ' point '.join([a, ' '.join(b)])
|
46 |
+
|
47 |
+
def normalize_text(text):
|
48 |
+
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
49 |
+
text = text.replace('«', chr(8220)).replace('»', chr(8221))
|
50 |
+
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
51 |
+
text = text.replace('(', '«').replace(')', '»')
|
52 |
+
for a, b in zip('、。!,:;?', ',.!,:;?'):
|
53 |
+
text = text.replace(a, b+' ')
|
54 |
+
text = re.sub(r'[^\S \n]', ' ', text)
|
55 |
+
text = re.sub(r' +', ' ', text)
|
56 |
+
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
|
57 |
+
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
|
58 |
+
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
|
59 |
+
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
|
60 |
+
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
|
61 |
+
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
|
62 |
+
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
|
63 |
+
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
|
64 |
+
text = re.sub(r'(?<=\d),(?=\d)', '', text)
|
65 |
+
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
|
66 |
+
text = re.sub(r'\d*\.\d+', point_num, text)
|
67 |
+
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
|
68 |
+
text = re.sub(r'(?<=\d)S', ' S', text)
|
69 |
+
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
70 |
+
text = re.sub(r"(?<=X')S\b", 's', text)
|
71 |
+
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
|
72 |
+
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
|
73 |
+
return text.strip()
|
74 |
+
|
75 |
+
def get_vocab():
|
76 |
+
_pad = "$"
|
77 |
+
_punctuation = ';:,.!?¡¿—…"«»“” '
|
78 |
+
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
79 |
+
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
80 |
+
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
81 |
+
dicts = {}
|
82 |
+
for i in range(len((symbols))):
|
83 |
+
dicts[symbols[i]] = i
|
84 |
+
return dicts
|
85 |
+
|
86 |
+
VOCAB = get_vocab()
|
87 |
+
def tokenize(ps):
|
88 |
+
return [i for i in map(VOCAB.get, ps) if i is not None]
|
89 |
+
|
90 |
+
phonemizers = dict(
|
91 |
+
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
|
92 |
+
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
|
93 |
+
)
|
94 |
+
def phonemize(text, lang, norm=True):
|
95 |
+
if norm:
|
96 |
+
text = normalize_text(text)
|
97 |
+
ps = phonemizers[lang].phonemize([text])
|
98 |
+
ps = ps[0] if ps else ''
|
99 |
+
# https://en.wiktionary.org/wiki/kokoro#English
|
100 |
+
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
|
101 |
+
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
|
102 |
+
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
|
103 |
+
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
|
104 |
+
if lang == 'a':
|
105 |
+
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
|
106 |
+
ps = ''.join(filter(lambda p: p in VOCAB, ps))
|
107 |
+
return ps.strip()
|
108 |
+
|
109 |
+
def length_to_mask(lengths):
|
110 |
+
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
111 |
+
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
112 |
+
return mask
|
113 |
+
|
114 |
+
@torch.no_grad()
|
115 |
+
def forward(model, tokens, ref_s, speed):
|
116 |
+
device = ref_s.device
|
117 |
+
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
118 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
119 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
120 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
121 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
122 |
+
s = ref_s[:, 128:]
|
123 |
+
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
124 |
+
x, _ = model.predictor.lstm(d)
|
125 |
+
duration = model.predictor.duration_proj(x)
|
126 |
+
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
127 |
+
pred_dur = torch.round(duration).clamp(min=1).long()
|
128 |
+
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
129 |
+
c_frame = 0
|
130 |
+
for i in range(pred_aln_trg.size(0)):
|
131 |
+
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
132 |
+
c_frame += pred_dur[0,i].item()
|
133 |
+
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
134 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
135 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
136 |
+
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
137 |
+
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
138 |
+
|
139 |
+
def generate(model, text, voicepack, lang='a', speed=1, ps=None):
|
140 |
+
ps = ps or phonemize(text, lang)
|
141 |
+
tokens = tokenize(ps)
|
142 |
+
if not tokens:
|
143 |
+
return None
|
144 |
+
elif len(tokens) > 510:
|
145 |
+
tokens = tokens[:510]
|
146 |
+
print('Truncated to 510 tokens')
|
147 |
+
ref_s = voicepack[len(tokens)]
|
148 |
+
out = forward(model, tokens, ref_s, speed)
|
149 |
+
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
150 |
+
return out, ps
|
151 |
+
|
152 |
+
def generate_full(model, text, voicepack, lang='a', speed=1, ps=None):
|
153 |
+
ps = ps or phonemize(text, lang)
|
154 |
+
tokens = tokenize(ps)
|
155 |
+
if not tokens:
|
156 |
+
return None
|
157 |
+
outs = []
|
158 |
+
loop_count = len(tokens)//510 + (1 if len(tokens) % 510 != 0 else 0)
|
159 |
+
for i in range(loop_count):
|
160 |
+
ref_s = voicepack[len(tokens[i*510:(i+1)*510])]
|
161 |
+
out = forward(model, tokens[i*510:(i+1)*510], ref_s, speed)
|
162 |
+
outs.append(out)
|
163 |
+
outs = np.concatenate(outs)
|
164 |
+
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
165 |
+
return outs, ps
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flask
|
2 |
+
flask-cors
|
3 |
+
transformers
|
4 |
+
librosa
|
5 |
+
numpy
|
6 |
+
soundfile
|
7 |
+
huggingface_hub
|
8 |
+
phonemizer
|
9 |
+
munch
|
10 |
+
werkzeug
|
11 |
+
num2words
|
12 |
+
dateparser
|
13 |
+
inflect
|
14 |
+
ftfy
|
15 |
+
sentencepiece
|
16 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
17 |
+
onnxruntime
|
tts_processor.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from dateutil.parser import parse
|
3 |
+
from num2words import num2words
|
4 |
+
import inflect
|
5 |
+
from ftfy import fix_text
|
6 |
+
|
7 |
+
# Initialize the inflect engine
|
8 |
+
inflect_engine = inflect.engine()
|
9 |
+
|
10 |
+
# Define alphabet pronunciation mapping
|
11 |
+
alphabet_map = {
|
12 |
+
"A": " Eh ", "B": " Bee ", "C": " See ", "D": " Dee ", "E": " Eee ",
|
13 |
+
"F": " Eff ", "G": " Jee ", "H": " Aitch ", "I": " Eye ", "J": " Jay ",
|
14 |
+
"K": " Kay ", "L": " El ", "M": " Emm ", "N": " Enn ", "O": " Ohh ",
|
15 |
+
"P": " Pee ", "Q": " Queue ", "R": " Are ", "S": " Ess ", "T": " Tee ",
|
16 |
+
"U": " You ", "V": " Vee ", "W": " Double You ", "X": " Ex ", "Y": " Why ", "Z": " Zed "
|
17 |
+
}
|
18 |
+
|
19 |
+
# Function to add ordinal suffix to a number
|
20 |
+
def add_ordinal_suffix(day):
|
21 |
+
"""Adds ordinal suffix to a day (e.g., 13 -> 13th)."""
|
22 |
+
if 11 <= day <= 13: # Special case for 11th, 12th, 13th
|
23 |
+
return f"{day}th"
|
24 |
+
elif day % 10 == 1:
|
25 |
+
return f"{day}st"
|
26 |
+
elif day % 10 == 2:
|
27 |
+
return f"{day}nd"
|
28 |
+
elif day % 10 == 3:
|
29 |
+
return f"{day}rd"
|
30 |
+
else:
|
31 |
+
return f"{day}th"
|
32 |
+
|
33 |
+
# Function to format dates in a human-readable form
|
34 |
+
def format_date(parsed_date, include_time=True):
|
35 |
+
"""Formats a parsed date into a human-readable string."""
|
36 |
+
if not parsed_date:
|
37 |
+
return None
|
38 |
+
|
39 |
+
# Convert the day into an ordinal (e.g., 13 -> 13th)
|
40 |
+
day = add_ordinal_suffix(parsed_date.day)
|
41 |
+
|
42 |
+
# Format the date in a TTS-friendly way
|
43 |
+
if include_time and parsed_date.hour != 0 and parsed_date.minute != 0:
|
44 |
+
return parsed_date.strftime(f"%B {day}, %Y at %-I:%M %p") # Unix
|
45 |
+
return parsed_date.strftime(f"%B {day}, %Y") # Only date
|
46 |
+
|
47 |
+
# Normalize dates in the text
|
48 |
+
def normalize_dates(text):
|
49 |
+
"""
|
50 |
+
Finds and replaces date strings with a nicely formatted, TTS-friendly version.
|
51 |
+
"""
|
52 |
+
def replace_date(match):
|
53 |
+
raw_date = match.group(0)
|
54 |
+
try:
|
55 |
+
parsed_date = parse(raw_date)
|
56 |
+
if parsed_date:
|
57 |
+
include_time = "T" in raw_date or " " in raw_date # Include time only if explicitly provided
|
58 |
+
return format_date(parsed_date, include_time)
|
59 |
+
except ValueError:
|
60 |
+
pass
|
61 |
+
return raw_date
|
62 |
+
|
63 |
+
# Match common date formats
|
64 |
+
date_pattern = r"\b(\d{4}-\d{2}-\d{2}(?:[ T]\d{2}:\d{2}:\d{2})?|\d{2}/\d{2}/\d{4}|\d{1,2} \w+ \d{4})\b"
|
65 |
+
return re.sub(date_pattern, replace_date, text)
|
66 |
+
|
67 |
+
# Replace invalid characters and clean text
|
68 |
+
def replace_invalid_chars(string):
|
69 |
+
string = fix_text(string)
|
70 |
+
replacements = {
|
71 |
+
"**": "",
|
72 |
+
''': "'",
|
73 |
+
'AI;': 'Artificial Intelligence!',
|
74 |
+
'iddqd;': 'Immortality cheat code',
|
75 |
+
'😉;': 'wink wink!',
|
76 |
+
':D': '*laughs* Ahahaha!',
|
77 |
+
';D': '*laughs* Ahahaha!'
|
78 |
+
}
|
79 |
+
for old, new in replacements.items():
|
80 |
+
string = string.replace(old, new)
|
81 |
+
return string
|
82 |
+
|
83 |
+
# Replace numbers with their word equivalents
|
84 |
+
def replace_numbers(string):
|
85 |
+
ipv4_pattern = r'(\b\d{1,3}(\.\d{1,3}){3}\b)'
|
86 |
+
ipv6_pattern = r'([0-9a-fA-F]{1,4}:){2,7}[0-9a-fA-F]{1,4}'
|
87 |
+
range_pattern = r'\b\d+-\d+\b' # Detect ranges like 1-4
|
88 |
+
date_pattern = r'\b\d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}:\d{2})?\b'
|
89 |
+
alphanumeric_pattern = r'\b[A-Za-z]+\d+|\d+[A-Za-z]+\b'
|
90 |
+
|
91 |
+
# Do not process IP addresses, date patterns, or alphanumerics
|
92 |
+
if re.search(ipv4_pattern, string) or re.search(ipv6_pattern, string) or re.search(range_pattern, string) or re.search(date_pattern, string) or re.search(alphanumeric_pattern, string):
|
93 |
+
return string
|
94 |
+
|
95 |
+
# Convert standalone numbers and port numbers
|
96 |
+
def convert_number(match):
|
97 |
+
number = match.group()
|
98 |
+
return num2words(int(number)) if number.isdigit() else number
|
99 |
+
|
100 |
+
pattern = re.compile(r'\b\d+\b')
|
101 |
+
return re.sub(pattern, convert_number, string)
|
102 |
+
|
103 |
+
# Replace abbreviations with expanded form
|
104 |
+
def replace_abbreviations(string):
|
105 |
+
words = string.split()
|
106 |
+
for i, word in enumerate(words):
|
107 |
+
if word.isupper() and len(word) <= 4 and not any(char.isdigit() for char in word) and word not in ["ID", "AM", "PM"]:
|
108 |
+
words[i] = ''.join([alphabet_map.get(char, char) for char in word])
|
109 |
+
return ' '.join(words)
|
110 |
+
|
111 |
+
# Clean up whitespace in the text
|
112 |
+
def clean_whitespace(string):
|
113 |
+
string = re.sub(r'\s+([.,?!])', r'\1', string)
|
114 |
+
return ' '.join(string.split())
|
115 |
+
|
116 |
+
# Main preprocessing pipeline
|
117 |
+
def preprocess_all(string):
|
118 |
+
string = normalize_dates(string)
|
119 |
+
string = replace_invalid_chars(string)
|
120 |
+
string = replace_numbers(string)
|
121 |
+
string = replace_abbreviations(string)
|
122 |
+
string = clean_whitespace(string)
|
123 |
+
return string
|
124 |
+
|
125 |
+
# Expose a testing function for external use
|
126 |
+
def test_preprocessing(file_path):
|
127 |
+
with open(file_path, 'r') as file:
|
128 |
+
lines = file.readlines()
|
129 |
+
for line in lines:
|
130 |
+
original = line.strip()
|
131 |
+
processed = preprocess_all(original)
|
132 |
+
print(f"Original: {original}")
|
133 |
+
print(f"Processed: {processed}\n")
|
134 |
+
|
135 |
+
if __name__ == "__main__":
|
136 |
+
import sys
|
137 |
+
if len(sys.argv) > 1:
|
138 |
+
test_file = sys.argv[1]
|
139 |
+
test_preprocessing(test_file)
|
140 |
+
else:
|
141 |
+
print("Please provide a file path as an argument.")
|
142 |
+
|