File size: 3,038 Bytes
c8e874d 49c5855 f132523 7c87bff 232e962 c8e874d 49c5855 253caff 7c87bff 253caff 232e962 7c87bff 232e962 8a7095f 232e962 88dbd92 232e962 88dbd92 232e962 88dbd92 7c87bff 232e962 88dbd92 232e962 88dbd92 232e962 7c87bff 88dbd92 7c87bff 232e962 253caff 8a7095f 232e962 8a7095f 232e962 8a7095f 232e962 8a7095f 232e962 49c5855 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import streamlit as st
from utils import validate_sequence, predict
from model import models
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def main():
st.set_page_config(layout="wide") # Keep the wide layout for overall flexibility
st.title("AA Property Inference Demo", anchor=None)
# Styling for the app to use monospace font
st.markdown("""
<style>
.reportview-container {
font-family: 'Courier New', monospace;
}
</style>
""", unsafe_allow_html=True)
# Input section in the sidebar
sequence = st.sidebar.text_input("Enter your amino acid sequence:")
uploaded_file = st.sidebar.file_uploader("Or upload a CSV file with amino acid sequences", type="csv")
analyze_pressed = st.sidebar.button("Analyze Sequence")
show_graphs = st.sidebar.checkbox("Show Prediction Graphs")
sequences = [sequence] if sequence else []
if uploaded_file:
df = pd.read_csv(uploaded_file)
sequences.extend(df['sequence'].tolist())
results = []
all_data = {}
if analyze_pressed:
for seq in sequences:
if validate_sequence(seq):
model_results = {}
graph_data = {}
for model_name, model in models.items():
prediction, confidence = predict(model, seq)
model_results[f"{model_name}_prediction"] = prediction
model_results[f"{model_name}_confidence"] = round(confidence, 3)
graph_data[model_name] = (prediction, confidence)
results.append({"Sequence": seq, **model_results})
all_data[seq] = graph_data
else:
st.sidebar.error(f"Invalid sequence: {seq}")
if results:
results_df = pd.DataFrame(results)
st.write("### Results")
st.dataframe(results_df.style.format(precision=3), width=None, height=None)
if show_graphs and all_data:
st.write("## Graphs")
plot_prediction_graphs(all_data)
def plot_prediction_graphs(data):
# Function to plot graphs for predictions
for model_name in models.keys():
plt.figure(figsize=(10, 4))
predictions = {seq: values[model_name][1] for seq, values in data.items()} # Using confidence for ordering
# Sorting sequences based on confidence, descending
sorted_sequences = sorted(predictions.items(), key=lambda x: x[1], reverse=True)
sequences = [x[0] for x in sorted_sequences]
conf_values = [x[1] for x in sorted_sequences]
sns.barplot(x=sequences, y=conf_values, palette="viridis")
plt.title(f'Confidence Scores for {model_name.capitalize()} Model')
plt.xlabel('Sequences')
plt.ylabel('Confidence')
plt.xticks(rotation=45) # Rotate x labels for better visibility
st.pyplot(plt) # Display each plot below the results table
if __name__ == "__main__":
main()
|