Update utils.py
Browse filesnew input to model keys as it has been moved from app.py
utils.py
CHANGED
@@ -24,25 +24,25 @@ def predict(model, sequence):
|
|
24 |
confidence = probabilities.max().item() * 0.85
|
25 |
return predicted_label.item(), confidence
|
26 |
|
27 |
-
def plot_prediction_graphs(data):
|
28 |
# Create a color palette that is consistent across graphs
|
29 |
-
|
30 |
-
palette = sns.color_palette("hsv", len(
|
31 |
-
color_dict = {
|
32 |
|
33 |
-
for model_name in
|
34 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
|
35 |
for prediction_val in [0, 1]:
|
36 |
ax = ax1 if prediction_val == 0 else ax2
|
37 |
-
filtered_data = {
|
38 |
-
# Sorting
|
39 |
-
|
40 |
-
|
41 |
-
conf_values = [x[1][1] for x in
|
42 |
-
colors = [color_dict[
|
43 |
-
sns.barplot(x=
|
44 |
ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
|
45 |
-
ax.set_xlabel('
|
46 |
ax.set_ylabel('Confidence')
|
47 |
ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
|
48 |
|
|
|
24 |
confidence = probabilities.max().item() * 0.85
|
25 |
return predicted_label.item(), confidence
|
26 |
|
27 |
+
def plot_prediction_graphs(data,model_keys):
|
28 |
# Create a color palette that is consistent across graphs
|
29 |
+
unique_names = sorted(data.keys()) # Using names instead of sequences
|
30 |
+
palette = sns.color_palette("hsv", len(unique_names))
|
31 |
+
color_dict = {name: color for name, color in zip(unique_names, palette)}
|
32 |
|
33 |
+
for model_name in model_keys:
|
34 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
|
35 |
for prediction_val in [0, 1]:
|
36 |
ax = ax1 if prediction_val == 0 else ax2
|
37 |
+
filtered_data = {name: values[model_name] for name, values in data.items() if values[model_name][0] == prediction_val}
|
38 |
+
# Sorting names based on confidence, descending
|
39 |
+
sorted_names = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
|
40 |
+
names = [x[0] for x in sorted_names]
|
41 |
+
conf_values = [x[1][1] for x in sorted_names]
|
42 |
+
colors = [color_dict[name] for name in names]
|
43 |
+
sns.barplot(x=names, y=conf_values, palette=colors, ax=ax)
|
44 |
ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
|
45 |
+
ax.set_xlabel('Names')
|
46 |
ax.set_ylabel('Confidence')
|
47 |
ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
|
48 |
|