feat: add example evaluate endpoint to fetch metrics for a single prediction given model and system ID
7d9b175
import random | |
import time | |
from pathlib import Path | |
import numpy as np | |
import pandas as pd | |
from biotite.structure.atoms import AtomArrayStack | |
from scipy.spatial.transform import Rotation as R | |
from pinder.core import PinderSystem | |
from pinder.core.structure import atoms | |
from pinder.core.structure.contacts import get_stack_contacts | |
from pinder.core.loader.structure import Structure | |
from pinder.eval.dockq import BiotiteDockQ | |
import gradio as gr | |
from gradio_molecule3d import Molecule3D | |
EVAL_METRICS = ["system", "L_rms", "I_rms", "F_nat", "DOCKQ", "CAPRI_class"] | |
def predict( | |
receptor_pdb: Path, | |
ligand_pdb: Path, | |
receptor_fasta: Path | None = None, | |
ligand_fasta: Path | None = None, | |
) -> tuple[str, float]: | |
start_time = time.time() | |
# Do inference here | |
# return an output pdb file with the protein and two chains R and L. | |
receptor = atoms.atom_array_from_pdb_file(receptor_pdb, extra_fields=["b_factor"]) | |
ligand = atoms.atom_array_from_pdb_file(ligand_pdb, extra_fields=["b_factor"]) | |
receptor = atoms.normalize_orientation(receptor) | |
ligand = atoms.normalize_orientation(ligand) | |
# Number of random poses to generate | |
M = 50 | |
# Inititalize an empty stack with shape (m x n x 3) | |
stack = AtomArrayStack(M, ligand.shape[0]) | |
# copy annotations from ligand | |
for annot in ligand.get_annotation_categories(): | |
stack.set_annotation(annot, np.copy(ligand.get_annotation(annot))) | |
# Random translations sampled along 0-50 angstroms per axis | |
translation_magnitudes = np.linspace( | |
0, 26, | |
num=26, | |
endpoint=False | |
) | |
# generate one pose at a time | |
for i in range(M): | |
q = R.random() | |
translation_vec = [ | |
random.choice(translation_magnitudes), # x | |
random.choice(translation_magnitudes), # y | |
random.choice(translation_magnitudes), # z | |
] | |
# transform the ligand chain | |
stack.coord[i, ...] = q.apply(ligand.coord) + translation_vec | |
# Find clashes (1.2 A contact radius) | |
stack_conts = get_stack_contacts(receptor, stack, threshold=1.2) | |
# Keep the "best" pose based on pose w/fewest clashes | |
pose_clashes = [] | |
for i in range(stack_conts.shape[0]): | |
pose_conts = stack_conts[i] | |
pose_clashes.append((i, np.argwhere(pose_conts != -1).shape[0])) | |
best_pose_idx = sorted(pose_clashes, key=lambda x: x[1])[0][0] | |
best_pose = receptor + stack[best_pose_idx] | |
output_dir = Path(receptor_pdb).parent | |
# System ID | |
pdb_name = Path(receptor_pdb).stem + "--" + Path(ligand_pdb).name | |
output_pdb = output_dir / pdb_name | |
atoms.write_pdb(best_pose, output_pdb) | |
end_time = time.time() | |
run_time = end_time - start_time | |
return str(output_pdb), run_time | |
def evaluate( | |
system_id: str, | |
prediction_pdb: Path, | |
) -> tuple[pd.DataFrame, float]: | |
start_time = time.time() | |
system = PinderSystem(system_id) | |
native = system.native.filepath | |
bdq = BiotiteDockQ(native, Path(prediction_pdb), parallel_io=False) | |
metrics = bdq.calculate() | |
metrics = metrics[["system", "LRMS", "iRMS", "Fnat", "DockQ", "CAPRI"]].copy() | |
metrics.rename(columns={"LRMS": "L_rms", "iRMS": "I_rms", "Fnat": "F_nat", "DockQ": "DOCKQ", "CAPRI": "CAPRI_class"}, inplace=True) | |
end_time = time.time() | |
run_time = end_time - start_time | |
pred = Structure(Path(prediction_pdb)) | |
nat = Structure(Path(native)) | |
pred, _, _ = pred.superimpose(nat) | |
pred.to_pdb(Path(prediction_pdb)) | |
return metrics, [str(prediction_pdb), str(native)], run_time | |
with gr.Blocks() as app: | |
with gr.Tab("🧬 PINDER inference template"): | |
gr.Markdown("Title, description, and other information about the model") | |
with gr.Row(): | |
with gr.Column(): | |
input_protein_1 = gr.File(label="Input Protein 1 monomer (PDB)") | |
input_fasta_1 = gr.File(label="Input Protein 1 monomer sequence (FASTA)") | |
with gr.Column(): | |
input_protein_2 = gr.File(label="Input Protein 2 monomer (PDB)") | |
input_fasta_2 = gr.File(label="Input Protein 2 monomer sequence (FASTA)") | |
# define any options here | |
# for automated inference the default options are used | |
# slider_option = gr.Slider(0,10, label="Slider Option") | |
# checkbox_option = gr.Checkbox(label="Checkbox Option") | |
# dropdown_option = gr.Dropdown(["Option 1", "Option 2", "Option 3"], label="Radio Option") | |
btn = gr.Button("Run Inference") | |
gr.Examples( | |
[ | |
[ | |
"8i5w_R.pdb", | |
"8i5w_R.fasta", | |
"8i5w_L.pdb", | |
"8i5w_L.fasta", | |
], | |
], | |
[input_protein_1, input_fasta_1, input_protein_2, input_fasta_2], | |
) | |
reps = [ | |
{ | |
"model": 0, | |
"style": "cartoon", | |
"chain": "R", | |
"color": "whiteCarbon", | |
}, | |
{ | |
"model": 0, | |
"style": "cartoon", | |
"chain": "L", | |
"color": "greenCarbon", | |
}, | |
{ | |
"model": 0, | |
"chain": "R", | |
"style": "stick", | |
"sidechain": True, | |
"color": "whiteCarbon", | |
}, | |
{ | |
"model": 0, | |
"chain": "L", | |
"style": "stick", | |
"sidechain": True, | |
"color": "greenCarbon" | |
} | |
] | |
out = Molecule3D(reps=reps) | |
run_time = gr.Textbox(label="Runtime") | |
btn.click(predict, inputs=[input_protein_1, input_protein_2, input_fasta_1, input_fasta_2], outputs=[out, run_time]) | |
with gr.Tab("⚖️ PINDER evaluation template"): | |
with gr.Row(): | |
with gr.Column(): | |
input_system_id = gr.Textbox(label="PINDER system ID") | |
input_prediction_pdb = gr.File(label="Top ranked prediction (PDB with chains R and L)") | |
eval_btn = gr.Button("Run Evaluation") | |
gr.Examples( | |
[ | |
[ | |
"3g9w__A1_Q71LX4--3g9w__D1_P05556", | |
"3g9w_R--3g9w_L.pdb", | |
], | |
], | |
[input_system_id, input_prediction_pdb], | |
) | |
reps = [ | |
{ | |
"model": 0, | |
"style": "cartoon", | |
"chain": "R", | |
"color": "greenCarbon", | |
}, | |
{ | |
"model": 0, | |
"style": "cartoon", | |
"chain": "L", | |
"color": "cyanCarbon", | |
}, | |
{ | |
"model": 1, | |
"style": "cartoon", | |
"chain": "R", | |
"color": "grayCarbon", | |
}, | |
{ | |
"model": 1, | |
"style": "cartoon", | |
"chain": "L", | |
"color": "blueCarbon", | |
}, | |
] | |
pred_native = Molecule3D(reps=reps, config={"backgroundColor": "black"}) | |
eval_run_time = gr.Textbox(label="Evaluation runtime") | |
metric_table = gr.DataFrame(pd.DataFrame([], columns=EVAL_METRICS),label="Evaluation metrics") | |
eval_btn.click(evaluate, inputs=[input_system_id, input_prediction_pdb], outputs=[metric_table, pred_native, eval_run_time]) | |
app.launch() | |