alielfilali01 commited on
Commit
6a40ae3
·
verified ·
1 Parent(s): 3df2235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py CHANGED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import requests
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ from io import BytesIO
8
+
9
+ # -------------------------------
10
+ # 1. Configuration and Data Loading
11
+ # -------------------------------
12
+ # URL to the JSON file (the URL below resolves to the raw file)
13
+ DATA_URL = "https://huggingface.co/spaces/alielfilali01/3C3H-HeatMap/resolve/main/files/aragen_v1_results.json"
14
+
15
+ # Define the metrics order (6 dimensions)
16
+ METRICS = ["Correctness", "Completeness", "Conciseness", "Helpfulness", "Honesty", "Harmlessness"]
17
+
18
+ def load_data(url=DATA_URL):
19
+ response = requests.get(url)
20
+ data = response.json()
21
+ # Filter out any non-model entries (e.g. timestamp entries)
22
+ model_data = [entry for entry in data if "Meta" in entry]
23
+ return model_data
24
+
25
+ # Load the JSON data once when the app starts
26
+ DATA = load_data()
27
+
28
+ # Extract model names for the dropdown based on the JSON "Meta" field
29
+ def get_model_names(data):
30
+ model_names = [entry["Meta"]["Model Name"] for entry in data]
31
+ return model_names
32
+
33
+ MODEL_NAMES = get_model_names(DATA)
34
+
35
+ # -------------------------------
36
+ # 2. Heatmap Generation Functions
37
+ # -------------------------------
38
+ def generate_heatmap_image(model_entry):
39
+ """
40
+ Given a model entry from the JSON data, this function extracts the 6 metrics,
41
+ computes a 6x6 similarity matrix using the definition: similarity = 1 - |v_i - v_j|,
42
+ and returns the heatmap image as bytes.
43
+ """
44
+ scores = model_entry["claude-3.5-sonnet Scores"]["3C3H Scores"]
45
+ # Create a vector with the metrics in the defined order
46
+ v = np.array([scores[m] for m in METRICS])
47
+ # Compute the 6x6 similarity matrix
48
+ matrix = 1 - np.abs(np.subtract.outer(v, v))
49
+
50
+ # Create a mask for the upper triangle (diagonal remains visible)
51
+ mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
52
+
53
+ plt.figure(figsize=(6, 5))
54
+ ax = sns.heatmap(matrix,
55
+ mask=mask,
56
+ annot=True,
57
+ fmt=".2f",
58
+ cmap="viridis",
59
+ xticklabels=METRICS,
60
+ yticklabels=METRICS,
61
+ cbar_kws={"label": "Similarity"})
62
+ plt.title(f"Confusion Matrix for Model: {model_entry['Meta']['Model Name']}")
63
+ plt.xlabel("Metrics")
64
+ plt.ylabel("Metrics")
65
+ plt.tight_layout()
66
+
67
+ # Save the figure to a bytes buffer
68
+ buf = BytesIO()
69
+ plt.savefig(buf, format="png")
70
+ plt.close()
71
+ buf.seek(0)
72
+ return buf.read()
73
+
74
+ def generate_heatmaps(selected_model_names):
75
+ """
76
+ Filters the global DATA for entries matching the selected model names,
77
+ generates a heatmap for each one, and returns a list of image bytes.
78
+ """
79
+ filtered_entries = [entry for entry in DATA if entry["Meta"]["Model Name"] in selected_model_names]
80
+ images = []
81
+ for entry in filtered_entries:
82
+ img_bytes = generate_heatmap_image(entry)
83
+ images.append(img_bytes)
84
+ return images
85
+
86
+ # -------------------------------
87
+ # 3. Build the Gradio Interface
88
+ # -------------------------------
89
+ with gr.Blocks() as demo:
90
+ gr.Markdown("## 3C3H Heatmap Generator")
91
+ gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
92
+
93
+ with gr.Row():
94
+ model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
95
+
96
+ generate_btn = gr.Button("Generate Heatmaps")
97
+ gallery = gr.Gallery(label="Heatmaps").style(grid=[2], height="auto")
98
+
99
+ generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
100
+
101
+ # Launch the Gradio app
102
+ demo.launch()