import rdkit from rdkit import Chem from rdkit.Chem import Draw from rdkit import DataStructs from rdkit.Chem import AllChem from rdkit.Chem import rdmolfiles from rdkit.Chem.Draw import IPythonConsole from molvs import standardize_smiles import os import gc import sys import time import json import math import random import argparse import itertools import numpy as np import mxnet as mx import pandas as pd import networkx as nx from scipy import sparse from mxnet.gluon import nn from collections import Counter from mxnet.autograd import Function from mxnet.gluon.data import Dataset from mxnet import gluon, autograd, nd from mxnet.gluon.data import DataLoader from abc import ABCMeta, abstractmethod from mxnet.gluon.data.sampler import Sampler class MoleculeSpec(object): def __init__(self, file_name='models_folder/atom_types.txt'): self.atom_types = [] self.atom_symbols = [] with open(file_name) as f: for line in f: atom_type_i = line.strip('\n').split(',') self.atom_types.append((atom_type_i[0], int(atom_type_i[1]), int(atom_type_i[2]))) if atom_type_i[0] not in self.atom_symbols: self.atom_symbols.append(atom_type_i[0]) self.bond_orders = [Chem.BondType.AROMATIC, Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE] self.max_iter = 120 def get_atom_type(self, atom): atom_symbol = atom.GetSymbol() atom_charge = atom.GetFormalCharge() atom_hs = atom.GetNumExplicitHs() return self.atom_types.index((atom_symbol, atom_charge, atom_hs)) def get_bond_type(self, bond): return self.bond_orders.index(bond.GetBondType()) def index_to_atom(self, idx): atom_symbol, atom_charge, atom_hs = self.atom_types[idx] a = Chem.Atom(atom_symbol) a.SetFormalCharge(atom_charge) a.SetNumExplicitHs(atom_hs) return a def index_to_bond(self, mol, begin_id, end_id, idx): mol.AddBond(begin_id, end_id, self.bond_orders[idx]) @property def num_atom_types(self): return len(self.atom_types) @property def num_bond_types(self): return len(self.bond_orders) _mol_spec = None def get_mol_spec(): global _mol_spec if _mol_spec is None: _mol_spec = MoleculeSpec() return _mol_spec def get_graph_from_smiles(smiles): mol = Chem.MolFromSmiles(smiles) # build graph atom_types, atom_ranks, bonds, bond_types = [], [], [], [] for a, r in zip(mol.GetAtoms(), Chem.CanonicalRankAtoms(mol)): atom_types.append(get_mol_spec().get_atom_type(a)) atom_ranks.append(r) for b in mol.GetBonds(): idx_1, idx_2, bt = b.GetBeginAtomIdx(), b.GetEndAtomIdx(), get_mol_spec().get_bond_type(b) bonds.append([idx_1, idx_2]) bond_types.append(bt) # build nx graph graph = nx.Graph() graph.add_nodes_from(range(len(atom_types))) graph.add_edges_from(bonds) return graph, atom_types, atom_ranks, bonds, bond_types def get_graph_from_smiles_list(smiles_list): graph_list = [] for smiles in smiles_list: mol = Chem.MolFromSmiles(smiles) # build graph atom_types, bonds, bond_types = [], [], [] for a in mol.GetAtoms(): atom_types.append(get_mol_spec().get_atom_type(a)) for b in mol.GetBonds(): idx_1, idx_2, bt = b.GetBeginAtomIdx(), b.GetEndAtomIdx(), get_mol_spec().get_bond_type(b) bonds.append([idx_1, idx_2]) bond_types.append(bt) X_0 = np.array(atom_types, dtype=np.int64) A_0 = np.concatenate([np.array(bonds, dtype=np.int64), np.array(bond_types, dtype=np.int64)[:, np.newaxis]], axis=1) graph_list.append([X_0, A_0]) return graph_list def traverse_graph(graph, atom_ranks, current_node=None, step_ids=None, p=0.9, log_p=0.0): if current_node is None: next_nodes = range(len(atom_ranks)) step_ids = [-1, ] * len(next_nodes) next_node_ranks = atom_ranks else: next_nodes = graph.neighbors(current_node) # get neighbor nodes next_nodes = [n for n in next_nodes if step_ids[n] < 0] # filter visited nodes next_node_ranks = [atom_ranks[n] for n in next_nodes] # get ranks for neighbors next_nodes = [n for n, r in sorted(zip(next_nodes, next_node_ranks), key=lambda _x:_x[1])] # sort by rank # iterate through neighbors while len(next_nodes) > 0: if len(next_nodes)==1: next_node = next_nodes[0] elif random.random() >= (1 - p): next_node = next_nodes[0] log_p += np.log(p) else: next_node = next_nodes[random.randint(1, len(next_nodes) - 1)] log_p += np.log((1.0 - p) / (len(next_nodes) - 1)) step_ids[next_node] = max(step_ids) + 1 _, log_p = traverse_graph(graph, atom_ranks, next_node, step_ids, p, log_p) next_nodes = [n for n in next_nodes if step_ids[n] < 0] # filter visited nodes return step_ids, log_p def single_reorder(X_0, A_0, step_ids): X_0, A_0 = np.copy(X_0), np.copy(A_0) step_ids = np.array(step_ids, dtype=np.int64) # sort by step_ids sorted_ids = np.argsort(step_ids) X_0 = X_0[sorted_ids] A_0[:, 0], A_0[:, 1] = step_ids[A_0[:, 0]], step_ids[A_0[:, 1]] max_b, min_b = np.amax(A_0[:, :2], axis=1), np.amin(A_0[:, :2], axis=1) A_0 = A_0[np.lexsort([-min_b, max_b]), :] # separate append and connect max_b, min_b = np.amax(A_0[:, :2], axis=1), np.amin(A_0[:, :2], axis=1) is_append = np.concatenate([np.array([True]), max_b[1:] > max_b[:-1]]) A_0 = np.concatenate([np.where(is_append[:, np.newaxis], np.stack([min_b, max_b], axis=1), np.stack([max_b, min_b], axis=1)), A_0[:, -1:]], axis=1) return X_0, A_0 def single_expand(X_0, A_0): X_0, A_0 = np.copy(X_0), np.copy(A_0) # expand X is_append_iter = np.less(A_0[:, 0], A_0[:, 1]).astype(np.int64) NX = np.cumsum(np.pad(is_append_iter, [[1, 0]], mode='constant', constant_values=1)) shift = np.cumsum(np.pad(NX, [[1, 0]], mode='constant')[:-1]) X_index = np.arange(NX.sum(), dtype=np.int64) - np.repeat(shift, NX) X = X_0[X_index] # expand A _, A_index = np.tril_indices(A_0.shape[0]) A = A_0[A_index, :] NA = np.arange(A_0.shape[0] + 1) # get action # action_type, atom_type, bond_type, append_pos, connect_pos action_type = 1 - is_append_iter atom_type = np.where(action_type == 0, X_0[A_0[:, 1]], 0) bond_type = A_0[:, 2] append_pos = np.where(action_type == 0, A_0[:, 0], 0) connect_pos = np.where(action_type == 1, A_0[:, 1], 0) actions = np.stack([action_type, atom_type, bond_type, append_pos, connect_pos], axis=1) last_action = [[2, 0, 0, 0, 0]] actions = np.append(actions, last_action, axis=0) action_0 = np.array([X_0[0]], dtype=np.int64) # }}} # {{{ Get mask last_atom_index = shift + NX - 1 last_atom_mask = np.zeros_like(X) last_atom_mask[last_atom_index] = np.where( np.pad(is_append_iter, [[1, 0]], mode='constant', constant_values=1) == 1, np.ones_like(last_atom_index), np.ones_like(last_atom_index) * 2) # }}} return action_0, X, NX, A, NA, actions, last_atom_mask def get_d(A, X): _to_sparse = lambda _A, _X: sparse.coo_matrix((np.ones([_A.shape[0] * 2], dtype=np.int64), (np.concatenate([_A[:, 0], _A[:, 1]], axis=0), np.concatenate([_A[:, 1], _A[:, 0]], axis=0))), shape=[_X.shape[0], ] * 2) A_sparse = _to_sparse(A, X) d2 = A_sparse * A_sparse d3 = d2 * A_sparse # get D_2 D_2 = np.stack(d2.nonzero(), axis=1) D_2 = D_2[D_2[:, 0] < D_2[:, 1], :] # get D_3 D_3 = np.stack(d3.nonzero(), axis=1) D_3 = D_3[D_3[:, 0] < D_3[:, 1], :] # remove D_1 elements from D_3 D_3_sparse = _to_sparse(D_3, X) D_3_sparse = D_3_sparse - D_3_sparse.multiply(A_sparse) D_3 = np.stack(D_3_sparse.nonzero(), axis=1) D_3 = D_3[D_3[:, 0] < D_3[:, 1], :] return D_2, D_3 def merge_single_0(X_0, A_0, NX_0, NA_0): # shift_ids cumsum = np.cumsum(np.pad(NX_0, [[1, 0]], mode='constant')[:-1]) A_0[:, :2] += np.stack([np.repeat(cumsum, NA_0), ] * 2, axis=1) # get D D_0_2, D_0_3 = get_d(A_0, X_0) # split A A_split = [] for i in range(get_mol_spec().num_bond_types): A_i = A_0[A_0[:, 2] == i, :2] A_split.append(A_i) A_split.extend([D_0_2, D_0_3]) A_0 = A_split # NX_rep NX_rep_0 = np.repeat(np.arange(NX_0.shape[0]), NX_0) return X_0, A_0, NX_0, NX_rep_0 def merge_single(X, A, NX, NA, mol_ids, rep_ids, iw_ids, action_0, actions, last_append_mask, log_p): X, A, NX, NX_rep = merge_single_0(X, A, NX, NA) cumsum = np.cumsum(np.pad(NX, [[1, 0]], mode='constant')[:-1]) actions[:, -2] += cumsum * (actions[:, 0] == 0) actions[:, -1] += cumsum * (actions[:, 0] == 1) mol_ids_rep = np.repeat(mol_ids, NX) rep_ids_rep = np.repeat(rep_ids, NX) return X, A,\ mol_ids_rep, rep_ids_rep, iw_ids,\ last_append_mask,\ NX, NX_rep,\ action_0, actions, \ log_p def process_single(smiles, k, p): graph, atom_types, atom_ranks, bonds, bond_types = get_graph_from_smiles(smiles) # original X_0 = np.array(atom_types, dtype=np.int64) A_0 = np.concatenate([np.array(bonds, dtype=np.int64), np.array(bond_types, dtype=np.int64)[:, np.newaxis]], axis=1) X, A = [], [] NX, NA = [], [] mol_ids, rep_ids, iw_ids = [], [], [] action_0, actions = [], [] last_append_mask = [] log_p = [] # random sampling decoding route for i in range(k): step_ids_i, log_p_i = traverse_graph(graph, atom_ranks, p=p) X_i, A_i = single_reorder(X_0, A_0, step_ids_i) action_0_i, X_i, NX_i, A_i, NA_i, actions_i, last_atom_mask_i = single_expand(X_i, A_i) # appends X.append(X_i) A.append(A_i) NX.append(NX_i) NA.append(NA_i) action_0.append(action_0_i) actions.append(actions_i) last_append_mask.append(last_atom_mask_i) mol_ids.append(np.zeros_like(NX_i, dtype=np.int64)) rep_ids.append(np.ones_like(NX_i, dtype=np.int64) * i) iw_ids.append(np.ones_like(NX_i, dtype=np.int64) * i) log_p.append(log_p_i) # concatenate X = np.concatenate(X, axis=0) A = np.concatenate(A, axis = 0) NX = np.concatenate(NX, axis = 0) NA = np.concatenate(NA, axis = 0) action_0 = np.concatenate(action_0, axis = 0) actions = np.concatenate(actions, axis = 0) last_append_mask = np.concatenate(last_append_mask, axis = 0) mol_ids = np.concatenate(mol_ids, axis = 0) rep_ids = np.concatenate(rep_ids, axis = 0) iw_ids = np.concatenate(iw_ids, axis = 0) log_p = np.array(log_p, dtype=np.float32) return X, A, NX, NA, mol_ids, rep_ids, iw_ids, action_0, actions, last_append_mask, log_p # noinspection PyArgumentList def get_mol_from_graph(X, A, sanitize=True): try: mol = Chem.RWMol(Chem.Mol()) X, A = X.tolist(), A.tolist() for i, atom_type in enumerate(X): mol.AddAtom(get_mol_spec().index_to_atom(atom_type)) for atom_id1, atom_id2, bond_type in A: get_mol_spec().index_to_bond(mol, atom_id1, atom_id2, bond_type) except: return None if sanitize: try: mol = mol.GetMol() Chem.SanitizeMol(mol) return mol except: return None else: return mol def get_mol_from_graph_list(graph_list, sanitize=True): mol_list = [get_mol_from_graph(X, A, sanitize) for X, A in graph_list] return mol_list class GraphConvFn(Function): def __init__(self, A): self.A = A # type: nd.sparse.CSRNDArray self.A_T = self.A # assume symmetric super(GraphConvFn, self).__init__() def forward(self, X): if self.A is not None: if len(X.shape) > 2: X_resized = X.reshape((X.shape[0], -1)) output = nd.sparse.dot(self.A, X_resized) output = output.reshape([-1, ] + [X.shape[i] for i in range(1, len(X.shape))]) else: output = nd.sparse.dot(self.A, X) return output else: return nd.zeros_like(X) def backward(self, grad_output): if self.A is not None: if len(grad_output.shape) > 2: grad_output_resized = grad_output.reshape((grad_output.shape[0], -1)) grad_input = nd.sparse.dot(self.A_T, grad_output_resized) grad_input = grad_input.reshape([-1] + [grad_output.shape[i] for i in range(1, len(grad_output.shape))]) else: grad_input = nd.sparse.dot(self.A_T, grad_output) return grad_input else: return nd.zeros_like(grad_output) class EfficientGraphConvFn(Function): """Save memory by re-computation""" def __init__(self, A_list): self.A_list = A_list super(EfficientGraphConvFn, self).__init__() def forward(self, X, W): X_list = [X] for A in self.A_list: if A is not None: X_list.append(nd.sparse.dot(A, X)) else: X_list.append(nd.zeros_like(X)) X_out = nd.concat(*X_list, dim=1) self.save_for_backward(X, W) return nd.dot(X_out, W) def backward(self, grad_output): X, W = self.saved_tensors # recompute X_out X_list = [X, ] for A in self.A_list: if A is not None: X_list.append(nd.sparse.dot(A, X)) else: X_list.append(nd.zeros_like(X)) X_out = nd.concat(*X_list, dim=1) grad_W = nd.dot(X_out.T, grad_output) grad_X_out = nd.dot(grad_output, W.T) grad_X_out_list = nd.split(grad_X_out, num_outputs=len(self.A_list) + 1) grad_X = [grad_X_out_list[0], ] for A, grad_X_out in zip(self.A_list, grad_X_out_list[1:]): if A is not None: grad_X.append(nd.sparse.dot(A, grad_X_out)) else: grad_X.append(nd.zeros_like(grad_X_out)) grad_X = sum(grad_X) return grad_X, grad_W class SegmentSumFn(GraphConvFn): def __init__(self, idx, num_seg): # build A # construct coo data = nd.ones(idx.shape[0], ctx=idx.context, dtype='int64') row, col = idx, nd.arange(idx.shape[0], ctx=idx.context, dtype='int64') shape = (num_seg, int(idx.shape[0])) sparse = nd.sparse.csr_matrix((data, (row, col)), shape=shape, ctx=idx.context, dtype='float32') super(SegmentSumFn, self).__init__(sparse) sparse = nd.sparse.csr_matrix((data, (col, row)), shape=(shape[1], shape[0]), ctx=idx.context, dtype='float32') self.A_T = sparse def squeeze(input, axis): assert input.shape[axis] == 1 new_shape = list(input.shape) del new_shape[axis] return input.reshape(new_shape) def unsqueeze(input, axis): return nd.expand_dims(input, axis=axis) def logsumexp(inputs, axis=None, keepdims=False): """Numerically stable logsumexp. Args: inputs: A Variable with any shape. axis: An integer. keepdims: A boolean. Returns: Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)). Adopted from: https://github.com/pytorch/pytorch/issues/2591 """ # For a 1-D array x (any array along a single dimension), # log sum exp(x) = s + log sum exp(x - s) # with s = max(x) being a common choice. if axis is None: inputs = inputs.reshape([-1]) axis = 0 s = nd.max(inputs, axis=axis, keepdims=True) outputs = s + (inputs - s).exp().sum(axis=axis, keepdims=True).log() if not keepdims: outputs = nd.sum(outputs, axis=axis, keepdims=False) return outputs def get_activation(name): activation_dict = { 'relu':nd.relu, 'tanh':nd.tanh } return activation_dict[name] class Linear_BN(nn.Sequential): def __init__(self, F_in, F_out): super(Linear_BN, self).__init__() self.add(nn.Dense(F_out, in_units=F_in, use_bias=False)) self.add(BatchNorm(in_channels=F_out)) class GraphConv(nn.Block): def __init__(self, Fin, Fout, D): super(GraphConv, self).__init__() # model settings self.Fin = Fin self.Fout = Fout self.D = D # model parameters self.W = self.params.get('w', shape=(self.Fin * (self.D + 1), self.Fout), init=None, allow_deferred_init=False) def forward(self, X, A_list): try: assert len(A_list) == self.D except AssertionError as e: print(self.D, len(A_list)) raise e return EfficientGraphConvFn(A_list)(X, self.W.data(X.context)) class Policy(nn.Block): def __init__(self, F_in, F_h, N_A, N_B, k=1): super(Policy, self).__init__() self.F_in = F_in # number of input features for each atom self.F_h = F_h # number of context variables self.N_A = N_A # number of atom types self.N_B = N_B # number of bond types self.k = k # number of softmax used in the mixture with self.name_scope(): self.linear_h = Linear_BN(F_in * 2, self.F_h * k) self.linear_h_t = Linear_BN(F_in, self.F_h * k) self.linear_x = nn.Dense(self.N_B + self.N_B*self.N_A, in_units=self.F_h) self.linear_x_t = nn.Dense(1, in_units=self.F_h) if self.k > 1: self.linear_pi = nn.Dense(self.k, in_units=self.F_in) else: self.linear_pi = None def forward(self, X, NX, NX_rep, X_end=None): # segment mean for X if X_end is None: X_end = SegmentSumFn(NX_rep, NX.shape[0])(X)/nd.cast(fn.unsqueeze(NX, 1), 'float32') X = nd.concat(X, X_end[NX_rep, :], dim=1) X_h = nd.relu(self.linear_h(X)).reshape([-1, self.F_h]) X_h_end = nd.relu(self.linear_h_t(X_end)).reshape([-1, self.F_h]) X_x = nd.exp(self.linear_x(X_h)).reshape([-1, self.k, self.N_B + self.N_B*self.N_A]) X_x_end = nd.exp(self.linear_x_t(X_h_end)).reshape([-1, self.k, 1]) X_sum = nd.sum(SegmentSumFn(NX_rep, NX.shape[0])(X_x), -1, keepdims=True) + X_x_end X_sum_gathered = X_sum[NX_rep, :, :] X_softmax = X_x / X_sum_gathered X_softmax_end = X_x_end/ X_sum if self.k > 1: pi = unsqueeze(nd.softmax(self.linear_pi(X_end), axis=1), -1) pi_gathered = pi[NX_rep, :, :] X_softmax = nd.sum(X_softmax * pi_gathered, axis=1) X_softmax_end = nd.sum(X_softmax_end * pi, axis=1) else: X_softmax = squeeze(X_softmax, 1) X_softmax_end = squeeze(X_softmax_end, 1) # generate output probabilities connect, append = X_softmax[:, :self.N_B], X_softmax[:, self.N_B:] append = append.reshape([-1, self.N_A, self.N_B]) end = squeeze(X_softmax_end, -1) return append, connect, end class BatchNorm(nn.Block): def __init__(self, in_channels, momentum=0.9, eps=1e-5): super(BatchNorm, self).__init__() self.F = in_channels self.bn_weight = self.params.get('bn_weight', shape=(self.F,), init=mx.init.One(), allow_deferred_init=False) self.bn_bias = self.params.get('bn_bias', shape=(self.F,), init=mx.init.Zero(), allow_deferred_init=False) self.running_mean = self.params.get('running_mean', grad_req='null', shape=(self.F,), init=mx.init.Zero(), allow_deferred_init=False, differentiable=False) self.running_var = self.params.get('running_var', grad_req='null', shape=(self.F,), init=mx.init.One(), allow_deferred_init=False, differentiable=False) self.momentum = momentum self.eps = eps def forward(self, x): if autograd.is_training(): return nd.BatchNorm(x, gamma=self.bn_weight.data(x.context), beta=self.bn_bias.data(x.context), moving_mean=self.running_mean.data(x.context), moving_var=self.running_var.data(x.context), eps=self.eps, momentum=self.momentum, use_global_stats=False) else: return nd.BatchNorm(x, gamma=self.bn_weight.data(x.context), beta=self.bn_bias.data(x.context), moving_mean=self.running_mean.data(x.context), moving_var=self.running_var.data(x.context), eps=self.eps, momentum=self.momentum, use_global_stats=True) class MoleculeGenerator(nn.Block): __metaclass__ = ABCMeta def __init__(self, N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation, *args, **kwargs): super(MoleculeGenerator, self).__init__() self.N_A = N_A self.N_B = N_B self.D = D self.F_e = F_e self.F_skip = F_skip self.F_c = list(F_c) if isinstance(F_c, tuple) else F_c self.Fh_policy = Fh_policy self.activation = get_activation(activation) with self.name_scope(): # embeddings self.embedding_atom = nn.Embedding(self.N_A, self.F_e) self.embedding_mask = nn.Embedding(3, self.F_e) # graph conv self._build_graph_conv(*args, **kwargs) # fully connected self.dense = nn.Sequential() for i, (f_in, f_out) in enumerate(zip([self.F_skip, ] + self.F_c[:-1], self.F_c)): self.dense.add(Linear_BN(f_in, f_out)) # policy self.policy_0 = self.params.get('policy_0', shape=[self.N_A, ], init=mx.init.Zero(), allow_deferred_init=False) self.policy_h = Policy(self.F_c[-1], self.Fh_policy, self.N_A, self.N_B) self.mode = 'loss' @abstractmethod def _build_graph_conv(self, *args, **kwargs): raise NotImplementedError @abstractmethod def _graph_conv_forward(self, X, A): raise NotImplementedError def _policy_0(self, ctx): policy_0 = nd.exp(self.policy_0.data(ctx)) policy_0 = policy_0/policy_0.sum() return policy_0 def _policy(self, X, A, NX, NX_rep, last_append_mask): # get initial embedding X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) # convolution X = self._graph_conv_forward(X, A) # linear X = self.dense(X) # policy append, connect, end = self.policy_h(X, NX, NX_rep) return append, connect, end def _likelihood(self, init, append, connect, end, action_0, actions, iw_ids, log_p_sigma, batch_size, iw_size): # decompose action: action_type, node_type, edge_type, append_pos, connect_pos = \ actions[:, 0], actions[:, 1], actions[:, 2], actions[:, 3], actions[:, 4] _log_mask = lambda _x, _mask: _mask * nd.log(_x + 1e-10) + (1- _mask) * nd.zeros_like(_x) # init init = init.reshape([batch_size * iw_size, self.N_A]) index = nd.stack(nd.arange(action_0.shape[0], ctx=action_0.context, dtype='int64'), action_0, axis=0) loss_init = nd.log(nd.gather_nd(init, index) + 1e-10) # end loss_end = _log_mask(end, nd.cast(action_type == 2, 'float32')) # append index = nd.stack(append_pos, node_type, edge_type, axis=0) loss_append = _log_mask(nd.gather_nd(append, index), nd.cast(action_type == 0, 'float32')) # connect index = nd.stack(connect_pos, edge_type, axis=0) loss_connect = _log_mask(nd.gather_nd(connect, index), nd.cast(action_type == 1, 'float32')) # sum up results log_p_x = loss_end + loss_append + loss_connect log_p_x = squeeze(SegmentSumFn(iw_ids, batch_size*iw_size)(unsqueeze(log_p_x, -1)), -1) log_p_x = log_p_x + loss_init # reshape log_p_x = log_p_x.reshape([batch_size, iw_size]) log_p_sigma = log_p_sigma.reshape([batch_size, iw_size]) l = log_p_x - log_p_sigma l = logsumexp(l, axis=1) - math.log(float(iw_size)) return l def forward(self, *input): if self.mode=='loss' or self.mode=='likelihood': X, A, iw_ids, last_append_mask, \ NX, NX_rep, action_0, actions, log_p, \ batch_size, iw_size = input init = self._policy_0(X.context).tile([batch_size * iw_size, 1]) append, connect, end = self._policy(X, A, NX, NX_rep, last_append_mask) l = self._likelihood(init, append, connect, end, action_0, actions, iw_ids, log_p, batch_size, iw_size) if self.mode=='likelihood': return l else: return -l.mean() elif self.mode == 'decode_0': return self._policy_0(input[0]) elif self.mode == 'decode_step': X, A, NX, NX_rep, last_append_mask = input return self._policy(X, A, NX, NX_rep, last_append_mask) class MoleculeGenerator_RNN(MoleculeGenerator): __metaclass__ = ABCMeta def __init__(self, N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation, N_rnn, *args, **kwargs): super(MoleculeGenerator_RNN, self).__init__(N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation, *args, **kwargs) self.N_rnn = N_rnn with self.name_scope(): self.rnn = gluon.rnn.GRU(hidden_size=self.F_c[-1], num_layers=self.N_rnn, layout='NTC', input_size=self.F_c[-1] * 2) def _rnn_train(self, X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum): X_avg = SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(unsqueeze(NX, 1), 'float32') X_curr = nd.take(X, indices=NX_cum-1) X = nd.concat(X_avg, X_curr, dim=1) # rnn X = nd.take(X, indices=graph_to_rnn) # batch_size, iw_size, length, num_features batch_size, iw_size, length, num_features = X.shape X = X.reshape([batch_size*iw_size, length, num_features]) X = self.rnn(X) X = X.reshape([batch_size, iw_size, length, -1]) X = nd.gather_nd(X, indices=rnn_to_graph) return X def _rnn_test(self, X, NX, NX_rep, NX_cum, h): # note: one partition for one molecule X_avg = SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(unsqueeze(NX, 1), 'float32') X_curr = nd.take(X, indices=NX_cum - 1) X = nd.concat(X_avg, X_curr, dim=1) # size: [NX, F_in * 2] # rnn X = unsqueeze(X, axis=1) X, h = self.rnn(X, h) X = squeeze(X, axis=1) return X, h def _policy(self, X, A, NX, NX_rep, last_append_mask, graph_to_rnn, rnn_to_graph, NX_cum): # get initial embedding X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) # convolution X = self._graph_conv_forward(X, A) # linear X = self.dense(X) # rnn X_mol = self._rnn_train(X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum) # policy append, connect, end = self.policy_h(X, NX, NX_rep, X_mol) return append, connect, end def _decode_step(self, X, A, NX, NX_rep, last_append_mask, NX_cum, h): # get initial embedding X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) # convolution X = self._graph_conv_forward(X, A) # linear X = self.dense(X) # rnn X_mol, h = self._rnn_test(X, NX, NX_rep, NX_cum, h) # policy append, connect, end = self.policy_h(X, NX, NX_rep, X_mol) return append, connect, end, h def forward(self, *input): if self.mode=='loss' or self.mode=='likelihood': X, A, iw_ids, last_append_mask, \ NX, NX_rep, action_0, actions, log_p, \ batch_size, iw_size, \ graph_to_rnn, rnn_to_graph, NX_cum = input init = self._policy_0(X.context).tile([batch_size * iw_size, 1]) append, connect, end = self._policy(X, A, NX, NX_rep, last_append_mask, graph_to_rnn, rnn_to_graph, NX_cum) l = self._likelihood(init, append, connect, end, action_0, actions, iw_ids, log_p, batch_size, iw_size) if self.mode=='likelihood': return l else: return -l.mean() elif self.mode == 'decode_0': return self._policy_0(input[0]) elif self.mode == 'decode_step': X, A, NX, NX_rep, last_append_mask, NX_cum, h = input return self._decode_step(X, A, NX, NX_rep, last_append_mask, NX_cum, h) else: raise ValueError class _TwoLayerDense(nn.Block): def __init__(self, input_size, hidden_size, output_size): super(_TwoLayerDense, self).__init__() self.hidden_size = hidden_size self.output_size = output_size self.input_size = input_size with self.name_scope(): # config 1 self.input = nn.Dense(self.hidden_size, use_bias=False, in_units=self.input_size) self.bn_input = BatchNorm(in_channels=hidden_size) self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.hidden_size) # config 2 #self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.input_size) # config 3 #self.input1 = nn.Dense(self.hidden_size, use_bias=False, in_units=self.input_size) #self.bn_input1 = BatchNorm(in_channels=self.hidden_size) #self.input2 = nn.Dense(self.hidden_size, use_bias=False, in_units=self.hidden_size) #self.bn_input2 = BatchNorm(in_channels=self.hidden_size) #self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.hidden_size) # config 4 #self.bn_input = BatchNorm(in_channels=self.input_size) #self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.input_size) # config 5 #self.bn_input = BatchNorm(in_channels=1024) #self.output = nn.Dense(self.output_size, use_bias=True, in_units=1024) def forward(self, c): # config 1 return nd.softmax(self.output(nd.relu(self.bn_input(self.input(c)))), axis=-1) # config 2 #return nd.softmax(self.output(c), axis=-1) # config 3 #return nd.softmax(self.output(nd.relu(self.bn_input2(self.input2(nd.relu(self.bn_input1(self.input1(c))))))), axis=-1) # config 4 #return nd.softmax(self.output(nd.relu(self.bn_input(c))), axis=-1) # config 5 #return nd.softmax(self.output(c), axis=-1) class CMoleculeGenerator_RNN(MoleculeGenerator_RNN): __metaclass__ = ABCMeta def __init__(self, N_A, N_B, N_C, D, F_e, F_skip, F_c, Fh_policy, activation, N_rnn, *args, **kwargs): self.N_C = N_C # number of conditional variables super(CMoleculeGenerator_RNN, self).__init__(N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation, N_rnn, *args, **kwargs) with self.name_scope(): self.dense_policy_0 = _TwoLayerDense(self.N_C, self.N_A * 3, self.N_A) @abstractmethod def _graph_conv_forward(self, X, A, c, ids): raise NotImplementedError def _policy_0(self, c): return self.dense_policy_0(c) + 0.0 * self.policy_0.data(c.context) def _policy(self, X, A, NX, NX_rep, last_append_mask, graph_to_rnn, rnn_to_graph, NX_cum, c, ids): # get initial embedding X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) # convolution X = self._graph_conv_forward(X, A, c, ids) # linear X = self.dense(X) # rnn X_mol = self._rnn_train(X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum) # policy append, connect, end = self.policy_h(X, NX, NX_rep, X_mol) return append, connect, end def _decode_step(self, X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids): # get initial embedding X = self.embedding_atom(X) + self.embedding_mask(last_append_mask) # convolution X = self._graph_conv_forward(X, A, c, ids) # linear X = self.dense(X) # rnn X_mol, h = self._rnn_test(X, NX, NX_rep, NX_cum, h) # policy append, connect, end = self.policy_h(X, NX, NX_rep, X_mol) return append, connect, end, h def forward(self, *input): if self.mode=='loss' or self.mode=='likelihood': X, A, iw_ids, last_append_mask, \ NX, NX_rep, action_0, actions, log_p, \ batch_size, iw_size, \ graph_to_rnn, rnn_to_graph, NX_cum, \ c, ids = input init = nd.tile(unsqueeze(self._policy_0(c), axis=1), [1, iw_size, 1]) append, connect, end = self._policy(X, A, NX, NX_rep, last_append_mask, graph_to_rnn, rnn_to_graph, NX_cum, c, ids) l = self._likelihood(init, append, connect, end, action_0, actions, iw_ids, log_p, batch_size, iw_size) if self.mode=='likelihood': return l else: return -l.mean() elif self.mode == 'decode_0': return self._policy_0(*input) elif self.mode == 'decode_step': X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids = input return self._decode_step(X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids) else: raise ValueError class CVanillaMolGen_RNN(CMoleculeGenerator_RNN): def __init__(self, N_A, N_B, N_C, D, F_e, F_h, F_skip, F_c, Fh_policy, activation, N_rnn, rename=False): self.rename = rename super(CVanillaMolGen_RNN, self).__init__(N_A, N_B, N_C, D, F_e, F_skip, F_c, Fh_policy, activation, N_rnn, F_h) def _build_graph_conv(self, F_h): self.F_h = list(F_h) if isinstance(F_h, tuple) else F_h self.conv, self.bn = [], [] for i, (f_in, f_out) in enumerate(zip([self.F_e] + self.F_h[:-1], self.F_h)): conv = GraphConv(f_in, f_out, self.N_B + self.D) self.conv.append(conv) self.register_child(conv) if i != 0: bn = BatchNorm(in_channels=f_in) self.register_child(bn) else: bn = None self.bn.append(bn) self.bn_skip = BatchNorm(in_channels=sum(self.F_h)) self.linear_skip = Linear_BN(sum(self.F_h), self.F_skip) # projectors for conditional variable (protein embedding) self.linear_c = [] for i, f_out in enumerate(self.F_h): if self.rename: linear_c = nn.Dense(f_out, use_bias=False, in_units=self.N_C, prefix='cond_{}'.format(i)) else: linear_c = nn.Dense(f_out, use_bias=False, in_units=self.N_C) self.register_child(linear_c) self.linear_c.append(linear_c) def _graph_conv_forward(self, X, A, c, ids): X_out = [X] for conv, bn, linear_c in zip(self.conv, self.bn, self.linear_c): X = X_out[-1] if bn is not None: X_out.append(conv(self.activation(bn(X)), A) + linear_c(c)[ids, :]) else: X_out.append(conv(X, A) + linear_c(c)[ids, :]) X_out = nd.concat(*X_out[1:], dim=1) return self.activation(self.linear_skip(self.activation(self.bn_skip(X_out)))) def _decode_step(X, A, NX, NA, last_action, finished, get_init, get_action, random=False, n_node_types=get_mol_spec().num_atom_types, n_edge_types=get_mol_spec().num_bond_types): if X is None: init = get_init() if random: X = [] for i in range(init.shape[0]): # init probabilities(for first atom) p = init[i, :] # Random sampling using init probability distribution selected_atom = np.random.choice(np.arange(init.shape[1]), 1, p=p)[0] X.append(selected_atom) X = np.array(X, dtype=np.int64) else: X = np.argmax(init, axis=1) A = np.zeros((0, 3), dtype=np.int64) NX = last_action = np.ones([X.shape[0]], dtype=np.int64) NA = np.zeros([X.shape[0]], dtype=np.int64) finished = np.array([False, ] * X.shape[0], dtype=np.bool) return X, A, NX, NA, last_action, finished else: X_u = X[np.repeat(np.logical_not(finished), NX)] A_u = A[np.repeat(np.logical_not(finished), NA), :] NX_u = NX[np.logical_not(finished)] NA_u = NA[np.logical_not(finished)] last_action_u = last_action[np.logical_not(finished)] # conv mol_ids_rep = NX_rep = np.repeat(np.arange(NX_u.shape[0]), NX_u) rep_ids_rep = np.zeros_like(mol_ids_rep) if A.shape[0] == 0: D_2 = D_3 = np.zeros((0, 2), dtype=np.int64) A_u = [np.zeros((0, 2), dtype=np.int64) for _ in range(get_mol_spec().num_bond_types)] A_u += [D_2, D_3] else: cumsum = np.cumsum(np.pad(NX_u, [[1, 0]], mode='constant')[:-1]) shift = np.repeat(cumsum, NA_u) A_u[:, :2] += np.stack([shift, ] * 2, axis=1) D_2, D_3 = get_d(A_u, X_u) A_u = [A_u[A_u[:, 2] == _i, :2] for _i in range(n_edge_types)] A_u += [D_2, D_3] mask = np.zeros([X_u.shape[0]], dtype=np.int64) last_append_index = np.cumsum(NX_u) - 1 mask[last_append_index] = np.where(last_action_u == 1, np.ones_like(last_append_index, dtype=np.int64), np.ones_like(last_append_index, dtype=np.int64) * 2) decode_input = [X_u, A_u, NX_u, NX_rep, mask, mol_ids_rep, rep_ids_rep] append, connect, end = get_action(decode_input) if A.shape[0] == 0: max_index = np.argmax(np.reshape(append, [-1, n_node_types * n_edge_types]), axis=1) atom_type, bond_type = np.unravel_index(max_index, [n_node_types, n_edge_types]) X = np.reshape(np.stack([X, atom_type], axis=1), [-1]) NX = np.array([2, ] * len(finished), dtype=np.int64) A = np.stack([np.zeros([len(finished), ], dtype=np.int64), np.ones([len(finished), ], dtype=np.int64), bond_type], axis=1) NA = np.ones([len(finished), ], dtype=np.int64) last_action = np.ones_like(NX, dtype=np.int64) else: # process for each molecule append, connect = np.split(append, np.cumsum(NX_u)), np.split(connect, np.cumsum(NX_u)) end = end.tolist() unfinished_ids = np.where(np.logical_not(finished))[0].tolist() cumsum = np.cumsum(NX) cumsum_a = np.cumsum(NA) X_insert = [] X_insert_ids = [] A_insert = [] A_insert_ids = [] finished_ids = [] for i, (unfinished_id, append_i, connect_i, end_i) \ in enumerate(zip(unfinished_ids, append, connect, end)): if random: def _rand_id(*_x): _x_reshaped = [np.reshape(_xi, [-1]) for _xi in _x] _x_length = np.array([_x_reshape_i.shape[0] for _x_reshape_i in _x_reshaped], dtype=np.int64) _begin = np.cumsum(np.pad(_x_length, [[1, 0]], mode='constant')[:-1]) _end = np.cumsum(_x_length) - 1 _p = np.concatenate(_x_reshaped) _p = _p / np.sum(_p) # Count NaN values num_nan = np.isnan(_p).sum() if num_nan > 0: print(f'Number of NaN values in _p: {num_nan}') _rand_index = np.random.choice(np.arange(len(_p)), 1)[0] else: _rand_index = np.random.choice(np.arange(_p.shape[0]), 1, p=_p)[0] _p_step = _p[_rand_index] _x_index = np.where(np.logical_and(_begin <= _rand_index, _end >= _rand_index))[0][0] _rand_index = _rand_index - _begin[_x_index] _rand_index = np.unravel_index(_rand_index, _x[_x_index].shape) return _x_index, _rand_index, _p_step action_type, action_index, p_step = _rand_id(append_i, connect_i, np.array([end_i])) else: _argmax = lambda _x: np.unravel_index(np.argmax(_x), _x.shape) append_id, append_val = _argmax(append_i), np.max(append_i) connect_id, connect_val = _argmax(connect_i), np.max(connect_i) end_val = end_i if end_val >= append_val and end_val >= connect_val: action_type = 2 action_index = None elif append_val >= connect_val and append_val >= end_val: action_type = 0 action_index = append_id else: action_type = 1 action_index = connect_id if action_type == 2: # finish growth finished_ids.append(unfinished_id) elif action_type == 0: # append action append_pos, atom_type, bond_type = action_index X_insert.append(atom_type) X_insert_ids.append(unfinished_id) A_insert.append([append_pos, NX[unfinished_id], bond_type]) A_insert_ids.append(unfinished_id) else: # connect connect_ps, bond_type = action_index A_insert.append([NX[unfinished_id] - 1, connect_ps, bond_type]) A_insert_ids.append(unfinished_id) if len(A_insert_ids) > 0: A = np.insert(A, cumsum_a[A_insert_ids], A_insert, axis=0) NA[A_insert_ids] += 1 last_action[A_insert_ids] = 0 if len(X_insert_ids) > 0: X = np.insert(X, cumsum[X_insert_ids], X_insert, axis=0) NX[X_insert_ids] += 1 last_action[X_insert_ids] = 1 if len(finished_ids) > 0: finished[finished_ids] = True # print finished return X, A, NX, NA, last_action, finished class Builder(object, metaclass=ABCMeta): def __init__(self, model_loc, gpu_id=None): with open(os.path.join(model_loc, 'configs.json')) as f: configs = json.load(f) self.mdl = self.__class__._get_model(configs) self.ctx = mx.gpu(gpu_id) if gpu_id is not None else mx.cpu() self.mdl.load_parameters(os.path.join(model_loc, 'ckpt.params'), ctx=self.ctx, allow_missing=True) @staticmethod def _get_model(configs): raise NotImplementedError @abstractmethod def sample(self, num_samples, *args, **kwargs): raise NotImplementedError class CVanilla_RNN_Builder(Builder): @staticmethod def _get_model(configs): return CVanillaMolGen_RNN(get_mol_spec().num_atom_types, get_mol_spec().num_bond_types, D=2, **configs) def sample(self, num_samples, c, output_type='mol', sanitize=True, random=True): if len(c.shape) == 1: c = np.stack([c, ]*num_samples, axis=0) with autograd.predict_mode(): # step one finished = [False, ] * num_samples def get_init(): self.mdl.mode = 'decode_0' _c = nd.array(c, dtype='float32', ctx=self.ctx) init = self.mdl(_c).asnumpy() return init outputs = _decode_step(X=None, A=None, NX=None, NA=None, last_action=None, finished=finished, get_init=get_init, get_action=None, n_node_types=self.mdl.N_A, n_edge_types=self.mdl.N_B, random=random) # If outputs is None if outputs is None: return None X, A, NX, NA, last_action, finished = outputs count = 1 h = np.zeros([self.mdl.N_rnn, num_samples, self.mdl.F_c[-1]], dtype=np.float32) while not np.all(finished) and count < 100: def get_action(inputs): self.mdl.mode = 'decode_step' _h = nd.array(h[:, np.logical_not(finished), :], ctx=self.ctx, dtype='float32') _c = nd.array(c[np.logical_not(finished), :], ctx=self.ctx, dtype='float32') _X, _A_sparse, _NX, _NX_rep, _mask, _NX_cum = self.to_nd(inputs) _append, _connect, _end, _h = self.mdl(_X, _A_sparse, _NX, _NX_rep, _mask, _NX_cum, _h, _c, _NX_rep) h[:, np.logical_not(finished), :] = _h[0].asnumpy() return _append.asnumpy(), _connect.asnumpy(), _end.asnumpy() outputs = _decode_step(X, A, NX, NA, last_action, finished, get_init=None, get_action=get_action, n_node_types=self.mdl.N_A, n_edge_types=self.mdl.N_B, random=random) X, A, NX, NA, last_action, finished = outputs count += 1 graph_list = [] cumsum_X_ = np.cumsum(np.pad(NX, [[1, 0]], mode='constant')).tolist() cumsum_A_ = np.cumsum(np.pad(NA, [[1, 0]], mode='constant')).tolist() for cumsum_A_pre, cumsum_A_post, \ cumsum_X_pre, cumsum_X_post in zip(cumsum_A_[:-1], cumsum_A_[1:], cumsum_X_[:-1], cumsum_X_[1:]): graph_list.append([X[cumsum_X_pre:cumsum_X_post], A[cumsum_A_pre:cumsum_A_post, :]]) if output_type=='graph': return graph_list elif output_type == 'mol': return get_mol_from_graph_list(graph_list, sanitize) elif output_type == 'smiles': mol_list = get_mol_from_graph_list(graph_list, sanitize=True) smiles_list = [Chem.MolToSmiles(m) if m is not None else None for m in mol_list] return smiles_list else: raise ValueError('Unrecognized output type') def to_nd(self, inputs): X, A, NX, NX_rep, mask = inputs[:-2] NX_cum = np.cumsum(NX) # convert to ndarray _to_ndarray = lambda _x: nd.array(_x, self.ctx, 'int64') X, NX, NX_rep, mask, NX_cum = \ _to_ndarray(X), _to_ndarray(NX), _to_ndarray(NX_rep), _to_ndarray(mask), _to_ndarray(NX_cum) A_sparse = [] for _A_i in A: if _A_i.shape[0] == 0: A_sparse.append(None) else: # transpose may not be supported in gpu _A_i = np.concatenate([_A_i, _A_i[:, [1, 0]]], axis=0) # construct csr matrix ... _data = np.ones((_A_i.shape[0],), dtype=np.float32) _row, _col = _A_i[:, 0], _A_i[:, 1] _A_sparse_i = nd.sparse.csr_matrix((_data, (_row, _col)), shape=tuple([int(X.shape[0]), ] * 2), ctx=self.ctx, dtype='float32') # append to list A_sparse.append(_A_sparse_i) return X, A_sparse, NX, NX_rep, mask, NX_cum