File size: 10,026 Bytes
f4ad64c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from datasets import load_dataset, Dataset
from functools import lru_cache
from typing import Tuple
import gradio as gr
import json

from env import MODELS, TASK, ORG_NAME

def aggregate_results() -> list:
    """From the path of outputs and model list, extracts the current scores and stores them in a list of dicts with model, score, time as keys 
    """
    all_results = []
    for org_model in MODELS:
        try:
            path = f"{ORG_NAME}/details_{org_model.replace('/', '__')}_private"
            ds = load_dataset(path, "results", split="latest")
            config = json.loads(ds["config_general"][0])
            results = json.loads(ds["results"][0])
                
            # Model data
            org, model = org_model.split("/")

            cur_result = {
                "Org": org, 
                "Model": model, 
                "Duration (s)": config["end_time"] - config["start_time"]
            }
                    
            # Extract the task from the JSON data
            for k_metric, v_dict in results.items():
                if k_metric != "all":
                    for k, v in v_dict.items():
                        cur_result[f"{k}({k_metric})"] = v
            all_results.append(cur_result)
        except Exception as e:
            print(f"Error processing {model} {ORG_NAME}: {e}")
    return all_results

def extract_dataviz() -> Tuple[list, list]:
    """From the path of outputs and model list, extracts from the details the worst samples, best samples
    """
    all_samples = {}
    for org_model in MODELS:
        try:
            path = f"{ORG_NAME}/details_{org_model.replace('/', '__')}_private"
            ds = load_dataset(path, f"custom_{TASK.replace('/', '_')}_0", split="latest")

            for ix, row in enumerate(ds):
                prompt = row["full_prompt"]
                gold = row["gold"]
                score = list(row["metrics"].values())[0]
                prediction = row["predictions"][0]


                # We store flattened samples in a dict
                # ix -> ix, prompt, gold, model_score for each model, model_prediction for each model
                # then 2 lists: model_scores and models, to aggreg more easily
                if ix not in all_samples:
                    all_samples[ix] = {
                        "ix": ix,
                        "prompt": prompt,
                        "gold": gold[0] if isinstance(gold, list) else gold,
                        # A bit redundant, but put in their own boxes for simplicity of access later
                        "model_scores": [],
                        "models": []
                    }
                if org_model not in all_samples[ix]["models"]:
                    all_samples[ix][f"{org_model}_score"] = row["metrics"]
                    all_samples[ix][f"{org_model}_prediction"] = prediction
                    all_samples[ix]["model_scores"].append(score)
                    all_samples[ix]["models"].append(org_model)

        except Exception as e:
            print(f"Error processing {org_model}: {e}")

    full_samples = sorted(list(all_samples.values()), key= lambda r: r['ix'])
    hard_samples = sorted([sample for sample in all_samples.values() if sum(sample["model_scores"]) == 0], key= lambda r: r['ix'])
    easy_samples = sorted([sample for sample in all_samples.values() if sum(sample["model_scores"]) == len(sample["model_scores"])], key= lambda r: r['ix'])
    return easy_samples, hard_samples, full_samples

