import os, sys, argparse, tempfile, shutil, base64, io from flask import Flask, request, render_template_string from werkzeug.utils import secure_filename from torch.utils.data import DataLoader import selfies from rdkit import Chem import torch import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib import cm from typing import Optional from utils.drug_tokenizer import DrugTokenizer from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI from utils.foldseek_util import get_struc_seq # ───── Biopython fallback ─────────────────────────────────────── from Bio.PDB import PDBParser, MMCIFParser from Bio.Data import IUPACData three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()} three2one.update({"SEC": "C", "PYL": "K"}) def simple_seq_from_structure(path: str) -> str: parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True) chain = next(parser.get_structure("P", path).get_chains()) return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain) # ───── global paths / args ────────────────────────────────────── FOLDSEEK_BIN = shutil.which("foldseek") os.environ["TOKENIZERS_PARALLELISM"] = "false" sys.path.append("..") def parse_config(): p = argparse.ArgumentParser() p.add_argument("-f") p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2") p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer") p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}") p.add_argument("--group_size", type=int, default=1) p.add_argument("--lr", type=float, default=1e-4) p.add_argument("--fusion", default="CAN") p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--save_path_prefix", default="save_model_ckp/") p.add_argument("--dataset", default="BindingDB", help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')") return p.parse_args() args = parse_config() DEVICE = args.device # ───── tokenisers & encoders ──────────────────────────────────── prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path) prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path) drug_tokenizer = DrugTokenizer() # SELFIES drug_model = AutoModel.from_pretrained(args.drug_encoder_path) encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE) # ─── collate fn ──────────────────────────────────────────────── def collate_fn(batch): query1, query2, scores = zip(*batch) query_encodings1 = prot_tokenizer.batch_encode_plus( list(query1), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) query_encodings2 = drug_tokenizer.batch_encode_plus( list(query2), max_length=512, padding="max_length", truncation=True, add_special_tokens=True, return_tensors="pt", ) scores = torch.tensor(list(scores)) attention_mask1 = query_encodings1["attention_mask"].bool() attention_mask2 = query_encodings2["attention_mask"].bool() return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores # def collate_fn_batch_encoding(batch): def smiles_to_selfies(smiles: str) -> Optional[str]: try: mol = Chem.MolFromSmiles(smiles) if mol is None: return None selfies_str = selfies.encoder(smiles) return selfies_str except Exception: return None # ───── single-case embedding ─────────────────────────────────── def get_case_feature(model, loader): model.eval() with torch.no_grad(): for p_ids, p_mask, d_ids, d_mask, _ in loader: p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE) d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE) p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask) return [(p_emb.cpu(), d_emb.cpu(), p_ids.cpu(), d_ids.cpu(), p_mask.cpu(), d_mask.cpu(), None)] # ───── helper:过滤特殊 token ─────────────────────────────────── def clean_tokens(ids, tokenizer): toks = tokenizer.convert_ids_to_tokens(ids.tolist()) return [t for t in toks if t not in tokenizer.all_special_tokens] # ───── visualisation ─────────────────────────────────────────── def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str: """ Render a Protein → Drug cross-attention heat-map and, optionally, a Top-20 protein-residue table for a chosen drug-token index. The token index shown on the x-axis (and accepted via *drug_idx*) is **the position of that token in the *original* drug sequence**, *after* the tokeniser but *before* any pruning or truncation (1-based in the labels, 0-based for the function argument). Returns ------- html : str Base64-embedded PNG heat-map (+ optional HTML table). """ model.eval() with torch.no_grad(): # ── unpack single-case tensors ─────────────────────────────────────────── p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0] p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE) p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE) # ── forward pass: Protein → Drug attention (B, n_p, n_d) ─────────────── _, att_pd = model(p_emb, d_emb, p_mask, d_mask) attn = att_pd.squeeze(0).cpu() # (n_p, n_d) # ── decode tokens (skip special symbols) ──────────────────────────────── def clean_ids(ids, tokenizer): toks = tokenizer.convert_ids_to_tokens(ids.tolist()) return [t for t in toks if t not in tokenizer.all_special_tokens] # ── decode full sequences + record 1-based indices ────────────────── p_tokens_full = clean_ids(p_ids[0], prot_tokenizer) p_indices_full = list(range(1, len(p_tokens_full) + 1)) d_tokens_full = clean_ids(d_ids[0], drug_tokenizer) d_indices_full = list(range(1, len(d_tokens_full) + 1)) # ── safety cut-off to match attn mat size ─────────────────────────────── p_tokens = p_tokens_full[: attn.size(0)] p_indices_full = p_indices_full[: attn.size(0)] d_tokens_full = d_tokens_full[: attn.size(1)] d_indices_full = d_indices_full[: attn.size(1)] attn = attn[: len(p_tokens_full), : len(d_tokens_full)] # ── adaptive sparsity pruning ─────────────────────────────────────────── thr = attn.max().item() * 0.05 row_keep = (attn.max(dim=1).values > thr) col_keep = (attn.max(dim=0).values > thr) if row_keep.sum() < 3: row_keep[:] = True if col_keep.sum() < 3: col_keep[:] = True attn = attn[row_keep][:, col_keep] p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep] p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep] d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep] d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep] # ── cap column count at 150 for readability ───────────────────────────── if attn.size(1) > 150: topc = torch.topk(attn.sum(0), k=150).indices attn = attn[:, topc] d_tokens = [d_tokens [i] for i in topc] d_indices = [d_indices[i] for i in topc] # ── draw heat-map ─────────────────────────────────────────────────────── x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)] y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)] fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column fig_h = min(24, max(6, len(p_tokens) * 0.8)) fig, ax = plt.subplots(figsize=(fig_w, fig_h)) im = ax.imshow(attn.numpy(), aspect="auto", cmap=cm.viridis, interpolation="nearest") ax.set_title("Protein → Drug Attention", pad=8, fontsize=10) ax.set_xticks(range(len(x_labels))) ax.set_xticklabels(x_labels, rotation=90, fontsize=8, ha="center", va="center") ax.tick_params(axis="x", top=True, bottom=False, labeltop=True, labelbottom=False, pad=27) ax.set_yticks(range(len(y_labels))) ax.set_yticklabels(y_labels, fontsize=7) ax.tick_params(axis="y", top=True, bottom=False, labeltop=True, labelbottom=False, pad=10) fig.colorbar(im, fraction=0.026, pad=0.01) fig.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=140) plt.close(fig) html = f'' # ───────────────────── 生成 Top-20 表(若需要) ───────────────────── table_html = "" # 先设空串,方便后面统一拼接 if drug_idx is not None: # map original 0-based drug_idx → current column position if (drug_idx + 1) in d_indices: col_pos = d_indices.index(drug_idx + 1) elif 0 <= drug_idx < len(d_tokens): col_pos = drug_idx else: col_pos = None if col_pos is not None: col_vec = attn[:, col_pos] topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist() rank_hdr = "".join(f"{r+1}" for r in range(len(topk))) res_row = "".join(f"{p_tokens[i]}" for i in topk) pos_row = "".join(f"{p_indices[i]}"for i in topk) drug_tok_text = d_tokens[col_pos] orig_idx = d_indices[col_pos] table_html = ( f"

