from transformers import AutoTokenizer import torch import torch.nn.functional as F def validate_sequence(sequence): valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY") # 20 standard amino acids return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200 def load_model(): # Load your model as before model = torch.load('solubility_model.pth', map_location=torch.device('cpu')) model.eval() return model def predict(model, sequence): tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D') tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True) output = model(**tokenized_input) logits = output.logits # Extract logits probabilities = F.softmax(logits, dim=-1) # Apply softmax to convert logits to probabilities predicted_label = torch.argmax(probabilities, dim=-1) # Get the predicted label return predicted_label.item() # Return the label as a Python integer