basilboy commited on
Commit
a6bbca3
·
verified ·
1 Parent(s): 186e4b6

Update utils.py

Browse files

moved plotting function here

Files changed (1) hide show
  1. utils.py +25 -1
utils.py CHANGED
@@ -20,4 +20,28 @@ def predict(model, sequence):
20
  probabilities = F.softmax(output.logits, dim=-1)
21
  predicted_label = torch.argmax(probabilities, dim=-1)
22
  confidence = probabilities.max().item() * 0.85
23
- return predicted_label.item(), confidence
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  probabilities = F.softmax(output.logits, dim=-1)
21
  predicted_label = torch.argmax(probabilities, dim=-1)
22
  confidence = probabilities.max().item() * 0.85
23
+ return predicted_label.item(), confidence
24
+
25
+ def plot_prediction_graphs(data):
26
+ # Create a color palette that is consistent across graphs
27
+ unique_sequences = sorted(set(seq for seq in data))
28
+ palette = sns.color_palette("hsv", len(unique_sequences))
29
+ color_dict = {seq: color for seq, color in zip(unique_sequences, palette)}
30
+
31
+ for model_name in models.keys():
32
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
33
+ for prediction_val in [0, 1]:
34
+ ax = ax1 if prediction_val == 0 else ax2
35
+ filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val}
36
+ # Sorting sequences based on confidence, descending
37
+ sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
38
+ sequences = [x[0] for x in sorted_sequences]
39
+ conf_values = [x[1][1] for x in sorted_sequences]
40
+ colors = [color_dict[seq] for seq in sequences]
41
+ sns.barplot(x=sequences, y=conf_values, palette=colors, ax=ax)
42
+ ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
43
+ ax.set_xlabel('Sequences')
44
+ ax.set_ylabel('Confidence')
45
+ ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
46
+
47
+ st.pyplot(fig)