ZhaohanM commited on
Commit
5676c75
·
1 Parent(s): fde18d9

Update: SMILES-to-SELFIES conversion, UI polish, and usage guide

Browse files
.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, argparse, tempfile, shutil, base64, io
2
+ from flask import Flask, request, render_template_string
3
+ from werkzeug.utils import secure_filename
4
+ from torch.utils.data import DataLoader
5
+ import selfies
6
+ from rdkit import Chem
7
+
8
+ import torch
9
+ import matplotlib
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pyplot as plt
12
+ from matplotlib import cm
13
+ from typing import Optional
14
+
15
+ from utils.drug_tokenizer import DrugTokenizer
16
+ from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
17
+ from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
18
+ from utils.foldseek_util import get_struc_seq
19
+
20
+ # ───── Biopython fallback ───────────────────────────────────────
21
+ from Bio.PDB import PDBParser, MMCIFParser
22
+ from Bio.Data import IUPACData
23
+
24
+ three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
25
+ three2one.update({"SEC": "C", "PYL": "K"})
26
+ def simple_seq_from_structure(path: str) -> str:
27
+ parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
28
+ chain = next(parser.get_structure("P", path).get_chains())
29
+ return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
30
+
31
+ # ───── global paths / args ──────────────────────────────────────
32
+ FOLDSEEK_BIN = shutil.which("foldseek")
33
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
34
+ sys.path.append("..")
35
+
36
+ def parse_config():
37
+ p = argparse.ArgumentParser()
38
+ p.add_argument("-f")
39
+ p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
40
+ p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
41
+ p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
42
+ p.add_argument("--group_size", type=int, default=1)
43
+ p.add_argument("--lr", type=float, default=1e-4)
44
+ p.add_argument("--fusion", default="CAN")
45
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
46
+ p.add_argument("--save_path_prefix", default="save_model_ckp/")
47
+ p.add_argument("--dataset", default="BindingDB"), help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')"
48
+ return p.parse_args()
49
+
50
+ args = parse_config()
51
+ DEVICE = args.device
52
+
53
+ # ───── tokenisers & encoders ────────────────────────────────────
54
+ prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
55
+ prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
56
+
57
+ drug_tokenizer = DrugTokenizer() # SELFIES
58
+ drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
59
+
60
+ encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
61
+
62
+ # ─── collate fn ────────────────────────────────────────────────
63
+ def collate_fn(batch):
64
+ query1, query2, scores = zip(*batch)
65
+
66
+ query_encodings1 = prot_tokenizer.batch_encode_plus(
67
+ list(query1),
68
+ max_length=512,
69
+ padding="max_length",
70
+ truncation=True,
71
+ add_special_tokens=True,
72
+ return_tensors="pt",
73
+ )
74
+ query_encodings2 = drug_tokenizer.batch_encode_plus(
75
+ list(query2),
76
+ max_length=512,
77
+ padding="max_length",
78
+ truncation=True,
79
+ add_special_tokens=True,
80
+ return_tensors="pt",
81
+ )
82
+ scores = torch.tensor(list(scores))
83
+
84
+ attention_mask1 = query_encodings1["attention_mask"].bool()
85
+ attention_mask2 = query_encodings2["attention_mask"].bool()
86
+
87
+ return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
88
+ # def collate_fn_batch_encoding(batch):
89
+
90
+ def smiles_to_selfies(smiles: str) -> Optional[str]:
91
+ try:
92
+ mol = Chem.MolFromSmiles(smiles)
93
+ if mol is None:
94
+ return None
95
+ selfies_str = selfies.encoder(smiles)
96
+ return selfies_str
97
+ except Exception:
98
+ return None
99
+
100
+
101
+ # ───── single-case embedding ───────────────────────────────────
102
+ def get_case_feature(model, loader):
103
+ model.eval()
104
+ with torch.no_grad():
105
+ for p_ids, p_mask, d_ids, d_mask, _ in loader:
106
+ p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
107
+ d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
108
+ p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
109
+ return [(p_emb.cpu(), d_emb.cpu(),
110
+ p_ids.cpu(), d_ids.cpu(),
111
+ p_mask.cpu(), d_mask.cpu(), None)]
112
+
113
+ # ───── helper:过滤特殊 token ───────────────────────────────────
114
+ def clean_tokens(ids, tokenizer):
115
+ toks = tokenizer.convert_ids_to_tokens(ids.tolist())
116
+ return [t for t in toks if t not in tokenizer.all_special_tokens]
117
+
118
+ # ───── visualisation ───────────────────────────────────────────
119
+
120
+ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
121
+ """
122
+ Render a Protein → Drug cross-attention heat-map and, optionally, a
123
+ Top-20 protein-residue table for a chosen drug-token index.
124
+
125
+ The token index shown on the x-axis (and accepted via *drug_idx*) is **the
126
+ position of that token in the *original* drug sequence**, *after* the
127
+ tokeniser but *before* any pruning or truncation (1-based in the labels,
128
+ 0-based for the function argument).
129
+
130
+ Returns
131
+ -------
132
+ html : str
133
+ Base64-embedded PNG heat-map (+ optional HTML table).
134
+ """
135
+ model.eval()
136
+ with torch.no_grad():
137
+ # ── unpack single-case tensors ───────────────────────────────────────────
138
+ p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
139
+ p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
140
+ p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
141
+
142
+ # ── forward pass: Protein → Drug attention (B, n_p, n_d) ───────────────
143
+ _, att_pd = model(p_emb, d_emb, p_mask, d_mask)
144
+ attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
145
+
146
+ # ── decode tokens (skip special symbols) ────────────────────────────────
147
+ def clean_ids(ids, tokenizer):
148
+ toks = tokenizer.convert_ids_to_tokens(ids.tolist())
149
+ return [t for t in toks if t not in tokenizer.all_special_tokens]
150
+
151
+ # ── decode full sequences + record 1-based indices ──────────────────
152
+ p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
153
+ p_indices_full = list(range(1, len(p_tokens_full) + 1))
154
+
155
+ d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
156
+ d_indices_full = list(range(1, len(d_tokens_full) + 1))
157
+
158
+ # ── safety cut-off to match attn mat size ───────────────────────────────
159
+ p_tokens = p_tokens_full[: attn.size(0)]
160
+ p_indices_full = p_indices_full[: attn.size(0)]
161
+ d_tokens_full = d_tokens_full[: attn.size(1)]
162
+ d_indices_full = d_indices_full[: attn.size(1)]
163
+ attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
164
+
165
+ # ── adaptive sparsity pruning ───────────────────────────────────────────
166
+ thr = attn.max().item() * 0.05
167
+ row_keep = (attn.max(dim=1).values > thr)
168
+ col_keep = (attn.max(dim=0).values > thr)
169
+
170
+ if row_keep.sum() < 3:
171
+ row_keep[:] = True
172
+ if col_keep.sum() < 3:
173
+ col_keep[:] = True
174
+
175
+ attn = attn[row_keep][:, col_keep]
176
+ p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
177
+ p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
178
+ d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
179
+ d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
180
+
181
+ # ── cap column count at 150 for readability ─────────────────────────────
182
+ if attn.size(1) > 150:
183
+ topc = torch.topk(attn.sum(0), k=150).indices
184
+ attn = attn[:, topc]
185
+ d_tokens = [d_tokens [i] for i in topc]
186
+ d_indices = [d_indices[i] for i in topc]
187
+
188
+ # ── draw heat-map ───────────────────────────────────────────────────────
189
+ x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
190
+ y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
191
+
192
+
193
+ fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column
194
+ fig_h = min(24, max(6, len(p_tokens) * 0.8))
195
+
196
+ fig, ax = plt.subplots(figsize=(fig_w, fig_h))
197
+ im = ax.imshow(attn.numpy(), aspect="auto",
198
+ cmap=cm.viridis, interpolation="nearest")
199
+
200
+ ax.set_title("Protein → Drug Attention", pad=8, fontsize=10)
201
+
202
+ ax.set_xticks(range(len(x_labels)))
203
+ ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
204
+ ha="center", va="center")
205
+ ax.tick_params(axis="x", top=True, bottom=False,
206
+ labeltop=True, labelbottom=False, pad=27)
207
+
208
+ ax.set_yticks(range(len(y_labels)))
209
+ ax.set_yticklabels(y_labels, fontsize=7)
210
+ ax.tick_params(axis="y", top=True, bottom=False,
211
+ labeltop=True, labelbottom=False,
212
+ pad=10)
213
+
214
+ fig.colorbar(im, fraction=0.026, pad=0.01)
215
+ fig.tight_layout()
216
+
217
+ buf = io.BytesIO()
218
+ fig.savefig(buf, format="png", dpi=140)
219
+ plt.close(fig)
220
+ html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
221
+
222
+ # ───────────────────── 生成 Top-20 表(若需要) ─────────────────────
223
+ table_html = "" # 先设空串,方便后面统一拼接
224
+ if drug_idx is not None:
225
+ # map original 0-based drug_idx → current column position
226
+ if (drug_idx + 1) in d_indices:
227
+ col_pos = d_indices.index(drug_idx + 1)
228
+ elif 0 <= drug_idx < len(d_tokens):
229
+ col_pos = drug_idx
230
+ else:
231
+ col_pos = None
232
+
233
+ if col_pos is not None:
234
+ col_vec = attn[:, col_pos]
235
+ topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist()
236
+
237
+ rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
238
+ res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
239
+ pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk)
240
+
241
+ drug_tok_text = d_tokens[col_pos]
242
+ orig_idx = d_indices[col_pos]
243
+
244
+ table_html = (
245
+ f"<h4 style='margin-bottom:6px'>"
246
+ f"Drug token #{orig_idx} <code>{drug_tok_text}</code> "
247
+ f"→ Top-20 Protein residues</h4>"
248
+ "<table class='tg' style='margin-bottom:8px'>"
249
+ f"<tr><th>Rank</th>{rank_hdr}</tr>"
250
+ f"<tr><td>Residue</td>{res_row}</tr>"
251
+ f"<tr><td>Position</td>{pos_row}</tr>"
252
+ "</table>")
253
+
254
+ # ────────────────── 生成可放大 + 可下载的热图 ────────────────────
255
+ buf_png = io.BytesIO()
256
+ fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅)
257
+ buf_png.seek(0)
258
+
259
+ buf_pdf = io.BytesIO()
260
+ fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量)
261
+ buf_pdf.seek(0)
262
+ plt.close(fig)
263
+
264
+ png_b64 = base64.b64encode(buf_png.getvalue()).decode()
265
+ pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
266
+
267
+ html_heat = (
268
+ f"<a href='data:image/png;base64,{png_b64}' target='_blank' "
269
+ f"title='Click to enlarge'>"
270
+ f"<img src='data:image/png;base64,{png_b64}' "
271
+ f"style='max-width:100%;height:auto;cursor:zoom-in' /></a>"
272
+ f"<div style='margin-top:6px'>"
273
+ f"<a href='data:application/pdf;base64,{pdf_b64}' "
274
+ f"download='attention_heatmap.pdf'>Download PDF</a></div>"
275
+ )
276
+
277
+ # ───────────────────────── 返回最终 HTML ─────────────────────────
278
+ return table_html + html_heat
279
+
280
+
281
+ # ───── Flask app ───────────────────────────────────────────────
282
+ app = Flask(__name__)
283
+
284
+ @app.route("/", methods=["GET", "POST"])
285
+ def index():
286
+ protein_seq = drug_seq = structure_seq = ""; result_html = None
287
+ tmp_structure_path = ""; drug_idx = None
288
+
289
+ if request.method == "POST":
290
+ drug_idx_raw = request.form.get("drug_idx", "")
291
+ drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None
292
+
293
+ struct = request.files.get("structure_file")
294
+ if struct and struct.filename:
295
+ path = os.path.join(tempfile.gettempdir(), secure_filename(struct.filename))
296
+ struct.save(path); tmp_structure_path = path
297
+ else:
298
+ tmp_structure_path = request.form.get("tmp_structure_path", "")
299
+
300
+ if "clear" in request.form:
301
+ protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = ""
302
+
303
+ elif "confirm_structure" in request.form and tmp_structure_path:
304
+ try:
305
+ parsed = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, None, plddt_mask=False)
306
+ chain = list(parsed.keys())[0]; _, _, structure_seq = parsed[chain]
307
+ except Exception:
308
+ structure_seq = simple_seq_from_structure(tmp_structure_path)
309
+ protein_seq = structure_seq
310
+ drug_input = request.form.get("drug_sequence", "")
311
+ # Heuristically check if input is SMILES (not starting with [) and convert
312
+ if not drug_input.strip().startswith("["):
313
+ converted = smiles_to_selfies(drug_input.strip())
314
+ if converted:
315
+ drug_seq = converted
316
+ else:
317
+ drug_seq = ""
318
+ result_html = "<p style='color:red'><strong>Failed to convert SMILES to SELFIES. Please check the input string.</strong></p>"
319
+ else:
320
+ drug_seq = drug_input
321
+
322
+ elif "Inference" in request.form:
323
+ protein_seq = request.form.get("protein_sequence", "")
324
+ drug_seq = request.form.get("drug_sequence", "")
325
+ if protein_seq and drug_seq:
326
+ loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1,
327
+ collate_fn=collate_fn)
328
+ feats = get_case_feature(encoding, loader)
329
+ model = FusionDTI(446, 768, args).to(DEVICE)
330
+ ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}",
331
+ "best_model.ckpt")
332
+ if os.path.isfile(ckpt):
333
+ model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
334
+ result_html = visualize_attention(model, feats, drug_idx)
335
+
336
+ return render_template_string(
337
+ # ───────────── HTML (原 UI + 新输入框) ─────────────
338
+ """
339
+ <!doctype html>
340
+ <html lang="en"><head><meta charset="utf-8"><title>FusionDTI </title>
341
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=Poppins:wght@500;600&display=swap" rel="stylesheet">
342
+
343
+ <style>
344
+ :root{--bg:#f3f4f6;--card:#fff;--primary:#6366f1;--primary-dark:#4f46e5;--text:#111827;--border:#e5e7eb;}
345
+ *{box-sizing:border-box;margin:0;padding:0}
346
+ body{background:var(--bg);color:var(--text);font-family:Inter,system-ui,Arial,sans-serif;line-height:1.5;padding:32px 12px;}
347
+ h1{font-family:Poppins,Inter,sans-serif;font-weight:600;font-size:1.7rem;text-align:center;margin-bottom:28px;letter-spacing:-.2px;}
348
+ .card{max-width:1000px;margin:0 auto;background:var(--card);border:1px solid var(--border);
349
+ border-radius:12px;box-shadow:0 2px 6px rgba(0,0,0,.05);padding:32px 36px;}
350
+ label{font-weight:500;margin-bottom:6px;display:block}
351
+ textarea,input[type=file]{width:100%;font-size:.9rem;font-family:monospace;padding:10px 12px;
352
+ border:1px solid var(--border);border-radius:8px;background:#fff;resize:vertical;}
353
+ textarea{min-height:90px}
354
+ .btn{appearance:none;border:none;cursor:pointer;padding:12px 22px;border-radius:8px;font-weight:500;
355
+ font-family:Inter,sans-serif;transition:all .18s ease;color:#fff;}
356
+ .btn-primary{background:var(--primary)}.btn-primary:hover{background:var(--primary-dark)}
357
+ .btn-neutral{background:#9ca3af;}.btn-neutral:hover{background:#6b7280}
358
+ .grid{display:grid;gap:22px}.grid-2{grid-template-columns:1fr 1fr}
359
+ .vis-box{margin-top:28px;border:1px solid var(--border);border-radius:10px;overflow:auto;max-height:72vh;}
360
+ pre{white-space:pre-wrap;word-break:break-all;font-family:monospace;margin-top:8px}
361
+
362
+ /* ── tidy table for Top-20 list ─────────────────────────────── */
363
+ table.tg{border-collapse:collapse;margin-top:4px;font-size:0.83rem}
364
+ table.tg th,table.tg td{border:1px solid var(--border);padding:6px 8px;text-align:left}
365
+ table.tg th{background:var(--bg);font-weight:600}
366
+ </style>
367
+ </head>
368
+ <body>
369
+ <h1> Token-level Visualiser for Drug-Target Interaction</h1>
370
+
371
+ <!-- ───────────── Project Links (larger + spaced) ───────────── -->
372
+ <div style="margin-top:24px; text-align:center;">
373
+ <a href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank"
374
+ style="display:inline-block;margin:8px 18px;padding:10px 20px;
375
+ background:linear-gradient(to right,#10b981,#059669);color:white;
376
+ font-weight:600;border-radius:8px;font-size:0.9rem;
377
+ font-family:Inter,sans-serif;text-decoration:none;
378
+ box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
379
+ onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
380
+ 🌐 Project Page
381
+ </a>
382
+
383
+ <a href="https://arxiv.org/abs/2406.01651" target="_blank"
384
+ style="display:inline-block;margin:8px 18px;padding:10px 20px;
385
+ background:linear-gradient(to right,#ef4444,#dc2626);color:white;
386
+ font-weight:600;border-radius:8px;font-size:0.9rem;
387
+ font-family:Inter,sans-serif;text-decoration:none;
388
+ box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
389
+ onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
390
+ 📄 ArXiv: 2406.01651
391
+ </a>
392
+
393
+ <a href="https://github.com/ZhaohanM/FusionDTI" target="_blank"
394
+ style="display:inline-block;margin:8px 18px;padding:10px 20px;
395
+ background:linear-gradient(to right,#3b82f6,#2563eb);color:white;
396
+ font-weight:600;border-radius:8px;font-size:0.9rem;
397
+ font-family:Inter,sans-serif;text-decoration:none;
398
+ box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
399
+ onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
400
+ 💻 GitHub Repo
401
+ </a>
402
+ </div>
403
+
404
+ <!-- ───────────── Guidelines for Use ───────────── -->
405
+ <div class="card" style="margin-bottom:24px">
406
+ <h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for Use</h2>
407
+ <ul style="margin-left:18px;line-height:1.55;list-style:decimal;">
408
+ <li><strong>Convert protein structure into a structure-aware sequence:</strong>
409
+ Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
410
+ sequence will be generated using
411
+ <a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
412
+ based on 3D structures from
413
+ <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
414
+ <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
415
+
416
+ <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
417
+ you must first visit the
418
+ <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
419
+ or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
420
+ to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
421
+
422
+ <li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
423
+ You can enter a SELFIES string directly, or paste a SMILES string.
424
+ SMILES will be automatically converted to SELFIES using
425
+ <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
426
+ If conversion fails, a red error message will be displayed.</li>
427
+
428
+ <li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
429
+ to highlight the Top-10 interacting protein residues.</li>
430
+
431
+ <li>After inference, you can use the
432
+ “Download PDF” link to export a high-resolution vector version.</li>
433
+ </ul>
434
+ </div>
435
+
436
+ <div class="card">
437
+ <form method="POST" enctype="multipart/form-data" class="grid">
438
+
439
+ <div><label>Protein Structure (.pdb / .cif)</label>
440
+ <input type="file" name="structure_file">
441
+ <input type="hidden" name="tmp_structure_path" value="{{ tmp_structure_path }}"></div>
442
+
443
+ <div><label>Protein Sequence</label>
444
+ <textarea name="protein_sequence" placeholder="Confirm / paste sequence…">{{ protein_seq }}</textarea></div>
445
+
446
+ <div><label>Drug Sequence (SELFIES/SMILES)</label>
447
+ <textarea name="drug_sequence" placeholder="[C][C][O]/cco …">{{ drug_seq }}</textarea></div>
448
+
449
+ <label>Drug atom/substructure index (1-based) – show Top-10 related protein residue</label>
450
+ <input type="number" name="drug_idx" min="1" style="width:120px">
451
+
452
+ <div class="grid grid-2">
453
+ <button class="btn btn-primary" type="Inference" name="confirm_structure">Confirm Structure</button>
454
+ <button class="btn btn-primary" type="Inference" name="Inference">Inference</button>
455
+ </div>
456
+ <button class="btn btn-neutral" style="width:100%" type="Inference" name="clear">Clear</button>
457
+ </form>
458
+
459
+ {% if structure_seq %}
460
+ <div style="margin-top:18px"><strong>Structure-aware sequence:</strong><pre>{{ structure_seq }}</pre></div>
461
+ {% endif %}
462
+ {% if result_html %}
463
+ <div class="vis-box" style="margin-top:26px">{{ result_html|safe }}</div>
464
+ {% endif %}
465
+ </div></body></html>
466
+ """,
467
+ protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq,
468
+ result_html=result_html, tmp_structure_path=tmp_structure_path)
469
+
470
+ # ───── run ─────────────────────────────────────────────────────
471
+ if __name__ == "__main__":
472
+ app.run(debug=True, host="0.0.0.0", port=7860)
.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Flask
2
+ torch
3
+ transformers
4
+ IPython
5
+ selfies
6
+ rdkit
7
+ biopython
8
+ matplotlib
9
+ scikit-learn
10
+ numpy
11
+ pandas
app.py CHANGED
@@ -1,209 +1,66 @@
1
- import os
2
- import sys
3
- import argparse
4
- import torch
5
  from torch.utils.data import DataLoader
