Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -109,46 +109,89 @@ app.layout = dbc.Container([
|
|
109 |
dcc.Store(id='generated-audio'),
|
110 |
])
|
111 |
|
112 |
-
def process_prompt(
|
113 |
-
prompt = f"{voice}: {
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
def parse_output(generated_ids):
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
for i, silent in enumerate(is_silent):
|
133 |
-
if silent and silent_start is None:
|
134 |
-
silent_start = i
|
135 |
-
elif not silent and silent_start is not None:
|
136 |
-
if i - silent_start >= min_silence_len:
|
137 |
-
silent_regions.append((silent_start, i))
|
138 |
-
silent_start = None
|
139 |
-
if silent_start is not None and len(audio) - silent_start >= min_silence_len:
|
140 |
-
silent_regions.append((silent_start, len(audio)))
|
141 |
-
return silent_regions
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
|
|
|
|
|
|
149 |
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
def generate_audio(script_output, voice1, voice2, num_hosts, temperature, top_p, repetition_penalty, max_new_tokens):
|
154 |
try:
|
@@ -174,59 +217,26 @@ def generate_audio(script_output, voice1, voice2, num_hosts, temperature, top_p,
|
|
174 |
max_new_tokens=max_new_tokens,
|
175 |
num_return_sequences=1,
|
176 |
eos_token_id=128258,
|
177 |
-
pad_token_id=128258,
|
178 |
)
|
179 |
|
180 |
code_list = parse_output(generated_ids)
|
181 |
|
182 |
-
|
183 |
-
expected_size = 2048 # This should match the model's expected input size
|
184 |
-
if len(code_list) < expected_size:
|
185 |
-
code_list = code_list + [0] * (expected_size - len(code_list))
|
186 |
-
elif len(code_list) > expected_size:
|
187 |
-
code_list = code_list[:expected_size]
|
188 |
-
|
189 |
-
# Convert to float tensor to match bias type
|
190 |
-
codes_tensor = torch.tensor(code_list, dtype=torch.float32).unsqueeze(0).to(device)
|
191 |
|
192 |
-
#
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
paragraph_audio = snac_model(codes_tensor)
|
198 |
-
|
199 |
-
# Handle tuple output
|
200 |
-
if isinstance(paragraph_audio, tuple):
|
201 |
-
paragraph_audio = paragraph_audio[0] # Assume the first element is the audio tensor
|
202 |
-
|
203 |
-
paragraph_audio = paragraph_audio.cpu().numpy().flatten()
|
204 |
-
|
205 |
-
# Log audio statistics
|
206 |
-
logger.info(f"Paragraph {i+1} audio shape: {paragraph_audio.shape}, min: {np.min(paragraph_audio)}, max: {np.max(paragraph_audio)}")
|
207 |
-
|
208 |
-
# Normalize audio to [-1, 1] range
|
209 |
-
paragraph_audio = paragraph_audio / np.max(np.abs(paragraph_audio))
|
210 |
|
211 |
audio_samples.append(paragraph_audio)
|
212 |
|
213 |
final_audio = np.concatenate(audio_samples)
|
214 |
|
215 |
-
#
|
216 |
-
|
217 |
-
|
218 |
-
# Convert to 16-bit PCM
|
219 |
-
final_audio = (final_audio * 32767).astype(np.int16)
|
220 |
|
221 |
-
|
222 |
-
buffer = io.BytesIO()
|
223 |
-
sf.write(buffer, final_audio, 24000, format='WAV', subtype='PCM_16')
|
224 |
-
buffer.seek(0)
|
225 |
-
|
226 |
-
# Log buffer size
|
227 |
-
logger.info(f"Audio buffer size: {buffer.getbuffer().nbytes} bytes")
|
228 |
-
|
229 |
-
return buffer
|
230 |
except Exception as e:
|
231 |
logger.error(f"Error generating speech: {str(e)}")
|
232 |
return None
|
@@ -346,11 +356,16 @@ def combined_callback(generate_script_clicks, generate_audio_clicks, advanced_se
|
|
346 |
if not script_output.strip():
|
347 |
return dash.no_update, html.Div("No audio generated yet."), dash.no_update, dash.no_update, "", ""
|
348 |
|
349 |
-
|
350 |
|
351 |
-
if
|
|
|
|
|
|
|
|
|
|
|
352 |
# Convert to base64 for audio playback
|
353 |
-
audio_base64 = base64.b64encode(
|
354 |
src = f"data:audio/wav;base64,{audio_base64}"
|
355 |
|
356 |
# Log audio file size
|
@@ -367,14 +382,7 @@ def combined_callback(generate_script_clicks, generate_audio_clicks, advanced_se
|
|
367 |
else:
|
368 |
logger.error("Failed to generate audio")
|
369 |
return dash.no_update, html.Div("Error generating audio"), dash.no_update, dash.no_update, "", ""
|
370 |
-
|
371 |
-
elif trigger_id == "advanced-settings-toggle":
|
372 |
-
return dash.no_update, dash.no_update, not is_advanced_open, dash.no_update, "", ""
|
373 |
-
|
374 |
-
elif trigger_id == "clear-btn":
|
375 |
-
return "", html.Div("No audio generated yet."), dash.no_update, "", "", ""
|
376 |
-
|
377 |
-
return dash.no_update, dash.no_update, dash.no_update, dash.no_update, "", ""
|
378 |
|
379 |
# Run the app
|
380 |
if __name__ == '__main__':
|
|
|
109 |
dcc.Store(id='generated-audio'),
|
110 |
])
|
111 |
|
112 |
+
def process_prompt(prompt, voice, tokenizer, device):
|
113 |
+
prompt = f"{voice}: {prompt}"
|
114 |
+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
115 |
+
|
116 |
+
start_token = torch.tensor([[128259]], dtype=torch.int64)
|
117 |
+
end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
|
118 |
+
|
119 |
+
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
|
120 |
+
attention_mask = torch.ones_like(modified_input_ids)
|
121 |
+
|
122 |
+
return modified_input_ids.to(device), attention_mask.to(device)
|
123 |
|
124 |
def parse_output(generated_ids):
|
125 |
+
token_to_find = 128257
|
126 |
+
token_to_remove = 128258
|
127 |
+
|
128 |
+
token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
|
129 |
|
130 |
+
if len(token_indices[1]) > 0:
|
131 |
+
last_occurrence_idx = token_indices[1][-1].item()
|
132 |
+
cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
|
133 |
+
else:
|
134 |
+
cropped_tensor = generated_ids
|
135 |
|
136 |
+
processed_rows = []
|
137 |
+
for row in cropped_tensor:
|
138 |
+
masked_row = row[row != token_to_remove]
|
139 |
+
processed_rows.append(masked_row)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
code_lists = []
|
142 |
+
for row in processed_rows:
|
143 |
+
row_length = row.size(0)
|
144 |
+
new_length = (row_length // 7) * 7
|
145 |
+
trimmed_row = row[:new_length]
|
146 |
+
trimmed_row = [t - 128266 for t in trimmed_row]
|
147 |
+
code_lists.append(trimmed_row)
|
148 |
+
|
149 |
+
return code_lists[0]
|
150 |
|
151 |
+
def redistribute_codes(code_list, snac_model):
|
152 |
+
device = next(snac_model.parameters()).device # Get the device of SNAC model
|
153 |
+
|
154 |
+
layer_1 = []
|
155 |
+
layer_2 = []
|
156 |
+
layer_3 = []
|
157 |
+
for i in range((len(code_list)+1)//7):
|
158 |
+
layer_1.append(code_list[7*i])
|
159 |
+
layer_2.append(code_list[7*i+1]-4096)
|
160 |
+
layer_3.append(code_list[7*i+2]-(2*4096))
|
161 |
+
layer_3.append(code_list[7*i+3]-(3*4096))
|
162 |
+
layer_2.append(code_list[7*i+4]-(4*4096))
|
163 |
+
layer_3.append(code_list[7*i+5]-(5*4096))
|
164 |
+
layer_3.append(code_list[7*i+6]-(6*4096))
|
165 |
+
|
166 |
+
codes = [
|
167 |
+
torch.tensor(layer_1, device=device).unsqueeze(0),
|
168 |
+
torch.tensor(layer_2, device=device).unsqueeze(0),
|
169 |
+
torch.tensor(layer_3, device=device).unsqueeze(0)
|
170 |
+
]
|
171 |
+
|
172 |
+
audio_hat = snac_model.decode(codes)
|
173 |
+
return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
|
174 |
+
|
175 |
+
def detect_silence(audio, threshold=0.005, min_silence_duration=1.3):
|
176 |
+
sample_rate = 24000 # Adjust if your sample rate is different
|
177 |
+
is_silent = np.abs(audio) < threshold
|
178 |
+
silent_regions = np.where(is_silent)[0]
|
179 |
+
|
180 |
+
silence_starts = []
|
181 |
+
silence_ends = []
|
182 |
+
|
183 |
+
if len(silent_regions) > 0:
|
184 |
+
silence_starts.append(silent_regions[0])
|
185 |
+
for i in range(1, len(silent_regions)):
|
186 |
+
if silent_regions[i] - silent_regions[i-1] > 1:
|
187 |
+
silence_ends.append(silent_regions[i-1])
|
188 |
+
silence_starts.append(silent_regions[i])
|
189 |
+
silence_ends.append(silent_regions[-1])
|
190 |
+
|
191 |
+
long_silences = [(start, end) for start, end in zip(silence_starts, silence_ends)
|
192 |
+
if (end - start) / sample_rate >= min_silence_duration]
|
193 |
+
|
194 |
+
return long_silences
|
195 |
|
196 |
def generate_audio(script_output, voice1, voice2, num_hosts, temperature, top_p, repetition_penalty, max_new_tokens):
|
197 |
try:
|
|
|
217 |
max_new_tokens=max_new_tokens,
|
218 |
num_return_sequences=1,
|
219 |
eos_token_id=128258,
|
|
|
220 |
)
|
221 |
|
222 |
code_list = parse_output(generated_ids)
|
223 |
|
224 |
+
paragraph_audio = redistribute_codes(code_list, snac_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
+
# Add silence detection here
|
227 |
+
silences = detect_silence(paragraph_audio)
|
228 |
+
if silences:
|
229 |
+
# Trim the audio at the last detected silence
|
230 |
+
paragraph_audio = paragraph_audio[:silences[-1][1]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
audio_samples.append(paragraph_audio)
|
233 |
|
234 |
final_audio = np.concatenate(audio_samples)
|
235 |
|
236 |
+
# Normalize the audio
|
237 |
+
final_audio = np.int16(final_audio / np.max(np.abs(final_audio)) * 32767)
|
|
|
|
|
|
|
238 |
|
239 |
+
return final_audio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
except Exception as e:
|
241 |
logger.error(f"Error generating speech: {str(e)}")
|
242 |
return None
|
|
|
356 |
if not script_output.strip():
|
357 |
return dash.no_update, html.Div("No audio generated yet."), dash.no_update, dash.no_update, "", ""
|
358 |
|
359 |
+
final_audio = generate_audio(script_output, voice1, voice2, num_hosts, temperature, top_p, repetition_penalty, max_new_tokens)
|
360 |
|
361 |
+
if final_audio is not None:
|
362 |
+
# Convert to WAV format
|
363 |
+
buffer = io.BytesIO()
|
364 |
+
sf.write(buffer, final_audio, 24000, format='WAV', subtype='PCM_16')
|
365 |
+
buffer.seek(0)
|
366 |
+
|
367 |
# Convert to base64 for audio playback
|
368 |
+
audio_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
369 |
src = f"data:audio/wav;base64,{audio_base64}"
|
370 |
|
371 |
# Log audio file size
|
|
|
382 |
else:
|
383 |
logger.error("Failed to generate audio")
|
384 |
return dash.no_update, html.Div("Error generating audio"), dash.no_update, dash.no_update, "", ""
|
385 |
+
return dash.no_update, dash.no_update, dash.no_update, dash.no_update, "", ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
# Run the app
|
388 |
if __name__ == '__main__':
|