|
import random |
|
import time |
|
from pathlib import Path |
|
|
|
import numpy as np |
|
from biotite.structure.atoms import AtomArrayStack |
|
from scipy.spatial.transform import Rotation as R |
|
from pinder.core.structure.atoms import atom_array_from_pdb_file, normalize_orientation, write_pdb |
|
from pinder.core.structure.contacts import get_stack_contacts |
|
|
|
import gradio as gr |
|
|
|
from gradio_molecule3d import Molecule3D |
|
|
|
|
|
def predict( |
|
receptor_pdb: Path, |
|
ligand_pdb: Path, |
|
receptor_fasta: Path | None = None, |
|
ligand_fasta: Path | None = None, |
|
) -> tuple[str, float]: |
|
start_time = time.time() |
|
|
|
|
|
receptor = atom_array_from_pdb_file(receptor_pdb, extra_fields=["b_factor"]) |
|
ligand = atom_array_from_pdb_file(ligand_pdb, extra_fields=["b_factor"]) |
|
receptor = normalize_orientation(receptor) |
|
ligand = normalize_orientation(ligand) |
|
|
|
|
|
M = 50 |
|
|
|
stack = AtomArrayStack(M, ligand.shape[0]) |
|
|
|
|
|
for annot in ligand.get_annotation_categories(): |
|
stack.set_annotation(annot, np.copy(ligand.get_annotation(annot))) |
|
|
|
|
|
translation_magnitudes = np.linspace( |
|
0, 26, |
|
num=26, |
|
endpoint=False |
|
) |
|
|
|
for i in range(M): |
|
q = R.random() |
|
translation_vec = [ |
|
random.choice(translation_magnitudes), |
|
random.choice(translation_magnitudes), |
|
random.choice(translation_magnitudes), |
|
] |
|
|
|
stack.coord[i, ...] = q.apply(ligand.coord) + translation_vec |
|
|
|
|
|
stack_conts = get_stack_contacts(receptor, stack, threshold=1.2) |
|
|
|
|
|
pose_clashes = [] |
|
for i in range(stack_conts.shape[0]): |
|
pose_conts = stack_conts[i] |
|
pose_clashes.append((i, np.argwhere(pose_conts != -1).shape[0])) |
|
|
|
best_pose_idx = sorted(pose_clashes, key=lambda x: x[1])[0][0] |
|
best_pose = receptor + stack[best_pose_idx] |
|
|
|
output_dir = Path(receptor_pdb).parent |
|
|
|
pdb_name = Path(receptor_pdb).stem + "--" + Path(ligand_pdb).name |
|
output_pdb = output_dir / pdb_name |
|
write_pdb(best_pose, output_pdb) |
|
end_time = time.time() |
|
run_time = end_time - start_time |
|
return str(output_pdb), run_time |
|
|
|
|
|
with gr.Blocks() as app: |
|
|
|
gr.Markdown("# Template for inference") |
|
|
|
gr.Markdown("Title, description, and other information about the model") |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_protein_1 = gr.File(label="Input Protein 1 monomer (PDB)") |
|
input_fasta_1 = gr.File(label="Input Protein 1 monomer sequence (FASTA)") |
|
with gr.Column(): |
|
input_protein_2 = gr.File(label="Input Protein 2 monomer (PDB)") |
|
input_fasta_2 = gr.File(label="Input Protein 2 monomer sequence (FASTA)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
btn = gr.Button("Run Inference") |
|
|
|
gr.Examples( |
|
[ |
|
[ |
|
"8i5w_R.pdb", |
|
"8i5w_R.fasta", |
|
"8i5w_L.pdb", |
|
"8i5w_L.fasta", |
|
], |
|
], |
|
[input_protein_1, input_fasta_1, input_protein_2, input_fasta_2], |
|
) |
|
reps = [ |
|
{ |
|
"model": 0, |
|
"style": "cartoon", |
|
"chain": "R", |
|
"color": "whiteCarbon", |
|
}, |
|
{ |
|
"model": 0, |
|
"style": "cartoon", |
|
"chain": "L", |
|
"color": "greenCarbon", |
|
}, |
|
{ |
|
"model": 0, |
|
"chain": "R", |
|
"style": "stick", |
|
"sidechain": True, |
|
"color": "whiteCarbon", |
|
}, |
|
{ |
|
"model": 0, |
|
"chain": "L", |
|
"style": "stick", |
|
"sidechain": True, |
|
"color": "greenCarbon" |
|
} |
|
] |
|
|
|
out = Molecule3D(reps=reps) |
|
run_time = gr.Textbox(label="Runtime") |
|
|
|
btn.click(predict, inputs=[input_protein_1, input_protein_2, input_fasta_1, input_fasta_2], outputs=[out, run_time]) |
|
|
|
app.launch() |
|
|