audrey06100 commited on
Commit
71e584e
·
verified ·
1 Parent(s): 69ee4d8

Update app_utils.py

Browse files
Files changed (1) hide show
  1. app_utils.py +358 -360
app_utils.py CHANGED
@@ -1,360 +1,358 @@
1
- import utils
2
- import os
3
- import math
4
- import json
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- import mne
8
- from mne.channels import read_custom_montage
9
- from scipy.interpolate import Rbf
10
- from scipy.optimize import linear_sum_assignment
11
- from sklearn.neighbors import NearestNeighbors
12
-
13
- def reorder_data(idx_order, fill_flags, filename, new_filename):
14
- # read the input data
15
- raw_data = utils.read_train_data(filename)
16
- #print(raw_data.shape)
17
- new_data = np.zeros((30, raw_data.shape[1]))
18
-
19
- zero_arr = np.zeros((1, raw_data.shape[1]))
20
- for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
21
- if flag == False:
22
- new_data[i, :] = raw_data[idx_set[0], :]
23
- elif idx_set == []:
24
- new_data[i, :] = zero_arr
25
- else:
26
- tmp_data = [raw_data[j, :] for j in idx_set]
27
- new_data[i, :] = np.mean(tmp_data, axis=0)
28
-
29
- utils.save_data(new_data, new_filename)
30
- return raw_data.shape
31
-
32
- def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
33
- # read the denoised data
34
- d_data = utils.read_train_data(filename)
35
- if batch_cnt == 0:
36
- new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
37
- #print(new_data.shape)
38
- else:
39
- new_data = utils.read_train_data(new_filename)
40
-
41
- for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
42
- # ignore if this channel was filled using "fillmode"
43
- if flag == False:
44
- new_data[idx_set[0], :] = d_data[i, :]
45
-
46
- utils.save_data(new_data, new_filename)
47
- return
48
-
49
- def get_matched(tpl_order, tpl_dict):
50
- return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
51
-
52
- def get_empty_templates(tpl_order, tpl_dict):
53
- return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
54
-
55
- def get_unassigned_inputs(in_order, in_dict):
56
- return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
57
-
58
- def read_montage_data(loc_file):
59
- tpl_montage = read_custom_montage("./template_chanlocs.loc")
60
- in_montage = read_custom_montage(loc_file)
61
- tpl_order = tpl_montage.ch_names
62
- in_order = in_montage.ch_names
63
- tpl_dict = {}
64
- in_dict = {}
65
-
66
- # convert all channel names to uppercase and store the channel information
67
- for i, channel in enumerate(tpl_order):
68
- up_channel = str.upper(channel)
69
- tpl_montage.rename_channels({channel: up_channel})
70
- tpl_dict[up_channel] = {
71
- "index" : i,
72
- "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
73
- "matched" : False
74
- }
75
- for i, channel in enumerate(in_order):
76
- up_channel = str.upper(channel)
77
- in_montage.rename_channels({channel: up_channel})
78
- in_dict[up_channel] = {
79
- "index" : i,
80
- "coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
81
- "assigned" : False
82
- }
83
- return tpl_montage, in_montage, tpl_dict, in_dict
84
-
85
- def save_figures(channel_info, tpl_montage, filename1, filename2):
86
- tpl_order = channel_info["templateOrder"]
87
- in_order = channel_info["inputOrder"]
88
- tpl_dict = channel_info["templateDict"]
89
- in_dict = channel_info["inputDict"]
90
-
91
- tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
92
- tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
93
- in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
94
- in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
95
- tpl_coords = np.vstack((tpl_x, tpl_y)).T
96
- in_coords = np.vstack((in_x, in_y)).T
97
-
98
- # extract template's head figure
99
- tpl_fig = tpl_montage.plot()
100
- tpl_ax = tpl_fig.axes[0]
101
- lines = tpl_ax.lines
102
- head_lines = []
103
- for line in lines:
104
- x, y = line.get_data()
105
- head_lines.append((x,y))
106
- plt.close()
107
-
108
- # -------------------------plot input montage------------------------------
109
- fig = plt.figure(figsize=(6.4,6.4), dpi=100)
110
- ax = fig.add_subplot(111)
111
- fig.tight_layout()
112
- ax.set_aspect('equal')
113
- ax.axis('off')
114
-
115
- # plot template's head
116
- for x, y in head_lines:
117
- ax.plot(x, y, color='black', linewidth=1.0)
118
- # plot in_channels on it
119
- ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
120
- for i, channel in enumerate(in_order):
121
- ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
122
- # save input_montage
123
- fig.savefig(filename1)
124
-
125
- # ---------------------------add indications-------------------------------
126
- # plot unmatched input channels in red
127
- indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
128
- ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
129
- for i in indices:
130
- ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
131
- # save mapped_montage
132
- fig.savefig(filename2)
133
-
134
- # -------------------------------------------------------------------------
135
- # store the tpl and in_channels' display positions (in px).
136
- tpl_coords = ax.transData.transform(tpl_coords)
137
- in_coords = ax.transData.transform(in_coords)
138
- plt.close()
139
-
140
- for i, channel in enumerate(tpl_order):
141
- css_left = (tpl_coords[i,0]-11)/6.4
142
- css_bottom = (tpl_coords[i,1]-7)/6.4
143
- tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
144
- for i, channel in enumerate(in_order):
145
- css_left = (in_coords[i,0]-11)/6.4
146
- css_bottom = (in_coords[i,1]-7)/6.4
147
- in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
148
-
149
- channel_info.update({
150
- "templateDict" : tpl_dict,
151
- "inputDict" : in_dict
152
- })
153
- return channel_info
154
-
155
- def align_coords(channel_info, tpl_montage, in_montage):
156
- tpl_order = channel_info["templateOrder"]
157
- in_order = channel_info["inputOrder"]
158
- tpl_dict = channel_info["templateDict"]
159
- in_dict = channel_info["inputDict"]
160
- matched = get_matched(tpl_order, tpl_dict)
161
-
162
- # 2D alignment (for visualization purposes)
163
- fig = [tpl_montage.plot(), in_montage.plot()]
164
- ax = [fig[0].axes[0], fig[1].axes[0]]
165
-
166
- # extract the displayed 2D coordinates from the plots
167
- all_tpl = ax[0].collections[0].get_offsets().data
168
- all_in= ax[1].collections[0].get_offsets().data
169
- matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
170
- matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
171
-
172
- # apply TPS to transform in_channels positions to align with tpl_channels positions
173
- rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
174
- rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
175
-
176
- # apply the transformation to all in_channels
177
- transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
178
- transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
179
- transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
180
-
181
- # store the 2D positions
182
- for i, channel in enumerate(tpl_order):
183
- tpl_dict[channel]["coord_2d"] = all_tpl[i]
184
- for i, channel in enumerate(in_order):
185
- in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
186
-
187
-
188
- # 3D alignment
189
- all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
190
- all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
191
- matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
192
- matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
193
-
194
- rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
195
- rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
196
- rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
197
-
198
- transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
199
- transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
200
- transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
201
- transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
202
-
203
- # update in_channels' 3D positions
204
- for i, channel in enumerate(in_order):
205
- in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
206
-
207
- channel_info.update({
208
- "templateDict" : tpl_dict,
209
- "inputDict" : in_dict
210
- })
211
- return channel_info
212
-
213
- def find_neighbors(channel_info, missing_channels, new_idx):
214
- in_order = channel_info["inputOrder"]
215
- tpl_dict = channel_info["templateDict"]
216
- in_dict = channel_info["inputDict"]
217
-
218
- all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
219
- empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
220
-
221
- # use KNN to choose k nearest channels
222
- k = 4 if len(in_order)>4 else len(in_order)
223
- knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
224
- knn.fit(all_in)
225
- for i, channel in enumerate(missing_channels):
226
- distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
227
- idx = tpl_dict[channel]["index"]
228
- new_idx[idx] = indices[0].tolist()
229
-
230
- return new_idx
231
-
232
- def match_names(stage1_info, channel_info):
233
- # read the location file
234
- loc_file = stage1_info["fileNames"]["input_loc"]
235
- tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
236
- tpl_order = tpl_montage.ch_names
237
- in_order = in_montage.ch_names
238
- new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channls
239
- fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode"
240
-
241
- alias_dict = {
242
- 'T3': 'T7',
243
- 'T4': 'T8',
244
- 'T5': 'P7',
245
- 'T6': 'P8'
246
- }
247
- for i, channel in enumerate(tpl_order):
248
- if channel in alias_dict and alias_dict[channel] in in_dict:
249
- tpl_montage.rename_channels({channel: alias_dict[channel]})
250
- tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
251
- channel = alias_dict[channel]
252
-
253
- if channel in in_dict:
254
- new_idx[i] = [in_dict[channel]["index"]]
255
- fill_flags[i] = False
256
- tpl_dict[channel]["matched"] = True
257
- in_dict[channel]["assigned"] = True
258
-
259
- # update the names
260
- tpl_order = tpl_montage.ch_names
261
-
262
- stage1_info.update({
263
- "unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
264
- "missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
265
- "mappingData" : [
266
- {
267
- "newOrder" : new_idx,
268
- "fillFlags" : fill_flags
269
- }
270
- ]
271
- })
272
- channel_info.update({
273
- "templateOrder" : tpl_order,
274
- "inputOrder" : in_order,
275
- "templateDict" : tpl_dict,
276
- "inputDict" : in_dict
277
- })
278
- return stage1_info, channel_info, tpl_montage, in_montage
279
-
280
- def optimal_mapping(channel_info):
281
- tpl_order = channel_info["templateOrder"]
282
- in_order = channel_info["inputOrder"]
283
- tpl_dict = channel_info["templateDict"]
284
- in_dict = channel_info["inputDict"]
285
- unassigned = get_unassigned_inputs(in_order, in_dict)
286
- # reset all tpl.matched to False
287
- for channel in tpl_dict:
288
- tpl_dict[channel]["matched"] = False
289
-
290
- all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
291
- unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
292
-
293
- # initialize the cost matrix for the Hungarian algorithm
294
- if len(unassigned) < 30:
295
- cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
296
- else:
297
- cost_matrix = np.zeros((30, len(unassigned)))
298
- # fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
299
- for i in range(30):
300
- for j in range(len(unassigned)):
301
- cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
302
-
303
- # apply the Hungarian algorithm to optimally assign each in_channel to a tpl_channel
304
- # by minimizing the total distances between their positions.
305
- row_idx, col_idx = linear_sum_assignment(cost_matrix)
306
-
307
- # store the mapping results
308
- new_idx = [[]]*30
309
- fill_flags = [True]*30
310
- for i in range(30):
311
- if col_idx[i] < len(unassigned): # filter out dummy channels
312
- tpl_channel = tpl_order[row_idx[i]]
313
- in_channel = unassigned[col_idx[i]]
314
-
315
- new_idx[row_idx[i]] = [in_dict[in_channel]["index"]]
316
- fill_flags[row_idx[i]] = False
317
- tpl_dict[tpl_channel]["matched"] = True
318
- in_dict[in_channel]["assigned"] = True
319
- #print(f'{tpl_channel}({row_idx[i]}) <- {in_channel}({col_idx[i]})')
320
-
321
- # fill the remaining empty tpl_channels
322
- missing_channels = get_empty_templates(tpl_order, tpl_dict)
323
- if missing_channels != []:
324
- new_idx = find_neighbors(channel_info, missing_channels, new_idx)
325
-
326
- mapping_data = {
327
- "newOrder" : new_idx,
328
- "fillFlags" : fill_flags
329
- }
330
- channel_info.update({
331
- "templateDict" : tpl_dict,
332
- "inputDict" : in_dict
333
- })
334
- return mapping_data, channel_info
335
-
336
- def mapping_result(stage1_info, stage2_info, channel_info, filename):
337
- unassigned_num = len(stage1_info["unassignedInputs"])
338
- batch_num = math.ceil(unassigned_num/30) + 1
339
-
340
- # map the remaining in_channels
341
- unassigned_num = len(stage1_info["unassignedInputs"])
342
- batch_num = math.ceil(unassigned_num/30) + 1
343
- for i in range(1, batch_num):
344
- # optimally select 30 in_channels to map to the tpl_channels based on proximity
345
- new_mapping_data, channel_info = optimal_mapping(channel_info)
346
- stage1_info["mappingData"] += [new_mapping_data]
347
-
348
- # save the mapping results
349
- new_dict = {
350
- #"templateOrder" : channel_info["templateOrder"],
351
- #"inputOrder" : channel_info["inputOrder"],
352
- "batchNum" : batch_num,
353
- "mappingData" : stage1_info["mappingData"]
354
- }
355
- with open(filename, 'w') as jsonfile:
356
- jsonfile.write(json.dumps(new_dict))
357
-
358
- stage2_info["totalBatchNum"] = batch_num
359
- return stage1_info, stage2_info, channel_info
360
-
 
