basilboy commited on
Commit
9ff1365
·
verified ·
1 Parent(s): 8a7095f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -56,20 +56,23 @@ def main():
56
  plot_prediction_graphs(all_data)
57
 
58
  def plot_prediction_graphs(data):
59
- # Function to plot graphs for predictions
60
  for model_name in models.keys():
61
- plt.figure(figsize=(10, 4))
62
- predictions = {seq: values[model_name][1] for seq, values in data.items()} # Using confidence for ordering
63
- # Sorting sequences based on confidence, descending
64
- sorted_sequences = sorted(predictions.items(), key=lambda x: x[1], reverse=True)
65
- sequences = [x[0] for x in sorted_sequences]
66
- conf_values = [x[1] for x in sorted_sequences]
67
- sns.barplot(x=sequences, y=conf_values, palette="viridis")
68
- plt.title(f'Confidence Scores for {model_name.capitalize()} Model')
69
- plt.xlabel('Sequences')
70
- plt.ylabel('Confidence')
71
- plt.xticks(rotation=45) # Rotate x labels for better visibility
72
- st.pyplot(plt) # Display each plot below the results table
 
 
 
73
 
74
  if __name__ == "__main__":
75
  main()
 
56
  plot_prediction_graphs(all_data)
57
 
58
  def plot_prediction_graphs(data):
59
+ # Function to plot graphs for predictions, divided by prediction value
60
  for model_name in models.keys():
61
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
62
+ for prediction_val in [0, 1]:
63
+ ax = ax1 if prediction_val == 0 else ax2
64
+ filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val}
65
+ # Sorting sequences based on confidence, descending
66
+ sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
67
+ sequences = [x[0] for x in sorted_sequences]
68
+ conf_values = [x[1][1] for x in sorted_sequences]
69
+ sns.barplot(x=sequences, y=conf_values, palette="coolwarm" if prediction_val == 0 else "cubehelix", ax=ax)
70
+ ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
71
+ ax.set_xlabel('Sequences')
72
+ ax.set_ylabel('Confidence')
73
+ ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
74
+
75
+ st.pyplot(fig) # Display the plot with two subplots below the results table
76
 
77
  if __name__ == "__main__":
78
  main()