latticetower commited on
Commit
b40aac1
·
1 Parent(s): 2d1d8cb

fix avxline in plots, use common legend in gradio, add reaction and loading on launch

Browse files
Files changed (4) hide show
  1. app.py +70 -9
  2. constants.py +17 -0
  3. mpl_data_plotter.py +28 -23
  4. plot_utils.py +1 -1
app.py CHANGED
@@ -17,11 +17,13 @@ def convert_int64_to_int32(df):
17
 
18
  print(f"Loading domains data...")
19
  single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
20
- single_df['biosyn_class_index'] = single_df.bgc_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
 
21
  single_df = convert_int64_to_int32(single_df)
22
 
23
  pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
24
- pair_df['biosyn_class_index'] = pair_df.bgc_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
 
25
  pair_df = convert_int64_to_int32(pair_df)
26
 
27
  num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
@@ -33,6 +35,53 @@ print(f"Initializing data plotter...")
33
  data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def update_all_plots(frequency, split_name):
37
  return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
38
 
@@ -43,6 +92,8 @@ with gr.Blocks(title="BGC Keyword Plotter") as demo:
43
  gr.Markdown("## BGC Keyword Plotter")
44
  gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
45
 
 
 
46
  with gr.Row():
47
  frequency_slider = gr.Slider(
48
  minimum=int(unique_domain_lengths.min()),
@@ -51,14 +102,13 @@ with gr.Blocks(title="BGC Keyword Plotter") as demo:
51
  value=int(unique_domain_lengths.min()),
52
  label="Min number of domains"
53
  )
 
 
 
 
 
54
 
55
  with gr.Row():
56
- with gr.Column():
57
- split_selector = gr.Dropdown(
58
- choices=["stratified"] + BIOSYN_CLASS_NAMES,
59
- value="stratified",
60
- label="Split name"
61
- )
62
  with gr.Column():
