bluenevus commited on
Commit
9e8b4c0
·
verified ·
1 Parent(s): 1e4432d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -87
app.py CHANGED
@@ -109,46 +109,89 @@ app.layout = dbc.Container([
109
  dcc.Store(id='generated-audio'),
110
  ])
111
 
112
- def process_prompt(text, voice, tokenizer, device):
113
- prompt = f"{voice}: {text}"
114
- inputs = tokenizer(prompt, return_tensors="pt")
115
- input_ids = inputs["input_ids"].to(device)
116
- attention_mask = inputs["attention_mask"].to(device)
117
- return input_ids, attention_mask
 
 
 
 
 
118
 
119
  def parse_output(generated_ids):
120
- decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
121
- code_list = [int(code) for code in decoded.split() if code.isdigit()]
122
- return code_list
 
123
 
124
- def redistribute_codes(code_list, snac_model):
125
- audio = snac_model.codes_to_audio(torch.tensor(code_list).unsqueeze(0).to(device))
126
- return audio.cpu().numpy().flatten()
 
 
127
 
128
- def detect_silence(audio, threshold=0.01, min_silence_len=1000):
129
- is_silent = np.abs(audio) < threshold
130
- silent_regions = []
131
- silent_start = None
132
- for i, silent in enumerate(is_silent):
133
- if silent and silent_start is None:
134
- silent_start = i
135
- elif not silent and silent_start is not None:
136
- if i - silent_start >= min_silence_len:
137
- silent_regions.append((silent_start, i))
138
- silent_start = None
139
- if silent_start is not None and len(audio) - silent_start >= min_silence_len:
140
- silent_regions.append((silent_start, len(audio)))
141
- return silent_regions
142
 
143
- import logging
144
- import numpy as np
145
- import torch
146
- import soundfile as sf
147
- import io
148
- from tqdm import tqdm
 
 
 
149
 
150
- logging.basicConfig(level=logging.INFO)
151
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  def generate_audio(script_output, voice1, voice2, num_hosts, temperature, top_p, repetition_penalty, max_new_tokens):
154
  try:
@@ -174,59 +217,26 @@ def generate_audio(script_output, voice1, voice2, num_hosts, temperature, top_p,
174
  max_new_tokens=max_new_tokens,
175
  num_return_sequences=1,
176
  eos_token_id=128258,
177
- pad_token_id=128258,
178
  )
179
 
180
  code_list = parse_output(generated_ids)
181
 
182
- # Ensure the code list matches the expected input size of the SNAC model
183
- expected_size = 2048 # This should match the model's expected input size
184
- if len(code_list) < expected_size:
185
- code_list = code_list + [0] * (expected_size - len(code_list))
186
- elif len(code_list) > expected_size:
187
- code_list = code_list[:expected_size]
188
-
189
- # Convert to float tensor to match bias type
190
- codes_tensor = torch.tensor(code_list, dtype=torch.float32).unsqueeze(0).to(device)
191
 
192
- # Reshape the tensor to match the expected input shape
193
- codes_tensor = codes_tensor.view(1, -1, 2048) # Adjust these dimensions as needed
194
-
195
- # Generate audio
196
- with torch.no_grad():
197
- paragraph_audio = snac_model(codes_tensor)
198
-
199
- # Handle tuple output
200
- if isinstance(paragraph_audio, tuple):
201
- paragraph_audio = paragraph_audio[0] # Assume the first element is the audio tensor
202
-
203
- paragraph_audio = paragraph_audio.cpu().numpy().flatten()
204
-
205
- # Log audio statistics
206
- logger.info(f"Paragraph {i+1} audio shape: {paragraph_audio.shape}, min: {np.min(paragraph_audio)}, max: {np.max(paragraph_audio)}")
207
-
208
- # Normalize audio to [-1, 1] range
209
- paragraph_audio = paragraph_audio / np.max(np.abs(paragraph_audio))
210
 
211
  audio_samples.append(paragraph_audio)
212
 
213
  final_audio = np.concatenate(audio_samples)
214
 
