basilboy commited on
Commit
12565b9
·
verified ·
1 Parent(s): 80f9224

Update utils.py

Browse files

new input to model keys as it has been moved from app.py

Files changed (1) hide show
  1. utils.py +13 -13
utils.py CHANGED
@@ -24,25 +24,25 @@ def predict(model, sequence):
24
  confidence = probabilities.max().item() * 0.85
25
  return predicted_label.item(), confidence
26
 
27
- def plot_prediction_graphs(data):
28
  # Create a color palette that is consistent across graphs
29
- unique_sequences = sorted(set(seq for seq in data))
30
- palette = sns.color_palette("hsv", len(unique_sequences))
31
- color_dict = {seq: color for seq, color in zip(unique_sequences, palette)}
32
 
33
- for model_name in models.keys():
34
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
35
  for prediction_val in [0, 1]:
36
  ax = ax1 if prediction_val == 0 else ax2
37
- filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val}
38
- # Sorting sequences based on confidence, descending
39
- sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
40
- sequences = [x[0] for x in sorted_sequences]
41
- conf_values = [x[1][1] for x in sorted_sequences]
42
- colors = [color_dict[seq] for seq in sequences]
43
- sns.barplot(x=sequences, y=conf_values, palette=colors, ax=ax)
44
  ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
45
- ax.set_xlabel('Sequences')
46
  ax.set_ylabel('Confidence')
47
  ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
48
 
 
24
  confidence = probabilities.max().item() * 0.85
25
  return predicted_label.item(), confidence
26
 
27
+ def plot_prediction_graphs(data,model_keys):
28
  # Create a color palette that is consistent across graphs
29
+ unique_names = sorted(data.keys()) # Using names instead of sequences
30
+ palette = sns.color_palette("hsv", len(unique_names))
31
+ color_dict = {name: color for name, color in zip(unique_names, palette)}
32
 
33
+ for model_name in model_keys:
34
  fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
35
  for prediction_val in [0, 1]:
36
  ax = ax1 if prediction_val == 0 else ax2
37
+ filtered_data = {name: values[model_name] for name, values in data.items() if values[model_name][0] == prediction_val}
38
+ # Sorting names based on confidence, descending
39
+ sorted_names = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
40
+ names = [x[0] for x in sorted_names]
41
+ conf_values = [x[1][1] for x in sorted_names]
42
+ colors = [color_dict[name] for name in names]
43
+ sns.barplot(x=names, y=conf_values, palette=colors, ax=ax)
44
  ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
45
+ ax.set_xlabel('Names')
46
  ax.set_ylabel('Confidence')
47
  ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
48