Spaces:
Sleeping
Sleeping
Commit
·
694c1c6
1
Parent(s):
1d4cb28
make draft app
Browse files- constants.py +6 -0
- mpl_data_plotter.py +109 -0
- plot_utils.py +93 -0
constants.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Terpene', "Other"]
|
3 |
+
|
4 |
+
SINGLE_DOMAINS_FILE = 'data/single_domains.csv.gz'
|
5 |
+
PAIR_DOMAINS_FILE = 'data/pair_domains.csv.gz'
|
6 |
+
|
mpl_data_plotter.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import plot_utils
|
4 |
+
from constants import *
|
5 |
+
|
6 |
+
|
7 |
+
class MatplotlibDataPlotter:
|
8 |
+
def __init__(self, single_df, pair_df, num_domains_in_region_df):
|
9 |
+
self.single_df = single_df
|
10 |
+
self.pair_df = pair_df
|
11 |
+
|
12 |
+
self.num_domains_in_region_df = num_domains_in_region_df
|
13 |
+
|
14 |
+
self.single_domains_fig = plt.figure(figsize=(5, 10))
|
15 |
+
self.pair_domains_fig = plt.figure(figsize=(5, 10))
|
16 |
+
|
17 |
+
def plot_single_domains(self, num_domains, split_name):
|
18 |
+
|
19 |
+
#fig = plt.gcf()
|
20 |
+
#fig.set_size_inches(10, 5)
|
21 |
+
selected_region_ids = self.num_domains_in_region_df.loc[
|
22 |
+
self.num_domains_in_region_df.num_domains >= num_domains,
|
23 |
+
'cds_region_id'].values
|
24 |
+
single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]
|
25 |
+
|
26 |
+
split_name = 'stratified'
|
27 |
+
column_name = f'cosine_similarity_{split_name}'
|
28 |
+
# single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
|
29 |
+
selected_keyword_index = single_df_subset.groupby('cds_region_id').agg(
|
30 |
+
{column_name: 'idxmax'}
|
31 |
+
).values.flatten()
|
32 |
+
targets_list = single_df_subset.loc[selected_keyword_index, 'biosyn_class_index'].values
|
33 |
+
label_list = single_df_subset.loc[selected_keyword_index, 'profile_name'].values
|
34 |
+
|
35 |
+
top_n=5
|
36 |
+
bin_width=1
|
37 |
+
hue_group_offset=0.5
|
38 |
+
# hue_order=BIOSYN_CLASS_NAMES
|
39 |
+
hue2count={}
|
40 |
+
width=0.9
|
41 |
+
|
42 |
+
show_legend=True
|
43 |
+
|
44 |
+
fig = self.single_domains_fig
|
45 |
+
fig.clf()
|
46 |
+
|
47 |
+
ax = fig.gca()
|
48 |
+
plot_utils.draw_barplots(
|
49 |
+
targets_list,
|
50 |
+
label_list=label_list,
|
51 |
+
top_n=5,
|
52 |
+
bin_width=1,
|
53 |
+
hue_group_offset=0.5,
|
54 |
+
hue_order=BIOSYN_CLASS_NAMES,
|
55 |
+
hue2count={},
|
56 |
+
width=0.9,
|
57 |
+
ax=ax,
|
58 |
+
show_legend=True
|
59 |
+
)
|
60 |
+
plt.tight_layout()
|
61 |
+
return fig # plt.gcf()
|
62 |
+
|
63 |
+
def plot_pair_domains(self, num_domains, split_name):
|
64 |
+
selected_region_ids = self.num_domains_in_region_df.loc[
|
65 |
+
self.num_domains_in_region_df.num_domains >= num_domains,
|
66 |
+
'cds_region_id'].values
|
67 |
+
pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
|
68 |
+
|
69 |
+
split_name = 'stratified'
|
70 |
+
column_name = f'cosine_similarity_{split_name}'
|
71 |
+
# pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
|
72 |
+
selected_keyword_index = pair_df_subset.groupby('cds_region_id').agg(
|
73 |
+
{column_name: 'idxmax'}
|
74 |
+
).values.flatten()
|
75 |
+
targets_list = pair_df_subset.loc[
|
76 |
+
selected_keyword_index, 'biosyn_class_index'].values
|
77 |
+
label_list=pair_df_subset.loc[
|
78 |
+
selected_keyword_index, 'profile_name'].values
|
79 |
+
|
80 |
+
top_n=5
|
81 |
+
bin_width=1
|
82 |
+
hue_group_offset=0.5
|
83 |
+
# hue_order=BIOSYN_CLASS_NAMES
|
84 |
+
hue2count={}
|
85 |
+
width=0.9
|
86 |
+
|
87 |
+
show_legend=True
|
88 |
+
# fig = plt.figure(figsize=(5, 10))
|
89 |
+
fig = self.pair_domains_fig
|
90 |
+
fig.clf()
|
91 |
+
|
92 |
+
ax = fig.gca()
|
93 |
+
plot_utils.draw_barplots(
|
94 |
+
targets_list,
|
95 |
+
label_list=label_list,
|
96 |
+
top_n=5,
|
97 |
+
bin_width=1,
|
98 |
+
hue_group_offset=0.5,
|
99 |
+
hue_order=BIOSYN_CLASS_NAMES,
|
100 |
+
hue2count={},
|
101 |
+
width=0.9,
|
102 |
+
ax=ax,
|
103 |
+
show_legend=True
|
104 |
+
)
|
105 |
+
plt.tight_layout()
|
106 |
+
return fig #plt.gcf()
|
107 |
+
|
108 |
+
|
109 |
+
|
plot_utils.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import seaborn as sns
|
4 |
+
|
5 |
+
|
6 |
+
POSTER_BLUE = '#01589C'
|
7 |
+
|
8 |
+
|
9 |
+
def groupby(array_like, hue_order=None):
|
10 |
+
idx = np.argsort(array_like, kind='stable')
|
11 |
+
values, indices, counts = np.unique(array_like[idx], return_counts=True, return_index=True)
|
12 |
+
split_idx = np.split(idx, indices[1:])
|
13 |
+
name2indices = {group_name: indices for group_name, indices in zip(values, split_idx)}
|
14 |
+
if hue_order is not None and isinstance(hue_order, list):
|
15 |
+
for k in sorted(hue_order):
|
16 |
+
if k in name2indices:
|
17 |
+
yield k, name2indices[k]
|
18 |
+
return
|
19 |
+
for k in sorted(name2indices):
|
20 |
+
yield k, name2indices[k]
|
21 |
+
|
22 |
+
|
23 |
+
def draw_barplots(targets_list, label_list=None, top_n=5, bin_width=1,
|
24 |
+
hue_group_offset=0.5, hue_order=[],
|
25 |
+
hue2count={}, width=0.9, ax=None, show_legend=True,
|
26 |
+
palette='tab10'):
|
27 |
+
if isinstance(palette, str):
|
28 |
+
palette = sns.color_palette(palette)
|
29 |
+
if label_list is None:
|
30 |
+
label_list = np.asarray([hue_order[x] for x in targets_list])
|
31 |
+
hue_values, ucount = np.unique(targets_list, return_counts=True)
|
32 |
+
n_bins = max(len(hue_values), len(hue_values))
|
33 |
+
bin_size = top_n
|
34 |
+
|
35 |
+
hue_offset = np.arange(n_bins)*(bin_size*bin_width + hue_group_offset) #
|
36 |
+
hue_label2offset = {hue_order[k]: v for k, v in zip(hue_values, hue_offset)}
|
37 |
+
# print(hue_label2offset)
|
38 |
+
tick_positions = []
|
39 |
+
tick_labels = []
|
40 |
+
max_x_value = 0
|
41 |
+
|
42 |
+
for idx, (hue_index, hue_indices) in enumerate(groupby(targets_list)):
|
43 |
+
hue_label = hue_order[hue_index]
|
44 |
+
#print(idx, hue_label, hue_indices)
|
45 |
+
bottom = np.zeros(n_bins*bin_size)
|
46 |
+
subset_y = label_list[hue_indices]
|
47 |
+
#print(subset_y)
|
48 |
+
bin_labels, bin_counts = np.unique(subset_y, return_counts=True)
|
49 |
+
|
50 |
+
# if normalize:
|
51 |
+
denominator = hue2count.get(hue_label, 1)
|
52 |
+
bin_counts = bin_counts / denominator
|
53 |
+
max_x_value = max(max_x_value, bin_counts.max())
|
54 |
+
|
55 |
+
if hue_label in hue_order:
|
56 |
+
color_index = hue_order.index(hue_label)
|
57 |
+
else:
|
58 |
+
color_index = idx
|
59 |
+
# new
|
60 |
+
top_indices = np.argsort(bin_counts)[::-1][:bin_size]
|
61 |
+
bin_labels = bin_labels[top_indices]
|
62 |
+
bin_counts = bin_counts[top_indices]
|
63 |
+
|
64 |
+
bin_indices = np.asarray([hue_label2offset[hue_label] + i for i, label in enumerate(bin_labels)])
|
65 |
+
tick_positions.extend(bin_indices)
|
66 |
+
tick_labels.extend(bin_labels)
|
67 |
+
# old
|
68 |
+
#offset = hue_offsets.get(hue_label, 0)
|
69 |
+
|
70 |
+
#bin_indices = np.asarray([label2tick[t]+offset for t in bin_labels])
|
71 |
+
|
72 |
+
p = ax.barh(
|
73 |
+
bin_indices, bin_counts, width, label=hue_label, # left=bottom[bin_indices],
|
74 |
+
color=palette[color_index])
|
75 |
+
# if do_stack:
|
76 |
+
# bottom[bin_indices] += bin_counts
|
77 |
+
# if not normalize:
|
78 |
+
# bottom[bin_indices] += bar_offset
|
79 |
+
line_pos = bin_indices.max() + width/2 + hue_group_offset/2
|
80 |
+
plt.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE)
|
81 |
+
if show_legend:
|
82 |
+
ax.legend(
|
83 |
+
loc='upper center', bbox_to_anchor=(0.5, -0.05),
|
84 |
+
fancybox=True, shadow=True,
|
85 |
+
ncol=4
|
86 |
+
)
|
87 |
+
|
88 |
+
ax.set_yticks(tick_positions)
|
89 |
+
ax.set_yticklabels(tick_labels)
|
90 |
+
if max_x_value <= 1:
|
91 |
+
ax.set_xlim(0, 1.)
|
92 |
+
ax.set_ylim(-0.5, np.max(tick_positions)+width/2+hue_group_offset/2)
|
93 |
+
ax.invert_yaxis()
|