Spaces:
Running
Running
import os | |
from datetime import datetime | |
from pathlib import Path | |
import huggingface_hub | |
import jiwer | |
import pandas as pd | |
import requests | |
import streamlit as st | |
from huggingface_hub import HfFileSystem | |
from st_fixed_container import st_fixed_container | |
from visual_eval.evaluator import HebrewTextNormalizer | |
from visual_eval.visualization import render_visualize_jiwer_result_html | |
HF_API_TOKEN = None | |
try: | |
HF_API_TOKEN = st.secrets["HF_API_TOKEN"] | |
except FileNotFoundError: | |
HF_API_TOKEN = os.environ.get("HF_API_TOKEN") | |
has_api_token = HF_API_TOKEN is not None | |
known_datasets = [ | |
("ivrit-ai/eval-d1:test:text", None, "ivrit_ai_eval_d1"), | |
("upai-inc/saspeech:test:text", None, "saspeech"), | |
("google/fleurs:test:transcription", "he_il", "fleurs"), | |
("mozilla-foundation/common_voice_17_0:test:sentence", "he", "common_voice_17"), | |
("imvladikon/hebrew_speech_kan:validation:sentence", None, "hebrew_speech_kan"), | |
] | |
# Initialize session state for audio cache if it doesn't exist | |
if "audio_cache" not in st.session_state: | |
st.session_state.audio_cache = {} | |
if "audio_preview_active" not in st.session_state: | |
st.session_state.audio_preview_active = {} | |
if "uploaded_file" not in st.session_state: | |
st.session_state.results_file = None | |
if "selected_entry_idx" not in st.session_state: | |
st.session_state.selected_entry_idx = 0 | |
if "total_entry_count" not in st.session_state: | |
st.session_state.total_entry_count = 0 | |
if "entry_page_size" not in st.session_state: | |
st.session_state.entry_page_size = 20 | |
def get_current_page_slice(): | |
ss = st.session_state | |
if ss.total_entry_count == 0: | |
return slice(0, 0) | |
page_first_entry = ( | |
st.session_state.selected_entry_idx // ss.entry_page_size | |
) * ss.entry_page_size | |
page_last_entry = min(page_first_entry + ss.entry_page_size, ss.total_entry_count) | |
return slice(page_first_entry, page_last_entry) | |
def page_navigation(): | |
ss = st.session_state | |
current_page_slice = get_current_page_slice() | |
has_next_page = current_page_slice.stop < ss.total_entry_count - 1 | |
has_prev_page = current_page_slice.start >= ss.entry_page_size | |
col1, col2 = st.columns(2) | |
if col1.button("Prev Page", disabled=not has_prev_page): | |
ss.selected_entry_idx = current_page_slice.start - 1 | |
st.rerun() | |
if col2.button("Next Page", disabled=not has_next_page): | |
ss.selected_entry_idx = current_page_slice.start + ss.entry_page_size | |
st.rerun() | |
def on_file_upload(): | |
st.session_state.audio_cache = {} | |
st.session_state.audio_preview_active = {} | |
st.session_state.selected_entry_idx = 0 | |
st.session_state.results_file = None | |
def get_leaderboard_result_csv_paths(root_search_path): | |
fs = HfFileSystem(token=HF_API_TOKEN) | |
found_files = fs.glob(f"{root_search_path}/*/*.csv") | |
found_files_relative_paths = [f.split(root_search_path)[1] for f in found_files] | |
return found_files_relative_paths | |
def choose_input_file_from_leaderboard(): | |
if not has_api_token: | |
st.rerun() | |
root_search_path = "ivrit-ai/hebrew-transcription-leaderboard/results" | |
fsspec_spaces_root_search_path = f"spaces/{root_search_path}" | |
found_files_relative_paths = get_leaderboard_result_csv_paths( | |
fsspec_spaces_root_search_path | |
) | |
selected_file = st.selectbox( | |
"Select a CSV file from the leaderboard:", | |
found_files_relative_paths, | |
index=None, | |
) | |
# Get the selected file | |
if selected_file: | |
paths_part = Path(selected_file).parent | |
file_part = Path(selected_file).name | |
uploaded_file = huggingface_hub.hf_hub_url( | |
repo_id="ivrit-ai/hebrew-transcription-leaderboard", | |
subfolder=f"results{paths_part}", | |
filename=file_part, | |
repo_type="space", | |
) | |
on_file_upload() | |
st.session_state.results_file = uploaded_file | |
st.rerun() | |
def read_results_csv(uploaded_file): | |
with st.spinner("Loading results...", show_time=True): | |
results_df = pd.read_csv(uploaded_file) | |
return results_df | |
def display_rtl(html): | |
"""Render an RTL container with the provided HTML string""" | |
st.markdown( | |
f""" | |
<div dir="rtl" lang="he"> | |
{html} | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
def calculate_final_metrics(uploaded_file, _df): | |
"""Calculate final metrics for all entries | |
Args: | |
uploaded_file: The uploaded file object (For cache hash gen) | |
_df: The dataframe containing the evaluation results (not included in cache hash) | |
Returns: | |
A dictionary containing the final metrics | |
""" | |
_df = _df.sort_values(by=["id"]) | |
_df["reference_text"] = _df["reference_text"].fillna("") | |
_df["predicted_text"] = _df["predicted_text"].fillna("") | |
# convert to list of dicts | |
entries_data = _df.to_dict(orient="records") | |
htn = HebrewTextNormalizer() | |
# Calculate final metrics | |
results = jiwer.process_words( | |
[htn(entry["reference_text"]) for entry in entries_data], | |
[htn(entry["predicted_text"]) for entry in entries_data], | |
) | |
return results | |
def get_known_dataset_by_output_name(output_name): | |
for dataset in known_datasets: | |
if dataset[2] == output_name: | |
return dataset | |
return None | |
def get_dataset_entries_audio_urls(dataset, offset=0, max_entries=100): | |
if dataset is None or not has_api_token: | |
return None | |
dataset_repo_id, dataset_config, _ = dataset | |
if not dataset_config: | |
dataset_config = "default" | |
if ":" in dataset_repo_id: | |
dataset_repo_id, split, _ = dataset_repo_id.split(":") | |
else: | |
split = "test" | |
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} | |
api_query_params = { | |
"dataset": dataset_repo_id, | |
"config": dataset_config, | |
"split": split, | |
"offset": offset, | |
"length": max_entries, | |
} | |
query_params_str = "&".join([f"{k}={v}" for k, v in api_query_params.items()]) | |
API_URL = f"https://datasets-server.huggingface.co/rows?{query_params_str}" | |
def query(): | |
response = requests.get(API_URL, headers=headers) | |
return response.json() | |
data = query() | |
def get_audio_url(row): | |
audio_feature_list = row["row"]["audio"] | |
first_audio = audio_feature_list[0] | |
return first_audio["src"] | |
if "rows" in data and len(data["rows"]) > 0: | |
return [get_audio_url(row) for row in data["rows"]] | |
else: | |
return None | |
def get_audio_url_for_entry( | |
dataset, entry_idx, cache_neighbors=True, neighbor_range=20 | |
): | |
""" | |
Get audio URL for a specific entry and optionally cache neighbors | |
Args: | |
dataset: Dataset tuple (repo_id, config, output_name) | |
entry_idx: Index of the entry to get audio URL for | |
cache_neighbors: Whether to cache audio URLs for neighboring entries | |
neighbor_range: Range of neighboring entries to cache | |
Returns: | |
Audio URL for the specified entry | |
""" | |
# Calculate the range of entries to load | |
if cache_neighbors: | |
start_idx = max(0, entry_idx - neighbor_range) | |
max_entries = neighbor_range * 2 + 1 | |
else: | |
start_idx = entry_idx | |
max_entries = 1 | |
# Get audio URLs for the range of entries | |
audio_urls = get_dataset_entries_audio_urls(dataset, start_idx, max_entries) | |
if not audio_urls: | |
return None | |
# Cache the audio URLs | |
for i, url in enumerate(audio_urls): | |
idx = start_idx + i | |
# Extract expiration time from URL if available | |
expires = None | |
if "expires=" in url: | |
try: | |
expires_param = url.split("expires=")[1].split("&")[0] | |
expires = datetime.fromtimestamp(int(expires_param)) | |
except (ValueError, IndexError): | |
expires = None | |
st.session_state.audio_cache[idx] = {"url": url, "expires": expires} | |
# Return the URL for the requested entry | |
relative_idx = entry_idx - start_idx | |
if 0 <= relative_idx < len(audio_urls): | |
return audio_urls[relative_idx] | |
return None | |
def get_cached_audio_url(entry_idx): | |
""" | |
Get audio URL from cache if available and not expired | |
Args: | |
entry_idx: Index of the entry to get audio URL for | |
Returns: | |
Audio URL if available in cache and not expired, None otherwise | |
""" | |
if entry_idx not in st.session_state.audio_cache: | |
return None | |
cache_entry = st.session_state.audio_cache[entry_idx] | |
# Check if the URL is expired | |
if cache_entry["expires"] and datetime.now() > cache_entry["expires"]: | |
return None | |
return cache_entry["url"] | |
def main(): | |
st.set_page_config( | |
page_title="ASR Evaluation Visualizer", page_icon="🎤", layout="wide" | |
) | |
if not has_api_token: | |
st.warning("No Hugging Face API token found. Audio previews will not work.") | |
st.title("ASR Evaluation Visualizer") | |
# File uploader | |
uploaded_file = st.file_uploader( | |
"Upload evaluation results CSV", | |
type=["csv"], | |
on_change=on_file_upload, | |
key="uploaded_file", | |
) | |
if uploaded_file is not None: | |
st.session_state.results_file = uploaded_file | |
if uploaded_file is None: | |
st.write("Or:") | |
if st.button("Choose from leaderboard"): | |
choose_input_file_from_leaderboard() | |
if st.session_state.results_file is not None: | |
uploaded_file = st.session_state.results_file | |
# Load the data | |
try: | |
eval_results = read_results_csv(uploaded_file) | |
st.session_state.total_entry_count = len(eval_results) | |
st.success("File uploaded successfully!") | |
with st.sidebar: | |
# Toggle for calculating total metrics | |
show_total_metrics = st.toggle("Show total metrics", value=False) | |
if show_total_metrics: | |
total_metrics = calculate_final_metrics(uploaded_file, eval_results) | |
# Display total metrics in a nice format | |
with st.container(): | |
st.metric("WER", f"{total_metrics.wer * 100:.4f}%") | |
st.table( | |
{ | |
"Hits": total_metrics.hits, | |
"Subs": total_metrics.substitutions, | |
"Dels": total_metrics.deletions, | |
"Insrt": total_metrics.insertions, | |
} | |
) | |
# Toggle for normalized vs raw text | |
use_normalized = st.sidebar.toggle("Use normalized text", value=True) | |
# Create sidebar for entry selection | |
st.sidebar.header("Select Entry") | |
# Add Next/Prev buttons at the top of the sidebar | |
col1, col2 = st.sidebar.columns(2) | |
# Define navigation functions | |
def go_prev(): | |
if st.session_state.selected_entry_idx > 0: | |
st.session_state.selected_entry_idx -= 1 | |
def go_next(): | |
if st.session_state.selected_entry_idx < len(eval_results) - 1: | |
st.session_state.selected_entry_idx += 1 | |
# Add navigation buttons | |
col1.button("← Prev", on_click=go_prev, use_container_width=True) | |
col2.button("Next →", on_click=go_next, use_container_width=True) | |
# Use a container for better styling | |
entry_container = st.sidebar.container() | |
with entry_container: | |
page_navigation() | |
st.write(f"Total entries: {st.session_state.total_entry_count}") | |
# Create a data table with entries and their WER | |
entries_data = [] | |
for i in range(len(eval_results)): | |
wer_value = eval_results.iloc[i].get("wer", 0) | |
# Format WER as percentage | |
wer_formatted = ( | |
f"{wer_value*100:.2f}%" | |
if isinstance(wer_value, (int, float)) | |
else wer_value | |
) | |
entries_data.append({"Entry": f"Entry #{i+1}", "WER": wer_formatted}) | |
# Create a selection mechanism using radio buttons that look like a table | |
st.sidebar.write("Select an entry") | |
# Create a radio button for each entry, styled to look like a table row | |
current_page_slice = get_current_page_slice() | |
entry_container.radio( | |
"Select an entry", | |
options=list(range(len(eval_results))[current_page_slice]), | |
format_func=lambda i: f"Entry #{i+1} ({entries_data[i]['WER']})", | |
label_visibility="collapsed", | |
key="selected_entry_idx", | |
) | |
# Use the selected entry | |
selected_entry = st.session_state.selected_entry_idx | |
# Get the text columns based on the toggle | |
if use_normalized and "norm_reference_text" in eval_results.columns: | |
ref_col, hyp_col = "norm_reference_text", "norm_predicted_text" | |
else: | |
ref_col, hyp_col = "reference_text", "predicted_text" | |
# Get the reference and hypothesis texts | |
ref, hyp = eval_results.iloc[selected_entry][[ref_col, hyp_col]].values | |
st.header("Visualization") | |
# Check if the CSV file is from a known dataset | |
dataset_name = None | |
# If no dataset column, try to infer from filename | |
if uploaded_file is not None: | |
if isinstance(uploaded_file, str): | |
filename_stem = Path(uploaded_file).stem | |
else: | |
filename_stem = Path(uploaded_file.name).stem | |
dataset_name = filename_stem | |
if not dataset_name and "dataset" in eval_results.columns: | |
dataset_name = eval_results.iloc[selected_entry]["dataset"] | |
# Get the known dataset if available | |
known_dataset = get_known_dataset_by_output_name(dataset_name) | |
# Display audio preview button if from a known dataset | |
if known_dataset: | |
# Check if we have the audio URL in cache | |
audio_url = get_cached_audio_url(selected_entry) | |
audio_preview_active = st.session_state.audio_preview_active.get( | |
selected_entry, False | |
) | |
preview_audio = False | |
if not audio_preview_active: | |
# Create a button to preview audio | |
preview_audio = st.button("Preview Audio", key="preview_audio") | |
if preview_audio or audio_url: | |
st.session_state.audio_preview_active[selected_entry] = True | |
with st_fixed_container( | |
mode="sticky", position="top", border=True, margin=0 | |
): | |
# If button clicked or we already have the URL, get/use the audio URL | |
if not audio_url: | |
with st.spinner("Loading audio..."): | |
audio_url = get_audio_url_for_entry( | |
known_dataset, selected_entry | |
) | |
# Display the audio player in the sticky container at the top | |
if audio_url: | |
st.audio(audio_url) | |
else: | |
st.error("Failed to load audio for this entry.") | |
# Display the visualization | |
html = render_visualize_jiwer_result_html(ref, hyp) | |
display_rtl(html) | |
# Display metadata | |
st.header("Metadata") | |
metadata_cols = [ | |
"metadata_uuid", | |
"model", | |
"dataset", | |
"dataset_split", | |
"engine", | |
] | |
metadata = eval_results.iloc[selected_entry][metadata_cols] | |
# Create a DataFrame for better display | |
metadata_df = pd.DataFrame( | |
{"Field": metadata_cols, "Value": metadata.values} | |
) | |
st.table(metadata_df) | |
# If we have audio URL, display it in the sticky container | |
if "audio_url" in locals() and audio_url: | |
pass # CSS is now applied globally | |
except Exception as e: | |
st.error(f"Error processing file: {str(e)}") | |
else: | |
st.info( | |
"Please upload an evaluation results CSV file to visualize the results." | |
) | |
st.markdown( | |
""" | |
### Expected CSV Format | |
The CSV should have the following columns: | |
- id | |
- reference_text | |
- predicted_text | |
- norm_reference_text | |
- norm_predicted_text | |
- wer | |
- wil | |
- substitutions | |
- deletions | |
- insertions | |
- hits | |
- metadata_uuid | |
- model | |
- dataset | |
- dataset_split | |
- engine | |
""" | |
) | |
if __name__ == "__main__": | |
main() | |