test3 / app.py
basilboy's picture
Update app.py
088cb01 verified
raw
history blame
3.7 kB
import streamlit as st
from utils import validate_sequence, predict
from model import models
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def main():
st.set_page_config(layout="wide") # Keep the wide layout for overall flexibility
st.title("AA Property Inference Demo", anchor=None)
# Instructional text below title
st.markdown("""
<style>
.reportview-container {
font-family: 'Courier New', monospace;
}
</style>
<p style='font-size:16px;'><span style='font-size:24px;'>&larr;</span> Don't know where to start? Open tab to input a sequence.</p>
""", unsafe_allow_html=True)
# Input section in the sidebar
sequence = st.sidebar.text_input("Enter your amino acid sequence:")
uploaded_file = st.sidebar.file_uploader("Or upload a CSV file with amino acid sequences", type="csv")
analyze_pressed = st.sidebar.button("Analyze Sequence")
show_graphs = st.sidebar.checkbox("Show Prediction Graphs")
sequences = [sequence] if sequence else []
if uploaded_file:
df = pd.read_csv(uploaded_file)
sequences.extend(df['sequence'].tolist())
results = []
all_data = {}
if analyze_pressed:
for seq in sequences:
if validate_sequence(seq):
model_results = {}
graph_data = {}
for model_name, model in models.items():
prediction, confidence = predict(model, seq)
model_results[f"{model_name}_prediction"] = prediction
model_results[f"{model_name}_confidence"] = round(confidence, 3)
graph_data[model_name] = (prediction, confidence)
results.append({"Sequence": seq, **model_results})
all_data[seq] = graph_data
else:
st.sidebar.error(f"Invalid sequence: {seq}")
if results:
results_df = pd.DataFrame(results)
st.write("### Results")
st.dataframe(results_df.style.format(precision=3), width=None, height=None)
if show_graphs and all_data:
st.write("## Graphs")
plot_prediction_graphs(all_data)
def plot_prediction_graphs(data):
# Create a color palette that is consistent across graphs
unique_sequences = sorted(set(seq for seq in data))
palette = sns.color_palette("hsv", len(unique_sequences))
color_dict = {seq: color for seq, color in zip(unique_sequences, palette)}
for model_name in models.keys():
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
for prediction_val in [0, 1]:
ax = ax1 if prediction_val == 0 else ax2
filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val}
# Sorting sequences based on confidence, descending
sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
sequences = [x[0] for x in sorted_sequences]
conf_values = [x[1][1] for x in sorted_sequences]
colors = [color_dict[seq] for seq in sequences]
sns.barplot(x=sequences, y=conf_values, palette=colors, ax=ax)
ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
ax.set_xlabel('Sequences')
ax.set_ylabel('Confidence')
ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
st.pyplot(fig) # Display the plot with two subplots below the results table
if __name__ == "__main__":
main()