1
+ import utils
2
+ import os
3
+ import math
4
+ import json
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import mne
8
+ from mne.channels import read_custom_montage
9
+ from scipy.interpolate import Rbf
10
+ from scipy.optimize import linear_sum_assignment
11
+ from sklearn.neighbors import NearestNeighbors
12
+
13
+ def reorder_data(idx_order, fill_flags, filename, new_filename):
14
+ # read the input data
15
+ raw_data = utils.read_train_data(filename)
16
+ #print(raw_data.shape)
17
+ new_data = np.zeros((30, raw_data.shape[1]))
18
+
19
+ zero_arr = np.zeros((1, raw_data.shape[1]))
20
+ for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
21
+ if flag == False:
22
+ new_data[i, :] = raw_data[idx_set[0], :]
23
+ elif idx_set == []:
24
+ new_data[i, :] = zero_arr
25
+ else:
26
+ tmp_data = [raw_data[j, :] for j in idx_set]
27
+ new_data[i, :] = np.mean(tmp_data, axis=0)
28
+
29
+ utils.save_data(new_data, new_filename)
30
+ return raw_data.shape
31
+
32
+ def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
33
+ # read the denoised data
34
+ d_data = utils.read_train_data(filename)
35
+ if batch_cnt == 0:
36
+ new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
37
+ #print(new_data.shape)
38
+ else:
39
+ new_data = utils.read_train_data(new_filename)
40
+
41
+ for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
42
+ # ignore if this channel was filled using "fillmode"
43
+ if flag == False:
44
+ new_data[idx_set[0], :] = d_data[i, :]
45
+
46
+ utils.save_data(new_data, new_filename)
47
+ return
48
+
49
+ def get_matched(tpl_order, tpl_dict):
50
+ return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
51
+
52
+ def get_empty_templates(tpl_order, tpl_dict):
53
+ return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
54
+
55
+ def get_unassigned_inputs(in_order, in_dict):
56
+ return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
57
+
58
+ def read_montage_data(loc_file):
59
+ tpl_montage = read_custom_montage("./template_chanlocs.loc")
60
+ in_montage = read_custom_montage(loc_file)
61
+ tpl_order = tpl_montage.ch_names
62
+ in_order = in_montage.ch_names
63
+ tpl_dict = {}
64
+ in_dict = {}
65
+
66
+ # convert all channel names to uppercase and store the channel information
67
+ for i, channel in enumerate(tpl_order):
68
+ up_channel = str.upper(channel)
69
+ tpl_montage.rename_channels({channel: up_channel})
70
+ tpl_dict[up_channel] = {
71
+ "index" : i,
72
+ "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
73
+ "matched" : False
74
+ }
75
+ for i, channel in enumerate(in_order):
76
+ up_channel = str.upper(channel)
77
+ in_montage.rename_channels({channel: up_channel})
78
+ in_dict[up_channel] = {
79
+ "index" : i,
80
+ "coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
81
+ "assigned" : False
82
+ }
83
+ return tpl_montage, in_montage, tpl_dict, in_dict
84
+
85
+ def save_figures(channel_info, tpl_montage, filename1, filename2):
86
+ tpl_order = channel_info["templateOrder"]
87
+ in_order = channel_info["inputOrder"]
88
+ tpl_dict = channel_info["templateDict"]
89
+ in_dict = channel_info["inputDict"]
90
+
91
+ tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
92
+ tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
93
+ in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
94
+ in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
95
+ tpl_coords = np.vstack((tpl_x, tpl_y)).T
96
+ in_coords = np.vstack((in_x, in_y)).T
97
+
98
+ # extract template's head figure
99
+ tpl_fig = tpl_montage.plot()
100
+ tpl_ax = tpl_fig.axes[0]
101
+ lines = tpl_ax.lines
102
+ head_lines = []
103
+ for line in lines:
104
+ x, y = line.get_data()
105
+ head_lines.append((x,y))
106
+ plt.close()
107
+
108
+ # -------------------------plot input montage------------------------------
109
+ fig = plt.figure(figsize=(6.4,6.4), dpi=100)
110
+ ax = fig.add_subplot(111)
111
+ fig.tight_layout()
112
+ ax.set_aspect('equal')
113
+ ax.axis('off')
114
+
115
+ # plot template's head
116
+ for x, y in head_lines:
117
+ ax.plot(x, y, color='black', linewidth=1.0)
118
+ # plot in_channels on it
119
+ ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
120
+ for i, channel in enumerate(in_order):
121
+ ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
122
+ # save input_montage
123
+ fig.savefig(filename1)
124
+
125
+ # ---------------------------add indications-------------------------------
126
+ # plot unmatched input channels in red
127
+ indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
128
+ ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
129
+ for i in indices:
130
+ ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
131
+ # save mapped_montage
132
+ fig.savefig(filename2)
133
+
134
+ # -------------------------------------------------------------------------
135
+ # store the tpl and in_channels' display positions (in px).
136
+ tpl_coords = ax.transData.transform(tpl_coords)
137
+ in_coords = ax.transData.transform(in_coords)
138
+ plt.close()
139
+
140
+ for i, channel in enumerate(tpl_order):
141
+ css_left = (tpl_coords[i,0]-11)/6.4
142
+ css_bottom = (tpl_coords[i,1]-7)/6.4
143
+ tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
144
+ for i, channel in enumerate(in_order):
145
+ css_left = (in_coords[i,0]-11)/6.4
146
+ css_bottom = (in_coords[i,1]-7)/6.4
147
+ in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
148
+
149
+ channel_info.update({
150
+ "templateDict" : tpl_dict,
151
+ "inputDict" : in_dict
152
+ })
153
+ return channel_info
154
+
155
+ def align_coords(channel_info, tpl_montage, in_montage):
156
+ tpl_order = channel_info["templateOrder"]
157
+ in_order = channel_info["inputOrder"]
158
+ tpl_dict = channel_info["templateDict"]
159
+ in_dict = channel_info["inputDict"]
160
+ matched = get_matched(tpl_order, tpl_dict)
161
+
162
+ # 2D alignment (for visualization purposes)
163
+ fig = [tpl_montage.plot(), in_montage.plot()]
164
+ ax = [fig[0].axes[0], fig[1].axes[0]]
165
+
166
+ # extract the displayed 2D coordinates from the plots
167
+ all_tpl = ax[0].collections[0].get_offsets().data
168
+ all_in= ax[1].collections[0].get_offsets().data
169
+ matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
170
+ matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
171
+
172
+ # apply TPS to transform in_channels positions to align with tpl_channels positions
173
+ rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
174
+ rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
175
+
176
+ # apply the transformation to all in_channels
177
+ transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
178
+ transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
179
+ transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
180
+
181
+ # store the 2D positions
182
+ for i, channel in enumerate(tpl_order):
183
+ tpl_dict[channel]["coord_2d"] = all_tpl[i]
184
+ for i, channel in enumerate(in_order):
185
+ in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
186
+
187
+
188
+ # 3D alignment
189
+ all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
190
+ all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
191
+ matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
192
+ matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
193
+
194
+ rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
195
+ rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
196
+ rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
197
+
198
+ transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
199
+ transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
200
+ transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
201
+ transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
202
+
203
+ # update in_channels' 3D positions
204
+ for i, channel in enumerate(in_order):
205
+ in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
206
+
207
+ channel_info.update({
208
+ "templateDict" : tpl_dict,
209
+ "inputDict" : in_dict
210
+ })
211
+ return channel_info
212
+
213
+ def find_neighbors(channel_info, missing_channels, new_idx):
214
+ in_order = channel_info["inputOrder"]
215
+ tpl_dict = channel_info["templateDict"]
216
+ in_dict = channel_info["inputDict"]
217
+
218
+ all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
219
+ empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
220
+
221
+ # use KNN to choose k nearest channels
222
+ k = 4 if len(in_order)>4 else len(in_order)
223
+ knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
224
+ knn.fit(all_in)
225
+ for i, channel in enumerate(missing_channels):
226
+ distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
227
+ idx = tpl_dict[channel]["index"]
228
+ new_idx[idx] = indices[0].tolist()
229
+
230
+ return new_idx
231
+
232
+ def match_names(stage1_info, channel_info):
233
+ # read the location file
234
+ loc_file = stage1_info["fileNames"]["input_loc"]
235
+ tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
236
+ tpl_order = tpl_montage.ch_names
237
+ in_order = in_montage.ch_names
238
+ new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channels
239
+ fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode"
240
+
241
+ alias_dict = {
242
+ 'T3': 'T7',
243
+ 'T4': 'T8',
244
+ 'T5': 'P7',
245
+ 'T6': 'P8'
246
+ }
247
+ for i, channel in enumerate(tpl_order):
248
+ if channel in alias_dict and alias_dict[channel] in in_dict:
249
+ tpl_montage.rename_channels({channel: alias_dict[channel]})
250
+ tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
251
+ channel = alias_dict[channel]
252
+
253
+ if channel in in_dict:
254
+ new_idx[i] = [in_dict[channel]["index"]]
255
+ fill_flags[i] = False
256
+ tpl_dict[channel]["matched"] = True
257
+ in_dict[channel]["assigned"] = True
258
+
259
+ # update the names
260
+ tpl_order = tpl_montage.ch_names
261
+
262
+ stage1_info.update({
263
+ "unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
264
+ "missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
265
+ "mappingData" : [
266
+ {
267
+ "newOrder" : new_idx,
268
+ "fillFlags" : fill_flags
269
+ }
270
+ ]
271
+ })
272
+ channel_info.update({
273
+ "templateOrder" : tpl_order,
274
+ "inputOrder" : in_order,
275
+ "templateDict" : tpl_dict,
276
+ "inputDict" : in_dict
277
+ })
278
+ return stage1_info, channel_info, tpl_montage, in_montage
279
+
280
+ def optimal_mapping(channel_info):
281
+ tpl_order = channel_info["templateOrder"]
282
+ in_order = channel_info["inputOrder"]
283
+ tpl_dict = channel_info["templateDict"]
284
+ in_dict = channel_info["inputDict"]
285
+ unassigned = get_unassigned_inputs(in_order, in_dict)
286
+ # reset all tpl.matched to False
287
+ for channel in tpl_dict:
288
+ tpl_dict[channel]["matched"] = False
289
+
290
+ all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
291
+ unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
292
+
293
+ # initialize the cost matrix for the Hungarian algorithm
294
+ if len(unassigned) < 30:
295
+ cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
296
+ else:
297
+ cost_matrix = np.zeros((30, len(unassigned)))
298
+ # fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
299
+ for i in range(30):
300
+ for j in range(len(unassigned)):
301
+ cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
302
+
303
+ # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
304
+ # by minimizing the total distances between their positions.
305
+ row_idx, col_idx = linear_sum_assignment(cost_matrix)
306
+
307
+ # store the mapping results
308
+ new_idx = [[]]*30
309
+ fill_flags = [True]*30
310
+ for i, j in zip(row_idx, col_idx):
311
+ if j < len(unassigned): # filter out dummy channels
312
+ tpl_channel = tpl_order[i]
313
+ in_channel = unassigned[j]
314
+
315
+ new_idx[i] = [in_dict[in_channel]["index"]]
316
+ fill_flags[i] = False
317
+ tpl_dict[tpl_channel]["matched"] = True
318
+ in_dict[in_channel]["assigned"] = True
319
+ #print(f'{tpl_channel}({i}) <- {in_channel}({j})')
320
+
321
+ # fill the remaining empty tpl_channels
322
+ missing_channels = get_empty_templates(tpl_order, tpl_dict)
323
+ if missing_channels != []:
324
+ new_idx = find_neighbors(channel_info, missing_channels, new_idx)
325
+
326
+ mapping_data = {
327
+ "newOrder" : new_idx,
328
+ "fillFlags" : fill_flags
329
+ }
330
+ channel_info.update({
331
+ "templateDict" : tpl_dict,
332
+ "inputDict" : in_dict
333
+ })
334
+ return mapping_data, channel_info
335
+
336
+ def mapping_result(stage1_info, stage2_info, channel_info, filename):
337
+ unassigned_num = len(stage1_info["unassignedInputs"])
338
+ batch_num = math.ceil(unassigned_num/30) + 1
339
+
340
+ # map the remaining in_channels
341
+ for i in range(1, batch_num):
342
+ # optimally select 30 in_channels to map to the tpl_channels based on proximity
343
+ new_mapping_data, channel_info = optimal_mapping(channel_info)
344
+ stage1_info["mappingData"] += [new_mapping_data]
345
+
346
+ # save the mapping results
347
+ new_dict = {
348
+ #"templateOrder" : channel_info["templateOrder"],
349
+ #"inputOrder" : channel_info["inputOrder"],
350
+ "batchNum" : batch_num,
351
+ "mappingData" : stage1_info["mappingData"]
352
+ }
353
+ with open(filename, 'w') as jsonfile:
354
+ jsonfile.write(json.dumps(new_dict))
355
+
356
+ stage2_info["totalBatchNum"] = batch_num
357
+ return stage1_info, stage2_info, channel_info
358
+