ProttosmilesappHF / modelstrc.py
Bhanushray's picture
Upload 16 files
f829b9d verified
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit import DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem import rdmolfiles
from rdkit.Chem.Draw import IPythonConsole
from molvs import standardize_smiles
import os
import gc
import sys
import time
import json
import math
import random
import argparse
import itertools
import numpy as np
import mxnet as mx
import pandas as pd
import networkx as nx
from scipy import sparse
from mxnet.gluon import nn
from collections import Counter
from mxnet.autograd import Function
from mxnet.gluon.data import Dataset
from mxnet import gluon, autograd, nd
from mxnet.gluon.data import DataLoader
from abc import ABCMeta, abstractmethod
from mxnet.gluon.data.sampler import Sampler
class MoleculeSpec(object):
def __init__(self, file_name='models_folder/atom_types.txt'):
self.atom_types = []
self.atom_symbols = []
with open(file_name) as f:
for line in f:
atom_type_i = line.strip('\n').split(',')
self.atom_types.append((atom_type_i[0], int(atom_type_i[1]), int(atom_type_i[2])))
if atom_type_i[0] not in self.atom_symbols:
self.atom_symbols.append(atom_type_i[0])
self.bond_orders = [Chem.BondType.AROMATIC,
Chem.BondType.SINGLE,
Chem.BondType.DOUBLE,
Chem.BondType.TRIPLE]
self.max_iter = 120
def get_atom_type(self, atom):
atom_symbol = atom.GetSymbol()
atom_charge = atom.GetFormalCharge()
atom_hs = atom.GetNumExplicitHs()
return self.atom_types.index((atom_symbol, atom_charge, atom_hs))
def get_bond_type(self, bond):
return self.bond_orders.index(bond.GetBondType())
def index_to_atom(self, idx):
atom_symbol, atom_charge, atom_hs = self.atom_types[idx]
a = Chem.Atom(atom_symbol)
a.SetFormalCharge(atom_charge)
a.SetNumExplicitHs(atom_hs)
return a
def index_to_bond(self, mol, begin_id, end_id, idx):
mol.AddBond(begin_id, end_id, self.bond_orders[idx])
@property
def num_atom_types(self):
return len(self.atom_types)
@property
def num_bond_types(self):
return len(self.bond_orders)
_mol_spec = None
def get_mol_spec():
global _mol_spec
if _mol_spec is None:
_mol_spec = MoleculeSpec()
return _mol_spec
def get_graph_from_smiles(smiles):
mol = Chem.MolFromSmiles(smiles)
# build graph
atom_types, atom_ranks, bonds, bond_types = [], [], [], []
for a, r in zip(mol.GetAtoms(), Chem.CanonicalRankAtoms(mol)):
atom_types.append(get_mol_spec().get_atom_type(a))
atom_ranks.append(r)
for b in mol.GetBonds():
idx_1, idx_2, bt = b.GetBeginAtomIdx(), b.GetEndAtomIdx(), get_mol_spec().get_bond_type(b)
bonds.append([idx_1, idx_2])
bond_types.append(bt)
# build nx graph
graph = nx.Graph()
graph.add_nodes_from(range(len(atom_types)))
graph.add_edges_from(bonds)
return graph, atom_types, atom_ranks, bonds, bond_types
def get_graph_from_smiles_list(smiles_list):
graph_list = []
for smiles in smiles_list:
mol = Chem.MolFromSmiles(smiles)
# build graph
atom_types, bonds, bond_types = [], [], []
for a in mol.GetAtoms():
atom_types.append(get_mol_spec().get_atom_type(a))
for b in mol.GetBonds():
idx_1, idx_2, bt = b.GetBeginAtomIdx(), b.GetEndAtomIdx(), get_mol_spec().get_bond_type(b)
bonds.append([idx_1, idx_2])
bond_types.append(bt)
X_0 = np.array(atom_types, dtype=np.int64)
A_0 = np.concatenate([np.array(bonds, dtype=np.int64),
np.array(bond_types, dtype=np.int64)[:, np.newaxis]],
axis=1)
graph_list.append([X_0, A_0])
return graph_list
def traverse_graph(graph, atom_ranks, current_node=None, step_ids=None, p=0.9, log_p=0.0):
if current_node is None:
next_nodes = range(len(atom_ranks))
step_ids = [-1, ] * len(next_nodes)
next_node_ranks = atom_ranks
else:
next_nodes = graph.neighbors(current_node) # get neighbor nodes
next_nodes = [n for n in next_nodes if step_ids[n] < 0] # filter visited nodes
next_node_ranks = [atom_ranks[n] for n in next_nodes] # get ranks for neighbors
next_nodes = [n for n, r in sorted(zip(next_nodes, next_node_ranks), key=lambda _x:_x[1])] # sort by rank
# iterate through neighbors
while len(next_nodes) > 0:
if len(next_nodes)==1:
next_node = next_nodes[0]
elif random.random() >= (1 - p):
next_node = next_nodes[0]
log_p += np.log(p)
else:
next_node = next_nodes[random.randint(1, len(next_nodes) - 1)]
log_p += np.log((1.0 - p) / (len(next_nodes) - 1))
step_ids[next_node] = max(step_ids) + 1
_, log_p = traverse_graph(graph, atom_ranks, next_node, step_ids, p, log_p)
next_nodes = [n for n in next_nodes if step_ids[n] < 0] # filter visited nodes
return step_ids, log_p
def single_reorder(X_0, A_0, step_ids):
X_0, A_0 = np.copy(X_0), np.copy(A_0)
step_ids = np.array(step_ids, dtype=np.int64)
# sort by step_ids
sorted_ids = np.argsort(step_ids)
X_0 = X_0[sorted_ids]
A_0[:, 0], A_0[:, 1] = step_ids[A_0[:, 0]], step_ids[A_0[:, 1]]
max_b, min_b = np.amax(A_0[:, :2], axis=1), np.amin(A_0[:, :2], axis=1)
A_0 = A_0[np.lexsort([-min_b, max_b]), :]
# separate append and connect
max_b, min_b = np.amax(A_0[:, :2], axis=1), np.amin(A_0[:, :2], axis=1)
is_append = np.concatenate([np.array([True]), max_b[1:] > max_b[:-1]])
A_0 = np.concatenate([np.where(is_append[:, np.newaxis],
np.stack([min_b, max_b], axis=1),
np.stack([max_b, min_b], axis=1)),
A_0[:, -1:]], axis=1)
return X_0, A_0
def single_expand(X_0, A_0):
X_0, A_0 = np.copy(X_0), np.copy(A_0)
# expand X
is_append_iter = np.less(A_0[:, 0], A_0[:, 1]).astype(np.int64)
NX = np.cumsum(np.pad(is_append_iter, [[1, 0]], mode='constant', constant_values=1))
shift = np.cumsum(np.pad(NX, [[1, 0]], mode='constant')[:-1])
X_index = np.arange(NX.sum(), dtype=np.int64) - np.repeat(shift, NX)
X = X_0[X_index]
# expand A
_, A_index = np.tril_indices(A_0.shape[0])
A = A_0[A_index, :]
NA = np.arange(A_0.shape[0] + 1)
# get action
# action_type, atom_type, bond_type, append_pos, connect_pos
action_type = 1 - is_append_iter
atom_type = np.where(action_type == 0, X_0[A_0[:, 1]], 0)
bond_type = A_0[:, 2]
append_pos = np.where(action_type == 0, A_0[:, 0], 0)
connect_pos = np.where(action_type == 1, A_0[:, 1], 0)
actions = np.stack([action_type, atom_type, bond_type, append_pos, connect_pos],
axis=1)
last_action = [[2, 0, 0, 0, 0]]
actions = np.append(actions, last_action, axis=0)
action_0 = np.array([X_0[0]], dtype=np.int64)
# }}}
# {{{ Get mask
last_atom_index = shift + NX - 1
last_atom_mask = np.zeros_like(X)
last_atom_mask[last_atom_index] = np.where(
np.pad(is_append_iter, [[1, 0]], mode='constant', constant_values=1) == 1,
np.ones_like(last_atom_index),
np.ones_like(last_atom_index) * 2)
# }}}
return action_0, X, NX, A, NA, actions, last_atom_mask
def get_d(A, X):
_to_sparse = lambda _A, _X: sparse.coo_matrix((np.ones([_A.shape[0] * 2], dtype=np.int64),
(np.concatenate([_A[:, 0], _A[:, 1]], axis=0),
np.concatenate([_A[:, 1], _A[:, 0]], axis=0))),
shape=[_X.shape[0], ] * 2)
A_sparse = _to_sparse(A, X)
d2 = A_sparse * A_sparse
d3 = d2 * A_sparse
# get D_2
D_2 = np.stack(d2.nonzero(), axis=1)
D_2 = D_2[D_2[:, 0] < D_2[:, 1], :]
# get D_3
D_3 = np.stack(d3.nonzero(), axis=1)
D_3 = D_3[D_3[:, 0] < D_3[:, 1], :]
# remove D_1 elements from D_3
D_3_sparse = _to_sparse(D_3, X)
D_3_sparse = D_3_sparse - D_3_sparse.multiply(A_sparse)
D_3 = np.stack(D_3_sparse.nonzero(), axis=1)
D_3 = D_3[D_3[:, 0] < D_3[:, 1], :]
return D_2, D_3
def merge_single_0(X_0, A_0, NX_0, NA_0):
# shift_ids
cumsum = np.cumsum(np.pad(NX_0, [[1, 0]], mode='constant')[:-1])
A_0[:, :2] += np.stack([np.repeat(cumsum, NA_0), ] * 2, axis=1)
# get D
D_0_2, D_0_3 = get_d(A_0, X_0)
# split A
A_split = []
for i in range(get_mol_spec().num_bond_types):
A_i = A_0[A_0[:, 2] == i, :2]
A_split.append(A_i)
A_split.extend([D_0_2, D_0_3])
A_0 = A_split
# NX_rep
NX_rep_0 = np.repeat(np.arange(NX_0.shape[0]), NX_0)
return X_0, A_0, NX_0, NX_rep_0
def merge_single(X, A,
NX, NA,
mol_ids, rep_ids, iw_ids,
action_0, actions,
last_append_mask,
log_p):
X, A, NX, NX_rep = merge_single_0(X, A, NX, NA)
cumsum = np.cumsum(np.pad(NX, [[1, 0]], mode='constant')[:-1])
actions[:, -2] += cumsum * (actions[:, 0] == 0)
actions[:, -1] += cumsum * (actions[:, 0] == 1)
mol_ids_rep = np.repeat(mol_ids, NX)
rep_ids_rep = np.repeat(rep_ids, NX)
return X, A,\
mol_ids_rep, rep_ids_rep, iw_ids,\
last_append_mask,\
NX, NX_rep,\
action_0, actions, \
log_p
def process_single(smiles, k, p):
graph, atom_types, atom_ranks, bonds, bond_types = get_graph_from_smiles(smiles)
# original
X_0 = np.array(atom_types, dtype=np.int64)
A_0 = np.concatenate([np.array(bonds, dtype=np.int64),
np.array(bond_types, dtype=np.int64)[:, np.newaxis]],
axis=1)
X, A = [], []
NX, NA = [], []
mol_ids, rep_ids, iw_ids = [], [], []
action_0, actions = [], []
last_append_mask = []
log_p = []
# random sampling decoding route
for i in range(k):
step_ids_i, log_p_i = traverse_graph(graph, atom_ranks, p=p)
X_i, A_i = single_reorder(X_0, A_0, step_ids_i)
action_0_i, X_i, NX_i, A_i, NA_i, actions_i, last_atom_mask_i = single_expand(X_i, A_i)
# appends
X.append(X_i)
A.append(A_i)
NX.append(NX_i)
NA.append(NA_i)
action_0.append(action_0_i)
actions.append(actions_i)
last_append_mask.append(last_atom_mask_i)
mol_ids.append(np.zeros_like(NX_i, dtype=np.int64))
rep_ids.append(np.ones_like(NX_i, dtype=np.int64) * i)
iw_ids.append(np.ones_like(NX_i, dtype=np.int64) * i)
log_p.append(log_p_i)
# concatenate
X = np.concatenate(X, axis=0)
A = np.concatenate(A, axis = 0)
NX = np.concatenate(NX, axis = 0)
NA = np.concatenate(NA, axis = 0)
action_0 = np.concatenate(action_0, axis = 0)
actions = np.concatenate(actions, axis = 0)
last_append_mask = np.concatenate(last_append_mask, axis = 0)
mol_ids = np.concatenate(mol_ids, axis = 0)
rep_ids = np.concatenate(rep_ids, axis = 0)
iw_ids = np.concatenate(iw_ids, axis = 0)
log_p = np.array(log_p, dtype=np.float32)
return X, A, NX, NA, mol_ids, rep_ids, iw_ids, action_0, actions, last_append_mask, log_p
# noinspection PyArgumentList
def get_mol_from_graph(X, A, sanitize=True):
try:
mol = Chem.RWMol(Chem.Mol())
X, A = X.tolist(), A.tolist()
for i, atom_type in enumerate(X):
mol.AddAtom(get_mol_spec().index_to_atom(atom_type))
for atom_id1, atom_id2, bond_type in A:
get_mol_spec().index_to_bond(mol, atom_id1, atom_id2, bond_type)
except:
return None
if sanitize:
try:
mol = mol.GetMol()
Chem.SanitizeMol(mol)
return mol
except:
return None
else:
return mol
def get_mol_from_graph_list(graph_list, sanitize=True):
mol_list = [get_mol_from_graph(X, A, sanitize) for X, A in graph_list]
return mol_list
class GraphConvFn(Function):
def __init__(self, A):
self.A = A # type: nd.sparse.CSRNDArray
self.A_T = self.A # assume symmetric
super(GraphConvFn, self).__init__()
def forward(self, X):
if self.A is not None:
if len(X.shape) > 2:
X_resized = X.reshape((X.shape[0], -1))
output = nd.sparse.dot(self.A, X_resized)
output = output.reshape([-1, ] + [X.shape[i] for i in range(1, len(X.shape))])
else:
output = nd.sparse.dot(self.A, X)
return output
else:
return nd.zeros_like(X)
def backward(self, grad_output):
if self.A is not None:
if len(grad_output.shape) > 2:
grad_output_resized = grad_output.reshape((grad_output.shape[0], -1))
grad_input = nd.sparse.dot(self.A_T, grad_output_resized)
grad_input = grad_input.reshape([-1] + [grad_output.shape[i]
for i in range(1, len(grad_output.shape))])
else:
grad_input = nd.sparse.dot(self.A_T, grad_output)
return grad_input
else:
return nd.zeros_like(grad_output)
class EfficientGraphConvFn(Function):
"""Save memory by re-computation"""
def __init__(self, A_list):
self.A_list = A_list
super(EfficientGraphConvFn, self).__init__()
def forward(self, X, W):
X_list = [X]
for A in self.A_list:
if A is not None:
X_list.append(nd.sparse.dot(A, X))
else:
X_list.append(nd.zeros_like(X))
X_out = nd.concat(*X_list, dim=1)
self.save_for_backward(X, W)
return nd.dot(X_out, W)
def backward(self, grad_output):
X, W = self.saved_tensors
# recompute X_out
X_list = [X, ]
for A in self.A_list:
if A is not None:
X_list.append(nd.sparse.dot(A, X))
else:
X_list.append(nd.zeros_like(X))
X_out = nd.concat(*X_list, dim=1)
grad_W = nd.dot(X_out.T, grad_output)
grad_X_out = nd.dot(grad_output, W.T)
grad_X_out_list = nd.split(grad_X_out, num_outputs=len(self.A_list) + 1)
grad_X = [grad_X_out_list[0], ]
for A, grad_X_out in zip(self.A_list, grad_X_out_list[1:]):
if A is not None:
grad_X.append(nd.sparse.dot(A, grad_X_out))
else:
grad_X.append(nd.zeros_like(grad_X_out))
grad_X = sum(grad_X)
return grad_X, grad_W
class SegmentSumFn(GraphConvFn):
def __init__(self, idx, num_seg):
# build A
# construct coo
data = nd.ones(idx.shape[0], ctx=idx.context, dtype='int64')
row, col = idx, nd.arange(idx.shape[0], ctx=idx.context, dtype='int64')
shape = (num_seg, int(idx.shape[0]))
sparse = nd.sparse.csr_matrix((data, (row, col)), shape=shape,
ctx=idx.context, dtype='float32')
super(SegmentSumFn, self).__init__(sparse)
sparse = nd.sparse.csr_matrix((data, (col, row)), shape=(shape[1], shape[0]),
ctx=idx.context, dtype='float32')
self.A_T = sparse
def squeeze(input, axis):
assert input.shape[axis] == 1
new_shape = list(input.shape)
del new_shape[axis]
return input.reshape(new_shape)
def unsqueeze(input, axis):
return nd.expand_dims(input, axis=axis)
def logsumexp(inputs, axis=None, keepdims=False):
"""Numerically stable logsumexp.
Args:
inputs: A Variable with any shape.
axis: An integer.
keepdims: A boolean.
Returns:
Equivalent of log(sum(exp(inputs), dim=dim, keepdim=keepdim)).
Adopted from: https://github.com/pytorch/pytorch/issues/2591
"""
# For a 1-D array x (any array along a single dimension),
# log sum exp(x) = s + log sum exp(x - s)
# with s = max(x) being a common choice.
if axis is None:
inputs = inputs.reshape([-1])
axis = 0
s = nd.max(inputs, axis=axis, keepdims=True)
outputs = s + (inputs - s).exp().sum(axis=axis, keepdims=True).log()
if not keepdims:
outputs = nd.sum(outputs, axis=axis, keepdims=False)
return outputs
def get_activation(name):
activation_dict = {
'relu':nd.relu,
'tanh':nd.tanh
}
return activation_dict[name]
class Linear_BN(nn.Sequential):
def __init__(self, F_in, F_out):
super(Linear_BN, self).__init__()
self.add(nn.Dense(F_out, in_units=F_in, use_bias=False))
self.add(BatchNorm(in_channels=F_out))
class GraphConv(nn.Block):
def __init__(self, Fin, Fout, D):
super(GraphConv, self).__init__()
# model settings
self.Fin = Fin
self.Fout = Fout
self.D = D
# model parameters
self.W = self.params.get('w', shape=(self.Fin * (self.D + 1), self.Fout),
init=None, allow_deferred_init=False)
def forward(self, X, A_list):
try:
assert len(A_list) == self.D
except AssertionError as e:
print(self.D, len(A_list))
raise e
return EfficientGraphConvFn(A_list)(X, self.W.data(X.context))
class Policy(nn.Block):
def __init__(self, F_in, F_h, N_A, N_B, k=1):
super(Policy, self).__init__()
self.F_in = F_in # number of input features for each atom
self.F_h = F_h # number of context variables
self.N_A = N_A # number of atom types
self.N_B = N_B # number of bond types
self.k = k # number of softmax used in the mixture
with self.name_scope():
self.linear_h = Linear_BN(F_in * 2, self.F_h * k)
self.linear_h_t = Linear_BN(F_in, self.F_h * k)
self.linear_x = nn.Dense(self.N_B + self.N_B*self.N_A, in_units=self.F_h)
self.linear_x_t = nn.Dense(1, in_units=self.F_h)
if self.k > 1:
self.linear_pi = nn.Dense(self.k, in_units=self.F_in)
else:
self.linear_pi = None
def forward(self, X, NX, NX_rep, X_end=None):
# segment mean for X
if X_end is None:
X_end = SegmentSumFn(NX_rep, NX.shape[0])(X)/nd.cast(fn.unsqueeze(NX, 1), 'float32')
X = nd.concat(X, X_end[NX_rep, :], dim=1)
X_h = nd.relu(self.linear_h(X)).reshape([-1, self.F_h])
X_h_end = nd.relu(self.linear_h_t(X_end)).reshape([-1, self.F_h])
X_x = nd.exp(self.linear_x(X_h)).reshape([-1, self.k, self.N_B + self.N_B*self.N_A])
X_x_end = nd.exp(self.linear_x_t(X_h_end)).reshape([-1, self.k, 1])
X_sum = nd.sum(SegmentSumFn(NX_rep, NX.shape[0])(X_x), -1, keepdims=True) + X_x_end
X_sum_gathered = X_sum[NX_rep, :, :]
X_softmax = X_x / X_sum_gathered
X_softmax_end = X_x_end/ X_sum
if self.k > 1:
pi = unsqueeze(nd.softmax(self.linear_pi(X_end), axis=1), -1)
pi_gathered = pi[NX_rep, :, :]
X_softmax = nd.sum(X_softmax * pi_gathered, axis=1)
X_softmax_end = nd.sum(X_softmax_end * pi, axis=1)
else:
X_softmax = squeeze(X_softmax, 1)
X_softmax_end = squeeze(X_softmax_end, 1)
# generate output probabilities
connect, append = X_softmax[:, :self.N_B], X_softmax[:, self.N_B:]
append = append.reshape([-1, self.N_A, self.N_B])
end = squeeze(X_softmax_end, -1)
return append, connect, end
class BatchNorm(nn.Block):
def __init__(self, in_channels, momentum=0.9, eps=1e-5):
super(BatchNorm, self).__init__()
self.F = in_channels
self.bn_weight = self.params.get('bn_weight', shape=(self.F,), init=mx.init.One(),
allow_deferred_init=False)
self.bn_bias = self.params.get('bn_bias', shape=(self.F,), init=mx.init.Zero(),
allow_deferred_init=False)
self.running_mean = self.params.get('running_mean', grad_req='null',
shape=(self.F,),
init=mx.init.Zero(),
allow_deferred_init=False,
differentiable=False)
self.running_var = self.params.get('running_var', grad_req='null',
shape=(self.F,),
init=mx.init.One(),
allow_deferred_init=False,
differentiable=False)
self.momentum = momentum
self.eps = eps
def forward(self, x):
if autograd.is_training():
return nd.BatchNorm(x,
gamma=self.bn_weight.data(x.context),
beta=self.bn_bias.data(x.context),
moving_mean=self.running_mean.data(x.context),
moving_var=self.running_var.data(x.context),
eps=self.eps, momentum=self.momentum,
use_global_stats=False)
else:
return nd.BatchNorm(x,
gamma=self.bn_weight.data(x.context),
beta=self.bn_bias.data(x.context),
moving_mean=self.running_mean.data(x.context),
moving_var=self.running_var.data(x.context),
eps=self.eps, momentum=self.momentum,
use_global_stats=True)
class MoleculeGenerator(nn.Block):
__metaclass__ = ABCMeta
def __init__(self, N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation,
*args, **kwargs):
super(MoleculeGenerator, self).__init__()
self.N_A = N_A
self.N_B = N_B
self.D = D
self.F_e = F_e
self.F_skip = F_skip
self.F_c = list(F_c) if isinstance(F_c, tuple) else F_c
self.Fh_policy = Fh_policy
self.activation = get_activation(activation)
with self.name_scope():
# embeddings
self.embedding_atom = nn.Embedding(self.N_A, self.F_e)
self.embedding_mask = nn.Embedding(3, self.F_e)
# graph conv
self._build_graph_conv(*args, **kwargs)
# fully connected
self.dense = nn.Sequential()
for i, (f_in, f_out) in enumerate(zip([self.F_skip, ] + self.F_c[:-1], self.F_c)):
self.dense.add(Linear_BN(f_in, f_out))
# policy
self.policy_0 = self.params.get('policy_0', shape=[self.N_A, ],
init=mx.init.Zero(),
allow_deferred_init=False)
self.policy_h = Policy(self.F_c[-1], self.Fh_policy, self.N_A, self.N_B)
self.mode = 'loss'
@abstractmethod
def _build_graph_conv(self, *args, **kwargs):
raise NotImplementedError
@abstractmethod
def _graph_conv_forward(self, X, A):
raise NotImplementedError
def _policy_0(self, ctx):
policy_0 = nd.exp(self.policy_0.data(ctx))
policy_0 = policy_0/policy_0.sum()
return policy_0
def _policy(self, X, A, NX, NX_rep, last_append_mask):
# get initial embedding
X = self.embedding_atom(X) + self.embedding_mask(last_append_mask)
# convolution
X = self._graph_conv_forward(X, A)
# linear
X = self.dense(X)
# policy
append, connect, end = self.policy_h(X, NX, NX_rep)
return append, connect, end
def _likelihood(self, init, append, connect, end,
action_0, actions, iw_ids, log_p_sigma,
batch_size, iw_size):
# decompose action:
action_type, node_type, edge_type, append_pos, connect_pos = \
actions[:, 0], actions[:, 1], actions[:, 2], actions[:, 3], actions[:, 4]
_log_mask = lambda _x, _mask: _mask * nd.log(_x + 1e-10) + (1- _mask) * nd.zeros_like(_x)
# init
init = init.reshape([batch_size * iw_size, self.N_A])
index = nd.stack(nd.arange(action_0.shape[0], ctx=action_0.context, dtype='int64'), action_0, axis=0)
loss_init = nd.log(nd.gather_nd(init, index) + 1e-10)
# end
loss_end = _log_mask(end, nd.cast(action_type == 2, 'float32'))
# append
index = nd.stack(append_pos, node_type, edge_type, axis=0)
loss_append = _log_mask(nd.gather_nd(append, index), nd.cast(action_type == 0, 'float32'))
# connect
index = nd.stack(connect_pos, edge_type, axis=0)
loss_connect = _log_mask(nd.gather_nd(connect, index), nd.cast(action_type == 1, 'float32'))
# sum up results
log_p_x = loss_end + loss_append + loss_connect
log_p_x = squeeze(SegmentSumFn(iw_ids, batch_size*iw_size)(unsqueeze(log_p_x, -1)), -1)
log_p_x = log_p_x + loss_init
# reshape
log_p_x = log_p_x.reshape([batch_size, iw_size])
log_p_sigma = log_p_sigma.reshape([batch_size, iw_size])
l = log_p_x - log_p_sigma
l = logsumexp(l, axis=1) - math.log(float(iw_size))
return l
def forward(self, *input):
if self.mode=='loss' or self.mode=='likelihood':
X, A, iw_ids, last_append_mask, \
NX, NX_rep, action_0, actions, log_p, \
batch_size, iw_size = input
init = self._policy_0(X.context).tile([batch_size * iw_size, 1])
append, connect, end = self._policy(X, A, NX, NX_rep, last_append_mask)
l = self._likelihood(init, append, connect, end, action_0, actions, iw_ids, log_p, batch_size, iw_size)
if self.mode=='likelihood':
return l
else:
return -l.mean()
elif self.mode == 'decode_0':
return self._policy_0(input[0])
elif self.mode == 'decode_step':
X, A, NX, NX_rep, last_append_mask = input
return self._policy(X, A, NX, NX_rep, last_append_mask)
class MoleculeGenerator_RNN(MoleculeGenerator):
__metaclass__ = ABCMeta
def __init__(self, N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation,
N_rnn, *args, **kwargs):
super(MoleculeGenerator_RNN, self).__init__(N_A, N_B, D, F_e, F_skip, F_c, Fh_policy, activation,
*args, **kwargs)
self.N_rnn = N_rnn
with self.name_scope():
self.rnn = gluon.rnn.GRU(hidden_size=self.F_c[-1],
num_layers=self.N_rnn,
layout='NTC', input_size=self.F_c[-1] * 2)
def _rnn_train(self, X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum):
X_avg = SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(unsqueeze(NX, 1), 'float32')
X_curr = nd.take(X, indices=NX_cum-1)
X = nd.concat(X_avg, X_curr, dim=1)
# rnn
X = nd.take(X, indices=graph_to_rnn) # batch_size, iw_size, length, num_features
batch_size, iw_size, length, num_features = X.shape
X = X.reshape([batch_size*iw_size, length, num_features])
X = self.rnn(X)
X = X.reshape([batch_size, iw_size, length, -1])
X = nd.gather_nd(X, indices=rnn_to_graph)
return X
def _rnn_test(self, X, NX, NX_rep, NX_cum, h):
# note: one partition for one molecule
X_avg = SegmentSumFn(NX_rep, NX.shape[0])(X) / nd.cast(unsqueeze(NX, 1), 'float32')
X_curr = nd.take(X, indices=NX_cum - 1)
X = nd.concat(X_avg, X_curr, dim=1) # size: [NX, F_in * 2]
# rnn
X = unsqueeze(X, axis=1)
X, h = self.rnn(X, h)
X = squeeze(X, axis=1)
return X, h
def _policy(self, X, A, NX, NX_rep, last_append_mask, graph_to_rnn, rnn_to_graph, NX_cum):
# get initial embedding
X = self.embedding_atom(X) + self.embedding_mask(last_append_mask)
# convolution
X = self._graph_conv_forward(X, A)
# linear
X = self.dense(X)
# rnn
X_mol = self._rnn_train(X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum)
# policy
append, connect, end = self.policy_h(X, NX, NX_rep, X_mol)
return append, connect, end
def _decode_step(self, X, A, NX, NX_rep, last_append_mask, NX_cum, h):
# get initial embedding
X = self.embedding_atom(X) + self.embedding_mask(last_append_mask)
# convolution
X = self._graph_conv_forward(X, A)
# linear
X = self.dense(X)
# rnn
X_mol, h = self._rnn_test(X, NX, NX_rep, NX_cum, h)
# policy
append, connect, end = self.policy_h(X, NX, NX_rep, X_mol)
return append, connect, end, h
def forward(self, *input):
if self.mode=='loss' or self.mode=='likelihood':
X, A, iw_ids, last_append_mask, \
NX, NX_rep, action_0, actions, log_p, \
batch_size, iw_size, \
graph_to_rnn, rnn_to_graph, NX_cum = input
init = self._policy_0(X.context).tile([batch_size * iw_size, 1])
append, connect, end = self._policy(X, A, NX, NX_rep, last_append_mask, graph_to_rnn, rnn_to_graph, NX_cum)
l = self._likelihood(init, append, connect, end, action_0, actions, iw_ids, log_p, batch_size, iw_size)
if self.mode=='likelihood':
return l
else:
return -l.mean()
elif self.mode == 'decode_0':
return self._policy_0(input[0])
elif self.mode == 'decode_step':
X, A, NX, NX_rep, last_append_mask, NX_cum, h = input
return self._decode_step(X, A, NX, NX_rep, last_append_mask, NX_cum, h)
else:
raise ValueError
class _TwoLayerDense(nn.Block):
def __init__(self, input_size, hidden_size, output_size):
super(_TwoLayerDense, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.input_size = input_size
with self.name_scope():
# config 1
self.input = nn.Dense(self.hidden_size, use_bias=False, in_units=self.input_size)
self.bn_input = BatchNorm(in_channels=hidden_size)
self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.hidden_size)
# config 2
#self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.input_size)
# config 3
#self.input1 = nn.Dense(self.hidden_size, use_bias=False, in_units=self.input_size)
#self.bn_input1 = BatchNorm(in_channels=self.hidden_size)
#self.input2 = nn.Dense(self.hidden_size, use_bias=False, in_units=self.hidden_size)
#self.bn_input2 = BatchNorm(in_channels=self.hidden_size)
#self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.hidden_size)
# config 4
#self.bn_input = BatchNorm(in_channels=self.input_size)
#self.output = nn.Dense(self.output_size, use_bias=True, in_units=self.input_size)
# config 5
#self.bn_input = BatchNorm(in_channels=1024)
#self.output = nn.Dense(self.output_size, use_bias=True, in_units=1024)
def forward(self, c):
# config 1
return nd.softmax(self.output(nd.relu(self.bn_input(self.input(c)))), axis=-1)
# config 2
#return nd.softmax(self.output(c), axis=-1)
# config 3
#return nd.softmax(self.output(nd.relu(self.bn_input2(self.input2(nd.relu(self.bn_input1(self.input1(c))))))), axis=-1)
# config 4
#return nd.softmax(self.output(nd.relu(self.bn_input(c))), axis=-1)
# config 5
#return nd.softmax(self.output(c), axis=-1)
class CMoleculeGenerator_RNN(MoleculeGenerator_RNN):
__metaclass__ = ABCMeta
def __init__(self, N_A, N_B, N_C, D,
F_e, F_skip, F_c, Fh_policy,
activation, N_rnn,
*args, **kwargs):
self.N_C = N_C # number of conditional variables
super(CMoleculeGenerator_RNN, self).__init__(N_A, N_B, D,
F_e, F_skip, F_c, Fh_policy,
activation, N_rnn,
*args, **kwargs)
with self.name_scope():
self.dense_policy_0 = _TwoLayerDense(self.N_C, self.N_A * 3, self.N_A)
@abstractmethod
def _graph_conv_forward(self, X, A, c, ids):
raise NotImplementedError
def _policy_0(self, c):
return self.dense_policy_0(c) + 0.0 * self.policy_0.data(c.context)
def _policy(self, X, A, NX, NX_rep, last_append_mask,
graph_to_rnn, rnn_to_graph, NX_cum,
c, ids):
# get initial embedding
X = self.embedding_atom(X) + self.embedding_mask(last_append_mask)
# convolution
X = self._graph_conv_forward(X, A, c, ids)
# linear
X = self.dense(X)
# rnn
X_mol = self._rnn_train(X, NX, NX_rep, graph_to_rnn, rnn_to_graph, NX_cum)
# policy
append, connect, end = self.policy_h(X, NX, NX_rep, X_mol)
return append, connect, end
def _decode_step(self, X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids):
# get initial embedding
X = self.embedding_atom(X) + self.embedding_mask(last_append_mask)
# convolution
X = self._graph_conv_forward(X, A, c, ids)
# linear
X = self.dense(X)
# rnn
X_mol, h = self._rnn_test(X, NX, NX_rep, NX_cum, h)
# policy
append, connect, end = self.policy_h(X, NX, NX_rep, X_mol)
return append, connect, end, h
def forward(self, *input):
if self.mode=='loss' or self.mode=='likelihood':
X, A, iw_ids, last_append_mask, \
NX, NX_rep, action_0, actions, log_p, \
batch_size, iw_size, \
graph_to_rnn, rnn_to_graph, NX_cum, \
c, ids = input
init = nd.tile(unsqueeze(self._policy_0(c), axis=1), [1, iw_size, 1])
append, connect, end = self._policy(X, A, NX, NX_rep, last_append_mask,
graph_to_rnn, rnn_to_graph, NX_cum,
c, ids)
l = self._likelihood(init, append, connect, end,
action_0, actions, iw_ids, log_p,
batch_size, iw_size)
if self.mode=='likelihood':
return l
else:
return -l.mean()
elif self.mode == 'decode_0':
return self._policy_0(*input)
elif self.mode == 'decode_step':
X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids = input
return self._decode_step(X, A, NX, NX_rep, last_append_mask, NX_cum, h, c, ids)
else:
raise ValueError
class CVanillaMolGen_RNN(CMoleculeGenerator_RNN):
def __init__(self, N_A, N_B, N_C, D,
F_e, F_h, F_skip, F_c, Fh_policy,
activation, N_rnn, rename=False):
self.rename = rename
super(CVanillaMolGen_RNN, self).__init__(N_A, N_B, N_C, D,
F_e, F_skip, F_c, Fh_policy,
activation, N_rnn,
F_h)
def _build_graph_conv(self, F_h):
self.F_h = list(F_h) if isinstance(F_h, tuple) else F_h
self.conv, self.bn = [], []
for i, (f_in, f_out) in enumerate(zip([self.F_e] + self.F_h[:-1], self.F_h)):
conv = GraphConv(f_in, f_out, self.N_B + self.D)
self.conv.append(conv)
self.register_child(conv)
if i != 0:
bn = BatchNorm(in_channels=f_in)
self.register_child(bn)
else:
bn = None
self.bn.append(bn)
self.bn_skip = BatchNorm(in_channels=sum(self.F_h))
self.linear_skip = Linear_BN(sum(self.F_h), self.F_skip)
# projectors for conditional variable (protein embedding)
self.linear_c = []
for i, f_out in enumerate(self.F_h):
if self.rename:
linear_c = nn.Dense(f_out, use_bias=False, in_units=self.N_C, prefix='cond_{}'.format(i))
else:
linear_c = nn.Dense(f_out, use_bias=False, in_units=self.N_C)
self.register_child(linear_c)
self.linear_c.append(linear_c)
def _graph_conv_forward(self, X, A, c, ids):
X_out = [X]
for conv, bn, linear_c in zip(self.conv, self.bn, self.linear_c):
X = X_out[-1]
if bn is not None:
X_out.append(conv(self.activation(bn(X)), A) + linear_c(c)[ids, :])
else:
X_out.append(conv(X, A) + linear_c(c)[ids, :])
X_out = nd.concat(*X_out[1:], dim=1)
return self.activation(self.linear_skip(self.activation(self.bn_skip(X_out))))
def _decode_step(X, A, NX, NA, last_action, finished,
get_init, get_action,
random=False, n_node_types=get_mol_spec().num_atom_types,
n_edge_types=get_mol_spec().num_bond_types):
if X is None:
init = get_init()
if random:
X = []
for i in range(init.shape[0]):
# init probabilities(for first atom)
p = init[i, :]
# Random sampling using init probability distribution
selected_atom = np.random.choice(np.arange(init.shape[1]), 1, p=p)[0]
X.append(selected_atom)
X = np.array(X, dtype=np.int64)
else:
X = np.argmax(init, axis=1)
A = np.zeros((0, 3), dtype=np.int64)
NX = last_action = np.ones([X.shape[0]], dtype=np.int64)
NA = np.zeros([X.shape[0]], dtype=np.int64)
finished = np.array([False, ] * X.shape[0], dtype=np.bool)
return X, A, NX, NA, last_action, finished
else:
X_u = X[np.repeat(np.logical_not(finished), NX)]
A_u = A[np.repeat(np.logical_not(finished), NA), :]
NX_u = NX[np.logical_not(finished)]
NA_u = NA[np.logical_not(finished)]
last_action_u = last_action[np.logical_not(finished)]
# conv
mol_ids_rep = NX_rep = np.repeat(np.arange(NX_u.shape[0]), NX_u)
rep_ids_rep = np.zeros_like(mol_ids_rep)
if A.shape[0] == 0:
D_2 = D_3 = np.zeros((0, 2), dtype=np.int64)
A_u = [np.zeros((0, 2), dtype=np.int64) for _ in range(get_mol_spec().num_bond_types)]
A_u += [D_2, D_3]
else:
cumsum = np.cumsum(np.pad(NX_u, [[1, 0]], mode='constant')[:-1])
shift = np.repeat(cumsum, NA_u)
A_u[:, :2] += np.stack([shift, ] * 2, axis=1)
D_2, D_3 = get_d(A_u, X_u)
A_u = [A_u[A_u[:, 2] == _i, :2] for _i in range(n_edge_types)]
A_u += [D_2, D_3]
mask = np.zeros([X_u.shape[0]], dtype=np.int64)
last_append_index = np.cumsum(NX_u) - 1
mask[last_append_index] = np.where(last_action_u == 1,
np.ones_like(last_append_index, dtype=np.int64),
np.ones_like(last_append_index, dtype=np.int64) * 2)
decode_input = [X_u, A_u, NX_u, NX_rep, mask, mol_ids_rep, rep_ids_rep]
append, connect, end = get_action(decode_input)
if A.shape[0] == 0:
max_index = np.argmax(np.reshape(append, [-1, n_node_types * n_edge_types]), axis=1)
atom_type, bond_type = np.unravel_index(max_index, [n_node_types, n_edge_types])
X = np.reshape(np.stack([X, atom_type], axis=1), [-1])
NX = np.array([2, ] * len(finished), dtype=np.int64)
A = np.stack([np.zeros([len(finished), ], dtype=np.int64),
np.ones([len(finished), ], dtype=np.int64),
bond_type], axis=1)
NA = np.ones([len(finished), ], dtype=np.int64)
last_action = np.ones_like(NX, dtype=np.int64)
else:
# process for each molecule
append, connect = np.split(append, np.cumsum(NX_u)), np.split(connect, np.cumsum(NX_u))
end = end.tolist()
unfinished_ids = np.where(np.logical_not(finished))[0].tolist()
cumsum = np.cumsum(NX)
cumsum_a = np.cumsum(NA)
X_insert = []
X_insert_ids = []
A_insert = []
A_insert_ids = []
finished_ids = []
for i, (unfinished_id, append_i, connect_i, end_i) \
in enumerate(zip(unfinished_ids, append, connect, end)):
if random:
def _rand_id(*_x):
_x_reshaped = [np.reshape(_xi, [-1]) for _xi in _x]
_x_length = np.array([_x_reshape_i.shape[0] for _x_reshape_i in _x_reshaped],
dtype=np.int64)
_begin = np.cumsum(np.pad(_x_length, [[1, 0]], mode='constant')[:-1])
_end = np.cumsum(_x_length) - 1
_p = np.concatenate(_x_reshaped)
_p = _p / np.sum(_p)
# Count NaN values
num_nan = np.isnan(_p).sum()
if num_nan > 0:
print(f'Number of NaN values in _p: {num_nan}')
_rand_index = np.random.choice(np.arange(len(_p)), 1)[0]
else:
_rand_index = np.random.choice(np.arange(_p.shape[0]), 1, p=_p)[0]
_p_step = _p[_rand_index]
_x_index = np.where(np.logical_and(_begin <= _rand_index, _end >= _rand_index))[0][0]
_rand_index = _rand_index - _begin[_x_index]
_rand_index = np.unravel_index(_rand_index, _x[_x_index].shape)
return _x_index, _rand_index, _p_step
action_type, action_index, p_step = _rand_id(append_i, connect_i, np.array([end_i]))
else:
_argmax = lambda _x: np.unravel_index(np.argmax(_x), _x.shape)
append_id, append_val = _argmax(append_i), np.max(append_i)
connect_id, connect_val = _argmax(connect_i), np.max(connect_i)
end_val = end_i
if end_val >= append_val and end_val >= connect_val:
action_type = 2
action_index = None
elif append_val >= connect_val and append_val >= end_val:
action_type = 0
action_index = append_id
else:
action_type = 1
action_index = connect_id
if action_type == 2:
# finish growth
finished_ids.append(unfinished_id)
elif action_type == 0:
# append action
append_pos, atom_type, bond_type = action_index
X_insert.append(atom_type)
X_insert_ids.append(unfinished_id)
A_insert.append([append_pos, NX[unfinished_id], bond_type])
A_insert_ids.append(unfinished_id)
else:
# connect
connect_ps, bond_type = action_index
A_insert.append([NX[unfinished_id] - 1, connect_ps, bond_type])
A_insert_ids.append(unfinished_id)
if len(A_insert_ids) > 0:
A = np.insert(A, cumsum_a[A_insert_ids], A_insert, axis=0)
NA[A_insert_ids] += 1
last_action[A_insert_ids] = 0
if len(X_insert_ids) > 0:
X = np.insert(X, cumsum[X_insert_ids], X_insert, axis=0)
NX[X_insert_ids] += 1
last_action[X_insert_ids] = 1
if len(finished_ids) > 0:
finished[finished_ids] = True
# print finished
return X, A, NX, NA, last_action, finished
class Builder(object, metaclass=ABCMeta):
def __init__(self, model_loc, gpu_id=None):
with open(os.path.join(model_loc, 'configs.json')) as f:
configs = json.load(f)
self.mdl = self.__class__._get_model(configs)
self.ctx = mx.gpu(gpu_id) if gpu_id is not None else mx.cpu()
self.mdl.load_parameters(os.path.join(model_loc, 'ckpt.params'), ctx=self.ctx, allow_missing=True)
@staticmethod
def _get_model(configs):
raise NotImplementedError
@abstractmethod
def sample(self, num_samples, *args, **kwargs):
raise NotImplementedError
class CVanilla_RNN_Builder(Builder):
@staticmethod
def _get_model(configs):
return CVanillaMolGen_RNN(get_mol_spec().num_atom_types, get_mol_spec().num_bond_types, D=2, **configs)
def sample(self, num_samples, c, output_type='mol', sanitize=True, random=True):
if len(c.shape) == 1:
c = np.stack([c, ]*num_samples, axis=0)
with autograd.predict_mode():
# step one
finished = [False, ] * num_samples
def get_init():
self.mdl.mode = 'decode_0'
_c = nd.array(c, dtype='float32', ctx=self.ctx)
init = self.mdl(_c).asnumpy()
return init
outputs = _decode_step(X=None, A=None, NX=None, NA=None, last_action=None, finished=finished,
get_init=get_init, get_action=None,
n_node_types=self.mdl.N_A, n_edge_types=self.mdl.N_B,
random=random)
# If outputs is None
if outputs is None:
return None
X, A, NX, NA, last_action, finished = outputs
count = 1
h = np.zeros([self.mdl.N_rnn, num_samples, self.mdl.F_c[-1]], dtype=np.float32)
while not np.all(finished) and count < 100:
def get_action(inputs):
self.mdl.mode = 'decode_step'
_h = nd.array(h[:, np.logical_not(finished), :], ctx=self.ctx, dtype='float32')
_c = nd.array(c[np.logical_not(finished), :], ctx=self.ctx, dtype='float32')
_X, _A_sparse, _NX, _NX_rep, _mask, _NX_cum = self.to_nd(inputs)
_append, _connect, _end, _h = self.mdl(_X, _A_sparse, _NX, _NX_rep, _mask, _NX_cum, _h, _c, _NX_rep)
h[:, np.logical_not(finished), :] = _h[0].asnumpy()
return _append.asnumpy(), _connect.asnumpy(), _end.asnumpy()
outputs = _decode_step(X, A, NX, NA, last_action, finished,
get_init=None, get_action=get_action,
n_node_types=self.mdl.N_A, n_edge_types=self.mdl.N_B,
random=random)
X, A, NX, NA, last_action, finished = outputs
count += 1
graph_list = []
cumsum_X_ = np.cumsum(np.pad(NX, [[1, 0]], mode='constant')).tolist()
cumsum_A_ = np.cumsum(np.pad(NA, [[1, 0]], mode='constant')).tolist()
for cumsum_A_pre, cumsum_A_post, \
cumsum_X_pre, cumsum_X_post in zip(cumsum_A_[:-1], cumsum_A_[1:],
cumsum_X_[:-1], cumsum_X_[1:]):
graph_list.append([X[cumsum_X_pre:cumsum_X_post], A[cumsum_A_pre:cumsum_A_post, :]])
if output_type=='graph':
return graph_list
elif output_type == 'mol':
return get_mol_from_graph_list(graph_list, sanitize)
elif output_type == 'smiles':
mol_list = get_mol_from_graph_list(graph_list, sanitize=True)
smiles_list = [Chem.MolToSmiles(m) if m is not None else None for m in mol_list]
return smiles_list
else:
raise ValueError('Unrecognized output type')
def to_nd(self, inputs):
X, A, NX, NX_rep, mask = inputs[:-2]
NX_cum = np.cumsum(NX)
# convert to ndarray
_to_ndarray = lambda _x: nd.array(_x, self.ctx, 'int64')
X, NX, NX_rep, mask, NX_cum = \
_to_ndarray(X), _to_ndarray(NX), _to_ndarray(NX_rep), _to_ndarray(mask), _to_ndarray(NX_cum)
A_sparse = []
for _A_i in A:
if _A_i.shape[0] == 0:
A_sparse.append(None)
else:
# transpose may not be supported in gpu
_A_i = np.concatenate([_A_i, _A_i[:, [1, 0]]], axis=0)
# construct csr matrix ...
_data = np.ones((_A_i.shape[0],), dtype=np.float32)
_row, _col = _A_i[:, 0], _A_i[:, 1]
_A_sparse_i = nd.sparse.csr_matrix((_data, (_row, _col)),
shape=tuple([int(X.shape[0]), ] * 2),
ctx=self.ctx, dtype='float32')
# append to list
A_sparse.append(_A_sparse_i)
return X, A_sparse, NX, NX_rep, mask, NX_cum