2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import biotite
import joblib
import math
import numpy as np
import os
import scipy.spatial as spa
import torch
import torch.nn.functional as F
from Bio import PDB
from Bio.SeqUtils import seq1
from pathlib import Path
from torch_geometric.data import Batch, Data
from torch_scatter import scatter_mean, scatter_sum, scatter_max
from tqdm import tqdm
from typing import List
from biotite.sequence import ProteinSequence
from biotite.structure import filter_backbone, get_chains
from biotite.structure.io import pdb, pdbx
from biotite.structure.residues import get_residues
from .encoder import AutoGraphEncoder
def _normalize(tensor, dim=-1):
"""
Normalizes a `torch.Tensor` along dimension `dim` without `nan`s.
"""
return torch.nan_to_num(
torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))
)
def _rbf(D, D_min=0.0, D_max=20.0, D_count=16, device="cpu"):
"""
From https://github.com/jingraham/neurips19-graph-protein-design
Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1.
That is, if `D` has shape [...dims], then the returned tensor will have
shape [...dims, D_count].
"""
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
D_mu = D_mu.view([1, -1])
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2))
return RBF
def _orientations(X_ca):
forward = _normalize(X_ca[1:] - X_ca[:-1])
backward = _normalize(X_ca[:-1] - X_ca[1:])
forward = F.pad(forward, [0, 0, 0, 1])
backward = F.pad(backward, [0, 0, 1, 0])
return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
def _sidechains(X):
n, origin, c = X[:, 0], X[:, 1], X[:, 2]
c, n = _normalize(c - origin), _normalize(n - origin)
bisector = _normalize(c + n)
perp = _normalize(torch.cross(c, n))
vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
return vec
def _positional_embeddings(edge_index, num_embeddings=16, period_range=[2, 1000]):
# From https://github.com/jingraham/neurips19-graph-protein-design
d = edge_index[0] - edge_index[1]
frequency = torch.exp(
torch.arange(0, num_embeddings, 2, dtype=torch.float32)
* -(np.log(10000.0) / num_embeddings)
)
angles = d.unsqueeze(-1) * frequency
E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
return E
def generate_graph(pdb_file, max_distance=10):
"""
generate graph data from pdb file
params:
pdb_file: pdb file path
node_level: residue or secondary_structure
node_s_type: ss3, ss8, foldseek
max_distance: cut off
foldseek_fasta_file: foldseek fasta file path
foldseek_fasta_multi_chain: pdb multi chain for foldseek fasta
return:
graph data
"""
pdb_parser = PDB.PDBParser(QUIET=True)
structure = pdb_parser.get_structure("protein", pdb_file)
model = structure[0]
# extract amino acid sequence
seq = []
# extract amino acid coordinates
aa_coords = {"N": [], "CA": [], "C": [], "O": []}
for model in structure:
for chain in model:
for residue in chain:
if residue.get_id()[0] == " ":
seq.append(residue.get_resname())
for atom_name in aa_coords.keys():
atom = residue[atom_name]
aa_coords[atom_name].append(atom.get_coord().tolist())
one_letter_seq = "".join([seq1(aa) for aa in seq])
# aa means amino acid
coords = list(zip(aa_coords["N"], aa_coords["CA"], aa_coords["C"], aa_coords["O"]))
coords = torch.tensor(coords)
# mask out the missing coordinates
mask = torch.isfinite(coords.sum(dim=(1, 2)))
coords[~mask] = np.inf
ca_coords = coords[:, 1]
node_s = torch.zeros(len(ca_coords), 20)
# build graph and max_distance
distances = spa.distance_matrix(ca_coords, ca_coords)
edge_index = torch.tensor(np.array(np.where(distances < max_distance)))
# remove loop
mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]
# node features
orientations = _orientations(ca_coords)
sidechains = _sidechains(coords)
node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2)
# edge features
pos_embeddings = _positional_embeddings(edge_index)
E_vectors = ca_coords[edge_index[0]] - ca_coords[edge_index[1]]
rbf = _rbf(E_vectors.norm(dim=-1), D_count=16)
edge_s = torch.cat([rbf, pos_embeddings], dim=-1)
edge_v = _normalize(E_vectors).unsqueeze(-2)
# node_v: [node_num, 3, 3]
# edge_index: [2, edge_num]
# edge_s: [edge_num, 16+16]
# edge_v: [edge_num, 1, 3]
node_s, node_v, edge_s, edge_v = map(
torch.nan_to_num, (node_s, node_v, edge_s, edge_v)
)
data = Data(
node_s=node_s,
node_v=node_v,
edge_index=edge_index,
edge_s=edge_s,
edge_v=edge_v,
distances=distances,
aa_seq=one_letter_seq,
)
return data
def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
"""
Example for atoms argument: ["N", "CA", "C"]
"""
def filterfn(s, axis=None):
filters = np.stack([s.atom_name == name for name in atoms], axis=1)
sum = filters.sum(0)
if not np.all(sum <= np.ones(filters.shape[1])):
raise RuntimeError("structure has multiple atoms with same name")
index = filters.argmax(0)
coords = s[index].coord
coords[sum == 0] = float("nan")
return coords
return biotite.structure.apply_residue_wise(struct, struct, filterfn)
def extract_coords_from_structure(structure: biotite.structure.AtomArray):
"""
Args:
structure: An instance of biotite AtomArray
Returns:
Tuple (coords, seq)
- coords is an L x 3 x 3 array for N, CA, C coordinates
- seq is the extracted sequence
"""
coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
residue_identities = get_residues(structure)[1]
seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
return coords
def extract_seq_from_pdb(pdb_file, chain=None):
"""
Args:
structure: An instance of biotite AtomArray
Returns:
- seq is the extracted sequence
"""
structure = load_structure(pdb_file, chain)
residue_identities = get_residues(structure)[1]
seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
return seq
def generate_pos_subgraph(
graph_data,
subgraph_depth=None,
subgraph_interval=1,
max_distance=10,
anchor_nodes=None,
pure_subgraph=False,
device="cuda" if torch.cuda.is_available() else "cpu"
):
# move graph_data to GPU
graph_data = Data(
node_s=graph_data.node_s.to(device) if torch.is_tensor(graph_data.node_s) else torch.tensor(graph_data.node_s, device=device),
node_v=graph_data.node_v.to(device) if torch.is_tensor(graph_data.node_v) else torch.tensor(graph_data.node_v, device=device),
edge_index=graph_data.edge_index.to(device) if torch.is_tensor(graph_data.edge_index) else torch.tensor(graph_data.edge_index, device=device),
edge_s=graph_data.edge_s.to(device) if torch.is_tensor(graph_data.edge_s) else torch.tensor(graph_data.edge_s, device=device),
edge_v=graph_data.edge_v.to(device) if torch.is_tensor(graph_data.edge_v) else torch.tensor(graph_data.edge_v, device=device),
distances=graph_data.distances.to(device) if torch.is_tensor(graph_data.distances) else torch.tensor(graph_data.distances, device=device),
aa_seq=graph_data.aa_seq
)
distances = graph_data.distances
if subgraph_depth is None:
subgraph_depth = 50
# Calculate anchor nodes if not provided
if anchor_nodes is None:
anchor_nodes = list(range(0, len(graph_data.aa_seq), subgraph_interval))
anchor_nodes_tensor = torch.tensor(anchor_nodes, device=device) # Move anchor nodes to device
# Get the k nearest neighbors for ALL anchor nodes (batched)
k = 50
nearest_indices = torch.argsort(distances, dim=1)[:, :k] # (num_nodes, k)
distance_mask = torch.gather(distances, 1, nearest_indices) < max_distance # (num_nodes, k)
nearest_indices = torch.where(distance_mask, nearest_indices, torch.tensor(-1, device=device)) # (num_nodes, k)
subgraph_dict = {}
for anchor_node in anchor_nodes: #Reverted back to for loop to ensure everything works with batches
try:
#Get neighbors for each anchornode
k_neighbors = nearest_indices[anchor_node]
k_neighbors = k_neighbors[k_neighbors != -1]
if len(k_neighbors) == 0: # Skip if no neighbors found
continue
if len(k_neighbors) > 30:
k_neighbors = k_neighbors[:40]
k_neighbors, _ = torch.sort(k_neighbors)
sub_matrix = distances.index_select(0, k_neighbors).index_select(1, k_neighbors)
# Create edge indices efficiently
sub_edges = torch.nonzero(sub_matrix < max_distance, as_tuple=False)
mask = sub_edges[:, 0] != sub_edges[:, 1]
sub_edge_index = sub_edges[mask]
if len(sub_edge_index) == 0: # Skip if no edges found
continue
# Move edge_index to GPU only when needed
edge_index_device = graph_data.edge_index.to(device)
original_edge_index = k_neighbors[sub_edge_index]
# More memory efficient edge matching
matches = []
for edge in original_edge_index:
match = (edge_index_device[0] == edge[0]) & (edge_index_device[1] == edge[1])
matches.append(match)
matches = torch.stack(matches)
edge_to_feature_idx = torch.nonzero(matches, as_tuple=True)[0].to(device)
if len(edge_to_feature_idx) == 0: # Skip if no matching edges
continue
#Create data
new_node_s = graph_data.node_s[k_neighbors].to(device)
new_node_v = graph_data.node_v[k_neighbors].to(device)
new_edge_s = graph_data.edge_s[edge_to_feature_idx].to(device)
new_edge_v = graph_data.edge_v[edge_to_feature_idx].to(device)
result = Data(
edge_index=sub_edge_index.T.to(device),
edge_s=new_edge_s.to(device),
edge_v=new_edge_v.to(device),
node_s=new_node_s.to(device),
node_v=new_node_v.to(device),
)
if not pure_subgraph:
result.index_map = {
int(old_id.to(device).item()): new_id
for new_id, old_id in enumerate(k_neighbors)
}
subgraph_dict[anchor_node] = result
except Exception as e:
print(f"Error processing anchor node {anchor_node}: {str(e)}")
continue
return subgraph_dict
def load_structure(fpath, chain=None):
"""
Args:
fpath: filepath to either pdb or cif file
chain: the chain id or list of chain ids to load
Returns:
biotite.structure.AtomArray
"""
if fpath.endswith("cif"):
with open(fpath) as fin:
pdbxf = pdbx.PDBxFile.read(fin)
structure = pdbx.get_structure(pdbxf, model=1)
elif fpath.endswith("pdb"):
with open(fpath) as fin:
pdbf = pdb.PDBFile.read(fin)
structure = pdb.get_structure(pdbf, model=1)
bbmask = filter_backbone(structure)
structure = structure[bbmask]
all_chains = get_chains(structure)
if len(all_chains) == 0:
raise ValueError("No chains found in the input file.")
if chain is None:
chain_ids = all_chains
elif isinstance(chain, list):
chain_ids = chain
else:
chain_ids = [chain]
for chain in chain_ids:
if chain not in all_chains:
raise ValueError(f"Chain {chain} not found in input file")
chain_filter = [a.chain_id in chain_ids for a in structure]
structure = structure[chain_filter]
return structure
def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
"""
Example for atoms argument: ["N", "CA", "C"]
"""
def filterfn(s, axis=None):
filters = np.stack([s.atom_name == name for name in atoms], axis=1)
sum = filters.sum(0)
if not np.all(sum <= np.ones(filters.shape[1])):
raise RuntimeError("structure has multiple atoms with same name")
index = filters.argmax(0)
coords = s[index].coord
coords[sum == 0] = float("nan")
return coords
return biotite.structure.apply_residue_wise(struct, struct, filterfn)
def extract_coords_from_structure(structure: biotite.structure.AtomArray):
"""
Args:
structure: An instance of biotite AtomArray
Returns:
Tuple (coords, seq)
- coords is an L x 3 x 3 array for N, CA, C coordinates
- seq is the extracted sequence
"""
coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
residue_identities = get_residues(structure)[1]
seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
return coords
def extract_seq_from_pdb(pdb_file, chain=None):
"""
Args:
structure: An instance of biotite AtomArray
Returns:
- seq is the extracted sequence
"""
structure = load_structure(pdb_file, chain)
residue_identities = get_residues(structure)[1]
seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
return seq
def convert_graph(graph):
graph = Data(
node_s=graph.node_s.to(torch.float32),
node_v=graph.node_v.to(torch.float32),
edge_index=graph.edge_index.to(torch.int64),
edge_s=graph.edge_s.to(torch.float32),
edge_v=graph.edge_v.to(torch.float32),
)
return graph
def predict_structure(model, cluster_models, dataloader, datalabels, device):
epoch_iterator = dataloader
struc_label_dict = {}
cluster_model_dict = {}
for cluster_model_path in cluster_models:
cluster_model_name = cluster_model_path.split("/")[-1].split(".")[0]
struc_label_dict[cluster_model_name] = {}
cluster_model_dict[cluster_model_name] = joblib.load(cluster_model_path)
with torch.no_grad():
for batch, label_dict in zip(epoch_iterator, datalabels):
batch.to(device)
h_V = (batch.node_s, batch.node_v)
h_E = (batch.edge_s, batch.edge_v)
node_emebddings = model.get_embedding(h_V, batch.edge_index, h_E)
graph_emebddings = scatter_mean(node_emebddings, batch.batch, dim=0).to(device)
norm_graph_emebddings = F.normalize(graph_emebddings, p=2, dim=1)
struc_label_dict[cluster_model_name][label_dict['name']]={}
for name, cluster_model in cluster_model_dict.items():
batch_structure_labels = cluster_model.predict(
norm_graph_emebddings.cpu()
).tolist()
struc_label_dict[name][label_dict['name']]['seq']=label_dict['aa_seq']
struc_label_dict[name][label_dict['name']]['struct']=batch_structure_labels
return struc_label_dict
def get_embeds(model, dataloader, device, pooling="mean"):
epoch_iterator = tqdm(dataloader)
embeds = []
with torch.no_grad():
for batch in epoch_iterator:
batch.to(device)
h_V = (batch.node_s, batch.node_v)
h_E = (batch.edge_s, batch.edge_v)
node_embeds = model.get_embedding(h_V, batch.edge_index, h_E).cpu()
if pooling == "mean":
graph_embeds = scatter_mean(node_embeds, batch.batch.cpu(), dim=0)
elif pooling == "sum":
graph_embeds = scatter_sum(node_embeds, batch.batch.cpu(), dim=0)
elif pooling == "max":
graph_embeds, _ = scatter_max(node_embeds, batch.batch.cpu(), dim=0)
else:
raise ValueError("pooling should be mean, sum or max")
embeds.append(graph_embeds)
embeds = torch.cat(embeds, dim=0)
norm_embeds = F.normalize(embeds, p=2, dim=1)
return norm_embeds
def process_pdb_file(
pdb_file,
subgraph_depth,
subgraph_interval,
max_distance,
device="cuda" if torch.cuda.is_available() else "cpu"
):
result_dict, subgraph_dict = {}, {}
result_dict["name"] = Path(pdb_file).name
try:
graph = generate_graph(pdb_file, max_distance)
except Exception as e:
print(f"Error in processing {pdb_file}")
result_dict["error"] = str(e)
return None, result_dict, 0
result_dict["aa_seq"] = graph.aa_seq
anchor_nodes = list(range(0, len(graph.node_s), subgraph_interval)) #Define anchor nodes
try: #Run subgraph generation
subgraph_dict = generate_pos_subgraph(
graph,
subgraph_depth,
subgraph_interval,
max_distance,
anchor_nodes=anchor_nodes,
pure_subgraph=True,
device=device
)
#Move all subgraphs to GPU
for key in subgraph_dict.keys():
subgraph_dict[key] = convert_graph(subgraph_dict[key])
except Exception as e:
print(f"Error processing subgraph {e}")
return None, result_dict, 0
subgraph_dict = dict(sorted(subgraph_dict.items(), key=lambda x: x[0]))
subgraphs = list(subgraph_dict.values())
return subgraphs, result_dict, len(anchor_nodes)
def pdb_converter(
pdb_files,
subgraph_depth,
subgraph_interval,
max_distance,
device="cuda" if torch.cuda.is_available() else "cpu",
batch_size=32
):
error_proteins, error_messages = [], []
dataset, results, node_counts = [], [], []
for i in tqdm(range(0, len(pdb_files), batch_size), desc="Processing PDB files"):
batch = pdb_files[i:i + batch_size]
for pdb_file in batch:
pdb_subgraphs, result_dict, node_count = process_pdb_file(
pdb_file,
subgraph_depth,
subgraph_interval,
max_distance,
device=device
)
if pdb_subgraphs is None:
error_proteins.append(result_dict["name"])
error_messages.append(result_dict["error"])
continue
dataset.append(pdb_subgraphs)
results.append(result_dict)
node_counts.append(node_count)
if error_proteins:
print(f"Found {len(error_proteins)} errors:")
for name, msg in zip(error_proteins, error_messages):
print(f"{name}: {msg}")
def collate_fn(batch):
batch_graphs = []
for d in batch:
batch_graphs.extend(d)
batch_graphs = Batch.from_data_list(batch_graphs)
batch_graphs.node_s = torch.zeros_like(batch_graphs.node_s)
return batch_graphs
def data_loader():
for item in dataset:
yield collate_fn([item])
return data_loader(), results
class PdbQuantizer:
def __init__(
self,
structure_vocab_size=2048,
max_distance=10,
subgraph_depth=None,
subgraph_interval=1,
anchor_nodes=None,
model_path=None,
cluster_dir=None,
cluster_model=None,
device=None,
batch_size=16,
) -> None:
assert structure_vocab_size in [20, 64, 128, 512, 1024, 2048, 4096]
self.batch_size = batch_size
self.max_distance = max_distance
self.subgraph_depth = subgraph_depth
self.subgraph_interval = subgraph_interval
self.anchor_nodes = anchor_nodes
if model_path is None:
self.model_path = str(Path(__file__).parent / "static" / "AE.pt")
else:
self.model_path = model_path
self.structure_vocab_size = structure_vocab_size
if cluster_dir is None:
self.cluster_dir = str(Path(__file__).parent / "static")
self.cluster_model = [
Path(self.cluster_dir) / f"{structure_vocab_size}.joblib",
]
else:
self.cluster_dir = cluster_dir
self.cluster_model = cluster_model
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
# Load model
node_dim = (256, 32)
edge_dim = (64, 2)
model = AutoGraphEncoder(
node_in_dim=(20, 3),
node_h_dim=node_dim,
edge_in_dim=(32, 1),
edge_h_dim=edge_dim,
num_layers=6,
)
model.load_state_dict(torch.load(self.model_path))
model = model.to(self.device)
model = model.eval()
self.model = model
self.cluster_models = [
os.path.join(self.cluster_dir, m) for m in self.cluster_model
]
def __call__(self, pdb_files, return_residue_seq=False):
if isinstance(pdb_files, str):
pdb_files = [pdb_files]
elif isinstance(pdb_files, list):
pass
else:
raise ValueError("pdb_files should be either a string or a list of strings")
data_loader, results = pdb_converter(
pdb_files,
self.subgraph_depth,
self.subgraph_interval,
self.max_distance,
device=self.device,
batch_size=self.batch_size
)
structures = predict_structure(
self.model, self.cluster_models, data_loader, results, self.device
)
if not return_residue_seq:
for clusterModelLabels in structures.keys():
for structureDict in structures[clusterModelLabels].keys():
structures[clusterModelLabels][structureDict].pop('seq', None)
return structures