File size: 4,653 Bytes
3a329d1
7ac370b
 
 
3a329d1
7ac370b
3a329d1
7ac370b
 
 
ca7444f
 
 
 
 
7ac370b
 
b40aac1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4c1e61
7ac370b
 
 
af49af1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ac370b
af49af1
 
 
 
b40aac1
af49af1
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from constants import *


from mpl_data_plotter import MatplotlibDataPlotter

def convert_int64_to_int32(df):
    for col in df.columns:
        if df[col].dtype == 'int64':
            df[col] = df[col].astype('int32')
    return df


def create_color_legend(class_to_color):
    # Create HTML for the color legend
    legend_html = """
        <div style="
            margin: 10px 0;
            padding: 10px;
            border: 1px solid #ddd;
            border-radius: 4px;
            background: white;
        ">
            <div style="
                font-weight: bold;
                margin-bottom: 8px;
            ">Color Legend:</div>
            <div style="
                display: flex;
                flex-wrap: wrap;
                gap: 15px;
                align-items: center;
            ">
    """
    # Add each class and its color
    for class_name, color in class_to_color.items():
        legend_html += f"""
            <div style="
                display: flex;
                align-items: center;
                gap: 5px;
            ">
                <div style="
                    width: 20px;
                    height: 20px;
                    background-color: {color};
                    border-radius: 3px;
                "></div>
                <span>{class_name}</span>
            </div>
        """
    
    legend_html += """
            </div>
        </div>
    """
    
    return gr.HTML(legend_html)


def update_all_plots(frequency, split_name):
    return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)


if __name__ == "__main__":
    print(f"Loading domains data...")
    single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
    single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
    single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
    single_df = convert_int64_to_int32(single_df)

    pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
    pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
    pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
    pair_df = convert_int64_to_int32(pair_df)

    num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
        columns={'as_domain_id': 'num_domains'})

    unique_domain_lengths = num_domains_in_region_df.num_domains.unique()

    print(f"Initializing data plotter...")
    data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)


    print(f"Defining blocks...")

    # Create Gradio interface
    with gr.Blocks(title="BGC Keyword Plotter") as demo:
        gr.Markdown("## BGC Keyword Plotter")
        gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")

        color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
        
        with gr.Row():
            frequency_slider = gr.Slider(
                minimum=int(unique_domain_lengths.min()), 
                maximum=int(unique_domain_lengths.max()), 
                step=1, 
                value=int(unique_domain_lengths.min()), 
                label="Min number of domains"
            )
            model_selector = gr.Radio(
                choices=["stratified"] + BIOSYN_CLASS_NAMES,
                value="stratified",
                label="Model name"
            )

        with gr.Row():
            with gr.Column():
                single_domains_plot = gr.Plot(
                    label="Single domains",
                    container=True, 
                    elem_id="single_domains_plot"
                )
            with gr.Column():
                pair_domains_plot = gr.Plot(label="Pair domains")

        frequency_slider.release(
            fn=update_all_plots,
            inputs=[frequency_slider, model_selector],
            outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
        )
        demo.load(
            fn=update_all_plots, 
            inputs=[frequency_slider, model_selector], 
            outputs=[single_domains_plot, pair_domains_plot]
        )
        model_selector.input(
            fn=update_all_plots,
            inputs=[frequency_slider, model_selector],
            outputs=[single_domains_plot, pair_domains_plot]
        )


    print(f"Launching!...")
    demo.launch()

    # demo.load(filter_map, [min_price, max_price, boroughs], map)