import gradio as gr import pandas as pd import numpy as np import matplotlib.pyplot as plt from constants import * from mpl_data_plotter import MatplotlibDataPlotter def convert_int64_to_int32(df): for col in df.columns: if df[col].dtype == 'int64': print(col) df[col] = df[col].astype('int32') return df print(f"Loading domains data...") single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip') single_df['biosyn_class_index'] = single_df.bgc_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x)) single_df = convert_int64_to_int32(single_df) pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip') pair_df['biosyn_class_index'] = pair_df.bgc_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x)) pair_df = convert_int64_to_int32(pair_df) num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename( columns={'as_domain_id': 'num_domains'}) unique_domain_lengths = num_domains_in_region_df.num_domains.unique() print(f"Initializing data plotter...") data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df) def update_all_plots(frequency, split_name): return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name) print(f"Defining blocks...") # Create Gradio interface with gr.Blocks(title="Interactive Wave Plotter") as demo: gr.Markdown("## Interactive Wave Plotter") gr.Markdown("Adjust the slider to change the frequency of all waves simultaneously.") with gr.Row(): frequency_slider = gr.Slider( minimum=int(unique_domain_lengths.min()), maximum=int(unique_domain_lengths.max()), step=1, value=int(unique_domain_lengths.min()), label="Min number of domains" ) with gr.Row(): with gr.Column(): split_selector = gr.Dropdown( choices=["stratified"] + BIOSYN_CLASS_NAMES, value="stratified", label="Split name" ) with gr.Column(): single_domains_plot = gr.Plot( label="Single domains", container=True, elem_id="single_domains_plot" ) # gr.HTML(""" # # """) with gr.Column(): pair_domains_plot = gr.Plot(label="Pair domains") # with gr.Column(): # combined_plot = gr.Plot(label="Combined Wave") frequency_slider.release( fn=update_all_plots, inputs=[frequency_slider, split_selector], outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot] ) print(f"Launching!...") demo.launch() # demo.load(filter_map, [min_price, max_price, boroughs], map)