215
- # Log final audio statistics
216
- logger.info(f"Final audio shape: {final_audio.shape}, min: {np.min(final_audio)}, max: {np.max(final_audio)}")
217
-
218
- # Convert to 16-bit PCM
219
- final_audio = (final_audio * 32767).astype(np.int16)
220
 
221
- # Save as WAV file in memory
222
- buffer = io.BytesIO()
223
- sf.write(buffer, final_audio, 24000, format='WAV', subtype='PCM_16')
224
- buffer.seek(0)
225
-
226
- # Log buffer size
227
- logger.info(f"Audio buffer size: {buffer.getbuffer().nbytes} bytes")
228
-
229
- return buffer
230
  except Exception as e:
231
  logger.error(f"Error generating speech: {str(e)}")
232
  return None
@@ -346,11 +356,16 @@ def combined_callback(generate_script_clicks, generate_audio_clicks, advanced_se
346
  if not script_output.strip():
347
  return dash.no_update, html.Div("No audio generated yet."), dash.no_update, dash.no_update, "", ""
348
 
349
- audio_buffer = generate_audio(script_output, voice1, voice2, num_hosts, temperature, top_p, repetition_penalty, max_new_tokens)
350
 
351
- if audio_buffer is not None:
 
 
 
 
 
352
  # Convert to base64 for audio playback
353
- audio_base64 = base64.b64encode(audio_buffer.getvalue()).decode('utf-8')
354
  src = f"data:audio/wav;base64,{audio_base64}"
355
 
356
  # Log audio file size
@@ -367,14 +382,7 @@ def combined_callback(generate_script_clicks, generate_audio_clicks, advanced_se
367
  else:
368
  logger.error("Failed to generate audio")
369
  return dash.no_update, html.Div("Error generating audio"), dash.no_update, dash.no_update, "", ""
370
-
371
- elif trigger_id == "advanced-settings-toggle":
372
- return dash.no_update, dash.no_update, not is_advanced_open, dash.no_update, "", ""
373
-
374
- elif trigger_id == "clear-btn":
375
- return "", html.Div("No audio generated yet."), dash.no_update, "", "", ""
376
-
377
- return dash.no_update, dash.no_update, dash.no_update, dash.no_update, "", ""
378
 
379
  # Run the app
380
  if __name__ == '__main__':
 
109
  dcc.Store(id='generated-audio'),
110
  ])
111
 
112
+ def process_prompt(prompt, voice, tokenizer, device):
113
+ prompt = f"{voice}: {prompt}"
114
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
115
+
116
+ start_token = torch.tensor([[128259]], dtype=torch.int64)
117
+ end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)
118
+
119
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
120
+ attention_mask = torch.ones_like(modified_input_ids)
121
+
122
+ return modified_input_ids.to(device), attention_mask.to(device)
123
 
124
  def parse_output(generated_ids):
125
+ token_to_find = 128257
126
+ token_to_remove = 128258
127
+
128
+ token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)
129
 
130
+ if len(token_indices[1]) > 0:
131
+ last_occurrence_idx = token_indices[1][-1].item()
132
+ cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
133
+ else:
134
+ cropped_tensor = generated_ids
135
 
136
+ processed_rows = []
137
+ for row in cropped_tensor:
138
+ masked_row = row[row != token_to_remove]
139
+ processed_rows.append(masked_row)
 
 
 
 
 
 
 
 
 
 
140
 
