navodit17's picture
working demo_2
7133195
import gradio as gr
import torch
import numpy as np
from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, MarianTokenizer, MarianMTModel
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset
title = "Cascaded STST"
description = """
Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in Hindi.
Demo uses OpenAI's [Whisper Base](https://huggingface.co/openai/whisper-base) model for speech translation to English,
then MarianMT's [opus-mt-en-hi](https://huggingface.co/Helsinki-NLP/opus-mt-en-hi) model for translation to Hindi,
and finally microsoft/speechT5 fine-tuned for Hindi on IndicTTS dataset for text-to-speech.
[SpeechT5 TTS](https://huggingface.co/navodit17/speecht5_finetuned_indic_tts_hi) model for text-to-speech:
![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation")
### NOTE: The goal is not to generate perfect Hindi speech or translation, but to demonstrate the cascaded STST approach using multiple models.
The model might give poor result for very short sentences (1-2 words or so). Try to send longer audio in that case.
---
"""
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")
# load speech translation checkpoint
asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
# load text-to-speech checkpoint and speaker embeddings
processor = SpeechT5Processor.from_pretrained("navodit17/speecht5_finetuned_indic_tts_hi")
model = SpeechT5ForTextToSpeech.from_pretrained("navodit17/speecht5_finetuned_indic_tts_hi").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
# load english to hindi translation checkpoint
tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
model_en_hi = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-hi")
normalizer = BasicTextNormalizer()
def translate_en_hi(text):
inputs = tokenizer(text, return_tensors="pt")
outputs = model_en_hi.generate(**inputs, max_new_tokens=256)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def translate(audio):
outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "translate"})
print(f"Translated text - English: {outputs['text']}")
translated_text = translate_en_hi(outputs["text"])
print(f"Translated text - Hindi: {translated_text}")
return translated_text
def synthesise(text):
text = normalizer(transliterate(text, sanscript.DEVANAGARI, sanscript.ITRANS))
print(f"Normalized Text: {text}")
inputs = processor(text=text, return_tensors="pt")
print(f"Inputs: {inputs['input_ids'].shape}")
speech = model.generate_speech(input_ids=inputs["input_ids"].to(device), speaker_embeddings=speaker_embeddings.to(device), vocoder=vocoder)
return speech.cpu()
def speech_to_speech_translation(audio):
translated_text = translate(audio)
synthesised_speech = synthesise(translated_text)
print(f"Generated speech shape: {synthesised_speech.shape}")
synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
return 16000, synthesised_speech
demo = gr.Blocks()
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
title=title,
description=description,
)
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy", ),
title=title,
description=description,
)
with demo:
gr.TabbedInterface([file_translate, mic_translate], ["Audio File", "Microphone"])
demo.launch()