import numpy as np import matplotlib.pyplot as plt import seaborn as sns from constants import * def groupby(array_like, hue_order=None): idx = np.argsort(array_like, kind='stable') values, indices, counts = np.unique(array_like[idx], return_counts=True, return_index=True) split_idx = np.split(idx, indices[1:]) name2indices = {group_name: indices for group_name, indices in zip(values, split_idx)} if hue_order is not None and isinstance(hue_order, list): for k in sorted(hue_order): if k in name2indices: yield k, name2indices[k] return for k in sorted(name2indices): yield k, name2indices[k] def draw_barplots(targets_list, label_list=None, top_n=5, bin_width=1, hue_group_offset=0.5, hue_order=[], hue2count={}, width=0.9, ax=None, show_legend=True, palette='tab10'): if isinstance(palette, str): palette = sns.color_palette(palette) if label_list is None: label_list = np.asarray([hue_order[x] for x in targets_list]) hue_values, ucount = np.unique(targets_list, return_counts=True) n_bins = max(len(hue_values), len(hue_values)) bin_size = top_n hue_offset = np.arange(n_bins)*(bin_size*bin_width + hue_group_offset) # hue_label2offset = {hue_order[k]: v for k, v in zip(hue_values, hue_offset)} # print(hue_label2offset) tick_positions = [] tick_labels = [] max_x_value = 0 for idx, (hue_index, hue_indices) in enumerate(groupby(targets_list)): hue_label = hue_order[hue_index] #print(idx, hue_label, hue_indices) bottom = np.zeros(n_bins*bin_size) subset_y = label_list[hue_indices] #print(subset_y) bin_labels, bin_counts = np.unique(subset_y, return_counts=True) # if normalize: denominator = hue2count.get(hue_label, 1) bin_counts = bin_counts / denominator max_x_value = max(max_x_value, bin_counts.max()) if hue_label in hue_order: color_index = hue_order.index(hue_label) else: color_index = idx # new top_indices = np.argsort(bin_counts)[::-1][:bin_size] bin_labels = bin_labels[top_indices] bin_counts = bin_counts[top_indices] bin_indices = np.asarray([hue_label2offset[hue_label] + i for i, label in enumerate(bin_labels)]) tick_positions.extend(bin_indices) tick_labels.extend(bin_labels) # old #offset = hue_offsets.get(hue_label, 0) #bin_indices = np.asarray([label2tick[t]+offset for t in bin_labels]) p = ax.barh( bin_indices, bin_counts, width, label=hue_label, # left=bottom[bin_indices], color=palette[color_index]) # if do_stack: # bottom[bin_indices] += bin_counts # if not normalize: # bottom[bin_indices] += bar_offset line_pos = bin_indices.max() + width/2 + hue_group_offset/2 ax.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE) if show_legend: ax.legend( loc='upper center', bbox_to_anchor=(0.5, -0.05), fancybox=True, shadow=True, ncol=4 ) ax.set_yticks(tick_positions) ax.set_yticklabels(tick_labels) if max_x_value <= 1: ax.set_xlim(0, 1.) ax.set_ylim(-0.5, np.max(tick_positions)+width/2+hue_group_offset/2) ax.invert_yaxis()