141
+ code_lists = []
142
+ for row in processed_rows:
143
+ row_length = row.size(0)
144
+ new_length = (row_length // 7) * 7
145
+ trimmed_row = row[:new_length]
146
+ trimmed_row = [t - 128266 for t in trimmed_row]
147
+ code_lists.append(trimmed_row)
148
+
149
+ return code_lists[0]
150
 
151
+ def redistribute_codes(code_list, snac_model):
152
+ device = next(snac_model.parameters()).device # Get the device of SNAC model
153
+
154
+ layer_1 = []
155
+ layer_2 = []
156
+ layer_3 = []
157
+ for i in range((len(code_list)+1)//7):
158
+ layer_1.append(code_list[7*i])
159
+ layer_2.append(code_list[7*i+1]-4096)
160
+ layer_3.append(code_list[7*i+2]-(2*4096))
161
+ layer_3.append(code_list[7*i+3]-(3*4096))
162
+ layer_2.append(code_list[7*i+4]-(4*4096))
163
+ layer_3.append(code_list[7*i+5]-(5*4096))
164
+ layer_3.append(code_list[7*i+6]-(6*4096))
165
+
166
+ codes = [
167
+ torch.tensor(layer_1, device=device).unsqueeze(0),
168
+ torch.tensor(layer_2, device=device).unsqueeze(0),
169
+ torch.tensor(layer_3, device=device).unsqueeze(0)
170
+ ]
171
+
172
+ audio_hat = snac_model.decode(codes)
173
+ return audio_hat.detach().squeeze().cpu().numpy() # Always return CPU numpy array
174
+
175
+ def detect_silence(audio, threshold=0.005, min_silence_duration=1.3):
176
+ sample_rate = 24000 # Adjust if your sample rate is different
177
+ is_silent = np.abs(audio) < threshold
178
+ silent_regions = np.where(is_silent)[0]
179
+
180
+ silence_starts = []
181
+ silence_ends = []
182
+
183
+ if len(silent_regions) > 0:
184
+ silence_starts.append(silent_regions[0])
185
+ for i in range(1, len(silent_regions)):
186
+ if silent_regions[i] - silent_regions[i-1] > 1:
187
+ silence_ends.append(silent_regions[i-1])
188
+ silence_starts.append(silent_regions[i])
189
+ silence_ends.append(silent_regions[-1])
190
+
191
+ long_silences = [(start, end) for start, end in zip(silence_starts, silence_ends)
192
+ if (end - start) / sample_rate >= min_silence_duration]
193
+
194
+ return long_silences
195
 
196
  def generate_audio(script_output, voice1, voice2, num_hosts, temperature, top_p, repetition_penalty, max_new_tokens):
197
  try:
 
217
  max_new_tokens=max_new_tokens,
218
  num_return_sequences=1,
219
  eos_token_id=128258,
 
220
  )
221
 
222
  code_list = parse_output(generated_ids)
223
 
224
+ paragraph_audio = redistribute_codes(code_list, snac_model)
 
 
 
 
 
 
 
 
225
 
226
+ # Add silence detection here
227
+ silences = detect_silence(paragraph_audio)
228
+ if silences:
229
+ # Trim the audio at the last detected silence
230
+ paragraph_audio = paragraph_audio[:silences[-1][1]]
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  audio_samples.append(paragraph_audio)
233
 
234
  final_audio = np.concatenate(audio_samples)
235
 
236
+ # Normalize the audio
237
+ final_audio = np.int16(final_audio / np.max(np.abs(final_audio)) * 32767)
 
 
 
238
 
239
+ return final_audio
 
 
 
 
 
 
 
 
240
  except Exception as e:
241
  logger.error(f"Error generating speech: {str(e)}")
242
  return None
 
356
  if not script_output.strip():
357
  return dash.no_update, html.Div("No audio generated yet."), dash.no_update, dash.no_update, "", ""
358
 
359
+ final_audio = generate_audio(script_output, voice1, voice2, num_hosts, temperature, top_p, repetition_penalty, max_new_tokens)
360
 
361
+ if final_audio is not None:
362
+ # Convert to WAV format
363
+ buffer = io.BytesIO()
364
+ sf.write(buffer, final_audio, 24000, format='WAV', subtype='PCM_16')
365
+ buffer.seek(0)
366
+
367
  # Convert to base64 for audio playback
368
+ audio_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
369
  src = f"data:audio/wav;base64,{audio_base64}"
370
 
371
  # Log audio file size
 
382
  else:
383
  logger.error("Failed to generate audio")
384
  return dash.no_update, html.Div("Error generating audio"), dash.no_update, dash.no_update, "", ""
385
+ return dash.no_update, dash.no_update, dash.no_update, dash.no_update, "", ""
 
 
 
 
 
 
 
386
 
387
  # Run the app
388
  if __name__ == '__main__':