Update app.py
Browse files
app.py
CHANGED
@@ -140,16 +140,7 @@ def handle_name(name=None, pdb_input=None, model_version="ESM3"):
|
|
140 |
return f'{pdb_name}-Dyna1{"" if model_version == "ESM3" else "-ESM2"}'
|
141 |
|
142 |
@spaces.GPU(duration=300)
|
143 |
-
def run_model(model_version='ESM2', seq_input=None, struct_input=None, sequence_id=None):
|
144 |
-
if model_version == "ESM3":
|
145 |
-
model = ESM_model(method='esm3')
|
146 |
-
model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1.pt'), strict=False)
|
147 |
-
else:
|
148 |
-
model = ESM_model(method='esm2', nheads=8, nlayers=12, layer=30).to(DEVICE)
|
149 |
-
model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1-esm2.pt'), strict=False)
|
150 |
-
|
151 |
-
model.eval()
|
152 |
-
|
153 |
if model_version == "ESM3":
|
154 |
logits = model((seq_input, struct_input), sequence_id)
|
155 |
else:
|
@@ -170,6 +161,14 @@ def predict_dynamics(sequence=None, pdb_input=None, chain_id='A', use_pdb_seq=Fa
|
|
170 |
seq_input, struct_input = None, None
|
171 |
sequence = validate_sequence(sequence) if sequence else None
|
172 |
protein = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
if pdb_input and model_version == "ESM3":
|
175 |
protein, protein_chain = process_structure(pdb_input, chain_id)
|
|
|
140 |
return f'{pdb_name}-Dyna1{"" if model_version == "ESM3" else "-ESM2"}'
|
141 |
|
142 |
@spaces.GPU(duration=300)
|
143 |
+
def run_model(model, model_version='ESM2', seq_input=None, struct_input=None, sequence_id=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
if model_version == "ESM3":
|
145 |
logits = model((seq_input, struct_input), sequence_id)
|
146 |
else:
|
|
|
161 |
seq_input, struct_input = None, None
|
162 |
sequence = validate_sequence(sequence) if sequence else None
|
163 |
protein = None
|
164 |
+
if model_version == "ESM3":
|
165 |
+
model = ESM_model(method='esm3')
|
166 |
+
model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1.pt'), strict=False)
|
167 |
+
else:
|
168 |
+
model = ESM_model(method='esm2', nheads=8, nlayers=12, layer=30).to(DEVICE)
|
169 |
+
model.load_state_dict(torch.load('Dyna-1/model/weights/dyna1-esm2.pt'), strict=False)
|
170 |
+
|
171 |
+
model.eval()
|
172 |
|
173 |
if pdb_input and model_version == "ESM3":
|
174 |
protein, protein_chain = process_structure(pdb_input, chain_id)
|