6
- from transformers import EsmForMaskedLM, AutoModel, EsmTokenizer
 
 
 
 
 
 
 
 
 
7
  from utils.drug_tokenizer import DrugTokenizer
 
8
  from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
9
- from bertviz import head_view
10
- import tempfile
11
- from flask import Flask, request, render_template_string
12
 
13
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
- sys.path.append("../")
 
15
 
16
- app = Flask(__name__)
 
 
 
 
 
 
 
 
 
 
17
 
18
  def parse_config():
19
- parser = argparse.ArgumentParser()
20
- parser.add_argument('-f')
21
- parser.add_argument("--prot_encoder_path", type=str, default="westlake-repl/SaProt_650M_AF2", help="path/name of protein encoder model located")
22
- parser.add_argument("--drug_encoder_path", type=str, default="HUBioDataLab/SELFormer", help="path/name of SMILE pre-trained language model")
23
- parser.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
24
- parser.add_argument("--fusion", default="CAN", type=str, help="{CAN|BAN}")
25
- parser.add_argument("--batch_size", type=int, default=64)
26
- parser.add_argument("--group_size", type=int, default=1)
27
- parser.add_argument("--lr", type=float, default=1e-4)
28
- parser.add_argument("--dropout", type=float, default=0.1)
29
- parser.add_argument("--test", type=int, default=0)
30
- parser.add_argument("--use_pooled", action="store_true", default=True)
31
- parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
32
- parser.add_argument("--save_path_prefix", type=str, default="save_model_ckp/", help="save the result in which directory")
33
- parser.add_argument("--save_name", default="fine_tune", type=str, help="the name of the saved file")
34
- parser.add_argument("--dataset", type=str, default="Human", help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')")
35
- return parser.parse_args()
36
 
37
  args = parse_config()
38
- device = args.device
39
 
 
40
  prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
41
- drug_tokenizer = DrugTokenizer()
42
-
43
- prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
44
- drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
45
 
46
- encoding = Pre_encoded(prot_model, drug_model, args).to(device)
 
47
 
48
- def get_case_feature(model, dataloader, device):
49
- with torch.no_grad():
50
- for step, batch in enumerate(dataloader):
51
- prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask, label = batch
52
- prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask = \
53
- prot_input_ids.to(device), prot_attention_mask.to(device), drug_input_ids.to(device), drug_attention_mask.to(device)
54
-
55
- prot_embed, drug_embed = model.encoding(prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask)
56
- prot_embed, drug_embed = prot_embed.cpu(), drug_embed.cpu()
57
- prot_input_ids, drug_input_ids = prot_input_ids.cpu(), drug_input_ids.cpu()
58
- prot_attention_mask, drug_attention_mask = prot_attention_mask.cpu(), drug_attention_mask.cpu()
59
- label = label.cpu()
60
-
61
- return [(prot_embed, drug_embed, prot_input_ids, drug_input_ids, prot_attention_mask, drug_attention_mask, label)]
62
 
63
- def visualize_attention(model, case_features, device, prot_tokenizer, drug_tokenizer):
64
- model.eval()
65
- with torch.no_grad():
66
- for batch in case_features:
67
- prot, drug, prot_ids, drug_ids, prot_mask, drug_mask, label = batch
68
- prot, drug = prot.to(device), drug.to(device)
69
- prot_mask, drug_mask = prot_mask.to(device), drug_mask.to(device)
70
-
71
- output, attention_weights = model(prot, drug, prot_mask, drug_mask)
72
- prot_tokens = [prot_tokenizer.decode([pid.item()], skip_special_tokens=True) for pid in prot_ids.squeeze()]
73
- drug_tokens = [drug_tokenizer.decode([did.item()], skip_special_tokens=True) for did in drug_ids.squeeze()]
74
- tokens = prot_tokens + drug_tokens
75
-
76
- attention_weights = attention_weights.unsqueeze(1)
77
-
78
- # Generate HTML content using head_view with html_action='return'
79
- html_head_view = head_view(attention_weights, tokens, sentence_b_start=512, html_action='return')
80
-
81
- # Parse the HTML and modify it to replace sentence labels
82
- html_content = html_head_view.data
83
- html_content = html_content.replace("Sentence A -> Sentence A", "Protein -> Protein")
84
- html_content = html_content.replace("Sentence B -> Sentence B", "Drug -> Drug")
85
- html_content = html_content.replace("Sentence A -> Sentence B", "Protein -> Drug")
86
- html_content = html_content.replace("Sentence B -> Sentence A", "Drug -> Protein")
87
-
88
- # Save the modified HTML content to a temporary file
89
- with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
90
- f.write(html_content.encode('utf-8'))
91
- temp_file_path = f.name
92
-
93
- return temp_file_path
94
-
95
- @app.route('/', methods=['GET', 'POST'])
96
- def index():
97
- protein_sequence = ""
98
- drug_sequence = ""
99
- result = None
100
-
101
- if request.method == 'POST':
102
- if 'clear' in request.form:
103
- protein_sequence = ""
104
- drug_sequence = ""
105
- else:
106
- protein_sequence = request.form['protein_sequence']
107
- drug_sequence = request.form['drug_sequence']
108
-
109
- dataset = [(protein_sequence, drug_sequence, 1)]
110
- dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn_batch_encoding)
111
-
112
- case_features = get_case_feature(encoding, dataloader, device)
113
- model = FusionDTI(446, 768, args).to(device)
114
-
115
- best_model_dir = f"{args.save_path_prefix}{args.dataset}_{args.fusion}"
116
- checkpoint_path = os.path.join(best_model_dir, 'best_model.ckpt')
117
-
118
- if os.path.exists(checkpoint_path):
119
- model.load_state_dict(torch.load(checkpoint_path, map_location=device))
120
-
121
- html_file_path = visualize_attention(model, case_features, device, prot_tokenizer, drug_tokenizer)
122
-
123
- with open(html_file_path, 'r') as f:
124
- result = f.read()
125
-
126
- return render_template_string('''
127
- <html>
128
- <head>
129
- <title>Drug Target Interaction Visualization</title>
130
- <style>
131
- body { font-family: 'Times New Roman', Times, serif; margin: 40px; }
132
- h2 { color: #333; }
133
- .container { display: flex; }
134
- .left { flex: 1; padding-right: 20px; }
135
- .right { flex: 1; }
136
- textarea {
137
- width: 100%;
138
- padding: 12px 20px;
139
- margin: 8px 0;
140
- display: inline-block;
141
- border: 1px solid #ccc;
142
- border-radius: 4px;
143
- box-sizing: border-box;
144
- font-size: 16px;
145
- font-family: 'Times New Roman', Times, serif;
146
- }
147
- .button-container {
148
- display: flex;
149
- justify-content: space-between;
150
- }
151
- input[type="submit"], .button {
152
- width: 48%;
153
- color: white;
154
- padding: 14px 20px;
155
- margin: 8px 0;
156
- border: none;
157
- border-radius: 4px;
158
- cursor: pointer;
159
- font-size: 16px;
160
- font-family: 'Times New Roman', Times, serif;
161
- }
162
- .submit {
163
- background-color: #FFA500;
164
- }
165
- .submit:hover {
166
- background-color: #FF8C00;
167
- }
168
- .clear {
169
- background-color: #D3D3D3;
170
- }
171
- .clear:hover {
172
- background-color: #A9A9A9;
173
- }
174
- .result {
175
- font-size: 18px;
176
- }
177
- </style>
178
- </head>
179
- <body>
180
- <h2 style="text-align: center;">Drug Target Interaction Visualization</h2>
181
- <div class="container">
182
- <div class="left">
183
- <form method="post">
184
- <label for="protein_sequence">Protein Sequence:</label>
185
- <textarea id="protein_sequence" name="protein_sequence" rows="4" placeholder="Enter protein sequence here..." required>{{ protein_sequence }}</textarea><br>
186
- <label for="drug_sequence">Drug Sequence:</label>
187
- <textarea id="drug_sequence" name="drug_sequence" rows="4" placeholder="Enter drug sequence here..." required>{{ drug_sequence }}</textarea><br>
188
- <div class="button-container">
189
- <input type="submit" name="submit" class="button submit" value="Submit">
190
- <input type="submit" name="clear" class="button clear" value="Clear">
191
- </div>
192
- </form>
193
- </div>
194
- <div class="right" style="display: flex; justify-content: center; align-items: center;">
195
- {% if result %}
196
- <div class="result">
197
- {{ result|safe }}
198
- </div>
199
- {% endif %}
200
- </div>
201
- </div>
202
- </body>
203
- </html>
204
- ''', protein_sequence=protein_sequence, drug_sequence=drug_sequence, result=result)
205
-
206
- def collate_fn_batch_encoding(batch):
207
  query1, query2, scores = zip(*batch)
