keyword-embeddings-space / plot_utils.py
latticetower's picture
fix avxline in plots, use common legend in gradio, add reaction and loading on launch
b40aac1
raw
history blame
3.52 kB
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from constants import *
def groupby(array_like, hue_order=None):
idx = np.argsort(array_like, kind='stable')
values, indices, counts = np.unique(array_like[idx], return_counts=True, return_index=True)
split_idx = np.split(idx, indices[1:])
name2indices = {group_name: indices for group_name, indices in zip(values, split_idx)}
if hue_order is not None and isinstance(hue_order, list):
for k in sorted(hue_order):
if k in name2indices:
yield k, name2indices[k]
return
for k in sorted(name2indices):
yield k, name2indices[k]
def draw_barplots(targets_list, label_list=None, top_n=5, bin_width=1,
hue_group_offset=0.5, hue_order=[],
hue2count={}, width=0.9, ax=None, show_legend=True,
palette='tab10'):
if isinstance(palette, str):
palette = sns.color_palette(palette)
if label_list is None:
label_list = np.asarray([hue_order[x] for x in targets_list])
hue_values, ucount = np.unique(targets_list, return_counts=True)
n_bins = max(len(hue_values), len(hue_values))
bin_size = top_n
hue_offset = np.arange(n_bins)*(bin_size*bin_width + hue_group_offset) #
hue_label2offset = {hue_order[k]: v for k, v in zip(hue_values, hue_offset)}
# print(hue_label2offset)
tick_positions = []
tick_labels = []
max_x_value = 0
for idx, (hue_index, hue_indices) in enumerate(groupby(targets_list)):
hue_label = hue_order[hue_index]
#print(idx, hue_label, hue_indices)
bottom = np.zeros(n_bins*bin_size)
subset_y = label_list[hue_indices]
#print(subset_y)
bin_labels, bin_counts = np.unique(subset_y, return_counts=True)
# if normalize:
denominator = hue2count.get(hue_label, 1)
bin_counts = bin_counts / denominator
max_x_value = max(max_x_value, bin_counts.max())
if hue_label in hue_order:
color_index = hue_order.index(hue_label)
else:
color_index = idx
# new
top_indices = np.argsort(bin_counts)[::-1][:bin_size]
bin_labels = bin_labels[top_indices]
bin_counts = bin_counts[top_indices]
bin_indices = np.asarray([hue_label2offset[hue_label] + i for i, label in enumerate(bin_labels)])
tick_positions.extend(bin_indices)
tick_labels.extend(bin_labels)
# old
#offset = hue_offsets.get(hue_label, 0)
#bin_indices = np.asarray([label2tick[t]+offset for t in bin_labels])
p = ax.barh(
bin_indices, bin_counts, width, label=hue_label, # left=bottom[bin_indices],
color=palette[color_index])
# if do_stack:
# bottom[bin_indices] += bin_counts
# if not normalize:
# bottom[bin_indices] += bar_offset
line_pos = bin_indices.max() + width/2 + hue_group_offset/2
ax.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE)
if show_legend:
ax.legend(
loc='upper center', bbox_to_anchor=(0.5, -0.05),
fancybox=True, shadow=True,
ncol=4
)
ax.set_yticks(tick_positions)
ax.set_yticklabels(tick_labels)
if max_x_value <= 1:
ax.set_xlim(0, 1.)
ax.set_ylim(-0.5, np.max(tick_positions)+width/2+hue_group_offset/2)
ax.invert_yaxis()