latticetower commited on
Commit
af49af1
·
1 Parent(s): b40aac1
Files changed (2) hide show
  1. app.py +73 -81
  2. mpl_data_plotter.py +3 -5
app.py CHANGED
@@ -11,29 +11,9 @@ from mpl_data_plotter import MatplotlibDataPlotter
11
  def convert_int64_to_int32(df):
12
  for col in df.columns:
13
  if df[col].dtype == 'int64':
14
- print(col)
15
  df[col] = df[col].astype('int32')
16
  return df
17
 
18
- print(f"Loading domains data...")
19
- single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
20
- single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
21
- single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
22
- single_df = convert_int64_to_int32(single_df)
23
-
24
- pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
25
- pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
26
- pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
27
- pair_df = convert_int64_to_int32(pair_df)
28
-
29
- num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
30
- columns={'as_domain_id': 'num_domains'})
31
-
32
- unique_domain_lengths = num_domains_in_region_df.num_domains.unique()
33
-
34
- print(f"Initializing data plotter...")
35
- data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
36
-
37
 
38
  def create_color_legend(class_to_color):
39
  # Create HTML for the color legend
@@ -86,66 +66,78 @@ def update_all_plots(frequency, split_name):
86
  return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
87
 
88
 
89
- print(f"Defining blocks...")
90
- # Create Gradio interface
91
- with gr.Blocks(title="BGC Keyword Plotter") as demo:
92
- gr.Markdown("## BGC Keyword Plotter")
93
- gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
94
-
95
- color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
96
-
97
- with gr.Row():
98
- frequency_slider = gr.Slider(
99
- minimum=int(unique_domain_lengths.min()),
100
- maximum=int(unique_domain_lengths.max()),
101
- step=1,
102
- value=int(unique_domain_lengths.min()),
103
- label="Min number of domains"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  )
105
- model_selector = gr.Radio(
106
- choices=["stratified"] + BIOSYN_CLASS_NAMES,
107
- value="stratified",
108
- label="Model name"
109
  )
110
-
111
- with gr.Row():
112
- with gr.Column():
113
- single_domains_plot = gr.Plot(
114
- label="Single domains",
115
- container=True,
116
- elem_id="single_domains_plot"
117
- )
118
- # gr.HTML("""
119
- # <style>
120
- # #single_domains_plot {
121
- # height: 100% !important;
122
- # width: 100% !important;
123
- # }
124
- # </style>
125
- # """)
126
- with gr.Column():
127
- pair_domains_plot = gr.Plot(label="Pair domains")
128
- # with gr.Column():
129
- # combined_plot = gr.Plot(label="Combined Wave")
130
-
131
- frequency_slider.release(
132
- fn=update_all_plots,
133
- inputs=[frequency_slider, model_selector],
134
- outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
135
- )
136
- demo.load(
137
- fn=update_all_plots,
138
- inputs=[frequency_slider, model_selector],
139
- outputs=[single_domains_plot, pair_domains_plot]
140
- )
141
- model_selector.input(
142
- fn=update_all_plots,
143
- inputs=[frequency_slider, model_selector],
144
- outputs=[single_domains_plot, pair_domains_plot]
145
- )
146
-
147
-
148
- print(f"Launching!...")
149
- demo.launch()
150
-
151
- # demo.load(filter_map, [min_price, max_price, boroughs], map)
 
11
  def convert_int64_to_int32(df):
12
  for col in df.columns:
13
  if df[col].dtype == 'int64':
 
14
  df[col] = df[col].astype('int32')
15
  return df
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def create_color_legend(class_to_color):
19
  # Create HTML for the color legend
 
66
  return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
67
 
68
 
