danielkovtun commited on
Commit
1de0dc0
·
1 Parent(s): 7d59aa6

feat: update inference app template to implement a dummy poc predict endpoint that follows the suggested input specs

Browse files
Files changed (1) hide show
  1. inference_app.py +80 -20
inference_app.py CHANGED
@@ -1,20 +1,81 @@
1
-
2
  import time
 
 
 
 
 
 
 
3
 
4
  import gradio as gr
5
 
6
  from gradio_molecule3d import Molecule3D
7
 
8
 
9
-
10
-
11
- def predict (input_seq_1, input_protein_1, input_seq_2, input_protein_2):
 
 
 
12
  start_time = time.time()
13
  # Do inference here
14
  # return an output pdb file with the protein and two chains A and B.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  end_time = time.time()
16
  run_time = end_time - start_time
17
- return "test_out.pdb", run_time
 
18
 
19
  with gr.Blocks() as app:
20
 
@@ -23,11 +84,11 @@ with gr.Blocks() as app:
23
  gr.Markdown("Title, description, and other information about the model")
24
  with gr.Row():
25
  with gr.Column():
26
- input_seq_1 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
27
- input_protein_1 = gr.File(label="Input Protein 2 monomer (PDB)")
28
  with gr.Column():
29
- input_seq_2 = gr.Textbox(lines=3, label="Input Protein 1 sequence (FASTA)")
30
- input_protein_2 = gr.File(label="Input Protein 2 structure (PDB)")
31
 
32
 
33
 
@@ -43,38 +104,37 @@ with gr.Blocks() as app:
43
  gr.Examples(
44
  [
45
  [
46
- "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD",
47
- "3v1c_A.pdb",
48
- "GSGSPLAQQIKNIHSFIHQAKAAGRMDEVRTLQENLHQLMHEYFQQSD",
49
- "3v1c_B.pdb",
50
-
51
  ],
52
  ],
53
- [input_seq_1, input_protein_1, input_seq_2, input_protein_2],
54
  )