208
 
209
  query_encodings1 = prot_tokenizer.batch_encode_plus(
@@ -228,6 +85,388 @@ def collate_fn_batch_encoding(batch):
228
  attention_mask2 = query_encodings2["attention_mask"].bool()
229
 
230
  return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
- if __name__ == '__main__':
 
233
  app.run(debug=True, host="0.0.0.0", port=7860)
 
1
+ import os, sys, argparse, tempfile, shutil, base64, io
2
+ from flask import Flask, request, render_template_string
3
+ from werkzeug.utils import secure_filename
 
4
  from torch.utils.data import DataLoader
5
+ import selfies
6
+ from rdkit import Chem
7
+
8
+ import torch
9
+ import matplotlib
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pyplot as plt
12
+ from matplotlib import cm
13
+ from typing import Optional
14
+
15
  from utils.drug_tokenizer import DrugTokenizer
16
+ from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
17
  from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
18
+ from utils.foldseek_util import get_struc_seq
 
 
19
 
20
+ # ───── Biopython fallback ───────────────────────────────────────
21
+ from Bio.PDB import PDBParser, MMCIFParser
22
+ from Bio.Data import IUPACData
23
 
24
+ three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
25
+ three2one.update({"SEC": "C", "PYL": "K"})
26
+ def simple_seq_from_structure(path: str) -> str:
27
+ parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
28
+ chain = next(parser.get_structure("P", path).get_chains())
29
+ return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
30
+
31
+ # ───── global paths / args ──────────────────────────────────────
32
+ FOLDSEEK_BIN = shutil.which("foldseek")
33
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
34
+ sys.path.append("..")
35
 
36
  def parse_config():
37
+ p = argparse.ArgumentParser()
38
+ p.add_argument("-f")
39
+ p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
40
+ p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
41
+ p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
42
+ p.add_argument("--group_size", type=int, default=1)
43
+ p.add_argument("--lr", type=float, default=1e-4)
44
+ p.add_argument("--fusion", default="CAN")
45
+ p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
46
+ p.add_argument("--save_path_prefix", default="save_model_ckp/")
47
+ p.add_argument("--dataset", default="BindingDB"), help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')"
48
+ return p.parse_args()
 
 
 
 
 
49
 
50
  args = parse_config()
51
+ DEVICE = args.device
52
 
53
+ # ───── tokenisers & encoders ────────────────────────────────────
54
  prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
55
+ prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
 
 
 
56
 
57
+ drug_tokenizer = DrugTokenizer() # SELFIES
58
+ drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
59
 
60
+ encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # ─── collate fn ────────────────────────────────────────────────
63
+ def collate_fn(batch):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  query1, query2, scores = zip(*batch)
65
 
66
  query_encodings1 = prot_tokenizer.batch_encode_plus(
 
85
  attention_mask2 = query_encodings2["attention_mask"].bool()
86
 
87
  return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
88
+ # def collate_fn_batch_encoding(batch):
89
+
90
+ def smiles_to_selfies(smiles: str) -> Optional[str]:
91
+ try:
92
+ mol = Chem.MolFromSmiles(smiles)
93
+ if mol is None:
94
+ return None
95
+ selfies_str = selfies.encoder(smiles)
96
+ return selfies_str
97
+ except Exception:
98
+ return None
99
+
100
+
101
+ # ───── single-case embedding ───────────────────────────────────
102
+ def get_case_feature(model, loader):
103
+ model.eval()
104
+ with torch.no_grad():
105
+ for p_ids, p_mask, d_ids, d_mask, _ in loader:
106
+ p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
107
+ d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
108
+ p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
109
+ return [(p_emb.cpu(), d_emb.cpu(),
110
+ p_ids.cpu(), d_ids.cpu(),
111
+ p_mask.cpu(), d_mask.cpu(), None)]
112
+
113
+ # ───── helper:过滤特殊 token ───────────────────────────────────
114
+ def clean_tokens(ids, tokenizer):
115
+ toks = tokenizer.convert_ids_to_tokens(ids.tolist())
116
+ return [t for t in toks if t not in tokenizer.all_special_tokens]
117
+
118
+ # ───── visualisation ───────────────────────────────────────────
119
+
120
+ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
121
+ """
122
+ Render a Protein → Drug cross-attention heat-map and, optionally, a
123
+ Top-20 protein-residue table for a chosen drug-token index.
124
+
125
+ The token index shown on the x-axis (and accepted via *drug_idx*) is **the
126
+ position of that token in the *original* drug sequence**, *after* the
127
+ tokeniser but *before* any pruning or truncation (1-based in the labels,
128
+ 0-based for the function argument).
129
+
130
+ Returns
131
+ -------
132
+ html : str
133
+ Base64-embedded PNG heat-map (+ optional HTML table).
134
+ """
135
+ model.eval()
136
+ with torch.no_grad():
137
+ # ── unpack single-case tensors ───────────────────────────────────────────
138
+ p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
139
+ p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
140
+ p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
141
+
142
+ # ── forward pass: Protein → Drug attention (B, n_p, n_d) ───────────────
143
+ _, att_pd = model(p_emb, d_emb, p_mask, d_mask)
144
+ attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
145
+
146
+ # ── decode tokens (skip special symbols) ────────────────────────────────
147
+ def clean_ids(ids, tokenizer):
148
+ toks = tokenizer.convert_ids_to_tokens(ids.tolist())
149
+ return [t for t in toks if t not in tokenizer.all_special_tokens]
150
+
151
+ # ── decode full sequences + record 1-based indices ──────────────────
152
+ p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
153
+ p_indices_full = list(range(1, len(p_tokens_full) + 1))
154
+
155
+ d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
156
+ d_indices_full = list(range(1, len(d_tokens_full) + 1))
157
+
158
+ # ── safety cut-off to match attn mat size ───────────────────────────────
159
+ p_tokens = p_tokens_full[: attn.size(0)]
160
+ p_indices_full = p_indices_full[: attn.size(0)]
161
+ d_tokens_full = d_tokens_full[: attn.size(1)]
162
+ d_indices_full = d_indices_full[: attn.size(1)]
163
+ attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
164
+
165
+ # ── adaptive sparsity pruning ───────────────────────────────────────────
166
+ thr = attn.max().item() * 0.05
167
+ row_keep = (attn.max(dim=1).values > thr)
168
+ col_keep = (attn.max(dim=0).values > thr)
169
+
170
+ if row_keep.sum() < 3:
171
+ row_keep[:] = True
172
+ if col_keep.sum() < 3:
173
+ col_keep[:] = True
174
+
175
+ attn = attn[row_keep][:, col_keep]
176
+ p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
177
+ p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
178
+ d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
179
+ d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
180
+
181
+ # ── cap column count at 150 for readability ─────────────────────────────
182
+ if attn.size(1) > 150:
183
+ topc = torch.topk(attn.sum(0), k=150).indices
184
+ attn = attn[:, topc]
185
+ d_tokens = [d_tokens [i] for i in topc]
186
+ d_indices = [d_indices[i] for i in topc]
187
+
188
+ # ── draw heat-map ───────────────────────────────────────────────────────
189
+ x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
190
+ y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
191
+
192
+
193
+ fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column
194
+ fig_h = min(24, max(6, len(p_tokens) * 0.8))
195
+
196
+ fig, ax = plt.subplots(figsize=(fig_w, fig_h))
197
+ im = ax.imshow(attn.numpy(), aspect="auto",
198
+ cmap=cm.viridis, interpolation="nearest")
199
+
200
+ ax.set_title("Protein → Drug Attention", pad=8, fontsize=10)
201
+
202
+ ax.set_xticks(range(len(x_labels)))
203
+ ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
204
+ ha="center", va="center")
205
+ ax.tick_params(axis="x", top=True, bottom=False,
206
+ labeltop=True, labelbottom=False, pad=27)
207
+
208
+ ax.set_yticks(range(len(y_labels)))
209
+ ax.set_yticklabels(y_labels, fontsize=7)
210
+ ax.tick_params(axis="y", top=True, bottom=False,
211
+ labeltop=True, labelbottom=False,
212
+ pad=10)
213
+
214
+ fig.colorbar(im, fraction=0.026, pad=0.01)
215
+ fig.tight_layout()
216
+
217
+ buf = io.BytesIO()
218
+ fig.savefig(buf, format="png", dpi=140)
219
+ plt.close(fig)
220
+ html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
221
+
222
+ # ───────────────────── 生成 Top-20 表(若需要) ─────────────────────
223
+ table_html = "" # 先设空串,方便后面统一拼接
224
+ if drug_idx is not None:
225
+ # map original 0-based drug_idx → current column position
226
+ if (drug_idx + 1) in d_indices:
227
+ col_pos = d_indices.index(drug_idx + 1)
228
+ elif 0 <= drug_idx < len(d_tokens):
229
+ col_pos = drug_idx
230
+ else:
231
+ col_pos = None
232
+
233
+ if col_pos is not None:
234
+ col_vec = attn[:, col_pos]
235
+ topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist()
236
+
237
+ rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
238
+ res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
239
+ pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk)
240
+
241
+ drug_tok_text = d_tokens[col_pos]
242
+ orig_idx = d_indices[col_pos]
243
+
244
+ table_html = (
245
+ f"<h4 style='margin-bottom:6px'>"
246
+ f"Drug token #{orig_idx} <code>{drug_tok_text}</code> "
247
+ f"→ Top-20 Protein residues</h4>"
248
+ "<table class='tg' style='margin-bottom:8px'>"
249
+ f"<tr><th>Rank</th>{rank_hdr}</tr>"
250
+ f"<tr><td>Residue</td>{res_row}</tr>"
251
+ f"<tr><td>Position</td>{pos_row}</tr>"
252
+ "</table>")
253
+
254
+ # ────────────────── 生成可放大 + 可下载的热图 ────────────────────
255
+ buf_png = io.BytesIO()
256
+ fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅)
257
+ buf_png.seek(0)
258
+
259
+ buf_pdf = io.BytesIO()
260
+ fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量)
261
+ buf_pdf.seek(0)
262
+ plt.close(fig)
263
+
264
+ png_b64 = base64.b64encode(buf_png.getvalue()).decode()
265
+ pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
266
+
267
+ html_heat = (
268
+ f"<a href='data:image/png;base64,{png_b64}' target='_blank' "
269
+ f"title='Click to enlarge'>"
270
+ f"<img src='data:image/png;base64,{png_b64}' "
271
+ f"style='max-width:100%;height:auto;cursor:zoom-in' /></a>"
272
+ f"<div style='margin-top:6px'>"
273
+ f"<a href='data:application/pdf;base64,{pdf_b64}' "
274
+ f"download='attention_heatmap.pdf'>Download PDF</a></div>"
275
+ )
276
+
277
+ # ───────────────────────── 返回最终 HTML ─────────────────────────
278
+ return table_html + html_heat
279
+
280
+
281
+ # ───── Flask app ───────────────────────────────────────────────
282
+ app = Flask(__name__)
283
+
284
+ @app.route("/", methods=["GET", "POST"])
285
+ def index():
286
+ protein_seq = drug_seq = structure_seq = ""; result_html = None
287
+ tmp_structure_path = ""; drug_idx = None
288
+
289
+ if request.method == "POST":
290
+ drug_idx_raw = request.form.get("drug_idx", "")
291
+ drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None
292
+
293
+ struct = request.files.get("structure_file")
294
+ if struct and struct.filename:
295
+ path = os.path.join(tempfile.gettempdir(), secure_filename(struct.filename))
296
+ struct.save(path); tmp_structure_path = path
297
+ else:
298
+ tmp_structure_path = request.form.get("tmp_structure_path", "")
299
+
300
+ if "clear" in request.form:
301
+ protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = ""
302
+
303
+ elif "confirm_structure" in request.form and tmp_structure_path:
304
+ try:
305
+ parsed = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, None, plddt_mask=False)
306
+ chain = list(parsed.keys())[0]; _, _, structure_seq = parsed[chain]
307
+ except Exception:
308
+ structure_seq = simple_seq_from_structure(tmp_structure_path)
309
+ protein_seq = structure_seq
310
+ drug_input = request.form.get("drug_sequence", "")
311
+ # Heuristically check if input is SMILES (not starting with [) and convert
312
+ if not drug_input.strip().startswith("["):
313
+ converted = smiles_to_selfies(drug_input.strip())
314
+ if converted:
315
+ drug_seq = converted
316
+ else:
317
+ drug_seq = ""
318
+ result_html = "<p style='color:red'><strong>Failed to convert SMILES to SELFIES. Please check the input string.</strong></p>"
319
+ else:
320
+ drug_seq = drug_input
321
+
322
+ elif "Inference" in request.form:
323
+ protein_seq = request.form.get("protein_sequence", "")
324
+ drug_seq = request.form.get("drug_sequence", "")
325
+ if protein_seq and drug_seq:
326
+ loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1,
327
+ collate_fn=collate_fn)
328
+ feats = get_case_feature(encoding, loader)
329
+ model = FusionDTI(446, 768, args).to(DEVICE)
330
+ ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}",
331
+ "best_model.ckpt")
332
+ if os.path.isfile(ckpt):
333
+ model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
334
+ result_html = visualize_attention(model, feats, drug_idx)
335
+
336
+ return render_template_string(
337
+ # ───────────── HTML (原 UI + 新输入框) ─────────────
338
+ """
339
+ <!doctype html>
340
+ <html lang="en"><head><meta charset="utf-8"><title>FusionDTI </title>
341
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=Poppins:wght@500;600&display=swap" rel="stylesheet">
342
+
343
+ <style>
344
+ :root{--bg:#f3f4f6;--card:#fff;--primary:#6366f1;--primary-dark:#4f46e5;--text:#111827;--border:#e5e7eb;}
345
+ *{box-sizing:border-box;margin:0;padding:0}
346
+ body{background:var(--bg);color:var(--text);font-family:Inter,system-ui,Arial,sans-serif;line-height:1.5;padding:32px 12px;}
347
+ h1{font-family:Poppins,Inter,sans-serif;font-weight:600;font-size:1.7rem;text-align:center;margin-bottom:28px;letter-spacing:-.2px;}
348
+ .card{max-width:1000px;margin:0 auto;background:var(--card);border:1px solid var(--border);
349
+ border-radius:12px;box-shadow:0 2px 6px rgba(0,0,0,.05);padding:32px 36px;}
350
+ label{font-weight:500;margin-bottom:6px;display:block}
351
+ textarea,input[type=file]{width:100%;font-size:.9rem;font-family:monospace;padding:10px 12px;
352
+ border:1px solid var(--border);border-radius:8px;background:#fff;resize:vertical;}
353
+ textarea{min-height:90px}
354
+ .btn{appearance:none;border:none;cursor:pointer;padding:12px 22px;border-radius:8px;font-weight:500;
355
+ font-family:Inter,sans-serif;transition:all .18s ease;color:#fff;}
356
+ .btn-primary{background:var(--primary)}.btn-primary:hover{background:var(--primary-dark)}
357
+ .btn-neutral{background:#9ca3af;}.btn-neutral:hover{background:#6b7280}
358
+ .grid{display:grid;gap:22px}.grid-2{grid-template-columns:1fr 1fr}
359
+ .vis-box{margin-top:28px;border:1px solid var(--border);border-radius:10px;overflow:auto;max-height:72vh;}
360
+ pre{white-space:pre-wrap;word-break:break-all;font-family:monospace;margin-top:8px}
361
+
362
+ /* ── tidy table for Top-20 list ─────────────────────────────── */
363
+ table.tg{border-collapse:collapse;margin-top:4px;font-size:0.83rem}
364
+ table.tg th,table.tg td{border:1px solid var(--border);padding:6px 8px;text-align:left}
365
+ table.tg th{background:var(--bg);font-weight:600}
366
+ </style>
367
+ </head>
368
+ <body>
369
+ <h1> Token-level Visualiser for Drug-Target Interaction</h1>
370
+
371
+ <!-- ───────────── Project Links (larger + spaced) ───────────── -->
372
+ <div style="margin-top:24px; text-align:center;">
373
+ <a href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank"
374
+ style="display:inline-block;margin:8px 18px;padding:10px 20px;
375
+ background:linear-gradient(to right,#10b981,#059669);color:white;
376
+ font-weight:600;border-radius:8px;font-size:0.9rem;
377
+ font-family:Inter,sans-serif;text-decoration:none;
378
+ box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
379
+ onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
380
+ 🌐 Project Page
381
+ </a>
382
+
383
+ <a href="https://arxiv.org/abs/2406.01651" target="_blank"
384
+ style="display:inline-block;margin:8px 18px;padding:10px 20px;
385
+ background:linear-gradient(to right,#ef4444,#dc2626);color:white;
386
+ font-weight:600;border-radius:8px;font-size:0.9rem;
387
+ font-family:Inter,sans-serif;text-decoration:none;
388
+ box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
389
+ onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
390
+ 📄 ArXiv: 2406.01651
391
+ </a>
392
+
393
+ <a href="https://github.com/ZhaohanM/FusionDTI" target="_blank"
394
+ style="display:inline-block;margin:8px 18px;padding:10px 20px;
395
+ background:linear-gradient(to right,#3b82f6,#2563eb);color:white;
396
+ font-weight:600;border-radius:8px;font-size:0.9rem;
397
+ font-family:Inter,sans-serif;text-decoration:none;
398
+ box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
399
+ onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
400
+ 💻 GitHub Repo
401
+ </a>
402
+ </div>
403
+
404
+ <!-- ───────────── Guidelines for Use ───────────── -->
405
+ <div class="card" style="margin-bottom:24px">
406
+ <h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for Use</h2>
407
+ <ul style="margin-left:18px;line-height:1.55;list-style:decimal;">
408
+ <li><strong>Convert protein structure into a structure-aware sequence:</strong>
409
+ Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
410
+ sequence will be generated using
411
+ <a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
412
+ based on 3D structures from
413
+ <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
414
+ <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
415
+
416
+ <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
417
+ you must first visit the
418
+ <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
419
+ or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
420
+ to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
421
+
422
+ <li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
423
+ You can enter a SELFIES string directly, or paste a SMILES string.
424
+ SMILES will be automatically converted to SELFIES using
425
+ <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
426
+ If conversion fails, a red error message will be displayed.</li>
427
+
428
+ <li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
429
+ to highlight the Top-10 interacting protein residues.</li>
430
+
431
+ <li>After inference, you can use the
432
+ “Download PDF” link to export a high-resolution vector version.</li>
433
+ </ul>
434
+ </div>
435
+
436
+ <div class="card">
437
+ <form method="POST" enctype="multipart/form-data" class="grid">
438
+
439
+ <div><label>Protein Structure (.pdb / .cif)</label>
440
+ <input type="file" name="structure_file">
441
+ <input type="hidden" name="tmp_structure_path" value="{{ tmp_structure_path }}"></div>
442
+
443
+ <div><label>Protein Sequence</label>
444
+ <textarea name="protein_sequence" placeholder="Confirm / paste sequence…">{{ protein_seq }}</textarea></div>
445
+
446
+ <div><label>Drug Sequence (SELFIES/SMILES)</label>
447
+ <textarea name="drug_sequence" placeholder="[C][C][O]/cco …">{{ drug_seq }}</textarea></div>
448
+
449
+ <label>Drug atom/substructure index (1-based) – show Top-10 related protein residue</label>
450
+ <input type="number" name="drug_idx" min="1" style="width:120px">
451
+
452
+ <div class="grid grid-2">
453
+ <button class="btn btn-primary" type="Inference" name="confirm_structure">Confirm Structure</button>
454
+ <button class="btn btn-primary" type="Inference" name="Inference">Inference</button>
455
+ </div>
456
+ <button class="btn btn-neutral" style="width:100%" type="Inference" name="clear">Clear</button>
457
+ </form>
458
+
459
+ {% if structure_seq %}
460
+ <div style="margin-top:18px"><strong>Structure-aware sequence:</strong><pre>{{ structure_seq }}</pre></div>
461
+ {% endif %}
462
+ {% if result_html %}
463
+ <div class="vis-box" style="margin-top:26px">{{ result_html|safe }}</div>
464
+ {% endif %}
465
+ </div></body></html>
466
+ """,
467
+ protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq,
468
+ result_html=result_html, tmp_structure_path=tmp_structure_path)
469
 
