latticetower's picture
cleanup
af49af1
raw
history blame
4.65 kB
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from constants import *
from mpl_data_plotter import MatplotlibDataPlotter
def convert_int64_to_int32(df):
for col in df.columns:
if df[col].dtype == 'int64':
df[col] = df[col].astype('int32')
return df
def create_color_legend(class_to_color):
# Create HTML for the color legend
legend_html = """
<div style="
margin: 10px 0;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
background: white;
">
<div style="
font-weight: bold;
margin-bottom: 8px;
">Color Legend:</div>
<div style="
display: flex;
flex-wrap: wrap;
gap: 15px;
align-items: center;
">
"""
# Add each class and its color
for class_name, color in class_to_color.items():
legend_html += f"""
<div style="
display: flex;
align-items: center;
gap: 5px;
">
<div style="
width: 20px;
height: 20px;
background-color: {color};
border-radius: 3px;
"></div>
<span>{class_name}</span>
</div>
"""
legend_html += """
</div>
</div>
"""
return gr.HTML(legend_html)
def update_all_plots(frequency, split_name):
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
if __name__ == "__main__":
print(f"Loading domains data...")
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
single_df = convert_int64_to_int32(single_df)
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
pair_df = convert_int64_to_int32(pair_df)
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
columns={'as_domain_id': 'num_domains'})
unique_domain_lengths = num_domains_in_region_df.num_domains.unique()
print(f"Initializing data plotter...")
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
print(f"Defining blocks...")
# Create Gradio interface
with gr.Blocks(title="BGC Keyword Plotter") as demo:
gr.Markdown("## BGC Keyword Plotter")
gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
with gr.Row():
frequency_slider = gr.Slider(
minimum=int(unique_domain_lengths.min()),
maximum=int(unique_domain_lengths.max()),
step=1,
value=int(unique_domain_lengths.min()),
label="Min number of domains"
)
model_selector = gr.Radio(
choices=["stratified"] + BIOSYN_CLASS_NAMES,
value="stratified",
label="Model name"
)
with gr.Row():
with gr.Column():
single_domains_plot = gr.Plot(
label="Single domains",
container=True,
elem_id="single_domains_plot"
)
with gr.Column():
pair_domains_plot = gr.Plot(label="Pair domains")
frequency_slider.release(
fn=update_all_plots,
inputs=[frequency_slider, model_selector],
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
)
demo.load(
fn=update_all_plots,
inputs=[frequency_slider, model_selector],
outputs=[single_domains_plot, pair_domains_plot]
)
model_selector.input(
fn=update_all_plots,
inputs=[frequency_slider, model_selector],
outputs=[single_domains_plot, pair_domains_plot]
)
print(f"Launching!...")
demo.launch()
# demo.load(filter_map, [min_price, max_price, boroughs], map)