keyword-embeddings-space / mpl_data_plotter.py
latticetower's picture
fix avxline in plots, use common legend in gradio, add reaction and loading on launch
b40aac1
raw
history blame
3.98 kB
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')
import plot_utils
from constants import *
class MatplotlibDataPlotter:
def __init__(self, single_df, pair_df, num_domains_in_region_df):
self.single_df = single_df
self.pair_df = pair_df
self.num_domains_in_region_df = num_domains_in_region_df
self.single_domains_fig = plt.figure(figsize=(5, 10))
self.pair_domains_fig = plt.figure(figsize=(5, 10))
def plot_single_domains(self, num_domains, split_name):
selected_region_ids = self.num_domains_in_region_df.loc[
self.num_domains_in_region_df.num_domains >= num_domains,
'cds_region_id'].values
single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]
biosyn_counts_single = single_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
hue2count_single = dict(biosyn_counts_single.values)
# split_name = 'stratified'
column_name = f'cosine_similarity_{split_name}'
# single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
selected_keyword_index = single_df_subset.groupby('cds_region_id').agg(
{column_name: 'idxmax'}
).values.flatten()
targets_list = single_df_subset.loc[selected_keyword_index, 'biosyn_class_index'].values
label_list = single_df_subset.loc[selected_keyword_index, 'profile_name'].values
top_n=5
bin_width=1
hue_group_offset=0.5
# hue_order=BIOSYN_CLASS_NAMES
width=0.9
fig = self.single_domains_fig
fig.clf()
ax = fig.gca()
plot_utils.draw_barplots(
targets_list,
label_list=label_list,
top_n=top_n,
bin_width=bin_width,
hue_group_offset=hue_group_offset,
hue_order=BIOSYN_CLASS_NAMES,
hue2count=hue2count_single,
width=width,
ax=ax,
show_legend=False,
palette=COLOR_PALETTE
)
fig.tight_layout()
return fig
def plot_pair_domains(self, num_domains, split_name):
selected_region_ids = self.num_domains_in_region_df.loc[
self.num_domains_in_region_df.num_domains >= num_domains,
'cds_region_id'].values
pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
hue2count_pairs = dict(biosyn_counts_pairs.values)
# split_name = 'stratified'
column_name = f'cosine_similarity_{split_name}'
# pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
selected_keyword_index = pair_df_subset.groupby('cds_region_id').agg(
{column_name: 'idxmax'}
).values.flatten()
targets_list = pair_df_subset.loc[
selected_keyword_index, 'biosyn_class_index'].values
label_list=pair_df_subset.loc[
selected_keyword_index, 'profile_name'].values
top_n=5
bin_width=1
hue_group_offset=0.5
# hue_order=BIOSYN_CLASS_NAMES
hue2count={}
width=0.9
show_legend=False
fig = self.pair_domains_fig
fig.clf()
ax = fig.gca()
plot_utils.draw_barplots(
targets_list,
label_list=label_list,
top_n=top_n,
bin_width=bin_width,
hue_group_offset=hue_group_offset,
hue_order=BIOSYN_CLASS_NAMES,
hue2count=hue2count_pairs,
width=width,
ax=ax,
show_legend=show_legend,
palette=COLOR_PALETTE
)
fig.tight_layout()
return fig #plt.gcf()