test3 / utils.py
basilboy's picture
Update utils.py
a6bbca3 verified
raw
history blame
2.27 kB
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
def validate_sequence(sequence):
valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids
return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
def load_model(model_name):
# Load the model based on the provided name
model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu'))
model.eval()
return model
def predict(model, sequence):
tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
output = model(**tokenized_input)
probabilities = F.softmax(output.logits, dim=-1)
predicted_label = torch.argmax(probabilities, dim=-1)
confidence = probabilities.max().item() * 0.85
return predicted_label.item(), confidence
def plot_prediction_graphs(data):
# Create a color palette that is consistent across graphs
unique_sequences = sorted(set(seq for seq in data))
palette = sns.color_palette("hsv", len(unique_sequences))
color_dict = {seq: color for seq, color in zip(unique_sequences, palette)}
for model_name in models.keys():
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), sharey=True)
for prediction_val in [0, 1]:
ax = ax1 if prediction_val == 0 else ax2
filtered_data = {seq: values[model_name] for seq, values in data.items() if values[model_name][0] == prediction_val}
# Sorting sequences based on confidence, descending
sorted_sequences = sorted(filtered_data.items(), key=lambda x: x[1][1], reverse=True)
sequences = [x[0] for x in sorted_sequences]
conf_values = [x[1][1] for x in sorted_sequences]
colors = [color_dict[seq] for seq in sequences]
sns.barplot(x=sequences, y=conf_values, palette=colors, ax=ax)
ax.set_title(f'Confidence Scores for {model_name.capitalize()} (Prediction {prediction_val})')
ax.set_xlabel('Sequences')
ax.set_ylabel('Confidence')
ax.tick_params(axis='x', rotation=45) # Rotate x labels for better visibility
st.pyplot(fig)