viditk's picture
Upload 134 files
d44849f verified
raw
history blame contribute delete
3.95 kB
import gradio as gr
import torch
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import speech_recognition as sr
# Constants
BATCH_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
quantization = None
# ---- IndicTrans2 Model Initialization ----
def initialize_model_and_tokenizer(ckpt_dir, quantization):
if quantization == "4-bit":
qconfig = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
elif quantization == "8-bit":
qconfig = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_use_double_quant=True,
bnb_8bit_compute_dtype=torch.bfloat16,
)
else:
qconfig = None
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(
ckpt_dir,
trust_remote_code=True,
low_cpu_mem_usage=True,
quantization_config=qconfig,
)
if qconfig is None:
model = model.to(DEVICE)
if DEVICE == "cuda":
model.half()
model.eval()
return tokenizer, model
def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip):
translations = []
for i in range(0, len(input_sentences), BATCH_SIZE):
batch = input_sentences[i : i + BATCH_SIZE]
batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang)
del inputs
torch.cuda.empty_cache()
return translations
# Initialize IndicTrans2
en_indic_ckpt_dir = "ai4bharat/indictrans2-indic-en-1B"
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization)
ip = IndicProcessor(inference=True)
# ---- Gradio Function ----
def transcribe_and_translate(audio):
recognizer = sr.Recognizer()
with sr.AudioFile(audio) as source:
audio_data = recognizer.record(source)
try:
# Malayalam transcription using Google API
malayalam_text = recognizer.recognize_google(audio_data, language="ml-IN")
except sr.UnknownValueError:
return "Could not understand audio", ""
except sr.RequestError as e:
return f"Google API Error: {e}", ""
# Translation
en_sents = [malayalam_text]
src_lang, tgt_lang = "mal_Mlym", "eng_Latn"
translations = batch_translate(en_sents, src_lang, tgt_lang, en_indic_model, en_indic_tokenizer, ip)
return malayalam_text, translations[0]
# ---- Gradio Interface ----
iface = gr.Interface(
fn=transcribe_and_translate,
inputs=gr.Audio(sources=["microphone", "upload"], type="filepath"),
outputs=[
gr.Textbox(label="Malayalam Transcription"),
gr.Textbox(label="English Translation")
],
title="Malayalam Speech Recognition & Translation",
description="Speak in Malayalam → Transcribe using Google Speech Recognition → Translate to English using IndicTrans2."
)
iface.launch(debug=True, share=True)