|
import binascii |
|
import glob |
|
import hashlib |
|
import os |
|
import pickle |
|
from collections import defaultdict |
|
from multiprocessing import Pool |
|
import random |
|
import copy |
|
|
|
import numpy as np |
|
import torch |
|
from rdkit.Chem import MolToSmiles, MolFromSmiles, AddHs |
|
from torch_geometric.data import Dataset, HeteroData |
|
from torch_geometric.loader import DataLoader, DataListLoader |
|
from torch_geometric.transforms import BaseTransform |
|
from tqdm import tqdm |
|
|
|
from datasets.process_mols import ( |
|
read_molecule, |
|
get_rec_graph, |
|
generate_conformer, |
|
get_lig_graph_with_matching, |
|
extract_receptor_structure, |
|
parse_receptor, |
|
parse_pdb_from_path, |
|
) |
|
from utils.diffusion_utils import modify_conformer, set_time |
|
from utils.utils import read_strings_from_txt |
|
from utils import so3, torus |
|
|
|
|
|
class NoiseTransform(BaseTransform): |
|
def __init__(self, t_to_sigma, no_torsion, all_atom): |
|
self.t_to_sigma = t_to_sigma |
|
self.no_torsion = no_torsion |
|
self.all_atom = all_atom |
|
|
|
def __call__(self, data): |
|
t = np.random.uniform() |
|
t_tr, t_rot, t_tor = t, t, t |
|
return self.apply_noise(data, t_tr, t_rot, t_tor) |
|
|
|
def apply_noise( |
|
self, |
|
data, |
|
t_tr, |
|
t_rot, |
|
t_tor, |
|
tr_update=None, |
|
rot_update=None, |
|
torsion_updates=None, |
|
): |
|
if not torch.is_tensor(data["ligand"].pos): |
|
data["ligand"].pos = random.choice(data["ligand"].pos) |
|
|
|
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor) |
|
set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None) |
|
|
|
tr_update = ( |
|
torch.normal(mean=0, std=tr_sigma, size=(1, 3)) |
|
if tr_update is None |
|
else tr_update |
|
) |
|
rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update |
|
torsion_updates = ( |
|
np.random.normal( |
|
loc=0.0, scale=tor_sigma, size=data["ligand"].edge_mask.sum() |
|
) |
|
if torsion_updates is None |
|
else torsion_updates |
|
) |
|
torsion_updates = None if self.no_torsion else torsion_updates |
|
modify_conformer( |
|
data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates |
|
) |
|
|
|
data.tr_score = -tr_update / tr_sigma**2 |
|
data.rot_score = ( |
|
torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)) |
|
.float() |
|
.unsqueeze(0) |
|
) |
|
data.tor_score = ( |
|
None |
|
if self.no_torsion |
|
else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float() |
|
) |
|
data.tor_sigma_edge = ( |
|
None |
|
if self.no_torsion |
|
else np.ones(data["ligand"].edge_mask.sum()) * tor_sigma |
|
) |
|
return data |
|
|
|
|
|
class PDBBind(Dataset): |
|
def __init__( |
|
self, |
|
root, |
|
transform=None, |
|
cache_path="data/cache", |
|
split_path="data/", |
|
limit_complexes=0, |
|
receptor_radius=30, |
|
num_workers=1, |
|
c_alpha_max_neighbors=None, |
|
popsize=15, |
|
maxiter=15, |
|
matching=True, |
|
keep_original=False, |
|
max_lig_size=None, |
|
remove_hs=False, |
|
num_conformers=1, |
|
all_atoms=False, |
|
atom_radius=5, |
|
atom_max_neighbors=None, |
|
esm_embeddings_path=None, |
|
require_ligand=False, |
|
ligands_list=None, |
|
protein_path_list=None, |
|
ligand_descriptions=None, |
|
keep_local_structures=False, |
|
): |
|
|
|
super(PDBBind, self).__init__(root, transform) |
|
self.pdbbind_dir = root |
|
self.max_lig_size = max_lig_size |
|
self.split_path = split_path |
|
self.limit_complexes = limit_complexes |
|
self.receptor_radius = receptor_radius |
|
self.num_workers = num_workers |
|
self.c_alpha_max_neighbors = c_alpha_max_neighbors |
|
self.remove_hs = remove_hs |
|
self.esm_embeddings_path = esm_embeddings_path |
|
self.require_ligand = require_ligand |
|
self.protein_path_list = protein_path_list |
|
self.ligand_descriptions = ligand_descriptions |
|
self.keep_local_structures = keep_local_structures |
|
if ( |
|
matching |
|
or protein_path_list is not None |
|
and ligand_descriptions is not None |
|
): |
|
cache_path += "_torsion" |
|
if all_atoms: |
|
cache_path += "_allatoms" |
|
self.full_cache_path = os.path.join( |
|
cache_path, |
|
f"limit{self.limit_complexes}" |
|
f"_INDEX{os.path.splitext(os.path.basename(self.split_path))[0]}" |
|
f"_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}" |
|
f"_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}" |
|
+ ( |
|
"" |
|
if not all_atoms |
|
else f"_atomRad{atom_radius}_atomMax{atom_max_neighbors}" |
|
) |
|
+ ("" if not matching or num_conformers == 1 else f"_confs{num_conformers}") |
|
+ ("" if self.esm_embeddings_path is None else f"_esmEmbeddings") |
|
+ ("" if not keep_local_structures else f"_keptLocalStruct") |
|
+ ( |
|
"" |
|
if protein_path_list is None or ligand_descriptions is None |
|
else str( |
|
binascii.crc32( |
|
"".join(ligand_descriptions + protein_path_list).encode() |
|
) |
|
) |
|
), |
|
) |
|
self.popsize, self.maxiter = popsize, maxiter |
|
self.matching, self.keep_original = matching, keep_original |
|
self.num_conformers = num_conformers |
|
self.all_atoms = all_atoms |
|
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors |
|
if not os.path.exists( |
|
os.path.join(self.full_cache_path, "heterographs.pkl") |
|
) or ( |
|
require_ligand |
|
and not os.path.exists( |
|
os.path.join(self.full_cache_path, "rdkit_ligands.pkl") |
|
) |
|
): |
|
os.makedirs(self.full_cache_path, exist_ok=True) |
|
if protein_path_list is None or ligand_descriptions is None: |
|
self.preprocessing() |
|
else: |
|
self.inference_preprocessing() |
|
|
|
print( |
|
"loading data from memory: ", |
|
os.path.join(self.full_cache_path, "heterographs.pkl"), |
|
) |
|
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), "rb") as f: |
|
self.complex_graphs = pickle.load(f) |
|
if require_ligand: |
|
with open( |
|
os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "rb" |
|
) as f: |
|
self.rdkit_ligands = pickle.load(f) |
|
|
|
print_statistics(self.complex_graphs) |
|
|
|
def len(self): |
|
return len(self.complex_graphs) |
|
|
|
def get(self, idx): |
|
if self.require_ligand: |
|
complex_graph = copy.deepcopy(self.complex_graphs[idx]) |
|
complex_graph.mol = copy.deepcopy(self.rdkit_ligands[idx]) |
|
return complex_graph |
|
else: |
|
return copy.deepcopy(self.complex_graphs[idx]) |
|
|
|
def preprocessing(self): |
|
print( |
|
f"Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]" |
|
) |
|
|
|
complex_names_all = read_strings_from_txt(self.split_path) |
|
if self.limit_complexes is not None and self.limit_complexes != 0: |
|
complex_names_all = complex_names_all[: self.limit_complexes] |
|
print(f"Loading {len(complex_names_all)} complexes.") |
|
|
|
if self.esm_embeddings_path is not None: |
|
id_to_embeddings = torch.load(self.esm_embeddings_path) |
|
chain_embeddings_dictlist = defaultdict(list) |
|
for key, embedding in id_to_embeddings.items(): |
|
key_name = key.split("_")[0] |
|
if key_name in complex_names_all: |
|
chain_embeddings_dictlist[key_name].append(embedding) |
|
lm_embeddings_chains_all = [] |
|
for name in complex_names_all: |
|
lm_embeddings_chains_all.append(chain_embeddings_dictlist[name]) |
|
else: |
|
lm_embeddings_chains_all = [None] * len(complex_names_all) |
|
|
|
if self.num_workers > 1: |
|
|
|
for i in range(len(complex_names_all) // 1000 + 1): |
|
if os.path.exists( |
|
os.path.join(self.full_cache_path, f"heterographs{i}.pkl") |
|
): |
|
continue |
|
complex_names = complex_names_all[1000 * i : 1000 * (i + 1)] |
|
lm_embeddings_chains = lm_embeddings_chains_all[ |
|
1000 * i : 1000 * (i + 1) |
|
] |
|
complex_graphs, rdkit_ligands = [], [] |
|
if self.num_workers > 1: |
|
p = Pool(self.num_workers, maxtasksperchild=1) |
|
p.__enter__() |
|
with tqdm( |
|
total=len(complex_names), |
|
desc=f"loading complexes {i}/{len(complex_names_all)//1000+1}", |
|
) as pbar: |
|
map_fn = p.imap_unordered if self.num_workers > 1 else map |
|
for t in map_fn( |
|
self.get_complex, |
|
zip( |
|
complex_names, |
|
lm_embeddings_chains, |
|
[None] * len(complex_names), |
|
[None] * len(complex_names), |
|
), |
|
): |
|
complex_graphs.extend(t[0]) |
|
rdkit_ligands.extend(t[1]) |
|
pbar.update() |
|
if self.num_workers > 1: |
|
p.__exit__(None, None, None) |
|
|
|
with open( |
|
os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "wb" |
|
) as f: |
|
pickle.dump((complex_graphs), f) |
|
with open( |
|
os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "wb" |
|
) as f: |
|
pickle.dump((rdkit_ligands), f) |
|
|
|
complex_graphs_all = [] |
|
for i in range(len(complex_names_all) // 1000 + 1): |
|
with open( |
|
os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "rb" |
|
) as f: |
|
l = pickle.load(f) |
|
complex_graphs_all.extend(l) |
|
with open( |
|
os.path.join(self.full_cache_path, f"heterographs.pkl"), "wb" |
|
) as f: |
|
pickle.dump((complex_graphs_all), f) |
|
|
|
rdkit_ligands_all = [] |
|
for i in range(len(complex_names_all) // 1000 + 1): |
|
with open( |
|
os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "rb" |
|
) as f: |
|
l = pickle.load(f) |
|
rdkit_ligands_all.extend(l) |
|
with open( |
|
os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), "wb" |
|
) as f: |
|
pickle.dump((rdkit_ligands_all), f) |
|
else: |
|
complex_graphs, rdkit_ligands = [], [] |
|
with tqdm(total=len(complex_names_all), desc="loading complexes") as pbar: |
|
for t in map( |
|
self.get_complex, |
|
zip( |
|
complex_names_all, |
|
lm_embeddings_chains_all, |
|
[None] * len(complex_names_all), |
|
[None] * len(complex_names_all), |
|
), |
|
): |
|
complex_graphs.extend(t[0]) |
|
rdkit_ligands.extend(t[1]) |
|
pbar.update() |
|
with open( |
|
os.path.join(self.full_cache_path, "heterographs.pkl"), "wb" |
|
) as f: |
|
pickle.dump((complex_graphs), f) |
|
with open( |
|
os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "wb" |
|
) as f: |
|
pickle.dump((rdkit_ligands), f) |
|
|
|
def inference_preprocessing(self): |
|
ligands_list = [] |
|
print("Reading molecules and generating local structures with RDKit") |
|
for ligand_description in tqdm(self.ligand_descriptions): |
|
mol = MolFromSmiles(ligand_description) |
|
if mol is not None: |
|
mol = AddHs(mol) |
|
generate_conformer(mol) |
|
ligands_list.append(mol) |
|
else: |
|
mol = read_molecule(ligand_description, remove_hs=False, sanitize=True) |
|
if not self.keep_local_structures: |
|
mol.RemoveAllConformers() |
|
mol = AddHs(mol) |
|
generate_conformer(mol) |
|
ligands_list.append(mol) |
|
|
|
if self.esm_embeddings_path is not None: |
|
print("Reading language model embeddings.") |
|
lm_embeddings_chains_all = [] |
|
if not os.path.exists(self.esm_embeddings_path): |
|
raise Exception( |
|
"ESM embeddings path does not exist: ", self.esm_embeddings_path |
|
) |
|
for protein_path in self.protein_path_list: |
|
embeddings_paths = sorted( |
|
glob.glob( |
|
os.path.join( |
|
self.esm_embeddings_path, os.path.basename(protein_path) |
|
) |
|
+ "*" |
|
) |
|
) |
|
lm_embeddings_chains = [] |
|
for embeddings_path in embeddings_paths: |
|
lm_embeddings_chains.append( |
|
torch.load(embeddings_path)["representations"][33] |
|
) |
|
lm_embeddings_chains_all.append(lm_embeddings_chains) |
|
else: |
|
lm_embeddings_chains_all = [None] * len(self.protein_path_list) |
|
|
|
print("Generating graphs for ligands and proteins") |
|
if self.num_workers > 1: |
|
|
|
for i in range(len(self.protein_path_list) // 1000 + 1): |
|
if os.path.exists( |
|
os.path.join(self.full_cache_path, f"heterographs{i}.pkl") |
|
): |
|
continue |
|
protein_paths_chunk = self.protein_path_list[1000 * i : 1000 * (i + 1)] |
|
ligand_description_chunk = self.ligand_descriptions[ |
|
1000 * i : 1000 * (i + 1) |
|
] |
|
ligands_chunk = ligands_list[1000 * i : 1000 * (i + 1)] |
|
lm_embeddings_chains = lm_embeddings_chains_all[ |
|
1000 * i : 1000 * (i + 1) |
|
] |
|
complex_graphs, rdkit_ligands = [], [] |
|
if self.num_workers > 1: |
|
p = Pool(self.num_workers, maxtasksperchild=1) |
|
p.__enter__() |
|
with tqdm( |
|
total=len(protein_paths_chunk), |
|
desc=f"loading complexes {i}/{len(protein_paths_chunk)//1000+1}", |
|
) as pbar: |
|
map_fn = p.imap_unordered if self.num_workers > 1 else map |
|
for t in map_fn( |
|
self.get_complex, |
|
zip( |
|
protein_paths_chunk, |
|
lm_embeddings_chains, |
|
ligands_chunk, |
|
ligand_description_chunk, |
|
), |
|
): |
|
complex_graphs.extend(t[0]) |
|
rdkit_ligands.extend(t[1]) |
|
pbar.update() |
|
if self.num_workers > 1: |
|
p.__exit__(None, None, None) |
|
|
|
with open( |
|
os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "wb" |
|
) as f: |
|
pickle.dump((complex_graphs), f) |
|
with open( |
|
os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "wb" |
|
) as f: |
|
pickle.dump((rdkit_ligands), f) |
|
|
|
complex_graphs_all = [] |
|
for i in range(len(self.protein_path_list) // 1000 + 1): |
|
with open( |
|
os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), "rb" |
|
) as f: |
|
l = pickle.load(f) |
|
complex_graphs_all.extend(l) |
|
with open( |
|
os.path.join(self.full_cache_path, f"heterographs.pkl"), "wb" |
|
) as f: |
|
pickle.dump((complex_graphs_all), f) |
|
|
|
rdkit_ligands_all = [] |
|
for i in range(len(self.protein_path_list) // 1000 + 1): |
|
with open( |
|
os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), "rb" |
|
) as f: |
|
l = pickle.load(f) |
|
rdkit_ligands_all.extend(l) |
|
with open( |
|
os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), "wb" |
|
) as f: |
|
pickle.dump((rdkit_ligands_all), f) |
|
else: |
|
complex_graphs, rdkit_ligands = [], [] |
|
with tqdm( |
|
total=len(self.protein_path_list), desc="loading complexes" |
|
) as pbar: |
|
for t in map( |
|
self.get_complex, |
|
zip( |
|
self.protein_path_list, |
|
lm_embeddings_chains_all, |
|
ligands_list, |
|
self.ligand_descriptions, |
|
), |
|
): |
|
complex_graphs.extend(t[0]) |
|
rdkit_ligands.extend(t[1]) |
|
pbar.update() |
|
with open( |
|
os.path.join(self.full_cache_path, "heterographs.pkl"), "wb" |
|
) as f: |
|
pickle.dump((complex_graphs), f) |
|
with open( |
|
os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), "wb" |
|
) as f: |
|
pickle.dump((rdkit_ligands), f) |
|
|
|
def get_complex(self, par): |
|
name, lm_embedding_chains, ligand, ligand_description = par |
|
if not os.path.exists(os.path.join(self.pdbbind_dir, name)) and ligand is None: |
|
print("Folder not found", name) |
|
return [], [] |
|
|
|
if ligand is not None: |
|
rec_model = parse_pdb_from_path(name) |
|
name = f"{name}____{ligand_description}" |
|
ligs = [ligand] |
|
else: |
|
try: |
|
rec_model = parse_receptor(name, self.pdbbind_dir) |
|
except Exception as e: |
|
print(f"Skipping {name} because of the error:") |
|
print(e) |
|
return [], [] |
|
|
|
ligs = read_mols(self.pdbbind_dir, name, remove_hs=False) |
|
complex_graphs = [] |
|
for i, lig in enumerate(ligs): |
|
if ( |
|
self.max_lig_size is not None |
|
and lig.GetNumHeavyAtoms() > self.max_lig_size |
|
): |
|
print( |
|
f"Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data." |
|
) |
|
continue |
|
complex_graph = HeteroData() |
|
complex_graph["name"] = name |
|
try: |
|
get_lig_graph_with_matching( |
|
lig, |
|
complex_graph, |
|
self.popsize, |
|
self.maxiter, |
|
self.matching, |
|
self.keep_original, |
|
self.num_conformers, |
|
remove_hs=self.remove_hs, |
|
) |
|
print(lm_embedding_chains) |
|
( |
|
rec, |
|
rec_coords, |
|
c_alpha_coords, |
|
n_coords, |
|
c_coords, |
|
lm_embeddings, |
|
) = extract_receptor_structure( |
|
copy.deepcopy(rec_model), |
|
lig, |
|
lm_embedding_chains=lm_embedding_chains, |
|
) |
|
if lm_embeddings is not None and len(c_alpha_coords) != len( |
|
lm_embeddings |
|
): |
|
print( |
|
f"LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}." |
|
) |
|
continue |
|
|
|
get_rec_graph( |
|
rec, |
|
rec_coords, |
|
c_alpha_coords, |
|
n_coords, |
|
c_coords, |
|
complex_graph, |
|
rec_radius=self.receptor_radius, |
|
c_alpha_max_neighbors=self.c_alpha_max_neighbors, |
|
all_atoms=self.all_atoms, |
|
atom_radius=self.atom_radius, |
|
atom_max_neighbors=self.atom_max_neighbors, |
|
remove_hs=self.remove_hs, |
|
lm_embeddings=lm_embeddings, |
|
) |
|
|
|
except Exception as e: |
|
print(f"Skipping {name} because of the error:") |
|
print(e) |
|
raise e |
|
continue |
|
|
|
protein_center = torch.mean( |
|
complex_graph["receptor"].pos, dim=0, keepdim=True |
|
) |
|
complex_graph["receptor"].pos -= protein_center |
|
if self.all_atoms: |
|
complex_graph["atom"].pos -= protein_center |
|
|
|
if (not self.matching) or self.num_conformers == 1: |
|
complex_graph["ligand"].pos -= protein_center |
|
else: |
|
for p in complex_graph["ligand"].pos: |
|
p -= protein_center |
|
|
|
complex_graph.original_center = protein_center |
|
complex_graphs.append(complex_graph) |
|
return complex_graphs, ligs |
|
|
|
|
|
def print_statistics(complex_graphs): |
|
statistics = ([], [], [], []) |
|
|
|
for complex_graph in complex_graphs: |
|
lig_pos = ( |
|
complex_graph["ligand"].pos |
|
if torch.is_tensor(complex_graph["ligand"].pos) |
|
else complex_graph["ligand"].pos[0] |
|
) |
|
radius_protein = torch.max( |
|
torch.linalg.vector_norm(complex_graph["receptor"].pos, dim=1) |
|
) |
|
molecule_center = torch.mean(lig_pos, dim=0) |
|
radius_molecule = torch.max( |
|
torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1) |
|
) |
|
distance_center = torch.linalg.vector_norm(molecule_center) |
|
statistics[0].append(radius_protein) |
|
statistics[1].append(radius_molecule) |
|
statistics[2].append(distance_center) |
|
if "rmsd_matching" in complex_graph: |
|
statistics[3].append(complex_graph.rmsd_matching) |
|
else: |
|
statistics[3].append(0) |
|
|
|
name = [ |
|
"radius protein", |
|
"radius molecule", |
|
"distance protein-mol", |
|
"rmsd matching", |
|
] |
|
print("Number of complexes: ", len(complex_graphs)) |
|
for i in range(4): |
|
array = np.asarray(statistics[i]) |
|
print( |
|
f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}" |
|
) |
|
|
|
|
|
def construct_loader(args, t_to_sigma): |
|
transform = NoiseTransform( |
|
t_to_sigma=t_to_sigma, no_torsion=args.no_torsion, all_atom=args.all_atoms |
|
) |
|
|
|
common_args = { |
|
"transform": transform, |
|
"root": args.data_dir, |
|
"limit_complexes": args.limit_complexes, |
|
"receptor_radius": args.receptor_radius, |
|
"c_alpha_max_neighbors": args.c_alpha_max_neighbors, |
|
"remove_hs": args.remove_hs, |
|
"max_lig_size": args.max_lig_size, |
|
"matching": not args.no_torsion, |
|
"popsize": args.matching_popsize, |
|
"maxiter": args.matching_maxiter, |
|
"num_workers": args.num_workers, |
|
"all_atoms": args.all_atoms, |
|
"atom_radius": args.atom_radius, |
|
"atom_max_neighbors": args.atom_max_neighbors, |
|
"esm_embeddings_path": args.esm_embeddings_path, |
|
} |
|
|
|
train_dataset = PDBBind( |
|
cache_path=args.cache_path, |
|
split_path=args.split_train, |
|
keep_original=True, |
|
num_conformers=args.num_conformers, |
|
**common_args, |
|
) |
|
val_dataset = PDBBind( |
|
cache_path=args.cache_path, |
|
split_path=args.split_val, |
|
keep_original=True, |
|
**common_args, |
|
) |
|
|
|
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader |
|
train_loader = loader_class( |
|
dataset=train_dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_dataloader_workers, |
|
shuffle=True, |
|
pin_memory=args.pin_memory, |
|
) |
|
val_loader = loader_class( |
|
dataset=val_dataset, |
|
batch_size=args.batch_size, |
|
num_workers=args.num_dataloader_workers, |
|
shuffle=True, |
|
pin_memory=args.pin_memory, |
|
) |
|
|
|
return train_loader, val_loader |
|
|
|
|
|
def read_mol(pdbbind_dir, name, remove_hs=False): |
|
lig = read_molecule( |
|
os.path.join(pdbbind_dir, name, f"{name}_ligand.sdf"), |
|
remove_hs=remove_hs, |
|
sanitize=True, |
|
) |
|
if lig is None: |
|
lig = read_molecule( |
|
os.path.join(pdbbind_dir, name, f"{name}_ligand.mol2"), |
|
remove_hs=remove_hs, |
|
sanitize=True, |
|
) |
|
return lig |
|
|
|
|
|
def read_mols(pdbbind_dir, name, remove_hs=False): |
|
ligs = [] |
|
for file in os.listdir(os.path.join(pdbbind_dir, name)): |
|
if file.endswith(".sdf") and "rdkit" not in file: |
|
lig = read_molecule( |
|
os.path.join(pdbbind_dir, name, file), |
|
remove_hs=remove_hs, |
|
sanitize=True, |
|
) |
|
if lig is None and os.path.exists( |
|
os.path.join(pdbbind_dir, name, file[:-4] + ".mol2") |
|
): |
|
print( |
|
"Using the .sdf file failed. We found a .mol2 file instead and are trying to use that." |
|
) |
|
lig = read_molecule( |
|
os.path.join(pdbbind_dir, name, file[:-4] + ".mol2"), |
|
remove_hs=remove_hs, |
|
sanitize=True, |
|
) |
|
if lig is not None: |
|
ligs.append(lig) |
|
return ligs |
|
|