Spaces:
Runtime error
Runtime error
import biotite | |
import joblib | |
import math | |
import numpy as np | |
import os | |
import scipy.spatial as spa | |
import torch | |
import torch.nn.functional as F | |
from Bio import PDB | |
from Bio.SeqUtils import seq1 | |
from pathlib import Path | |
from torch_geometric.data import Batch, Data | |
from torch_scatter import scatter_mean, scatter_sum, scatter_max | |
from tqdm import tqdm | |
from typing import List | |
from biotite.sequence import ProteinSequence | |
from biotite.structure import filter_backbone, get_chains | |
from biotite.structure.io import pdb, pdbx | |
from biotite.structure.residues import get_residues | |
from .encoder import AutoGraphEncoder | |
def _normalize(tensor, dim=-1): | |
""" | |
Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. | |
""" | |
return torch.nan_to_num( | |
torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True)) | |
) | |
def _rbf(D, D_min=0.0, D_max=20.0, D_count=16, device="cpu"): | |
""" | |
From https://github.com/jingraham/neurips19-graph-protein-design | |
Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. | |
That is, if `D` has shape [...dims], then the returned tensor will have | |
shape [...dims, D_count]. | |
""" | |
D_mu = torch.linspace(D_min, D_max, D_count, device=device) | |
D_mu = D_mu.view([1, -1]) | |
D_sigma = (D_max - D_min) / D_count | |
D_expand = torch.unsqueeze(D, -1) | |
RBF = torch.exp(-(((D_expand - D_mu) / D_sigma) ** 2)) | |
return RBF | |
def _orientations(X_ca): | |
forward = _normalize(X_ca[1:] - X_ca[:-1]) | |
backward = _normalize(X_ca[:-1] - X_ca[1:]) | |
forward = F.pad(forward, [0, 0, 0, 1]) | |
backward = F.pad(backward, [0, 0, 1, 0]) | |
return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2) | |
def _sidechains(X): | |
n, origin, c = X[:, 0], X[:, 1], X[:, 2] | |
c, n = _normalize(c - origin), _normalize(n - origin) | |
bisector = _normalize(c + n) | |
perp = _normalize(torch.cross(c, n)) | |
vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) | |
return vec | |
def _positional_embeddings(edge_index, num_embeddings=16, period_range=[2, 1000]): | |
# From https://github.com/jingraham/neurips19-graph-protein-design | |
d = edge_index[0] - edge_index[1] | |
frequency = torch.exp( | |
torch.arange(0, num_embeddings, 2, dtype=torch.float32) | |
* -(np.log(10000.0) / num_embeddings) | |
) | |
angles = d.unsqueeze(-1) * frequency | |
E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) | |
return E | |
def generate_graph(pdb_file, max_distance=10): | |
""" | |
generate graph data from pdb file | |
params: | |
pdb_file: pdb file path | |
node_level: residue or secondary_structure | |
node_s_type: ss3, ss8, foldseek | |
max_distance: cut off | |
foldseek_fasta_file: foldseek fasta file path | |
foldseek_fasta_multi_chain: pdb multi chain for foldseek fasta | |
return: | |
graph data | |
""" | |
pdb_parser = PDB.PDBParser(QUIET=True) | |
structure = pdb_parser.get_structure("protein", pdb_file) | |
model = structure[0] | |
# extract amino acid sequence | |
seq = [] | |
# extract amino acid coordinates | |
aa_coords = {"N": [], "CA": [], "C": [], "O": []} | |
for model in structure: | |
for chain in model: | |
for residue in chain: | |
if residue.get_id()[0] == " ": | |
seq.append(residue.get_resname()) | |
for atom_name in aa_coords.keys(): | |
atom = residue[atom_name] | |
aa_coords[atom_name].append(atom.get_coord().tolist()) | |
one_letter_seq = "".join([seq1(aa) for aa in seq]) | |
# aa means amino acid | |
coords = list(zip(aa_coords["N"], aa_coords["CA"], aa_coords["C"], aa_coords["O"])) | |
coords = torch.tensor(coords) | |
# mask out the missing coordinates | |
mask = torch.isfinite(coords.sum(dim=(1, 2))) | |
coords[~mask] = np.inf | |
ca_coords = coords[:, 1] | |
node_s = torch.zeros(len(ca_coords), 20) | |
# build graph and max_distance | |
distances = spa.distance_matrix(ca_coords, ca_coords) | |
edge_index = torch.tensor(np.array(np.where(distances < max_distance))) | |
# remove loop | |
mask = edge_index[0] != edge_index[1] | |
edge_index = edge_index[:, mask] | |
# node features | |
orientations = _orientations(ca_coords) | |
sidechains = _sidechains(coords) | |
node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2) | |
# edge features | |
pos_embeddings = _positional_embeddings(edge_index) | |
E_vectors = ca_coords[edge_index[0]] - ca_coords[edge_index[1]] | |
rbf = _rbf(E_vectors.norm(dim=-1), D_count=16) | |
edge_s = torch.cat([rbf, pos_embeddings], dim=-1) | |
edge_v = _normalize(E_vectors).unsqueeze(-2) | |
# node_v: [node_num, 3, 3] | |
# edge_index: [2, edge_num] | |
# edge_s: [edge_num, 16+16] | |
# edge_v: [edge_num, 1, 3] | |
node_s, node_v, edge_s, edge_v = map( | |
torch.nan_to_num, (node_s, node_v, edge_s, edge_v) | |
) | |
data = Data( | |
node_s=node_s, | |
node_v=node_v, | |
edge_index=edge_index, | |
edge_s=edge_s, | |
edge_v=edge_v, | |
distances=distances, | |
aa_seq=one_letter_seq, | |
) | |
return data | |
def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): | |
""" | |
Example for atoms argument: ["N", "CA", "C"] | |
""" | |
def filterfn(s, axis=None): | |
filters = np.stack([s.atom_name == name for name in atoms], axis=1) | |
sum = filters.sum(0) | |
if not np.all(sum <= np.ones(filters.shape[1])): | |
raise RuntimeError("structure has multiple atoms with same name") | |
index = filters.argmax(0) | |
coords = s[index].coord | |
coords[sum == 0] = float("nan") | |
return coords | |
return biotite.structure.apply_residue_wise(struct, struct, filterfn) | |
def extract_coords_from_structure(structure: biotite.structure.AtomArray): | |
""" | |
Args: | |
structure: An instance of biotite AtomArray | |
Returns: | |
Tuple (coords, seq) | |
- coords is an L x 3 x 3 array for N, CA, C coordinates | |
- seq is the extracted sequence | |
""" | |
coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) | |
residue_identities = get_residues(structure)[1] | |
seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) | |
return coords | |
def extract_seq_from_pdb(pdb_file, chain=None): | |
""" | |
Args: | |
structure: An instance of biotite AtomArray | |
Returns: | |
- seq is the extracted sequence | |
""" | |
structure = load_structure(pdb_file, chain) | |
residue_identities = get_residues(structure)[1] | |
seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) | |
return seq | |
def generate_pos_subgraph( | |
graph_data, | |
subgraph_depth=None, | |
subgraph_interval=1, | |
max_distance=10, | |
anchor_nodes=None, | |
pure_subgraph=False, | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
): | |
# move graph_data to GPU | |
graph_data = Data( | |
node_s=graph_data.node_s.to(device) if torch.is_tensor(graph_data.node_s) else torch.tensor(graph_data.node_s, device=device), | |
node_v=graph_data.node_v.to(device) if torch.is_tensor(graph_data.node_v) else torch.tensor(graph_data.node_v, device=device), | |
edge_index=graph_data.edge_index.to(device) if torch.is_tensor(graph_data.edge_index) else torch.tensor(graph_data.edge_index, device=device), | |
edge_s=graph_data.edge_s.to(device) if torch.is_tensor(graph_data.edge_s) else torch.tensor(graph_data.edge_s, device=device), | |
edge_v=graph_data.edge_v.to(device) if torch.is_tensor(graph_data.edge_v) else torch.tensor(graph_data.edge_v, device=device), | |
distances=graph_data.distances.to(device) if torch.is_tensor(graph_data.distances) else torch.tensor(graph_data.distances, device=device), | |
aa_seq=graph_data.aa_seq | |
) | |
distances = graph_data.distances | |
if subgraph_depth is None: | |
subgraph_depth = 50 | |
# Calculate anchor nodes if not provided | |
if anchor_nodes is None: | |
anchor_nodes = list(range(0, len(graph_data.aa_seq), subgraph_interval)) | |
anchor_nodes_tensor = torch.tensor(anchor_nodes, device=device) # Move anchor nodes to device | |
# Get the k nearest neighbors for ALL anchor nodes (batched) | |
k = 50 | |
nearest_indices = torch.argsort(distances, dim=1)[:, :k] # (num_nodes, k) | |
distance_mask = torch.gather(distances, 1, nearest_indices) < max_distance # (num_nodes, k) | |
nearest_indices = torch.where(distance_mask, nearest_indices, torch.tensor(-1, device=device)) # (num_nodes, k) | |
subgraph_dict = {} | |
for anchor_node in anchor_nodes: #Reverted back to for loop to ensure everything works with batches | |
try: | |
#Get neighbors for each anchornode | |
k_neighbors = nearest_indices[anchor_node] | |
k_neighbors = k_neighbors[k_neighbors != -1] | |
if len(k_neighbors) == 0: # Skip if no neighbors found | |
continue | |
if len(k_neighbors) > 30: | |
k_neighbors = k_neighbors[:40] | |
k_neighbors, _ = torch.sort(k_neighbors) | |
sub_matrix = distances.index_select(0, k_neighbors).index_select(1, k_neighbors) | |
# Create edge indices efficiently | |
sub_edges = torch.nonzero(sub_matrix < max_distance, as_tuple=False) | |
mask = sub_edges[:, 0] != sub_edges[:, 1] | |
sub_edge_index = sub_edges[mask] | |
if len(sub_edge_index) == 0: # Skip if no edges found | |
continue | |
# Move edge_index to GPU only when needed | |
edge_index_device = graph_data.edge_index.to(device) | |
original_edge_index = k_neighbors[sub_edge_index] | |
# More memory efficient edge matching | |
matches = [] | |
for edge in original_edge_index: | |
match = (edge_index_device[0] == edge[0]) & (edge_index_device[1] == edge[1]) | |
matches.append(match) | |
matches = torch.stack(matches) | |
edge_to_feature_idx = torch.nonzero(matches, as_tuple=True)[0].to(device) | |
if len(edge_to_feature_idx) == 0: # Skip if no matching edges | |
continue | |
#Create data | |
new_node_s = graph_data.node_s[k_neighbors].to(device) | |
new_node_v = graph_data.node_v[k_neighbors].to(device) | |
new_edge_s = graph_data.edge_s[edge_to_feature_idx].to(device) | |
new_edge_v = graph_data.edge_v[edge_to_feature_idx].to(device) | |
result = Data( | |
edge_index=sub_edge_index.T.to(device), | |
edge_s=new_edge_s.to(device), | |
edge_v=new_edge_v.to(device), | |
node_s=new_node_s.to(device), | |
node_v=new_node_v.to(device), | |
) | |
if not pure_subgraph: | |
result.index_map = { | |
int(old_id.to(device).item()): new_id | |
for new_id, old_id in enumerate(k_neighbors) | |
} | |
subgraph_dict[anchor_node] = result | |
except Exception as e: | |
print(f"Error processing anchor node {anchor_node}: {str(e)}") | |
continue | |
return subgraph_dict | |
def load_structure(fpath, chain=None): | |
""" | |
Args: | |
fpath: filepath to either pdb or cif file | |
chain: the chain id or list of chain ids to load | |
Returns: | |
biotite.structure.AtomArray | |
""" | |
if fpath.endswith("cif"): | |
with open(fpath) as fin: | |
pdbxf = pdbx.PDBxFile.read(fin) | |
structure = pdbx.get_structure(pdbxf, model=1) | |
elif fpath.endswith("pdb"): | |
with open(fpath) as fin: | |
pdbf = pdb.PDBFile.read(fin) | |
structure = pdb.get_structure(pdbf, model=1) | |
bbmask = filter_backbone(structure) | |
structure = structure[bbmask] | |
all_chains = get_chains(structure) | |
if len(all_chains) == 0: | |
raise ValueError("No chains found in the input file.") | |
if chain is None: | |
chain_ids = all_chains | |
elif isinstance(chain, list): | |
chain_ids = chain | |
else: | |
chain_ids = [chain] | |
for chain in chain_ids: | |
if chain not in all_chains: | |
raise ValueError(f"Chain {chain} not found in input file") | |
chain_filter = [a.chain_id in chain_ids for a in structure] | |
structure = structure[chain_filter] | |
return structure | |
def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray): | |
""" | |
Example for atoms argument: ["N", "CA", "C"] | |
""" | |
def filterfn(s, axis=None): | |
filters = np.stack([s.atom_name == name for name in atoms], axis=1) | |
sum = filters.sum(0) | |
if not np.all(sum <= np.ones(filters.shape[1])): | |
raise RuntimeError("structure has multiple atoms with same name") | |
index = filters.argmax(0) | |
coords = s[index].coord | |
coords[sum == 0] = float("nan") | |
return coords | |
return biotite.structure.apply_residue_wise(struct, struct, filterfn) | |
def extract_coords_from_structure(structure: biotite.structure.AtomArray): | |
""" | |
Args: | |
structure: An instance of biotite AtomArray | |
Returns: | |
Tuple (coords, seq) | |
- coords is an L x 3 x 3 array for N, CA, C coordinates | |
- seq is the extracted sequence | |
""" | |
coords = get_atom_coords_residuewise(["N", "CA", "C"], structure) | |
residue_identities = get_residues(structure)[1] | |
seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) | |
return coords | |
def extract_seq_from_pdb(pdb_file, chain=None): | |
""" | |
Args: | |
structure: An instance of biotite AtomArray | |
Returns: | |
- seq is the extracted sequence | |
""" | |
structure = load_structure(pdb_file, chain) | |
residue_identities = get_residues(structure)[1] | |
seq = "".join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities]) | |
return seq | |
def convert_graph(graph): | |
graph = Data( | |
node_s=graph.node_s.to(torch.float32), | |
node_v=graph.node_v.to(torch.float32), | |
edge_index=graph.edge_index.to(torch.int64), | |
edge_s=graph.edge_s.to(torch.float32), | |
edge_v=graph.edge_v.to(torch.float32), | |
) | |
return graph | |
def predict_structure(model, cluster_models, dataloader, datalabels, device): | |
epoch_iterator = dataloader | |
struc_label_dict = {} | |
cluster_model_dict = {} | |
for cluster_model_path in cluster_models: | |
cluster_model_name = cluster_model_path.split("/")[-1].split(".")[0] | |
struc_label_dict[cluster_model_name] = {} | |
cluster_model_dict[cluster_model_name] = joblib.load(cluster_model_path) | |
with torch.no_grad(): | |
for batch, label_dict in zip(epoch_iterator, datalabels): | |
batch.to(device) | |
h_V = (batch.node_s, batch.node_v) | |
h_E = (batch.edge_s, batch.edge_v) | |
node_emebddings = model.get_embedding(h_V, batch.edge_index, h_E) | |
graph_emebddings = scatter_mean(node_emebddings, batch.batch, dim=0).to(device) | |
norm_graph_emebddings = F.normalize(graph_emebddings, p=2, dim=1) | |
struc_label_dict[cluster_model_name][label_dict['name']]={} | |
for name, cluster_model in cluster_model_dict.items(): | |
batch_structure_labels = cluster_model.predict( | |
norm_graph_emebddings.cpu() | |
).tolist() | |
struc_label_dict[name][label_dict['name']]['seq']=label_dict['aa_seq'] | |
struc_label_dict[name][label_dict['name']]['struct']=batch_structure_labels | |
return struc_label_dict | |
def get_embeds(model, dataloader, device, pooling="mean"): | |
epoch_iterator = tqdm(dataloader) | |
embeds = [] | |
with torch.no_grad(): | |
for batch in epoch_iterator: | |
batch.to(device) | |
h_V = (batch.node_s, batch.node_v) | |
h_E = (batch.edge_s, batch.edge_v) | |
node_embeds = model.get_embedding(h_V, batch.edge_index, h_E).cpu() | |
if pooling == "mean": | |
graph_embeds = scatter_mean(node_embeds, batch.batch.cpu(), dim=0) | |
elif pooling == "sum": | |
graph_embeds = scatter_sum(node_embeds, batch.batch.cpu(), dim=0) | |
elif pooling == "max": | |
graph_embeds, _ = scatter_max(node_embeds, batch.batch.cpu(), dim=0) | |
else: | |
raise ValueError("pooling should be mean, sum or max") | |
embeds.append(graph_embeds) | |
embeds = torch.cat(embeds, dim=0) | |
norm_embeds = F.normalize(embeds, p=2, dim=1) | |
return norm_embeds | |
def process_pdb_file( | |
pdb_file, | |
subgraph_depth, | |
subgraph_interval, | |
max_distance, | |
device="cuda" if torch.cuda.is_available() else "cpu" | |
): | |
result_dict, subgraph_dict = {}, {} | |
result_dict["name"] = Path(pdb_file).name | |
try: | |
graph = generate_graph(pdb_file, max_distance) | |
except Exception as e: | |
print(f"Error in processing {pdb_file}") | |
result_dict["error"] = str(e) | |
return None, result_dict, 0 | |
result_dict["aa_seq"] = graph.aa_seq | |
anchor_nodes = list(range(0, len(graph.node_s), subgraph_interval)) #Define anchor nodes | |
try: #Run subgraph generation | |
subgraph_dict = generate_pos_subgraph( | |
graph, | |
subgraph_depth, | |
subgraph_interval, | |
max_distance, | |
anchor_nodes=anchor_nodes, | |
pure_subgraph=True, | |
device=device | |
) | |
#Move all subgraphs to GPU | |
for key in subgraph_dict.keys(): | |
subgraph_dict[key] = convert_graph(subgraph_dict[key]) | |
except Exception as e: | |
print(f"Error processing subgraph {e}") | |
return None, result_dict, 0 | |
subgraph_dict = dict(sorted(subgraph_dict.items(), key=lambda x: x[0])) | |
subgraphs = list(subgraph_dict.values()) | |
return subgraphs, result_dict, len(anchor_nodes) | |
def pdb_converter( | |
pdb_files, | |
subgraph_depth, | |
subgraph_interval, | |
max_distance, | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
batch_size=32 | |
): | |
error_proteins, error_messages = [], [] | |
dataset, results, node_counts = [], [], [] | |
for i in tqdm(range(0, len(pdb_files), batch_size), desc="Processing PDB files"): | |
batch = pdb_files[i:i + batch_size] | |
for pdb_file in batch: | |
pdb_subgraphs, result_dict, node_count = process_pdb_file( | |
pdb_file, | |
subgraph_depth, | |
subgraph_interval, | |
max_distance, | |
device=device | |
) | |
if pdb_subgraphs is None: | |
error_proteins.append(result_dict["name"]) | |
error_messages.append(result_dict["error"]) | |
continue | |
dataset.append(pdb_subgraphs) | |
results.append(result_dict) | |
node_counts.append(node_count) | |
if error_proteins: | |
print(f"Found {len(error_proteins)} errors:") | |
for name, msg in zip(error_proteins, error_messages): | |
print(f"{name}: {msg}") | |
def collate_fn(batch): | |
batch_graphs = [] | |
for d in batch: | |
batch_graphs.extend(d) | |
batch_graphs = Batch.from_data_list(batch_graphs) | |
batch_graphs.node_s = torch.zeros_like(batch_graphs.node_s) | |
return batch_graphs | |
def data_loader(): | |
for item in dataset: | |
yield collate_fn([item]) | |
return data_loader(), results | |
class PdbQuantizer: | |
def __init__( | |
self, | |
structure_vocab_size=2048, | |
max_distance=10, | |
subgraph_depth=None, | |
subgraph_interval=1, | |
anchor_nodes=None, | |
model_path=None, | |
cluster_dir=None, | |
cluster_model=None, | |
device=None, | |
batch_size=16, | |
) -> None: | |
assert structure_vocab_size in [20, 64, 128, 512, 1024, 2048, 4096] | |
self.batch_size = batch_size | |
self.max_distance = max_distance | |
self.subgraph_depth = subgraph_depth | |
self.subgraph_interval = subgraph_interval | |
self.anchor_nodes = anchor_nodes | |
if model_path is None: | |
self.model_path = str(Path(__file__).parent / "static" / "AE.pt") | |
else: | |
self.model_path = model_path | |
self.structure_vocab_size = structure_vocab_size | |
if cluster_dir is None: | |
self.cluster_dir = str(Path(__file__).parent / "static") | |
self.cluster_model = [ | |
Path(self.cluster_dir) / f"{structure_vocab_size}.joblib", | |
] | |
else: | |
self.cluster_dir = cluster_dir | |
self.cluster_model = cluster_model | |
if device is None: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
else: | |
self.device = device | |
# Load model | |
node_dim = (256, 32) | |
edge_dim = (64, 2) | |
model = AutoGraphEncoder( | |
node_in_dim=(20, 3), | |
node_h_dim=node_dim, | |
edge_in_dim=(32, 1), | |
edge_h_dim=edge_dim, | |
num_layers=6, | |
) | |
model.load_state_dict(torch.load(self.model_path)) | |
model = model.to(self.device) | |
model = model.eval() | |
self.model = model | |
self.cluster_models = [ | |
os.path.join(self.cluster_dir, m) for m in self.cluster_model | |
] | |
def __call__(self, pdb_files, return_residue_seq=False): | |
if isinstance(pdb_files, str): | |
pdb_files = [pdb_files] | |
elif isinstance(pdb_files, list): | |
pass | |
else: | |
raise ValueError("pdb_files should be either a string or a list of strings") | |
data_loader, results = pdb_converter( | |
pdb_files, | |
self.subgraph_depth, | |
self.subgraph_interval, | |
self.max_distance, | |
device=self.device, | |
batch_size=self.batch_size | |
) | |
structures = predict_structure( | |
self.model, self.cluster_models, data_loader, results, self.device | |
) | |
if not return_residue_seq: | |
for clusterModelLabels in structures.keys(): | |
for structureDict in structures[clusterModelLabels].keys(): | |
structures[clusterModelLabels][structureDict].pop('seq', None) | |
return structures | |