Flux9665 commited on
Commit
5f0da2f
·
verified ·
1 Parent(s): 1d10354

overwrite some pitch values at the start and end to make it sound more lively

Browse files
Modules/ToucanTTS/InferenceToucanTTS.py CHANGED
@@ -219,32 +219,42 @@ class ToucanTTS(torch.nn.Module):
219
  encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids)
220
 
221
  # predicting pitch, energy and durations
222
- reduced_pitch_space = torchfunc.dropout(self.pitch_latent_reduction(encoded_texts), p=0.1).transpose(1, 2)
223
  pitch_predictions = self.pitch_predictor(mu=reduced_pitch_space,
224
  mask=text_masks.float(),
225
- n_timesteps=10,
226
  temperature=prosody_creativity,
227
  c=utterance_embedding) if gold_pitch is None else gold_pitch
 
 
 
 
228
  pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale)
229
  embedded_pitch_curve = self.pitch_embed(pitch_predictions).transpose(1, 2)
230
 
231
- reduced_energy_space = torchfunc.dropout(self.energy_latent_reduction(encoded_texts + embedded_pitch_curve), p=0.1).transpose(1, 2)
232
  energy_predictions = self.energy_predictor(mu=reduced_energy_space,
233
  mask=text_masks.float(),
234
- n_timesteps=10,
235
  temperature=prosody_creativity,
236
  c=utterance_embedding) if gold_energy is None else gold_energy
 
 
 
 
 
237
  energy_predictions = _scale_variance(energy_predictions, energy_variance_scale)
238
  embedded_energy_curve = self.energy_embed(energy_predictions).transpose(1, 2)
239
 
240
- reduced_duration_space = torchfunc.dropout(self.duration_latent_reduction(encoded_texts + embedded_pitch_curve + embedded_energy_curve), p=0.1).transpose(1, 2)
241
  predicted_durations = torch.clamp(torch.ceil(self.duration_predictor(mu=reduced_duration_space,
242
  mask=text_masks.float(),
243
- n_timesteps=10,
244
  temperature=prosody_creativity,
245
- c=utterance_embedding)), min=0.0).long().squeeze(1) if gold_durations is None else gold_durations
246
 
247
  # modifying the predictions with control parameters
 
248
  for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
249
  if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1:
250
  predicted_durations[0][phoneme_index] = 0
@@ -267,8 +277,8 @@ class ToucanTTS(torch.nn.Module):
267
 
268
  refined_codec_frames = self.flow_matching_decoder(mu=preliminary_spectrogram.transpose(1, 2),
269
  mask=make_non_pad_mask([len(decoded_speech[0])], device=decoded_speech.device).unsqueeze(-2),
270
- n_timesteps=15,
271
- temperature=0.1, # low temperature, so the model follows the specified prosody curves better.
272
  c=None).transpose(1, 2)
273
 
274
  return refined_codec_frames, predicted_durations.squeeze(), pitch_predictions.squeeze(), energy_predictions.squeeze()
@@ -326,19 +336,19 @@ class ToucanTTS(torch.nn.Module):
326
  lang_id = lang_id.to(text.device)
327
 
328
  outs, \
329
- predicted_durations, \
330
- pitch_predictions, \
331
- energy_predictions = self._forward(text.unsqueeze(0),
332
- text_length,
333
- gold_durations=durations,
334
- gold_pitch=pitch,
335
- gold_energy=energy,
336
- utterance_embedding=utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None, lang_ids=lang_id,
337
- duration_scaling_factor=duration_scaling_factor,
338
- pitch_variance_scale=pitch_variance_scale,
339
- energy_variance_scale=energy_variance_scale,
340
- pause_duration_scaling_factor=pause_duration_scaling_factor,
341
- prosody_creativity=prosody_creativity)
342
 
343
  if return_duration_pitch_energy:
344
  return outs.squeeze().transpose(0, 1), predicted_durations, pitch_predictions, energy_predictions
 
219
  encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids)
220
 
221
  # predicting pitch, energy and durations
