|
import streamlit as st |
|
from utils import validate_sequence, predict |
|
from model import models |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
def main(): |
|
st.set_page_config(layout="wide") |
|
st.title("AA Property Inference Demo", anchor=None) |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.reportview-container { |
|
font-family: 'Courier New', monospace; |
|
} |
|
</style> |
|
<p style='font-size:16px;'><span style='font-size:24px;'>←</span> Don't know where to start? Open tab to input a sequence.</p> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
sequence = st.sidebar.text_input("Enter your amino acid sequence:") |
|
uploaded_file = st.sidebar.file_uploader("Or upload a CSV file with amino acid sequences", type="csv") |
|
analyze_pressed = st.sidebar.button("Analyze Sequence") |
|
show_graphs = st.sidebar.checkbox("Show Prediction Graphs") |
|
|
|
sequences = [sequence] if sequence else [] |
|
if uploaded_file: |
|
df = pd.read_csv(uploaded_file) |
|
sequences.extend(df['sequence'].tolist()) |
|
|
|
results = [] |
|
all_data = {} |
|
if analyze_pressed: |
|
for seq in sequences: |
|
if validate_sequence(seq): |
|
model_results = {} |
|
graph_data = {} |
|
for model_name, model in models.items(): |
|
prediction, confidence = predict(model, seq) |
|
model_results[f"{model_name}_prediction"] = prediction |
|
model_results[f"{model_name}_confidence"] = round(confidence, 3) |
|
graph_data[model_name] = (prediction, confidence) |
|
results.append({"Sequence": seq, **model_results}) |
|
all_data[seq] = graph_data |
|
else: |
|
st.sidebar.error(f"Invalid sequence: {seq}") |
|
|
|
if results: |
|
results_df = pd.DataFrame(results) |
|
st.write("### Results") |
|
st.dataframe(results_df.style.format(precision=3), width=None, height=None) |
|
|
|
if show_graphs and all_data: |
|
st.write("## Graphs") |
|
plot_prediction_graphs(all_data) |
|
|
|
def plot_prediction_graphs(data): |
|
|
|
unique_sequences = sorted(set(seq for seq in data)) |
|
palette = sns.color_palette("hsv", len(unique_sequences)) |
|
color_dict = {seq: color for seq, color in zip(unique_sequences, palette)} |
|
|
|
for model_name in models.keys(): |
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True) |
|
for prediction_val in [0, 1]: |
|
ax = ax1 if prediction_val == 0 else ax2 |
|
filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val} |
|
|
|
sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True) |
|
sequences = [x[0] for x in sorted_sequences] |
|
conf_values = [x[1][1] for x in sorted_sequences] |
|
colors = [color_dict[seq] for seq in sequences] |
|
sns.barplot(x=sequences, y=conf_values, palette=colors, ax=ax) |
|
ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})') |
|
ax.set_xlabel('Sequences') |
|
ax.set_ylabel('Confidence') |
|
ax.tick_params(axis='x', rotation=45) |
|
|
|
st.pyplot(fig) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|