import torch import torch.nn as nn from .layer import GVP, GVPConvLayer, LayerNorm from torch_scatter import scatter_mean class AttentionPooling(nn.Module): def __init__(self, input_dim, attention_dim): super(AttentionPooling, self).__init__() self.attention_dim = attention_dim self.query_layer = nn.Linear(input_dim, attention_dim, bias=True) self.key_layer = nn.Linear(input_dim, attention_dim, bias=True) self.value_layer = nn.Linear(input_dim, 1, bias=True) # value layer outputs one score self.softmax = nn.Softmax(dim=1) def forward(self, nodes_features1, nodes_features2): # Assuming nodes_features1 and nodes_features2 are both of shape [node_num, 128] nodes_features = nodes_features1 + nodes_features2 # This can also be concatenation or another operation query = self.query_layer(nodes_features) key = self.key_layer(nodes_features) value = self.value_layer(nodes_features) attention_scores = torch.matmul(query, key.transpose(-2, -1)) # [node_num, node_num] attention_scores = self.softmax(attention_scores) pooled_features = torch.matmul(attention_scores, value) # [node_num, 1] return pooled_features class AutoGraphEncoder(nn.Module): def __init__(self, node_in_dim, node_h_dim, edge_in_dim, edge_h_dim, attention_dim=64, num_layers=4, drop_rate=0.1) -> None: super().__init__() self.W_v = nn.Sequential( LayerNorm(node_in_dim), GVP(node_in_dim, node_h_dim, activations=(None, None)) ) self.W_e = nn.Sequential( LayerNorm(edge_in_dim), GVP(edge_in_dim, edge_h_dim, activations=(None, None)) ) self.layers = nn.ModuleList( GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) for _ in range(num_layers)) ns, _ = node_h_dim self.W_out = nn.Sequential( LayerNorm(node_h_dim), GVP(node_h_dim, (ns, 0))) self.dense = nn.Sequential( nn.Linear(ns, 2*ns), nn.ReLU(inplace=True), nn.Dropout(p=drop_rate), nn.Linear(2*ns, node_in_dim[0]) # label num ) self.loss_fn = nn.CrossEntropyLoss() def forward(self, h_V, edge_index, h_E, node_s_labels): h_V = self.W_v(h_V) h_E = self.W_e(h_E) for layer in self.layers: h_V = layer(h_V, edge_index, h_E) out = self.W_out(h_V) logits = self.dense(out) loss = self.loss_fn(logits, node_s_labels) return loss, logits def get_embedding(self, h_V, edge_index, h_E): h_V = self.W_v(h_V) h_E = self.W_e(h_E) for layer in self.layers: h_V = layer(h_V, edge_index, h_E) out = self.W_out(h_V) return out class SubgraphClassficationModel(nn.Module): ''' :param node_in_dim: node dimensions in input graph, should be (6, 3) if using original features :param node_h_dim: node dimensions to use in GVP-GNN layers :param edge_in_dim: edge dimensions in input graph, should be (32, 1) if using original features :param edge_h_dim: edge dimensions to embed to before use in GVP-GNN layers :param num_layers: number of GVP-GNN layers :param drop_rate: rate to use in all dropout layers ''' def __init__(self, node_in_dim, node_h_dim, edge_in_dim, edge_h_dim, attention_dim=64, num_layers=4, drop_rate=0.1): super(SubgraphClassficationModel, self).__init__() self.W_v = nn.Sequential( LayerNorm(node_in_dim), GVP(node_in_dim, node_h_dim, activations=(None, None)) ) self.W_e = nn.Sequential( LayerNorm(edge_in_dim), GVP(edge_in_dim, edge_h_dim, activations=(None, None)) ) self.layers = nn.ModuleList( GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) for _ in range(num_layers)) ns, _ = node_h_dim self.W_out = nn.Sequential( LayerNorm(node_h_dim), GVP(node_h_dim, (ns, 0))) self.attention_classifier = AttentionPooling(ns, attention_dim) self.dense = nn.Sequential( nn.Linear(2*ns, 2*ns), nn.ReLU(inplace=True), nn.Dropout(p=drop_rate), nn.Linear(2*ns, 1) ) self.loss_fn = nn.BCEWithLogitsLoss() def forward(self, h_V_parent, edge_index_parent, h_E_parent, batch_parent, h_V_subgraph, edge_index_subgraph, h_E_subgraph, batch_subgraph, labels): ''' :param h_V: tuple (s, V) of node embeddings :param edge_index: `torch.Tensor` of shape [2, num_edges] :param h_E: tuple (s, V) of edge embeddings ''' h_V_parent = self.W_v(h_V_parent) h_E_parent = self.W_e(h_E_parent) for layer in self.layers: h_V_parent = layer(h_V_parent, edge_index_parent, h_E_parent) out_parent = self.W_out(h_V_parent) out_parent = scatter_mean(out_parent, batch_parent, dim=0) h_V_subgraph = self.W_v(h_V_subgraph) h_E_subgraph = self.W_e(h_E_subgraph) for layer in self.layers: h_V_subgraph = layer(h_V_subgraph, edge_index_subgraph, h_E_subgraph) out_subgraph = self.W_out(h_V_subgraph) out_subgraph = scatter_mean(out_subgraph, batch_subgraph, dim=0) labels = labels.float() out = torch.cat([out_parent, out_subgraph], dim=1) logits = self.dense(out) # logits = self.attention_classifier(out_parent, out_subgraph) loss = self.loss_fn(logits.squeeze(-1), labels) return loss, logits def get_embedding(self, h_V, edge_index, h_E, batch): h_V = self.W_v(h_V) h_E = self.W_e(h_E) for layer in self.layers: h_V = layer(h_V, edge_index, h_E) out = self.W_out(h_V) out = scatter_mean(out, batch, dim=0) return out