63
  single_domains_plot = gr.Plot(
64
  label="Single domains",
@@ -80,11 +130,22 @@ with gr.Blocks(title="BGC Keyword Plotter") as demo:
80
 
81
  frequency_slider.release(
82
  fn=update_all_plots,
83
- inputs=[frequency_slider, split_selector],
84
  outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
85
  )
 
 
 
 
 
 
 
 
 
 
86
 
87
 
88
  print(f"Launching!...")
89
  demo.launch()
 
90
  # demo.load(filter_map, [min_price, max_price, boroughs], map)
 
17
 
18
  print(f"Loading domains data...")
19
  single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
20
+ single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
21
+ single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
22
  single_df = convert_int64_to_int32(single_df)
23
 
24
  pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
25
+ pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
26
+ pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
27
  pair_df = convert_int64_to_int32(pair_df)
28
 
29
  num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
 
35
  data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
36
 
37
 
38
+ def create_color_legend(class_to_color):
39
+ # Create HTML for the color legend
40
+ legend_html = """
41
+ <div style="
42
+ margin: 10px 0;
43
+ padding: 10px;
44
+ border: 1px solid #ddd;
45
+ border-radius: 4px;
46
+ background: white;
47
+ ">
48
+ <div style="
49
+ font-weight: bold;
50
+ margin-bottom: 8px;
51
+ ">Color Legend:</div>
52
+ <div style="
53
+ display: flex;
54
+ flex-wrap: wrap;
55
+ gap: 15px;
56
+ align-items: center;
57
+ ">
58
+ """
59
+ # Add each class and its color
60
+ for class_name, color in class_to_color.items():
61
+ legend_html += f"""
62
+ <div style="
63
+ display: flex;
64
+ align-items: center;
65
+ gap: 5px;
66
+ ">
67
+ <div style="
68
+ width: 20px;
69
+ height: 20px;
70
+ background-color: {color};
71
+ border-radius: 3px;
72
+ "></div>
73
+ <span>{class_name}</span>
74
+ </div>
75
+ """
76
+
77
+ legend_html += """
78
+ </div>
79
+ </div>
80
+ """
81
+
82
+ return gr.HTML(legend_html)
83
+
84
+
85
  def update_all_plots(frequency, split_name):
86
  return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
87
 
 
92
  gr.Markdown("## BGC Keyword Plotter")
93
  gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
94
 
95
+ color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
96
+
97
  with gr.Row():
98
  frequency_slider = gr.Slider(
99
  minimum=int(unique_domain_lengths.min()),
 
102
  value=int(unique_domain_lengths.min()),
103
  label="Min number of domains"
104
  )
105
+ model_selector = gr.Radio(
106
+ choices=["stratified"] + BIOSYN_CLASS_NAMES,
107
+ value="stratified",
108
+ label="Model name"
109
+ )
110
 
111
  with gr.Row():
 
 
 
 
 
 
112
  with gr.Column():
113
  single_domains_plot = gr.Plot(
114
  label="Single domains",
 
130
 
131
  frequency_slider.release(
132
  fn=update_all_plots,
133
+ inputs=[frequency_slider, model_selector],
134
  outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
135
  )
136
+ demo.load(
137
+ fn=update_all_plots,
138
+ inputs=[frequency_slider, model_selector],
139
+ outputs=[single_domains_plot, pair_domains_plot]
140
+ )
141
+ model_selector.input(
142
+ fn=update_all_plots,
143
+ inputs=[frequency_slider, model_selector],
144
+ outputs=[single_domains_plot, pair_domains_plot]
145
+ )
146
 
147
 
148
  print(f"Launching!...")
149
  demo.launch()
150
+
151
  # demo.load(filter_map, [min_price, max_price, boroughs], map)
constants.py CHANGED
@@ -1,4 +1,6 @@
1
 
 
 
2
  POSTER_BLUE = '#01589C'
3
 
4
  BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Terpene', "Other"]
@@ -6,3 +8,18 @@ BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Te
6
  SINGLE_DOMAINS_FILE = 'data/single_domains.csv.gz'
7
  PAIR_DOMAINS_FILE = 'data/pair_domains.csv.gz'
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import seaborn as sns
3
+
4
  POSTER_BLUE = '#01589C'
5
 
6
  BIOSYN_CLASS_NAMES = ['Alkaloid', 'NRP', 'Polyketide', 'RiPP', 'Saccharide', 'Terpene', "Other"]
 
8
  SINGLE_DOMAINS_FILE = 'data/single_domains.csv.gz'
9
  PAIR_DOMAINS_FILE = 'data/pair_domains.csv.gz'
10
 
11
+ BIOSYN_CLASS_HEX_COLORS = {
12
+ 'Alkaloid': '#a1c9f4',
13
+ 'NRP': '#ffb482',
14
+ 'Polyketide': '#8de5a1',
15
+ 'RiPP': '#ff9f9b',
16
+ 'Saccharide': '#d0bbff',
17
+ 'Terpene': '#debb9b',
18
+ 'Other': '#cfcfcf',
19
+ # 'stratified': '#01589C', # just in case
20
+ }
21
+
22
+ COLOR_PALETTE = sns.color_palette([
23
+ BIOSYN_CLASS_HEX_COLORS[biosyn_class]
24
+ for biosyn_class in BIOSYN_CLASS_NAMES
25
+ ])
mpl_data_plotter.py CHANGED
@@ -21,7 +21,12 @@ class MatplotlibDataPlotter:
21
  selected_region_ids = self.num_domains_in_region_df.loc[
22
  self.num_domains_in_region_df.num_domains >= num_domains,
23
  'cds_region_id'].values
 
24
  single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]
 
 
 
 
25
  # split_name = 'stratified'
26
  column_name = f'cosine_similarity_{split_name}'
27
  # single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
@@ -35,12 +40,8 @@ class MatplotlibDataPlotter:
35
  bin_width=1
36
  hue_group_offset=0.5
37
  # hue_order=BIOSYN_CLASS_NAMES
38
- hue2count={}
39
  width=0.9
40
 
41
- show_legend=True
42
- print(matplotlib.get_backend())
43
-
44
  fig = self.single_domains_fig
45
  fig.clf()
46
 
@@ -48,23 +49,29 @@ class MatplotlibDataPlotter:
48
  plot_utils.draw_barplots(
49
  targets_list,
50
  label_list=label_list,
51
- top_n=5,
52
- bin_width=1,
53
- hue_group_offset=0.5,
54
  hue_order=BIOSYN_CLASS_NAMES,
55
- hue2count={},
56
- width=0.9,
57
  ax=ax,
58
- show_legend=True
 
59
  )
60
- plt.tight_layout()
61
- return fig # plt.gcf()
62
 
63
  def plot_pair_domains(self, num_domains, split_name):
64
  selected_region_ids = self.num_domains_in_region_df.loc[
65
  self.num_domains_in_region_df.num_domains >= num_domains,
66
  'cds_region_id'].values
 
67
  pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
 
 
 
 
68
  # split_name = 'stratified'
69
  column_name = f'cosine_similarity_{split_name}'
70
  # pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
@@ -83,27 +90,25 @@ class MatplotlibDataPlotter:
83
  hue2count={}
84
  width=0.9
85
 
86
- show_legend=True
87
- # fig = plt.figure(figsize=(5, 10))
88
  fig = self.pair_domains_fig
89
- # fig = plt.gcf()
90
  fig.clf()
91
- print(matplotlib.get_backend())
92
 
93
  ax = fig.gca()
94
  plot_utils.draw_barplots(
95
  targets_list,
96
  label_list=label_list,
97
- top_n=5,
98
- bin_width=1,
99
- hue_group_offset=0.5,
100
  hue_order=BIOSYN_CLASS_NAMES,
101
- hue2count={},
102
- width=0.9,
103
  ax=ax,
104
- show_legend=True
 
105
  )
106
- plt.tight_layout()
107
  return fig #plt.gcf()
108
 
109
 
 
21
  selected_region_ids = self.num_domains_in_region_df.loc[
22
  self.num_domains_in_region_df.num_domains >= num_domains,
23
  'cds_region_id'].values
24
+
25
  single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]