470
+ # ───── run ─────────────────────────────────────────────────────
471
+ if __name__ == "__main__":
472
  app.run(debug=True, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -1,5 +1,11 @@
1
  Flask
2
  torch
3
  transformers
4
- bertviz
5
- IPython
 
 
 
 
 
 
 
1
  Flask
2
  torch
3
  transformers
4
+ IPython
5
+ selfies
6
+ rdkit
7
+ biopython
8
+ matplotlib
9
+ scikit-learn
10
+ numpy
11
+ pandas
utils/.ipynb_checkpoints/drug_tokenizer-checkpoint.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ class DrugTokenizer:
8
+ def __init__(self, vocab_path="data/Tokenizer/vocab.json", special_tokens_path="data/Tokenizer/special_tokens_map.json"):
9
+ self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
10
+ self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
11
+ self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
12
+ self.unk_token_id = self.vocab[self.special_tokens['unk_token']]
13
+ self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
14
+ self.id_to_token = {v: k for k, v in self.vocab.items()}
15
+
16
+ self.all_special_tokens = list(self.special_tokens.values())
17
+
18
+ def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
19
+ with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
20
+ vocab = json.load(vocab_file)
21
+ with open(special_tokens_path, 'r', encoding='utf-8') as special_tokens_file:
22
+ special_tokens_raw = json.load(special_tokens_file)
23
+
24
+ special_tokens = {key: value['content'] for key, value in special_tokens_raw.items()}
25
+ return vocab, special_tokens
26
+
27
+ def encode(self, sequence):
28
+ tokens = re.findall(r'\[([^\[\]]+)\]', sequence)
29
+ input_ids = [self.cls_token_id] + [self.vocab.get(token, self.unk_token_id) for token in tokens] + [self.sep_token_id]
30
+ attention_mask = [1] * len(input_ids)
31
+ return {
32
+ 'input_ids': input_ids,
33
+ 'attention_mask': attention_mask
34
+ }
35
+
36
+ def batch_encode_plus(self, sequences, max_length, padding, truncation, add_special_tokens, return_tensors):
37
+ input_ids_list = []
38
+ attention_mask_list = []
39
+
40
+ for sequence in sequences:
41
+ encoded = self.encode(sequence)
42
+ input_ids = encoded['input_ids']
43
+ attention_mask = encoded['attention_mask']
44
+
45
+ if len(input_ids) > max_length:
46
+ input_ids = input_ids[:max_length]
47
+ attention_mask = attention_mask[:max_length]
48
+ elif len(input_ids) < max_length:
49
+ pad_length = max_length - len(input_ids)
50
+ input_ids = input_ids + [self.vocab[self.special_tokens['pad_token']]] * pad_length
51
+ attention_mask = attention_mask + [0] * pad_length
52
+
53
+ input_ids_list.append(input_ids)
54
+ attention_mask_list.append(attention_mask)
55
+
56
+ return {
57
+ 'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
58
+ 'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long)
59
+ }
60
+
61
+ def decode(self, input_ids, skip_special_tokens=False):
62
+ tokens = []
63
+ for id in input_ids:
64
+ if skip_special_tokens and id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]:
65
+ continue
66
+ tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
67
+ sequence = ''.join([f'[{token}]' for token in tokens])
68
+ return sequence
69
+
70
+ # --- 新增 ---
71
+ def convert_ids_to_tokens(self, ids):
72
+ """list[int] → list[str],跳过未知 id"""
73
+ return [self.id_to_token.get(i, self.special_tokens['unk_token']) for i in ids]
utils/.ipynb_checkpoints/metric_learning_models_att_maps-checkpoint.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+
5
+ sys.path.append("../")
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import functional as F
10
+ from torch.cuda.amp import autocast
11
+ from torch.nn import Module
12
+ from tqdm import tqdm
13
+ from torch.nn.utils.weight_norm import weight_norm
14
+ from torch.utils.data import Dataset
15
+
16
+ LOGGER = logging.getLogger(__name__)
17
+
18
+ class FusionDTI(nn.Module):
19
+ def __init__(self, prot_out_dim, disease_out_dim, args):
20
+ super(FusionDTI, self).__init__()
21
+ self.fusion = args.fusion
22
+ self.drug_reg = nn.Linear(disease_out_dim, 512)
23
+ self.prot_reg = nn.Linear(prot_out_dim, 512)
24
+
25
+ if self.fusion == "CAN":
26
+ self.can_layer = CAN_Layer(hidden_dim=512, num_heads=8, args=args)
27
+ self.mlp_classifier = MlPdecoder_CAN(input_dim=1024)
28
+ elif self.fusion == "BAN":
29
+ self.ban_layer = weight_norm(BANLayer(512, 512, 256, 2), name='h_mat', dim=None)
30
+ self.mlp_classifier = MlPdecoder_CAN(input_dim=256)
31
+ elif self.fusion == "Nan":
32
+ self.mlp_classifier_nan = MlPdecoder_CAN(input_dim=1214)
33
+
34
+ def forward(self, prot_embed, drug_embed, prot_mask, drug_mask):
35
+ # print("drug_embed", drug_embed.shape)
36
+ if self.fusion == "Nan":
37
+ prot_embed = prot_embed.mean(1) # query : [batch_size, hidden]
38
+ drug_embed = drug_embed.mean(1) # query : [batch_size, hidden]
39
+ joint_embed = torch.cat([prot_embed, drug_embed], dim=1)
40
+ score = self.mlp_classifier_nan(joint_embed)
41
+ else:
42
+ prot_embed = self.prot_reg(prot_embed)
43
+ drug_embed = self.drug_reg(drug_embed)
44
+
45
+ if self.fusion == "CAN":
46
+ joint_embed, att = self.can_layer(prot_embed, drug_embed, prot_mask, drug_mask)
47
+ elif self.fusion == "BAN":
48
+ joint_embed, att = self.ban_layer(prot_embed, drug_embed)
49
+
50
+ score = self.mlp_classifier(joint_embed)
51
+
52
+ return score, att
53
+
54
+ class Pre_encoded(nn.Module):
55
+ def __init__(
56
+ self, prot_encoder, drug_encoder, args
57
+ ):
58
+ """Constructor for the model.
59
+
60
+ Args:
61
+ prot_encoder (_type_): Protein sturcture-aware sequence encoder.
62
+ drug_encoder (_type_): Drug SFLFIES encoder.
63
+ args (_type_): _description_
64
+ """
65
+ super(Pre_encoded, self).__init__()
66
+ self.prot_encoder = prot_encoder
67
+ self.drug_encoder = drug_encoder
68
+
69
+ def encoding(self, prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask):
70
+ # Process inputs through encoders
71
+ prot_embed = self.prot_encoder(
72
+ input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
73
+ ).logits
74
+ # prot_embed = self.prot_reg(prot_embed)
75
+
76
+ drug_embed = self.drug_encoder(
77
+ input_ids=drug_input_ids, attention_mask=drug_attention_mask, return_dict=True
78
+ ).last_hidden_state # .last_hidden_state
79
+
80
+ # print("drug_embed", drug_embed.shape)
81
+
82
+ return prot_embed, drug_embed
83
+
84
+
85
+ class CAN_Layer(nn.Module):
86
+ def __init__(self, hidden_dim, num_heads, args):
87
+ super(CAN_Layer, self).__init__()
88
+ self.agg_mode = args.agg_mode
89
+ self.group_size = args.group_size # Control Fusion Scale
90
+ self.hidden_dim = hidden_dim
91
+ self.num_heads = num_heads
92
+ self.head_size = hidden_dim // num_heads
93
+
94
+ self.query_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
95
+ self.key_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
96
+ self.value_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
97
+
98
+ self.query_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
99
+ self.key_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
100
+ self.value_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
101
+
102
+ def alpha_logits(self, logits, mask_row, mask_col, inf=1e6):
103
+ N, L1, L2, H = logits.shape
104
+ mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
105
+ mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
106
+ mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
107
+
108
+ logits = torch.where(mask_pair, logits, logits - inf)
109
+ alpha = torch.softmax(logits, dim=2)
110
+ mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
111
+ alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
112
+ return alpha
113
+
114
+ def apply_heads(self, x, n_heads, n_ch):
115
+ s = list(x.size())[:-1] + [n_heads, n_ch]
116
+ return x.view(*s)
117
+
118
+ def group_embeddings(self, x, mask, group_size):
119
+ N, L, D = x.shape
120
+ groups = L // group_size
121
+ x_grouped = x.view(N, groups, group_size, D).mean(dim=2)
122
+ mask_grouped = mask.view(N, groups, group_size).any(dim=2)
123
+ return x_grouped, mask_grouped
124
+
125
+ def forward(self, protein, drug, mask_prot, mask_drug):
126
+ # Group embeddings before applying multi-head attention
127
+ protein_grouped, mask_prot_grouped = self.group_embeddings(protein, mask_prot, self.group_size)
128
+ drug_grouped, mask_drug_grouped = self.group_embeddings(drug, mask_drug, self.group_size)
129
+
130
+ # print("protein_grouped:", protein_grouped.shape)
131
+ # print("mask_prot_grouped:", mask_prot_grouped.shape)
132
+
133
+ # Compute queries, keys, values for both protein and drug after grouping
134
+ query_prot = self.apply_heads(self.query_p(protein_grouped), self.num_heads, self.head_size)
135
+ key_prot = self.apply_heads(self.key_p(protein_grouped), self.num_heads, self.head_size)
136
+ value_prot = self.apply_heads(self.value_p(protein_grouped), self.num_heads, self.head_size)
137
+
138
+ query_drug = self.apply_heads(self.query_d(drug_grouped), self.num_heads, self.head_size)
139
+ key_drug = self.apply_heads(self.key_d(drug_grouped), self.num_heads, self.head_size)
140
+ value_drug = self.apply_heads(self.value_d(drug_grouped), self.num_heads, self.head_size)
141
+
142
+ # Compute attention scores
143
+ logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot)
144
+ logits_pd = torch.einsum('blhd, bkhd->blkh', query_prot, key_drug)
145
+ logits_dp = torch.einsum('blhd, bkhd->blkh', query_drug, key_prot)
146
+ logits_dd = torch.einsum('blhd, bkhd->blkh', query_drug, key_drug)
147
+ # print("logits_pp:", logits_pp.shape)
148
+
149
+ alpha_pp = self.alpha_logits(logits_pp, mask_prot_grouped, mask_prot_grouped)
150
+ alpha_pd = self.alpha_logits(logits_pd, mask_prot_grouped, mask_drug_grouped)
151
+ alpha_dp = self.alpha_logits(logits_dp, mask_drug_grouped, mask_prot_grouped)
152
+ alpha_dd = self.alpha_logits(logits_dd, mask_drug_grouped, mask_drug_grouped)
153
+
154
+ prot_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_pp, value_prot).flatten(-2) +
155
+ torch.einsum('blkh, bkhd->blhd', alpha_pd, value_drug).flatten(-2)) / 2
156
+ drug_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_dp, value_prot).flatten(-2) +
157
+ torch.einsum('blkh, bkhd->blhd', alpha_dd, value_drug).flatten(-2)) / 2
158
+
159
+ # print("prot_embedding:", prot_embedding.shape)
160
+
161
+ # Continue as usual with the aggregation mode
162
+ if self.agg_mode == "cls":
163
+ prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden]
164
+ drug_embed = drug_embedding[:, 0] # query : [batch_size, hidden]
165
+ elif self.agg_mode == "mean_all_tok":
166
+ prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden]
167
+ drug_embed = drug_embedding.mean(1) # query : [batch_size, hidden]
168
+ elif self.agg_mode == "mean":
169
+ prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1)
170
+ drug_embed = (drug_embedding * mask_drug_grouped.unsqueeze(-1)).sum(1) / mask_drug_grouped.sum(-1).unsqueeze(-1)
171
+ else:
172
+ raise NotImplementedError()
173
+
174
+ # print("prot_embed:", prot_embed.shape)
175
+
176
+ query_embed = torch.cat([prot_embed, drug_embed], dim=1)
177
+
178
+ att_pd = alpha_pd.mean(dim=-1)
179
+
180
+ # print("query_embed:", query_embed.shape)
181
+ return query_embed, att_pd
182
+
183
+ class MlPdecoder_CAN(nn.Module):
184
+ def __init__(self, input_dim):
185
+ super(MlPdecoder_CAN, self).__init__()
186
+ self.fc1 = nn.Linear(input_dim, input_dim)
187
+ self.bn1 = nn.BatchNorm1d(input_dim)
188
+ self.fc2 = nn.Linear(input_dim, input_dim // 2)
189
+ self.bn2 = nn.BatchNorm1d(input_dim // 2)
190
+ self.fc3 = nn.Linear(input_dim // 2, input_dim // 4)
191
+ self.bn3 = nn.BatchNorm1d(input_dim // 4)
192
+ self.output = nn.Linear(input_dim // 4, 1)
193
+
194
+ def forward(self, x):
195
+ x = self.bn1(torch.relu(self.fc1(x)))
196
+ x = self.bn2(torch.relu(self.fc2(x)))
197
+ x = self.bn3(torch.relu(self.fc3(x)))
198
+ x = torch.sigmoid(self.output(x))
199
+ return x
200
+
201
+ class MLPdecoder_BAN(nn.Module):
202
+ def __init__(self, in_dim, hidden_dim, out_dim, binary=1):
203
+ super(MLPdecoder_BAN, self).__init__()
204
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
205
+ self.bn1 = nn.BatchNorm1d(hidden_dim)
206
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
207
+ self.bn2 = nn.BatchNorm1d(hidden_dim)
208
+ self.fc3 = nn.Linear(hidden_dim, out_dim)
209
+ self.bn3 = nn.BatchNorm1d(out_dim)
210
+ self.fc4 = nn.Linear(out_dim, binary)
211
+
212
+ def forward(self, x):
213
+ x = self.bn1(F.relu(self.fc1(x)))
214
+ x = self.bn2(F.relu(self.fc2(x)))
215
+ x = self.bn3(F.relu(self.fc3(x)))
216
+ # x = self.fc4(x)
217
+ x = torch.sigmoid(self.fc4(x))
218
+ return x
219
+
220
+ class BANLayer(nn.Module):
221
+ """ Bilinear attention network
222
+ Modified from https://github.com/peizhenbai/DrugBAN/blob/main/ban.py
223
+ """
224
+ def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3):
225
+ super(BANLayer, self).__init__()
226
+
227
+ self.c = 32
228
+ self.k = k
229
+ self.v_dim = v_dim
230
+ self.q_dim = q_dim
231
+ self.h_dim = h_dim
232
+ self.h_out = h_out
233
+
234
+ self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout)
235
+ self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout)
236
+ # self.dropout = nn.Dropout(dropout[1])
237
+ if 1 < k:
238
+ self.p_net = nn.AvgPool1d(self.k, stride=self.k)
239
+
240
+ if h_out <= self.c:
241
+ self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
242
+ self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
243
+ else:
244
+ self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
245
+
246
+ self.bn = nn.BatchNorm1d(h_dim)
247
+
248
+ def attention_pooling(self, v, q, att_map):
249
+ fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q))
250
+ if 1 < self.k:
251
+ fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d
252
+ fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling
253
+ return fusion_logits
254
+
255
+ def forward(self, v, q, softmax=False):
256
+ v_num = v.size(1)
257
+ q_num = q.size(1)
258
+ # print("v_num", v_num)
259
+ # print("v_num ", v_num)
260
+ if self.h_out <= self.c:
261
+ v_ = self.v_net(v)
262
+ q_ = self.q_net(q)
263
+ # print("v_", v_.shape)
264
+ # print("q_ ", q_.shape)
265
+ att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
266
+ # print("Attention map_1",att_maps.shape)
267
+ else:
268
+ v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
269
+ q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
270
+ d_ = torch.matmul(v_, q_) # b x h_dim x v x q
271
+ att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
272
+ att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
273
+ # print("Attention map_2",att_maps.shape)
274
+ if softmax:
275
+ p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2)
276
+ att_maps = p.view(-1, self.h_out, v_num, q_num)
277
+ # print("Attention map_softmax", att_maps.shape)
278
+ logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :])
279
+ for i in range(1, self.h_out):
280
+ logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :])
281
+ logits += logits_i
282
+ logits = self.bn(logits)
283
+ return logits, att_maps
284
+
285
+
286
+ class FCNet(nn.Module):
287
+ """Simple class for non-linear fully connect network
288
+ Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py
289
+ """
290
+
291
+ def __init__(self, dims, act='ReLU', dropout=0):
292
+ super(FCNet, self).__init__()
293
+
294
+ layers = []
295
+ for i in range(len(dims) - 2):
296
+ in_dim = dims[i]
297
+ out_dim = dims[i + 1]
298
+ if 0 < dropout:
299
+ layers.append(nn.Dropout(dropout))
300
+ layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
301
+ if '' != act:
302
+ layers.append(getattr(nn, act)())
303
+ if 0 < dropout:
304
+ layers.append(nn.Dropout(dropout))
305
+ layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
306
+ if '' != act:
307
+ layers.append(getattr(nn, act)())
308
+
309
+ self.main = nn.Sequential(*layers)
310
+
311
+ def forward(self, x):
312
+ return self.main(x)
313
+
314
+
315
+ class BatchFileDataset_Case(Dataset):
316
+ def __init__(self, file_list):
317
+ self.file_list = file_list
318
+
319
+ def __len__(self):
320
+ return len(self.file_list)
321
+
322
+ def __getitem__(self, idx):
323
+ batch_file = self.file_list[idx]
324
+ data = torch.load(batch_file)
325
+ return data['prot'], data['drug'], data['prot_ids'], data['drug_ids'], data['prot_mask'], data['drug_mask'], data['y']
utils/__pycache__/foldseek_util.cpython-38.pyc ADDED
Binary file (4.86 kB). View file
 
