import os, sys, argparse, tempfile, shutil, base64, io
from flask import Flask, request, render_template_string
from werkzeug.utils import secure_filename
from torch.utils.data import DataLoader
import selfies
from rdkit import Chem
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
from typing import Optional
from utils.drug_tokenizer import DrugTokenizer
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
from utils.foldseek_util import get_struc_seq
# ───── Biopython fallback ───────────────────────────────────────
from Bio.PDB import PDBParser, MMCIFParser
from Bio.Data import IUPACData
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
three2one.update({"SEC": "C", "PYL": "K"})
def simple_seq_from_structure(path: str) -> str:
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
chain = next(parser.get_structure("P", path).get_chains())
return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
# ───── global paths / args ──────────────────────────────────────
FOLDSEEK_BIN = shutil.which("foldseek")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
sys.path.append("..")
def parse_config():
p = argparse.ArgumentParser()
p.add_argument("-f")
p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
p.add_argument("--group_size", type=int, default=1)
p.add_argument("--lr", type=float, default=1e-4)
p.add_argument("--fusion", default="CAN")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--save_path_prefix", default="save_model_ckp/")
p.add_argument("--dataset", default="BindingDB",
help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')")
return p.parse_args()
args = parse_config()
DEVICE = args.device
# ───── tokenisers & encoders ────────────────────────────────────
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
drug_tokenizer = DrugTokenizer() # SELFIES
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
# ─── collate fn ────────────────────────────────────────────────
def collate_fn(batch):
query1, query2, scores = zip(*batch)
query_encodings1 = prot_tokenizer.batch_encode_plus(
list(query1),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
query_encodings2 = drug_tokenizer.batch_encode_plus(
list(query2),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
scores = torch.tensor(list(scores))
attention_mask1 = query_encodings1["attention_mask"].bool()
attention_mask2 = query_encodings2["attention_mask"].bool()
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
# def collate_fn_batch_encoding(batch):
def smiles_to_selfies(smiles: str) -> Optional[str]:
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
selfies_str = selfies.encoder(smiles)
return selfies_str
except Exception:
return None
# ───── single-case embedding ───────────────────────────────────
def get_case_feature(model, loader):
model.eval()
with torch.no_grad():
for p_ids, p_mask, d_ids, d_mask, _ in loader:
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
return [(p_emb.cpu(), d_emb.cpu(),
p_ids.cpu(), d_ids.cpu(),
p_mask.cpu(), d_mask.cpu(), None)]
# ───── helper:过滤特殊 token ───────────────────────────────────
def clean_tokens(ids, tokenizer):
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
return [t for t in toks if t not in tokenizer.all_special_tokens]
# ───── visualisation ───────────────────────────────────────────
def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
"""
Render a Protein → Drug cross-attention heat-map and, optionally, a
Top-20 protein-residue table for a chosen drug-token index.
The token index shown on the x-axis (and accepted via *drug_idx*) is **the
position of that token in the *original* drug sequence**, *after* the
tokeniser but *before* any pruning or truncation (1-based in the labels,
0-based for the function argument).
Returns
-------
html : str
Base64-embedded PNG heat-map (+ optional HTML table).
"""
model.eval()
with torch.no_grad():
# ── unpack single-case tensors ───────────────────────────────────────────
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
# ── forward pass: Protein → Drug attention (B, n_p, n_d) ───────────────
_, att_pd = model(p_emb, d_emb, p_mask, d_mask)
attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
# ── decode tokens (skip special symbols) ────────────────────────────────
def clean_ids(ids, tokenizer):
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
return [t for t in toks if t not in tokenizer.all_special_tokens]
# ── decode full sequences + record 1-based indices ──────────────────
p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
p_indices_full = list(range(1, len(p_tokens_full) + 1))
d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
d_indices_full = list(range(1, len(d_tokens_full) + 1))
# ── safety cut-off to match attn mat size ───────────────────────────────
p_tokens = p_tokens_full[: attn.size(0)]
p_indices_full = p_indices_full[: attn.size(0)]
d_tokens_full = d_tokens_full[: attn.size(1)]
d_indices_full = d_indices_full[: attn.size(1)]
attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
# ── adaptive sparsity pruning ───────────────────────────────────────────
thr = attn.max().item() * 0.05
row_keep = (attn.max(dim=1).values > thr)
col_keep = (attn.max(dim=0).values > thr)
if row_keep.sum() < 3:
row_keep[:] = True
if col_keep.sum() < 3:
col_keep[:] = True
attn = attn[row_keep][:, col_keep]
p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
# ── cap column count at 150 for readability ─────────────────────────────
if attn.size(1) > 150:
topc = torch.topk(attn.sum(0), k=150).indices
attn = attn[:, topc]
d_tokens = [d_tokens [i] for i in topc]
d_indices = [d_indices[i] for i in topc]
# ── draw heat-map ───────────────────────────────────────────────────────
x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column
fig_h = min(24, max(6, len(p_tokens) * 0.8))
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
im = ax.imshow(attn.numpy(), aspect="auto",
cmap=cm.viridis, interpolation="nearest")
ax.set_title("Protein → Drug Attention", pad=8, fontsize=10)
ax.set_xticks(range(len(x_labels)))
ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
ha="center", va="center")
ax.tick_params(axis="x", top=True, bottom=False,
labeltop=True, labelbottom=False, pad=27)
ax.set_yticks(range(len(y_labels)))
ax.set_yticklabels(y_labels, fontsize=7)
ax.tick_params(axis="y", top=True, bottom=False,
labeltop=True, labelbottom=False,
pad=10)
fig.colorbar(im, fraction=0.026, pad=0.01)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=140)
plt.close(fig)
html = f''
# ───────────────────── 生成 Top-20 表(若需要) ─────────────────────
table_html = "" # 先设空串,方便后面统一拼接
if drug_idx is not None:
# map original 0-based drug_idx → current column position
if (drug_idx + 1) in d_indices:
col_pos = d_indices.index(drug_idx + 1)
elif 0 <= drug_idx < len(d_tokens):
col_pos = drug_idx
else:
col_pos = None
if col_pos is not None:
col_vec = attn[:, col_pos]
topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist()
rank_hdr = "".join(f"
{drug_tok_text}
"
f"→ Top-20 Protein residuesRank | {rank_hdr}
---|
Residue | {res_row}
Position | {pos_row}
Failed to convert SMILES to SELFIES. Please check the input string.
" else: drug_seq = drug_input elif "Inference" in request.form: protein_seq = request.form.get("protein_sequence", "") drug_seq = request.form.get("drug_sequence", "") if protein_seq and drug_seq: loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn) feats = get_case_feature(encoding, loader) model = FusionDTI(446, 768, args).to(DEVICE) ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt") if os.path.isfile(ckpt): model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) result_html = visualize_attention(model, feats, drug_idx) return render_template_string( # ───────────── HTML (原 UI + 新输入框) ───────────── """.pdb
or .cif
file. A structure-aware
sequence will be generated using
Foldseek,
based on 3D structures from
AlphaFold DB or the
Protein Data Bank (PDB)..cif
or .pdb
file.{{ structure_seq }}