def samples_to_box_display(samples: list, example_index: int = 0):
    """Adapted from Nathan's code in https://huggingface.co/spaces/SaylorTwift/OpenEvalsModelDetails/
    """
    if len(samples) == 0:
        return "No samples in this category!"
    outputs = []
    sample = samples[example_index]
    for model in sample["models"]:
        try:
            outputs.append({
                'Model': model,
                'Prediction': sample[f'{model}_prediction'],
                'Prompt': sample['prompt'],
                'Metrics': sample[f'{model}_score'],
                'Gold': sample['gold']
            })
        except (KeyError, IndexError):
            continue
    
    if not outputs:
        return "No results found for the selected combination."
    
    # Create HTML output with all models
    html_output = "<div style='max-width: 800px; margin: 0 auto;'>\n\n"
    
    # Show gold answer at the top with distinct styling
    if outputs:
        html_output += "<div style='background: #e6f3e6; padding: 20px; border-radius: 10px; margin-bottom: 20px;'>\n"
        html_output += "<h3 style='margin-top: 0;'>Ground Truth</h3>\n"
        html_output += "<div style='overflow-x: auto; max-width: 100%;'>\n"
        html_output += f"<pre style='white-space: pre-wrap; word-wrap: break-word; margin: 0;'><code>{outputs[0]['Gold']}</code></pre>\n"
        html_output += "</div>\n"
        html_output += "</div>\n"
    
    for output in outputs:
        html_output += "<div style='background: #f5f5f5; padding: 20px; margin-bottom: 20px; border-radius: 10px;'>\n"
        html_output += f"<h2 style='margin-top: 0;'>{output['Model']}</h2>\n"
        
        # Format metrics as a clean table
        html_output += "<details open style='margin-bottom: 15px;'>\n"
        html_output += "<summary><h3 style='display: inline; margin: 0;'>Metrics</h3></summary>\n"
        metrics = output['Metrics']
        if isinstance(metrics, str):
            metrics = eval(metrics)
        html_output += "<div style='overflow-x: auto;'>\n"
        html_output += "<table style='width: 100%; margin: 10px 0; border-collapse: collapse;'>\n"
        for key, value in metrics.items():
            if isinstance(value, float):
                value = f"{value:.3f}"
            html_output += f"<tr><td style='padding: 5px; border-bottom: 1px solid #ddd;'><strong>{key}</strong></td><td style='padding: 5px; border-bottom: 1px solid #ddd;'>{value}</td></tr>\n"
        html_output += "</table>\n"
        html_output += "</div>\n"
        html_output += "</details>\n\n"
        
        # Handle prompt formatting with better styling
        html_output += "<details style='margin-bottom: 15px;'>\n"
        html_output += "<summary><h3 style='display: inline; margin: 0;'>Prompt</h3></summary>\n"
        html_output += "<div style='background: #ffffff; padding: 15px; border-radius: 5px; margin-top: 10px;'>\n"
        
        prompt_text = output['Prompt']
        if isinstance(prompt_text, list):
            for i, msg in enumerate(prompt_text):
                if isinstance(msg, dict) and 'content' in msg:
                    role = msg.get('role', 'message').title()
                    html_output += "<div style='margin-bottom: 10px;'>\n"
                    html_output += f"<strong>{role}:</strong>\n"
                    html_output += "<div style='overflow-x: auto;'>\n"
                    html_output += f"<pre style='white-space: pre-wrap; word-wrap: break-word; margin: 5px 0;'><code>{msg['content']}</code></pre>\n"
                    html_output += "</div>\n"
                    html_output += "</div>\n"
                else:
                    html_output += "<div style='margin-bottom: 10px;'>\n"
                    html_output += "<div style='overflow-x: auto;'>\n"
                    html_output += f"<pre style='white-space: pre-wrap; word-wrap: break-word; margin: 5px 0;'><code>{json.dumps(msg, indent=2)}</code></pre>\n"
                    html_output += "</div>\n"
                    html_output += "</div>\n"
        else:
            html_output += "<div style='overflow-x: auto;'>\n"
            if isinstance(prompt_text, dict) and 'content' in prompt_text:
                html_output += f"<pre style='white-space: pre-wrap; word-wrap: break-word; margin: 5px 0;'><code>{prompt_text['content']}</code></pre>\n"
            else:
                html_output += f"<pre style='white-space: pre-wrap; word-wrap: break-word; margin: 5px 0;'><code>{prompt_text}</code></pre>\n"
            html_output += "</div>\n"
        
        html_output += "</div>\n"
        html_output += "</details>\n\n"
        
        # Style prediction output - now in a collapsible section
        html_output += "<details open style='margin-bottom: 15px;'>\n"
        html_output += "<summary><h3 style='display: inline; margin: 0;'>Prediction</h3>"
        # Add word count in a muted style
        word_count = len(output['Prediction'].split())
        html_output += f"<span style='color: #666; font-size: 0.8em; margin-left: 10px;'>({word_count} words)</span>"
        html_output += "</summary>\n"
        html_output += "<div style='background: #ffffff; padding: 15px; border-radius: 5px; margin-top: 10px;'>\n"
        html_output += "<div style='overflow-x: auto;'>\n"
        html_output += f"<pre style='white-space: pre-wrap; word-wrap: break-word; margin: 0;'><code>{output['Prediction']}</code></pre>\n"
        html_output += "</div>\n"
        html_output += "</div>\n"
        html_output += "</details>\n"
        html_output += "</div>\n\n"
    
    html_output += "</div>"
    return html_output

def run_pipeline(samples_ix: int = 0):
    results = aggregate_results()
    best_samples, worst_samples, all_samples = extract_dataviz()
    return gr.Dataframe(Dataset.from_list(results).to_pandas(), visible=True), \
        gr.HTML(samples_to_box_display(best_samples, samples_ix), label="Easiest samples (always found)", visible=True), \
        gr.HTML(samples_to_box_display(worst_samples, samples_ix), label="Hardest samples (always failed)", visible=True), \
        gr.HTML(samples_to_box_display(all_samples, samples_ix), label="All samples", visible=True)

def update_examples(samples_ix: int = 0):
    best_samples, worst_samples, all_samples = extract_dataviz()
    return samples_to_box_display(best_samples, samples_ix), \
        samples_to_box_display(worst_samples, samples_ix), \
        samples_to_box_display(all_samples, samples_ix)