DHEIVER's picture
Update app.py
6db90a4 verified
raw
history blame contribute delete
7.95 kB
import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import AutoProcessor, SeamlessM4Tv2Model
class TranslationModel:
def __init__(self):
self.model_name = "facebook/seamless-m4t-v2-large"
print("Loading model...")
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.model = SeamlessM4Tv2Model.from_pretrained(self.model_name)
self.sample_rate = self.model.config.sampling_rate
self.languages = {
"English": "eng",
"Spanish": "spa",
"French": "fra",
"German": "deu",
"Italian": "ita",
"Portuguese": "por",
"Russian": "rus",
"Chinese": "cmn",
"Japanese": "jpn",
"Korean": "kor"
}
def translate_text(self, text, src_lang, tgt_lang, progress=gr.Progress()):
try:
progress(0.3, desc="Processing...")
inputs = self.processor(text=text, src_lang=self.languages[src_lang], return_tensors="pt")
progress(0.6, desc="Generating...")
audio_array = self.model.generate(**inputs, tgt_lang=self.languages[tgt_lang])[0].cpu().numpy().squeeze()
progress(1.0, desc="Complete")
return (self.sample_rate, audio_array)
except Exception as e:
raise gr.Error(str(e))
def translate_audio(self, audio_path, tgt_lang, progress=gr.Progress()):
if not audio_path:
raise gr.Error("Please upload an audio file")
try:
progress(0.3, desc="Processing...")
audio, orig_freq = torchaudio.load(audio_path)
audio = torchaudio.functional.resample(audio, orig_freq=orig_freq, new_freq=16000)
progress(0.6, desc="Translating...")
inputs = self.processor(audios=audio, return_tensors="pt")
audio_array = self.model.generate(**inputs, tgt_lang=self.languages[tgt_lang])[0].cpu().numpy().squeeze()
progress(1.0, desc="Complete")
return (self.sample_rate, audio_array)
except Exception as e:
raise gr.Error(str(e))
css = """
:root {
--primary-color: #2D3648;
--secondary-color: #5E6AD2;
--background-color: #F5F7FF;
--text-color: #2D3648;
--border-radius: 12px;
--spacing: 20px;
}
.gradio-container {
background-color: var(--background-color) !important;
}
.main-container {
max-width: 1200px !important;
margin: 0 auto !important;
padding: var(--spacing) !important;
}
.app-header {
text-align: center;
padding: 40px 20px;
background: linear-gradient(45deg, var(--primary-color), var(--secondary-color));
border-radius: var(--border-radius);
color: white !important;
margin-bottom: var(--spacing);
}
.app-title {
font-size: 2.5em;
font-weight: 700;
margin-bottom: 10px;
color: white !important;
}
.app-subtitle {
font-size: 1.2em;
opacity: 0.9;
color: white !important;
}
.content-block {
background: white;
padding: var(--spacing);
border-radius: var(--border-radius);
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
margin-bottom: var(--spacing);
}
.gr-button {
background: var(--secondary-color) !important;
border: none !important;
color: white !important;
}
.gr-button:hover {
box-shadow: 0 4px 10px rgba(94, 106, 210, 0.3) !important;
transform: translateY(-1px);
}
.gr-input, .gr-select {
border-radius: 8px !important;
border: 2px solid #E5E7EB !important;
padding: 12px !important;
}
.gr-input:focus, .gr-select:focus {
border-color: var(--secondary-color) !important;
box-shadow: 0 0 0 3px rgba(94, 106, 210, 0.1) !important;
}
.gr-form {
background: white !important;
padding: var(--spacing) !important;
border-radius: var(--border-radius) !important;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05) !important;
}
.gr-box {
border-radius: var(--border-radius) !important;
border: none !important;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05) !important;
}
.footer {
text-align: center;
color: var(--text-color);
padding: var(--spacing);
opacity: 0.8;
}
/* Custom Tabs Styling */
.tab-nav {
background: white !important;
padding: 10px !important;
border-radius: var(--border-radius) !important;
margin-bottom: var(--spacing) !important;
}
.tab-nav button {
border-radius: 8px !important;
padding: 12px 24px !important;
}
.tab-nav button.selected {
background: var(--secondary-color) !important;
color: white !important;
}
"""
def create_ui():
model = TranslationModel()
with gr.Blocks(css=css, title="AI Language Translator") as demo:
gr.HTML(
"""
<div class="app-header">
<div class="app-title">AI Language Translator</div>
<div class="app-subtitle">Powered by Neural Machine Translation</div>
</div>
"""
)
with gr.Tabs():
# Text Translation Tab
with gr.Tab("Text to Speech"):
with gr.Column(variant="panel"):
gr.Markdown("### Enter Text")
text_input = gr.Textbox(
label="",
placeholder="Type or paste your text here...",
lines=4
)
with gr.Row():
src_lang = gr.Dropdown(
choices=sorted(model.languages.keys()),
value="English",
label="From"
)
tgt_lang = gr.Dropdown(
choices=sorted(model.languages.keys()),
value="Spanish",
label="To"
)
translate_btn = gr.Button("Translate", size="lg")
gr.Markdown("### Translation Output")
audio_output = gr.Audio(
label="",
type="numpy",
show_download_button=True
)
# Audio Translation Tab
with gr.Tab("Speech to Speech"):
with gr.Column(variant="panel"):
gr.Markdown("### Upload Audio")
audio_input = gr.Audio(
label="",
type="filepath",
sources=["upload", "microphone"]
)
tgt_lang_audio = gr.Dropdown(
choices=sorted(model.languages.keys()),
value="English",
label="Translate to"
)
translate_audio_btn = gr.Button("Translate Audio", size="lg")
gr.Markdown("### Translation Output")
audio_output_from_audio = gr.Audio(
label="",
type="numpy",
show_download_button=True
)
gr.HTML(
"""
<div class="footer">
Built with ❤️ using Meta's SeamlessM4T and Gradio
</div>
"""
)
# Event handlers
translate_btn.click(
fn=model.translate_text,
inputs=[text_input, src_lang, tgt_lang],
outputs=audio_output
)
translate_audio_btn.click(
fn=model.translate_audio,
inputs=[audio_input, tgt_lang_audio],
outputs=audio_output_from_audio
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.queue()
demo.launch()