import os from datetime import datetime from pathlib import Path import huggingface_hub import jiwer import pandas as pd import requests import streamlit as st from huggingface_hub import HfFileSystem from st_fixed_container import st_fixed_container from visual_eval.evaluator import HebrewTextNormalizer from visual_eval.visualization import render_visualize_jiwer_result_html HF_API_TOKEN = None try: HF_API_TOKEN = st.secrets["HF_API_TOKEN"] except FileNotFoundError: HF_API_TOKEN = os.environ.get("HF_API_TOKEN") has_api_token = HF_API_TOKEN is not None known_datasets = [ ("ivrit-ai/eval-d1:test:text", None, "ivrit_ai_eval_d1"), ("upai-inc/saspeech:test:text", None, "saspeech"), ("google/fleurs:test:transcription", "he_il", "fleurs"), ("mozilla-foundation/common_voice_17_0:test:sentence", "he", "common_voice_17"), ("imvladikon/hebrew_speech_kan:validation:sentence", None, "hebrew_speech_kan"), ] # Initialize session state for audio cache if it doesn't exist if "audio_cache" not in st.session_state: st.session_state.audio_cache = {} if "audio_preview_active" not in st.session_state: st.session_state.audio_preview_active = {} if "uploaded_file" not in st.session_state: st.session_state.results_file = None if "selected_entry_idx" not in st.session_state: st.session_state.selected_entry_idx = 0 if "total_entry_count" not in st.session_state: st.session_state.total_entry_count = 0 if "entry_page_size" not in st.session_state: st.session_state.entry_page_size = 20 def get_current_page_slice(): ss = st.session_state if ss.total_entry_count == 0: return slice(0, 0) page_first_entry = ( st.session_state.selected_entry_idx // ss.entry_page_size ) * ss.entry_page_size page_last_entry = min(page_first_entry + ss.entry_page_size, ss.total_entry_count) return slice(page_first_entry, page_last_entry) def page_navigation(): ss = st.session_state current_page_slice = get_current_page_slice() has_next_page = current_page_slice.stop < ss.total_entry_count - 1 has_prev_page = current_page_slice.start >= ss.entry_page_size col1, col2 = st.columns(2) if col1.button("Prev Page", disabled=not has_prev_page): ss.selected_entry_idx = current_page_slice.start - 1 st.rerun() if col2.button("Next Page", disabled=not has_next_page): ss.selected_entry_idx = current_page_slice.start + ss.entry_page_size st.rerun() def on_file_upload(): st.session_state.audio_cache = {} st.session_state.audio_preview_active = {} st.session_state.selected_entry_idx = 0 st.session_state.results_file = None @st.cache_data def get_leaderboard_result_csv_paths(root_search_path): fs = HfFileSystem(token=HF_API_TOKEN) found_files = fs.glob(f"{root_search_path}/*/*.csv") found_files_relative_paths = [f.split(root_search_path)[1] for f in found_files] return found_files_relative_paths @st.dialog("View Leaderboard Results") def choose_input_file_from_leaderboard(): if not has_api_token: st.rerun() root_search_path = "ivrit-ai/hebrew-transcription-leaderboard/results" fsspec_spaces_root_search_path = f"spaces/{root_search_path}" found_files_relative_paths = get_leaderboard_result_csv_paths( fsspec_spaces_root_search_path ) selected_file = st.selectbox( "Select a CSV file from the leaderboard:", found_files_relative_paths, index=None, ) # Get the selected file if selected_file: paths_part = Path(selected_file).parent file_part = Path(selected_file).name uploaded_file = huggingface_hub.hf_hub_url( repo_id="ivrit-ai/hebrew-transcription-leaderboard", subfolder=f"results{paths_part}", filename=file_part, repo_type="space", ) on_file_upload() st.session_state.results_file = uploaded_file st.rerun() @st.cache_data def read_results_csv(uploaded_file): with st.spinner("Loading results...", show_time=True): results_df = pd.read_csv(uploaded_file) return results_df def display_rtl(html): """Render an RTL container with the provided HTML string""" st.markdown( f"""
{html}
""", unsafe_allow_html=True, ) @st.cache_data def calculate_final_metrics(uploaded_file, _df): """Calculate final metrics for all entries Args: uploaded_file: The uploaded file object (For cache hash gen) _df: The dataframe containing the evaluation results (not included in cache hash) Returns: A dictionary containing the final metrics """ _df = _df.sort_values(by=["id"]) _df["reference_text"] = _df["reference_text"].fillna("") _df["predicted_text"] = _df["predicted_text"].fillna("") # convert to list of dicts entries_data = _df.to_dict(orient="records") htn = HebrewTextNormalizer() # Calculate final metrics results = jiwer.process_words( [htn(entry["reference_text"]) for entry in entries_data], [htn(entry["predicted_text"]) for entry in entries_data], ) return results def get_known_dataset_by_output_name(output_name): for dataset in known_datasets: if dataset[2] == output_name: return dataset return None def get_dataset_entries_audio_urls(dataset, offset=0, max_entries=100): if dataset is None or not has_api_token: return None dataset_repo_id, dataset_config, _ = dataset if not dataset_config: dataset_config = "default" if ":" in dataset_repo_id: dataset_repo_id, split, _ = dataset_repo_id.split(":") else: split = "test" headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} api_query_params = { "dataset": dataset_repo_id, "config": dataset_config, "split": split, "offset": offset, "length": max_entries, } query_params_str = "&".join([f"{k}={v}" for k, v in api_query_params.items()]) API_URL = f"https://datasets-server.huggingface.co/rows?{query_params_str}" def query(): response = requests.get(API_URL, headers=headers) return response.json() data = query() def get_audio_url(row): audio_feature_list = row["row"]["audio"] first_audio = audio_feature_list[0] return first_audio["src"] if "rows" in data and len(data["rows"]) > 0: return [get_audio_url(row) for row in data["rows"]] else: return None def get_audio_url_for_entry( dataset, entry_idx, cache_neighbors=True, neighbor_range=20 ): """ Get audio URL for a specific entry and optionally cache neighbors Args: dataset: Dataset tuple (repo_id, config, output_name) entry_idx: Index of the entry to get audio URL for cache_neighbors: Whether to cache audio URLs for neighboring entries neighbor_range: Range of neighboring entries to cache Returns: Audio URL for the specified entry """ # Calculate the range of entries to load if cache_neighbors: start_idx = max(0, entry_idx - neighbor_range) max_entries = neighbor_range * 2 + 1 else: start_idx = entry_idx max_entries = 1 # Get audio URLs for the range of entries audio_urls = get_dataset_entries_audio_urls(dataset, start_idx, max_entries) if not audio_urls: return None # Cache the audio URLs for i, url in enumerate(audio_urls): idx = start_idx + i # Extract expiration time from URL if available expires = None if "expires=" in url: try: expires_param = url.split("expires=")[1].split("&")[0] expires = datetime.fromtimestamp(int(expires_param)) except (ValueError, IndexError): expires = None st.session_state.audio_cache[idx] = {"url": url, "expires": expires} # Return the URL for the requested entry relative_idx = entry_idx - start_idx if 0 <= relative_idx < len(audio_urls): return audio_urls[relative_idx] return None def get_cached_audio_url(entry_idx): """ Get audio URL from cache if available and not expired Args: entry_idx: Index of the entry to get audio URL for Returns: Audio URL if available in cache and not expired, None otherwise """ if entry_idx not in st.session_state.audio_cache: return None cache_entry = st.session_state.audio_cache[entry_idx] # Check if the URL is expired if cache_entry["expires"] and datetime.now() > cache_entry["expires"]: return None return cache_entry["url"] def main(): st.set_page_config( page_title="ASR Evaluation Visualizer", page_icon="🎤", layout="wide" ) if not has_api_token: st.warning("No Hugging Face API token found. Audio previews will not work.") st.title("ASR Evaluation Visualizer") # File uploader uploaded_file = st.file_uploader( "Upload evaluation results CSV", type=["csv"], on_change=on_file_upload, key="uploaded_file", ) if uploaded_file is not None: st.session_state.results_file = uploaded_file if uploaded_file is None: st.write("Or:") if st.button("Choose from leaderboard"): choose_input_file_from_leaderboard() if st.session_state.results_file is not None: uploaded_file = st.session_state.results_file # Load the data try: eval_results = read_results_csv(uploaded_file) st.session_state.total_entry_count = len(eval_results) st.success("File uploaded successfully!") with st.sidebar: # Toggle for calculating total metrics show_total_metrics = st.toggle("Show total metrics", value=False) if show_total_metrics: total_metrics = calculate_final_metrics(uploaded_file, eval_results) # Display total metrics in a nice format with st.container(): st.metric("WER", f"{total_metrics.wer * 100:.4f}%") st.table( { "Hits": total_metrics.hits, "Subs": total_metrics.substitutions, "Dels": total_metrics.deletions, "Insrt": total_metrics.insertions, } ) # Toggle for normalized vs raw text use_normalized = st.sidebar.toggle("Use normalized text", value=True) # Create sidebar for entry selection st.sidebar.header("Select Entry") # Add Next/Prev buttons at the top of the sidebar col1, col2 = st.sidebar.columns(2) # Define navigation functions def go_prev(): if st.session_state.selected_entry_idx > 0: st.session_state.selected_entry_idx -= 1 def go_next(): if st.session_state.selected_entry_idx < len(eval_results) - 1: st.session_state.selected_entry_idx += 1 # Add navigation buttons col1.button("← Prev", on_click=go_prev, use_container_width=True) col2.button("Next →", on_click=go_next, use_container_width=True) # Use a container for better styling entry_container = st.sidebar.container() with entry_container: page_navigation() st.write(f"Total entries: {st.session_state.total_entry_count}") # Create a data table with entries and their WER entries_data = [] for i in range(len(eval_results)): wer_value = eval_results.iloc[i].get("wer", 0) # Format WER as percentage wer_formatted = ( f"{wer_value*100:.2f}%" if isinstance(wer_value, (int, float)) else wer_value ) entries_data.append({"Entry": f"Entry #{i+1}", "WER": wer_formatted}) # Create a selection mechanism using radio buttons that look like a table st.sidebar.write("Select an entry") # Create a radio button for each entry, styled to look like a table row current_page_slice = get_current_page_slice() entry_container.radio( "Select an entry", options=list(range(len(eval_results))[current_page_slice]), format_func=lambda i: f"Entry #{i+1} ({entries_data[i]['WER']})", label_visibility="collapsed", key="selected_entry_idx", ) # Use the selected entry selected_entry = st.session_state.selected_entry_idx # Get the text columns based on the toggle if use_normalized and "norm_reference_text" in eval_results.columns: ref_col, hyp_col = "norm_reference_text", "norm_predicted_text" else: ref_col, hyp_col = "reference_text", "predicted_text" # Get the reference and hypothesis texts ref, hyp = eval_results.iloc[selected_entry][[ref_col, hyp_col]].values st.header("Visualization") # Check if the CSV file is from a known dataset dataset_name = None # If no dataset column, try to infer from filename if uploaded_file is not None: if isinstance(uploaded_file, str): filename_stem = Path(uploaded_file).stem else: filename_stem = Path(uploaded_file.name).stem dataset_name = filename_stem if not dataset_name and "dataset" in eval_results.columns: dataset_name = eval_results.iloc[selected_entry]["dataset"] # Get the known dataset if available known_dataset = get_known_dataset_by_output_name(dataset_name) # Display audio preview button if from a known dataset if known_dataset: # Check if we have the audio URL in cache audio_url = get_cached_audio_url(selected_entry) audio_preview_active = st.session_state.audio_preview_active.get( selected_entry, False ) preview_audio = False if not audio_preview_active: # Create a button to preview audio preview_audio = st.button("Preview Audio", key="preview_audio") if preview_audio or audio_url: st.session_state.audio_preview_active[selected_entry] = True with st_fixed_container( mode="sticky", position="top", border=True, margin=0 ): # If button clicked or we already have the URL, get/use the audio URL if not audio_url: with st.spinner("Loading audio..."): audio_url = get_audio_url_for_entry( known_dataset, selected_entry ) # Display the audio player in the sticky container at the top if audio_url: st.audio(audio_url) else: st.error("Failed to load audio for this entry.") # Display the visualization html = render_visualize_jiwer_result_html(ref, hyp) display_rtl(html) # Display metadata st.header("Metadata") metadata_cols = [ "metadata_uuid", "model", "dataset", "dataset_split", "engine", ] metadata = eval_results.iloc[selected_entry][metadata_cols] # Create a DataFrame for better display metadata_df = pd.DataFrame( {"Field": metadata_cols, "Value": metadata.values} ) st.table(metadata_df) # If we have audio URL, display it in the sticky container if "audio_url" in locals() and audio_url: pass # CSS is now applied globally except Exception as e: st.error(f"Error processing file: {str(e)}") else: st.info( "Please upload an evaluation results CSV file to visualize the results." ) st.markdown( """ ### Expected CSV Format The CSV should have the following columns: - id - reference_text - predicted_text - norm_reference_text - norm_predicted_text - wer - wil - substitutions - deletions - insertions - hits - metadata_uuid - model - dataset - dataset_split - engine """ ) if __name__ == "__main__": main()