|
from collections import defaultdict |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import pandas as pd |
|
from pytorch_ie import Document, DocumentMetric |
|
|
|
|
|
class ScoreDistribution(DocumentMetric): |
|
"""Computes the distribution of prediction scores for annotations in a layer. The scores are |
|
separated into true positives (TP) and false positives (FP) based on the gold annotations. |
|
|
|
Args: |
|
layer: The name of the annotation layer to analyze. |
|
per_label: If True, the scores are separated per label. Default is False. |
|
label_field: The field name of the label to use for separating the scores per label. Default is "label". |
|
equal_sample_size_binning: If True, the scores are binned into equal sample sizes. If False, |
|
the scores are binned into equal width. The former is useful when the distribution of scores is skewed. |
|
Default is True. |
|
show_plot: If True, a plot of the score distribution is shown. Default is False. |
|
plotting_backend: The plotting backend to use. Default is "plotly". |
|
plotting_caption_mapping: A mapping to rename any caption entries for plotting, i.e., the layer name, |
|
labels, or TP/FP. Default is None. |
|
plotting_colors: A dictionary mapping from gold scores to colors for plotting. Default is None. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
layer: str, |
|
label_field: str = "label", |
|
per_label: bool = False, |
|
show_plot: bool = False, |
|
equal_sample_size_binning: bool = True, |
|
plotting_backend: str = "plotly", |
|
plotting_caption_mapping: Optional[Dict[str, str]] = None, |
|
plotting_colors: Optional[Dict[str, str]] = None, |
|
plotly_use_create_distplot: bool = True, |
|
plotly_barmode: Optional[str] = None, |
|
plotly_marginal: Optional[str] = "violin", |
|
plotly_font_size: int = 18, |
|
plotly_font_family: Optional[str] = None, |
|
plotly_background_color: Optional[str] = None, |
|
): |
|
super().__init__() |
|
self.layer = layer |
|
self.label_field = label_field |
|
self.per_label = per_label |
|
self.equal_sample_size_binning = equal_sample_size_binning |
|
self.plotting_backend = plotting_backend |
|
self.show_plot = show_plot |
|
self.plotting_caption_mapping = plotting_caption_mapping or {} |
|
self.plotting_colors = plotting_colors |
|
self.plotly_use_create_distplot = plotly_use_create_distplot |
|
self.plotly_barmode = plotly_barmode |
|
self.plotly_marginal = plotly_marginal |
|
self.plotly_font_size = plotly_font_size |
|
self.plotly_font_family = plotly_font_family |
|
self.plotly_background_color = plotly_background_color |
|
self.scores: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) |
|
|
|
def reset(self): |
|
self.scores = defaultdict(lambda: defaultdict(list)) |
|
|
|
def _update(self, document: Document): |
|
|
|
gold_annotations = set(document[self.layer]) |
|
for ann in document[self.layer].predictions: |
|
if self.per_label: |
|
label = getattr(ann, self.label_field) |
|
else: |
|
label = "ALL" |
|
if ann in gold_annotations: |
|
self.scores[label]["TP"].append(ann.score) |
|
else: |
|
self.scores[label]["FP"].append(ann.score) |
|
|
|
def _combine_scores( |
|
self, |
|
scores_tp: List[float], |
|
score_fp: List[float], |
|
col_name_pred: str = "prediction", |
|
col_name_gold: str = "gold", |
|
) -> pd.DataFrame: |
|
scores_tp_df = pd.DataFrame(scores_tp, columns=[col_name_pred]) |
|
scores_tp_df[col_name_gold] = 1.0 |
|
scores_fp_df = pd.DataFrame(score_fp, columns=[col_name_pred]) |
|
scores_fp_df[col_name_gold] = 0.0 |
|
scores_df = pd.concat([scores_tp_df, scores_fp_df]) |
|
return scores_df |
|
|
|
def _get_calibration_data_and_metrics( |
|
self, scores: pd.DataFrame, q: int = 20 |
|
) -> Tuple[pd.DataFrame, pd.Series]: |
|
from sklearn.metrics import brier_score_loss |
|
|
|
if self.equal_sample_size_binning: |
|
|
|
scores["bin"] = pd.qcut(scores["prediction"], q=q, labels=False) |
|
else: |
|
|
|
scores["bin"] = pd.cut( |
|
scores["prediction"], |
|
bins=q, |
|
include_lowest=True, |
|
right=True, |
|
labels=False, |
|
) |
|
|
|
calibration_data = ( |
|
scores.groupby("bin") |
|
.apply( |
|
lambda x: pd.Series( |
|
{ |
|
"avg_score": x["prediction"].mean(), |
|
"fraction_positive": x["gold"].mean(), |
|
"count": len(x), |
|
} |
|
) |
|
) |
|
.reset_index() |
|
) |
|
|
|
total_count = scores.shape[0] |
|
calibration_data["bin_weight"] = calibration_data["count"] / total_count |
|
|
|
|
|
calibration_data["abs_diff"] = abs( |
|
calibration_data["avg_score"] - calibration_data["fraction_positive"] |
|
) |
|
calibration_data["squared_diff"] = ( |
|
calibration_data["avg_score"] - calibration_data["fraction_positive"] |
|
) ** 2 |
|
|
|
|
|
ece = (calibration_data["abs_diff"] * calibration_data["bin_weight"]).sum() |
|
|
|
|
|
mce = calibration_data["abs_diff"].max() |
|
|
|
|
|
mse = (calibration_data["squared_diff"] * calibration_data["bin_weight"]).sum() |
|
|
|
|
|
brier = brier_score_loss(scores["gold"], scores["prediction"]) |
|
|
|
values = { |
|
"ece": ece, |
|
"mce": mce, |
|
"mse": mse, |
|
"brier": brier, |
|
} |
|
return calibration_data, pd.Series(values) |
|
|
|
def calculate_calibration_metrics(self, scores_combined: pd.DataFrame) -> pd.DataFrame: |
|
|
|
calibration_data_dict = {} |
|
calibration_metrics_dict = {} |
|
for label, current_scores in scores_combined.groupby("label"): |
|
calibration_data, calibration_metrics = self._get_calibration_data_and_metrics( |
|
current_scores, q=20 |
|
) |
|
calibration_data_dict[label] = calibration_data |
|
calibration_metrics_dict[label] = calibration_metrics |
|
all_calibration_data = pd.concat( |
|
calibration_data_dict, names=["label", "idx"] |
|
).reset_index(level=0) |
|
all_calibration_metrics = pd.concat(calibration_metrics_dict, axis=1).T |
|
|
|
if self.show_plot: |
|
self.plot_calibration_data(calibration_data=all_calibration_data) |
|
|
|
return all_calibration_metrics |
|
|
|
def calculate_correlation(self, scores: pd.DataFrame) -> pd.Series: |
|
result_dict = {} |
|
for label, current_scores in scores.groupby("label"): |
|
result_dict[label] = current_scores.drop("label", axis=1).corr()["prediction"]["gold"] |
|
|
|
return pd.Series(result_dict, name="correlation") |
|
|
|
@property |
|
def mapped_layer(self): |
|
return self.plotting_caption_mapping.get(self.layer, self.layer) |
|
|
|
def plot_score_distribution(self, scores: pd.DataFrame): |
|
if self.plotting_backend == "plotly": |
|
for label in scores["label"].unique(): |
|
description = f"Distribution of Predicted Scores for {self.mapped_layer}" |
|
if self.per_label: |
|
label_mapped = self.plotting_caption_mapping.get(label, label) |
|
description += f" ({label_mapped})" |
|
if self.plotly_use_create_distplot: |
|
import plotly.figure_factory as ff |
|
|
|
current_scores = scores[scores["label"] == label] |
|
|
|
scores_dict = ( |
|
current_scores.groupby("gold")["prediction"].apply(list).to_dict() |
|
) |
|
group_labels, hist_data = zip(*scores_dict.items()) |
|
group_labels_renamed = [ |
|
self.plotting_caption_mapping.get(label, label) for label in group_labels |
|
] |
|
if self.plotting_colors is not None: |
|
colors = [ |
|
self.plotting_colors[group_label] for group_label in group_labels |
|
] |
|
else: |
|
colors = None |
|
fig = ff.create_distplot( |
|
hist_data, |
|
group_labels=group_labels_renamed, |
|
show_hist=True, |
|
colors=colors, |
|
bin_size=0.025, |
|
) |
|
else: |
|
import plotly.express as px |
|
|
|
fig = px.histogram( |
|
scores, |
|
x="prediction", |
|
color="gold", |
|
marginal=self.plotly_marginal, |
|
hover_data=scores.columns, |
|
color_discrete_map=self.plotting_colors, |
|
nbins=50, |
|
) |
|
|
|
fig.update_layout( |
|
height=600, |
|
width=800, |
|
title_text=description, |
|
title_x=0.5, |
|
font=dict(size=self.plotly_font_size), |
|
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), |
|
) |
|
if self.plotly_barmode is not None: |
|
fig.update_layout(barmode=self.plotly_barmode) |
|
if self.plotly_font_family is not None: |
|
fig.update_layout(font_family=self.plotly_font_family) |
|
if self.plotly_background_color is not None: |
|
fig.update_layout( |
|
plot_bgcolor=self.plotly_background_color, |
|
paper_bgcolor=self.plotly_background_color, |
|
) |
|
|
|
fig.show() |
|
else: |
|
raise NotImplementedError(f"Plotting backend {self.plotting_backend} not implemented") |
|
|
|
def plot_calibration_data(self, calibration_data: pd.DataFrame): |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
|
|
color = "label" if self.per_label else None |
|
x_col = "avg_score" |
|
y_col = "fraction_positive" |
|
fig = px.scatter( |
|
calibration_data, |
|
x=x_col, |
|
y=y_col, |
|
color=color, |
|
trendline="ols", |
|
labels=self.plotting_caption_mapping, |
|
) |
|
if not self.per_label: |
|
fig["data"][1]["name"] = "prediction vs. gold" |
|
|
|
|
|
for idx, trace_data in enumerate(fig["data"]): |
|
if idx % 2 == 0: |
|
trace_data["showlegend"] = False |
|
else: |
|
trace_data["showlegend"] = True |
|
|
|
|
|
minimum = calibration_data[x_col].min() |
|
maximum = calibration_data[x_col].max() |
|
fig.add_trace( |
|
go.Scatter( |
|
x=[minimum, maximum], |
|
y=[minimum, maximum], |
|
mode="lines", |
|
name="optimal", |
|
line=dict(color="black", dash="dash"), |
|
) |
|
) |
|
fig.update_layout( |
|
height=600, |
|
width=800, |
|
title_text=f"Mean Binned Scores for {self.mapped_layer}", |
|
title_x=0.5, |
|
font=dict(size=self.plotly_font_size), |
|
) |
|
fig.update_layout( |
|
legend=dict( |
|
yanchor="top", |
|
y=0.99, |
|
xanchor="left", |
|
x=0.01, |
|
title="OLS trendline" + ("s" if self.per_label else ""), |
|
), |
|
) |
|
if self.plotly_background_color is not None: |
|
fig.update_layout( |
|
plot_bgcolor=self.plotly_background_color, |
|
paper_bgcolor=self.plotly_background_color, |
|
) |
|
|
|
if self.plotly_font_family is not None: |
|
fig.update_layout(font_family=self.plotly_font_family) |
|
|
|
fig.show() |
|
|
|
def _compute(self) -> Dict[str, Dict[str, Any]]: |
|
scores_combined = pd.concat( |
|
{ |
|
label: self._combine_scores(scores["TP"], scores["FP"]) |
|
for label, scores in self.scores.items() |
|
}, |
|
names=["label", "idx"], |
|
).reset_index(level=0) |
|
|
|
result_df = scores_combined.groupby("label")["prediction"].agg(["mean", "std", "count"]) |
|
if self.show_plot: |
|
self.plot_score_distribution(scores=scores_combined) |
|
|
|
calibration_metrics = self.calculate_calibration_metrics(scores_combined) |
|
calibration_metrics["correlation"] = self.calculate_correlation(scores_combined) |
|
|
|
result_df = pd.concat( |
|
{"prediction": result_df, "prediction vs. gold": calibration_metrics}, axis=1 |
|
) |
|
|
|
if not self.per_label: |
|
result = result_df.xs("ALL") |
|
else: |
|
result = result_df.T.stack().unstack() |
|
|
|
result_dict = { |
|
main_key: result.xs(main_key).T.to_dict() |
|
for main_key in result.index.get_level_values(0).unique() |
|
} |
|
|
|
return result_dict |
|
|