audrey06100 commited on
Commit
df70562
·
1 Parent(s): ed22689
Files changed (2) hide show
  1. app.py +26 -26
  2. app_utils.py +92 -92
app.py CHANGED
@@ -60,7 +60,7 @@ init_js = """
60
  channel_info = JSON.parse(JSON.stringify(channel_info));
61
 
62
  let selector, attribute;
63
- let channel, left, bottom;
64
 
65
  if(stage1_info.state == "step2-selecting"){
66
  selector = "#radio-group > div:nth-of-type(2)";
@@ -84,9 +84,9 @@ init_js = """
84
  // move the radios/checkboxes
85
  let all_elem = document.querySelectorAll(selector+" > label");
86
  Array.from(all_elem).forEach(item => {
87
- channel = item.querySelector("input").getAttribute(attribute);
88
- left = channel_info.inputDict[channel].css_position[0];
89
- bottom = channel_info.inputDict[channel].css_position[1];
90
 
91
  item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
92
  item.className = "";
@@ -94,9 +94,9 @@ init_js = """
94
  });
95
 
96
  // add indication for the empty tpl_channels
97
- channel = stage1_info.emptyTemplates[0];
98
- left = channel_info.templateDict[channel].css_position[0];
99
- bottom = channel_info.templateDict[channel].css_position[1];
100
  let dot_rule = `
101
  ${selector}::before {
102
  content: "";
@@ -116,7 +116,7 @@ init_js = """
116
  bottom = bottom.toString()+"%";
117
  let txt_rule = `
118
  ${selector}::after {
119
- content: "${channel}";
120
  position: absolute;
121
  color: red;
122
  left: ${left};
@@ -144,7 +144,7 @@ update_js = """
144
  channel_info = JSON.parse(JSON.stringify(channel_info));
145
 
146
  let selector;
147
- let cnt, channel, left, bottom;
148
 
149
  if(stage1_info.state == "step2-selecting"){
150
  selector = "#radio-group > div:nth-of-type(2)";
@@ -153,9 +153,9 @@ update_js = """
153
  // update the radios
154
  let all_elem = document.querySelectorAll(selector+" > label");
155
  Array.from(all_elem).forEach(item => {
156
- channel = item.querySelector("input").value;
157
- left = channel_info.inputDict[channel].css_position[0];
158
- bottom = channel_info.inputDict[channel].css_position[1];
159
 
160
  item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
161
  item.className = "";
@@ -167,9 +167,9 @@ update_js = """
167
  }else return;
168
 
169
  // update the indication
170
- channel = stage1_info.emptyTemplates[cnt-1];
171
- left = channel_info.templateDict[channel].css_position[0];
172
- bottom = channel_info.templateDict[channel].css_position[1];
173
  let dot_rule = `
174
  ${selector}::before {
175
  content: "";
@@ -189,7 +189,7 @@ update_js = """
189
  bottom = bottom.toString()+"%";
190
  let txt_rule = `
191
  ${selector}::after {
192
- content: "${channel}";
193
  position: absolute;
194
  color: red;
195
  left: ${left};
@@ -403,7 +403,7 @@ with gr.Blocks() as demo:
403
 
404
  # ========================================step1=========================================
405
  elif stage1_info["state"] == "step1-finished":
406
- in_num = len(channel_info["inputOrder"])
407
  matched_num = 30 - len(stage1_info["emptyTemplates"])
408
 
409
  # step1 to step4
@@ -490,10 +490,10 @@ with gr.Blocks() as demo:
490
 
491
  # --------------------------------update information--------------------------------
492
  # exclude the selected in_channel of the previous round
493
- stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"],
494
  channel_info["inputDict"])
495
  # exclude the tpl_channels filled in step2
496
- stage1_info["emptyTemplates"] = app_utils.get_empty_templates(channel_info["templateOrder"],
497
  channel_info["templateDict"])
498
  # -----------------------------determine the next step------------------------------
499
  # step2 to step4
@@ -571,7 +571,7 @@ with gr.Blocks() as demo:
571
 
572
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
573
  value = stage1_info["mappingResults"][0]["newOrder"][tpl_idx]
574
- value = [channel_info["inputOrder"][i] for i in value]
575
 
576
  stage1_info["state"] = "step3-2-selecting"
577
  # determine which button to display
@@ -580,7 +580,7 @@ with gr.Blocks() as demo:
580
  desc_md : gr.Markdown(md),
581
  in_fillmode : gr.Dropdown(visible=False),
582
  fillmode_btn : gr.Button(visible=False),
583
- chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
584
  value=value, label=label, visible=True),
585
  next_btn : gr.Button(visible=True)}
586
  else:
@@ -588,7 +588,7 @@ with gr.Blocks() as demo:
588
  desc_md : gr.Markdown(md),
589
  in_fillmode : gr.Dropdown(visible=False),
590
  fillmode_btn : gr.Button(visible=False),
591
- chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
592
  value=value, label=label, visible=True),
593
  step3_btn : gr.Button(visible=True)}
594
 
@@ -598,7 +598,7 @@ with gr.Blocks() as demo:
598
  # --------------------------------store information---------------------------------
599
  prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step3"]["count"]-1]
600
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
601
- sel_idx = [channel_info["inputDict"][channel]["index"] for channel in sel_chkbox]
602
  stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
603
  #print(prev_tpl_name, '<-', sel_chkbox)
604
  # ----------------------------------------------------------------------------------
@@ -687,7 +687,7 @@ with gr.Blocks() as demo:
687
  step2["count"] += 1
688
 
689
  # exclude the selected in_channel of the previous round
690
- stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"], channel_info["inputDict"])
691
 
692
  tpl_name = stage1_info["emptyTemplates"][step2["count"]-1]
693
  label = '{} ({}/{})'.format(tpl_name, step2["count"], step2["totalNum"])
@@ -726,7 +726,7 @@ with gr.Blocks() as demo:
726
  # ----------------------------------store information-----------------------------------
727
  prev_tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
728
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
729
- sel_idx = [channel_info["inputDict"][channel]["index"] for channel in sel_name]
730
  stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
731
  #print(prev_tpl_name, '<-', sel_name)
732
 
@@ -738,7 +738,7 @@ with gr.Blocks() as demo:
738
 
739
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
740
  value = stage1_info["mappingResults"][0]["newOrder"][tpl_idx]
741
- value = [channel_info["inputOrder"][i] for i in value]
742
 
743
  stage1_info["step3"] = step3
744
  # determine which button to display
 
60
  channel_info = JSON.parse(JSON.stringify(channel_info));
61
 
62
  let selector, attribute;
63
+ let name, left, bottom;
64
 
65
  if(stage1_info.state == "step2-selecting"){
66
  selector = "#radio-group > div:nth-of-type(2)";
 
84
  // move the radios/checkboxes
85
  let all_elem = document.querySelectorAll(selector+" > label");
86
  Array.from(all_elem).forEach(item => {
87
+ name = item.querySelector("input").getAttribute(attribute);
88
+ left = channel_info.inputDict[name].css_position[0];
89
+ bottom = channel_info.inputDict[name].css_position[1];
90
 
91
  item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
92
  item.className = "";
 
94
  });
95
 
96
  // add indication for the empty tpl_channels
97
+ name = stage1_info.emptyTemplates[0];
98
+ left = channel_info.templateDict[name].css_position[0];
99
+ bottom = channel_info.templateDict[name].css_position[1];
100
  let dot_rule = `
101
  ${selector}::before {
102
  content: "";
 
116
  bottom = bottom.toString()+"%";
117
  let txt_rule = `
118
  ${selector}::after {
119
+ content: "${name}";
120
  position: absolute;
121
  color: red;
122
  left: ${left};
 
144
  channel_info = JSON.parse(JSON.stringify(channel_info));
145
 
146
  let selector;
147
+ let cnt, name, left, bottom;
148
 
149
  if(stage1_info.state == "step2-selecting"){
150
  selector = "#radio-group > div:nth-of-type(2)";
 
153
  // update the radios
154
  let all_elem = document.querySelectorAll(selector+" > label");
155
  Array.from(all_elem).forEach(item => {
156
+ name = item.querySelector("input").value;
157
+ left = channel_info.inputDict[name].css_position[0];
158
+ bottom = channel_info.inputDict[name].css_position[1];
159
 
160
  item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
161
  item.className = "";
 
167
  }else return;
168
 
169
  // update the indication
170
+ name = stage1_info.emptyTemplates[cnt-1];
171
+ left = channel_info.templateDict[name].css_position[0];
172
+ bottom = channel_info.templateDict[name].css_position[1];
173
  let dot_rule = `
174
  ${selector}::before {
175
  content: "";
 
189
  bottom = bottom.toString()+"%";
190
  let txt_rule = `
191
  ${selector}::after {
192
+ content: "${name}";
193
  position: absolute;
194
  color: red;
195
  left: ${left};
 
403
 
404
  # ========================================step1=========================================
405
  elif stage1_info["state"] == "step1-finished":
406
+ in_num = len(channel_info["inputNames"])
407
  matched_num = 30 - len(stage1_info["emptyTemplates"])
408
 
409
  # step1 to step4
 
490
 
491
  # --------------------------------update information--------------------------------
492
  # exclude the selected in_channel of the previous round
493
+ stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputNames"],
494
  channel_info["inputDict"])
495
  # exclude the tpl_channels filled in step2
496
+ stage1_info["emptyTemplates"] = app_utils.get_empty_templates(channel_info["templateNames"],
497
  channel_info["templateDict"])
498
  # -----------------------------determine the next step------------------------------
499
  # step2 to step4
 
571
 
572
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
573
  value = stage1_info["mappingResults"][0]["newOrder"][tpl_idx]
574
+ value = [channel_info["inputNames"][i] for i in value]
575
 
576
  stage1_info["state"] = "step3-2-selecting"
577
  # determine which button to display
 
580
  desc_md : gr.Markdown(md),
581
  in_fillmode : gr.Dropdown(visible=False),
582
  fillmode_btn : gr.Button(visible=False),
583
+ chkbox_group : gr.CheckboxGroup(choices=channel_info["inputNames"],
584
  value=value, label=label, visible=True),
585
  next_btn : gr.Button(visible=True)}
586
  else:
 
588
  desc_md : gr.Markdown(md),
589
  in_fillmode : gr.Dropdown(visible=False),
590
  fillmode_btn : gr.Button(visible=False),
591
+ chkbox_group : gr.CheckboxGroup(choices=channel_info["inputNames"],
592
  value=value, label=label, visible=True),
593
  step3_btn : gr.Button(visible=True)}
594
 
 
598
  # --------------------------------store information---------------------------------
599
  prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step3"]["count"]-1]
600
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
601
+ sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_chkbox]
602
  stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
603
  #print(prev_tpl_name, '<-', sel_chkbox)
604
  # ----------------------------------------------------------------------------------
 
687
  step2["count"] += 1
688
 
689
  # exclude the selected in_channel of the previous round
690
+ stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputNames"], channel_info["inputDict"])
691
 
692
  tpl_name = stage1_info["emptyTemplates"][step2["count"]-1]
693
  label = '{} ({}/{})'.format(tpl_name, step2["count"], step2["totalNum"])
 
726
  # ----------------------------------store information-----------------------------------
727
  prev_tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
728
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
729
+ sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_name]
730
  stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
731
  #print(prev_tpl_name, '<-', sel_name)
732
 
 
738
 
739
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
740
  value = stage1_info["mappingResults"][0]["newOrder"][tpl_idx]
741
+ value = [channel_info["inputNames"][i] for i in value]
742
 
743
  stage1_info["step3"] = step3
744
  # determine which button to display
app_utils.py CHANGED
@@ -46,52 +46,52 @@ def restore_order(batch_cnt, raw_data_shape, idx_order, orig_flags, filename, ou
46
  utils.save_data(new_data, outputname)
47
  return
48
 
49
- def get_matched(tpl_order, tpl_dict):
50
- return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
51
 
52
- def get_empty_templates(tpl_order, tpl_dict):
53
- return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
54
 
55
- def get_unassigned_inputs(in_order, in_dict):
56
- return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
57
 
58
  def read_montage_data(loc_file):
59
  tpl_montage = read_custom_montage("./template_chanlocs.loc")
60
  in_montage = read_custom_montage(loc_file)
61
- tpl_order = tpl_montage.ch_names
62
- in_order = in_montage.ch_names
63
  tpl_dict = {}
64
  in_dict = {}
65
 
66
  # convert all channel names to uppercase and store their information
67
- for i, channel in enumerate(tpl_order):
68
- up_channel = str.upper(channel)
69
- tpl_montage.rename_channels({channel: up_channel})
70
- tpl_dict[up_channel] = {
71
  "index" : i,
72
- "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
73
  "matched" : False
74
  }
75
- for i, channel in enumerate(in_order):
76
- up_channel = str.upper(channel)
77
- in_montage.rename_channels({channel: up_channel})
78
- in_dict[up_channel] = {
79
  "index" : i,
80
- "coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
81
  "assigned" : False
82
  }
83
  return tpl_montage, in_montage, tpl_dict, in_dict
84
 
85
  def save_figures(channel_info, tpl_montage, filename1, filename2):
86
- tpl_order = channel_info["templateOrder"]
87
- in_order = channel_info["inputOrder"]
88
  tpl_dict = channel_info["templateDict"]
89
  in_dict = channel_info["inputDict"]
90
 
91
- tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
92
- tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
93
- in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
94
- in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
95
  tpl_coords = np.vstack((tpl_x, tpl_y)).T
96
  in_coords = np.vstack((in_x, in_y)).T
97
 
@@ -116,18 +116,18 @@ def save_figures(channel_info, tpl_montage, filename1, filename2):
116
  ax.plot(x, y, color='black', linewidth=1.0)
117
  # plot in_channels on it
118
  ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
119
- for i, channel in enumerate(in_order):
120
- ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
121
  # save input_montage
122
  fig.savefig(filename1)
123
 
124
  # ---------------------------add indications-------------------------------
125
  # plot unmatched input channels in red
126
- indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
127
  if indices != []:
128
  ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
129
  for i in indices:
130
- ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
131
  # save mapped_montage
132
  fig.savefig(filename2)
133
 
@@ -137,14 +137,14 @@ def save_figures(channel_info, tpl_montage, filename1, filename2):
137
  in_coords = ax.transData.transform(in_coords)
138
  plt.close('all')
139
 
140
- for i, channel in enumerate(tpl_order):
141
  css_left = (tpl_coords[i,0]-11)/6.4
142
  css_bottom = (tpl_coords[i,1]-7)/6.4
143
- tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
144
- for i, channel in enumerate(in_order):
145
  css_left = (in_coords[i,0]-11)/6.4
146
  css_bottom = (in_coords[i,1]-7)/6.4
147
- in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
148
 
149
  channel_info.update({
150
  "templateDict" : tpl_dict,
@@ -153,11 +153,11 @@ def save_figures(channel_info, tpl_montage, filename1, filename2):
153
  return channel_info
154
 
155
  def align_coords(channel_info, tpl_montage, in_montage):
156
- tpl_order = channel_info["templateOrder"]
157
- in_order = channel_info["inputOrder"]
158
  tpl_dict = channel_info["templateDict"]
159
  in_dict = channel_info["inputDict"]
160
- matched_order = get_matched(tpl_order, tpl_dict)
161
 
162
  # 2D alignment (for visualization purposes)
163
  fig = [tpl_montage.plot(), in_montage.plot()]
@@ -166,8 +166,8 @@ def align_coords(channel_info, tpl_montage, in_montage):
166
  # extract the displayed 2D coordinates
167
  all_tpl = ax[0].collections[0].get_offsets().data
168
  all_in= ax[1].collections[0].get_offsets().data
169
- matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched_order])
170
- matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched_order])
171
  plt.close('all')
172
 
173
  # apply TPS to transform in_channels to align with tpl_channels positions
@@ -179,17 +179,17 @@ def align_coords(channel_info, tpl_montage, in_montage):
179
  transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
180
  transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
181
 
182
- for i, channel in enumerate(tpl_order):
183
- tpl_dict[channel]["coord_2d"] = all_tpl[i]
184
- for i, channel in enumerate(in_order):
185
- in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
186
 
187
 
188
  # 3D alignment
189
- all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
190
- all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
191
- matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched_order])
192
- matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched_order])
193
 
194
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
195
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
@@ -200,8 +200,8 @@ def align_coords(channel_info, tpl_montage, in_montage):
200
  transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
201
  transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
202
 
203
- for i, channel in enumerate(in_order):
204
- in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
205
 
206
  channel_info.update({
207
  "templateDict" : tpl_dict,
@@ -209,21 +209,21 @@ def align_coords(channel_info, tpl_montage, in_montage):
209
  })
210
  return channel_info
211
 
212
- def find_neighbors(channel_info, empty_tpl_order, new_idx):
213
- in_order = channel_info["inputOrder"]
214
  tpl_dict = channel_info["templateDict"]
215
  in_dict = channel_info["inputDict"]
216
 
217
- all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
218
- empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in empty_tpl_order]
219
 
220
  # use KNN to choose k nearest channels
221
- k = 4 if len(in_order)>4 else len(in_order)
222
  knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
223
  knn.fit(all_in)
224
- for i, channel in enumerate(empty_tpl_order):
225
  distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
226
- idx = tpl_dict[channel]["index"]
227
  new_idx[idx] = indices[0].tolist()
228
 
229
  return new_idx
@@ -232,8 +232,8 @@ def match_names(stage1_info):
232
  # read the location file
233
  loc_file = stage1_info["fileNames"]["inputLocation"]
234
  tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
235
- tpl_order = tpl_montage.ch_names
236
- in_order = in_montage.ch_names
237
  new_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
238
  orig_flags = [False]*30
239
 
@@ -243,24 +243,24 @@ def match_names(stage1_info):
243
  'T5': 'P7',
244
  'T6': 'P8'
245
  }
246
- for i, channel in enumerate(tpl_order):
247
- if channel in alias_dict and alias_dict[channel] in in_dict:
248
- tpl_montage.rename_channels({channel: alias_dict[channel]})
249
- tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
250
- channel = alias_dict[channel]
251
 
252
- if channel in in_dict:
253
- new_idx[i] = [in_dict[channel]["index"]]
254
  orig_flags[i] = True
255
- tpl_dict[channel]["matched"] = True
256
- in_dict[channel]["assigned"] = True
257
 
258
  # update the names
259
- tpl_order = tpl_montage.ch_names
260
 
261
  stage1_info.update({
262
- "unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
263
- "emptyTemplates" : get_empty_templates(tpl_order, tpl_dict),
264
  "mappingResults" : [
265
  {
266
  "newOrder" : new_idx,
@@ -269,34 +269,34 @@ def match_names(stage1_info):
269
  ]
270
  })
271
  channel_info = {
272
- "templateOrder" : tpl_order,
273
- "inputOrder" : in_order,
274
  "templateDict" : tpl_dict,
275
  "inputDict" : in_dict
276
  }
277
  return stage1_info, channel_info, tpl_montage, in_montage
278
 
279
  def optimal_mapping(channel_info):
280
- tpl_order = channel_info["templateOrder"]
281
- in_order = channel_info["inputOrder"]
282
  tpl_dict = channel_info["templateDict"]
283
  in_dict = channel_info["inputDict"]
284
- unass_in_order = get_unassigned_inputs(in_order, in_dict)
285
  # reset all tpl.matched to False
286
- for channel in tpl_dict:
287
- tpl_dict[channel]["matched"] = False
288
 
289
- all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
290
- unass_in = np.array([in_dict[channel]["coord_3d"] for channel in unass_in_order])
291
 
292
  # initialize the cost matrix for the Hungarian algorithm
293
- if len(unass_in_order) < 30:
294
  cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
295
  else:
296
- cost_matrix = np.zeros((30, len(unass_in_order)))
297
  # fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
298
  for i in range(30):
299
- for j in range(len(unass_in_order)):
300
  cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unass_in[j])*1000)
301
 
302
  # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
@@ -307,20 +307,20 @@ def optimal_mapping(channel_info):
307
  new_idx = [[None]]*30
308
  orig_flags = [False]*30
309
  for i, j in zip(row_idx, col_idx):
310
- if j < len(unass_in_order): # filter out dummy channels
311
- tpl_channel = tpl_order[i]
312
- in_channel = unass_in_order[j]
313
 
314
- new_idx[i] = [in_dict[in_channel]["index"]]
315
  orig_flags[i] = True
316
- tpl_dict[tpl_channel]["matched"] = True
317
- in_dict[in_channel]["assigned"] = True
318
- #print(f'{tpl_channel}({i}) <- {in_channel}({j})')
319
 
320
  # fill the remaining empty tpl_channels
321
- empty_tpl_order = get_empty_templates(tpl_order, tpl_dict)
322
- if empty_tpl_order != []:
323
- new_idx = find_neighbors(channel_info, empty_tpl_order, new_idx)
324
 
325
  result = {
326
  "newOrder" : new_idx,
@@ -344,8 +344,8 @@ def mapping_result(stage1_info, channel_info, filename):
344
  results += [result]
345
 
346
  data = {
347
- #"templateOrder" : channel_info["templateOrder"],
348
- #"inputOrder" : channel_info["inputOrder"],
349
  "batchNum" : batch_num,
350
  "mappingResults" : results
351
  }
 
46
  utils.save_data(new_data, outputname)
47
  return
48
 
49
+ def get_matched(tpl_names, tpl_dict):
50
+ return [name for name in tpl_names if tpl_dict[name]["matched"]==True]
51
 
52
+ def get_empty_templates(tpl_names, tpl_dict):
53
+ return [name for name in tpl_names if tpl_dict[name]["matched"]==False]
54
 
55
+ def get_unassigned_inputs(in_names, in_dict):
56
+ return [name for name in in_names if in_dict[name]["assigned"]==False]
57
 
58
  def read_montage_data(loc_file):
59
  tpl_montage = read_custom_montage("./template_chanlocs.loc")
60
  in_montage = read_custom_montage(loc_file)
61
+ tpl_names = tpl_montage.ch_names
62
+ in_names = in_montage.ch_names
63
  tpl_dict = {}
64
  in_dict = {}
65
 
66
  # convert all channel names to uppercase and store their information
67
+ for i, name in enumerate(tpl_names):
68
+ up_name = str.upper(name)
69
+ tpl_montage.rename_channels({name: up_name})
70
+ tpl_dict[up_name] = {
71
  "index" : i,
72
+ "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_name],
73
  "matched" : False
74
  }
75
+ for i, name in enumerate(in_names):
76
+ up_name = str.upper(name)
77
+ in_montage.rename_channels({name: up_name})
78
+ in_dict[up_name] = {
79
  "index" : i,
80
+ "coord_3d" : in_montage.get_positions()['ch_pos'][up_name],
81
  "assigned" : False
82
  }
83
  return tpl_montage, in_montage, tpl_dict, in_dict
84
 
85
  def save_figures(channel_info, tpl_montage, filename1, filename2):
86
+ tpl_names = channel_info["templateNames"]
87
+ in_names = channel_info["inputNames"]
88
  tpl_dict = channel_info["templateDict"]
89
  in_dict = channel_info["inputDict"]
90
 
91
+ tpl_x = [tpl_dict[name]["coord_2d"][0] for name in tpl_names]
92
+ tpl_y = [tpl_dict[name]["coord_2d"][1] for name in tpl_names]
93
+ in_x = [in_dict[name]["coord_2d"][0] for name in in_names]
94
+ in_y = [in_dict[name]["coord_2d"][1] for name in in_names]
95
  tpl_coords = np.vstack((tpl_x, tpl_y)).T
96
  in_coords = np.vstack((in_x, in_y)).T
97
 
 
116
  ax.plot(x, y, color='black', linewidth=1.0)
117
  # plot in_channels on it
118
  ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
119
+ for i, name in enumerate(in_names):
120
+ ax.text(in_coords[i,0]+0.003, in_coords[i,1], name, color='black', fontsize=10.0, va='center')
121
  # save input_montage
122
  fig.savefig(filename1)
123
 
124
  # ---------------------------add indications-------------------------------
125
  # plot unmatched input channels in red
126
+ indices = [in_dict[name]["index"] for name in in_names if in_dict[name]["assigned"]==False]
127
  if indices != []:
128
  ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
129
  for i in indices:
130
+ ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_names[i], color='red', fontsize=10.0, va='center')
131
  # save mapped_montage
132
  fig.savefig(filename2)
133
 
 
137
  in_coords = ax.transData.transform(in_coords)
138
  plt.close('all')
139
 
140
+ for i, name in enumerate(tpl_names):
141
  css_left = (tpl_coords[i,0]-11)/6.4
142
  css_bottom = (tpl_coords[i,1]-7)/6.4
143
+ tpl_dict[name]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
144
+ for i, name in enumerate(in_names):
145
  css_left = (in_coords[i,0]-11)/6.4
146
  css_bottom = (in_coords[i,1]-7)/6.4
147
+ in_dict[name]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
148
 
149
  channel_info.update({
150
  "templateDict" : tpl_dict,
 
153
  return channel_info
154
 
155
  def align_coords(channel_info, tpl_montage, in_montage):
156
+ tpl_names = channel_info["templateNames"]
157
+ in_names = channel_info["inputNames"]
158
  tpl_dict = channel_info["templateDict"]
159
  in_dict = channel_info["inputDict"]
160
+ matched_names = get_matched(tpl_names, tpl_dict)
161
 
162
  # 2D alignment (for visualization purposes)
163
  fig = [tpl_montage.plot(), in_montage.plot()]
 
166
  # extract the displayed 2D coordinates
167
  all_tpl = ax[0].collections[0].get_offsets().data
168
  all_in= ax[1].collections[0].get_offsets().data
169
+ matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
170
+ matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
171
  plt.close('all')
172
 
173
  # apply TPS to transform in_channels to align with tpl_channels positions
 
179
  transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
180
  transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
181
 
182
+ for i, name in enumerate(tpl_names):
183
+ tpl_dict[name]["coord_2d"] = all_tpl[i]
184
+ for i, name in enumerate(in_names):
185
+ in_dict[name]["coord_2d"] = transformed_in[i].tolist()
186
 
187
 
188
  # 3D alignment
189
+ all_tpl = np.array([tpl_dict[name]["coord_3d"].tolist() for name in tpl_names])
190
+ all_in = np.array([in_dict[name]["coord_3d"].tolist() for name in in_names])
191
+ matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
192
+ matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
193
 
194
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
195
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
 
200
  transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
201
  transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
202
 
203
+ for i, name in enumerate(in_names):
204
+ in_dict[name]["coord_3d"] = transformed_in[i].tolist()
205
 
206
  channel_info.update({
207
  "templateDict" : tpl_dict,
 
209
  })
210
  return channel_info
211
 
212
+ def find_neighbors(channel_info, empty_tpl_names, new_idx):
213
+ in_names = channel_info["inputNames"]
214
  tpl_dict = channel_info["templateDict"]
215
  in_dict = channel_info["inputDict"]
216
 
217
+ all_in = [np.array(in_dict[name]["coord_3d"]) for name in in_names]
218
+ empty_tpl = [np.array(tpl_dict[name]["coord_3d"]) for name in empty_tpl_names]
219
 
220
  # use KNN to choose k nearest channels
221
+ k = 4 if len(in_names)>4 else len(in_names)
222
  knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
223
  knn.fit(all_in)
224
+ for i, name in enumerate(empty_tpl_names):
225
  distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
226
+ idx = tpl_dict[name]["index"]
227
  new_idx[idx] = indices[0].tolist()
228
 
229
  return new_idx
 
232
  # read the location file
233
  loc_file = stage1_info["fileNames"]["inputLocation"]
234
  tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
235
+ tpl_names = tpl_montage.ch_names
236
+ in_names = in_montage.ch_names
237
  new_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
238
  orig_flags = [False]*30
239
 
 
243
  'T5': 'P7',
244
  'T6': 'P8'
245
  }
246
+ for i, name in enumerate(tpl_names):
247
+ if name in alias_dict and alias_dict[name] in in_dict:
248
+ tpl_montage.rename_channels({name: alias_dict[name]})
249
+ tpl_dict[alias_dict[name]] = tpl_dict.pop(name)
250
+ name = alias_dict[name]
251
 
252
+ if name in in_dict:
253
+ new_idx[i] = [in_dict[name]["index"]]
254
  orig_flags[i] = True
255
+ tpl_dict[name]["matched"] = True
256
+ in_dict[name]["assigned"] = True
257
 
258
  # update the names
259
+ tpl_names = tpl_montage.ch_names
260
 
261
  stage1_info.update({
262
+ "unassignedInputs" : get_unassigned_inputs(in_names, in_dict),
263
+ "emptyTemplates" : get_empty_templates(tpl_names, tpl_dict),
264
  "mappingResults" : [
265
  {
266
  "newOrder" : new_idx,
 
269
  ]
270
  })
271
  channel_info = {
272
+ "templateNames" : tpl_names,
273
+ "inputNames" : in_names,
274
  "templateDict" : tpl_dict,
275
  "inputDict" : in_dict
276
  }
277
  return stage1_info, channel_info, tpl_montage, in_montage
278
 
279
  def optimal_mapping(channel_info):
280
+ tpl_names = channel_info["templateNames"]
281
+ in_names = channel_info["inputNames"]
282
  tpl_dict = channel_info["templateDict"]
283
  in_dict = channel_info["inputDict"]
284
+ unass_in_names = get_unassigned_inputs(in_names, in_dict)
285
  # reset all tpl.matched to False
286
+ for name in tpl_dict:
287
+ tpl_dict[name]["matched"] = False
288
 
289
+ all_tpl = np.array([tpl_dict[name]["coord_3d"] for name in tpl_names])
290
+ unass_in = np.array([in_dict[name]["coord_3d"] for name in unass_in_names])
291
 
292
  # initialize the cost matrix for the Hungarian algorithm
293
+ if len(unass_in_names) < 30:
294
  cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
295
  else:
296
+ cost_matrix = np.zeros((30, len(unass_in_names)))
297
  # fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
298
  for i in range(30):
299
+ for j in range(len(unass_in_names)):
300
  cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unass_in[j])*1000)
301
 
302
  # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
 
307
  new_idx = [[None]]*30
308
  orig_flags = [False]*30
309
  for i, j in zip(row_idx, col_idx):
310
+ if j < len(unass_in_names): # filter out dummy channels
311
+ tpl_name = tpl_names[i]
312
+ in_name = unass_in_names[j]
313
 
314
+ new_idx[i] = [in_dict[in_name]["index"]]
315
  orig_flags[i] = True
316
+ tpl_dict[tpl_name]["matched"] = True
317
+ in_dict[in_name]["assigned"] = True
318
+ #print(f'{tpl_name}({i}) <- {in_name}({j})')
319
 
320
  # fill the remaining empty tpl_channels
321
+ empty_tpl_names = get_empty_templates(tpl_names, tpl_dict)
322
+ if empty_tpl_names != []:
323
+ new_idx = find_neighbors(channel_info, empty_tpl_names, new_idx)
324
 
325
  result = {
326
  "newOrder" : new_idx,
 
344
  results += [result]
345
 
346
  data = {
347
+ #"templateNames" : channel_info["templateNames"],
348
+ #"inputNames" : channel_info["inputNames"],
349
  "batchNum" : batch_num,
350
  "mappingResults" : results
351
  }