utils/__pycache__/metric_learning_models_att_maps.cpython-38.pyc ADDED
Binary file (10.8 kB). View file
 
utils/drug_tokenizer.py CHANGED
@@ -5,7 +5,7 @@ import torch.nn as nn
5
  from torch.nn import functional as F
6
 
7
  class DrugTokenizer:
8
- def __init__(self, vocab_path="tokenizer/vocab.json", special_tokens_path="tokenizer/special_tokens_map.json"):
9
  self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
10
  self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
11
  self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
@@ -13,6 +13,8 @@ class DrugTokenizer:
13
  self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
14
  self.id_to_token = {v: k for k, v in self.vocab.items()}
15
 
 
 
16
  def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
17
  with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
18
  vocab = json.load(vocab_file)
@@ -64,3 +66,8 @@ class DrugTokenizer:
64
  tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
65
  sequence = ''.join([f'[{token}]' for token in tokens])
66
  return sequence
 
 
 
 
 
 
5
  from torch.nn import functional as F
6
 
7
  class DrugTokenizer:
8
+ def __init__(self, vocab_path="data/Tokenizer/vocab.json", special_tokens_path="data/Tokenizer/special_tokens_map.json"):
9
  self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
10
  self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
11
  self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
 
13
  self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
14
  self.id_to_token = {v: k for k, v in self.vocab.items()}
