pinder_inference_template / inference_app.py
danielkovtun's picture
feat: add example evaluate endpoint to fetch metrics for a single prediction given model and system ID
7d9b175
raw
history blame
7.07 kB
import random
import time
from pathlib import Path
import numpy as np
import pandas as pd
from biotite.structure.atoms import AtomArrayStack
from scipy.spatial.transform import Rotation as R
from pinder.core import PinderSystem
from pinder.core.structure import atoms
from pinder.core.structure.contacts import get_stack_contacts
from pinder.core.loader.structure import Structure
from pinder.eval.dockq import BiotiteDockQ
import gradio as gr
from gradio_molecule3d import Molecule3D
EVAL_METRICS = ["system", "L_rms", "I_rms", "F_nat", "DOCKQ", "CAPRI_class"]
def predict(
receptor_pdb: Path,
ligand_pdb: Path,
receptor_fasta: Path | None = None,
ligand_fasta: Path | None = None,
) -> tuple[str, float]:
start_time = time.time()
# Do inference here
# return an output pdb file with the protein and two chains R and L.
receptor = atoms.atom_array_from_pdb_file(receptor_pdb, extra_fields=["b_factor"])
ligand = atoms.atom_array_from_pdb_file(ligand_pdb, extra_fields=["b_factor"])
receptor = atoms.normalize_orientation(receptor)
ligand = atoms.normalize_orientation(ligand)
# Number of random poses to generate
M = 50
# Inititalize an empty stack with shape (m x n x 3)
stack = AtomArrayStack(M, ligand.shape[0])
# copy annotations from ligand
for annot in ligand.get_annotation_categories():
stack.set_annotation(annot, np.copy(ligand.get_annotation(annot)))
# Random translations sampled along 0-50 angstroms per axis
translation_magnitudes = np.linspace(
0, 26,
num=26,
endpoint=False
)
# generate one pose at a time
for i in range(M):
q = R.random()
translation_vec = [
random.choice(translation_magnitudes), # x
random.choice(translation_magnitudes), # y
random.choice(translation_magnitudes), # z
]
# transform the ligand chain
stack.coord[i, ...] = q.apply(ligand.coord) + translation_vec
# Find clashes (1.2 A contact radius)
stack_conts = get_stack_contacts(receptor, stack, threshold=1.2)
# Keep the "best" pose based on pose w/fewest clashes
pose_clashes = []
for i in range(stack_conts.shape[0]):
pose_conts = stack_conts[i]
pose_clashes.append((i, np.argwhere(pose_conts != -1).shape[0]))
best_pose_idx = sorted(pose_clashes, key=lambda x: x[1])[0][0]
best_pose = receptor + stack[best_pose_idx]
output_dir = Path(receptor_pdb).parent
# System ID
pdb_name = Path(receptor_pdb).stem + "--" + Path(ligand_pdb).name
output_pdb = output_dir / pdb_name
atoms.write_pdb(best_pose, output_pdb)
end_time = time.time()
run_time = end_time - start_time
return str(output_pdb), run_time
def evaluate(
system_id: str,
prediction_pdb: Path,
) -> tuple[pd.DataFrame, float]:
start_time = time.time()
system = PinderSystem(system_id)
native = system.native.filepath
bdq = BiotiteDockQ(native, Path(prediction_pdb), parallel_io=False)
metrics = bdq.calculate()
metrics = metrics[["system", "LRMS", "iRMS", "Fnat", "DockQ", "CAPRI"]].copy()
metrics.rename(columns={"LRMS": "L_rms", "iRMS": "I_rms", "Fnat": "F_nat", "DockQ": "DOCKQ", "CAPRI": "CAPRI_class"}, inplace=True)
end_time = time.time()
run_time = end_time - start_time
pred = Structure(Path(prediction_pdb))
nat = Structure(Path(native))
pred, _, _ = pred.superimpose(nat)
pred.to_pdb(Path(prediction_pdb))
return metrics, [str(prediction_pdb), str(native)], run_time
with gr.Blocks() as app:
with gr.Tab("🧬 PINDER inference template"):
gr.Markdown("Title, description, and other information about the model")
with gr.Row():
with gr.Column():
input_protein_1 = gr.File(label="Input Protein 1 monomer (PDB)")
input_fasta_1 = gr.File(label="Input Protein 1 monomer sequence (FASTA)")
with gr.Column():
input_protein_2 = gr.File(label="Input Protein 2 monomer (PDB)")
input_fasta_2 = gr.File(label="Input Protein 2 monomer sequence (FASTA)")
# define any options here
# for automated inference the default options are used
# slider_option = gr.Slider(0,10, label="Slider Option")
# checkbox_option = gr.Checkbox(label="Checkbox Option")
# dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option")
btn = gr.Button("Run Inference")
gr.Examples(
[
[
"8i5w_R.pdb",
"8i5w_R.fasta",
"8i5w_L.pdb",
"8i5w_L.fasta",
],
],
[input_protein_1, input_fasta_1, input_protein_2, input_fasta_2],
)
reps = [
{
"model": 0,
"style": "cartoon",
"chain": "R",
"color": "whiteCarbon",
},
{
"model": 0,
"style": "cartoon",
"chain": "L",
"color": "greenCarbon",
},
{
"model": 0,
"chain": "R",
"style": "stick",
"sidechain": True,
"color": "whiteCarbon",
},
{
"model": 0,
"chain": "L",
"style": "stick",
"sidechain": True,
"color": "greenCarbon"
}
]
out = Molecule3D(reps=reps)
run_time = gr.Textbox(label="Runtime")
btn.click(predict, inputs=[input_protein_1, input_protein_2, input_fasta_1, input_fasta_2], outputs=[out, run_time])
with gr.Tab("⚖️ PINDER evaluation template"):
with gr.Row():
with gr.Column():
input_system_id = gr.Textbox(label="PINDER system ID")
input_prediction_pdb = gr.File(label="Top ranked prediction (PDB with chains R and L)")
eval_btn = gr.Button("Run Evaluation")
gr.Examples(
[
[
"3g9w__A1_Q71LX4--3g9w__D1_P05556",
"3g9w_R--3g9w_L.pdb",
],
],
[input_system_id, input_prediction_pdb],
)
reps = [
{
"model": 0,
"style": "cartoon",
"chain": "R",
"color": "greenCarbon",
},
{
"model": 0,
"style": "cartoon",
"chain": "L",
"color": "cyanCarbon",
},
{
"model": 1,
"style": "cartoon",
"chain": "R",
"color": "grayCarbon",
},
{
"model": 1,
"style": "cartoon",
"chain": "L",
"color": "blueCarbon",
},
]
pred_native = Molecule3D(reps=reps, config={"backgroundColor": "black"})
eval_run_time = gr.Textbox(label="Evaluation runtime")
metric_table = gr.DataFrame(pd.DataFrame([], columns=EVAL_METRICS),label="Evaluation metrics")
eval_btn.click(evaluate, inputs=[input_system_id, input_prediction_pdb], outputs=[metric_table, pred_native, eval_run_time])
app.launch()