pinder_inference_template / inference_app.py
danielkovtun's picture
chore: join the input filenames
be507b5
raw
history blame
4.52 kB
import random
import time
from pathlib import Path
import numpy as np
from biotite.structure.atoms import AtomArrayStack
from scipy.spatial.transform import Rotation as R
from pinder.core.structure.atoms import atom_array_from_pdb_file, normalize_orientation, write_pdb
from pinder.core.structure.contacts import get_stack_contacts
import gradio as gr
from gradio_molecule3d import Molecule3D
def predict(
receptor_pdb: Path,
ligand_pdb: Path,
receptor_fasta: Path | None = None,
ligand_fasta: Path | None = None,
) -> tuple[str, float]:
start_time = time.time()
# Do inference here
# return an output pdb file with the protein and two chains R and L.
receptor = atom_array_from_pdb_file(receptor_pdb, extra_fields=["b_factor"])
ligand = atom_array_from_pdb_file(ligand_pdb, extra_fields=["b_factor"])
receptor = normalize_orientation(receptor)
ligand = normalize_orientation(ligand)
# Number of random poses to generate
M = 50
# Inititalize an empty stack with shape (m x n x 3)
stack = AtomArrayStack(M, ligand.shape[0])
# copy annotations from ligand
for annot in ligand.get_annotation_categories():
stack.set_annotation(annot, np.copy(ligand.get_annotation(annot)))
# Random translations sampled along 0-50 angstroms per axis
translation_magnitudes = np.linspace(
0, 26,
num=26,
endpoint=False
)
# generate one pose at a time
for i in range(M):
q = R.random()
translation_vec = [
random.choice(translation_magnitudes), # x
random.choice(translation_magnitudes), # y
random.choice(translation_magnitudes), # z
]
# transform the ligand chain
stack.coord[i, ...] = q.apply(ligand.coord) + translation_vec
# Find clashes (1.2 A contact radius)
stack_conts = get_stack_contacts(receptor, stack, threshold=1.2)
# Keep the "best" pose based on pose w/fewest clashes
pose_clashes = []
for i in range(stack_conts.shape[0]):
pose_conts = stack_conts[i]
pose_clashes.append((i, np.argwhere(pose_conts != -1).shape[0]))
best_pose_idx = sorted(pose_clashes, key=lambda x: x[1])[0][0]
best_pose = receptor + stack[best_pose_idx]
output_dir = Path(receptor_pdb).parent
# System ID
pdb_name = Path(receptor_pdb).stem + "--" + Path(ligand_pdb).name
output_pdb = output_dir / pdb_name
write_pdb(best_pose, output_pdb)
end_time = time.time()
run_time = end_time - start_time
return str(output_pdb), run_time
with gr.Blocks() as app:
gr.Markdown("# Template for inference")
gr.Markdown("Title, description, and other information about the model")
with gr.Row():
with gr.Column():
input_protein_1 = gr.File(label="Input Protein 1 monomer (PDB)")
input_fasta_1 = gr.File(label="Input Protein 1 monomer sequence (FASTA)")
with gr.Column():
input_protein_2 = gr.File(label="Input Protein 2 monomer (PDB)")
input_fasta_2 = gr.File(label="Input Protein 2 monomer sequence (FASTA)")
# define any options here
# for automated inference the default options are used
# slider_option = gr.Slider(0,10, label="Slider Option")
# checkbox_option = gr.Checkbox(label="Checkbox Option")
# dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")
btn = gr.Button("Run Inference")
gr.Examples(
[
[
"8i5w_R.pdb",
"8i5w_R.fasta",
"8i5w_L.pdb",
"8i5w_L.fasta",
],
],
[input_protein_1, input_fasta_1, input_protein_2, input_fasta_2],
)
reps = [
{
"model": 0,
"style": "cartoon",
"chain": "R",
"color": "whiteCarbon",
},
{
"model": 0,
"style": "cartoon",
"chain": "L",
"color": "greenCarbon",
},
{
"model": 0,
"chain": "R",
"style": "stick",
"sidechain": True,
"color": "whiteCarbon",
},
{
"model": 0,
"chain": "L",
"style": "stick",
"sidechain": True,
"color": "greenCarbon"
}
]
out = Molecule3D(reps=reps)
run_time = gr.Textbox(label="Runtime")
btn.click(predict, inputs=[input_protein_1, input_protein_2, input_fasta_1, input_fasta_2], outputs=[out, run_time])
app.launch()