import gradio as gr import json import os import numpy as np import matplotlib.pyplot as plt import seaborn as sns from io import BytesIO from PIL import Image # ------------------------------- # 1. Load Results from Local File # ------------------------------- def load_results(): # Get the directory of the current file current_dir = os.path.dirname(os.path.abspath(__file__)) results_file = os.path.join(current_dir, "files", "aragen_v1_results.json") with open(results_file, "r") as f: data = json.load(f) # Filter out any non-model entries (e.g., timestamp entries) model_data = [entry for entry in data if "Meta" in entry] return model_data # Load the JSON data once when the app starts DATA = load_results() # Extract model names for the dropdown from the JSON "Meta" field def get_model_names(data): model_names = [entry["Meta"]["Model Name"] for entry in data] return model_names MODEL_NAMES = get_model_names(DATA) # ------------------------------- # 2. Define Metrics and Heatmap Generation Functions # ------------------------------- # Define the six metrics in the desired order. METRICS = ["Correctness", "Completeness", "Conciseness", "Helpfulness", "Honesty", "Harmlessness"] def generate_heatmap_image(model_entry): """ For a given model entry, extract the six metrics and compute a 6x6 similarity matrix using the definition: similarity = 1 - |v_i - v_j|, then return the heatmap as a PIL image. """ scores = model_entry["claude-3.5-sonnet Scores"]["3C3H Scores"] # Create a vector with the metrics in the defined order. v = np.array([scores[m] for m in METRICS]) # Compute the 6x6 similarity matrix. matrix = 1 - np.abs(np.subtract.outer(v, v)) # Create a mask for the upper triangle (keeping the diagonal visible). mask = np.triu(np.ones_like(matrix, dtype=bool), k=1) # Set a consistent figure size that will work well in the gallery plt.figure(figsize=(6, 5), dpi=100) sns.heatmap(matrix, mask=mask, annot=True, fmt=".2f", cmap="viridis", xticklabels=METRICS, yticklabels=METRICS, cbar_kws={"label": "Similarity"}) plt.title(f"Confusion Matrix for Model: {model_entry['Meta']['Model Name']}") plt.xlabel("Metrics") plt.ylabel("Metrics") plt.tight_layout() # Save the plot to a bytes buffer. buf = BytesIO() plt.savefig(buf, format="png", bbox_inches="tight") plt.close() buf.seek(0) # Convert the buffer into a PIL Image. image = Image.open(buf).convert("RGB") # Resize the image to a reasonable fixed size for the gallery max_size = (800, 600) image.thumbnail(max_size, Image.Resampling.LANCZOS) return image def generate_heatmaps(selected_model_names): """ Filter the global DATA for entries matching the selected model names, generate a heatmap for each, and return a list of PIL images. """ filtered_entries = [entry for entry in DATA if entry["Meta"]["Model Name"] in selected_model_names] images = [] for entry in filtered_entries: img = generate_heatmap_image(entry) images.append(img) return images # ------------------------------- # 3. Build the Gradio Interface # ------------------------------- with gr.Blocks(css=""" .gallery-item img { max-width: 100% !important; max-height: 100% !important; object-fit: contain !important; } """) as demo: gr.HTML("""


3C3H Heatmap Generator

Select the models you want to compare and generate their heatmaps below.



""") with gr.Row(): default_models = ["silma-ai/SILMA-9B-Instruct-v1.0", "google/gemma-2-9b-it"] model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=default_models) generate_btn = gr.Button("Generate Heatmaps") # Set height and columns for better display gallery = gr.Gallery( label="Heatmaps", columns=2, height="auto", object_fit="contain" ) generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery) demo.launch()