15
 
16
+ self.all_special_tokens = list(self.special_tokens.values())
17
+
18
  def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
19
  with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
20
  vocab = json.load(vocab_file)
 
66
  tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
67
  sequence = ''.join([f'[{token}]' for token in tokens])
68
  return sequence
69
+
70
+ # --- 新增 ---
71
+ def convert_ids_to_tokens(self, ids):
72
+ """list[int] → list[str],跳过未知 id"""
73
+ return [self.id_to_token.get(i, self.special_tokens['unk_token']) for i in ids]
utils/foldseek_util.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import numpy as np
5
+ import re
6
+ import sys
7
+
8
+ from Bio.PDB import PDBParser, MMCIFParser
9
+
10
+
11
+ sys.path.append(".")
12
+
13
+
14
+ # Get structural seqs from pdb file
15
+ def get_struc_seq(foldseek,
16
+ path,
17
+ chains: list = None,
18
+ process_id: int = 0,
19
+ plddt_mask: bool = "auto",
20
+ plddt_threshold: float = 70.,
21
+ foldseek_verbose: bool = False) -> dict:
22
+ """
23
+
24
+ Args:
25
+ foldseek: Binary executable file of foldseek
26
+
27
+ path: Path to pdb file
28
+
29
+ chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
30
+
31
+ process_id: Process ID for temporary files. This is used for parallel processing.
32
+
33
+ plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
34
+
35
+ plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
36
+
37
+ foldseek_verbose: If True, foldseek will print verbose messages.
38
+
39
+ Returns:
40
+ seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
41
+ (seq, struc_seq, combined_seq).
42
+ """
43
+ assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
44
+ assert os.path.exists(path), f"PDB file not found: {path}"
45
+
46
+ tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
47
+ if foldseek_verbose:
48
+ cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
49
+ else:
50
+ cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
51
+ os.system(cmd)
52
+
53
+ # Check whether the structure is predicted by AlphaFold2
54
+ if plddt_mask == "auto":
55
+ with open(path, "r") as r:
56
+ plddt_mask = True if "alphafold" in r.read().lower() else False
57
+
58
+ seq_dict = {}
59
+ name = os.path.basename(path)
60
+ with open(tmp_save_path, "r") as r:
61
+ for i, line in enumerate(r):
62
+ desc, seq, struc_seq = line.split("\t")[:3]
63
+
64
+ # Mask low plddt
65
+ if plddt_mask:
66
+ try:
67
+ plddts = extract_plddt(path)
68
+ assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
69
+
70
+ # Mask regions with plddt < threshold
71
+ indices = np.where(plddts < plddt_threshold)[0]
72
+ np_seq = np.array(list(struc_seq))
73
+ np_seq[indices] = "#"
74
+ struc_seq = "".join(np_seq)
75
+
76
+ except Exception as e:
77
+ print(f"Error: {e}")
78
+ print(f"Failed to mask plddt for {name}")
79
+
80
+ name_chain = desc.split(" ")[0]
81
+ chain = name_chain.replace(name, "").split("_")[-1]
82
+
83
+ if chains is None or chain in chains:
84
+ if chain not in seq_dict:
85
+ combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
86
+ seq_dict[chain] = (seq, struc_seq, combined_seq)
87
+
88
+ os.remove(tmp_save_path)
89
+ os.remove(tmp_save_path + ".dbtype")
90
+ return seq_dict
91
+
92
+
93
+ def extract_plddt(pdb_path: str) -> np.ndarray:
94
+ """
95
+ Extract plddt scores from pdb file.
96
+ Args:
97
+ pdb_path: Path to pdb file.
98
+
99
+ Returns:
100
+ plddts: plddt scores.
101
+ """
102
+
103
+ # Initialize parser
104
+ if pdb_path.endswith(".cif"):
105
+ parser = MMCIFParser()
106
+ elif pdb_path.endswith(".pdb"):
107
+ parser = PDBParser()
108
+ else:
109
+ raise ValueError("Invalid file format for plddt extraction. Must be '.cif' or '.pdb'.")
110
+
111
+ structure = parser.get_structure('protein', pdb_path)
112
+ model = structure[0]
113
+ chain = model["A"]
114
+
115
+ # Extract plddt scores
116
+ plddts = []
117
+ for residue in chain:
118
+ residue_plddts = []
119
+ for atom in residue:
120
+ plddt = atom.get_bfactor()
121
+ residue_plddts.append(plddt)
122
+
123
+ plddts.append(np.mean(residue_plddts))
124
+
125
+ plddts = np.array(plddts)
126
+ return plddts
127
+
128
+
129
+ def transform_pdb_dir(foldseek: str, pdb_dir: str, seq_type: str, save_path: str):
130
+ """
131
+ Transform a directory of pdb files into a fasta file.
132
+ Args:
133
+ foldseek: Binary executable file of foldseek.
134
+
135
+ pdb_dir: Directory of pdb files.
136
+
137
+ seq_type: Type of sequence to be extracted. Must be "aa" or "foldseek"
138
+
139
+ save_path: Path to save the fasta file.
140
+ """
141
+ assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
142
+ assert seq_type in ["aa", "foldseek"], f"seq_type must be 'aa' or 'foldseek'!"
143
+
144
+ tmp_save_path = f"get_struc_seq_{time.time()}.tsv"
145
+ cmd = f"{foldseek} structureto3didescriptor --chain-name-mode 1 {pdb_dir} {tmp_save_path}"
146
+ os.system(cmd)
147
+
148
+ with open(tmp_save_path, "r") as r, open(save_path, "w") as w:
149
+ for line in r:
150
+ protein_id, aa_seq, foldseek_seq = line.strip().split("\t")[:3]
151
+
152
+ if seq_type == "aa":
153
+ w.write(f">{protein_id}\n{aa_seq}\n")
154
+ else:
155
+ w.write(f">{protein_id}\n{foldseek_seq.lower()}\n")
156
+
157
+ os.remove(tmp_save_path)
158
+ os.remove(tmp_save_path + ".dbtype")
159
+
160
+
161
+ if __name__ == '__main__':
162
+ foldseek = "/sujin/bin/foldseek"
163
+ # test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
164
+ test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
165
+ plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
166
+ res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
167
+ print(res["A"][1].lower())
utils/metric_learning_models_att_maps.py CHANGED
@@ -175,15 +175,10 @@ class CAN_Layer(nn.Module):
175
 
176
  query_embed = torch.cat([prot_embed, drug_embed], dim=1)
177
 
178
-
179
- att = torch.zeros(1, 1, 1024, 1024)
180
- att[:, :, :512, :512] = alpha_pp.mean(dim=-1) # Protein to Protein
181
- att[:, :, :512, 512:] = alpha_pd.mean(dim=-1) # Protein to Drug
182
- att[:, :, 512:, :512] = alpha_dp.mean(dim=-1) # Drug to Protein
183
- att[:, :, 512:, 512:] = alpha_dd.mean(dim=-1) # Drug to Drug
184
 
185
  # print("query_embed:", query_embed.shape)
186
- return query_embed, att
187
 
188
  class MlPdecoder_CAN(nn.Module):
189
  def __init__(self, input_dim):
 
175
 
176
  query_embed = torch.cat([prot_embed, drug_embed], dim=1)
177
 
178
+ att_pd = alpha_pd.mean(dim=-1)
 
 
 
 
 
179
 
180
  # print("query_embed:", query_embed.shape)
181
+ return query_embed, att_pd
182
 
183
  class MlPdecoder_CAN(nn.Module):
184
  def __init__(self, input_dim):