222
+ reduced_pitch_space = self.pitch_latent_reduction(encoded_texts).transpose(1, 2)
223
  pitch_predictions = self.pitch_predictor(mu=reduced_pitch_space,
224
  mask=text_masks.float(),
225
+ n_timesteps=20,
226
  temperature=prosody_creativity,
227
  c=utterance_embedding) if gold_pitch is None else gold_pitch
228
+ # because of the way we are processing the data, the last few elements of a sequence will always receive an unnaturally low pitch value. To fix this, we just overwrite them here.
229
+ pitch_predictions[0][0][0] = pitch_predictions[0][0][1]
230
+ pitch_predictions[0][0][-1] = pitch_predictions[0][0][-3]
231
+ pitch_predictions[0][0][-2] = pitch_predictions[0][0][-3]
232
  pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale)
233
  embedded_pitch_curve = self.pitch_embed(pitch_predictions).transpose(1, 2)
234
 
235
+ reduced_energy_space = self.energy_latent_reduction(encoded_texts + embedded_pitch_curve).transpose(1, 2)
236
  energy_predictions = self.energy_predictor(mu=reduced_energy_space,
237
  mask=text_masks.float(),
238
+ n_timesteps=20,
239
  temperature=prosody_creativity,
240
  c=utterance_embedding) if gold_energy is None else gold_energy
241
+
242
+ # because of the way we are processing the data, the last few elements of a sequence will always receive an unnaturally low energy value. To fix this, we just overwrite them here.
243
+ energy_predictions[0][0][0] = energy_predictions[0][0][1]
244
+ energy_predictions[0][0][-1] = energy_predictions[0][0][-3]
245
+ energy_predictions[0][0][-2] = energy_predictions[0][0][-3]
246
  energy_predictions = _scale_variance(energy_predictions, energy_variance_scale)
247
  embedded_energy_curve = self.energy_embed(energy_predictions).transpose(1, 2)
248
 
249
+ reduced_duration_space = self.duration_latent_reduction(encoded_texts + embedded_pitch_curve + embedded_energy_curve).transpose(1, 2)
250
  predicted_durations = torch.clamp(torch.ceil(self.duration_predictor(mu=reduced_duration_space,
251
  mask=text_masks.float(),
252
+ n_timesteps=20,
253
  temperature=prosody_creativity,
254
+ c=utterance_embedding)), min=0.0).long().squeeze(1) if gold_durations is None else gold_durations.squeeze(1)
255
 
256
  # modifying the predictions with control parameters
257
+ predicted_durations[0][0] = 1 # if the initial pause is too long, we get artifacts. This is once more a dirty hack.
258
  for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)):
259
  if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1:
260
  predicted_durations[0][phoneme_index] = 0
 
277
 
278
  refined_codec_frames = self.flow_matching_decoder(mu=preliminary_spectrogram.transpose(1, 2),
279
  mask=make_non_pad_mask([len(decoded_speech[0])], device=decoded_speech.device).unsqueeze(-2),
280
+ n_timesteps=30,
281
+ temperature=0.2, # low temperature, so the model follows the specified prosody curves better.
282
  c=None).transpose(1, 2)
283
 
284
  return refined_codec_frames, predicted_durations.squeeze(), pitch_predictions.squeeze(), energy_predictions.squeeze()
 
336
  lang_id = lang_id.to(text.device)
337
 
338
  outs, \
339
+ predicted_durations, \
340
+ pitch_predictions, \
341
+ energy_predictions = self._forward(text.unsqueeze(0),
342
+ text_length,
343
+ gold_durations=durations,
344
+ gold_pitch=pitch,
345
+ gold_energy=energy,
346
+ utterance_embedding=utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None, lang_ids=lang_id,
347
+ duration_scaling_factor=duration_scaling_factor,
348
+ pitch_variance_scale=pitch_variance_scale,
349
+ energy_variance_scale=energy_variance_scale,
350
+ pause_duration_scaling_factor=pause_duration_scaling_factor,
351
+ prosody_creativity=prosody_creativity)
352
 
353
  if return_duration_pitch_energy:
354
  return outs.squeeze().transpose(0, 1), predicted_durations, pitch_predictions, energy_predictions