imanibase commited on
Commit
db0a2ce
·
1 Parent(s): 5efe7f0

init commit

Browse files
Files changed (7) hide show
  1. Dockerfile +35 -0
  2. README.md +2 -2
  3. app.py +302 -0
  4. commit +3 -0
  5. kokoro.py +165 -0
  6. requirements.txt +17 -0
  7. tts_processor.py +142 -0
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Install system dependencies
4
+ RUN apt-get update && apt-get install -y --no-install-recommends \
5
+ libsndfile1 \
6
+ espeak-ng \
7
+ ffmpeg \
8
+ git \
9
+ wget \
10
+ && rm -rf /var/lib/apt/lists/*
11
+ RUN useradd -m -u 1000 user
12
+
13
+ # Switch to the "user" user
14
+ USER user
15
+
16
+ # Set home to the user's home directory
17
+ ENV HOME=/home/user \
18
+ PATH=/home/user/.local/bin:$PATH
19
+
20
+ # Set the working directory to the user's home directory
21
+ WORKDIR $HOME/app
22
+
23
+ # Copy and install Python dependencies
24
+ COPY requirements.txt $HOME/app/
25
+ RUN pip install --no-cache-dir -r requirements.txt && pip install --upgrade pip
26
+
27
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
28
+ COPY --chown=user . $HOME/app
29
+
30
+
31
+ # Expose port
32
+ EXPOSE 7860
33
+
34
+ # Run the application
35
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: NetMonTTS
3
- emoji: 📚
4
  colorFrom: indigo
5
  colorTo: indigo
6
  sdk: docker
 
1
  ---
2
+ title: PrivateTest
3
+ emoji: 🐨
4
  colorFrom: indigo
5
  colorTo: indigo
6
  sdk: docker
app.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, send_from_directory, abort
2
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
+ import librosa
4
+ import torch
5
+ import numpy as np
6
+ from onnxruntime import InferenceSession
7
+ import soundfile as sf
8
+ import os
9
+ import sys
10
+ import uuid
11
+ import logging
12
+ from flask_cors import CORS
13
+ import threading
14
+ import tempfile
15
+ from huggingface_hub import snapshot_download
16
+ from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
17
+ import time
18
+ from tts_processor import preprocess_all
19
+ import hashlib
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ app = Flask(__name__)
26
+ CORS(app, resources={r"/*": {"origins": "*"}})
27
+
28
+ # Global lock to ensure one method runs at a time
29
+ global_lock = threading.Lock()
30
+
31
+ # Repository ID and paths
32
+ kokoro_model_id = 'onnx-community/Kokoro-82M-v1.0-ONNX'
33
+ model_path = 'kokoro_model'
34
+ voice_name = 'am_adam' # Example voice: af (adjust as needed)
35
+
36
+ # Directory to serve files from
37
+ SERVE_DIR = os.environ.get("SERVE_DIR", "./files") # Default to './files' if not provided
38
+
39
+ os.makedirs(SERVE_DIR, exist_ok=True)
40
+ def validate_audio_file(file):
41
+ if file.content_type not in ["audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3"]:
42
+ raise ValueError("Unsupported file type")
43
+ file.seek(0, os.SEEK_END)
44
+ file_size = file.tell()
45
+ file.seek(0) # Reset file pointer
46
+ if file_size > 10 * 1024 * 1024: # 10 MB limit
47
+ raise ValueError("File is too large (max 10 MB)")
48
+
49
+ def validate_text_input(text):
50
+ if not isinstance(text, str):
51
+ raise ValueError("Text input must be a string")
52
+ if len(text.strip()) == 0:
53
+ raise ValueError("Text input cannot be empty")
54
+ if len(text) > 1024: # Limit to 1024 characters
55
+ raise ValueError("Text input is too long (max 1024 characters)")
56
+
57
+ file_cache = {}
58
+
59
+ def is_cached(cached_file_path):
60
+ """
61
+ Check if a file exists in the cache.
62
+ If the file is not in the cache, perform a disk check and update the cache.
63
+ """
64
+ if cached_file_path in file_cache:
65
+ return file_cache[cached_file_path] # Return cached result
66
+ exists = os.path.exists(cached_file_path) # Perform disk check
67
+ file_cache[cached_file_path] = exists # Update the cache
68
+ return exists
69
+ import time
70
+ from huggingface_hub import snapshot_download
71
+ from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
72
+
73
+ def initialize_models():
74
+ global sess, voice_style, processor, whisper_model
75
+
76
+ max_retries = 5 # Maximum number of retries
77
+ retry_delay = 2 # Initial delay in seconds (will double after each retry)
78
+
79
+ for attempt in range(max_retries):
80
+ try:
81
+ # Download the ONNX model if not already downloaded
82
+ if not os.path.exists(model_path):
83
+ logger.info(f"Attempt {attempt + 1} to download and load Kokoro model...")
84
+ kokoro_dir = snapshot_download(kokoro_model_id, cache_dir=model_path)
85
+ logger.info(f"Kokoro model directory: {kokoro_dir}")
86
+ else:
87
+ kokoro_dir = model_path
88
+ logger.info(f"Using cached Kokoro model directory: {kokoro_dir}")
89
+
90
+ # Validate ONNX file path
91
+ onnx_path = None
92
+ for root, _, files in os.walk(kokoro_dir):
93
+ if 'model.onnx' in files:
94
+ onnx_path = os.path.join(root, 'model.onnx')
95
+ break
96
+
97
+ if not onnx_path or not os.path.exists(onnx_path):
98
+ raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}")
99
+
100
+ logger.info("Loading ONNX session...")
101
+ sess = InferenceSession(onnx_path)
102
+ logger.info(f"ONNX session loaded successfully from {onnx_path}")
103
+
104
+ # Load the voice style vector
105
+ voice_style_path = None
106
+ for root, _, files in os.walk(kokoro_dir):
107
+ if f'{voice_name}.bin' in files:
108
+ voice_style_path = os.path.join(root, f'{voice_name}.bin')
109
+ break
110
+
111
+ if not voice_style_path or not os.path.exists(voice_style_path):
112
+ raise FileNotFoundError(f"Voice style file not found at {voice_style_path}")
113
+
114
+ logger.info("Loading voice style vector...")
115
+ voice_style = np.fromfile(voice_style_path, dtype=np.float32).reshape(-1, 1, 256)
116
+ logger.info(f"Voice style vector loaded successfully from {voice_style_path}")
117
+
118
+ # Initialize Whisper model for S2T
119
+ logger.info("Downloading and loading Whisper model...")
120
+ processor = WhisperProcessor.from_pretrained("openai/whisper-base")
121
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
122
+ whisper_model.config.forced_decoder_ids = None
123
+ logger.info("Whisper model loaded successfully")
124
+
125
+ # If everything succeeds, break out of the retry loop
126
+ break
127
+
128
+ except (RepositoryNotFoundError, HfHubHTTPError, FileNotFoundError) as e:
129
+ logger.error(f"Attempt {attempt + 1} failed: {str(e)}")
130
+ if attempt == max_retries - 1:
131
+ logger.error("Max retries reached. Failed to initialize models.")
132
+ raise # Re-raise the exception if max retries are reached
133
+ time.sleep(retry_delay)
134
+ retry_delay *= 2 # Exponential backoff
135
+
136
+ # Initialize models
137
+ initialize_models()
138
+
139
+ # Health check endpoint
140
+ @app.route('/health', methods=['GET'])
141
+ def health_check():
142
+ try:
143
+ return jsonify({"status": "healthy"}), 200
144
+ except Exception as e:
145
+ logger.error(f"Health check failed: {str(e)}")
146
+ return jsonify({"status": "unhealthy"}), 500
147
+
148
+ # Text-to-Speech (T2S) Endpoint
149
+ @app.route('/generate_audio', methods=['POST'])
150
+ def generate_audio():
151
+ """Text-to-Speech (T2S) Endpoint"""
152
+ with global_lock: # Acquire global lock to ensure only one instance runs
153
+ try:
154
+ logger.debug("Received request to /generate_audio")
155
+ data = request.json
156
+ text = data['text']
157
+ output_dir = data.get('output_dir')
158
+
159
+ validate_text_input(text)
160
+ logger.debug(f"Text: {text}")
161
+ if not output_dir:
162
+ raise ValueError("Output directory is required but not provided")
163
+
164
+ # Ensure output_dir is an absolute path and valid
165
+ if not os.path.isabs(output_dir):
166
+ raise ValueError("Output directory must be an absolute path")
167
+ if not os.path.exists(output_dir):
168
+ raise ValueError(f"Output directory does not exist: {output_dir}")
169
+
170
+ # Generate a unique hash for the text
171
+ text = preprocess_all(text)
172
+ logger.debug(f"Processed Text {text}")
173
+ text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest()
174
+ hashed_file_name = f"{text_hash}.wav"
175
+ cached_file_path = os.path.join(output_dir, hashed_file_name)
176
+ logger.debug(f"Generated hash for processed text: {text_hash}")
177
+ logger.debug(f"Output directory: {output_dir}")
178
+ logger.debug(f"Cached file path: {cached_file_path}")
179
+
180
+ # Check if cached file exists
181
+ if is_cached(cached_file_path):
182
+ logger.info(f"Returning cached audio for text: {text}")
183
+ return jsonify({"status": "success", "output_path": cached_file_path})
184
+
185
+ # Tokenize text
186
+ logger.debug("Tokenizing text...")
187
+ from kokoro import phonemize, tokenize # Import dynamically
188
+ tokens = tokenize(phonemize(text, 'a'))
189
+ logger.debug(f"Initial tokens: {tokens}")
190
+ if len(tokens) > 510:
191
+ logger.warning("Text too long; truncating to 510 tokens.")
192
+ tokens = tokens[:510]
193
+ tokens = [[0, *tokens, 0]] # Add pad tokens
194
+ logger.debug(f"Final tokens: {tokens}")
195
+
196
+ # Get style vector based on token length
197
+ logger.debug("Fetching style vector...")
198
+ ref_s = voice_style[len(tokens[0]) - 2] # Shape: (1, 256)
199
+ logger.debug(f"Style vector shape: {ref_s.shape}")
200
+
201
+ # Run ONNX inference
202
+ logger.debug("Running ONNX inference...")
203
+ audio = sess.run(None, dict(
204
+ input_ids=np.array(tokens, dtype=np.int64),
205
+ style=ref_s,
206
+ speed=np.ones(1, dtype=np.float32),
207
+ ))[0]
208
+ logger.debug(f"Audio generated with shape: {audio.shape}")
209
+
210
+ # Fix audio data for saving
211
+ audio = np.squeeze(audio) # Remove extra dimension
212
+ audio = audio.astype(np.float32) # Ensure correct data type
213
+
214
+ # Save audio
215
+ logger.debug(f"Saving audio to {cached_file_path}...")
216
+ sf.write(cached_file_path, audio, 24000) # Save with 24 kHz sample rate
217
+ logger.info(f"Audio saved successfully to {cached_file_path}")
218
+ return jsonify({"status": "success", "output_path": cached_file_path})
219
+ except Exception as e:
220
+ logger.error(f"Error generating audio: {str(e)}")
221
+ return jsonify({"status": "error", "message": str(e)}), 500
222
+
223
+ # Speech-to-Text (S2T) Endpoint
224
+ @app.route('/transcribe_audio', methods=['POST'])
225
+ def transcribe_audio():
226
+ """Speech-to-Text (S2T) Endpoint"""
227
+ with global_lock: # Acquire global lock to ensure only one instance runs
228
+ audio_path = None
229
+ try:
230
+ logger.debug("Received request to /transcribe_audio")
231
+ file = request.files['file']
232
+ validate_audio_file(file)
233
+ # Generate a unique filename using uuid
234
+ unique_filename = f"{uuid.uuid4().hex}_{file.filename}"
235
+ audio_path = os.path.join("/tmp", unique_filename)
236
+ file.save(audio_path)
237
+ logger.debug(f"Audio file saved to {audio_path}")
238
+
239
+ # Load and preprocess audio
240
+ logger.debug("Processing audio for transcription...")
241
+ audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
242
+
243
+ input_features = processor(
244
+ audio_array,
245
+ sampling_rate=sampling_rate,
246
+ return_tensors="pt"
247
+ ).input_features
248
+
249
+ # Generate transcription
250
+ logger.debug("Generating transcription...")
251
+ predicted_ids = whisper_model.generate(input_features)
252
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
253
+ logger.info(f"Transcription: {transcription}")
254
+
255
+ return jsonify({"status": "success", "transcription": transcription})
256
+ except Exception as e:
257
+ logger.error(f"Error transcribing audio: {str(e)}")
258
+ return jsonify({"status": "error", "message": str(e)}), 500
259
+ finally:
260
+ # Ensure temporary file is removed
261
+ if audio_path and os.path.exists(audio_path):
262
+ os.remove(audio_path)
263
+ logger.debug(f"Temporary file {audio_path} removed")
264
+
265
+ @app.route('/files/<filename>', methods=['GET'])
266
+ def serve_wav_file(filename):
267
+ """
268
+ Serve a .wav file from the configured directory.
269
+ Only serves files ending with '.wav'.
270
+ """
271
+ # Ensure only .wav files are allowed
272
+ if not filename.lower().endswith('.wav'):
273
+ abort(400, "Only .wav files are allowed.")
274
+
275
+ # Check if the file exists in the directory
276
+ file_path = os.path.join(SERVE_DIR, filename)
277
+ logger.debug(f"Looking for file at: {file_path}")
278
+ if not os.path.isfile(file_path):
279
+ logger.error(f"File not found: {file_path}")
280
+ abort(404, "File not found.")
281
+
282
+ # Serve the file
283
+ return send_from_directory(SERVE_DIR, filename)
284
+
285
+ # Error handlers
286
+ @app.errorhandler(400)
287
+ def bad_request(error):
288
+ """Handle 400 errors."""
289
+ return {"error": "Bad Request", "message": str(error)}, 400
290
+
291
+ @app.errorhandler(404)
292
+ def not_found(error):
293
+ """Handle 404 errors."""
294
+ return {"error": "Not Found", "message": str(error)}, 404
295
+
296
+ @app.errorhandler(500)
297
+ def internal_error(error):
298
+ """Handle unexpected errors."""
299
+ return {"error": "Internal Server Error", "message": "An unexpected error occurred."}, 500
300
+
301
+ if __name__ == "__main__":
302
+ app.run(host="0.0.0.0", port=7860)
commit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git add .
2
+ git commit -m "$*"
3
+ git push
kokoro.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import phonemizer
2
+ import re
3
+ import torch
4
+ import numpy as np
5
+
6
+ def split_num(num):
7
+ num = num.group()
8
+ if '.' in num:
9
+ return num
10
+ elif ':' in num:
11
+ h, m = [int(n) for n in num.split(':')]
12
+ if m == 0:
13
+ return f"{h} o'clock"
14
+ elif m < 10:
15
+ return f'{h} oh {m}'
16
+ return f'{h} {m}'
17
+ year = int(num[:4])
18
+ if year < 1100 or year % 1000 < 10:
19
+ return num
20
+ left, right = num[:2], int(num[2:4])
21
+ s = 's' if num.endswith('s') else ''
22
+ if 100 <= year % 1000 <= 999:
23
+ if right == 0:
24
+ return f'{left} hundred{s}'
25
+ elif right < 10:
26
+ return f'{left} oh {right}{s}'
27
+ return f'{left} {right}{s}'
28
+
29
+ def flip_money(m):
30
+ m = m.group()
31
+ bill = 'dollar' if m[0] == '$' else 'pound'
32
+ if m[-1].isalpha():
33
+ return f'{m[1:]} {bill}s'
34
+ elif '.' not in m:
35
+ s = '' if m[1:] == '1' else 's'
36
+ return f'{m[1:]} {bill}{s}'
37
+ b, c = m[1:].split('.')
38
+ s = '' if b == '1' else 's'
39
+ c = int(c.ljust(2, '0'))
40
+ coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
41
+ return f'{b} {bill}{s} and {c} {coins}'
42
+
43
+ def point_num(num):
44
+ a, b = num.group().split('.')
45
+ return ' point '.join([a, ' '.join(b)])
46
+
47
+ def normalize_text(text):
48
+ text = text.replace(chr(8216), "'").replace(chr(8217), "'")
49
+ text = text.replace('«', chr(8220)).replace('»', chr(8221))
50
+ text = text.replace(chr(8220), '"').replace(chr(8221), '"')
51
+ text = text.replace('(', '«').replace(')', '»')
52
+ for a, b in zip('、。!,:;?', ',.!,:;?'):
53
+ text = text.replace(a, b+' ')
54
+ text = re.sub(r'[^\S \n]', ' ', text)
55
+ text = re.sub(r' +', ' ', text)
56
+ text = re.sub(r'(?<=\n) +(?=\n)', '', text)
57
+ text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
58
+ text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
59
+ text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
60
+ text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
61
+ text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
62
+ text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
63
+ text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
64
+ text = re.sub(r'(?<=\d),(?=\d)', '', text)
65
+ text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
66
+ text = re.sub(r'\d*\.\d+', point_num, text)
67
+ text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
68
+ text = re.sub(r'(?<=\d)S', ' S', text)
69
+ text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
70
+ text = re.sub(r"(?<=X')S\b", 's', text)
71
+ text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
72
+ text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
73
+ return text.strip()
74
+
75
+ def get_vocab():
76
+ _pad = "$"
77
+ _punctuation = ';:,.!?¡¿—…"«»“” '
78
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
79
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
80
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
81
+ dicts = {}
82
+ for i in range(len((symbols))):
83
+ dicts[symbols[i]] = i
84
+ return dicts
85
+
86
+ VOCAB = get_vocab()
87
+ def tokenize(ps):
88
+ return [i for i in map(VOCAB.get, ps) if i is not None]
89
+
90
+ phonemizers = dict(
91
+ a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
92
+ b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
93
+ )
94
+ def phonemize(text, lang, norm=True):
95
+ if norm:
96
+ text = normalize_text(text)
97
+ ps = phonemizers[lang].phonemize([text])
98
+ ps = ps[0] if ps else ''
99
+ # https://en.wiktionary.org/wiki/kokoro#English
100
+ ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
101
+ ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
102
+ ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
103
+ ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
104
+ if lang == 'a':
105
+ ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
106
+ ps = ''.join(filter(lambda p: p in VOCAB, ps))
107
+ return ps.strip()
108
+
109
+ def length_to_mask(lengths):
110
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
111
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
112
+ return mask
113
+
114
+ @torch.no_grad()
115
+ def forward(model, tokens, ref_s, speed):
116
+ device = ref_s.device
117
+ tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
118
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
119
+ text_mask = length_to_mask(input_lengths).to(device)
120
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
121
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
122
+ s = ref_s[:, 128:]
123
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
124
+ x, _ = model.predictor.lstm(d)
125
+ duration = model.predictor.duration_proj(x)
126
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
127
+ pred_dur = torch.round(duration).clamp(min=1).long()
128
+ pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
129
+ c_frame = 0
130
+ for i in range(pred_aln_trg.size(0)):
131
+ pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
132
+ c_frame += pred_dur[0,i].item()
133
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
134
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
135
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
136
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
137
+ return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
138
+
139
+ def generate(model, text, voicepack, lang='a', speed=1, ps=None):
140
+ ps = ps or phonemize(text, lang)
141
+ tokens = tokenize(ps)
142
+ if not tokens:
143
+ return None
144
+ elif len(tokens) > 510:
145
+ tokens = tokens[:510]
146
+ print('Truncated to 510 tokens')
147
+ ref_s = voicepack[len(tokens)]
148
+ out = forward(model, tokens, ref_s, speed)
149
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
150
+ return out, ps
151
+
152
+ def generate_full(model, text, voicepack, lang='a', speed=1, ps=None):
153
+ ps = ps or phonemize(text, lang)
154
+ tokens = tokenize(ps)
155
+ if not tokens:
156
+ return None
157
+ outs = []
158
+ loop_count = len(tokens)//510 + (1 if len(tokens) % 510 != 0 else 0)
159
+ for i in range(loop_count):
160
+ ref_s = voicepack[len(tokens[i*510:(i+1)*510])]
161
+ out = forward(model, tokens[i*510:(i+1)*510], ref_s, speed)
162
+ outs.append(out)
163
+ outs = np.concatenate(outs)
164
+ ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
165
+ return outs, ps
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ transformers
4
+ librosa
5
+ numpy
6
+ soundfile
7
+ huggingface_hub
8
+ phonemizer
9
+ munch
10
+ werkzeug
11
+ num2words
12
+ dateparser
13
+ inflect
14
+ ftfy
15
+ sentencepiece
16
+ torch --index-url https://download.pytorch.org/whl/cpu
17
+ onnxruntime
tts_processor.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dateutil.parser import parse
3
+ from num2words import num2words
4
+ import inflect
5
+ from ftfy import fix_text
6
+
7
+ # Initialize the inflect engine
8
+ inflect_engine = inflect.engine()
9
+
10
+ # Define alphabet pronunciation mapping
11
+ alphabet_map = {
12
+ "A": " Eh ", "B": " Bee ", "C": " See ", "D": " Dee ", "E": " Eee ",
13
+ "F": " Eff ", "G": " Jee ", "H": " Aitch ", "I": " Eye ", "J": " Jay ",
14
+ "K": " Kay ", "L": " El ", "M": " Emm ", "N": " Enn ", "O": " Ohh ",
15
+ "P": " Pee ", "Q": " Queue ", "R": " Are ", "S": " Ess ", "T": " Tee ",
16
+ "U": " You ", "V": " Vee ", "W": " Double You ", "X": " Ex ", "Y": " Why ", "Z": " Zed "
17
+ }
18
+
19
+ # Function to add ordinal suffix to a number
20
+ def add_ordinal_suffix(day):
21
+ """Adds ordinal suffix to a day (e.g., 13 -> 13th)."""
22
+ if 11 <= day <= 13: # Special case for 11th, 12th, 13th
23
+ return f"{day}th"
24
+ elif day % 10 == 1:
25
+ return f"{day}st"
26
+ elif day % 10 == 2:
27
+ return f"{day}nd"
28
+ elif day % 10 == 3:
29
+ return f"{day}rd"
30
+ else:
31
+ return f"{day}th"
32
+
33
+ # Function to format dates in a human-readable form
34
+ def format_date(parsed_date, include_time=True):
35
+ """Formats a parsed date into a human-readable string."""
36
+ if not parsed_date:
37
+ return None
38
+
39
+ # Convert the day into an ordinal (e.g., 13 -> 13th)
40
+ day = add_ordinal_suffix(parsed_date.day)
41
+
42
+ # Format the date in a TTS-friendly way
43
+ if include_time and parsed_date.hour != 0 and parsed_date.minute != 0:
44
+ return parsed_date.strftime(f"%B {day}, %Y at %-I:%M %p") # Unix
45
+ return parsed_date.strftime(f"%B {day}, %Y") # Only date
46
+
47
+ # Normalize dates in the text
48
+ def normalize_dates(text):
49
+ """
50
+ Finds and replaces date strings with a nicely formatted, TTS-friendly version.
51
+ """
52
+ def replace_date(match):
53
+ raw_date = match.group(0)
54
+ try:
55
+ parsed_date = parse(raw_date)
56
+ if parsed_date:
57
+ include_time = "T" in raw_date or " " in raw_date # Include time only if explicitly provided
58
+ return format_date(parsed_date, include_time)
59
+ except ValueError:
60
+ pass
61
+ return raw_date
62
+
63
+ # Match common date formats
64
+ date_pattern = r"\b(\d{4}-\d{2}-\d{2}(?:[ T]\d{2}:\d{2}:\d{2})?|\d{2}/\d{2}/\d{4}|\d{1,2} \w+ \d{4})\b"
65
+ return re.sub(date_pattern, replace_date, text)
66
+
67
+ # Replace invalid characters and clean text
68
+ def replace_invalid_chars(string):
69
+ string = fix_text(string)
70
+ replacements = {
71
+ "**": "",
72
+ '&#x27;': "'",
73
+ 'AI;': 'Artificial Intelligence!',
74
+ 'iddqd;': 'Immortality cheat code',
75
+ '😉;': 'wink wink!',
76
+ ':D': '*laughs* Ahahaha!',
77
+ ';D': '*laughs* Ahahaha!'
78
+ }
79
+ for old, new in replacements.items():
80
+ string = string.replace(old, new)
81
+ return string
82
+
83
+ # Replace numbers with their word equivalents
84
+ def replace_numbers(string):
85
+ ipv4_pattern = r'(\b\d{1,3}(\.\d{1,3}){3}\b)'
86
+ ipv6_pattern = r'([0-9a-fA-F]{1,4}:){2,7}[0-9a-fA-F]{1,4}'
87
+ range_pattern = r'\b\d+-\d+\b' # Detect ranges like 1-4
88
+ date_pattern = r'\b\d{4}-\d{2}-\d{2}(?:T\d{2}:\d{2}:\d{2})?\b'
89
+ alphanumeric_pattern = r'\b[A-Za-z]+\d+|\d+[A-Za-z]+\b'
90
+
91
+ # Do not process IP addresses, date patterns, or alphanumerics
92
+ if re.search(ipv4_pattern, string) or re.search(ipv6_pattern, string) or re.search(range_pattern, string) or re.search(date_pattern, string) or re.search(alphanumeric_pattern, string):
93
+ return string
94
+
95
+ # Convert standalone numbers and port numbers
96
+ def convert_number(match):
97
+ number = match.group()
98
+ return num2words(int(number)) if number.isdigit() else number
99
+
100
+ pattern = re.compile(r'\b\d+\b')
101
+ return re.sub(pattern, convert_number, string)
102
+
103
+ # Replace abbreviations with expanded form
104
+ def replace_abbreviations(string):
105
+ words = string.split()
106
+ for i, word in enumerate(words):
107
+ if word.isupper() and len(word) <= 4 and not any(char.isdigit() for char in word) and word not in ["ID", "AM", "PM"]:
108
+ words[i] = ''.join([alphabet_map.get(char, char) for char in word])
109
+ return ' '.join(words)
110
+
111
+ # Clean up whitespace in the text
112
+ def clean_whitespace(string):
113
+ string = re.sub(r'\s+([.,?!])', r'\1', string)
114
+ return ' '.join(string.split())
115
+
116
+ # Main preprocessing pipeline
117
+ def preprocess_all(string):
118
+ string = normalize_dates(string)
119
+ string = replace_invalid_chars(string)
120
+ string = replace_numbers(string)
121
+ string = replace_abbreviations(string)
122
+ string = clean_whitespace(string)
123
+ return string
124
+
125
+ # Expose a testing function for external use
126
+ def test_preprocessing(file_path):
127
+ with open(file_path, 'r') as file:
128
+ lines = file.readlines()
129
+ for line in lines:
130
+ original = line.strip()
131
+ processed = preprocess_all(original)
132
+ print(f"Original: {original}")
133
+ print(f"Processed: {processed}\n")
134
+
135
+ if __name__ == "__main__":
136
+ import sys
137
+ if len(sys.argv) > 1:
138
+ test_file = sys.argv[1]
139
+ test_preprocessing(test_file)
140
+ else:
141
+ print("Please provide a file path as an argument.")
142
+