FusionDTI / utils /foldseek_util.py
ZhaohanM
Update: SMILES-to-SELFIES conversion, UI polish, and usage guide
5676c75
raw
history blame
5.72 kB
import os
import time
import json
import numpy as np
import re
import sys
from Bio.PDB import PDBParser, MMCIFParser
sys.path.append(".")
# Get structural seqs from pdb file
def get_struc_seq(foldseek,
path,
chains: list = None,
process_id: int = 0,
plddt_mask: bool = "auto",
plddt_threshold: float = 70.,
foldseek_verbose: bool = False) -> dict:
"""
Args:
foldseek: Binary executable file of foldseek
path: Path to pdb file
chains: Chains to be extracted from pdb file. If None, all chains will be extracted.
process_id: Process ID for temporary files. This is used for parallel processing.
plddt_mask: If True, mask regions with plddt < plddt_threshold. plddt scores are from the pdb file.
plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked.
foldseek_verbose: If True, foldseek will print verbose messages.
Returns:
seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of
(seq, struc_seq, combined_seq).
"""
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
assert os.path.exists(path), f"PDB file not found: {path}"
tmp_save_path = f"get_struc_seq_{process_id}_{time.time()}.tsv"
if foldseek_verbose:
cmd = f"{foldseek} structureto3didescriptor --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
else:
cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}"
os.system(cmd)
# Check whether the structure is predicted by AlphaFold2
if plddt_mask == "auto":
with open(path, "r") as r:
plddt_mask = True if "alphafold" in r.read().lower() else False
seq_dict = {}
name = os.path.basename(path)
with open(tmp_save_path, "r") as r:
for i, line in enumerate(r):
desc, seq, struc_seq = line.split("\t")[:3]
# Mask low plddt
if plddt_mask:
try:
plddts = extract_plddt(path)
assert len(plddts) == len(struc_seq), f"Length mismatch: {len(plddts)} != {len(struc_seq)}"
# Mask regions with plddt < threshold
indices = np.where(plddts < plddt_threshold)[0]
np_seq = np.array(list(struc_seq))
np_seq[indices] = "#"
struc_seq = "".join(np_seq)
except Exception as e:
print(f"Error: {e}")
print(f"Failed to mask plddt for {name}")
name_chain = desc.split(" ")[0]
chain = name_chain.replace(name, "").split("_")[-1]
if chains is None or chain in chains:
if chain not in seq_dict:
combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)])
seq_dict[chain] = (seq, struc_seq, combined_seq)
os.remove(tmp_save_path)
os.remove(tmp_save_path + ".dbtype")
return seq_dict
def extract_plddt(pdb_path: str) -> np.ndarray:
"""
Extract plddt scores from pdb file.
Args:
pdb_path: Path to pdb file.
Returns:
plddts: plddt scores.
"""
# Initialize parser
if pdb_path.endswith(".cif"):
parser = MMCIFParser()
elif pdb_path.endswith(".pdb"):
parser = PDBParser()
else:
raise ValueError("Invalid file format for plddt extraction. Must be '.cif' or '.pdb'.")
structure = parser.get_structure('protein', pdb_path)
model = structure[0]
chain = model["A"]
# Extract plddt scores
plddts = []
for residue in chain:
residue_plddts = []
for atom in residue:
plddt = atom.get_bfactor()
residue_plddts.append(plddt)
plddts.append(np.mean(residue_plddts))
plddts = np.array(plddts)
return plddts
def transform_pdb_dir(foldseek: str, pdb_dir: str, seq_type: str, save_path: str):
"""
Transform a directory of pdb files into a fasta file.
Args:
foldseek: Binary executable file of foldseek.
pdb_dir: Directory of pdb files.
seq_type: Type of sequence to be extracted. Must be "aa" or "foldseek"
save_path: Path to save the fasta file.
"""
assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}"
assert seq_type in ["aa", "foldseek"], f"seq_type must be 'aa' or 'foldseek'!"
tmp_save_path = f"get_struc_seq_{time.time()}.tsv"
cmd = f"{foldseek} structureto3didescriptor --chain-name-mode 1 {pdb_dir} {tmp_save_path}"
os.system(cmd)
with open(tmp_save_path, "r") as r, open(save_path, "w") as w:
for line in r:
protein_id, aa_seq, foldseek_seq = line.strip().split("\t")[:3]
if seq_type == "aa":
w.write(f">{protein_id}\n{aa_seq}\n")
else:
w.write(f">{protein_id}\n{foldseek_seq.lower()}\n")
os.remove(tmp_save_path)
os.remove(tmp_save_path + ".dbtype")
if __name__ == '__main__':
foldseek = "/sujin/bin/foldseek"
# test_path = "/sujin/Datasets/PDB/all/6xtd.cif"
test_path = "/sujin/Datasets/FLIP/meltome/af2_structures/A0A061ACX4.pdb"
plddt_path = "/sujin/Datasets/FLIP/meltome/af2_plddts/A0A061ACX4.json"
res = get_struc_seq(foldseek, test_path, plddt_path=plddt_path, plddt_threshold=70.)
print(res["A"][1].lower())