Spaces:
Sleeping
Sleeping
Commit
·
b40aac1
1
Parent(s):
2d1d8cb
fix avxline in plots, use common legend in gradio, add reaction and loading on launch
Browse files- app.py +70 -9
- constants.py +17 -0
- mpl_data_plotter.py +28 -23
- plot_utils.py +1 -1
app.py
CHANGED
@@ -17,11 +17,13 @@ def convert_int64_to_int32(df):
|
|
17 |
|
18 |
print(f"Loading domains data...")
|
19 |
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
|
20 |
-
single_df
|
|
|
21 |
single_df = convert_int64_to_int32(single_df)
|
22 |
|
23 |
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
|
24 |
-
pair_df
|
|
|
25 |
pair_df = convert_int64_to_int32(pair_df)
|
26 |
|
27 |
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
|
@@ -33,6 +35,53 @@ print(f"Initializing data plotter...")
|
|
33 |
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def update_all_plots(frequency, split_name):
|
37 |
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
|
38 |
|
@@ -43,6 +92,8 @@ with gr.Blocks(title="BGC Keyword Plotter") as demo:
|
|
43 |
gr.Markdown("## BGC Keyword Plotter")
|
44 |
gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
|
45 |
|
|
|
|
|
46 |
with gr.Row():
|
47 |
frequency_slider = gr.Slider(
|
48 |
minimum=int(unique_domain_lengths.min()),
|
@@ -51,14 +102,13 @@ with gr.Blocks(title="BGC Keyword Plotter") as demo:
|
|
51 |
value=int(unique_domain_lengths.min()),
|
52 |
label="Min number of domains"
|
53 |
)
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
with gr.Row():
|
56 |
-
with gr.Column():
|
57 |
-
split_selector = gr.Dropdown(
|
58 |
-
choices=["stratified"] + BIOSYN_CLASS_NAMES,
|
59 |
-
value="stratified",
|
60 |
-
label="Split name"
|
61 |
-
)
|
62 |
with gr.Column():
|
63 |
single_domains_plot = gr.Plot(
|
64 |
label="Single domains",
|
@@ -80,11 +130,22 @@ with gr.Blocks(title="BGC Keyword Plotter") as demo:
|
|
80 |
|
81 |
frequency_slider.release(
|
82 |
fn=update_all_plots,
|
83 |
-
inputs=[frequency_slider,
|
84 |
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
|
85 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
|
88 |
print(f"Launching!...")
|
89 |
demo.launch()
|
|
|
90 |
# demo.load(filter_map, [min_price, max_price, boroughs], map)
|
|
|
17 |
|
18 |
print(f"Loading domains data...")
|
19 |
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
|
20 |
+
single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
|
21 |
+
single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
|
22 |
single_df = convert_int64_to_int32(single_df)
|
23 |
|
24 |
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
|
25 |
+
pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
|
26 |
+
pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
|
27 |
pair_df = convert_int64_to_int32(pair_df)
|
28 |
|
29 |
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
|
|
|
35 |
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
|
36 |
|
37 |
|
38 |
+
def create_color_legend(class_to_color):
|
39 |
+
# Create HTML for the color legend
|
40 |
+
legend_html = """
|
41 |
+
<div style="
|
42 |
+
margin: 10px 0;
|
43 |
+
padding: 10px;
|
44 |
+
border: 1px solid #ddd;
|
45 |
+
border-radius: 4px;
|
46 |
+
background: white;
|
47 |
+
">
|
48 |
+
<div style="
|
49 |
+
font-weight: bold;
|
50 |
+
margin-bottom: 8px;
|
51 |
+
">Color Legend:</div>
|
52 |
+
<div style="
|
53 |
+
display: flex;
|
54 |
+
flex-wrap: wrap;
|
55 |
+
gap: 15px;
|
56 |
+
align-items: center;
|
57 |
+
">
|
58 |
+
"""
|
59 |
+
# Add each class and its color
|
60 |
+
for class_name, color in class_to_color.items():
|
61 |
+
legend_html += f"""
|
62 |
+
<div style="
|
63 |
+
display: flex;
|
64 |
+
align-items: center;
|
65 |
+
gap: 5px;
|
66 |
+
">
|
67 |
+
<div style="
|
68 |
+
width: 20px;
|
69 |
+
height: 20px;
|
70 |
+
background-color: {color};
|
71 |
+
border-radius: 3px;
|
72 |
+
"></div>
|
73 |
+
<span>{class_name}</span>
|
74 |
+
</div>
|
75 |
+
"""
|
76 |
+
|
77 |
+
legend_html += """
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
"""
|
81 |
+
|
82 |
+
return gr.HTML(legend_html)
|
83 |
+
|
84 |
+
|
85 |
def update_all_plots(frequency, split_name):
|
86 |
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
|
87 |
|
|
|
92 |
gr.Markdown("## BGC Keyword Plotter")
|
93 |
gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
|
94 |
|
95 |
+
color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
|
96 |
+
|
97 |
with gr.Row():
|
98 |
frequency_slider = gr.Slider(
|
99 |
minimum=int(unique_domain_lengths.min()),
|
|
|
102 |
value=int(unique_domain_lengths.min()),
|
103 |
label="Min number of domains"
|
104 |
)
|
105 |
+
model_selector = gr.Radio(
|
106 |
+
choices=["stratified"] + BIOSYN_CLASS_NAMES,
|
107 |
+
value="stratified",
|
108 |
+
label="Model name"
|
109 |
+
)
|
110 |
|
111 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
with gr.Column():
|
113 |
single_domains_plot = gr.Plot(
|
114 |
label="Single domains",
|
|
|
130 |
|
131 |
frequency_slider.release(
|
132 |
fn=update_all_plots,
|
133 |
+
inputs=[frequency_slider, model_selector],
|
134 |
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
|
135 |
)
|
136 |
+
demo.load(
|
137 |
+
fn=update_all_plots,
|
138 |
+
inputs=[frequency_slider, model_selector],
|
139 |
+
outputs=[single_domains_plot, pair_domains_plot]
|
140 |
+
)
|
141 |
+
model_selector.input(
|
142 |
+
fn=update_all_plots,
|
143 |
+
inputs=[frequency_slider, model_selector],
|
144 |
+
outputs=[single_domains_plot, pair_domains_plot]
|
145 |
+
)
|
146 |
|
147 |
|
148 |
print(f"Launching!...")
|
149 |
demo.launch()
|
150 |
+
|
151 |
# demo.load(filter_map, [min_price, max_price, boroughs], map)
|
constants.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
|
|
|
|
|
2 |
POSTER_BLUE = '#01589C'
|
3 |
|
4 |
BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Terpene', "Other"]
|
@@ -6,3 +8,18 @@ BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Te
|
|
6 |
SINGLE_DOMAINS_FILE = 'data/single_domains.csv.gz'
|
7 |
PAIR_DOMAINS_FILE = 'data/pair_domains.csv.gz'
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
|
2 |
+
import seaborn as sns
|
3 |
+
|
4 |
POSTER_BLUE = '#01589C'
|
5 |
|
6 |
BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Terpene', "Other"]
|
|
|
8 |
SINGLE_DOMAINS_FILE = 'data/single_domains.csv.gz'
|
9 |
PAIR_DOMAINS_FILE = 'data/pair_domains.csv.gz'
|
10 |
|
11 |
+
BIOSYN_CLASS_HEX_COLORS = {
|
12 |
+
'Alkaloid': '#a1c9f4',
|
13 |
+
'NRP': '#ffb482',
|
14 |
+
'Polyketide': '#8de5a1',
|
15 |
+
'RiPP': '#ff9f9b',
|
16 |
+
'Saccharide': '#d0bbff',
|
17 |
+
'Terpene': '#debb9b',
|
18 |
+
'Other': '#cfcfcf',
|
19 |
+
# 'stratified': '#01589C', # just in case
|
20 |
+
}
|
21 |
+
|
22 |
+
COLOR_PALETTE = sns.color_palette([
|
23 |
+
BIOSYN_CLASS_HEX_COLORS[biosyn_class]
|
24 |
+
for biosyn_class in BIOSYN_CLASS_NAMES
|
25 |
+
])
|
mpl_data_plotter.py
CHANGED
@@ -21,7 +21,12 @@ class MatplotlibDataPlotter:
|
|
21 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
22 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
23 |
'cds_region_id'].values
|
|
|
24 |
single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]
|
|
|
|
|
|
|
|
|
25 |
# split_name = 'stratified'
|
26 |
column_name = f'cosine_similarity_{split_name}'
|
27 |
# single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
|
@@ -35,12 +40,8 @@ class MatplotlibDataPlotter:
|
|
35 |
bin_width=1
|
36 |
hue_group_offset=0.5
|
37 |
# hue_order=BIOSYN_CLASS_NAMES
|
38 |
-
hue2count={}
|
39 |
width=0.9
|
40 |
|
41 |
-
show_legend=True
|
42 |
-
print(matplotlib.get_backend())
|
43 |
-
|
44 |
fig = self.single_domains_fig
|
45 |
fig.clf()
|
46 |
|
@@ -48,23 +49,29 @@ class MatplotlibDataPlotter:
|
|
48 |
plot_utils.draw_barplots(
|
49 |
targets_list,
|
50 |
label_list=label_list,
|
51 |
-
top_n=
|
52 |
-
bin_width=
|
53 |
-
hue_group_offset=
|
54 |
hue_order=BIOSYN_CLASS_NAMES,
|
55 |
-
hue2count=
|
56 |
-
width=
|
57 |
ax=ax,
|
58 |
-
show_legend=
|
|
|
59 |
)
|
60 |
-
|
61 |
-
return fig
|
62 |
|
63 |
def plot_pair_domains(self, num_domains, split_name):
|
64 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
65 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
66 |
'cds_region_id'].values
|
|
|
67 |
pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
|
|
|
|
|
|
|
|
|
68 |
# split_name = 'stratified'
|
69 |
column_name = f'cosine_similarity_{split_name}'
|
70 |
# pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
|
@@ -83,27 +90,25 @@ class MatplotlibDataPlotter:
|
|
83 |
hue2count={}
|
84 |
width=0.9
|
85 |
|
86 |
-
show_legend=
|
87 |
-
# fig = plt.figure(figsize=(5, 10))
|
88 |
fig = self.pair_domains_fig
|
89 |
-
# fig = plt.gcf()
|
90 |
fig.clf()
|
91 |
-
print(matplotlib.get_backend())
|
92 |
|
93 |
ax = fig.gca()
|
94 |
plot_utils.draw_barplots(
|
95 |
targets_list,
|
96 |
label_list=label_list,
|
97 |
-
top_n=
|
98 |
-
bin_width=
|
99 |
-
hue_group_offset=
|
100 |
hue_order=BIOSYN_CLASS_NAMES,
|
101 |
-
hue2count=
|
102 |
-
width=
|
103 |
ax=ax,
|
104 |
-
show_legend=
|
|
|
105 |
)
|
106 |
-
|
107 |
return fig #plt.gcf()
|
108 |
|
109 |
|
|
|
21 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
22 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
23 |
'cds_region_id'].values
|
24 |
+
|
25 |
single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]
|
26 |
+
|
27 |
+
biosyn_counts_single = single_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
|
28 |
+
hue2count_single = dict(biosyn_counts_single.values)
|
29 |
+
|
30 |
# split_name = 'stratified'
|
31 |
column_name = f'cosine_similarity_{split_name}'
|
32 |
# single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
|
|
|
40 |
bin_width=1
|
41 |
hue_group_offset=0.5
|
42 |
# hue_order=BIOSYN_CLASS_NAMES
|
|
|
43 |
width=0.9
|
44 |
|
|
|
|
|
|
|
45 |
fig = self.single_domains_fig
|
46 |
fig.clf()
|
47 |
|
|
|
49 |
plot_utils.draw_barplots(
|
50 |
targets_list,
|
51 |
label_list=label_list,
|
52 |
+
top_n=top_n,
|
53 |
+
bin_width=bin_width,
|
54 |
+
hue_group_offset=hue_group_offset,
|
55 |
hue_order=BIOSYN_CLASS_NAMES,
|
56 |
+
hue2count=hue2count_single,
|
57 |
+
width=width,
|
58 |
ax=ax,
|
59 |
+
show_legend=False,
|
60 |
+
palette=COLOR_PALETTE
|
61 |
)
|
62 |
+
fig.tight_layout()
|
63 |
+
return fig
|
64 |
|
65 |
def plot_pair_domains(self, num_domains, split_name):
|
66 |
selected_region_ids = self.num_domains_in_region_df.loc[
|
67 |
self.num_domains_in_region_df.num_domains >= num_domains,
|
68 |
'cds_region_id'].values
|
69 |
+
|
70 |
pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
|
71 |
+
|
72 |
+
biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
|
73 |
+
hue2count_pairs = dict(biosyn_counts_pairs.values)
|
74 |
+
|
75 |
# split_name = 'stratified'
|
76 |
column_name = f'cosine_similarity_{split_name}'
|
77 |
# pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
|
|
|
90 |
hue2count={}
|
91 |
width=0.9
|
92 |
|
93 |
+
show_legend=False
|
|
|
94 |
fig = self.pair_domains_fig
|
|
|
95 |
fig.clf()
|
|
|
96 |
|
97 |
ax = fig.gca()
|
98 |
plot_utils.draw_barplots(
|
99 |
targets_list,
|
100 |
label_list=label_list,
|
101 |
+
top_n=top_n,
|
102 |
+
bin_width=bin_width,
|
103 |
+
hue_group_offset=hue_group_offset,
|
104 |
hue_order=BIOSYN_CLASS_NAMES,
|
105 |
+
hue2count=hue2count_pairs,
|
106 |
+
width=width,
|
107 |
ax=ax,
|
108 |
+
show_legend=show_legend,
|
109 |
+
palette=COLOR_PALETTE
|
110 |
)
|
111 |
+
fig.tight_layout()
|
112 |
return fig #plt.gcf()
|
113 |
|
114 |
|
plot_utils.py
CHANGED
@@ -76,7 +76,7 @@ def draw_barplots(targets_list, label_list=None, top_n=5, bin_width=1,
|
|
76 |
# if not normalize:
|
77 |
# bottom[bin_indices] += bar_offset
|
78 |
line_pos = bin_indices.max() + width/2 + hue_group_offset/2
|
79 |
-
|
80 |
if show_legend:
|
81 |
ax.legend(
|
82 |
loc='upper center', bbox_to_anchor=(0.5, -0.05),
|
|
|
76 |
# if not normalize:
|
77 |
# bottom[bin_indices] += bar_offset
|
78 |
line_pos = bin_indices.max() + width/2 + hue_group_offset/2
|
79 |
+
ax.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE)
|
80 |
if show_legend:
|
81 |
ax.legend(
|
82 |
loc='upper center', bbox_to_anchor=(0.5, -0.05),
|