55
  reps = [
56
  {
57
  "model": 0,
58
  "style": "cartoon",
59
- "chain": "A",
60
  "color": "whiteCarbon",
61
  },
62
  {
63
  "model": 0,
64
  "style": "cartoon",
65
- "chain": "B",
66
  "color": "greenCarbon",
67
  },
68
  {
69
  "model": 0,
70
- "chain": "A",
71
  "style": "stick",
72
  "sidechain": True,
73
  "color": "whiteCarbon",
74
  },
75
  {
76
  "model": 0,
77
- "chain": "B",
78
  "style": "stick",
79
  "sidechain": True,
80
  "color": "greenCarbon"
@@ -84,6 +144,6 @@ with gr.Blocks() as app:
84
  out = Molecule3D(reps=reps)
85
  run_time = gr.Textbox(label="Runtime")
86
 
87
- btn.click(predict, inputs=[input_seq_1, input_protein_1, input_seq_2, input_protein_2], outputs=[out, run_time])
88
 
89
  app.launch()
 
1
+ import random
2
  import time
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ from biotite.structure.atoms import AtomArrayStack
7
+ from scipy.spatial.transform import Rotation as R
8
+ from pinder.core.structure.atoms import atom_array_from_pdb_file, normalize_orientation, write_pdb
9
+ from pinder.core.structure.contacts import get_stack_contacts
10
 
11
  import gradio as gr
12
 
13
  from gradio_molecule3d import Molecule3D
14
 
15
 
16
+ def predict(
17
+ receptor_pdb: Path,
18
+ ligand_pdb: Path,
19
+ receptor_fasta: Path | None = None,
20
+ ligand_fasta: Path | None = None,
21
+ ) -> tuple[str, float]:
22
  start_time = time.time()
23
  # Do inference here
24
  # return an output pdb file with the protein and two chains A and B.
25
+ receptor = atom_array_from_pdb_file(receptor_pdb, extra_fields=["b_factor"])
26
+ ligand = atom_array_from_pdb_file(ligand_pdb, extra_fields=["b_factor"])
27
+ receptor = normalize_orientation(receptor)
28
+ ligand = normalize_orientation(ligand)
29
+
30
+ # Number of random poses to generate
31
+ M = 50
32
+ # Inititalize an empty stack with shape (m x n x 3)
33
+ stack = AtomArrayStack(M, ligand.shape[0])
34
+
35
+ # copy annotations from ligand
36
+ for annot in ligand.get_annotation_categories():
37
+ stack.set_annotation(annot, np.copy(ligand.get_annotation(annot)))
38
+
39
+ # Random translations sampled along 0-50 angstroms per axis
40
+ translation_magnitudes = np.linspace(
41
+ 0, M + 1,
42
+ num=M + 1,
43
+ endpoint=False
44
+ )
45
+ # generate one pose at a time
46
+ for i in range(M):
47
+ q = R.random()
48
+ translation_vec = [
49
+ random.choice(translation_magnitudes), # x
50
+ random.choice(translation_magnitudes), # y
51
+ random.choice(translation_magnitudes), # z
52
+ ]
53
+ # transform the ligand chain
54
+ stack.coord[i, ...] = q.apply(ligand.coord) + translation_vec
55
+
56
+ # Find clashes (1.2 A contact radius)
57
+ stack_conts = get_stack_contacts(receptor, stack, threshold=1.2)
58
+
59
+ # Keep the "best" pose based on pose w/fewest clashes
60
+ pose_clashes = []
61
+ for i in range(stack_conts.shape[0]):
62
+ pose_conts = stack_conts[i]
63
+ pose_clashes.append((i, np.argwhere(pose_conts != -1).shape[0]))
64
+
65
+ best_pose_idx = sorted(pose_clashes, key=lambda x: x[1])[0][0]
66
+ best_pose = receptor + stack[best_pose_idx]
67
+
68
+ output_dir = Path(receptor_pdb).parent
69
+ # System ID
70
+ pdb_name = "--".join([
71
+ Path(receptor_pdb).stem.rstrip("-R"), Path(ligand_pdb).stem.rstrip("-L")
72
+ ]) + ".pdb"
73
+ output_pdb = output_dir / pdb_name
74
+ write_pdb(best_pose, output_pdb)
75
  end_time = time.time()
76
  run_time = end_time - start_time
77
+ return str(output_pdb), run_time
78
+
79
 
80
  with gr.Blocks() as app:
81
 
 
84
  gr.Markdown("Title, description, and other information about the model")
85
  with gr.Row():
86
  with gr.Column():
87
+ input_protein_1 = gr.File(label="Input Protein 1 monomer (PDB)")
88
+ input_fasta_1 = gr.File(label="Input Protein 1 monomer sequence (FASTA)")
89
  with gr.Column():
90
+ input_protein_2 = gr.File(label="Input Protein 2 monomer (PDB)")
91
+ input_fasta_2 = gr.File(label="Input Protein 2 monomer sequence (FASTA)")
92
 
93
 
94
 
 
104
  gr.Examples(
105
  [
106
  [
107
+ "8i5w_R.pdb",
108
+ "8i5w_R.fasta",
109
+ "8i5w_L.pdb",
110
+ "8i5w_L.fasta",
 
111
  ],
112
  ],
113
+ [input_protein_1, input_fasta_1, input_protein_2, input_fasta_2],
114
  )
115
  reps = [
116
  {
117
  "model": 0,
118
  "style": "cartoon",
119
+ "chain": "R",
120
  "color": "whiteCarbon",
121
  },
122
  {
123
  "model": 0,
124
  "style": "cartoon",
125
+ "chain": "L",
126
  "color": "greenCarbon",
127
  },
128
  {
129
  "model": 0,
130
+ "chain": "R",
131
  "style": "stick",
132
  "sidechain": True,
133
  "color": "whiteCarbon",
134
  },
135
  {
136
  "model": 0,
137
+ "chain": "L",
138
  "style": "stick",
139
  "sidechain": True,
140
  "color": "greenCarbon"
 
144
  out = Molecule3D(reps=reps)
145
  run_time = gr.Textbox(label="Runtime")
146
 
147
+ btn.click(predict, inputs=[input_protein_1, input_protein_2, input_fasta_1, input_fasta_2], outputs=[out, run_time])
148
 
149
  app.launch()