69
+ if __name__ == "__main__":
70
+ print(f"Loading domains data...")
71
+ single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
72
+ single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
73
+ single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
74
+ single_df = convert_int64_to_int32(single_df)
75
+
76
+ pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
77
+ pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
78
+ pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
79
+ pair_df = convert_int64_to_int32(pair_df)
80
+
81
+ num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
82
+ columns={'as_domain_id': 'num_domains'})
83
+
84
+ unique_domain_lengths = num_domains_in_region_df.num_domains.unique()
85
+
86
+ print(f"Initializing data plotter...")
87
+ data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
88
+
89
+
90
+ print(f"Defining blocks...")
91
+
92
+ # Create Gradio interface
93
+ with gr.Blocks(title="BGC Keyword Plotter") as demo:
94
+ gr.Markdown("## BGC Keyword Plotter")
95
+ gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
96
+
97
+ color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
98
+
99
+ with gr.Row():
100
+ frequency_slider = gr.Slider(
101
+ minimum=int(unique_domain_lengths.min()),
102
+ maximum=int(unique_domain_lengths.max()),
103
+ step=1,
104
+ value=int(unique_domain_lengths.min()),
105
+ label="Min number of domains"
106
+ )
107
+ model_selector = gr.Radio(
108
+ choices=["stratified"] + BIOSYN_CLASS_NAMES,
109
+ value="stratified",
110
+ label="Model name"
111
+ )
112
+
113
+ with gr.Row():
114
+ with gr.Column():
115
+ single_domains_plot = gr.Plot(
116
+ label="Single domains",
117
+ container=True,
118
+ elem_id="single_domains_plot"
119
+ )
120
+ with gr.Column():
121
+ pair_domains_plot = gr.Plot(label="Pair domains")
122
+
123
+ frequency_slider.release(
124
+ fn=update_all_plots,
125
+ inputs=[frequency_slider, model_selector],
126
+ outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
127
  )
128
+ demo.load(
129
+ fn=update_all_plots,
130
+ inputs=[frequency_slider, model_selector],
131
+ outputs=[single_domains_plot, pair_domains_plot]
132
  )
133
+ model_selector.input(
134
+ fn=update_all_plots,
135
+ inputs=[frequency_slider, model_selector],
136
+ outputs=[single_domains_plot, pair_domains_plot]
137
+ )
138
+
139
+
140
+ print(f"Launching!...")
141
+ demo.launch()
142
+
143
+ # demo.load(filter_map, [min_price, max_price, boroughs], map)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mpl_data_plotter.py CHANGED
@@ -17,7 +17,7 @@ class MatplotlibDataPlotter:
17
  self.single_domains_fig = plt.figure(figsize=(5, 10))
18
  self.pair_domains_fig = plt.figure(figsize=(5, 10))
19
 
20
- def plot_single_domains(self, num_domains, split_name):
21
  selected_region_ids = self.num_domains_in_region_df.loc[
22
  self.num_domains_in_region_df.num_domains >= num_domains,
23
  'cds_region_id'].values
@@ -39,7 +39,6 @@ class MatplotlibDataPlotter:
39
  top_n=5
40
  bin_width=1
41
  hue_group_offset=0.5
42
- # hue_order=BIOSYN_CLASS_NAMES
43
  width=0.9
44
 
45
  fig = self.single_domains_fig
@@ -62,7 +61,7 @@ class MatplotlibDataPlotter:
62
  fig.tight_layout()
63
  return fig
64
 
65
- def plot_pair_domains(self, num_domains, split_name):
66
  selected_region_ids = self.num_domains_in_region_df.loc[
67
  self.num_domains_in_region_df.num_domains >= num_domains,
68
  'cds_region_id'].values
@@ -72,9 +71,8 @@ class MatplotlibDataPlotter:
72
  biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
73
  hue2count_pairs = dict(biosyn_counts_pairs.values)
74
 
75
- # split_name = 'stratified'
76
  column_name = f'cosine_similarity_{split_name}'
77
- # pair_df_subset = pair_df.loc[pair_df.dom_location_len >= num_domains]
78
  selected_keyword_index = pair_df_subset.groupby('cds_region_id').agg(
79
  {column_name: 'idxmax'}
80
  ).values.flatten()
 
17
  self.single_domains_fig = plt.figure(figsize=(5, 10))
18
  self.pair_domains_fig = plt.figure(figsize=(5, 10))
19
 
20
+ def plot_single_domains(self, num_domains, split_name="stratified"):
21
  selected_region_ids = self.num_domains_in_region_df.loc[
22
  self.num_domains_in_region_df.num_domains >= num_domains,
23
  'cds_region_id'].values
 
39
  top_n=5
40
  bin_width=1
41
  hue_group_offset=0.5
 
42
  width=0.9
43
 
44
  fig = self.single_domains_fig
 
61
  fig.tight_layout()
62
  return fig
63
 
64
+ def plot_pair_domains(self, num_domains, split_name="stratified"):
65
  selected_region_ids = self.num_domains_in_region_df.loc[
66
  self.num_domains_in_region_df.num_domains >= num_domains,
67
  'cds_region_id'].values
 
71
  biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
72
  hue2count_pairs = dict(biosyn_counts_pairs.values)
73
 
 
74
  column_name = f'cosine_similarity_{split_name}'
75
+
76
  selected_keyword_index = pair_df_subset.groupby('cds_region_id').agg(
77
  {column_name: 'idxmax'}
78
  ).values.flatten()