from transformers import AutoTokenizer import torch import torch.nn.functional as F def validate_sequence(sequence): valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200 def load_model(model_name): # Load the model based on the provided name model = torch.load(f'{model_name}_model.pth', map_location=torch.device('cpu')) model.eval() return model def predict(model, sequence): tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D') tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True) output = model(**tokenized_input) probabilities = F.softmax(output.logits, dim=-1) predicted_label = torch.argmax(probabilities, dim=-1) confidence = probabilities.max().item() * 0.85 return predicted_label.item(), confidence