Spaces:
Running
Running
ZhaohanM
commited on
Commit
·
5676c75
1
Parent(s):
fde18d9
Update: SMILES-to-SELFIES conversion, UI polish, and usage guide
Browse files- .ipynb_checkpoints/app-checkpoint.py +472 -0
- .ipynb_checkpoints/requirements-checkpoint.txt +11 -0
- app.py +432 -193
- requirements.txt +8 -2
- utils/.ipynb_checkpoints/drug_tokenizer-checkpoint.py +73 -0
- utils/.ipynb_checkpoints/metric_learning_models_att_maps-checkpoint.py +325 -0
- utils/__pycache__/foldseek_util.cpython-38.pyc +0 -0
- utils/__pycache__/metric_learning_models_att_maps.cpython-38.pyc +0 -0
- utils/drug_tokenizer.py +8 -1
- utils/foldseek_util.py +167 -0
- utils/metric_learning_models_att_maps.py +2 -7
.ipynb_checkpoints/app-checkpoint.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, argparse, tempfile, shutil, base64, io
|
2 |
+
from flask import Flask, request, render_template_string
|
3 |
+
from werkzeug.utils import secure_filename
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import selfies
|
6 |
+
from rdkit import Chem
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import matplotlib
|
10 |
+
matplotlib.use("Agg")
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from matplotlib import cm
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
+
from utils.drug_tokenizer import DrugTokenizer
|
16 |
+
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
|
17 |
+
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
|
18 |
+
from utils.foldseek_util import get_struc_seq
|
19 |
+
|
20 |
+
# ───── Biopython fallback ───────────────────────────────────────
|
21 |
+
from Bio.PDB import PDBParser, MMCIFParser
|
22 |
+
from Bio.Data import IUPACData
|
23 |
+
|
24 |
+
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
|
25 |
+
three2one.update({"SEC": "C", "PYL": "K"})
|
26 |
+
def simple_seq_from_structure(path: str) -> str:
|
27 |
+
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
|
28 |
+
chain = next(parser.get_structure("P", path).get_chains())
|
29 |
+
return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
|
30 |
+
|
31 |
+
# ───── global paths / args ──────────────────────────────────────
|
32 |
+
FOLDSEEK_BIN = shutil.which("foldseek")
|
33 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
34 |
+
sys.path.append("..")
|
35 |
+
|
36 |
+
def parse_config():
|
37 |
+
p = argparse.ArgumentParser()
|
38 |
+
p.add_argument("-f")
|
39 |
+
p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
|
40 |
+
p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
|
41 |
+
p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
|
42 |
+
p.add_argument("--group_size", type=int, default=1)
|
43 |
+
p.add_argument("--lr", type=float, default=1e-4)
|
44 |
+
p.add_argument("--fusion", default="CAN")
|
45 |
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
46 |
+
p.add_argument("--save_path_prefix", default="save_model_ckp/")
|
47 |
+
p.add_argument("--dataset", default="BindingDB"), help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')"
|
48 |
+
return p.parse_args()
|
49 |
+
|
50 |
+
args = parse_config()
|
51 |
+
DEVICE = args.device
|
52 |
+
|
53 |
+
# ───── tokenisers & encoders ────────────────────────────────────
|
54 |
+
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
|
55 |
+
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
|
56 |
+
|
57 |
+
drug_tokenizer = DrugTokenizer() # SELFIES
|
58 |
+
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
|
59 |
+
|
60 |
+
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
|
61 |
+
|
62 |
+
# ─── collate fn ────────────────────────────────────────────────
|
63 |
+
def collate_fn(batch):
|
64 |
+
query1, query2, scores = zip(*batch)
|
65 |
+
|
66 |
+
query_encodings1 = prot_tokenizer.batch_encode_plus(
|
67 |
+
list(query1),
|
68 |
+
max_length=512,
|
69 |
+
padding="max_length",
|
70 |
+
truncation=True,
|
71 |
+
add_special_tokens=True,
|
72 |
+
return_tensors="pt",
|
73 |
+
)
|
74 |
+
query_encodings2 = drug_tokenizer.batch_encode_plus(
|
75 |
+
list(query2),
|
76 |
+
max_length=512,
|
77 |
+
padding="max_length",
|
78 |
+
truncation=True,
|
79 |
+
add_special_tokens=True,
|
80 |
+
return_tensors="pt",
|
81 |
+
)
|
82 |
+
scores = torch.tensor(list(scores))
|
83 |
+
|
84 |
+
attention_mask1 = query_encodings1["attention_mask"].bool()
|
85 |
+
attention_mask2 = query_encodings2["attention_mask"].bool()
|
86 |
+
|
87 |
+
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
|
88 |
+
# def collate_fn_batch_encoding(batch):
|
89 |
+
|
90 |
+
def smiles_to_selfies(smiles: str) -> Optional[str]:
|
91 |
+
try:
|
92 |
+
mol = Chem.MolFromSmiles(smiles)
|
93 |
+
if mol is None:
|
94 |
+
return None
|
95 |
+
selfies_str = selfies.encoder(smiles)
|
96 |
+
return selfies_str
|
97 |
+
except Exception:
|
98 |
+
return None
|
99 |
+
|
100 |
+
|
101 |
+
# ───── single-case embedding ───────────────────────────────────
|
102 |
+
def get_case_feature(model, loader):
|
103 |
+
model.eval()
|
104 |
+
with torch.no_grad():
|
105 |
+
for p_ids, p_mask, d_ids, d_mask, _ in loader:
|
106 |
+
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
|
107 |
+
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
|
108 |
+
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
|
109 |
+
return [(p_emb.cpu(), d_emb.cpu(),
|
110 |
+
p_ids.cpu(), d_ids.cpu(),
|
111 |
+
p_mask.cpu(), d_mask.cpu(), None)]
|
112 |
+
|
113 |
+
# ───── helper:过滤特殊 token ───────────────────────────────────
|
114 |
+
def clean_tokens(ids, tokenizer):
|
115 |
+
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
|
116 |
+
return [t for t in toks if t not in tokenizer.all_special_tokens]
|
117 |
+
|
118 |
+
# ───── visualisation ───────────────────────────────────────────
|
119 |
+
|
120 |
+
def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
|
121 |
+
"""
|
122 |
+
Render a Protein → Drug cross-attention heat-map and, optionally, a
|
123 |
+
Top-20 protein-residue table for a chosen drug-token index.
|
124 |
+
|
125 |
+
The token index shown on the x-axis (and accepted via *drug_idx*) is **the
|
126 |
+
position of that token in the *original* drug sequence**, *after* the
|
127 |
+
tokeniser but *before* any pruning or truncation (1-based in the labels,
|
128 |
+
0-based for the function argument).
|
129 |
+
|
130 |
+
Returns
|
131 |
+
-------
|
132 |
+
html : str
|
133 |
+
Base64-embedded PNG heat-map (+ optional HTML table).
|
134 |
+
"""
|
135 |
+
model.eval()
|
136 |
+
with torch.no_grad():
|
137 |
+
# ── unpack single-case tensors ───────────────────────────────────────────
|
138 |
+
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
|
139 |
+
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
|
140 |
+
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
|
141 |
+
|
142 |
+
# ── forward pass: Protein → Drug attention (B, n_p, n_d) ───────────────
|
143 |
+
_, att_pd = model(p_emb, d_emb, p_mask, d_mask)
|
144 |
+
attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
|
145 |
+
|
146 |
+
# ── decode tokens (skip special symbols) ────────────────────────────────
|
147 |
+
def clean_ids(ids, tokenizer):
|
148 |
+
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
|
149 |
+
return [t for t in toks if t not in tokenizer.all_special_tokens]
|
150 |
+
|
151 |
+
# ── decode full sequences + record 1-based indices ──────────────────
|
152 |
+
p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
|
153 |
+
p_indices_full = list(range(1, len(p_tokens_full) + 1))
|
154 |
+
|
155 |
+
d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
|
156 |
+
d_indices_full = list(range(1, len(d_tokens_full) + 1))
|
157 |
+
|
158 |
+
# ── safety cut-off to match attn mat size ───────────────────────────────
|
159 |
+
p_tokens = p_tokens_full[: attn.size(0)]
|
160 |
+
p_indices_full = p_indices_full[: attn.size(0)]
|
161 |
+
d_tokens_full = d_tokens_full[: attn.size(1)]
|
162 |
+
d_indices_full = d_indices_full[: attn.size(1)]
|
163 |
+
attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
|
164 |
+
|
165 |
+
# ── adaptive sparsity pruning ───────────────────────────────────────────
|
166 |
+
thr = attn.max().item() * 0.05
|
167 |
+
row_keep = (attn.max(dim=1).values > thr)
|
168 |
+
col_keep = (attn.max(dim=0).values > thr)
|
169 |
+
|
170 |
+
if row_keep.sum() < 3:
|
171 |
+
row_keep[:] = True
|
172 |
+
if col_keep.sum() < 3:
|
173 |
+
col_keep[:] = True
|
174 |
+
|
175 |
+
attn = attn[row_keep][:, col_keep]
|
176 |
+
p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
|
177 |
+
p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
|
178 |
+
d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
|
179 |
+
d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
|
180 |
+
|
181 |
+
# ── cap column count at 150 for readability ─────────────────────────────
|
182 |
+
if attn.size(1) > 150:
|
183 |
+
topc = torch.topk(attn.sum(0), k=150).indices
|
184 |
+
attn = attn[:, topc]
|
185 |
+
d_tokens = [d_tokens [i] for i in topc]
|
186 |
+
d_indices = [d_indices[i] for i in topc]
|
187 |
+
|
188 |
+
# ── draw heat-map ───────────────────────────────────────────────────────
|
189 |
+
x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
|
190 |
+
y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
|
191 |
+
|
192 |
+
|
193 |
+
fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column
|
194 |
+
fig_h = min(24, max(6, len(p_tokens) * 0.8))
|
195 |
+
|
196 |
+
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
|
197 |
+
im = ax.imshow(attn.numpy(), aspect="auto",
|
198 |
+
cmap=cm.viridis, interpolation="nearest")
|
199 |
+
|
200 |
+
ax.set_title("Protein → Drug Attention", pad=8, fontsize=10)
|
201 |
+
|
202 |
+
ax.set_xticks(range(len(x_labels)))
|
203 |
+
ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
|
204 |
+
ha="center", va="center")
|
205 |
+
ax.tick_params(axis="x", top=True, bottom=False,
|
206 |
+
labeltop=True, labelbottom=False, pad=27)
|
207 |
+
|
208 |
+
ax.set_yticks(range(len(y_labels)))
|
209 |
+
ax.set_yticklabels(y_labels, fontsize=7)
|
210 |
+
ax.tick_params(axis="y", top=True, bottom=False,
|
211 |
+
labeltop=True, labelbottom=False,
|
212 |
+
pad=10)
|
213 |
+
|
214 |
+
fig.colorbar(im, fraction=0.026, pad=0.01)
|
215 |
+
fig.tight_layout()
|
216 |
+
|
217 |
+
buf = io.BytesIO()
|
218 |
+
fig.savefig(buf, format="png", dpi=140)
|
219 |
+
plt.close(fig)
|
220 |
+
html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
|
221 |
+
|
222 |
+
# ───────────────────── 生成 Top-20 表(若需要) ─────────────────────
|
223 |
+
table_html = "" # 先设空串,方便后面统一拼接
|
224 |
+
if drug_idx is not None:
|
225 |
+
# map original 0-based drug_idx → current column position
|
226 |
+
if (drug_idx + 1) in d_indices:
|
227 |
+
col_pos = d_indices.index(drug_idx + 1)
|
228 |
+
elif 0 <= drug_idx < len(d_tokens):
|
229 |
+
col_pos = drug_idx
|
230 |
+
else:
|
231 |
+
col_pos = None
|
232 |
+
|
233 |
+
if col_pos is not None:
|
234 |
+
col_vec = attn[:, col_pos]
|
235 |
+
topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist()
|
236 |
+
|
237 |
+
rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
|
238 |
+
res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
|
239 |
+
pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk)
|
240 |
+
|
241 |
+
drug_tok_text = d_tokens[col_pos]
|
242 |
+
orig_idx = d_indices[col_pos]
|
243 |
+
|
244 |
+
table_html = (
|
245 |
+
f"<h4 style='margin-bottom:6px'>"
|
246 |
+
f"Drug token #{orig_idx} <code>{drug_tok_text}</code> "
|
247 |
+
f"→ Top-20 Protein residues</h4>"
|
248 |
+
"<table class='tg' style='margin-bottom:8px'>"
|
249 |
+
f"<tr><th>Rank</th>{rank_hdr}</tr>"
|
250 |
+
f"<tr><td>Residue</td>{res_row}</tr>"
|
251 |
+
f"<tr><td>Position</td>{pos_row}</tr>"
|
252 |
+
"</table>")
|
253 |
+
|
254 |
+
# ────────────────── 生成可放大 + 可下载的热图 ────────────────────
|
255 |
+
buf_png = io.BytesIO()
|
256 |
+
fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅)
|
257 |
+
buf_png.seek(0)
|
258 |
+
|
259 |
+
buf_pdf = io.BytesIO()
|
260 |
+
fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量)
|
261 |
+
buf_pdf.seek(0)
|
262 |
+
plt.close(fig)
|
263 |
+
|
264 |
+
png_b64 = base64.b64encode(buf_png.getvalue()).decode()
|
265 |
+
pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
|
266 |
+
|
267 |
+
html_heat = (
|
268 |
+
f"<a href='data:image/png;base64,{png_b64}' target='_blank' "
|
269 |
+
f"title='Click to enlarge'>"
|
270 |
+
f"<img src='data:image/png;base64,{png_b64}' "
|
271 |
+
f"style='max-width:100%;height:auto;cursor:zoom-in' /></a>"
|
272 |
+
f"<div style='margin-top:6px'>"
|
273 |
+
f"<a href='data:application/pdf;base64,{pdf_b64}' "
|
274 |
+
f"download='attention_heatmap.pdf'>Download PDF</a></div>"
|
275 |
+
)
|
276 |
+
|
277 |
+
# ───────────────────────── 返回最终 HTML ─────────────────────────
|
278 |
+
return table_html + html_heat
|
279 |
+
|
280 |
+
|
281 |
+
# ───── Flask app ───────────────────────────────────────────────
|
282 |
+
app = Flask(__name__)
|
283 |
+
|
284 |
+
@app.route("/", methods=["GET", "POST"])
|
285 |
+
def index():
|
286 |
+
protein_seq = drug_seq = structure_seq = ""; result_html = None
|
287 |
+
tmp_structure_path = ""; drug_idx = None
|
288 |
+
|
289 |
+
if request.method == "POST":
|
290 |
+
drug_idx_raw = request.form.get("drug_idx", "")
|
291 |
+
drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None
|
292 |
+
|
293 |
+
struct = request.files.get("structure_file")
|
294 |
+
if struct and struct.filename:
|
295 |
+
path = os.path.join(tempfile.gettempdir(), secure_filename(struct.filename))
|
296 |
+
struct.save(path); tmp_structure_path = path
|
297 |
+
else:
|
298 |
+
tmp_structure_path = request.form.get("tmp_structure_path", "")
|
299 |
+
|
300 |
+
if "clear" in request.form:
|
301 |
+
protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = ""
|
302 |
+
|
303 |
+
elif "confirm_structure" in request.form and tmp_structure_path:
|
304 |
+
try:
|
305 |
+
parsed = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, None, plddt_mask=False)
|
306 |
+
chain = list(parsed.keys())[0]; _, _, structure_seq = parsed[chain]
|
307 |
+
except Exception:
|
308 |
+
structure_seq = simple_seq_from_structure(tmp_structure_path)
|
309 |
+
protein_seq = structure_seq
|
310 |
+
drug_input = request.form.get("drug_sequence", "")
|
311 |
+
# Heuristically check if input is SMILES (not starting with [) and convert
|
312 |
+
if not drug_input.strip().startswith("["):
|
313 |
+
converted = smiles_to_selfies(drug_input.strip())
|
314 |
+
if converted:
|
315 |
+
drug_seq = converted
|
316 |
+
else:
|
317 |
+
drug_seq = ""
|
318 |
+
result_html = "<p style='color:red'><strong>Failed to convert SMILES to SELFIES. Please check the input string.</strong></p>"
|
319 |
+
else:
|
320 |
+
drug_seq = drug_input
|
321 |
+
|
322 |
+
elif "Inference" in request.form:
|
323 |
+
protein_seq = request.form.get("protein_sequence", "")
|
324 |
+
drug_seq = request.form.get("drug_sequence", "")
|
325 |
+
if protein_seq and drug_seq:
|
326 |
+
loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1,
|
327 |
+
collate_fn=collate_fn)
|
328 |
+
feats = get_case_feature(encoding, loader)
|
329 |
+
model = FusionDTI(446, 768, args).to(DEVICE)
|
330 |
+
ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}",
|
331 |
+
"best_model.ckpt")
|
332 |
+
if os.path.isfile(ckpt):
|
333 |
+
model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
|
334 |
+
result_html = visualize_attention(model, feats, drug_idx)
|
335 |
+
|
336 |
+
return render_template_string(
|
337 |
+
# ───────────── HTML (原 UI + 新输入框) ─────────────
|
338 |
+
"""
|
339 |
+
<!doctype html>
|
340 |
+
<html lang="en"><head><meta charset="utf-8"><title>FusionDTI </title>
|
341 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=Poppins:wght@500;600&display=swap" rel="stylesheet">
|
342 |
+
|
343 |
+
<style>
|
344 |
+
:root{--bg:#f3f4f6;--card:#fff;--primary:#6366f1;--primary-dark:#4f46e5;--text:#111827;--border:#e5e7eb;}
|
345 |
+
*{box-sizing:border-box;margin:0;padding:0}
|
346 |
+
body{background:var(--bg);color:var(--text);font-family:Inter,system-ui,Arial,sans-serif;line-height:1.5;padding:32px 12px;}
|
347 |
+
h1{font-family:Poppins,Inter,sans-serif;font-weight:600;font-size:1.7rem;text-align:center;margin-bottom:28px;letter-spacing:-.2px;}
|
348 |
+
.card{max-width:1000px;margin:0 auto;background:var(--card);border:1px solid var(--border);
|
349 |
+
border-radius:12px;box-shadow:0 2px 6px rgba(0,0,0,.05);padding:32px 36px;}
|
350 |
+
label{font-weight:500;margin-bottom:6px;display:block}
|
351 |
+
textarea,input[type=file]{width:100%;font-size:.9rem;font-family:monospace;padding:10px 12px;
|
352 |
+
border:1px solid var(--border);border-radius:8px;background:#fff;resize:vertical;}
|
353 |
+
textarea{min-height:90px}
|
354 |
+
.btn{appearance:none;border:none;cursor:pointer;padding:12px 22px;border-radius:8px;font-weight:500;
|
355 |
+
font-family:Inter,sans-serif;transition:all .18s ease;color:#fff;}
|
356 |
+
.btn-primary{background:var(--primary)}.btn-primary:hover{background:var(--primary-dark)}
|
357 |
+
.btn-neutral{background:#9ca3af;}.btn-neutral:hover{background:#6b7280}
|
358 |
+
.grid{display:grid;gap:22px}.grid-2{grid-template-columns:1fr 1fr}
|
359 |
+
.vis-box{margin-top:28px;border:1px solid var(--border);border-radius:10px;overflow:auto;max-height:72vh;}
|
360 |
+
pre{white-space:pre-wrap;word-break:break-all;font-family:monospace;margin-top:8px}
|
361 |
+
|
362 |
+
/* ── tidy table for Top-20 list ─────────────────────────────── */
|
363 |
+
table.tg{border-collapse:collapse;margin-top:4px;font-size:0.83rem}
|
364 |
+
table.tg th,table.tg td{border:1px solid var(--border);padding:6px 8px;text-align:left}
|
365 |
+
table.tg th{background:var(--bg);font-weight:600}
|
366 |
+
</style>
|
367 |
+
</head>
|
368 |
+
<body>
|
369 |
+
<h1> Token-level Visualiser for Drug-Target Interaction</h1>
|
370 |
+
|
371 |
+
<!-- ───────────── Project Links (larger + spaced) ───────────── -->
|
372 |
+
<div style="margin-top:24px; text-align:center;">
|
373 |
+
<a href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank"
|
374 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
375 |
+
background:linear-gradient(to right,#10b981,#059669);color:white;
|
376 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
377 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
378 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
379 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
380 |
+
🌐 Project Page
|
381 |
+
</a>
|
382 |
+
|
383 |
+
<a href="https://arxiv.org/abs/2406.01651" target="_blank"
|
384 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
385 |
+
background:linear-gradient(to right,#ef4444,#dc2626);color:white;
|
386 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
387 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
388 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
389 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
390 |
+
📄 ArXiv: 2406.01651
|
391 |
+
</a>
|
392 |
+
|
393 |
+
<a href="https://github.com/ZhaohanM/FusionDTI" target="_blank"
|
394 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
395 |
+
background:linear-gradient(to right,#3b82f6,#2563eb);color:white;
|
396 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
397 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
398 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
399 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
400 |
+
💻 GitHub Repo
|
401 |
+
</a>
|
402 |
+
</div>
|
403 |
+
|
404 |
+
<!-- ───────────── Guidelines for Use ───────────── -->
|
405 |
+
<div class="card" style="margin-bottom:24px">
|
406 |
+
<h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for Use</h2>
|
407 |
+
<ul style="margin-left:18px;line-height:1.55;list-style:decimal;">
|
408 |
+
<li><strong>Convert protein structure into a structure-aware sequence:</strong>
|
409 |
+
Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
|
410 |
+
sequence will be generated using
|
411 |
+
<a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
|
412 |
+
based on 3D structures from
|
413 |
+
<a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> or the
|
414 |
+
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
|
415 |
+
|
416 |
+
<li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
|
417 |
+
you must first visit the
|
418 |
+
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
|
419 |
+
or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a>
|
420 |
+
to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
|
421 |
+
|
422 |
+
<li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
|
423 |
+
You can enter a SELFIES string directly, or paste a SMILES string.
|
424 |
+
SMILES will be automatically converted to SELFIES using
|
425 |
+
<a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
|
426 |
+
If conversion fails, a red error message will be displayed.</li>
|
427 |
+
|
428 |
+
<li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
|
429 |
+
to highlight the Top-10 interacting protein residues.</li>
|
430 |
+
|
431 |
+
<li>After inference, you can use the
|
432 |
+
“Download PDF” link to export a high-resolution vector version.</li>
|
433 |
+
</ul>
|
434 |
+
</div>
|
435 |
+
|
436 |
+
<div class="card">
|
437 |
+
<form method="POST" enctype="multipart/form-data" class="grid">
|
438 |
+
|
439 |
+
<div><label>Protein Structure (.pdb / .cif)</label>
|
440 |
+
<input type="file" name="structure_file">
|
441 |
+
<input type="hidden" name="tmp_structure_path" value="{{ tmp_structure_path }}"></div>
|
442 |
+
|
443 |
+
<div><label>Protein Sequence</label>
|
444 |
+
<textarea name="protein_sequence" placeholder="Confirm / paste sequence…">{{ protein_seq }}</textarea></div>
|
445 |
+
|
446 |
+
<div><label>Drug Sequence (SELFIES/SMILES)</label>
|
447 |
+
<textarea name="drug_sequence" placeholder="[C][C][O]/cco …">{{ drug_seq }}</textarea></div>
|
448 |
+
|
449 |
+
<label>Drug atom/substructure index (1-based) – show Top-10 related protein residue</label>
|
450 |
+
<input type="number" name="drug_idx" min="1" style="width:120px">
|
451 |
+
|
452 |
+
<div class="grid grid-2">
|
453 |
+
<button class="btn btn-primary" type="Inference" name="confirm_structure">Confirm Structure</button>
|
454 |
+
<button class="btn btn-primary" type="Inference" name="Inference">Inference</button>
|
455 |
+
</div>
|
456 |
+
<button class="btn btn-neutral" style="width:100%" type="Inference" name="clear">Clear</button>
|
457 |
+
</form>
|
458 |
+
|
459 |
+
{% if structure_seq %}
|
460 |
+
<div style="margin-top:18px"><strong>Structure-aware sequence:</strong><pre>{{ structure_seq }}</pre></div>
|
461 |
+
{% endif %}
|
462 |
+
{% if result_html %}
|
463 |
+
<div class="vis-box" style="margin-top:26px">{{ result_html|safe }}</div>
|
464 |
+
{% endif %}
|
465 |
+
</div></body></html>
|
466 |
+
""",
|
467 |
+
protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq,
|
468 |
+
result_html=result_html, tmp_structure_path=tmp_structure_path)
|
469 |
+
|
470 |
+
# ───── run ─────────────────────────────────────────────────────
|
471 |
+
if __name__ == "__main__":
|
472 |
+
app.run(debug=True, host="0.0.0.0", port=7860)
|
.ipynb_checkpoints/requirements-checkpoint.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Flask
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
IPython
|
5 |
+
selfies
|
6 |
+
rdkit
|
7 |
+
biopython
|
8 |
+
matplotlib
|
9 |
+
scikit-learn
|
10 |
+
numpy
|
11 |
+
pandas
|
app.py
CHANGED
@@ -1,209 +1,66 @@
|
|
1 |
-
import os
|
2 |
-
import
|
3 |
-
import
|
4 |
-
import torch
|
5 |
from torch.utils.data import DataLoader
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from utils.drug_tokenizer import DrugTokenizer
|
|
|
8 |
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
|
9 |
-
from
|
10 |
-
import tempfile
|
11 |
-
from flask import Flask, request, render_template_string
|
12 |
|
13 |
-
|
14 |
-
|
|
|
15 |
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
def parse_config():
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
32 |
-
parser.add_argument("--save_path_prefix", type=str, default="save_model_ckp/", help="save the result in which directory")
|
33 |
-
parser.add_argument("--save_name", default="fine_tune", type=str, help="the name of the saved file")
|
34 |
-
parser.add_argument("--dataset", type=str, default="Human", help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')")
|
35 |
-
return parser.parse_args()
|
36 |
|
37 |
args = parse_config()
|
38 |
-
|
39 |
|
|
|
40 |
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
|
41 |
-
|
42 |
-
|
43 |
-
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
|
44 |
-
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
|
45 |
|
46 |
-
|
|
|
47 |
|
48 |
-
|
49 |
-
with torch.no_grad():
|
50 |
-
for step, batch in enumerate(dataloader):
|
51 |
-
prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask, label = batch
|
52 |
-
prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask = \
|
53 |
-
prot_input_ids.to(device), prot_attention_mask.to(device), drug_input_ids.to(device), drug_attention_mask.to(device)
|
54 |
-
|
55 |
-
prot_embed, drug_embed = model.encoding(prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask)
|
56 |
-
prot_embed, drug_embed = prot_embed.cpu(), drug_embed.cpu()
|
57 |
-
prot_input_ids, drug_input_ids = prot_input_ids.cpu(), drug_input_ids.cpu()
|
58 |
-
prot_attention_mask, drug_attention_mask = prot_attention_mask.cpu(), drug_attention_mask.cpu()
|
59 |
-
label = label.cpu()
|
60 |
-
|
61 |
-
return [(prot_embed, drug_embed, prot_input_ids, drug_input_ids, prot_attention_mask, drug_attention_mask, label)]
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
with torch.no_grad():
|
66 |
-
for batch in case_features:
|
67 |
-
prot, drug, prot_ids, drug_ids, prot_mask, drug_mask, label = batch
|
68 |
-
prot, drug = prot.to(device), drug.to(device)
|
69 |
-
prot_mask, drug_mask = prot_mask.to(device), drug_mask.to(device)
|
70 |
-
|
71 |
-
output, attention_weights = model(prot, drug, prot_mask, drug_mask)
|
72 |
-
prot_tokens = [prot_tokenizer.decode([pid.item()], skip_special_tokens=True) for pid in prot_ids.squeeze()]
|
73 |
-
drug_tokens = [drug_tokenizer.decode([did.item()], skip_special_tokens=True) for did in drug_ids.squeeze()]
|
74 |
-
tokens = prot_tokens + drug_tokens
|
75 |
-
|
76 |
-
attention_weights = attention_weights.unsqueeze(1)
|
77 |
-
|
78 |
-
# Generate HTML content using head_view with html_action='return'
|
79 |
-
html_head_view = head_view(attention_weights, tokens, sentence_b_start=512, html_action='return')
|
80 |
-
|
81 |
-
# Parse the HTML and modify it to replace sentence labels
|
82 |
-
html_content = html_head_view.data
|
83 |
-
html_content = html_content.replace("Sentence A -> Sentence A", "Protein -> Protein")
|
84 |
-
html_content = html_content.replace("Sentence B -> Sentence B", "Drug -> Drug")
|
85 |
-
html_content = html_content.replace("Sentence A -> Sentence B", "Protein -> Drug")
|
86 |
-
html_content = html_content.replace("Sentence B -> Sentence A", "Drug -> Protein")
|
87 |
-
|
88 |
-
# Save the modified HTML content to a temporary file
|
89 |
-
with tempfile.NamedTemporaryFile(delete=False, suffix=".html") as f:
|
90 |
-
f.write(html_content.encode('utf-8'))
|
91 |
-
temp_file_path = f.name
|
92 |
-
|
93 |
-
return temp_file_path
|
94 |
-
|
95 |
-
@app.route('/', methods=['GET', 'POST'])
|
96 |
-
def index():
|
97 |
-
protein_sequence = ""
|
98 |
-
drug_sequence = ""
|
99 |
-
result = None
|
100 |
-
|
101 |
-
if request.method == 'POST':
|
102 |
-
if 'clear' in request.form:
|
103 |
-
protein_sequence = ""
|
104 |
-
drug_sequence = ""
|
105 |
-
else:
|
106 |
-
protein_sequence = request.form['protein_sequence']
|
107 |
-
drug_sequence = request.form['drug_sequence']
|
108 |
-
|
109 |
-
dataset = [(protein_sequence, drug_sequence, 1)]
|
110 |
-
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn_batch_encoding)
|
111 |
-
|
112 |
-
case_features = get_case_feature(encoding, dataloader, device)
|
113 |
-
model = FusionDTI(446, 768, args).to(device)
|
114 |
-
|
115 |
-
best_model_dir = f"{args.save_path_prefix}{args.dataset}_{args.fusion}"
|
116 |
-
checkpoint_path = os.path.join(best_model_dir, 'best_model.ckpt')
|
117 |
-
|
118 |
-
if os.path.exists(checkpoint_path):
|
119 |
-
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
|
120 |
-
|
121 |
-
html_file_path = visualize_attention(model, case_features, device, prot_tokenizer, drug_tokenizer)
|
122 |
-
|
123 |
-
with open(html_file_path, 'r') as f:
|
124 |
-
result = f.read()
|
125 |
-
|
126 |
-
return render_template_string('''
|
127 |
-
<html>
|
128 |
-
<head>
|
129 |
-
<title>Drug Target Interaction Visualization</title>
|
130 |
-
<style>
|
131 |
-
body { font-family: 'Times New Roman', Times, serif; margin: 40px; }
|
132 |
-
h2 { color: #333; }
|
133 |
-
.container { display: flex; }
|
134 |
-
.left { flex: 1; padding-right: 20px; }
|
135 |
-
.right { flex: 1; }
|
136 |
-
textarea {
|
137 |
-
width: 100%;
|
138 |
-
padding: 12px 20px;
|
139 |
-
margin: 8px 0;
|
140 |
-
display: inline-block;
|
141 |
-
border: 1px solid #ccc;
|
142 |
-
border-radius: 4px;
|
143 |
-
box-sizing: border-box;
|
144 |
-
font-size: 16px;
|
145 |
-
font-family: 'Times New Roman', Times, serif;
|
146 |
-
}
|
147 |
-
.button-container {
|
148 |
-
display: flex;
|
149 |
-
justify-content: space-between;
|
150 |
-
}
|
151 |
-
input[type="submit"], .button {
|
152 |
-
width: 48%;
|
153 |
-
color: white;
|
154 |
-
padding: 14px 20px;
|
155 |
-
margin: 8px 0;
|
156 |
-
border: none;
|
157 |
-
border-radius: 4px;
|
158 |
-
cursor: pointer;
|
159 |
-
font-size: 16px;
|
160 |
-
font-family: 'Times New Roman', Times, serif;
|
161 |
-
}
|
162 |
-
.submit {
|
163 |
-
background-color: #FFA500;
|
164 |
-
}
|
165 |
-
.submit:hover {
|
166 |
-
background-color: #FF8C00;
|
167 |
-
}
|
168 |
-
.clear {
|
169 |
-
background-color: #D3D3D3;
|
170 |
-
}
|
171 |
-
.clear:hover {
|
172 |
-
background-color: #A9A9A9;
|
173 |
-
}
|
174 |
-
.result {
|
175 |
-
font-size: 18px;
|
176 |
-
}
|
177 |
-
</style>
|
178 |
-
</head>
|
179 |
-
<body>
|
180 |
-
<h2 style="text-align: center;">Drug Target Interaction Visualization</h2>
|
181 |
-
<div class="container">
|
182 |
-
<div class="left">
|
183 |
-
<form method="post">
|
184 |
-
<label for="protein_sequence">Protein Sequence:</label>
|
185 |
-
<textarea id="protein_sequence" name="protein_sequence" rows="4" placeholder="Enter protein sequence here..." required>{{ protein_sequence }}</textarea><br>
|
186 |
-
<label for="drug_sequence">Drug Sequence:</label>
|
187 |
-
<textarea id="drug_sequence" name="drug_sequence" rows="4" placeholder="Enter drug sequence here..." required>{{ drug_sequence }}</textarea><br>
|
188 |
-
<div class="button-container">
|
189 |
-
<input type="submit" name="submit" class="button submit" value="Submit">
|
190 |
-
<input type="submit" name="clear" class="button clear" value="Clear">
|
191 |
-
</div>
|
192 |
-
</form>
|
193 |
-
</div>
|
194 |
-
<div class="right" style="display: flex; justify-content: center; align-items: center;">
|
195 |
-
{% if result %}
|
196 |
-
<div class="result">
|
197 |
-
{{ result|safe }}
|
198 |
-
</div>
|
199 |
-
{% endif %}
|
200 |
-
</div>
|
201 |
-
</div>
|
202 |
-
</body>
|
203 |
-
</html>
|
204 |
-
''', protein_sequence=protein_sequence, drug_sequence=drug_sequence, result=result)
|
205 |
-
|
206 |
-
def collate_fn_batch_encoding(batch):
|
207 |
query1, query2, scores = zip(*batch)
|
208 |
|
209 |
query_encodings1 = prot_tokenizer.batch_encode_plus(
|
@@ -228,6 +85,388 @@ def collate_fn_batch_encoding(batch):
|
|
228 |
attention_mask2 = query_encodings2["attention_mask"].bool()
|
229 |
|
230 |
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
-
|
|
|
233 |
app.run(debug=True, host="0.0.0.0", port=7860)
|
|
|
1 |
+
import os, sys, argparse, tempfile, shutil, base64, io
|
2 |
+
from flask import Flask, request, render_template_string
|
3 |
+
from werkzeug.utils import secure_filename
|
|
|
4 |
from torch.utils.data import DataLoader
|
5 |
+
import selfies
|
6 |
+
from rdkit import Chem
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import matplotlib
|
10 |
+
matplotlib.use("Agg")
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from matplotlib import cm
|
13 |
+
from typing import Optional
|
14 |
+
|
15 |
from utils.drug_tokenizer import DrugTokenizer
|
16 |
+
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
|
17 |
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
|
18 |
+
from utils.foldseek_util import get_struc_seq
|
|
|
|
|
19 |
|
20 |
+
# ───── Biopython fallback ───────────────────────────────────────
|
21 |
+
from Bio.PDB import PDBParser, MMCIFParser
|
22 |
+
from Bio.Data import IUPACData
|
23 |
|
24 |
+
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
|
25 |
+
three2one.update({"SEC": "C", "PYL": "K"})
|
26 |
+
def simple_seq_from_structure(path: str) -> str:
|
27 |
+
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
|
28 |
+
chain = next(parser.get_structure("P", path).get_chains())
|
29 |
+
return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
|
30 |
+
|
31 |
+
# ───── global paths / args ──────────────────────────────────────
|
32 |
+
FOLDSEEK_BIN = shutil.which("foldseek")
|
33 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
34 |
+
sys.path.append("..")
|
35 |
|
36 |
def parse_config():
|
37 |
+
p = argparse.ArgumentParser()
|
38 |
+
p.add_argument("-f")
|
39 |
+
p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
|
40 |
+
p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
|
41 |
+
p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
|
42 |
+
p.add_argument("--group_size", type=int, default=1)
|
43 |
+
p.add_argument("--lr", type=float, default=1e-4)
|
44 |
+
p.add_argument("--fusion", default="CAN")
|
45 |
+
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
|
46 |
+
p.add_argument("--save_path_prefix", default="save_model_ckp/")
|
47 |
+
p.add_argument("--dataset", default="BindingDB"), help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')"
|
48 |
+
return p.parse_args()
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
args = parse_config()
|
51 |
+
DEVICE = args.device
|
52 |
|
53 |
+
# ───── tokenisers & encoders ────────────────────────────────────
|
54 |
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
|
55 |
+
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
|
|
|
|
|
|
|
56 |
|
57 |
+
drug_tokenizer = DrugTokenizer() # SELFIES
|
58 |
+
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
|
59 |
|
60 |
+
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
# ─── collate fn ────────────────────────────────────────────────
|
63 |
+
def collate_fn(batch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
query1, query2, scores = zip(*batch)
|
65 |
|
66 |
query_encodings1 = prot_tokenizer.batch_encode_plus(
|
|
|
85 |
attention_mask2 = query_encodings2["attention_mask"].bool()
|
86 |
|
87 |
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
|
88 |
+
# def collate_fn_batch_encoding(batch):
|
89 |
+
|
90 |
+
def smiles_to_selfies(smiles: str) -> Optional[str]:
|
91 |
+
try:
|
92 |
+
mol = Chem.MolFromSmiles(smiles)
|
93 |
+
if mol is None:
|
94 |
+
return None
|
95 |
+
selfies_str = selfies.encoder(smiles)
|
96 |
+
return selfies_str
|
97 |
+
except Exception:
|
98 |
+
return None
|
99 |
+
|
100 |
+
|
101 |
+
# ───── single-case embedding ───────────────────────────────────
|
102 |
+
def get_case_feature(model, loader):
|
103 |
+
model.eval()
|
104 |
+
with torch.no_grad():
|
105 |
+
for p_ids, p_mask, d_ids, d_mask, _ in loader:
|
106 |
+
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
|
107 |
+
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
|
108 |
+
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
|
109 |
+
return [(p_emb.cpu(), d_emb.cpu(),
|
110 |
+
p_ids.cpu(), d_ids.cpu(),
|
111 |
+
p_mask.cpu(), d_mask.cpu(), None)]
|
112 |
+
|
113 |
+
# ───── helper:过滤特殊 token ───────────────────────────────────
|
114 |
+
def clean_tokens(ids, tokenizer):
|
115 |
+
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
|
116 |
+
return [t for t in toks if t not in tokenizer.all_special_tokens]
|
117 |
+
|
118 |
+
# ───── visualisation ───────────────────────────────────────────
|
119 |
+
|
120 |
+
def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
|
121 |
+
"""
|
122 |
+
Render a Protein → Drug cross-attention heat-map and, optionally, a
|
123 |
+
Top-20 protein-residue table for a chosen drug-token index.
|
124 |
+
|
125 |
+
The token index shown on the x-axis (and accepted via *drug_idx*) is **the
|
126 |
+
position of that token in the *original* drug sequence**, *after* the
|
127 |
+
tokeniser but *before* any pruning or truncation (1-based in the labels,
|
128 |
+
0-based for the function argument).
|
129 |
+
|
130 |
+
Returns
|
131 |
+
-------
|
132 |
+
html : str
|
133 |
+
Base64-embedded PNG heat-map (+ optional HTML table).
|
134 |
+
"""
|
135 |
+
model.eval()
|
136 |
+
with torch.no_grad():
|
137 |
+
# ── unpack single-case tensors ───────────────────────────────────────────
|
138 |
+
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
|
139 |
+
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
|
140 |
+
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
|
141 |
+
|
142 |
+
# ── forward pass: Protein → Drug attention (B, n_p, n_d) ───────────────
|
143 |
+
_, att_pd = model(p_emb, d_emb, p_mask, d_mask)
|
144 |
+
attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
|
145 |
+
|
146 |
+
# ── decode tokens (skip special symbols) ────────────────────────────────
|
147 |
+
def clean_ids(ids, tokenizer):
|
148 |
+
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
|
149 |
+
return [t for t in toks if t not in tokenizer.all_special_tokens]
|
150 |
+
|
151 |
+
# ── decode full sequences + record 1-based indices ──────────────────
|
152 |
+
p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
|
153 |
+
p_indices_full = list(range(1, len(p_tokens_full) + 1))
|
154 |
+
|
155 |
+
d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
|
156 |
+
d_indices_full = list(range(1, len(d_tokens_full) + 1))
|
157 |
+
|
158 |
+
# ── safety cut-off to match attn mat size ───────────────────────────────
|
159 |
+
p_tokens = p_tokens_full[: attn.size(0)]
|
160 |
+
p_indices_full = p_indices_full[: attn.size(0)]
|
161 |
+
d_tokens_full = d_tokens_full[: attn.size(1)]
|
162 |
+
d_indices_full = d_indices_full[: attn.size(1)]
|
163 |
+
attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
|
164 |
+
|
165 |
+
# ── adaptive sparsity pruning ───────────────────────────────────────────
|
166 |
+
thr = attn.max().item() * 0.05
|
167 |
+
row_keep = (attn.max(dim=1).values > thr)
|
168 |
+
col_keep = (attn.max(dim=0).values > thr)
|
169 |
+
|
170 |
+
if row_keep.sum() < 3:
|
171 |
+
row_keep[:] = True
|
172 |
+
if col_keep.sum() < 3:
|
173 |
+
col_keep[:] = True
|
174 |
+
|
175 |
+
attn = attn[row_keep][:, col_keep]
|
176 |
+
p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
|
177 |
+
p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
|
178 |
+
d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
|
179 |
+
d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
|
180 |
+
|
181 |
+
# ── cap column count at 150 for readability ─────────────────────────────
|
182 |
+
if attn.size(1) > 150:
|
183 |
+
topc = torch.topk(attn.sum(0), k=150).indices
|
184 |
+
attn = attn[:, topc]
|
185 |
+
d_tokens = [d_tokens [i] for i in topc]
|
186 |
+
d_indices = [d_indices[i] for i in topc]
|
187 |
+
|
188 |
+
# ── draw heat-map ───────────────────────────────────────────────────────
|
189 |
+
x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
|
190 |
+
y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
|
191 |
+
|
192 |
+
|
193 |
+
fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column
|
194 |
+
fig_h = min(24, max(6, len(p_tokens) * 0.8))
|
195 |
+
|
196 |
+
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
|
197 |
+
im = ax.imshow(attn.numpy(), aspect="auto",
|
198 |
+
cmap=cm.viridis, interpolation="nearest")
|
199 |
+
|
200 |
+
ax.set_title("Protein → Drug Attention", pad=8, fontsize=10)
|
201 |
+
|
202 |
+
ax.set_xticks(range(len(x_labels)))
|
203 |
+
ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
|
204 |
+
ha="center", va="center")
|
205 |
+
ax.tick_params(axis="x", top=True, bottom=False,
|
206 |
+
labeltop=True, labelbottom=False, pad=27)
|
207 |
+
|
208 |
+
ax.set_yticks(range(len(y_labels)))
|
209 |
+
ax.set_yticklabels(y_labels, fontsize=7)
|
210 |
+
ax.tick_params(axis="y", top=True, bottom=False,
|
211 |
+
labeltop=True, labelbottom=False,
|
212 |
+
pad=10)
|
213 |
+
|
214 |
+
fig.colorbar(im, fraction=0.026, pad=0.01)
|
215 |
+
fig.tight_layout()
|
216 |
+
|
217 |
+
buf = io.BytesIO()
|
218 |
+
fig.savefig(buf, format="png", dpi=140)
|
219 |
+
plt.close(fig)
|
220 |
+
html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
|
221 |
+
|
222 |
+
# ───────────────────── 生成 Top-20 表(若需要) ─────────────────────
|
223 |
+
table_html = "" # 先设空串,方便后面统一拼接
|
224 |
+
if drug_idx is not None:
|
225 |
+
# map original 0-based drug_idx → current column position
|
226 |
+
if (drug_idx + 1) in d_indices:
|
227 |
+
col_pos = d_indices.index(drug_idx + 1)
|
228 |
+
elif 0 <= drug_idx < len(d_tokens):
|
229 |
+
col_pos = drug_idx
|
230 |
+
else:
|
231 |
+
col_pos = None
|
232 |
+
|
233 |
+
if col_pos is not None:
|
234 |
+
col_vec = attn[:, col_pos]
|
235 |
+
topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist()
|
236 |
+
|
237 |
+
rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
|
238 |
+
res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
|
239 |
+
pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk)
|
240 |
+
|
241 |
+
drug_tok_text = d_tokens[col_pos]
|
242 |
+
orig_idx = d_indices[col_pos]
|
243 |
+
|
244 |
+
table_html = (
|
245 |
+
f"<h4 style='margin-bottom:6px'>"
|
246 |
+
f"Drug token #{orig_idx} <code>{drug_tok_text}</code> "
|
247 |
+
f"→ Top-20 Protein residues</h4>"
|
248 |
+
"<table class='tg' style='margin-bottom:8px'>"
|
249 |
+
f"<tr><th>Rank</th>{rank_hdr}</tr>"
|
250 |
+
f"<tr><td>Residue</td>{res_row}</tr>"
|
251 |
+
f"<tr><td>Position</td>{pos_row}</tr>"
|
252 |
+
"</table>")
|
253 |
+
|
254 |
+
# ────────────────── 生成可放大 + 可下载的热图 ────────────────────
|
255 |
+
buf_png = io.BytesIO()
|
256 |
+
fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅)
|
257 |
+
buf_png.seek(0)
|
258 |
+
|
259 |
+
buf_pdf = io.BytesIO()
|
260 |
+
fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量)
|
261 |
+
buf_pdf.seek(0)
|
262 |
+
plt.close(fig)
|
263 |
+
|
264 |
+
png_b64 = base64.b64encode(buf_png.getvalue()).decode()
|
265 |
+
pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
|
266 |
+
|
267 |
+
html_heat = (
|
268 |
+
f"<a href='data:image/png;base64,{png_b64}' target='_blank' "
|
269 |
+
f"title='Click to enlarge'>"
|
270 |
+
f"<img src='data:image/png;base64,{png_b64}' "
|
271 |
+
f"style='max-width:100%;height:auto;cursor:zoom-in' /></a>"
|
272 |
+
f"<div style='margin-top:6px'>"
|
273 |
+
f"<a href='data:application/pdf;base64,{pdf_b64}' "
|
274 |
+
f"download='attention_heatmap.pdf'>Download PDF</a></div>"
|
275 |
+
)
|
276 |
+
|
277 |
+
# ───────────────────────── 返回最终 HTML ─────────────────────────
|
278 |
+
return table_html + html_heat
|
279 |
+
|
280 |
+
|
281 |
+
# ───── Flask app ───────────────────────────────────────────────
|
282 |
+
app = Flask(__name__)
|
283 |
+
|
284 |
+
@app.route("/", methods=["GET", "POST"])
|
285 |
+
def index():
|
286 |
+
protein_seq = drug_seq = structure_seq = ""; result_html = None
|
287 |
+
tmp_structure_path = ""; drug_idx = None
|
288 |
+
|
289 |
+
if request.method == "POST":
|
290 |
+
drug_idx_raw = request.form.get("drug_idx", "")
|
291 |
+
drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None
|
292 |
+
|
293 |
+
struct = request.files.get("structure_file")
|
294 |
+
if struct and struct.filename:
|
295 |
+
path = os.path.join(tempfile.gettempdir(), secure_filename(struct.filename))
|
296 |
+
struct.save(path); tmp_structure_path = path
|
297 |
+
else:
|
298 |
+
tmp_structure_path = request.form.get("tmp_structure_path", "")
|
299 |
+
|
300 |
+
if "clear" in request.form:
|
301 |
+
protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = ""
|
302 |
+
|
303 |
+
elif "confirm_structure" in request.form and tmp_structure_path:
|
304 |
+
try:
|
305 |
+
parsed = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, None, plddt_mask=False)
|
306 |
+
chain = list(parsed.keys())[0]; _, _, structure_seq = parsed[chain]
|
307 |
+
except Exception:
|
308 |
+
structure_seq = simple_seq_from_structure(tmp_structure_path)
|
309 |
+
protein_seq = structure_seq
|
310 |
+
drug_input = request.form.get("drug_sequence", "")
|
311 |
+
# Heuristically check if input is SMILES (not starting with [) and convert
|
312 |
+
if not drug_input.strip().startswith("["):
|
313 |
+
converted = smiles_to_selfies(drug_input.strip())
|
314 |
+
if converted:
|
315 |
+
drug_seq = converted
|
316 |
+
else:
|
317 |
+
drug_seq = ""
|
318 |
+
result_html = "<p style='color:red'><strong>Failed to convert SMILES to SELFIES. Please check the input string.</strong></p>"
|
319 |
+
else:
|
320 |
+
drug_seq = drug_input
|
321 |
+
|
322 |
+
elif "Inference" in request.form:
|
323 |
+
protein_seq = request.form.get("protein_sequence", "")
|
324 |
+
drug_seq = request.form.get("drug_sequence", "")
|
325 |
+
if protein_seq and drug_seq:
|
326 |
+
loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1,
|
327 |
+
collate_fn=collate_fn)
|
328 |
+
feats = get_case_feature(encoding, loader)
|
329 |
+
model = FusionDTI(446, 768, args).to(DEVICE)
|
330 |
+
ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}",
|
331 |
+
"best_model.ckpt")
|
332 |
+
if os.path.isfile(ckpt):
|
333 |
+
model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
|
334 |
+
result_html = visualize_attention(model, feats, drug_idx)
|
335 |
+
|
336 |
+
return render_template_string(
|
337 |
+
# ───────────── HTML (原 UI + 新输入框) ─────────────
|
338 |
+
"""
|
339 |
+
<!doctype html>
|
340 |
+
<html lang="en"><head><meta charset="utf-8"><title>FusionDTI </title>
|
341 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=Poppins:wght@500;600&display=swap" rel="stylesheet">
|
342 |
+
|
343 |
+
<style>
|
344 |
+
:root{--bg:#f3f4f6;--card:#fff;--primary:#6366f1;--primary-dark:#4f46e5;--text:#111827;--border:#e5e7eb;}
|
345 |
+
*{box-sizing:border-box;margin:0;padding:0}
|
346 |
+
body{background:var(--bg);color:var(--text);font-family:Inter,system-ui,Arial,sans-serif;line-height:1.5;padding:32px 12px;}
|
347 |
+
h1{font-family:Poppins,Inter,sans-serif;font-weight:600;font-size:1.7rem;text-align:center;margin-bottom:28px;letter-spacing:-.2px;}
|
348 |
+
.card{max-width:1000px;margin:0 auto;background:var(--card);border:1px solid var(--border);
|
349 |
+
border-radius:12px;box-shadow:0 2px 6px rgba(0,0,0,.05);padding:32px 36px;}
|
350 |
+
label{font-weight:500;margin-bottom:6px;display:block}
|
351 |
+
textarea,input[type=file]{width:100%;font-size:.9rem;font-family:monospace;padding:10px 12px;
|
352 |
+
border:1px solid var(--border);border-radius:8px;background:#fff;resize:vertical;}
|
353 |
+
textarea{min-height:90px}
|
354 |
+
.btn{appearance:none;border:none;cursor:pointer;padding:12px 22px;border-radius:8px;font-weight:500;
|
355 |
+
font-family:Inter,sans-serif;transition:all .18s ease;color:#fff;}
|
356 |
+
.btn-primary{background:var(--primary)}.btn-primary:hover{background:var(--primary-dark)}
|
357 |
+
.btn-neutral{background:#9ca3af;}.btn-neutral:hover{background:#6b7280}
|
358 |
+
.grid{display:grid;gap:22px}.grid-2{grid-template-columns:1fr 1fr}
|
359 |
+
.vis-box{margin-top:28px;border:1px solid var(--border);border-radius:10px;overflow:auto;max-height:72vh;}
|
360 |
+
pre{white-space:pre-wrap;word-break:break-all;font-family:monospace;margin-top:8px}
|
361 |
+
|
362 |
+
/* ── tidy table for Top-20 list ─────────────────────────────── */
|
363 |
+
table.tg{border-collapse:collapse;margin-top:4px;font-size:0.83rem}
|
364 |
+
table.tg th,table.tg td{border:1px solid var(--border);padding:6px 8px;text-align:left}
|
365 |
+
table.tg th{background:var(--bg);font-weight:600}
|
366 |
+
</style>
|
367 |
+
</head>
|
368 |
+
<body>
|
369 |
+
<h1> Token-level Visualiser for Drug-Target Interaction</h1>
|
370 |
+
|
371 |
+
<!-- ───────────── Project Links (larger + spaced) ───────────── -->
|
372 |
+
<div style="margin-top:24px; text-align:center;">
|
373 |
+
<a href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank"
|
374 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
375 |
+
background:linear-gradient(to right,#10b981,#059669);color:white;
|
376 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
377 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
378 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
379 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
380 |
+
🌐 Project Page
|
381 |
+
</a>
|
382 |
+
|
383 |
+
<a href="https://arxiv.org/abs/2406.01651" target="_blank"
|
384 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
385 |
+
background:linear-gradient(to right,#ef4444,#dc2626);color:white;
|
386 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
387 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
388 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
389 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
390 |
+
📄 ArXiv: 2406.01651
|
391 |
+
</a>
|
392 |
+
|
393 |
+
<a href="https://github.com/ZhaohanM/FusionDTI" target="_blank"
|
394 |
+
style="display:inline-block;margin:8px 18px;padding:10px 20px;
|
395 |
+
background:linear-gradient(to right,#3b82f6,#2563eb);color:white;
|
396 |
+
font-weight:600;border-radius:8px;font-size:0.9rem;
|
397 |
+
font-family:Inter,sans-serif;text-decoration:none;
|
398 |
+
box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
|
399 |
+
onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
|
400 |
+
💻 GitHub Repo
|
401 |
+
</a>
|
402 |
+
</div>
|
403 |
+
|
404 |
+
<!-- ───────────── Guidelines for Use ───────────── -->
|
405 |
+
<div class="card" style="margin-bottom:24px">
|
406 |
+
<h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for Use</h2>
|
407 |
+
<ul style="margin-left:18px;line-height:1.55;list-style:decimal;">
|
408 |
+
<li><strong>Convert protein structure into a structure-aware sequence:</strong>
|
409 |
+
Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
|
410 |
+
sequence will be generated using
|
411 |
+
<a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
|
412 |
+
based on 3D structures from
|
413 |
+
<a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> or the
|
414 |
+
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
|
415 |
+
|
416 |
+
<li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
|
417 |
+
you must first visit the
|
418 |
+
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
|
419 |
+
or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a>
|
420 |
+
to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
|
421 |
+
|
422 |
+
<li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
|
423 |
+
You can enter a SELFIES string directly, or paste a SMILES string.
|
424 |
+
SMILES will be automatically converted to SELFIES using
|
425 |
+
<a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
|
426 |
+
If conversion fails, a red error message will be displayed.</li>
|
427 |
+
|
428 |
+
<li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
|
429 |
+
to highlight the Top-10 interacting protein residues.</li>
|
430 |
+
|
431 |
+
<li>After inference, you can use the
|
432 |
+
“Download PDF” link to export a high-resolution vector version.</li>
|
433 |
+
</ul>
|
434 |
+
</div>
|
435 |
+
|
436 |
+
<div class="card">
|
437 |
+
<form method="POST" enctype="multipart/form-data" class="grid">
|
438 |
+
|
439 |
+
<div><label>Protein Structure (.pdb / .cif)</label>
|
440 |
+
<input type="file" name="structure_file">
|
441 |
+
<input type="hidden" name="tmp_structure_path" value="{{ tmp_structure_path }}"></div>
|
442 |
+
|
443 |
+
<div><label>Protein Sequence</label>
|
444 |
+
<textarea name="protein_sequence" placeholder="Confirm / paste sequence…">{{ protein_seq }}</textarea></div>
|
445 |
+
|
446 |
+
<div><label>Drug Sequence (SELFIES/SMILES)</label>
|
447 |
+
<textarea name="drug_sequence" placeholder="[C][C][O]/cco …">{{ drug_seq }}</textarea></div>
|
448 |
+
|
449 |
+
<label>Drug atom/substructure index (1-based) – show Top-10 related protein residue</label>
|
450 |
+
<input type="number" name="drug_idx" min="1" style="width:120px">
|
451 |
+
|
452 |
+
<div class="grid grid-2">
|
453 |
+
<button class="btn btn-primary" type="Inference" name="confirm_structure">Confirm Structure</button>
|
454 |
+
<button class="btn btn-primary" type="Inference" name="Inference">Inference</button>
|
455 |
+
</div>
|
456 |
+
<button class="btn btn-neutral" style="width:100%" type="Inference" name="clear">Clear</button>
|
457 |
+
</form>
|
458 |
+
|
459 |
+
{% if structure_seq %}
|
460 |
+
<div style="margin-top:18px"><strong>Structure-aware sequence:</strong><pre>{{ structure_seq }}</pre></div>
|
461 |
+
{% endif %}
|
462 |
+
{% if result_html %}
|
463 |
+
<div class="vis-box" style="margin-top:26px">{{ result_html|safe }}</div>
|
464 |
+
{% endif %}
|
465 |
+
</div></body></html>
|
466 |
+
""",
|
467 |
+
protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq,
|
468 |
+
result_html=result_html, tmp_structure_path=tmp_structure_path)
|
469 |
|
470 |
+
# ───── run ─────────────────────────────────────────────────────
|
471 |
+
if __name__ == "__main__":
|
472 |
app.run(debug=True, host="0.0.0.0", port=7860)
|
requirements.txt
CHANGED
@@ -1,5 +1,11 @@
|
|
1 |
Flask
|
2 |
torch
|
3 |
transformers
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
Flask
|
2 |
torch
|
3 |
transformers
|
4 |
+
IPython
|
5 |
+
selfies
|
6 |
+
rdkit
|
7 |
+
biopython
|
8 |
+
matplotlib
|
9 |
+
scikit-learn
|
10 |
+
numpy
|
11 |
+
pandas
|
utils/.ipynb_checkpoints/drug_tokenizer-checkpoint.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
class DrugTokenizer:
|
8 |
+
def __init__(self, vocab_path="data/Tokenizer/vocab.json", special_tokens_path="data/Tokenizer/special_tokens_map.json"):
|
9 |
+
self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
|
10 |
+
self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
|
11 |
+
self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
|
12 |
+
self.unk_token_id = self.vocab[self.special_tokens['unk_token']]
|
13 |
+
self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
|
14 |
+
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
15 |
+
|
16 |
+
self.all_special_tokens = list(self.special_tokens.values())
|
17 |
+
|
18 |
+
def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
|
19 |
+
with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
|
20 |
+
vocab = json.load(vocab_file)
|
21 |
+
with open(special_tokens_path, 'r', encoding='utf-8') as special_tokens_file:
|
22 |
+
special_tokens_raw = json.load(special_tokens_file)
|
23 |
+
|
24 |
+
special_tokens = {key: value['content'] for key, value in special_tokens_raw.items()}
|
25 |
+
return vocab, special_tokens
|
26 |
+
|
27 |
+
def encode(self, sequence):
|
28 |
+
tokens = re.findall(r'\[([^\[\]]+)\]', sequence)
|
29 |
+
input_ids = [self.cls_token_id] + [self.vocab.get(token, self.unk_token_id) for token in tokens] + [self.sep_token_id]
|
30 |
+
attention_mask = [1] * len(input_ids)
|
31 |
+
return {
|
32 |
+
'input_ids': input_ids,
|
33 |
+
'attention_mask': attention_mask
|
34 |
+
}
|
35 |
+
|
36 |
+
def batch_encode_plus(self, sequences, max_length, padding, truncation, add_special_tokens, return_tensors):
|
37 |
+
input_ids_list = []
|
38 |
+
attention_mask_list = []
|
39 |
+
|
40 |
+
for sequence in sequences:
|
41 |
+
encoded = self.encode(sequence)
|
42 |
+
input_ids = encoded['input_ids']
|
43 |
+
attention_mask = encoded['attention_mask']
|
44 |
+
|
45 |
+
if len(input_ids) > max_length:
|
46 |
+
input_ids = input_ids[:max_length]
|
47 |
+
attention_mask = attention_mask[:max_length]
|
48 |
+
elif len(input_ids) < max_length:
|
49 |
+
pad_length = max_length - len(input_ids)
|
50 |
+
input_ids = input_ids + [self.vocab[self.special_tokens['pad_token']]] * pad_length
|
51 |
+
attention_mask = attention_mask + [0] * pad_length
|
52 |
+
|
53 |
+
input_ids_list.append(input_ids)
|
54 |
+
attention_mask_list.append(attention_mask)
|
55 |
+
|
56 |
+
return {
|
57 |
+
'input_ids': torch.tensor(input_ids_list, dtype=torch.long),
|
58 |
+
'attention_mask': torch.tensor(attention_mask_list, dtype=torch.long)
|
59 |
+
}
|
60 |
+
|
61 |
+
def decode(self, input_ids, skip_special_tokens=False):
|
62 |
+
tokens = []
|
63 |
+
for id in input_ids:
|
64 |
+
if skip_special_tokens and id in [self.cls_token_id, self.sep_token_id, self.pad_token_id]:
|
65 |
+
continue
|
66 |
+
tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
|
67 |
+
sequence = ''.join([f'[{token}]' for token in tokens])
|
68 |
+
return sequence
|
69 |
+
|
70 |
+
# --- 新增 ---
|
71 |
+
def convert_ids_to_tokens(self, ids):
|
72 |
+
"""list[int] → list[str],跳过未知 id"""
|
73 |
+
return [self.id_to_token.get(i, self.special_tokens['unk_token']) for i in ids]
|
utils/.ipynb_checkpoints/metric_learning_models_att_maps-checkpoint.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
sys.path.append("../")
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from torch.cuda.amp import autocast
|
11 |
+
from torch.nn import Module
|
12 |
+
from tqdm import tqdm
|
13 |
+
from torch.nn.utils.weight_norm import weight_norm
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
|
16 |
+
LOGGER = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
class FusionDTI(nn.Module):
|
19 |
+
def __init__(self, prot_out_dim, disease_out_dim, args):
|
20 |
+
super(FusionDTI, self).__init__()
|
21 |
+
self.fusion = args.fusion
|
22 |
+
self.drug_reg = nn.Linear(disease_out_dim, 512)
|
23 |
+
self.prot_reg = nn.Linear(prot_out_dim, 512)
|
24 |
+
|
25 |
+
if self.fusion == "CAN":
|
26 |
+
self.can_layer = CAN_Layer(hidden_dim=512, num_heads=8, args=args)
|
27 |
+
self.mlp_classifier = MlPdecoder_CAN(input_dim=1024)
|
28 |
+
elif self.fusion == "BAN":
|
29 |
+
self.ban_layer = weight_norm(BANLayer(512, 512, 256, 2), name='h_mat', dim=None)
|
30 |
+
self.mlp_classifier = MlPdecoder_CAN(input_dim=256)
|
31 |
+
elif self.fusion == "Nan":
|
32 |
+
self.mlp_classifier_nan = MlPdecoder_CAN(input_dim=1214)
|
33 |
+
|
34 |
+
def forward(self, prot_embed, drug_embed, prot_mask, drug_mask):
|
35 |
+
# print("drug_embed", drug_embed.shape)
|
36 |
+
if self.fusion == "Nan":
|
37 |
+
prot_embed = prot_embed.mean(1) # query : [batch_size, hidden]
|
38 |
+
drug_embed = drug_embed.mean(1) # query : [batch_size, hidden]
|
39 |
+
joint_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
40 |
+
score = self.mlp_classifier_nan(joint_embed)
|
41 |
+
else:
|
42 |
+
prot_embed = self.prot_reg(prot_embed)
|
43 |
+
drug_embed = self.drug_reg(drug_embed)
|
44 |
+
|
45 |
+
if self.fusion == "CAN":
|
46 |
+
joint_embed, att = self.can_layer(prot_embed, drug_embed, prot_mask, drug_mask)
|
47 |
+
elif self.fusion == "BAN":
|
48 |
+
joint_embed, att = self.ban_layer(prot_embed, drug_embed)
|
49 |
+
|
50 |
+
score = self.mlp_classifier(joint_embed)
|
51 |
+
|
52 |
+
return score, att
|
53 |
+
|
54 |
+
class Pre_encoded(nn.Module):
|
55 |
+
def __init__(
|
56 |
+
self, prot_encoder, drug_encoder, args
|
57 |
+
):
|
58 |
+
"""Constructor for the model.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
prot_encoder (_type_): Protein sturcture-aware sequence encoder.
|
62 |
+
drug_encoder (_type_): Drug SFLFIES encoder.
|
63 |
+
args (_type_): _description_
|
64 |
+
"""
|
65 |
+
super(Pre_encoded, self).__init__()
|
66 |
+
self.prot_encoder = prot_encoder
|
67 |
+
self.drug_encoder = drug_encoder
|
68 |
+
|
69 |
+
def encoding(self, prot_input_ids, prot_attention_mask, drug_input_ids, drug_attention_mask):
|
70 |
+
# Process inputs through encoders
|
71 |
+
prot_embed = self.prot_encoder(
|
72 |
+
input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True
|
73 |
+
).logits
|
74 |
+
# prot_embed = self.prot_reg(prot_embed)
|
75 |
+
|
76 |
+
drug_embed = self.drug_encoder(
|
77 |
+
input_ids=drug_input_ids, attention_mask=drug_attention_mask, return_dict=True
|
78 |
+
).last_hidden_state # .last_hidden_state
|
79 |
+
|
80 |
+
# print("drug_embed", drug_embed.shape)
|
81 |
+
|
82 |
+
return prot_embed, drug_embed
|
83 |
+
|
84 |
+
|
85 |
+
class CAN_Layer(nn.Module):
|
86 |
+
def __init__(self, hidden_dim, num_heads, args):
|
87 |
+
super(CAN_Layer, self).__init__()
|
88 |
+
self.agg_mode = args.agg_mode
|
89 |
+
self.group_size = args.group_size # Control Fusion Scale
|
90 |
+
self.hidden_dim = hidden_dim
|
91 |
+
self.num_heads = num_heads
|
92 |
+
self.head_size = hidden_dim // num_heads
|
93 |
+
|
94 |
+
self.query_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
95 |
+
self.key_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
96 |
+
self.value_p = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
97 |
+
|
98 |
+
self.query_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
99 |
+
self.key_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
100 |
+
self.value_d = nn.Linear(hidden_dim, hidden_dim, bias=False)
|
101 |
+
|
102 |
+
def alpha_logits(self, logits, mask_row, mask_col, inf=1e6):
|
103 |
+
N, L1, L2, H = logits.shape
|
104 |
+
mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H)
|
105 |
+
mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H)
|
106 |
+
mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col)
|
107 |
+
|
108 |
+
logits = torch.where(mask_pair, logits, logits - inf)
|
109 |
+
alpha = torch.softmax(logits, dim=2)
|
110 |
+
mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1)
|
111 |
+
alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha))
|
112 |
+
return alpha
|
113 |
+
|
114 |
+
def apply_heads(self, x, n_heads, n_ch):
|
115 |
+
s = list(x.size())[:-1] + [n_heads, n_ch]
|
116 |
+
return x.view(*s)
|
117 |
+
|
118 |
+
def group_embeddings(self, x, mask, group_size):
|
119 |
+
N, L, D = x.shape
|
120 |
+
groups = L // group_size
|
121 |
+
x_grouped = x.view(N, groups, group_size, D).mean(dim=2)
|
122 |
+
mask_grouped = mask.view(N, groups, group_size).any(dim=2)
|
123 |
+
return x_grouped, mask_grouped
|
124 |
+
|
125 |
+
def forward(self, protein, drug, mask_prot, mask_drug):
|
126 |
+
# Group embeddings before applying multi-head attention
|
127 |
+
protein_grouped, mask_prot_grouped = self.group_embeddings(protein, mask_prot, self.group_size)
|
128 |
+
drug_grouped, mask_drug_grouped = self.group_embeddings(drug, mask_drug, self.group_size)
|
129 |
+
|
130 |
+
# print("protein_grouped:", protein_grouped.shape)
|
131 |
+
# print("mask_prot_grouped:", mask_prot_grouped.shape)
|
132 |
+
|
133 |
+
# Compute queries, keys, values for both protein and drug after grouping
|
134 |
+
query_prot = self.apply_heads(self.query_p(protein_grouped), self.num_heads, self.head_size)
|
135 |
+
key_prot = self.apply_heads(self.key_p(protein_grouped), self.num_heads, self.head_size)
|
136 |
+
value_prot = self.apply_heads(self.value_p(protein_grouped), self.num_heads, self.head_size)
|
137 |
+
|
138 |
+
query_drug = self.apply_heads(self.query_d(drug_grouped), self.num_heads, self.head_size)
|
139 |
+
key_drug = self.apply_heads(self.key_d(drug_grouped), self.num_heads, self.head_size)
|
140 |
+
value_drug = self.apply_heads(self.value_d(drug_grouped), self.num_heads, self.head_size)
|
141 |
+
|
142 |
+
# Compute attention scores
|
143 |
+
logits_pp = torch.einsum('blhd, bkhd->blkh', query_prot, key_prot)
|
144 |
+
logits_pd = torch.einsum('blhd, bkhd->blkh', query_prot, key_drug)
|
145 |
+
logits_dp = torch.einsum('blhd, bkhd->blkh', query_drug, key_prot)
|
146 |
+
logits_dd = torch.einsum('blhd, bkhd->blkh', query_drug, key_drug)
|
147 |
+
# print("logits_pp:", logits_pp.shape)
|
148 |
+
|
149 |
+
alpha_pp = self.alpha_logits(logits_pp, mask_prot_grouped, mask_prot_grouped)
|
150 |
+
alpha_pd = self.alpha_logits(logits_pd, mask_prot_grouped, mask_drug_grouped)
|
151 |
+
alpha_dp = self.alpha_logits(logits_dp, mask_drug_grouped, mask_prot_grouped)
|
152 |
+
alpha_dd = self.alpha_logits(logits_dd, mask_drug_grouped, mask_drug_grouped)
|
153 |
+
|
154 |
+
prot_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_pp, value_prot).flatten(-2) +
|
155 |
+
torch.einsum('blkh, bkhd->blhd', alpha_pd, value_drug).flatten(-2)) / 2
|
156 |
+
drug_embedding = (torch.einsum('blkh, bkhd->blhd', alpha_dp, value_prot).flatten(-2) +
|
157 |
+
torch.einsum('blkh, bkhd->blhd', alpha_dd, value_drug).flatten(-2)) / 2
|
158 |
+
|
159 |
+
# print("prot_embedding:", prot_embedding.shape)
|
160 |
+
|
161 |
+
# Continue as usual with the aggregation mode
|
162 |
+
if self.agg_mode == "cls":
|
163 |
+
prot_embed = prot_embedding[:, 0] # query : [batch_size, hidden]
|
164 |
+
drug_embed = drug_embedding[:, 0] # query : [batch_size, hidden]
|
165 |
+
elif self.agg_mode == "mean_all_tok":
|
166 |
+
prot_embed = prot_embedding.mean(1) # query : [batch_size, hidden]
|
167 |
+
drug_embed = drug_embedding.mean(1) # query : [batch_size, hidden]
|
168 |
+
elif self.agg_mode == "mean":
|
169 |
+
prot_embed = (prot_embedding * mask_prot_grouped.unsqueeze(-1)).sum(1) / mask_prot_grouped.sum(-1).unsqueeze(-1)
|
170 |
+
drug_embed = (drug_embedding * mask_drug_grouped.unsqueeze(-1)).sum(1) / mask_drug_grouped.sum(-1).unsqueeze(-1)
|
171 |
+
else:
|
172 |
+
raise NotImplementedError()
|
173 |
+
|
174 |
+
# print("prot_embed:", prot_embed.shape)
|
175 |
+
|
176 |
+
query_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
177 |
+
|
178 |
+
att_pd = alpha_pd.mean(dim=-1)
|
179 |
+
|
180 |
+
# print("query_embed:", query_embed.shape)
|
181 |
+
return query_embed, att_pd
|
182 |
+
|
183 |
+
class MlPdecoder_CAN(nn.Module):
|
184 |
+
def __init__(self, input_dim):
|
185 |
+
super(MlPdecoder_CAN, self).__init__()
|
186 |
+
self.fc1 = nn.Linear(input_dim, input_dim)
|
187 |
+
self.bn1 = nn.BatchNorm1d(input_dim)
|
188 |
+
self.fc2 = nn.Linear(input_dim, input_dim // 2)
|
189 |
+
self.bn2 = nn.BatchNorm1d(input_dim // 2)
|
190 |
+
self.fc3 = nn.Linear(input_dim // 2, input_dim // 4)
|
191 |
+
self.bn3 = nn.BatchNorm1d(input_dim // 4)
|
192 |
+
self.output = nn.Linear(input_dim // 4, 1)
|
193 |
+
|
194 |
+
def forward(self, x):
|
195 |
+
x = self.bn1(torch.relu(self.fc1(x)))
|
196 |
+
x = self.bn2(torch.relu(self.fc2(x)))
|
197 |
+
x = self.bn3(torch.relu(self.fc3(x)))
|
198 |
+
x = torch.sigmoid(self.output(x))
|
199 |
+
return x
|
200 |
+
|
201 |
+
class MLPdecoder_BAN(nn.Module):
|
202 |
+
def __init__(self, in_dim, hidden_dim, out_dim, binary=1):
|
203 |
+
super(MLPdecoder_BAN, self).__init__()
|
204 |
+
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
205 |
+
self.bn1 = nn.BatchNorm1d(hidden_dim)
|
206 |
+
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
|
207 |
+
self.bn2 = nn.BatchNorm1d(hidden_dim)
|
208 |
+
self.fc3 = nn.Linear(hidden_dim, out_dim)
|
209 |
+
self.bn3 = nn.BatchNorm1d(out_dim)
|
210 |
+
self.fc4 = nn.Linear(out_dim, binary)
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
x = self.bn1(F.relu(self.fc1(x)))
|
214 |
+
x = self.bn2(F.relu(self.fc2(x)))
|
215 |
+
x = self.bn3(F.relu(self.fc3(x)))
|
216 |
+
# x = self.fc4(x)
|
217 |
+
x = torch.sigmoid(self.fc4(x))
|
218 |
+
return x
|
219 |
+
|
220 |
+
class BANLayer(nn.Module):
|
221 |
+
""" Bilinear attention network
|
222 |
+
Modified from https://github.com/peizhenbai/DrugBAN/blob/main/ban.py
|
223 |
+
"""
|
224 |
+
def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=0.2, k=3):
|
225 |
+
super(BANLayer, self).__init__()
|
226 |
+
|
227 |
+
self.c = 32
|
228 |
+
self.k = k
|
229 |
+
self.v_dim = v_dim
|
230 |
+
self.q_dim = q_dim
|
231 |
+
self.h_dim = h_dim
|
232 |
+
self.h_out = h_out
|
233 |
+
|
234 |
+
self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout)
|
235 |
+
self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout)
|
236 |
+
# self.dropout = nn.Dropout(dropout[1])
|
237 |
+
if 1 < k:
|
238 |
+
self.p_net = nn.AvgPool1d(self.k, stride=self.k)
|
239 |
+
|
240 |
+
if h_out <= self.c:
|
241 |
+
self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
|
242 |
+
self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
|
243 |
+
else:
|
244 |
+
self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
|
245 |
+
|
246 |
+
self.bn = nn.BatchNorm1d(h_dim)
|
247 |
+
|
248 |
+
def attention_pooling(self, v, q, att_map):
|
249 |
+
fusion_logits = torch.einsum('bvk,bvq,bqk->bk', (v, att_map, q))
|
250 |
+
if 1 < self.k:
|
251 |
+
fusion_logits = fusion_logits.unsqueeze(1) # b x 1 x d
|
252 |
+
fusion_logits = self.p_net(fusion_logits).squeeze(1) * self.k # sum-pooling
|
253 |
+
return fusion_logits
|
254 |
+
|
255 |
+
def forward(self, v, q, softmax=False):
|
256 |
+
v_num = v.size(1)
|
257 |
+
q_num = q.size(1)
|
258 |
+
# print("v_num", v_num)
|
259 |
+
# print("v_num ", v_num)
|
260 |
+
if self.h_out <= self.c:
|
261 |
+
v_ = self.v_net(v)
|
262 |
+
q_ = self.q_net(q)
|
263 |
+
# print("v_", v_.shape)
|
264 |
+
# print("q_ ", q_.shape)
|
265 |
+
att_maps = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
|
266 |
+
# print("Attention map_1",att_maps.shape)
|
267 |
+
else:
|
268 |
+
v_ = self.v_net(v).transpose(1, 2).unsqueeze(3)
|
269 |
+
q_ = self.q_net(q).transpose(1, 2).unsqueeze(2)
|
270 |
+
d_ = torch.matmul(v_, q_) # b x h_dim x v x q
|
271 |
+
att_maps = self.h_net(d_.transpose(1, 2).transpose(2, 3)) # b x v x q x h_out
|
272 |
+
att_maps = att_maps.transpose(2, 3).transpose(1, 2) # b x h_out x v x q
|
273 |
+
# print("Attention map_2",att_maps.shape)
|
274 |
+
if softmax:
|
275 |
+
p = nn.functional.softmax(att_maps.view(-1, self.h_out, v_num * q_num), 2)
|
276 |
+
att_maps = p.view(-1, self.h_out, v_num, q_num)
|
277 |
+
# print("Attention map_softmax", att_maps.shape)
|
278 |
+
logits = self.attention_pooling(v_, q_, att_maps[:, 0, :, :])
|
279 |
+
for i in range(1, self.h_out):
|
280 |
+
logits_i = self.attention_pooling(v_, q_, att_maps[:, i, :, :])
|
281 |
+
logits += logits_i
|
282 |
+
logits = self.bn(logits)
|
283 |
+
return logits, att_maps
|
284 |
+
|
285 |
+
|
286 |
+
class FCNet(nn.Module):
|
287 |
+
"""Simple class for non-linear fully connect network
|
288 |
+
Modified from https://github.com/jnhwkim/ban-vqa/blob/master/fc.py
|
289 |
+
"""
|
290 |
+
|
291 |
+
def __init__(self, dims, act='ReLU', dropout=0):
|
292 |
+
super(FCNet, self).__init__()
|
293 |
+
|
294 |
+
layers = []
|
295 |
+
for i in range(len(dims) - 2):
|
296 |
+
in_dim = dims[i]
|
297 |
+
out_dim = dims[i + 1]
|
298 |
+
if 0 < dropout:
|
299 |
+
layers.append(nn.Dropout(dropout))
|
300 |
+
layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
|
301 |
+
if '' != act:
|
302 |
+
layers.append(getattr(nn, act)())
|
303 |
+
if 0 < dropout:
|
304 |
+
layers.append(nn.Dropout(dropout))
|
305 |
+
layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
|
306 |
+
if '' != act:
|
307 |
+
layers.append(getattr(nn, act)())
|
308 |
+
|
309 |
+
self.main = nn.Sequential(*layers)
|
310 |
+
|
311 |
+
def forward(self, x):
|
312 |
+
return self.main(x)
|
313 |
+
|
314 |
+
|
315 |
+
class BatchFileDataset_Case(Dataset):
|
316 |
+
def __init__(self, file_list):
|
317 |
+
self.file_list = file_list
|
318 |
+
|
319 |
+
def __len__(self):
|
320 |
+
return len(self.file_list)
|
321 |
+
|
322 |
+
def __getitem__(self, idx):
|
323 |
+
batch_file = self.file_list[idx]
|
324 |
+
data = torch.load(batch_file)
|
325 |
+
return data['prot'], data['drug'], data['prot_ids'], data['drug_ids'], data['prot_mask'], data['drug_mask'], data['y']
|
utils/__pycache__/foldseek_util.cpython-38.pyc
ADDED
Binary file (4.86 kB). View file
|
|
utils/__pycache__/metric_learning_models_att_maps.cpython-38.pyc
ADDED
Binary file (10.8 kB). View file
|
|
utils/drug_tokenizer.py
CHANGED
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|
5 |
from torch.nn import functional as F
|
6 |
|
7 |
class DrugTokenizer:
|
8 |
-
def __init__(self, vocab_path="
|
9 |
self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
|
10 |
self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
|
11 |
self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
|
@@ -13,6 +13,8 @@ class DrugTokenizer:
|
|
13 |
self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
|
14 |
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
15 |
|
|
|
|
|
16 |
def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
|
17 |
with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
|
18 |
vocab = json.load(vocab_file)
|
@@ -64,3 +66,8 @@ class DrugTokenizer:
|
|
64 |
tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
|
65 |
sequence = ''.join([f'[{token}]' for token in tokens])
|
66 |
return sequence
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from torch.nn import functional as F
|
6 |
|
7 |
class DrugTokenizer:
|
8 |
+
def __init__(self, vocab_path="data/Tokenizer/vocab.json", special_tokens_path="data/Tokenizer/special_tokens_map.json"):
|
9 |
self.vocab, self.special_tokens = self.load_vocab_and_special_tokens(vocab_path, special_tokens_path)
|
10 |
self.cls_token_id = self.vocab[self.special_tokens['cls_token']]
|
11 |
self.sep_token_id = self.vocab[self.special_tokens['sep_token']]
|
|
|
13 |
self.pad_token_id = self.vocab[self.special_tokens['pad_token']]
|
14 |
self.id_to_token = {v: k for k, v in self.vocab.items()}
|
15 |
|
16 |
+
self.all_special_tokens = list(self.special_tokens.values())
|
17 |
+
|
18 |
def load_vocab_and_special_tokens(self, vocab_path, special_tokens_path):
|
19 |
with open(vocab_path, 'r', encoding='utf-8') as vocab_file:
|
20 |
vocab = json.load(vocab_file)
|
|
|
66 |
tokens.append(self.id_to_token.get(id, self.special_tokens['unk_token']))
|
67 |
sequence = ''.join([f'[{token}]' for token in tokens])
|
68 |
return sequence
|
69 |
+
|
70 |
+
# --- 新增 ---
|
71 |
+
def convert_ids_to_tokens(self, ids):
|
72 |
+
"""list[int] → list[str],跳过未知 id"""
|
73 |
+
return [self.id_to_token.get(i, self.special_tokens['unk_token']) for i in ids]
|
utils/foldseek_util.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import re
|
6 |
+
import sys
|
7 |
+
|
8 |
+
from Bio.PDB import PDBParser, MMCIFParser
|
9 |
+
|
10 |
+
|
11 |
+
sys.path.append(".")
|
12 |
+
|
13 |
+
|
14 |
+
# Get structural seqs from pdb file
|
15 |
+
def get_struc_seq(foldseek,
|
16 |
+
path,
|
17 |
+
chains: list = None,
|
18 |
+
process_id: int = 0,
|
19 |
+
plddt_mask: bool = "auto",
|
20 |
+
plddt_threshold: float = 70.,
|
21 |
+
foldseek_verbose: bool = False) -> dict:
|
22 |
+
"""
|
23 |
+
|
24 |
+
Args:
|
25 |
+
foldseek: Binary executable file of foldseek
|
26 |
+
|
27 |
+
path: Path to pdb file
|
28 |
+
|
29 |
+
chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
|
30 |
+
|
31 |
+
process_id: Process ID for temporary files. This is used for parallel processing.
|
32 |
+
|
33 |
+
plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
|
34 |
+
|
35 |
+
plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
|
36 |
+
|
37 |
+
foldseek_verbose: If True, foldseek will print verbose messages.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
|
41 |
+
(seq, struc_seq, combined_seq).
|
42 |
+
"""
|
43 |
+
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
|
44 |
+
assert os.path.exists(path), f"PDB file not found: {path}"
|
45 |
+
|
46 |
+
tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
|
47 |
+
if foldseek_verbose:
|
48 |
+
cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
49 |
+
else:
|
50 |
+
cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
|
51 |
+
os.system(cmd)
|
52 |
+
|
53 |
+
# Check whether the structure is predicted by AlphaFold2
|
54 |
+
if plddt_mask == "auto":
|
55 |
+
with open(path, "r") as r:
|
56 |
+
plddt_mask = True if "alphafold" in r.read().lower() else False
|
57 |
+
|
58 |
+
seq_dict = {}
|
59 |
+
name = os.path.basename(path)
|
60 |
+
with open(tmp_save_path, "r") as r:
|
61 |
+
for i, line in enumerate(r):
|
62 |
+
desc, seq, struc_seq = line.split("\t")[:3]
|
63 |
+
|
64 |
+
# Mask low plddt
|
65 |
+
if plddt_mask:
|
66 |
+
try:
|
67 |
+
plddts = extract_plddt(path)
|
68 |
+
assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
|
69 |
+
|
70 |
+
# Mask regions with plddt < threshold
|
71 |
+
indices = np.where(plddts < plddt_threshold)[0]
|
72 |
+
np_seq = np.array(list(struc_seq))
|
73 |
+
np_seq[indices] = "#"
|
74 |
+
struc_seq = "".join(np_seq)
|
75 |
+
|
76 |
+
except Exception as e:
|
77 |
+
print(f"Error: {e}")
|
78 |
+
print(f"Failed to mask plddt for {name}")
|
79 |
+
|
80 |
+
name_chain = desc.split(" ")[0]
|
81 |
+
chain = name_chain.replace(name, "").split("_")[-1]
|
82 |
+
|
83 |
+
if chains is None or chain in chains:
|
84 |
+
if chain not in seq_dict:
|
85 |
+
combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
|
86 |
+
seq_dict[chain] = (seq, struc_seq, combined_seq)
|
87 |
+
|
88 |
+
os.remove(tmp_save_path)
|
89 |
+
os.remove(tmp_save_path + ".dbtype")
|
90 |
+
return seq_dict
|
91 |
+
|
92 |
+
|
93 |
+
def extract_plddt(pdb_path: str) -> np.ndarray:
|
94 |
+
"""
|
95 |
+
Extract plddt scores from pdb file.
|
96 |
+
Args:
|
97 |
+
pdb_path: Path to pdb file.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
plddts: plddt scores.
|
101 |
+
"""
|
102 |
+
|
103 |
+
# Initialize parser
|
104 |
+
if pdb_path.endswith(".cif"):
|
105 |
+
parser = MMCIFParser()
|
106 |
+
elif pdb_path.endswith(".pdb"):
|
107 |
+
parser = PDBParser()
|
108 |
+
else:
|
109 |
+
raise ValueError("Invalid file format for plddt extraction. Must be '.cif' or '.pdb'.")
|
110 |
+
|
111 |
+
structure = parser.get_structure('protein', pdb_path)
|
112 |
+
model = structure[0]
|
113 |
+
chain = model["A"]
|
114 |
+
|
115 |
+
# Extract plddt scores
|
116 |
+
plddts = []
|
117 |
+
for residue in chain:
|
118 |
+
residue_plddts = []
|
119 |
+
for atom in residue:
|
120 |
+
plddt = atom.get_bfactor()
|
121 |
+
residue_plddts.append(plddt)
|
122 |
+
|
123 |
+
plddts.append(np.mean(residue_plddts))
|
124 |
+
|
125 |
+
plddts = np.array(plddts)
|
126 |
+
return plddts
|
127 |
+
|
128 |
+
|
129 |
+
def transform_pdb_dir(foldseek: str, pdb_dir: str, seq_type: str, save_path: str):
|
130 |
+
"""
|
131 |
+
Transform a directory of pdb files into a fasta file.
|
132 |
+
Args:
|
133 |
+
foldseek: Binary executable file of foldseek.
|
134 |
+
|
135 |
+
pdb_dir: Directory of pdb files.
|
136 |
+
|
137 |
+
seq_type: Type of sequence to be extracted. Must be "aa" or "foldseek"
|
138 |
+
|
139 |
+
save_path: Path to save the fasta file.
|
140 |
+
"""
|
141 |
+
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
|
142 |
+
assert seq_type in ["aa", "foldseek"], f"seq_type must be 'aa' or 'foldseek'!"
|
143 |
+
|
144 |
+
tmp_save_path = f"get_struc_seq_{time.time()}.tsv"
|
145 |
+
cmd = f"{foldseek} structureto3didescriptor --chain-name-mode 1 {pdb_dir} {tmp_save_path}"
|
146 |
+
os.system(cmd)
|
147 |
+
|
148 |
+
with open(tmp_save_path, "r") as r, open(save_path, "w") as w:
|
149 |
+
for line in r:
|
150 |
+
protein_id, aa_seq, foldseek_seq = line.strip().split("\t")[:3]
|
151 |
+
|
152 |
+
if seq_type == "aa":
|
153 |
+
w.write(f">{protein_id}\n{aa_seq}\n")
|
154 |
+
else:
|
155 |
+
w.write(f">{protein_id}\n{foldseek_seq.lower()}\n")
|
156 |
+
|
157 |
+
os.remove(tmp_save_path)
|
158 |
+
os.remove(tmp_save_path + ".dbtype")
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == '__main__':
|
162 |
+
foldseek = "/sujin/bin/foldseek"
|
163 |
+
# test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
|
164 |
+
test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
|
165 |
+
plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
|
166 |
+
res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
|
167 |
+
print(res["A"][1].lower())
|
utils/metric_learning_models_att_maps.py
CHANGED
@@ -175,15 +175,10 @@ class CAN_Layer(nn.Module):
|
|
175 |
|
176 |
query_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
177 |
|
178 |
-
|
179 |
-
att = torch.zeros(1, 1, 1024, 1024)
|
180 |
-
att[:, :, :512, :512] = alpha_pp.mean(dim=-1) # Protein to Protein
|
181 |
-
att[:, :, :512, 512:] = alpha_pd.mean(dim=-1) # Protein to Drug
|
182 |
-
att[:, :, 512:, :512] = alpha_dp.mean(dim=-1) # Drug to Protein
|
183 |
-
att[:, :, 512:, 512:] = alpha_dd.mean(dim=-1) # Drug to Drug
|
184 |
|
185 |
# print("query_embed:", query_embed.shape)
|
186 |
-
return query_embed,
|
187 |
|
188 |
class MlPdecoder_CAN(nn.Module):
|
189 |
def __init__(self, input_dim):
|
|
|
175 |
|
176 |
query_embed = torch.cat([prot_embed, drug_embed], dim=1)
|
177 |
|
178 |
+
att_pd = alpha_pd.mean(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
# print("query_embed:", query_embed.shape)
|
181 |
+
return query_embed, att_pd
|
182 |
|
183 |
class MlPdecoder_CAN(nn.Module):
|
184 |
def __init__(self, input_dim):
|