Spaces:
Running
Running
import gradio as gr | |
import json | |
import os | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from io import BytesIO | |
from PIL import Image | |
# ------------------------------- | |
# 1. Load Results from Local File | |
# ------------------------------- | |
def load_results(): | |
# Get the directory of the current file | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
results_file = os.path.join(current_dir, "files", "aragen_v1_results.json") | |
with open(results_file, "r") as f: | |
data = json.load(f) | |
# Filter out any non-model entries (e.g., timestamp entries) | |
model_data = [entry for entry in data if "Meta" in entry] | |
return model_data | |
# Load the JSON data once when the app starts | |
DATA = load_results() | |
# Extract model names for the dropdown from the JSON "Meta" field | |
def get_model_names(data): | |
model_names = [entry["Meta"]["Model Name"] for entry in data] | |
return model_names | |
MODEL_NAMES = get_model_names(DATA) | |
# ------------------------------- | |
# 2. Define Metrics and Heatmap Generation Functions | |
# ------------------------------- | |
# Define the six metrics in the desired order. | |
METRICS = ["Correctness", "Completeness", "Conciseness", "Helpfulness", "Honesty", "Harmlessness"] | |
def generate_heatmap_image(model_entry): | |
""" | |
For a given model entry, extract the six metrics and compute a 6x6 similarity matrix | |
using the definition: similarity = 1 - |v_i - v_j|, then return the heatmap as a PIL image. | |
""" | |
scores = model_entry["claude-3.5-sonnet Scores"]["3C3H Scores"] | |
# Create a vector with the metrics in the defined order. | |
v = np.array([scores[m] for m in METRICS]) | |
# Compute the 6x6 similarity matrix. | |
matrix = 1 - np.abs(np.subtract.outer(v, v)) | |
# Create a mask for the upper triangle (keeping the diagonal visible). | |
mask = np.triu(np.ones_like(matrix, dtype=bool), k=1) | |
# Set a consistent figure size that will work well in the gallery | |
plt.figure(figsize=(6, 5), dpi=100) | |
sns.heatmap(matrix, | |
mask=mask, | |
annot=True, | |
fmt=".2f", | |
cmap="viridis", | |
xticklabels=METRICS, | |
yticklabels=METRICS, | |
cbar_kws={"label": "Similarity"}) | |
plt.title(f"Confusion Matrix for Model: {model_entry['Meta']['Model Name']}") | |
plt.xlabel("Metrics") | |
plt.ylabel("Metrics") | |
plt.tight_layout() | |
# Save the plot to a bytes buffer. | |
buf = BytesIO() | |
plt.savefig(buf, format="png", bbox_inches="tight") | |
plt.close() | |
buf.seek(0) | |
# Convert the buffer into a PIL Image. | |
image = Image.open(buf).convert("RGB") | |
# Resize the image to a reasonable fixed size for the gallery | |
max_size = (800, 600) | |
image.thumbnail(max_size, Image.Resampling.LANCZOS) | |
return image | |
def generate_heatmaps(selected_model_names): | |
""" | |
Filter the global DATA for entries matching the selected model names, | |
generate a heatmap for each, and return a list of PIL images. | |
""" | |
filtered_entries = [entry for entry in DATA if entry["Meta"]["Model Name"] in selected_model_names] | |
images = [] | |
for entry in filtered_entries: | |
img = generate_heatmap_image(entry) | |
images.append(img) | |
return images | |
# ------------------------------- | |
# 3. Build the Gradio Interface | |
# ------------------------------- | |
with gr.Blocks(css=""" | |
.gallery-item img { | |
max-width: 100% !important; | |
max-height: 100% !important; | |
object-fit: contain !important; | |
} | |
""") as demo: | |
gr.HTML(""" | |
<center> | |
<br></br> | |
<h1>3C3H Heatmap Generator</h1> | |
<h3>Select the models you want to compare and generate their heatmaps below.</h3> | |
<br></br> | |
</center> | |
""") | |
with gr.Row(): | |
default_models = ["silma-ai/SILMA-9B-Instruct-v1.0", "google/gemma-2-9b-it"] | |
model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=default_models) | |
generate_btn = gr.Button("Generate Heatmaps") | |
# Set height and columns for better display | |
gallery = gr.Gallery( | |
label="Heatmaps", | |
columns=2, | |
height="auto", | |
object_fit="contain" | |
) | |
generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery) | |
demo.launch() | |