import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from constants import *
from mpl_data_plotter import MatplotlibDataPlotter
def convert_int64_to_int32(df):
for col in df.columns:
if df[col].dtype == 'int64':
df[col] = df[col].astype('int32')
return df
def create_color_legend(class_to_color):
# Create HTML for the color legend
legend_html = """
Color Legend:
"""
# Add each class and its color
for class_name, color in class_to_color.items():
legend_html += f"""
"""
legend_html += """
"""
return gr.HTML(legend_html)
def update_all_plots(frequency, split_name):
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
if __name__ == "__main__":
print(f"Loading domains data...")
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
single_df = convert_int64_to_int32(single_df)
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
pair_df = convert_int64_to_int32(pair_df)
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
columns={'as_domain_id': 'num_domains'})
unique_domain_lengths = num_domains_in_region_df.num_domains.unique()
print(f"Initializing data plotter...")
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
print(f"Defining blocks...")
# Create Gradio interface
with gr.Blocks(title="BGC Keyword Plotter") as demo:
gr.Markdown("## BGC Keyword Plotter")
gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
with gr.Row():
frequency_slider = gr.Slider(
minimum=int(unique_domain_lengths.min()),
maximum=int(unique_domain_lengths.max()),
step=1,
value=int(unique_domain_lengths.min()),
label="Min number of domains"
)
model_selector = gr.Radio(
choices=["stratified"] + BIOSYN_CLASS_NAMES,
value="stratified",
label="Model name"
)
with gr.Row():
with gr.Column():
single_domains_plot = gr.Plot(
label="Single domains",
container=True,
elem_id="single_domains_plot"
)
with gr.Column():
pair_domains_plot = gr.Plot(label="Pair domains")
frequency_slider.release(
fn=update_all_plots,
inputs=[frequency_slider, model_selector],
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
)
demo.load(
fn=update_all_plots,
inputs=[frequency_slider, model_selector],
outputs=[single_domains_plot, pair_domains_plot]
)
model_selector.input(
fn=update_all_plots,
inputs=[frequency_slider, model_selector],
outputs=[single_domains_plot, pair_domains_plot]
)
print(f"Launching!...")
demo.launch()
# demo.load(filter_map, [min_price, max_price, boroughs], map)