Spaces:
Runtime error
Runtime error
File size: 6,446 Bytes
8918ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
import torch.nn as nn
from .layer import GVP, GVPConvLayer, LayerNorm
from torch_scatter import scatter_mean
class AttentionPooling(nn.Module):
def __init__(self, input_dim, attention_dim):
super(AttentionPooling, self).__init__()
self.attention_dim = attention_dim
self.query_layer = nn.Linear(input_dim, attention_dim, bias=True)
self.key_layer = nn.Linear(input_dim, attention_dim, bias=True)
self.value_layer = nn.Linear(input_dim, 1, bias=True) # value layer outputs one score
self.softmax = nn.Softmax(dim=1)
def forward(self, nodes_features1, nodes_features2):
# Assuming nodes_features1 and nodes_features2 are both of shape [node_num, 128]
nodes_features = nodes_features1 + nodes_features2 # This can also be concatenation or another operation
query = self.query_layer(nodes_features)
key = self.key_layer(nodes_features)
value = self.value_layer(nodes_features)
attention_scores = torch.matmul(query, key.transpose(-2, -1)) # [node_num, node_num]
attention_scores = self.softmax(attention_scores)
pooled_features = torch.matmul(attention_scores, value) # [node_num, 1]
return pooled_features
class AutoGraphEncoder(nn.Module):
def __init__(self, node_in_dim, node_h_dim,
edge_in_dim, edge_h_dim, attention_dim=64,
num_layers=4, drop_rate=0.1) -> None:
super().__init__()
self.W_v = nn.Sequential(
LayerNorm(node_in_dim),
GVP(node_in_dim, node_h_dim, activations=(None, None))
)
self.W_e = nn.Sequential(
LayerNorm(edge_in_dim),
GVP(edge_in_dim, edge_h_dim, activations=(None, None))
)
self.layers = nn.ModuleList(
GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate)
for _ in range(num_layers))
ns, _ = node_h_dim
self.W_out = nn.Sequential(
LayerNorm(node_h_dim),
GVP(node_h_dim, (ns, 0)))
self.dense = nn.Sequential(
nn.Linear(ns, 2*ns),
nn.ReLU(inplace=True),
nn.Dropout(p=drop_rate),
nn.Linear(2*ns, node_in_dim[0]) # label num
)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, h_V, edge_index, h_E, node_s_labels):
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
for layer in self.layers:
h_V = layer(h_V, edge_index, h_E)
out = self.W_out(h_V)
logits = self.dense(out)
loss = self.loss_fn(logits, node_s_labels)
return loss, logits
def get_embedding(self, h_V, edge_index, h_E):
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
for layer in self.layers:
h_V = layer(h_V, edge_index, h_E)
out = self.W_out(h_V)
return out
class SubgraphClassficationModel(nn.Module):
'''
:param node_in_dim: node dimensions in input graph, should be
(6, 3) if using original features
:param node_h_dim: node dimensions to use in GVP-GNN layers
:param edge_in_dim: edge dimensions in input graph, should be
(32, 1) if using original features
:param edge_h_dim: edge dimensions to embed to before use
in GVP-GNN layers
:param num_layers: number of GVP-GNN layers
:param drop_rate: rate to use in all dropout layers
'''
def __init__(self, node_in_dim, node_h_dim,
edge_in_dim, edge_h_dim, attention_dim=64,
num_layers=4, drop_rate=0.1):
super(SubgraphClassficationModel, self).__init__()
self.W_v = nn.Sequential(
LayerNorm(node_in_dim),
GVP(node_in_dim, node_h_dim, activations=(None, None))
)
self.W_e = nn.Sequential(
LayerNorm(edge_in_dim),
GVP(edge_in_dim, edge_h_dim, activations=(None, None))
)
self.layers = nn.ModuleList(
GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate)
for _ in range(num_layers))
ns, _ = node_h_dim
self.W_out = nn.Sequential(
LayerNorm(node_h_dim),
GVP(node_h_dim, (ns, 0)))
self.attention_classifier = AttentionPooling(ns, attention_dim)
self.dense = nn.Sequential(
nn.Linear(2*ns, 2*ns),
nn.ReLU(inplace=True),
nn.Dropout(p=drop_rate),
nn.Linear(2*ns, 1)
)
self.loss_fn = nn.BCEWithLogitsLoss()
def forward(self, h_V_parent, edge_index_parent, h_E_parent, batch_parent,
h_V_subgraph, edge_index_subgraph, h_E_subgraph, batch_subgraph,
labels):
'''
:param h_V: tuple (s, V) of node embeddings
:param edge_index: `torch.Tensor` of shape [2, num_edges]
:param h_E: tuple (s, V) of edge embeddings
'''
h_V_parent = self.W_v(h_V_parent)
h_E_parent = self.W_e(h_E_parent)
for layer in self.layers:
h_V_parent = layer(h_V_parent, edge_index_parent, h_E_parent)
out_parent = self.W_out(h_V_parent)
out_parent = scatter_mean(out_parent, batch_parent, dim=0)
h_V_subgraph = self.W_v(h_V_subgraph)
h_E_subgraph = self.W_e(h_E_subgraph)
for layer in self.layers:
h_V_subgraph = layer(h_V_subgraph, edge_index_subgraph, h_E_subgraph)
out_subgraph = self.W_out(h_V_subgraph)
out_subgraph = scatter_mean(out_subgraph, batch_subgraph, dim=0)
labels = labels.float()
out = torch.cat([out_parent, out_subgraph], dim=1)
logits = self.dense(out)
# logits = self.attention_classifier(out_parent, out_subgraph)
loss = self.loss_fn(logits.squeeze(-1), labels)
return loss, logits
def get_embedding(self, h_V, edge_index, h_E, batch):
h_V = self.W_v(h_V)
h_E = self.W_e(h_E)
for layer in self.layers:
h_V = layer(h_V, edge_index, h_E)
out = self.W_out(h_V)
out = scatter_mean(out, batch, dim=0)
return out |