Spaces:
Running
Running
# -*- coding:utf-8 -*- | |
# Rhizome | |
# Version beta 0.0, August 2023 | |
# Property of IBM Research, Accelerated Discovery | |
# | |
""" | |
PLEASE NOTE THIS IMPLEMENTATION INCLUDES ADAPTED SOURCE CODE | |
OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE, | |
E.G., GRUEncoder/GRUDecoder, GrammarSeq2SeqVAE AND EVEN SOME METHODS OF GrammarGINVAE. | |
THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. | |
""" | |
import numpy as np | |
import logging | |
import torch | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.modules.loss import _Loss | |
from torch_geometric.nn import MessagePassing | |
from torch_geometric.nn import global_add_pool | |
from ..graph_grammar.graph_grammar.symbols import NTSymbol | |
from ..graph_grammar.nn.encoder import EncoderBase | |
from ..graph_grammar.nn.decoder import DecoderBase | |
def get_atom_edge_feature_dims(): | |
from torch_geometric.utils.smiles import x_map, e_map | |
func = lambda x: len(x[1]) | |
return list(map(func, x_map.items())), list(map(func, e_map.items())) | |
class FeatureEmbedding(nn.Module): | |
def __init__(self, input_dims, embedded_dim): | |
super().__init__() | |
self.embedding_list = nn.ModuleList() | |
for dim in input_dims: | |
embedding = nn.Embedding(dim, embedded_dim) | |
self.embedding_list.append(embedding) | |
def forward(self, x): | |
output = 0 | |
for i in range(x.shape[1]): | |
input = x[:, i].to(torch.int) | |
device = next(self.parameters()).device | |
if device != input.device: | |
input = input.to(device) | |
emb = self.embedding_list[i](input) | |
output += emb | |
return output | |
class GRUEncoder(EncoderBase): | |
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, | |
bidirectional: bool, dropout: float, batch_size: int, rank: int=-1, | |
no_dropout: bool=False): | |
super().__init__() | |
self.input_dim = input_dim | |
self.hidden_dim = hidden_dim | |
self.num_layers = num_layers | |
self.bidirectional = bidirectional | |
self.dropout = dropout | |
self.batch_size = batch_size | |
self.rank = rank | |
self.model = nn.GRU(input_size=self.input_dim, | |
hidden_size=self.hidden_dim, | |
num_layers=self.num_layers, | |
batch_first=True, | |
bidirectional=self.bidirectional, | |
dropout=self.dropout if not no_dropout else 0) | |
if self.rank >= 0: | |
if torch.cuda.is_available(): | |
self.model = self.model.to(rank) | |
else: | |
# support mac mps | |
self.model = self.model.to(torch.device("mps", rank)) | |
self.init_hidden(self.batch_size) | |
def init_hidden(self, bsize): | |
self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers, | |
min(self.batch_size, bsize), | |
self.hidden_dim), | |
requires_grad=False) | |
if self.rank >= 0: | |
if torch.cuda.is_available(): | |
self.h0 = self.h0.to(self.rank) | |
else: | |
# support mac mps | |
self.h0 = self.h0.to(torch.device("mps", self.rank)) | |
def to(self, device): | |
newself = super().to(device) | |
newself.model = newself.model.to(device) | |
newself.h0 = newself.h0.to(device) | |
newself.rank = next(newself.parameters()).get_device() | |
return newself | |
def forward(self, in_seq_emb): | |
''' forward model | |
Parameters | |
---------- | |
in_seq_emb : Tensor, shape (batch_size, max_len, input_dim) | |
Returns | |
------- | |
hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim) | |
''' | |
# Kishi: I think original MHG had this init_hidden() | |
self.init_hidden(in_seq_emb.size(0)) | |
max_len = in_seq_emb.size(1) | |
hidden_seq_emb, self.h0 = self.model( | |
in_seq_emb, self.h0) | |
# As shown as returns, convert hidden_seq_emb: (batch_size, seq_len, (1 or 2) * hidden_size) --> | |
# (batch_size, seq_len, 1 or 2, hidden_size) | |
# In the original input the original GRU/LSTM with bidirectional encoding | |
# has contactinated tensors | |
# (first half for forward RNN, latter half for backward RNN) | |
# so convert them in a more friendly format packed for each RNN | |
hidden_seq_emb = hidden_seq_emb.view(-1, | |
max_len, | |
1 + self.bidirectional, | |
self.hidden_dim) | |
return hidden_seq_emb | |
class GRUDecoder(DecoderBase): | |
def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, | |
dropout: float, batch_size: int, rank: int=-1, | |
no_dropout: bool=False): | |
super().__init__() | |
self.input_dim = input_dim | |
self.hidden_dim = hidden_dim | |
self.num_layers = num_layers | |
self.dropout = dropout | |
self.batch_size = batch_size | |
self.rank = rank | |
self.model = nn.GRU(input_size=self.input_dim, | |
hidden_size=self.hidden_dim, | |
num_layers=self.num_layers, | |
batch_first=True, | |
bidirectional=False, | |
dropout=self.dropout if not no_dropout else 0 | |
) | |
if self.rank >= 0: | |
if torch.cuda.is_available(): | |
self.model = self.model.to(self.rank) | |
else: | |
# support mac mps | |
self.model = self.model.to(torch.device("mps", self.rank)) | |
self.init_hidden(self.batch_size) | |
def init_hidden(self, bsize): | |
self.hidden_dict['h'] = torch.zeros((self.num_layers, | |
min(self.batch_size, bsize), | |
self.hidden_dim), | |
requires_grad=False) | |
if self.rank >= 0: | |
if torch.cuda.is_available(): | |
self.hidden_dict['h'] = self.hidden_dict['h'].to(self.rank) | |
else: | |
self.hidden_dict['h'] = self.hidden_dict['h'].to(torch.device("mps", self.rank)) | |
def to(self, device): | |
newself = super().to(device) | |
newself.model = newself.model.to(device) | |
for k in self.hidden_dict.keys(): | |
newself.hidden_dict[k] = newself.hidden_dict[k].to(device) | |
newself.rank = next(newself.parameters()).get_device() | |
return newself | |
def forward_one_step(self, tgt_emb_in): | |
''' one-step forward model | |
Parameters | |
---------- | |
tgt_emb_in : Tensor, shape (batch_size, input_dim) | |
Returns | |
------- | |
Tensor, shape (batch_size, hidden_dim) | |
''' | |
bsize = tgt_emb_in.size(0) | |
tgt_emb_out, self.hidden_dict['h'] \ | |
= self.model(tgt_emb_in.view(bsize, 1, -1), | |
self.hidden_dict['h']) | |
return tgt_emb_out | |
class NodeMLP(nn.Module): | |
def __init__(self, input_size, output_size, hidden_size): | |
super().__init__() | |
self.lin1 = nn.Linear(input_size, hidden_size) | |
self.nbat = nn.BatchNorm1d(hidden_size) | |
self.lin2 = nn.Linear(hidden_size, output_size) | |
def forward(self, x): | |
x = self.lin1(x) | |
x = self.nbat(x) | |
x = x.relu() | |
x = self.lin2(x) | |
return x | |
class GINLayer(MessagePassing): | |
def __init__(self, node_input_size, node_output_size, node_hidden_size, edge_input_size): | |
super().__init__() | |
self.node_mlp = NodeMLP(node_input_size, node_output_size, node_hidden_size) | |
self.edge_mlp = FeatureEmbedding(edge_input_size, node_output_size) | |
self.eps = nn.Parameter(torch.tensor([0.0])) | |
def forward(self, x, edge_index, edge_attr): | |
msg = self.propagate(edge_index, x=x ,edge_attr=edge_attr) | |
x = (1.0 + self.eps) * x + msg | |
x = x.relu() | |
x = self.node_mlp(x) | |
return x | |
def message(self, x_j, edge_attr): | |
edge_attr = self.edge_mlp(edge_attr) | |
x_j = x_j + edge_attr | |
x_j = x_j.relu() | |
return x_j | |
def update(self, aggr_out): | |
return aggr_out | |
#TODO implement the case where features of atoms and edges are considered | |
# Check GraphMVP and ogb (open graph benchmark) to realize this | |
class GIN(torch.nn.Module): | |
def __init__(self, node_feature_size, edge_feature_size, hidden_channels=64, | |
proximity_size=3, dropout=0.1): | |
super().__init__() | |
#print("(num node features, num edge features)=", (node_feature_size, edge_feature_size)) | |
hsize = hidden_channels * 2 | |
atom_dim, edge_dim = get_atom_edge_feature_dims() | |
self.trans = FeatureEmbedding(atom_dim, hidden_channels) | |
ml = [] | |
for _ in range(proximity_size): | |
ml.append(GINLayer(hidden_channels, hidden_channels, hsize, edge_dim)) | |
self.mlist = nn.ModuleList(ml) | |
#It is possible to calculate relu with x.relu() where x is an output | |
#self.activations = nn.ModuleList(actl) | |
self.dropout = dropout | |
self.proximity_size = proximity_size | |
def forward(self, x, edge_index, edge_attr, batch_size): | |
x = x.to(torch.float) | |
#print("before: edge_weight.shape=", edge_attr.shape) | |
edge_attr = edge_attr.to(torch.float) | |
#print("after: edge_weight.shape=", edge_attr.shape) | |
x = self.trans(x) | |
# TODO Check if this x is consistent with global_add_pool | |
hlist = [global_add_pool(x, batch_size)] | |
for id, m in enumerate(self.mlist): | |
x = m(x, edge_index=edge_index, edge_attr=edge_attr) | |
#print("Done with one layer") | |
###if id != self.proximity_size - 1: | |
x = x.relu() | |
x = F.dropout(x, p=self.dropout, training=self.training) | |
#h = global_mean_pool(x, batch_size) | |
h = global_add_pool(x, batch_size) | |
hlist.append(h) | |
#print("Done with one relu call: x.shape=", x.shape) | |
#print("calling golbal mean pool") | |
#print("calling dropout x.shape=", x.shape) | |
#print("x=", x) | |
#print("hlist[0].shape=", hlist[0].shape) | |
x = torch.cat(hlist, dim=1) | |
#print("x.shape=", x.shape) | |
x = F.dropout(x, p=self.dropout, training=self.training) | |
return x | |
# TODO copied from MHG implementation and adapted here. | |
class GrammarSeq2SeqVAE(nn.Module): | |
''' | |
Variational seq2seq with grammar. | |
TODO: rewrite this class using mixin | |
''' | |
def __init__(self, hrg, rank=-1, latent_dim=64, max_len=80, | |
batch_size=64, padding_idx=-1, | |
encoder_params={'hidden_dim': 384, 'num_layers': 3, 'bidirectional': True, | |
'dropout': 0.1}, | |
decoder_params={'hidden_dim': 384, #'num_layers': 2, | |
'num_layers': 3, | |
'dropout': 0.1}, | |
prod_rule_embed_params={'out_dim': 128}, | |
no_dropout=False): | |
super().__init__() | |
# TODO USE GRU FOR ENCODING AND DECODING | |
self.hrg = hrg | |
self.rank = rank | |
self.prod_rule_corpus = hrg.prod_rule_corpus | |
self.prod_rule_embed_params = prod_rule_embed_params | |
self.vocab_size = hrg.num_prod_rule + 1 | |
self.batch_size = batch_size | |
self.padding_idx = np.mod(padding_idx, self.vocab_size) | |
self.no_dropout = no_dropout | |
self.latent_dim = latent_dim | |
self.max_len = max_len | |
self.encoder_params = encoder_params | |
self.decoder_params = decoder_params | |
# TODO Simple embedding is used. Check if a domain-dependent embedding works or not. | |
embed_out_dim = self.prod_rule_embed_params['out_dim'] | |
#use MolecularProdRuleEmbedding later on | |
self.src_embedding = nn.Embedding(self.vocab_size, embed_out_dim, | |
padding_idx=self.padding_idx) | |
self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim, | |
padding_idx=self.padding_idx) | |
# USE a GRU-based encoder in MHG | |
self.encoder = GRUEncoder(input_dim=embed_out_dim, batch_size=self.batch_size, | |
rank=self.rank, no_dropout=self.no_dropout, | |
**self.encoder_params) | |
lin_dim = (self.encoder_params.get('bidirectional', False) + 1) * self.encoder_params['hidden_dim'] | |
lin_out_dim = self.latent_dim | |
self.hidden2mean = nn.Linear(lin_dim, lin_out_dim, bias=False) | |
self.hidden2logvar = nn.Linear(lin_dim, lin_out_dim) | |
# USE a GRU-based decoder in MHG | |
self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size, | |
rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params) | |
self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim) | |
self.latent2hidden_dict = nn.ModuleDict() | |
dec_lin_out_dim = self.decoder_params['hidden_dim'] | |
for each_hidden in self.decoder.hidden_dict.keys(): | |
self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, dec_lin_out_dim) | |
if self.rank >= 0: | |
if torch.cuda.is_available(): | |
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank) | |
else: | |
# support mac mps | |
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank)) | |
self.dec2vocab = nn.Linear(dec_lin_out_dim, self.vocab_size) | |
self.encoder.init_hidden(self.batch_size) | |
self.decoder.init_hidden(self.batch_size) | |
# TODO Do we need this? | |
if hasattr(self.src_embedding, 'weight'): | |
self.src_embedding.weight.data.uniform_(-0.1, 0.1) | |
if hasattr(self.tgt_embedding, 'weight'): | |
self.tgt_embedding.weight.data.uniform_(-0.1, 0.1) | |
self.encoder.init_hidden(self.batch_size) | |
self.decoder.init_hidden(self.batch_size) | |
def to(self, device): | |
newself = super().to(device) | |
newself.src_embedding = newself.src_embedding.to(device) | |
newself.tgt_embedding = newself.tgt_embedding.to(device) | |
newself.encoder = newself.encoder.to(device) | |
newself.decoder = newself.decoder.to(device) | |
newself.dec2vocab = newself.dec2vocab.to(device) | |
newself.hidden2mean = newself.hidden2mean.to(device) | |
newself.hidden2logvar = newself.hidden2logvar.to(device) | |
newself.latent2tgt_emb = newself.latent2tgt_emb.to(device) | |
newself.latent2hidden_dict = newself.latent2hidden_dict.to(device) | |
return newself | |
def forward(self, in_seq, out_seq): | |
''' forward model | |
Parameters | |
---------- | |
in_seq : Variable, shape (batch_size, length) | |
each element corresponds to word index. | |
where the index should be less than `vocab_size` | |
Returns | |
------- | |
Variable, shape (batch_size, length, vocab_size) | |
logit of each word (applying softmax yields the probability) | |
''' | |
mu, logvar = self.encode(in_seq) | |
z = self.reparameterize(mu, logvar) | |
return self.decode(z, out_seq), mu, logvar | |
def encode(self, in_seq): | |
src_emb = self.src_embedding(in_seq) | |
src_h = self.encoder.forward(src_emb) | |
if self.encoder_params.get('bidirectional', False): | |
concat_src_h = torch.cat((src_h[:, -1, 0, :], src_h[:, 0, 1, :]), dim=1) | |
return self.hidden2mean(concat_src_h), self.hidden2logvar(concat_src_h) | |
else: | |
return self.hidden2mean(src_h[:, -1, :]), self.hidden2logvar(src_h[:, -1, :]) | |
def reparameterize(self, mu, logvar, training=True): | |
if training: | |
std = logvar.mul(0.5).exp_() | |
device = next(self.parameters()).device | |
eps = Variable(std.data.new(std.size()).normal_()) | |
if device != eps.get_device(): | |
eps.to(device) | |
return eps.mul(std).add_(mu) | |
else: | |
return mu | |
#TODO Not tested. Need to implement this in case of molecular structure generation | |
def sample(self, sample_size=-1, deterministic=True, return_z=False): | |
self.eval() | |
self.init_hidden() | |
if sample_size == -1: | |
sample_size = self.batch_size | |
num_iter = int(np.ceil(sample_size / self.batch_size)) | |
hg_list = [] | |
z_list = [] | |
for _ in range(num_iter): | |
z = Variable(torch.normal( | |
torch.zeros(self.batch_size, self.latent_dim), | |
torch.ones(self.batch_size * self.latent_dim))).cuda() | |
_, each_hg_list = self.decode(z, deterministic=deterministic) | |
z_list.append(z) | |
hg_list += each_hg_list | |
z = torch.cat(z_list)[:sample_size] | |
hg_list = hg_list[:sample_size] | |
if return_z: | |
return hg_list, z.cpu().detach().numpy() | |
else: | |
return hg_list | |
def decode(self, z=None, out_seq=None, deterministic=True): | |
if z is None: | |
z = Variable(torch.normal( | |
torch.zeros(self.batch_size, self.latent_dim), | |
torch.ones(self.batch_size * self.latent_dim))) | |
if self.rank >= 0: | |
z = z.to(next(self.parameters()).device) | |
hidden_dict_0 = {} | |
for each_hidden in self.latent2hidden_dict.keys(): | |
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z) | |
bsize = z.size(0) | |
self.decoder.init_hidden(bsize) | |
self.decoder.feed_hidden(hidden_dict_0) | |
if out_seq is not None: | |
tgt_emb0 = self.latent2tgt_emb(z) | |
tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1]) | |
out_seq_emb = self.tgt_embedding(out_seq) | |
tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :] | |
tgt_emb_pred_list = [] | |
for each_idx in range(self.max_len): | |
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb[:, each_idx, :].view(bsize, 1, -1)) | |
tgt_emb_pred_list.append(tgt_emb_pred) | |
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) | |
return vocab_logit | |
else: | |
with torch.no_grad(): | |
tgt_emb = self.latent2tgt_emb(z) | |
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1]) | |
tgt_emb_pred_list = [] | |
stack_list = [] | |
hg_list = [] | |
nt_symbol_list = [] | |
nt_edge_list = [] | |
gen_finish_list = [] | |
for _ in range(bsize): | |
stack_list.append([]) | |
hg_list.append(None) | |
nt_symbol_list.append(NTSymbol(degree=0, | |
is_aromatic=False, | |
bond_symbol_list=[])) | |
nt_edge_list.append(None) | |
gen_finish_list.append(False) | |
for idx in range(self.max_len): | |
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) | |
tgt_emb_pred_list.append(tgt_emb_pred) | |
vocab_logit = self.dec2vocab(tgt_emb_pred) | |
for each_batch_idx in range(bsize): | |
if not gen_finish_list[each_batch_idx]: # if generation has not finished | |
# get production rule greedily | |
prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(), | |
nt_symbol_list[each_batch_idx], | |
deterministic=deterministic) | |
# convert production rule into an index | |
tgt_id = self.hrg.prod_rule_list.index(prod_rule) | |
# apply the production rule | |
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx]) | |
# add non-terminals to the stack | |
stack_list[each_batch_idx].extend(nt_edges[::-1]) | |
# if the stack size is 0, generation has finished! | |
if len(stack_list[each_batch_idx]) == 0: | |
gen_finish_list[each_batch_idx] = True | |
else: | |
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop() | |
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol'] | |
else: | |
tgt_id = np.mod(self.padding_idx, self.vocab_size) | |
indice_tensor = torch.LongTensor([tgt_id]) | |
device = next(self.parameters()).device | |
if indice_tensor.device != device: | |
indice_tensor = indice_tensor.to(device) | |
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor) | |
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) | |
#for id, v in enumerate(gen_finish_list): | |
#if not v: | |
# print("bacth id={} not finished generating a sequence: ".format(id)) | |
return gen_finish_list, vocab_logit, hg_list | |
# TODO A lot of duplicates with GrammarVAE. Clean up it if necessary | |
class GrammarGINVAE(nn.Module): | |
''' | |
Variational autoencoder based on GIN and grammar | |
''' | |
def __init__(self, hrg, rank=-1, max_len=80, | |
batch_size=64, padding_idx=-1, | |
encoder_params={'node_feature_size': 4, 'edge_feature_size': 3, | |
'hidden_channels': 64, 'proximity_size': 3, | |
'dropout': 0.1}, | |
decoder_params={'hidden_dim': 384, 'num_layers': 3, | |
'dropout': 0.1}, | |
prod_rule_embed_params={'out_dim': 128}, | |
no_dropout=False): | |
super().__init__() | |
# TODO USE GRU FOR ENCODING AND DECODING | |
self.hrg = hrg | |
self.rank = rank | |
self.prod_rule_corpus = hrg.prod_rule_corpus | |
self.prod_rule_embed_params = prod_rule_embed_params | |
self.vocab_size = hrg.num_prod_rule + 1 | |
self.batch_size = batch_size | |
self.padding_idx = np.mod(padding_idx, self.vocab_size) | |
self.no_dropout = no_dropout | |
self.max_len = max_len | |
self.encoder_params = encoder_params | |
self.decoder_params = decoder_params | |
# TODO Simple embedding is used. Check if a domain-dependent embedding works or not. | |
embed_out_dim = self.prod_rule_embed_params['out_dim'] | |
#use MolecularProdRuleEmbedding later on | |
self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim, | |
padding_idx=self.padding_idx) | |
self.encoder = GIN(**self.encoder_params) | |
self.latent_dim = self.encoder_params['hidden_channels'] | |
self.proximity_size = self.encoder_params['proximity_size'] | |
hidden_dim = self.decoder_params['hidden_dim'] | |
self.hidden2mean = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim, bias=False) | |
self.hidden2logvar = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim) | |
self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size, | |
rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params) | |
self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim) | |
self.latent2hidden_dict = nn.ModuleDict() | |
for each_hidden in self.decoder.hidden_dict.keys(): | |
self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, hidden_dim) | |
if self.rank >= 0: | |
if torch.cuda.is_available(): | |
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank) | |
else: | |
# support mac mps | |
self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank)) | |
self.dec2vocab = nn.Linear(hidden_dim, self.vocab_size) | |
self.decoder.init_hidden(self.batch_size) | |
# TODO Do we need this? | |
if hasattr(self.tgt_embedding, 'weight'): | |
self.tgt_embedding.weight.data.uniform_(-0.1, 0.1) | |
self.decoder.init_hidden(self.batch_size) | |
def to(self, device): | |
newself = super().to(device) | |
newself.encoder = newself.encoder.to(device) | |
newself.decoder = newself.decoder.to(device) | |
newself.rank = next(newself.encoder.parameters()).get_device() | |
return newself | |
def forward(self, x, edge_index, edge_attr, batch_size, out_seq=None, sched_prob = None): | |
mu, logvar = self.encode(x, edge_index, edge_attr, batch_size) | |
z = self.reparameterize(mu, logvar) | |
return self.decode(z, out_seq, sched_prob=sched_prob), mu, logvar | |
#TODO Not tested. Need to implement this in case of molecular structure generation | |
def sample(self, sample_size=-1, deterministic=True, return_z=False): | |
self.eval() | |
self.init_hidden() | |
if sample_size == -1: | |
sample_size = self.batch_size | |
num_iter = int(np.ceil(sample_size / self.batch_size)) | |
hg_list = [] | |
z_list = [] | |
for _ in range(num_iter): | |
z = Variable(torch.normal( | |
torch.zeros(self.batch_size, self.latent_dim), | |
torch.ones(self.batch_size * self.latent_dim))).cuda() | |
_, each_hg_list = self.decode(z, deterministic=deterministic) | |
z_list.append(z) | |
hg_list += each_hg_list | |
z = torch.cat(z_list)[:sample_size] | |
hg_list = hg_list[:sample_size] | |
if return_z: | |
return hg_list, z.cpu().detach().numpy() | |
else: | |
return hg_list | |
def decode(self, z=None, out_seq=None, deterministic=True, sched_prob=None): | |
if z is None: | |
z = Variable(torch.normal( | |
torch.zeros(self.batch_size, self.latent_dim), | |
torch.ones(self.batch_size * self.latent_dim))) | |
if self.rank >= 0: | |
z = z.to(next(self.parameters()).device) | |
hidden_dict_0 = {} | |
for each_hidden in self.latent2hidden_dict.keys(): | |
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z) | |
bsize = z.size(0) | |
self.decoder.init_hidden(bsize) | |
self.decoder.feed_hidden(hidden_dict_0) | |
if out_seq is not None: | |
tgt_emb0 = self.latent2tgt_emb(z) | |
tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1]) | |
out_seq_emb = self.tgt_embedding(out_seq) | |
tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :] | |
tgt_emb_pred_list = [] | |
tgt_emb_pred = None | |
for each_idx in range(self.max_len): | |
if tgt_emb_pred is None or sched_prob is None or torch.rand(1)[0] <= sched_prob: | |
inp = tgt_emb[:, each_idx, :].view(bsize, 1, -1) | |
else: | |
cur_logit = self.dec2vocab(tgt_emb_pred) | |
yi = torch.argmax(cur_logit, dim=2) | |
inp = self.tgt_embedding(yi) | |
tgt_emb_pred = self.decoder.forward_one_step(inp) | |
tgt_emb_pred_list.append(tgt_emb_pred) | |
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) | |
return vocab_logit | |
else: | |
with torch.no_grad(): | |
tgt_emb = self.latent2tgt_emb(z) | |
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1]) | |
tgt_emb_pred_list = [] | |
stack_list = [] | |
hg_list = [] | |
nt_symbol_list = [] | |
nt_edge_list = [] | |
gen_finish_list = [] | |
for _ in range(bsize): | |
stack_list.append([]) | |
hg_list.append(None) | |
nt_symbol_list.append(NTSymbol(degree=0, | |
is_aromatic=False, | |
bond_symbol_list=[])) | |
nt_edge_list.append(None) | |
gen_finish_list.append(False) | |
for _ in range(self.max_len): | |
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) | |
tgt_emb_pred_list.append(tgt_emb_pred) | |
vocab_logit = self.dec2vocab(tgt_emb_pred) | |
for each_batch_idx in range(bsize): | |
if not gen_finish_list[each_batch_idx]: # if generation has not finished | |
# get production rule greedily | |
prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(), | |
nt_symbol_list[each_batch_idx], | |
deterministic=deterministic) | |
# convert production rule into an index | |
tgt_id = self.hrg.prod_rule_list.index(prod_rule) | |
# apply the production rule | |
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx]) | |
# add non-terminals to the stack | |
stack_list[each_batch_idx].extend(nt_edges[::-1]) | |
# if the stack size is 0, generation has finished! | |
if len(stack_list[each_batch_idx]) == 0: | |
gen_finish_list[each_batch_idx] = True | |
else: | |
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop() | |
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol'] | |
else: | |
tgt_id = np.mod(self.padding_idx, self.vocab_size) | |
indice_tensor = torch.LongTensor([tgt_id]) | |
if self.rank >= 0: | |
indice_tensor = indice_tensor.to(next(self.parameters()).device) | |
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor) | |
vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) | |
return gen_finish_list, vocab_logit, hg_list | |
#TODO Not tested. Need to implement this in case of molecular structure generation | |
def conditional_distribution(self, z, tgt_id_list): | |
self.eval() | |
self.init_hidden() | |
z = z.cuda() | |
hidden_dict_0 = {} | |
for each_hidden in self.latent2hidden_dict.keys(): | |
hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z) | |
self.decoder.feed_hidden(hidden_dict_0) | |
with torch.no_grad(): | |
tgt_emb = self.latent2tgt_emb(z) | |
tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1]) | |
nt_symbol_list = [] | |
stack_list = [] | |
hg_list = [] | |
nt_edge_list = [] | |
gen_finish_list = [] | |
for _ in range(self.batch_size): | |
nt_symbol_list.append(NTSymbol(degree=0, | |
is_aromatic=False, | |
bond_symbol_list=[])) | |
stack_list.append([]) | |
hg_list.append(None) | |
nt_edge_list.append(None) | |
gen_finish_list.append(False) | |
for each_position in range(len(tgt_id_list[0])): | |
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) | |
for each_batch_idx in range(self.batch_size): | |
if not gen_finish_list[each_batch_idx]: # if generation has not finished | |
# use the prespecified target ids | |
tgt_id = tgt_id_list[each_batch_idx][each_position] | |
prod_rule = self.hrg.prod_rule_list[tgt_id] | |
# apply the production rule | |
hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx]) | |
# add non-terminals to the stack | |
stack_list[each_batch_idx].extend(nt_edges[::-1]) | |
# if the stack size is 0, generation has finished! | |
if len(stack_list[each_batch_idx]) == 0: | |
gen_finish_list[each_batch_idx] = True | |
else: | |
nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop() | |
nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol'] | |
else: | |
tgt_id = np.mod(self.padding_idx, self.vocab_size) | |
indice_tensor = torch.LongTensor([tgt_id]) | |
indice_tensor = indice_tensor.cuda() | |
tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor) | |
# last one step | |
conditional_logprob_list = [] | |
tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) | |
vocab_logit = self.dec2vocab(tgt_emb_pred) | |
for each_batch_idx in range(self.batch_size): | |
if not gen_finish_list[each_batch_idx]: # if generation has not finished | |
# get production rule greedily | |
masked_logprob = self.hrg.prod_rule_corpus.masked_logprob( | |
vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(), | |
nt_symbol_list[each_batch_idx]) | |
conditional_logprob_list.append(masked_logprob) | |
else: | |
conditional_logprob_list.append(None) | |
return conditional_logprob_list | |
#TODO Not tested. Need to implement this in case of molecular structure generation | |
def decode_with_beam_search(self, z, beam_width=1): | |
''' Decode a latent vector using beam search. | |
Parameters | |
---------- | |
z | |
latent vector | |
beam_width : int | |
parameter for beam search | |
Returns | |
------- | |
List of Hypergraphs | |
''' | |
if self.batch_size != 1: | |
raise ValueError('this method works only under batch_size=1') | |
if self.padding_idx != -1: | |
raise ValueError('this method works only under padding_idx=-1') | |
top_k_tgt_id_list = [[]] * beam_width | |
logprob_list = [0.] * beam_width | |
for each_len in range(self.max_len): | |
expanded_logprob_list = np.repeat(logprob_list, self.vocab_size) # including padding_idx | |
expanded_length_list = np.array([0] * (beam_width * self.vocab_size)) | |
for each_beam_idx, each_candidate in enumerate(top_k_tgt_id_list): | |
conditional_logprob = self.conditional_distribution(z, [each_candidate])[0] | |
if conditional_logprob is None: | |
expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\ | |
= logprob_list[each_beam_idx] | |
expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\ | |
= -np.inf | |
expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\ | |
= len(each_candidate) | |
else: | |
expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\ | |
= logprob_list[each_beam_idx] + conditional_logprob | |
expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\ | |
= -np.inf | |
expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\ | |
= len(each_candidate) + 1 | |
score_list = np.array(expanded_logprob_list) / np.array(expanded_length_list) | |
if each_len == 0: | |
top_k_list = np.argsort(score_list[:self.vocab_size])[::-1][:beam_width] | |
else: | |
top_k_list = np.argsort(score_list)[::-1][:beam_width] | |
next_top_k_tgt_id_list = [] | |
next_logprob_list = [] | |
for each_top_k in top_k_list: | |
beam_idx = each_top_k // self.vocab_size | |
vocab_idx = each_top_k % self.vocab_size | |
if vocab_idx == self.vocab_size - 1: | |
next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx]) | |
next_logprob_list.append(expanded_logprob_list[each_top_k]) | |
else: | |
next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx] + [vocab_idx]) | |
next_logprob_list.append(expanded_logprob_list[each_top_k]) | |
top_k_tgt_id_list = next_top_k_tgt_id_list | |
logprob_list = next_logprob_list | |
# construct hypergraphs | |
hg_list = [] | |
for each_tgt_id_list in top_k_tgt_id_list: | |
hg = None | |
stack = [] | |
nt_edge = None | |
for each_idx, each_prod_rule_id in enumerate(each_tgt_id_list): | |
prod_rule = self.hrg.prod_rule_list[each_prod_rule_id] | |
hg, nt_edges = prod_rule.applied_to(hg, nt_edge) | |
stack.extend(nt_edges[::-1]) | |
try: | |
nt_edge = stack.pop() | |
except IndexError: | |
if each_idx == len(each_tgt_id_list) - 1: | |
break | |
else: | |
raise ValueError('some bugs') | |
hg_list.append(hg) | |
return hg_list | |
def graph_embed(self, x, edge_index, edge_attr, batch_size): | |
src_h = self.encoder.forward(x, edge_index, edge_attr, batch_size) | |
return src_h | |
def encode(self, x, edge_index, edge_attr, batch_size): | |
#print("device for src_emb=", src_emb.get_device()) | |
#print("device for self.encoder=", next(self.encoder.parameters()).get_device()) | |
src_h = self.graph_embed(x, edge_index, edge_attr, batch_size) | |
mu, lv = self.get_mean_var(src_h) | |
return mu, lv | |
def get_mean_var(self, src_h): | |
#src_h = torch.tanh(src_h) | |
mu = self.hidden2mean(src_h) | |
lv = self.hidden2logvar(src_h) | |
mu = torch.tanh(mu) | |
lv = torch.tanh(lv) | |
return mu, lv | |
def reparameterize(self, mu, logvar, training=True): | |
if training: | |
std = logvar.mul(0.5).exp_() | |
eps = Variable(std.data.new(std.size()).normal_()) | |
if self.rank >= 0: | |
eps = eps.to(next(self.parameters()).device) | |
return eps.mul(std).add_(mu) | |
else: | |
return mu | |
# Copied from the MHG implementation and adapted | |
class GrammarVAELoss(_Loss): | |
''' | |
a loss function for Grammar VAE | |
Attributes | |
---------- | |
hrg : HyperedgeReplacementGrammar | |
beta : float | |
coefficient of KL divergence | |
''' | |
def __init__(self, rank, hrg, beta=1.0, **kwargs): | |
super().__init__(**kwargs) | |
self.hrg = hrg | |
self.beta = beta | |
self.rank = rank | |
def forward(self, mu, logvar, in_seq_pred, in_seq): | |
''' compute VAE loss | |
Parameters | |
---------- | |
in_seq_pred : torch.Tensor, shape (batch_size, max_len, vocab_size) | |
logit | |
in_seq : torch.Tensor, shape (batch_size, max_len) | |
each element corresponds to a word id in vocabulary. | |
mu : torch.Tensor, shape (batch_size, hidden_dim) | |
logvar : torch.Tensor, shape (batch_size, hidden_dim) | |
mean and log variance of the normal distribution | |
''' | |
batch_size = in_seq_pred.shape[0] | |
max_len = in_seq_pred.shape[1] | |
vocab_size = in_seq_pred.shape[2] | |
mask = torch.zeros(in_seq_pred.shape) | |
for each_batch in range(batch_size): | |
flag = True | |
for each_idx in range(max_len): | |
prod_rule_idx = in_seq[each_batch, each_idx] | |
if prod_rule_idx == vocab_size - 1: | |
#### DETERMINE WHETHER THIS SHOULD BE SKIPPED OR NOT | |
mask[each_batch, each_idx, prod_rule_idx] = 1 | |
#break | |
continue | |
lhs = self.hrg.prod_rule_corpus.prod_rule_list[prod_rule_idx].lhs_nt_symbol | |
lhs_idx = self.hrg.prod_rule_corpus.nt_symbol_list.index(lhs) | |
mask[each_batch, each_idx, :-1] = torch.FloatTensor(self.hrg.prod_rule_corpus.lhs_in_prod_rule[lhs_idx]) | |
if self.rank >= 0: | |
mask = mask.to(next(self.parameters()).device) | |
in_seq_pred = mask * in_seq_pred | |
cross_entropy = F.cross_entropy( | |
in_seq_pred.view(-1, vocab_size), | |
in_seq.view(-1), | |
reduction='sum', | |
#ignore_index=self.ignore_index if self.ignore_index is not None else -100 | |
) | |
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | |
return cross_entropy + self.beta * kl_div | |
class VAELoss(_Loss): | |
def __init__(self, beta=0.01): | |
super().__init__() | |
self.beta = beta | |
def forward(self, mean, log_var, dec_outputs, targets): | |
device = mean.get_device() | |
if device >= 0: | |
targets = targets.to(mean.get_device()) | |
reconstruction = F.cross_entropy(dec_outputs.view(-1, dec_outputs.size(2)), targets.view(-1), reduction='sum') | |
KL = 0.5 * torch.sum(1 + log_var - mean ** 2 - torch.exp(log_var)) | |
loss = - self.beta * KL + reconstruction | |
return loss | |