Spaces:
Running
Running
import sys | |
import time | |
import subprocess | |
from importlib.metadata import version | |
import spaces | |
import torch | |
import torchaudio | |
import torchaudio.transforms as T | |
import gradio as gr | |
from transformers import AutoModelForCTC, Wav2Vec2BertProcessor | |
from torchaudio.models.decoder import ctc_decoder | |
# Install kenlm | |
res = subprocess.check_output( | |
"pip install https://github.com/kpu/kenlm/archive/master.zip --no-build-isolation", | |
stderr=subprocess.STDOUT, | |
shell=True) | |
print(res) | |
use_cuda = torch.cuda.is_available() | |
if use_cuda: | |
print('CUDA is available, setting correct inference_device variable.') | |
device = 'cuda' | |
torch_dtype = torch.float16 | |
else: | |
device = 'cpu' | |
torch_dtype = torch.float32 | |
# Load the KenLM model | |
decoder = ctc_decoder( | |
lexicon='lm/model_lexicon.txt', | |
tokens='lm/model_tokens_w2v2.txt', | |
lm='lm/lm.binary', | |
nbest=1, | |
beam_size=100, | |
blank_token="<pad>", | |
) | |
# Config | |
model_name = "Yehor/w2v-bert-uk-v2.1-fp16" | |
min_duration = 0.5 | |
max_duration = 60 | |
concurrency_limit = 5 | |
use_torch_compile = False | |
# Load the model | |
asr_model = AutoModelForCTC.from_pretrained(model_name, torch_dtype=torch_dtype, device_map=device) | |
processor = Wav2Vec2BertProcessor.from_pretrained(model_name) | |
if use_torch_compile: | |
asr_model = torch.compile(asr_model) | |
# Elements | |
examples = [ | |
"example_1.wav", | |
"example_2.wav", | |
"example_3.wav", | |
"example_4.wav", | |
"example_5.wav", | |
"example_6.wav", | |
] | |
examples_table = """ | |
| File | Text | | |
| ------------- | ------------- | | |
| `example_1.wav` | тема про яку не люблять говорити офіційні джерела у генштабі і міноборони це хімічна зброя окупанти вже тривалий час використовують хімічну зброю заборонену | | |
| `example_2.wav` | всіма конвенціями якщо спочатку це були гранати з дронів то тепер фіксують випадки застосування | | |
| `example_3.wav` | хімічних снарядів причому склад отруйної речовони різний а отже й наслідки для наших військових теж різні | | |
| `example_4.wav` | використовує на фронті все що має і хімічна зброя не вийняток тож з чим маємо справу розбиралася марія моганисян | | |
| `example_5.wav` | двох тисяч випадків застосування росіянами боєприпасів споряджених небезпечними хімічними речовинами | | |
| `example_6.wav` | на всі писані норми марія моганисян олександр моторний спецкор марафон єдині новини | | |
""".strip() | |
# https://www.tablesgenerator.com/markdown_tables | |
authors_table = """ | |
## Authors | |
Follow them in social networks and **contact** if you need any help or have any questions: | |
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** | | |
|-------------------------------------------------------------------------------------------------| | |
| https://t.me/smlkw in Telegram | | |
| https://x.com/yehor_smoliakov at X | | |
| https://github.com/egorsmkv at GitHub | | |
| https://huggingface.co/Yehor at Hugging Face | | |
| or use [email protected] | | |
""".strip() | |
description_head = f""" | |
# Speech-to-Text for Ukrainian v2.1 with LM | |
## Overview | |
This space uses https://huggingface.co/{model_name} and https://huggingface.co/Yehor/kenlm-uk/tree/main/news/lm-4gram-500k models to recognize audio files. | |
> Due to resource limitations, audio duration **must not** exceed **{max_duration}** seconds. | |
""".strip() | |
description_foot = f""" | |
{authors_table} | |
""".strip() | |
transcription_value = """ | |
Recognized text will appear here. | |
Choose **an example file** below the Run button, upload **your audio file**, or use **the microphone** to record something. | |
""".strip() | |
tech_env = f""" | |
#### Environment | |
- Python: {sys.version} | |
- Torch device: {device} | |
- Torch dtype: {torch_dtype} | |
- Use torch.compile: {use_torch_compile} | |
""".strip() | |
tech_libraries = f""" | |
#### Libraries | |
- torch: {version('torch')} | |
- torchaudio: {version('torchaudio')} | |
- transformers: {version('transformers')} | |
- accelerate: {version('accelerate')} | |
- gradio: {version('gradio')} | |
""".strip() | |
def inference(audio_path, progress=gr.Progress()): | |
if not audio_path: | |
raise gr.Error("Please upload an audio file.") | |
gr.Info("Starting...", duration=1) | |
progress(0, desc="Recognizing") | |
meta = torchaudio.info(audio_path) | |
duration = meta.num_frames / meta.sample_rate | |
if duration < min_duration: | |
raise gr.Error( | |
f"The duration of the file is less than {min_duration} seconds, it is {round(duration, 2)} seconds." | |
) | |
if duration > max_duration: | |
raise gr.Error(f"The duration of the file exceeds {max_duration} seconds.") | |
paths = [ | |
audio_path, | |
] | |
results = [] | |
for path in progress.tqdm(paths, desc="Recognizing...", unit="file"): | |
t0 = time.time() | |
meta = torchaudio.info(audio_path) | |
audio_duration = meta.num_frames / meta.sample_rate | |
audio_input, sr = torchaudio.load(path) | |
if meta.num_channels > 1: | |
audio_input = torch.mean(audio_input, dim=0, keepdim=True) | |
if meta.sample_rate != 16_000: | |
resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype) | |
audio_input = resampler(audio_input) | |
audio_input = audio_input.squeeze().numpy() | |
features = processor([audio_input], sampling_rate=16_000).input_features | |
features = torch.tensor(features).to(device) | |
if torch_dtype == torch.float16: | |
features = features.half() | |
with torch.inference_mode(): | |
logits = asr_model(features).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
predictions = processor.batch_decode(predicted_ids) | |
print("Greedy search:", predicted_ids) | |
# Decode using KenLM | |
decoded = decoder(logits.cpu().to(torch.float32)) | |
batch_tokens = [decoder.idxs_to_tokens(hypo[0].tokens) for hypo in decoded] | |
transcripts = ["".join(tokens) for tokens in batch_tokens] | |
predictions = [it.replace('|', ' ').strip() for it in transcripts] | |
print("KenLM decoded:", predictions) | |
if not predictions: | |
predictions = "-" | |
elapsed_time = round(time.time() - t0, 2) | |
rtf = round(elapsed_time / audio_duration, 4) | |
audio_duration = round(audio_duration, 2) | |
results.append( | |
{ | |
"path": path.split("/")[-1], | |
"transcription": "\n".join(predictions), | |
"audio_duration": audio_duration, | |
"rtf": rtf, | |
} | |
) | |
gr.Success("Finished!", duration=0.5) | |
result_texts = [] | |
for result in results: | |
result_texts.append(f'**{result["path"]}**') | |
result_texts.append("\n\n") | |
result_texts.append(f'> {result["transcription"]}') | |
result_texts.append("\n\n") | |
result_texts.append(f'**Audio duration**: {result["audio_duration"]}') | |
result_texts.append("\n") | |
result_texts.append(f'**Real-Time Factor**: {result["rtf"]}') | |
return "\n".join(result_texts) | |
demo = gr.Blocks( | |
title="Speech-to-Text for Ukrainian", | |
analytics_enabled=False, | |
theme=gr.themes.Base(), | |
) | |
with demo: | |
gr.Markdown(description_head) | |
gr.Markdown("## Usage") | |
with gr.Column(): | |
audio_file = gr.Audio(label="Audio file", type="filepath") | |
transcription = gr.Markdown( | |
label="Transcription", | |
value=transcription_value, | |
) | |
gr.Button("Run").click( | |
inference, | |
concurrency_limit=concurrency_limit, | |
inputs=audio_file, | |
outputs=transcription, | |
) | |
with gr.Row(): | |
gr.Examples(label="Choose an example", inputs=audio_file, examples=examples) | |
gr.Markdown(examples_table) | |
gr.Markdown(description_foot) | |
gr.Markdown("### Gradio app uses:") | |
gr.Markdown(tech_env) | |
gr.Markdown(tech_libraries) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() | |