latticetower commited on
Commit
694c1c6
·
1 Parent(s): 1d4cb28

make draft app

Browse files
Files changed (3) hide show
  1. constants.py +6 -0
  2. mpl_data_plotter.py +109 -0
  3. plot_utils.py +93 -0
constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Terpene', "Other"]
3
+
4
+ SINGLE_DOMAINS_FILE = 'data/single_domains.csv.gz'
5
+ PAIR_DOMAINS_FILE = 'data/pair_domains.csv.gz'
6
+
mpl_data_plotter.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import matplotlib.pyplot as plt
3
+ import plot_utils
4
+ from constants import *
5
+
6
+
7
+ class MatplotlibDataPlotter:
8
+ def __init__(self, single_df, pair_df, num_domains_in_region_df):
9
+ self.single_df = single_df
10
+ self.pair_df = pair_df
11
+
12
+ self.num_domains_in_region_df = num_domains_in_region_df
13
+
14
+ self.single_domains_fig = plt.figure(figsize=(5, 10))
15
+ self.pair_domains_fig = plt.figure(figsize=(5, 10))
16
+
17
+ def plot_single_domains(self, num_domains, split_name):
18
+
19
+ #fig = plt.gcf()
20
+ #fig.set_size_inches(10, 5)
21
+ selected_region_ids = self.num_domains_in_region_df.loc[
22
+ self.num_domains_in_region_df.num_domains >= num_domains,
23
+ 'cds_region_id'].values
24
+ single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]
25
+
26
+ split_name = 'stratified'
27
+ column_name = f'cosine_similarity_{split_name}'
28
+ # single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
29
+ selected_keyword_index = single_df_subset.groupby('cds_region_id').agg(
30
+ {column_name: 'idxmax'}
31
+ ).values.flatten()
32
+ targets_list = single_df_subset.loc[selected_keyword_index, 'biosyn_class_index'].values
33
+ label_list = single_df_subset.loc[selected_keyword_index, 'profile_name'].values
34
+
35
+ top_n=5
36
+ bin_width=1
37
+ hue_group_offset=0.5
38
+ # hue_order=BIOSYN_CLASS_NAMES
39
+ hue2count={}
40
+ width=0.9
41
+
42
+ show_legend=True
43
+
44
+ fig = self.single_domains_fig
45
+ fig.clf()
46
+
47
+ ax = fig.gca()
48
+ plot_utils.draw_barplots(
49
+ targets_list,
50
+ label_list=label_list,
51
+ top_n=5,
52
+ bin_width=1,
53
+ hue_group_offset=0.5,
54
+ hue_order=BIOSYN_CLASS_NAMES,
55
+ hue2count={},
56
+ width=0.9,
57
+ ax=ax,
58
+ show_legend=True
59
+ )
60
+ plt.tight_layout()
61
+ return fig # plt.gcf()
62
+
63
+ def plot_pair_domains(self, num_domains, split_name):
64
+ selected_region_ids = self.num_domains_in_region_df.loc[
65
+ self.num_domains_in_region_df.num_domains >= num_domains,
66
+ 'cds_region_id'].values
67
+ pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
68
+
69
+ split_name = 'stratified'
70
+ column_name = f'cosine_similarity_{split_name}'
71
+ # pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
72
+ selected_keyword_index = pair_df_subset.groupby('cds_region_id').agg(
73
+ {column_name: 'idxmax'}
74
+ ).values.flatten()
75
+ targets_list = pair_df_subset.loc[
76
+ selected_keyword_index, 'biosyn_class_index'].values
77
+ label_list=pair_df_subset.loc[
78
+ selected_keyword_index, 'profile_name'].values
79
+
80
+ top_n=5
81
+ bin_width=1
82
+ hue_group_offset=0.5
83
+ # hue_order=BIOSYN_CLASS_NAMES
84
+ hue2count={}
85
+ width=0.9
86
+
87
+ show_legend=True
88
+ # fig = plt.figure(figsize=(5, 10))
89
+ fig = self.pair_domains_fig
90
+ fig.clf()
91
+
92
+ ax = fig.gca()
93
+ plot_utils.draw_barplots(
94
+ targets_list,
95
+ label_list=label_list,
96
+ top_n=5,
97
+ bin_width=1,
98
+ hue_group_offset=0.5,
99
+ hue_order=BIOSYN_CLASS_NAMES,
100
+ hue2count={},
101
+ width=0.9,
102
+ ax=ax,
103
+ show_legend=True
104
+ )
105
+ plt.tight_layout()
106
+ return fig #plt.gcf()
107
+
108
+
109
+
plot_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+
5
+
6
+ POSTER_BLUE = '#01589C'
7
+
8
+
9
+ def groupby(array_like, hue_order=None):
10
+ idx = np.argsort(array_like, kind='stable')
11
+ values, indices, counts = np.unique(array_like[idx], return_counts=True, return_index=True)
12
+ split_idx = np.split(idx, indices[1:])
13
+ name2indices = {group_name: indices for group_name, indices in zip(values, split_idx)}
14
+ if hue_order is not None and isinstance(hue_order, list):
15
+ for k in sorted(hue_order):
16
+ if k in name2indices:
17
+ yield k, name2indices[k]
18
+ return
19
+ for k in sorted(name2indices):
20
+ yield k, name2indices[k]
21
+
22
+
23
+ def draw_barplots(targets_list, label_list=None, top_n=5, bin_width=1,
24
+ hue_group_offset=0.5, hue_order=[],
25
+ hue2count={}, width=0.9, ax=None, show_legend=True,
26
+ palette='tab10'):
27
+ if isinstance(palette, str):
28
+ palette = sns.color_palette(palette)
29
+ if label_list is None:
30
+ label_list = np.asarray([hue_order[x] for x in targets_list])
31
+ hue_values, ucount = np.unique(targets_list, return_counts=True)
32
+ n_bins = max(len(hue_values), len(hue_values))
33
+ bin_size = top_n
34
+
35
+ hue_offset = np.arange(n_bins)*(bin_size*bin_width + hue_group_offset) #
36
+ hue_label2offset = {hue_order[k]: v for k, v in zip(hue_values, hue_offset)}
37
+ # print(hue_label2offset)
38
+ tick_positions = []
39
+ tick_labels = []
40
+ max_x_value = 0
41
+
42
+ for idx, (hue_index, hue_indices) in enumerate(groupby(targets_list)):
43
+ hue_label = hue_order[hue_index]
44
+ #print(idx, hue_label, hue_indices)
45
+ bottom = np.zeros(n_bins*bin_size)
46
+ subset_y = label_list[hue_indices]
47
+ #print(subset_y)
48
+ bin_labels, bin_counts = np.unique(subset_y, return_counts=True)
49
+
50
+ # if normalize:
51
+ denominator = hue2count.get(hue_label, 1)
52
+ bin_counts = bin_counts / denominator
53
+ max_x_value = max(max_x_value, bin_counts.max())
54
+
55
+ if hue_label in hue_order:
56
+ color_index = hue_order.index(hue_label)
57
+ else:
58
+ color_index = idx
59
+ # new
60
+ top_indices = np.argsort(bin_counts)[::-1][:bin_size]
61
+ bin_labels = bin_labels[top_indices]
62
+ bin_counts = bin_counts[top_indices]
63
+
64
+ bin_indices = np.asarray([hue_label2offset[hue_label] + i for i, label in enumerate(bin_labels)])
65
+ tick_positions.extend(bin_indices)
66
+ tick_labels.extend(bin_labels)
67
+ # old
68
+ #offset = hue_offsets.get(hue_label, 0)
69
+
70
+ #bin_indices = np.asarray([label2tick[t]+offset for t in bin_labels])
71
+
72
+ p = ax.barh(
73
+ bin_indices, bin_counts, width, label=hue_label, # left=bottom[bin_indices],
74
+ color=palette[color_index])
75
+ # if do_stack:
76
+ # bottom[bin_indices] += bin_counts
77
+ # if not normalize:
78
+ # bottom[bin_indices] += bar_offset
79
+ line_pos = bin_indices.max() + width/2 + hue_group_offset/2
80
+ plt.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE)
81
+ if show_legend:
82
+ ax.legend(
83
+ loc='upper center', bbox_to_anchor=(0.5, -0.05),
84
+ fancybox=True, shadow=True,
85
+ ncol=4
86
+ )
87
+
88
+ ax.set_yticks(tick_positions)
89
+ ax.set_yticklabels(tick_labels)
90
+ if max_x_value <= 1:
91
+ ax.set_xlim(0, 1.)
92
+ ax.set_ylim(-0.5, np.max(tick_positions)+width/2+hue_group_offset/2)
93
+ ax.invert_yaxis()