Kartheesh's picture
Create app.py
d27f30a verified
raw
history blame
2.15 kB
import torch
import soundfile as sf
import numpy as np
import gradio as gr
from transformers import VitsModel, MBartForConditionalGeneration, AutoTokenizer, pipeline
# Load the models and tokenizers
transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")
translation_tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-50-one-to-many-mmt", use_fast=False)
translation_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-one-to-many-mmt")
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hin")
tts_model = VitsModel.from_pretrained("facebook/mms-tts-hin")
def process_audio(audio):
if audio is None:
return "No audio provided.", None
sr, y = audio
y = y.astype(np.float32)
y /= np.max(np.abs(y))
# Transcribe the audio
transcription = transcriber({"sampling_rate": sr, "raw": y})["text"]
# Translate from English to Hindi
model_inputs = translation_tokenizer(transcription, return_tensors="pt", padding=True, truncation=True)
generated_tokens = translation_model.generate(
**model_inputs,
forced_bos_token_id=translation_tokenizer.lang_code_to_id["hi_IN"]
)
translated_text = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Generate Hindi speech from translated text
tts_inputs = tts_tokenizer(translated_text, return_tensors="pt")
try:
with torch.no_grad():
tts_output = tts_model(**tts_inputs)
waveform = tts_output.waveform.squeeze().cpu().numpy()
except RuntimeError as e:
return f"Runtime Error: {e}", None
# Save the waveform to an audio file
audio_path = 'output.wav'
sf.write(audio_path, waveform, 22050)
return audio_path
# Create the Gradio interface
demo = gr.Interface(
fn=process_audio,
inputs=gr.Audio(sources=["microphone"], type="numpy"),
outputs="audio",
title="Speech-to-Hindi",
description="Record your speech or upload an audio file to transcribe, translate to Hindi, and convert to speech."
)
# Launch the Gradio app
demo.launch(debug=True)