import streamlit as st from utils import validate_sequence, predict from model import models import pandas as pd import matplotlib.pyplot as plt import seaborn as sns def main(): st.set_page_config(layout="wide") # Keep the wide layout for overall flexibility st.title("AA Property Inference Demo", anchor=None) # Instructional text below title st.markdown("""
← Don't know where to start? Open tab to input a sequence.
""", unsafe_allow_html=True) # Input section in the sidebar sequence = st.sidebar.text_input("Enter your amino acid sequence:") uploaded_file = st.sidebar.file_uploader("Or upload a CSV file with amino acid sequences", type="csv") analyze_pressed = st.sidebar.button("Analyze Sequence") show_graphs = st.sidebar.checkbox("Show Prediction Graphs") sequences = [sequence] if sequence else [] if uploaded_file: df = pd.read_csv(uploaded_file) sequences.extend(df['sequence'].tolist()) results = [] all_data = {} if analyze_pressed: for seq in sequences: if validate_sequence(seq): model_results = {} graph_data = {} for model_name, model in models.items(): prediction, confidence = predict(model, seq) model_results[f"{model_name}_prediction"] = prediction model_results[f"{model_name}_confidence"] = round(confidence, 3) graph_data[model_name] = (prediction, confidence) results.append({"Sequence": seq, **model_results}) all_data[seq] = graph_data else: st.sidebar.error(f"Invalid sequence: {seq}") if results: results_df = pd.DataFrame(results) st.write("### Results") st.dataframe(results_df.style.format(precision=3), width=None, height=None) if show_graphs and all_data: st.write("## Graphs") plot_prediction_graphs(all_data) def plot_prediction_graphs(data): # Create a color palette that is consistent across graphs unique_sequences = sorted(set(seq for seq in data)) palette = sns.color_palette("hsv", len(unique_sequences)) color_dict = {seq: color for seq, color in zip(unique_sequences, palette)} for model_name in models.keys(): fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True) for prediction_val in [0, 1]: ax = ax1 if prediction_val == 0 else ax2 filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val} # Sorting sequences based on confidence, descending sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True) sequences = [x[0] for x in sorted_sequences] conf_values = [x[1][1] for x in sorted_sequences] colors = [color_dict[seq] for seq in sequences] sns.barplot(x=sequences, y=conf_values, palette=colors, ax=ax) ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})') ax.set_xlabel('Sequences') ax.set_ylabel('Confidence') ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility st.pyplot(fig) # Display the plot with two subplots below the results table if __name__ == "__main__": main()