import gradio as gr import pandas as pd import numpy as np import matplotlib.pyplot as plt from constants import * from mpl_data_plotter import MatplotlibDataPlotter def convert_int64_to_int32(df): for col in df.columns: if df[col].dtype == 'int64': df[col] = df[col].astype('int32') return df def create_color_legend(class_to_color): # Create HTML for the color legend legend_html = """
Color Legend:
""" # Add each class and its color for class_name, color in class_to_color.items(): legend_html += f"""
{class_name}
""" legend_html += """
""" return gr.HTML(legend_html) def update_all_plots(frequency, split_name): return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name) if __name__ == "__main__": print(f"Loading domains data...") single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip') single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True) single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x)) single_df = convert_int64_to_int32(single_df) pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip') pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True) pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x)) pair_df = convert_int64_to_int32(pair_df) num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename( columns={'as_domain_id': 'num_domains'}) unique_domain_lengths = num_domains_in_region_df.num_domains.unique() print(f"Initializing data plotter...") data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df) print(f"Defining blocks...") # Create Gradio interface with gr.Blocks(title="BGC Keyword Plotter") as demo: gr.Markdown("## BGC Keyword Plotter") gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.") color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS) with gr.Row(): frequency_slider = gr.Slider( minimum=int(unique_domain_lengths.min()), maximum=int(unique_domain_lengths.max()), step=1, value=int(unique_domain_lengths.min()), label="Min number of domains" ) model_selector = gr.Radio( choices=["stratified"] + BIOSYN_CLASS_NAMES, value="stratified", label="Model name" ) with gr.Row(): with gr.Column(): single_domains_plot = gr.Plot( label="Single domains", container=True, elem_id="single_domains_plot" ) with gr.Column(): pair_domains_plot = gr.Plot(label="Pair domains") frequency_slider.release( fn=update_all_plots, inputs=[frequency_slider, model_selector], outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot] ) demo.load( fn=update_all_plots, inputs=[frequency_slider, model_selector], outputs=[single_domains_plot, pair_domains_plot] ) model_selector.input( fn=update_all_plots, inputs=[frequency_slider, model_selector], outputs=[single_domains_plot, pair_domains_plot] ) print(f"Launching!...") demo.launch() # demo.load(filter_map, [min_price, max_price, boroughs], map)