basilboy commited on
Commit
21849ba
·
verified ·
1 Parent(s): 4811f4a

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +6 -5
utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
 
3
  def validate_sequence(sequence):
@@ -5,13 +6,13 @@ def validate_sequence(sequence):
5
  return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
6
 
7
  def load_model():
8
- # Assuming the model is a simple PyTorch model, adjust the path as needed
9
  model = torch.load('solubility_model.pth', map_location=torch.device('cpu'))
10
  model.eval()
11
  return model
12
 
13
  def predict(model, sequence):
14
- # Dummy tensor conversion, replace with your actual model's input handling
15
- tensor = torch.tensor([ord(char) for char in sequence], dtype=torch.float32)
16
- output = model(tensor)
17
- return output.item()
 
1
+ from transformers import AutoTokenizer
2
  import torch
3
 
4
  def validate_sequence(sequence):
 
6
  return all(aa in valid_amino_acids for aa in sequence) and len(sequence) <= 200
7
 
8
  def load_model():
9
+ # Load your model as before
10
  model = torch.load('solubility_model.pth', map_location=torch.device('cpu'))
11
  model.eval()
12
  return model
13
 
14
  def predict(model, sequence):
15
+ tokenizer = AutoTokenizer.from_pretrained('facebook/esm2_t6_8M_UR50D')
16
+ tokenized_input = tokenizer(sequence, return_tensors="pt", truncation=True, padding=True)
17
+ output = model(**tokenized_input)
18
+ return output.item()