26
+
27
+ biosyn_counts_single = single_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
28
+ hue2count_single = dict(biosyn_counts_single.values)
29
+
30
  # split_name = 'stratified'
31
  column_name = f'cosine_similarity_{split_name}'
32
  # single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
 
40
  bin_width=1
41
  hue_group_offset=0.5
42
  # hue_order=BIOSYN_CLASS_NAMES
 
43
  width=0.9
44
 
 
 
 
45
  fig = self.single_domains_fig
46
  fig.clf()
47
 
 
49
  plot_utils.draw_barplots(
50
  targets_list,
51
  label_list=label_list,
52
+ top_n=top_n,
53
+ bin_width=bin_width,
54
+ hue_group_offset=hue_group_offset,
55
  hue_order=BIOSYN_CLASS_NAMES,
56
+ hue2count=hue2count_single,
57
+ width=width,
58
  ax=ax,
59
+ show_legend=False,
60
+ palette=COLOR_PALETTE
61
  )
62
+ fig.tight_layout()
63
+ return fig
64
 
65
  def plot_pair_domains(self, num_domains, split_name):
66
  selected_region_ids = self.num_domains_in_region_df.loc[
67
  self.num_domains_in_region_df.num_domains >= num_domains,
68
  'cds_region_id'].values
69
+
70
  pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
71
+
72
+ biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
73
+ hue2count_pairs = dict(biosyn_counts_pairs.values)
74
+
75
  # split_name = 'stratified'
76
  column_name = f'cosine_similarity_{split_name}'
77
  # pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
 
90
  hue2count={}
91
  width=0.9
92
 
93
+ show_legend=False
 
94
  fig = self.pair_domains_fig
 
95
  fig.clf()
 
96
 
97
  ax = fig.gca()
98
  plot_utils.draw_barplots(
99
  targets_list,
100
  label_list=label_list,
101
+ top_n=top_n,
102
+ bin_width=bin_width,
103
+ hue_group_offset=hue_group_offset,
104
  hue_order=BIOSYN_CLASS_NAMES,
105
+ hue2count=hue2count_pairs,
106
+ width=width,
107
  ax=ax,
108
+ show_legend=show_legend,
109
+ palette=COLOR_PALETTE
110
  )
111
+ fig.tight_layout()
112
  return fig #plt.gcf()
113
 
114
 
plot_utils.py CHANGED
@@ -76,7 +76,7 @@ def draw_barplots(targets_list, label_list=None, top_n=5, bin_width=1,
76
  # if not normalize:
77
  # bottom[bin_indices] += bar_offset
78
  line_pos = bin_indices.max() + width/2 + hue_group_offset/2
79
- plt.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE)
80
  if show_legend:
81
  ax.legend(
82
  loc='upper center', bbox_to_anchor=(0.5, -0.05),
 
76
  # if not normalize:
77
  # bottom[bin_indices] += bar_offset
78
  line_pos = bin_indices.max() + width/2 + hue_group_offset/2
79
+ ax.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE)
80
  if show_legend:
81
  ax.legend(
82
  loc='upper center', bbox_to_anchor=(0.5, -0.05),