Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from constants import * | |
from mpl_data_plotter import MatplotlibDataPlotter | |
def convert_int64_to_int32(df): | |
for col in df.columns: | |
if df[col].dtype == 'int64': | |
df[col] = df[col].astype('int32') | |
return df | |
def create_color_legend(class_to_color): | |
# Create HTML for the color legend | |
legend_html = """ | |
<div style=" | |
margin: 10px 0; | |
padding: 10px; | |
border: 1px solid #ddd; | |
border-radius: 4px; | |
background: white; | |
"> | |
<div style=" | |
font-weight: bold; | |
margin-bottom: 8px; | |
">Color Legend:</div> | |
<div style=" | |
display: flex; | |
flex-wrap: wrap; | |
gap: 15px; | |
align-items: center; | |
"> | |
""" | |
# Add each class and its color | |
for class_name, color in class_to_color.items(): | |
legend_html += f""" | |
<div style=" | |
display: flex; | |
align-items: center; | |
gap: 5px; | |
"> | |
<div style=" | |
width: 20px; | |
height: 20px; | |
background-color: {color}; | |
border-radius: 3px; | |
"></div> | |
<span>{class_name}</span> | |
</div> | |
""" | |
legend_html += """ | |
</div> | |
</div> | |
""" | |
return gr.HTML(legend_html) | |
def update_all_plots(frequency, split_name): | |
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name) | |
if __name__ == "__main__": | |
print(f"Loading domains data...") | |
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip') | |
single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True) | |
single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x)) | |
single_df = convert_int64_to_int32(single_df) | |
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip') | |
pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True) | |
pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x)) | |
pair_df = convert_int64_to_int32(pair_df) | |
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename( | |
columns={'as_domain_id': 'num_domains'}) | |
unique_domain_lengths = num_domains_in_region_df.num_domains.unique() | |
print(f"Initializing data plotter...") | |
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df) | |
print(f"Defining blocks...") | |
# Create Gradio interface | |
with gr.Blocks(title="BGC Keyword Plotter") as demo: | |
gr.Markdown("## BGC Keyword Plotter") | |
gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.") | |
color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS) | |
with gr.Row(): | |
frequency_slider = gr.Slider( | |
minimum=int(unique_domain_lengths.min()), | |
maximum=int(unique_domain_lengths.max()), | |
step=1, | |
value=int(unique_domain_lengths.min()), | |
label="Min number of domains" | |
) | |
model_selector = gr.Radio( | |
choices=["stratified"] + BIOSYN_CLASS_NAMES, | |
value="stratified", | |
label="Model name" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
single_domains_plot = gr.Plot( | |
label="Single domains", | |
container=True, | |
elem_id="single_domains_plot" | |
) | |
with gr.Column(): | |
pair_domains_plot = gr.Plot(label="Pair domains") | |
frequency_slider.release( | |
fn=update_all_plots, | |
inputs=[frequency_slider, model_selector], | |
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot] | |
) | |
demo.load( | |
fn=update_all_plots, | |
inputs=[frequency_slider, model_selector], | |
outputs=[single_domains_plot, pair_domains_plot] | |
) | |
model_selector.input( | |
fn=update_all_plots, | |
inputs=[frequency_slider, model_selector], | |
outputs=[single_domains_plot, pair_domains_plot] | |
) | |
print(f"Launching!...") | |
demo.launch() | |
# demo.load(filter_map, [min_price, max_price, boroughs], map) |