Spaces:
Sleeping
Sleeping
Update app_utils.py
Browse files- app_utils.py +358 -360
app_utils.py
CHANGED
@@ -1,360 +1,358 @@
|
|
1 |
-
import utils
|
2 |
-
import os
|
3 |
-
import math
|
4 |
-
import json
|
5 |
-
import numpy as np
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
-
import mne
|
8 |
-
from mne.channels import read_custom_montage
|
9 |
-
from scipy.interpolate import Rbf
|
10 |
-
from scipy.optimize import linear_sum_assignment
|
11 |
-
from sklearn.neighbors import NearestNeighbors
|
12 |
-
|
13 |
-
def reorder_data(idx_order, fill_flags, filename, new_filename):
|
14 |
-
# read the input data
|
15 |
-
raw_data = utils.read_train_data(filename)
|
16 |
-
#print(raw_data.shape)
|
17 |
-
new_data = np.zeros((30, raw_data.shape[1]))
|
18 |
-
|
19 |
-
zero_arr = np.zeros((1, raw_data.shape[1]))
|
20 |
-
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
21 |
-
if flag == False:
|
22 |
-
new_data[i, :] = raw_data[idx_set[0], :]
|
23 |
-
elif idx_set == []:
|
24 |
-
new_data[i, :] = zero_arr
|
25 |
-
else:
|
26 |
-
tmp_data = [raw_data[j, :] for j in idx_set]
|
27 |
-
new_data[i, :] = np.mean(tmp_data, axis=0)
|
28 |
-
|
29 |
-
utils.save_data(new_data, new_filename)
|
30 |
-
return raw_data.shape
|
31 |
-
|
32 |
-
def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
|
33 |
-
# read the denoised data
|
34 |
-
d_data = utils.read_train_data(filename)
|
35 |
-
if batch_cnt == 0:
|
36 |
-
new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
|
37 |
-
#print(new_data.shape)
|
38 |
-
else:
|
39 |
-
new_data = utils.read_train_data(new_filename)
|
40 |
-
|
41 |
-
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
42 |
-
# ignore if this channel was filled using "fillmode"
|
43 |
-
if flag == False:
|
44 |
-
new_data[idx_set[0], :] = d_data[i, :]
|
45 |
-
|
46 |
-
utils.save_data(new_data, new_filename)
|
47 |
-
return
|
48 |
-
|
49 |
-
def get_matched(tpl_order, tpl_dict):
|
50 |
-
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
|
51 |
-
|
52 |
-
def get_empty_templates(tpl_order, tpl_dict):
|
53 |
-
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
|
54 |
-
|
55 |
-
def get_unassigned_inputs(in_order, in_dict):
|
56 |
-
return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
|
57 |
-
|
58 |
-
def read_montage_data(loc_file):
|
59 |
-
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
60 |
-
in_montage = read_custom_montage(loc_file)
|
61 |
-
tpl_order = tpl_montage.ch_names
|
62 |
-
in_order = in_montage.ch_names
|
63 |
-
tpl_dict = {}
|
64 |
-
in_dict = {}
|
65 |
-
|
66 |
-
# convert all channel names to uppercase and store the channel information
|
67 |
-
for i, channel in enumerate(tpl_order):
|
68 |
-
up_channel = str.upper(channel)
|
69 |
-
tpl_montage.rename_channels({channel: up_channel})
|
70 |
-
tpl_dict[up_channel] = {
|
71 |
-
"index" : i,
|
72 |
-
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
|
73 |
-
"matched" : False
|
74 |
-
}
|
75 |
-
for i, channel in enumerate(in_order):
|
76 |
-
up_channel = str.upper(channel)
|
77 |
-
in_montage.rename_channels({channel: up_channel})
|
78 |
-
in_dict[up_channel] = {
|
79 |
-
"index" : i,
|
80 |
-
"coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
|
81 |
-
"assigned" : False
|
82 |
-
}
|
83 |
-
return tpl_montage, in_montage, tpl_dict, in_dict
|
84 |
-
|
85 |
-
def save_figures(channel_info, tpl_montage, filename1, filename2):
|
86 |
-
tpl_order = channel_info["templateOrder"]
|
87 |
-
in_order = channel_info["inputOrder"]
|
88 |
-
tpl_dict = channel_info["templateDict"]
|
89 |
-
in_dict = channel_info["inputDict"]
|
90 |
-
|
91 |
-
tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
|
92 |
-
tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
|
93 |
-
in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
|
94 |
-
in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
|
95 |
-
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
96 |
-
in_coords = np.vstack((in_x, in_y)).T
|
97 |
-
|
98 |
-
# extract template's head figure
|
99 |
-
tpl_fig = tpl_montage.plot()
|
100 |
-
tpl_ax = tpl_fig.axes[0]
|
101 |
-
lines = tpl_ax.lines
|
102 |
-
head_lines = []
|
103 |
-
for line in lines:
|
104 |
-
x, y = line.get_data()
|
105 |
-
head_lines.append((x,y))
|
106 |
-
plt.close()
|
107 |
-
|
108 |
-
# -------------------------plot input montage------------------------------
|
109 |
-
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
110 |
-
ax = fig.add_subplot(111)
|
111 |
-
fig.tight_layout()
|
112 |
-
ax.set_aspect('equal')
|
113 |
-
ax.axis('off')
|
114 |
-
|
115 |
-
# plot template's head
|
116 |
-
for x, y in head_lines:
|
117 |
-
ax.plot(x, y, color='black', linewidth=1.0)
|
118 |
-
# plot in_channels on it
|
119 |
-
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
120 |
-
for i, channel in enumerate(in_order):
|
121 |
-
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
|
122 |
-
# save input_montage
|
123 |
-
fig.savefig(filename1)
|
124 |
-
|
125 |
-
# ---------------------------add indications-------------------------------
|
126 |
-
# plot unmatched input channels in red
|
127 |
-
indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
|
128 |
-
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
129 |
-
for i in indices:
|
130 |
-
ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
|
131 |
-
# save mapped_montage
|
132 |
-
fig.savefig(filename2)
|
133 |
-
|
134 |
-
# -------------------------------------------------------------------------
|
135 |
-
# store the tpl and in_channels' display positions (in px).
|
136 |
-
tpl_coords = ax.transData.transform(tpl_coords)
|
137 |
-
in_coords = ax.transData.transform(in_coords)
|
138 |
-
plt.close()
|
139 |
-
|
140 |
-
for i, channel in enumerate(tpl_order):
|
141 |
-
css_left = (tpl_coords[i,0]-11)/6.4
|
142 |
-
css_bottom = (tpl_coords[i,1]-7)/6.4
|
143 |
-
tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
144 |
-
for i, channel in enumerate(in_order):
|
145 |
-
css_left = (in_coords[i,0]-11)/6.4
|
146 |
-
css_bottom = (in_coords[i,1]-7)/6.4
|
147 |
-
in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
148 |
-
|
149 |
-
channel_info.update({
|
150 |
-
"templateDict" : tpl_dict,
|
151 |
-
"inputDict" : in_dict
|
152 |
-
})
|
153 |
-
return channel_info
|
154 |
-
|
155 |
-
def align_coords(channel_info, tpl_montage, in_montage):
|
156 |
-
tpl_order = channel_info["templateOrder"]
|
157 |
-
in_order = channel_info["inputOrder"]
|
158 |
-
tpl_dict = channel_info["templateDict"]
|
159 |
-
in_dict = channel_info["inputDict"]
|
160 |
-
matched = get_matched(tpl_order, tpl_dict)
|
161 |
-
|
162 |
-
# 2D alignment (for visualization purposes)
|
163 |
-
fig = [tpl_montage.plot(), in_montage.plot()]
|
164 |
-
ax = [fig[0].axes[0], fig[1].axes[0]]
|
165 |
-
|
166 |
-
# extract the displayed 2D coordinates from the plots
|
167 |
-
all_tpl = ax[0].collections[0].get_offsets().data
|
168 |
-
all_in= ax[1].collections[0].get_offsets().data
|
169 |
-
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
170 |
-
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
171 |
-
|
172 |
-
# apply TPS to transform in_channels positions to align with tpl_channels positions
|
173 |
-
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
|
174 |
-
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
|
175 |
-
|
176 |
-
# apply the transformation to all in_channels
|
177 |
-
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
|
178 |
-
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
|
179 |
-
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
|
180 |
-
|
181 |
-
# store the 2D positions
|
182 |
-
for i, channel in enumerate(tpl_order):
|
183 |
-
tpl_dict[channel]["coord_2d"] = all_tpl[i]
|
184 |
-
for i, channel in enumerate(in_order):
|
185 |
-
in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
|
186 |
-
|
187 |
-
|
188 |
-
# 3D alignment
|
189 |
-
all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
|
190 |
-
all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
|
191 |
-
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
192 |
-
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
193 |
-
|
194 |
-
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
|
195 |
-
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
|
196 |
-
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
|
197 |
-
|
198 |
-
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
|
199 |
-
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
|
200 |
-
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
201 |
-
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
202 |
-
|
203 |
-
# update in_channels' 3D positions
|
204 |
-
for i, channel in enumerate(in_order):
|
205 |
-
in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
|
206 |
-
|
207 |
-
channel_info.update({
|
208 |
-
"templateDict" : tpl_dict,
|
209 |
-
"inputDict" : in_dict
|
210 |
-
})
|
211 |
-
return channel_info
|
212 |
-
|
213 |
-
def find_neighbors(channel_info, missing_channels, new_idx):
|
214 |
-
in_order = channel_info["inputOrder"]
|
215 |
-
tpl_dict = channel_info["templateDict"]
|
216 |
-
in_dict = channel_info["inputDict"]
|
217 |
-
|
218 |
-
all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
|
219 |
-
empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
|
220 |
-
|
221 |
-
# use KNN to choose k nearest channels
|
222 |
-
k = 4 if len(in_order)>4 else len(in_order)
|
223 |
-
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
224 |
-
knn.fit(all_in)
|
225 |
-
for i, channel in enumerate(missing_channels):
|
226 |
-
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
|
227 |
-
idx = tpl_dict[channel]["index"]
|
228 |
-
new_idx[idx] = indices[0].tolist()
|
229 |
-
|
230 |
-
return new_idx
|
231 |
-
|
232 |
-
def match_names(stage1_info, channel_info):
|
233 |
-
# read the location file
|
234 |
-
loc_file = stage1_info["fileNames"]["input_loc"]
|
235 |
-
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
|
236 |
-
tpl_order = tpl_montage.ch_names
|
237 |
-
in_order = in_montage.ch_names
|
238 |
-
new_idx = [[]]*30 # store the indices of the in_channels in the order of
|
239 |
-
fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode"
|
240 |
-
|
241 |
-
alias_dict = {
|
242 |
-
'T3': 'T7',
|
243 |
-
'T4': 'T8',
|
244 |
-
'T5': 'P7',
|
245 |
-
'T6': 'P8'
|
246 |
-
}
|
247 |
-
for i, channel in enumerate(tpl_order):
|
248 |
-
if channel in alias_dict and alias_dict[channel] in in_dict:
|
249 |
-
tpl_montage.rename_channels({channel: alias_dict[channel]})
|
250 |
-
tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
|
251 |
-
channel = alias_dict[channel]
|
252 |
-
|
253 |
-
if channel in in_dict:
|
254 |
-
new_idx[i] = [in_dict[channel]["index"]]
|
255 |
-
fill_flags[i] = False
|
256 |
-
tpl_dict[channel]["matched"] = True
|
257 |
-
in_dict[channel]["assigned"] = True
|
258 |
-
|
259 |
-
# update the names
|
260 |
-
tpl_order = tpl_montage.ch_names
|
261 |
-
|
262 |
-
stage1_info.update({
|
263 |
-
"unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
|
264 |
-
"missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
|
265 |
-
"mappingData" : [
|
266 |
-
{
|
267 |
-
"newOrder" : new_idx,
|
268 |
-
"fillFlags" : fill_flags
|
269 |
-
}
|
270 |
-
]
|
271 |
-
})
|
272 |
-
channel_info.update({
|
273 |
-
"templateOrder" : tpl_order,
|
274 |
-
"inputOrder" : in_order,
|
275 |
-
"templateDict" : tpl_dict,
|
276 |
-
"inputDict" : in_dict
|
277 |
-
})
|
278 |
-
return stage1_info, channel_info, tpl_montage, in_montage
|
279 |
-
|
280 |
-
def optimal_mapping(channel_info):
|
281 |
-
tpl_order = channel_info["templateOrder"]
|
282 |
-
in_order = channel_info["inputOrder"]
|
283 |
-
tpl_dict = channel_info["templateDict"]
|
284 |
-
in_dict = channel_info["inputDict"]
|
285 |
-
unassigned = get_unassigned_inputs(in_order, in_dict)
|
286 |
-
# reset all tpl.matched to False
|
287 |
-
for channel in tpl_dict:
|
288 |
-
tpl_dict[channel]["matched"] = False
|
289 |
-
|
290 |
-
all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
|
291 |
-
unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
|
292 |
-
|
293 |
-
# initialize the cost matrix for the Hungarian algorithm
|
294 |
-
if len(unassigned) < 30:
|
295 |
-
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
|
296 |
-
else:
|
297 |
-
cost_matrix = np.zeros((30, len(unassigned)))
|
298 |
-
# fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
|
299 |
-
for i in range(30):
|
300 |
-
for j in range(len(unassigned)):
|
301 |
-
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
|
302 |
-
|
303 |
-
# apply the Hungarian algorithm to optimally assign
|
304 |
-
# by minimizing the total distances between their positions.
|
305 |
-
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
306 |
-
|
307 |
-
# store the mapping results
|
308 |
-
new_idx = [[]]*30
|
309 |
-
fill_flags = [True]*30
|
310 |
-
|
311 |
-
if
|
312 |
-
tpl_channel = tpl_order[
|
313 |
-
in_channel = unassigned[
|
314 |
-
|
315 |
-
new_idx[
|
316 |
-
fill_flags[
|
317 |
-
tpl_dict[tpl_channel]["matched"] = True
|
318 |
-
in_dict[in_channel]["assigned"] = True
|
319 |
-
#print(f'{tpl_channel}({
|
320 |
-
|
321 |
-
# fill the remaining empty tpl_channels
|
322 |
-
missing_channels = get_empty_templates(tpl_order, tpl_dict)
|
323 |
-
if missing_channels != []:
|
324 |
-
new_idx = find_neighbors(channel_info, missing_channels, new_idx)
|
325 |
-
|
326 |
-
mapping_data = {
|
327 |
-
"newOrder" : new_idx,
|
328 |
-
"fillFlags" : fill_flags
|
329 |
-
}
|
330 |
-
channel_info.update({
|
331 |
-
"templateDict" : tpl_dict,
|
332 |
-
"inputDict" : in_dict
|
333 |
-
})
|
334 |
-
return mapping_data, channel_info
|
335 |
-
|
336 |
-
def mapping_result(stage1_info, stage2_info, channel_info, filename):
|
337 |
-
unassigned_num = len(stage1_info["unassignedInputs"])
|
338 |
-
batch_num = math.ceil(unassigned_num/30) + 1
|
339 |
-
|
340 |
-
# map the remaining in_channels
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
return stage1_info, stage2_info, channel_info
|
360 |
-
|
|
|
1 |
+
import utils
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import mne
|
8 |
+
from mne.channels import read_custom_montage
|
9 |
+
from scipy.interpolate import Rbf
|
10 |
+
from scipy.optimize import linear_sum_assignment
|
11 |
+
from sklearn.neighbors import NearestNeighbors
|
12 |
+
|
13 |
+
def reorder_data(idx_order, fill_flags, filename, new_filename):
|
14 |
+
# read the input data
|
15 |
+
raw_data = utils.read_train_data(filename)
|
16 |
+
#print(raw_data.shape)
|
17 |
+
new_data = np.zeros((30, raw_data.shape[1]))
|
18 |
+
|
19 |
+
zero_arr = np.zeros((1, raw_data.shape[1]))
|
20 |
+
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
21 |
+
if flag == False:
|
22 |
+
new_data[i, :] = raw_data[idx_set[0], :]
|
23 |
+
elif idx_set == []:
|
24 |
+
new_data[i, :] = zero_arr
|
25 |
+
else:
|
26 |
+
tmp_data = [raw_data[j, :] for j in idx_set]
|
27 |
+
new_data[i, :] = np.mean(tmp_data, axis=0)
|
28 |
+
|
29 |
+
utils.save_data(new_data, new_filename)
|
30 |
+
return raw_data.shape
|
31 |
+
|
32 |
+
def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
|
33 |
+
# read the denoised data
|
34 |
+
d_data = utils.read_train_data(filename)
|
35 |
+
if batch_cnt == 0:
|
36 |
+
new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
|
37 |
+
#print(new_data.shape)
|
38 |
+
else:
|
39 |
+
new_data = utils.read_train_data(new_filename)
|
40 |
+
|
41 |
+
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
42 |
+
# ignore if this channel was filled using "fillmode"
|
43 |
+
if flag == False:
|
44 |
+
new_data[idx_set[0], :] = d_data[i, :]
|
45 |
+
|
46 |
+
utils.save_data(new_data, new_filename)
|
47 |
+
return
|
48 |
+
|
49 |
+
def get_matched(tpl_order, tpl_dict):
|
50 |
+
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
|
51 |
+
|
52 |
+
def get_empty_templates(tpl_order, tpl_dict):
|
53 |
+
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
|
54 |
+
|
55 |
+
def get_unassigned_inputs(in_order, in_dict):
|
56 |
+
return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
|
57 |
+
|
58 |
+
def read_montage_data(loc_file):
|
59 |
+
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
60 |
+
in_montage = read_custom_montage(loc_file)
|
61 |
+
tpl_order = tpl_montage.ch_names
|
62 |
+
in_order = in_montage.ch_names
|
63 |
+
tpl_dict = {}
|
64 |
+
in_dict = {}
|
65 |
+
|
66 |
+
# convert all channel names to uppercase and store the channel information
|
67 |
+
for i, channel in enumerate(tpl_order):
|
68 |
+
up_channel = str.upper(channel)
|
69 |
+
tpl_montage.rename_channels({channel: up_channel})
|
70 |
+
tpl_dict[up_channel] = {
|
71 |
+
"index" : i,
|
72 |
+
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
|
73 |
+
"matched" : False
|
74 |
+
}
|
75 |
+
for i, channel in enumerate(in_order):
|
76 |
+
up_channel = str.upper(channel)
|
77 |
+
in_montage.rename_channels({channel: up_channel})
|
78 |
+
in_dict[up_channel] = {
|
79 |
+
"index" : i,
|
80 |
+
"coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
|
81 |
+
"assigned" : False
|
82 |
+
}
|
83 |
+
return tpl_montage, in_montage, tpl_dict, in_dict
|
84 |
+
|
85 |
+
def save_figures(channel_info, tpl_montage, filename1, filename2):
|
86 |
+
tpl_order = channel_info["templateOrder"]
|
87 |
+
in_order = channel_info["inputOrder"]
|
88 |
+
tpl_dict = channel_info["templateDict"]
|
89 |
+
in_dict = channel_info["inputDict"]
|
90 |
+
|
91 |
+
tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
|
92 |
+
tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
|
93 |
+
in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
|
94 |
+
in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
|
95 |
+
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
96 |
+
in_coords = np.vstack((in_x, in_y)).T
|
97 |
+
|
98 |
+
# extract template's head figure
|
99 |
+
tpl_fig = tpl_montage.plot()
|
100 |
+
tpl_ax = tpl_fig.axes[0]
|
101 |
+
lines = tpl_ax.lines
|
102 |
+
head_lines = []
|
103 |
+
for line in lines:
|
104 |
+
x, y = line.get_data()
|
105 |
+
head_lines.append((x,y))
|
106 |
+
plt.close()
|
107 |
+
|
108 |
+
# -------------------------plot input montage------------------------------
|
109 |
+
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
110 |
+
ax = fig.add_subplot(111)
|
111 |
+
fig.tight_layout()
|
112 |
+
ax.set_aspect('equal')
|
113 |
+
ax.axis('off')
|
114 |
+
|
115 |
+
# plot template's head
|
116 |
+
for x, y in head_lines:
|
117 |
+
ax.plot(x, y, color='black', linewidth=1.0)
|
118 |
+
# plot in_channels on it
|
119 |
+
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
120 |
+
for i, channel in enumerate(in_order):
|
121 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
|
122 |
+
# save input_montage
|
123 |
+
fig.savefig(filename1)
|
124 |
+
|
125 |
+
# ---------------------------add indications-------------------------------
|
126 |
+
# plot unmatched input channels in red
|
127 |
+
indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
|
128 |
+
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
129 |
+
for i in indices:
|
130 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
|
131 |
+
# save mapped_montage
|
132 |
+
fig.savefig(filename2)
|
133 |
+
|
134 |
+
# -------------------------------------------------------------------------
|
135 |
+
# store the tpl and in_channels' display positions (in px).
|
136 |
+
tpl_coords = ax.transData.transform(tpl_coords)
|
137 |
+
in_coords = ax.transData.transform(in_coords)
|
138 |
+
plt.close()
|
139 |
+
|
140 |
+
for i, channel in enumerate(tpl_order):
|
141 |
+
css_left = (tpl_coords[i,0]-11)/6.4
|
142 |
+
css_bottom = (tpl_coords[i,1]-7)/6.4
|
143 |
+
tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
144 |
+
for i, channel in enumerate(in_order):
|
145 |
+
css_left = (in_coords[i,0]-11)/6.4
|
146 |
+
css_bottom = (in_coords[i,1]-7)/6.4
|
147 |
+
in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
148 |
+
|
149 |
+
channel_info.update({
|
150 |
+
"templateDict" : tpl_dict,
|
151 |
+
"inputDict" : in_dict
|
152 |
+
})
|
153 |
+
return channel_info
|
154 |
+
|
155 |
+
def align_coords(channel_info, tpl_montage, in_montage):
|
156 |
+
tpl_order = channel_info["templateOrder"]
|
157 |
+
in_order = channel_info["inputOrder"]
|
158 |
+
tpl_dict = channel_info["templateDict"]
|
159 |
+
in_dict = channel_info["inputDict"]
|
160 |
+
matched = get_matched(tpl_order, tpl_dict)
|
161 |
+
|
162 |
+
# 2D alignment (for visualization purposes)
|
163 |
+
fig = [tpl_montage.plot(), in_montage.plot()]
|
164 |
+
ax = [fig[0].axes[0], fig[1].axes[0]]
|
165 |
+
|
166 |
+
# extract the displayed 2D coordinates from the plots
|
167 |
+
all_tpl = ax[0].collections[0].get_offsets().data
|
168 |
+
all_in= ax[1].collections[0].get_offsets().data
|
169 |
+
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
170 |
+
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
171 |
+
|
172 |
+
# apply TPS to transform in_channels positions to align with tpl_channels positions
|
173 |
+
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
|
174 |
+
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
|
175 |
+
|
176 |
+
# apply the transformation to all in_channels
|
177 |
+
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
|
178 |
+
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
|
179 |
+
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
|
180 |
+
|
181 |
+
# store the 2D positions
|
182 |
+
for i, channel in enumerate(tpl_order):
|
183 |
+
tpl_dict[channel]["coord_2d"] = all_tpl[i]
|
184 |
+
for i, channel in enumerate(in_order):
|
185 |
+
in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
|
186 |
+
|
187 |
+
|
188 |
+
# 3D alignment
|
189 |
+
all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
|
190 |
+
all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
|
191 |
+
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
192 |
+
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
193 |
+
|
194 |
+
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
|
195 |
+
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
|
196 |
+
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
|
197 |
+
|
198 |
+
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
|
199 |
+
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
|
200 |
+
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
201 |
+
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
202 |
+
|
203 |
+
# update in_channels' 3D positions
|
204 |
+
for i, channel in enumerate(in_order):
|
205 |
+
in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
|
206 |
+
|
207 |
+
channel_info.update({
|
208 |
+
"templateDict" : tpl_dict,
|
209 |
+
"inputDict" : in_dict
|
210 |
+
})
|
211 |
+
return channel_info
|
212 |
+
|
213 |
+
def find_neighbors(channel_info, missing_channels, new_idx):
|
214 |
+
in_order = channel_info["inputOrder"]
|
215 |
+
tpl_dict = channel_info["templateDict"]
|
216 |
+
in_dict = channel_info["inputDict"]
|
217 |
+
|
218 |
+
all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
|
219 |
+
empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
|
220 |
+
|
221 |
+
# use KNN to choose k nearest channels
|
222 |
+
k = 4 if len(in_order)>4 else len(in_order)
|
223 |
+
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
224 |
+
knn.fit(all_in)
|
225 |
+
for i, channel in enumerate(missing_channels):
|
226 |
+
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
|
227 |
+
idx = tpl_dict[channel]["index"]
|
228 |
+
new_idx[idx] = indices[0].tolist()
|
229 |
+
|
230 |
+
return new_idx
|
231 |
+
|
232 |
+
def match_names(stage1_info, channel_info):
|
233 |
+
# read the location file
|
234 |
+
loc_file = stage1_info["fileNames"]["input_loc"]
|
235 |
+
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
|
236 |
+
tpl_order = tpl_montage.ch_names
|
237 |
+
in_order = in_montage.ch_names
|
238 |
+
new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channels
|
239 |
+
fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode"
|
240 |
+
|
241 |
+
alias_dict = {
|
242 |
+
'T3': 'T7',
|
243 |
+
'T4': 'T8',
|
244 |
+
'T5': 'P7',
|
245 |
+
'T6': 'P8'
|
246 |
+
}
|
247 |
+
for i, channel in enumerate(tpl_order):
|
248 |
+
if channel in alias_dict and alias_dict[channel] in in_dict:
|
249 |
+
tpl_montage.rename_channels({channel: alias_dict[channel]})
|
250 |
+
tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
|
251 |
+
channel = alias_dict[channel]
|
252 |
+
|
253 |
+
if channel in in_dict:
|
254 |
+
new_idx[i] = [in_dict[channel]["index"]]
|
255 |
+
fill_flags[i] = False
|
256 |
+
tpl_dict[channel]["matched"] = True
|
257 |
+
in_dict[channel]["assigned"] = True
|
258 |
+
|
259 |
+
# update the names
|
260 |
+
tpl_order = tpl_montage.ch_names
|
261 |
+
|
262 |
+
stage1_info.update({
|
263 |
+
"unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
|
264 |
+
"missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
|
265 |
+
"mappingData" : [
|
266 |
+
{
|
267 |
+
"newOrder" : new_idx,
|
268 |
+
"fillFlags" : fill_flags
|
269 |
+
}
|
270 |
+
]
|
271 |
+
})
|
272 |
+
channel_info.update({
|
273 |
+
"templateOrder" : tpl_order,
|
274 |
+
"inputOrder" : in_order,
|
275 |
+
"templateDict" : tpl_dict,
|
276 |
+
"inputDict" : in_dict
|
277 |
+
})
|
278 |
+
return stage1_info, channel_info, tpl_montage, in_montage
|
279 |
+
|
280 |
+
def optimal_mapping(channel_info):
|
281 |
+
tpl_order = channel_info["templateOrder"]
|
282 |
+
in_order = channel_info["inputOrder"]
|
283 |
+
tpl_dict = channel_info["templateDict"]
|
284 |
+
in_dict = channel_info["inputDict"]
|
285 |
+
unassigned = get_unassigned_inputs(in_order, in_dict)
|
286 |
+
# reset all tpl.matched to False
|
287 |
+
for channel in tpl_dict:
|
288 |
+
tpl_dict[channel]["matched"] = False
|
289 |
+
|
290 |
+
all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
|
291 |
+
unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
|
292 |
+
|
293 |
+
# initialize the cost matrix for the Hungarian algorithm
|
294 |
+
if len(unassigned) < 30:
|
295 |
+
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
|
296 |
+
else:
|
297 |
+
cost_matrix = np.zeros((30, len(unassigned)))
|
298 |
+
# fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
|
299 |
+
for i in range(30):
|
300 |
+
for j in range(len(unassigned)):
|
301 |
+
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
|
302 |
+
|
303 |
+
# apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
|
304 |
+
# by minimizing the total distances between their positions.
|
305 |
+
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
306 |
+
|
307 |
+
# store the mapping results
|
308 |
+
new_idx = [[]]*30
|
309 |
+
fill_flags = [True]*30
|
310 |
+
for i, j in zip(row_idx, col_idx):
|
311 |
+
if j < len(unassigned): # filter out dummy channels
|
312 |
+
tpl_channel = tpl_order[i]
|
313 |
+
in_channel = unassigned[j]
|
314 |
+
|
315 |
+
new_idx[i] = [in_dict[in_channel]["index"]]
|
316 |
+
fill_flags[i] = False
|
317 |
+
tpl_dict[tpl_channel]["matched"] = True
|
318 |
+
in_dict[in_channel]["assigned"] = True
|
319 |
+
#print(f'{tpl_channel}({i}) <- {in_channel}({j})')
|
320 |
+
|
321 |
+
# fill the remaining empty tpl_channels
|
322 |
+
missing_channels = get_empty_templates(tpl_order, tpl_dict)
|
323 |
+
if missing_channels != []:
|
324 |
+
new_idx = find_neighbors(channel_info, missing_channels, new_idx)
|
325 |
+
|
326 |
+
mapping_data = {
|
327 |
+
"newOrder" : new_idx,
|
328 |
+
"fillFlags" : fill_flags
|
329 |
+
}
|
330 |
+
channel_info.update({
|
331 |
+
"templateDict" : tpl_dict,
|
332 |
+
"inputDict" : in_dict
|
333 |
+
})
|
334 |
+
return mapping_data, channel_info
|
335 |
+
|
336 |
+
def mapping_result(stage1_info, stage2_info, channel_info, filename):
|
337 |
+
unassigned_num = len(stage1_info["unassignedInputs"])
|
338 |
+
batch_num = math.ceil(unassigned_num/30) + 1
|
339 |
+
|
340 |
+
# map the remaining in_channels
|
341 |
+
for i in range(1, batch_num):
|
342 |
+
# optimally select 30 in_channels to map to the tpl_channels based on proximity
|
343 |
+
new_mapping_data, channel_info = optimal_mapping(channel_info)
|
344 |
+
stage1_info["mappingData"] += [new_mapping_data]
|
345 |
+
|
346 |
+
# save the mapping results
|
347 |
+
new_dict = {
|
348 |
+
#"templateOrder" : channel_info["templateOrder"],
|
349 |
+
#"inputOrder" : channel_info["inputOrder"],
|
350 |
+
"batchNum" : batch_num,
|
351 |
+
"mappingData" : stage1_info["mappingData"]
|
352 |
+
}
|
353 |
+
with open(filename, 'w') as jsonfile:
|
354 |
+
jsonfile.write(json.dumps(new_dict))
|
355 |
+
|
356 |
+
stage2_info["totalBatchNum"] = batch_num
|
357 |
+
return stage1_info, stage2_info, channel_info
|
358 |
+
|
|
|
|