File size: 3,367 Bytes
c8e874d
49c5855
f132523
7c87bff
232e962
 
c8e874d
49c5855
253caff
7c87bff
253caff
232e962
7c87bff
 
 
 
 
 
 
232e962
 
 
 
 
 
8a7095f
232e962
 
 
 
88dbd92
232e962
 
 
88dbd92
 
 
232e962
88dbd92
 
7c87bff
 
232e962
88dbd92
232e962
88dbd92
232e962
7c87bff
88dbd92
7c87bff
232e962
253caff
8a7095f
 
 
 
232e962
 
9ff1365
232e962
9ff1365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232e962
49c5855
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import streamlit as st
from utils import validate_sequence, predict
from model import models
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def main():
    st.set_page_config(layout="wide")  # Keep the wide layout for overall flexibility
    st.title("AA Property Inference Demo", anchor=None)

    # Styling for the app to use monospace font
    st.markdown("""
        <style>
        .reportview-container {
            font-family: 'Courier New', monospace;
        }
        </style>
        """, unsafe_allow_html=True)

    # Input section in the sidebar
    sequence = st.sidebar.text_input("Enter your amino acid sequence:")
    uploaded_file = st.sidebar.file_uploader("Or upload a CSV file with amino acid sequences", type="csv")
    analyze_pressed = st.sidebar.button("Analyze Sequence")
    show_graphs = st.sidebar.checkbox("Show Prediction Graphs")

    sequences = [sequence] if sequence else []
    if uploaded_file:
        df = pd.read_csv(uploaded_file)
        sequences.extend(df['sequence'].tolist())

    results = []
    all_data = {}
    if analyze_pressed:
        for seq in sequences:
            if validate_sequence(seq):
                model_results = {}
                graph_data = {}
                for model_name, model in models.items():
                    prediction, confidence = predict(model, seq)
                    model_results[f"{model_name}_prediction"] = prediction
                    model_results[f"{model_name}_confidence"] = round(confidence, 3)
                    graph_data[model_name] = (prediction, confidence)
                results.append({"Sequence": seq, **model_results})
                all_data[seq] = graph_data
            else:
                st.sidebar.error(f"Invalid sequence: {seq}")

        if results:
            results_df = pd.DataFrame(results)
            st.write("### Results")
            st.dataframe(results_df.style.format(precision=3), width=None, height=None)
            
            if show_graphs and all_data:
                st.write("## Graphs")
                plot_prediction_graphs(all_data)

def plot_prediction_graphs(data):
    # Function to plot graphs for predictions, divided by prediction value
    for model_name in models.keys():
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
        for prediction_val in [0, 1]:
            ax = ax1 if prediction_val == 0 else ax2
            filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val}
            # Sorting sequences based on confidence, descending
            sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
            sequences = [x[0] for x in sorted_sequences]
            conf_values = [x[1][1] for x in sorted_sequences]
            sns.barplot(x=sequences, y=conf_values, palette="coolwarm" if prediction_val == 0 else "cubehelix", ax=ax)
            ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
            ax.set_xlabel('Sequences')
            ax.set_ylabel('Confidence')
            ax.tick_params(axis='x', rotation=45)  # Rotate x labels for better visibility

        st.pyplot(fig)  # Display the plot with two subplots below the results table

if __name__ == "__main__":
    main()