File size: 986 Bytes
21849ba
49c5855
121b388
49c5855
 
 
 
 
 
21849ba
55c698f
49c5855
 
 
 
21849ba
 
 
121b388
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F

def validate_sequence(sequence):
    valid_amino_acids = set("ACDEFGHIKLMNPQRSTVWY")  # 20 standard amino acids
    return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200

def load_model():
    # Load your model as before
    model = torch.load('solubility_model.pth', map_location=torch.device('cpu'))
    model.eval()
    return model

def predict(model, sequence):
    tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
    tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
    output = model(**tokenized_input)
    logits = output.logits  # Extract logits
    probabilities = F.softmax(logits, dim=-1)  # Apply softmax to convert logits to probabilities
    predicted_label = torch.argmax(probabilities, dim=-1)  # Get the predicted label
    return predicted_label.item()  # Return the label as a Python integer