Yehor's picture
Enhance Models section
215beb3
import sys
import time
from importlib.metadata import version, PackageNotFoundError
try:
import spaces
except ImportError:
print("ZeroGPU is not available, skipping...")
import torch
import torchaudio
import torchaudio.transforms as T
import gradio as gr
from gradio.themes import Soft
from gradio.utils import is_zero_gpu_space
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor
try:
spaces_version = version("spaces")
print("ZeroGPU is available, changing inference call.")
except PackageNotFoundError:
spaces_version = "N/A"
print("ZeroGPU is not available, skipping...")
use_zero_gpu = is_zero_gpu_space()
use_cuda = torch.cuda.is_available()
if use_cuda:
print("CUDA is available, setting correct inference_device variable.")
device = "cuda"
torch_dtype = torch.float16
else:
device = "cpu"
torch_dtype = torch.float32
# Config
model_name = "Yehor/w2v-bert-uk-v2.1-fp16"
min_duration = 0.5
max_duration = 60
concurrency_limit = 5
use_torch_compile = False
# Load the model
asr_model = AutoModelForCTC.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device
)
processor = Wav2Vec2BertProcessor.from_pretrained(model_name)
if use_torch_compile:
asr_model = torch.compile(asr_model)
# Elements
examples = [
"example_1.wav",
"example_2.wav",
"example_3.wav",
"example_4.wav",
"example_5.wav",
"example_6.wav",
]
examples_table = """
| File | Text |
| ------------- | ------------- |
| `example_1.wav` | тема про яку не люблять говорити офіційні джерела у генштабі і міноборони це хімічна зброя окупанти вже тривалий час використовують хімічну зброю заборонену |
| `example_2.wav` | всіма конвенціями якщо спочатку це були гранати з дронів то тепер фіксують випадки застосування |
| `example_3.wav` | хімічних снарядів причому склад отруйної речовони різний а отже й наслідки для наших військових теж різні |
| `example_4.wav` | використовує на фронті все що має і хімічна зброя не вийняток тож з чим маємо справу розбиралася марія моганисян |
| `example_5.wav` | двох тисяч випадків застосування росіянами боєприпасів споряджених небезпечними хімічними речовинами |
| `example_6.wav` | на всі писані норми марія моганисян олександр моторний спецкор марафон єдині новини |
""".strip()
# https://www.tablesgenerator.com/markdown_tables
authors_table = """
## Authors
Follow them in social networks and **contact** if you need any help or have any questions:
| **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use [email protected] |
""".strip()
description_head = f"""
# Speech-to-Text for Ukrainian v2.1
This space uses https://huggingface.co/{model_name} model to recognize audio files.
> Due to resource limitations, audio duration **must not** exceed **{max_duration}** seconds.
""".strip()
transcription_value = """
Recognized text will appear here.
Choose **an example file** below the Run button, upload **your audio file**, or use **the microphone** to record something.
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
- Torch device: {device}
- Torch dtype: {torch_dtype}
- Use torch.compile: {use_torch_compile}
#### Models
##### Acoustic model (Speech-to-Text)
- Name: wav2vec2-bert
- URL: https://huggingface.co/Yehor/w2v-bert-uk-v2.1-fp16
""".strip()
tech_libraries = f"""
#### Libraries
- torch: {version("torch")}
- torchaudio: {version("torchaudio")}
- transformers: {version("transformers")}
- accelerate: {version("accelerate")}
- gradio: {version("gradio")}
""".strip()
def inference(audio_path, progress=gr.Progress()):
if not audio_path:
raise gr.Error("Please upload an audio file.")
gr.Info("Starting...", duration=1)
progress(0, desc="Recognizing")
meta = torchaudio.info(audio_path)
duration = meta.num_frames / meta.sample_rate
if duration < min_duration:
raise gr.Error(
f"The duration of the file is less than {min_duration} seconds, it is {round(duration, 2)} seconds."
)
if duration > max_duration:
raise gr.Error(f"The duration of the file exceeds {max_duration} seconds.")
paths = [
audio_path,
]
results = []
for path in progress.tqdm(paths, desc="Recognizing...", unit="file"):
t0 = time.time()
meta = torchaudio.info(audio_path)
audio_duration = meta.num_frames / meta.sample_rate
audio_input, sr = torchaudio.load(path)
if meta.num_channels > 1:
audio_input = torch.mean(audio_input, dim=0, keepdim=True)
if meta.sample_rate != 16_000:
resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype)
audio_input = resampler(audio_input)
audio_input = audio_input.squeeze().numpy()
features = processor([audio_input], sampling_rate=16_000).input_features
features = torch.tensor(features).to(device)
if torch_dtype == torch.float16:
features = features.half()
with torch.inference_mode():
logits = asr_model(features).logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(predicted_ids)
if not predictions:
predictions = "-"
elapsed_time = round(time.time() - t0, 2)
rtf = round(elapsed_time / audio_duration, 4)
audio_duration = round(audio_duration, 2)
results.append(
{
"path": path.split("/")[-1],
"transcription": "\n".join(predictions),
"audio_duration": audio_duration,
"rtf": rtf,
}
)
gr.Success("Finished!", duration=0.5)
result_texts = []
for result in results:
result_texts.append(f"**{result['path']}**")
result_texts.append("\n\n")
result_texts.append(f"> {result['transcription']}")
result_texts.append("\n\n")
result_texts.append(f"**Audio duration**: {result['audio_duration']}")
result_texts.append("\n")
result_texts.append(f"**Real-Time Factor**: {result['rtf']}")
return "\n".join(result_texts)
inference_func = inference
if use_zero_gpu:
inference_func = spaces.GPU(inference)
def create_app():
tab = gr.Blocks(
title="Speech-to-Text for Ukrainian",
analytics_enabled=False,
theme=Soft(),
)
with tab:
gr.Markdown(description_head)
gr.Markdown("## Usage")
with gr.Column():
audio_file = gr.Audio(label="Audio file", type="filepath")
transcription = gr.Markdown(
label="Transcription",
value=transcription_value,
)
gr.Button("Run").click(
inference_func,
concurrency_limit=concurrency_limit,
inputs=audio_file,
outputs=transcription,
)
with gr.Row():
gr.Examples(label="Choose an example", inputs=audio_file, examples=examples)
gr.Markdown(examples_table)
return tab
def create_env():
with gr.Blocks(theme=Soft()) as tab:
gr.Markdown(tech_env)
gr.Markdown(tech_libraries)
return tab
def create_authors():
with gr.Blocks(theme=Soft()) as tab:
gr.Markdown(authors_table)
return tab
def create_demo():
app_tab = create_app()
authors_tab = create_authors()
env_tab = create_env()
return gr.TabbedInterface(
[app_tab, authors_tab, env_tab],
tab_names=[
"🎙️ Recognition",
"👥 Authors",
"📦 Environment, Models, and Libraries",
],
)
if __name__ == "__main__":
demo = create_demo()
demo.queue()
demo.launch()