import os
from datetime import datetime
from pathlib import Path
import huggingface_hub
import jiwer
import pandas as pd
import requests
import streamlit as st
from huggingface_hub import HfFileSystem
from st_fixed_container import st_fixed_container
from visual_eval.evaluator import HebrewTextNormalizer
from visual_eval.visualization import render_visualize_jiwer_result_html
HF_API_TOKEN = None
try:
HF_API_TOKEN = st.secrets["HF_API_TOKEN"]
except FileNotFoundError:
HF_API_TOKEN = os.environ.get("HF_API_TOKEN")
has_api_token = HF_API_TOKEN is not None
known_datasets = [
("ivrit-ai/eval-d1:test:text", None, "ivrit_ai_eval_d1"),
("upai-inc/saspeech:test:text", None, "saspeech"),
("google/fleurs:test:transcription", "he_il", "fleurs"),
("mozilla-foundation/common_voice_17_0:test:sentence", "he", "common_voice_17"),
("imvladikon/hebrew_speech_kan:validation:sentence", None, "hebrew_speech_kan"),
]
# Initialize session state for audio cache if it doesn't exist
if "audio_cache" not in st.session_state:
st.session_state.audio_cache = {}
if "audio_preview_active" not in st.session_state:
st.session_state.audio_preview_active = {}
if "uploaded_file" not in st.session_state:
st.session_state.results_file = None
if "selected_entry_idx" not in st.session_state:
st.session_state.selected_entry_idx = 0
if "total_entry_count" not in st.session_state:
st.session_state.total_entry_count = 0
if "entry_page_size" not in st.session_state:
st.session_state.entry_page_size = 20
def get_current_page_slice():
ss = st.session_state
if ss.total_entry_count == 0:
return slice(0, 0)
page_first_entry = (
st.session_state.selected_entry_idx // ss.entry_page_size
) * ss.entry_page_size
page_last_entry = min(page_first_entry + ss.entry_page_size, ss.total_entry_count)
return slice(page_first_entry, page_last_entry)
def page_navigation():
ss = st.session_state
current_page_slice = get_current_page_slice()
has_next_page = current_page_slice.stop < ss.total_entry_count - 1
has_prev_page = current_page_slice.start >= ss.entry_page_size
col1, col2 = st.columns(2)
if col1.button("Prev Page", disabled=not has_prev_page):
ss.selected_entry_idx = current_page_slice.start - 1
st.rerun()
if col2.button("Next Page", disabled=not has_next_page):
ss.selected_entry_idx = current_page_slice.start + ss.entry_page_size
st.rerun()
def on_file_upload():
st.session_state.audio_cache = {}
st.session_state.audio_preview_active = {}
st.session_state.selected_entry_idx = 0
st.session_state.results_file = None
@st.cache_data
def get_leaderboard_result_csv_paths(root_search_path):
fs = HfFileSystem(token=HF_API_TOKEN)
found_files = fs.glob(f"{root_search_path}/*/*.csv")
found_files_relative_paths = [f.split(root_search_path)[1] for f in found_files]
return found_files_relative_paths
@st.dialog("View Leaderboard Results")
def choose_input_file_from_leaderboard():
if not has_api_token:
st.rerun()
root_search_path = "ivrit-ai/hebrew-transcription-leaderboard/results"
fsspec_spaces_root_search_path = f"spaces/{root_search_path}"
found_files_relative_paths = get_leaderboard_result_csv_paths(
fsspec_spaces_root_search_path
)
selected_file = st.selectbox(
"Select a CSV file from the leaderboard:",
found_files_relative_paths,
index=None,
)
# Get the selected file
if selected_file:
paths_part = Path(selected_file).parent
file_part = Path(selected_file).name
uploaded_file = huggingface_hub.hf_hub_url(
repo_id="ivrit-ai/hebrew-transcription-leaderboard",
subfolder=f"results{paths_part}",
filename=file_part,
repo_type="space",
)
on_file_upload()
st.session_state.results_file = uploaded_file
st.rerun()
@st.cache_data
def read_results_csv(uploaded_file):
with st.spinner("Loading results...", show_time=True):
results_df = pd.read_csv(uploaded_file)
return results_df
def display_rtl(html):
"""Render an RTL container with the provided HTML string"""
st.markdown(
f"""
{html}
""",
unsafe_allow_html=True,
)
@st.cache_data
def calculate_final_metrics(uploaded_file, _df):
"""Calculate final metrics for all entries
Args:
uploaded_file: The uploaded file object (For cache hash gen)
_df: The dataframe containing the evaluation results (not included in cache hash)
Returns:
A dictionary containing the final metrics
"""
_df = _df.sort_values(by=["id"])
_df["reference_text"] = _df["reference_text"].fillna("")
_df["predicted_text"] = _df["predicted_text"].fillna("")
# convert to list of dicts
entries_data = _df.to_dict(orient="records")
htn = HebrewTextNormalizer()
# Calculate final metrics
results = jiwer.process_words(
[htn(entry["reference_text"]) for entry in entries_data],
[htn(entry["predicted_text"]) for entry in entries_data],
)
return results
def get_known_dataset_by_output_name(output_name):
for dataset in known_datasets:
if dataset[2] == output_name:
return dataset
return None
def get_dataset_entries_audio_urls(dataset, offset=0, max_entries=100):
if dataset is None or not has_api_token:
return None
dataset_repo_id, dataset_config, _ = dataset
if not dataset_config:
dataset_config = "default"
if ":" in dataset_repo_id:
dataset_repo_id, split, _ = dataset_repo_id.split(":")
else:
split = "test"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
api_query_params = {
"dataset": dataset_repo_id,
"config": dataset_config,
"split": split,
"offset": offset,
"length": max_entries,
}
query_params_str = "&".join([f"{k}={v}" for k, v in api_query_params.items()])
API_URL = f"https://datasets-server.huggingface.co/rows?{query_params_str}"
def query():
response = requests.get(API_URL, headers=headers)
return response.json()
data = query()
def get_audio_url(row):
audio_feature_list = row["row"]["audio"]
first_audio = audio_feature_list[0]
return first_audio["src"]
if "rows" in data and len(data["rows"]) > 0:
return [get_audio_url(row) for row in data["rows"]]
else:
return None
def get_audio_url_for_entry(
dataset, entry_idx, cache_neighbors=True, neighbor_range=20
):
"""
Get audio URL for a specific entry and optionally cache neighbors
Args:
dataset: Dataset tuple (repo_id, config, output_name)
entry_idx: Index of the entry to get audio URL for
cache_neighbors: Whether to cache audio URLs for neighboring entries
neighbor_range: Range of neighboring entries to cache
Returns:
Audio URL for the specified entry
"""
# Calculate the range of entries to load
if cache_neighbors:
start_idx = max(0, entry_idx - neighbor_range)
max_entries = neighbor_range * 2 + 1
else:
start_idx = entry_idx
max_entries = 1
# Get audio URLs for the range of entries
audio_urls = get_dataset_entries_audio_urls(dataset, start_idx, max_entries)
if not audio_urls:
return None
# Cache the audio URLs
for i, url in enumerate(audio_urls):
idx = start_idx + i
# Extract expiration time from URL if available
expires = None
if "expires=" in url:
try:
expires_param = url.split("expires=")[1].split("&")[0]
expires = datetime.fromtimestamp(int(expires_param))
except (ValueError, IndexError):
expires = None
st.session_state.audio_cache[idx] = {"url": url, "expires": expires}
# Return the URL for the requested entry
relative_idx = entry_idx - start_idx
if 0 <= relative_idx < len(audio_urls):
return audio_urls[relative_idx]
return None
def get_cached_audio_url(entry_idx):
"""
Get audio URL from cache if available and not expired
Args:
entry_idx: Index of the entry to get audio URL for
Returns:
Audio URL if available in cache and not expired, None otherwise
"""
if entry_idx not in st.session_state.audio_cache:
return None
cache_entry = st.session_state.audio_cache[entry_idx]
# Check if the URL is expired
if cache_entry["expires"] and datetime.now() > cache_entry["expires"]:
return None
return cache_entry["url"]
def main():
st.set_page_config(
page_title="ASR Evaluation Visualizer", page_icon="🎤", layout="wide"
)
if not has_api_token:
st.warning("No Hugging Face API token found. Audio previews will not work.")
st.title("ASR Evaluation Visualizer")
# File uploader
uploaded_file = st.file_uploader(
"Upload evaluation results CSV",
type=["csv"],
on_change=on_file_upload,
key="uploaded_file",
)
if uploaded_file is not None:
st.session_state.results_file = uploaded_file
if uploaded_file is None:
st.write("Or:")
if st.button("Choose from leaderboard"):
choose_input_file_from_leaderboard()
if st.session_state.results_file is not None:
uploaded_file = st.session_state.results_file
# Load the data
try:
eval_results = read_results_csv(uploaded_file)
st.session_state.total_entry_count = len(eval_results)
st.success("File uploaded successfully!")
with st.sidebar:
# Toggle for calculating total metrics
show_total_metrics = st.toggle("Show total metrics", value=False)
if show_total_metrics:
total_metrics = calculate_final_metrics(uploaded_file, eval_results)
# Display total metrics in a nice format
with st.container():
st.metric("WER", f"{total_metrics.wer * 100:.4f}%")
st.table(
{
"Hits": total_metrics.hits,
"Subs": total_metrics.substitutions,
"Dels": total_metrics.deletions,
"Insrt": total_metrics.insertions,
}
)
# Toggle for normalized vs raw text
use_normalized = st.sidebar.toggle("Use normalized text", value=True)
# Create sidebar for entry selection
st.sidebar.header("Select Entry")
# Add Next/Prev buttons at the top of the sidebar
col1, col2 = st.sidebar.columns(2)
# Define navigation functions
def go_prev():
if st.session_state.selected_entry_idx > 0:
st.session_state.selected_entry_idx -= 1
def go_next():
if st.session_state.selected_entry_idx < len(eval_results) - 1:
st.session_state.selected_entry_idx += 1
# Add navigation buttons
col1.button("← Prev", on_click=go_prev, use_container_width=True)
col2.button("Next →", on_click=go_next, use_container_width=True)
# Use a container for better styling
entry_container = st.sidebar.container()
with entry_container:
page_navigation()
st.write(f"Total entries: {st.session_state.total_entry_count}")
# Create a data table with entries and their WER
entries_data = []
for i in range(len(eval_results)):
wer_value = eval_results.iloc[i].get("wer", 0)
# Format WER as percentage
wer_formatted = (
f"{wer_value*100:.2f}%"
if isinstance(wer_value, (int, float))
else wer_value
)
entries_data.append({"Entry": f"Entry #{i+1}", "WER": wer_formatted})
# Create a selection mechanism using radio buttons that look like a table
st.sidebar.write("Select an entry")
# Create a radio button for each entry, styled to look like a table row
current_page_slice = get_current_page_slice()
entry_container.radio(
"Select an entry",
options=list(range(len(eval_results))[current_page_slice]),
format_func=lambda i: f"Entry #{i+1} ({entries_data[i]['WER']})",
label_visibility="collapsed",
key="selected_entry_idx",
)
# Use the selected entry
selected_entry = st.session_state.selected_entry_idx
# Get the text columns based on the toggle
if use_normalized and "norm_reference_text" in eval_results.columns:
ref_col, hyp_col = "norm_reference_text", "norm_predicted_text"
else:
ref_col, hyp_col = "reference_text", "predicted_text"
# Get the reference and hypothesis texts
ref, hyp = eval_results.iloc[selected_entry][[ref_col, hyp_col]].values
st.header("Visualization")
# Check if the CSV file is from a known dataset
dataset_name = None
# If no dataset column, try to infer from filename
if uploaded_file is not None:
if isinstance(uploaded_file, str):
filename_stem = Path(uploaded_file).stem
else:
filename_stem = Path(uploaded_file.name).stem
dataset_name = filename_stem
if not dataset_name and "dataset" in eval_results.columns:
dataset_name = eval_results.iloc[selected_entry]["dataset"]
# Get the known dataset if available
known_dataset = get_known_dataset_by_output_name(dataset_name)
# Display audio preview button if from a known dataset
if known_dataset:
# Check if we have the audio URL in cache
audio_url = get_cached_audio_url(selected_entry)
audio_preview_active = st.session_state.audio_preview_active.get(
selected_entry, False
)
preview_audio = False
if not audio_preview_active:
# Create a button to preview audio
preview_audio = st.button("Preview Audio", key="preview_audio")
if preview_audio or audio_url:
st.session_state.audio_preview_active[selected_entry] = True
with st_fixed_container(
mode="sticky", position="top", border=True, margin=0
):
# If button clicked or we already have the URL, get/use the audio URL
if not audio_url:
with st.spinner("Loading audio..."):
audio_url = get_audio_url_for_entry(
known_dataset, selected_entry
)
# Display the audio player in the sticky container at the top
if audio_url:
st.audio(audio_url)
else:
st.error("Failed to load audio for this entry.")
# Display the visualization
html = render_visualize_jiwer_result_html(ref, hyp)
display_rtl(html)
# Display metadata
st.header("Metadata")
metadata_cols = [
"metadata_uuid",
"model",
"dataset",
"dataset_split",
"engine",
]
metadata = eval_results.iloc[selected_entry][metadata_cols]
# Create a DataFrame for better display
metadata_df = pd.DataFrame(
{"Field": metadata_cols, "Value": metadata.values}
)
st.table(metadata_df)
# If we have audio URL, display it in the sticky container
if "audio_url" in locals() and audio_url:
pass # CSS is now applied globally
except Exception as e:
st.error(f"Error processing file: {str(e)}")
else:
st.info(
"Please upload an evaluation results CSV file to visualize the results."
)
st.markdown(
"""
### Expected CSV Format
The CSV should have the following columns:
- id
- reference_text
- predicted_text
- norm_reference_text
- norm_predicted_text
- wer
- wil
- substitutions
- deletions
- insertions
- hits
- metadata_uuid
- model
- dataset
- dataset_split
- engine
"""
)
if __name__ == "__main__":
main()