File size: 7,032 Bytes
fc0a115 6b33608 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import os
import numpy as np
import networkx as nx
import pygmtools as pygm
import torch
try:
from torch_geometric.data import Data
except:
os.system("pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
os.system("pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
os.system("pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
os.system("pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-2.0.0%2Bcpu.html")
from torch_geometric.data import Data
from one_hot import one_hot
from torch_geometric.transforms import OneHotDegree
import matplotlib.pyplot as plt
import pygmtools as pygm
pygm.set_backend('pytorch')
######################################################
# Constant Variable #
######################################################
AIDS700NEF_TYPE = [
'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F',
'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb',
'Sb', 'Se', 'Ni', 'Te'
]
COLOR = [
'#FF69B4', # O - 热情的粉红色
'#00CED1', # S - 深蓝绿色
'#FFD700', # C - 金色
'#FFA500', # N - 橙色
'#FF6347', # Cl - 番茄红色
'#8B008B', # Br - 深洋红色
'#00FF7F', # B - 春天的绿色
'#40E0D0', # Si - 绿松石色
'#FF4500', # Hg - 橙红色
'#9932CC', # I - 深兰花紫色
'#9370DB', # Bi - 中紫色
'#FFA500', # P - 橙色
'#FFFF00', # F - 黄色
'#B8860B', # Cu - 深金色
'#7FFFD4', # Ho - 碧绿色
'#FFD700', # Pd - 金色
'#B22222', # Ru - 砖红色
'#E5E4E2', # Pt - 浅灰色
'#A9A9A9', # Sn - 深灰色
'#32CD32', # Li - 酸橙色
'#CD853F', # Ga - 秘鲁色
'#7FFFD4', # Tb - 碧绿色
'#8A2BE2', # As - 紫罗兰色
'#FFD700', # Co - 金色
'#808080', # Pb - 灰色
'#A9A9A9', # Sb - 深灰色
'#FA8072', # Se - 鲑鱼色
'#BEBEBE', # Ni - 浅灰色
'#800080' # Te - 紫色
]
######################################################
# Utils Func #
######################################################
def from_gexf(filename: str, node_types: list=None):
r"""
Read Data from GEXF file
"""
if not filename.endswith('.gexf'):
raise ValueError("File type error, 'from_gexf' function only supports GEXF files")
graph = nx.read_gexf(filename)
mapping = {name: j for j, name in enumerate(graph.nodes())}
graph = nx.relabel_nodes(graph, mapping)
edge_index = torch.from_numpy(np.array(graph.edges, dtype=np.int64).transpose())
x = None
labels = None
data = None
colors = None
if 'type' in graph.nodes(data=True)[0].keys():
labels = dict()
colors = list()
num_nodes = graph.number_of_nodes()
x = torch.zeros(num_nodes, dtype=torch.long)
node_types = AIDS700NEF_TYPE if node_types is None else node_types
for node, info in graph.nodes(data=True):
x[int(node)] = node_types.index(info['type'])
labels[int(node)] = str(int(node)) + info['type']
colors.append(COLOR[x[int(node)]])
x = one_hot(x, num_classes=len(node_types))
data = Data(x=x, edge_index=edge_index, edge_attr=None)
return graph, data, labels, colors
def draw(graph, colors, labels, filename, title, pos_type=None):
if pos_type is None:
pos = nx.kamada_kawai_layout(graph)
elif pos_type == "spring":
pos = nx.spring_layout(graph)
plt.figure()
plt.gca().set_title(title)
nx.draw(graph, pos, with_labels=True, node_color=colors, edge_color='gray', labels=labels)
plt.savefig(filename)
plt.clf()
######################################################
# GED UI #
######################################################
def astar(
g1_path: str,
g2_path: str,
output_path: str="examples",
filename: str="example",
device='cpu'
):
if not os.path.exists(output_path):
os.mkdir(output_path)
output_filename = os.path.join(output_path, filename) + "_{}.png"
# Load data
g1, d1, l1, c1 = from_gexf(g1_path)
g2, d2, l2, c2 = from_gexf(g2_path)
if len(c1) > len(c2):
graph1, data1, labels1, colors1 = g2, d2, l2, c2
graph2, data2, labels2, colors2 = g1, d1, l1, c1
else:
graph1, data1, labels1, colors1 = g1, d1, l1, c1
graph2, data2, labels2, colors2 = g2, d2, l2, c2
# Build Graph and Adj Matrix
data1 = OneHotDegree(max_degree=6)(data1)
data2 = OneHotDegree(max_degree=6)(data2)
feat1 = data1.x.to(device)
feat2 = data2.x.to(device)
A1 = torch.tensor(pygm.utils.from_networkx(graph1)).float().to(device)
A2 = torch.tensor(pygm.utils.from_networkx(graph2)).float().to(device)
import site
site_path = site.getsitepackages()[0]
pygm_path = os.path.join(site_path, "pygmtools")
print(os.listdir(pygm_path))
# Caculate the ged
x_pred = pygm.genn_astar(feat1, feat2, A1, A2, return_network=False)
# Plot
draw(graph1, colors1, labels1, output_filename.format(1), "Graph1")
draw(graph2, colors2, labels2, output_filename.format(5), f"Graph2")
# Match Process
total_cost = 0
labels1_1 = labels1.copy()
for i in range(x_pred.shape[0]):
target = torch.nonzero(x_pred[i])[0].item()
labels1_1[i] = labels1[i].replace(str(i), str(target))
title = "Node Match"
draw(graph1, colors1, labels1_1, output_filename.format(2), title)
# Node Change
cur_cost = 0
labels1_2 = labels1.copy()
colors1_2 = colors1.copy()
target2ori = dict()
targets = list()
for i in range(x_pred.shape[0]):
target = torch.nonzero(x_pred[i])[0].item()
if labels1_1[i] != labels2[target]:
cur_cost += 1
labels1_2[i] = labels2[target]
colors1_2[i] = colors2[target]
target2ori[target] = i
targets.append(target)
total_cost += cur_cost
title = f"Node Change"
draw(graph1, colors1_2, labels1_2, output_filename.format(3), title)
# Edge Change
leave_cost = np.array(graph2).shape[0] - np.array(graph1).shape[0]
leave_cost += graph2.number_of_nodes() - graph1.number_of_nodes()
e2 = np.array(graph2.edges)
new_edges = list()
for edge in e2:
if edge[0] in targets and edge[1] in targets:
new_edges.append([target2ori[edge[0]], target2ori[edge[1]]])
graph1.edges = nx.Graph(new_edges).edges
title = f"Edge Change"
draw(graph1, colors1_2, labels1_2, output_filename.format(4), title, pos_type="spring") |