" f"Drug token #{orig_idx} {drug_tok_text} " f"→ Top-20 Protein residues

" "" f"{rank_hdr}" f"{res_row}" f"{pos_row}" "
Rank
Residue
Position
") # ────────────────── 生成可放大 + 可下载的热图 ──────────────────── buf_png = io.BytesIO() fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅) buf_png.seek(0) buf_pdf = io.BytesIO() fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量) buf_pdf.seek(0) plt.close(fig) png_b64 = base64.b64encode(buf_png.getvalue()).decode() pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode() html_heat = ( f"" f"" f"
" f"Download PDF
" ) # ───────────────────────── 返回最终 HTML ───────────────────────── return table_html + html_heat # ───── Flask app ─────────────────────────────────────────────── app = Flask(__name__) @app.route("/", methods=["GET", "POST"]) def index(): protein_seq = drug_seq = structure_seq = ""; result_html = None tmp_structure_path = ""; drug_idx = None if request.method == "POST": drug_idx_raw = request.form.get("drug_idx", "") drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None struct = request.files.get("structure_file") if struct and struct.filename: path = os.path.join(tempfile.gettempdir(), secure_filename(struct.filename)) struct.save(path); tmp_structure_path = path else: tmp_structure_path = request.form.get("tmp_structure_path", "") if "clear" in request.form: protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = "" elif "confirm_structure" in request.form and tmp_structure_path: try: parsed = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, None, plddt_mask=False) chain = list(parsed.keys())[0]; _, _, structure_seq = parsed[chain] except Exception: structure_seq = simple_seq_from_structure(tmp_structure_path) protein_seq = structure_seq drug_input = request.form.get("drug_sequence", "") # Heuristically check if input is SMILES (not starting with [) and convert if not drug_input.strip().startswith("["): converted = smiles_to_selfies(drug_input.strip()) if converted: drug_seq = converted else: drug_seq = "" result_html = "

Failed to convert SMILES to SELFIES. Please check the input string.

" else: drug_seq = drug_input elif "Inference" in request.form: protein_seq = request.form.get("protein_sequence", "") drug_seq = request.form.get("drug_sequence", "") if protein_seq and drug_seq: loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn) feats = get_case_feature(encoding, loader) model = FusionDTI(446, 768, args).to(DEVICE) ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt") if os.path.isfile(ckpt): model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) result_html = visualize_attention(model, feats, drug_idx) return render_template_string( # ───────────── HTML (原 UI + 新输入框) ───────────── """ FusionDTI

Token-level Visualiser for Drug-Target Interaction

🌐 Project Page 📄 ArXiv: 2406.01651 💻 GitHub Repo

Guidelines for Use

{% if structure_seq %}
Structure-aware sequence:
{{ structure_seq }}
{% endif %} {% if result_html %}
{{ result_html|safe }}
{% endif %}
""", protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq, result_html=result_html, tmp_structure_path=tmp_structure_path) # ───── run ───────────────────────────────────────────────────── if __name__ == "__main__": app.run(debug=True, host="0.0.0.0", port=7860)