audrey06100 commited on
Commit
ed22689
·
1 Parent(s): 732255a
Files changed (2) hide show
  1. app.py +18 -18
  2. app_utils.py +20 -20
app.py CHANGED
@@ -93,8 +93,8 @@ init_js = """
93
  item.querySelector(":scope > span").innerText = "";
94
  });
95
 
96
- // add indication for the missing channels
97
- channel = stage1_info.missingTemplates[0];
98
  left = channel_info.templateDict[channel].css_position[0];
99
  bottom = channel_info.templateDict[channel].css_position[1];
100
  let dot_rule = `
@@ -167,7 +167,7 @@ update_js = """
167
  }else return;
168
 
169
  // update the indication
170
- channel = stage1_info.missingTemplates[cnt-1];
171
  left = channel_info.templateDict[channel].css_position[0];
172
  bottom = channel_info.templateDict[channel].css_position[1];
173
  let dot_rule = `
@@ -323,7 +323,7 @@ with gr.Blocks() as demo:
323
  "totalNum" : None
324
  },
325
  "unassignedInputs" : None,
326
- "missingTemplates" : None,
327
  "batchNum" : None,
328
  "mappingResults" : [
329
  {
@@ -404,7 +404,7 @@ with gr.Blocks() as demo:
404
  # ========================================step1=========================================
405
  elif stage1_info["state"] == "step1-finished":
406
  in_num = len(channel_info["inputOrder"])
407
- matched_num = 30 - len(stage1_info["missingTemplates"])
408
 
409
  # step1 to step4
410
  if matched_num == 30:
@@ -435,9 +435,9 @@ with gr.Blocks() as demo:
435
  # initialize the progress indication label
436
  stage1_info["step2"] = {
437
  "count" : 1,
438
- "totalNum" : len(stage1_info["missingTemplates"])
439
  }
440
- tpl_name = stage1_info["missingTemplates"][0]
441
  label = '{} (1/{})'.format(tpl_name, stage1_info["step2"]["totalNum"])
442
 
443
  stage1_info["state"] = "step2-selecting"
@@ -478,7 +478,7 @@ with gr.Blocks() as demo:
478
  # --------------------------------store information---------------------------------
479
  # if the user has selected an in_channel to forward to the previous target tpl_channel
480
  if sel_radio != []:
481
- prev_tpl_name = stage1_info["missingTemplates"][stage1_info["step2"]["count"]-1]
482
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
483
  sel_idx = channel_info["inputDict"][sel_radio]["index"]
484
 
@@ -493,11 +493,11 @@ with gr.Blocks() as demo:
493
  stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"],
494
  channel_info["inputDict"])
495
  # exclude the tpl_channels filled in step2
496
- stage1_info["missingTemplates"] = app_utils.get_empty_templates(channel_info["templateOrder"],
497
  channel_info["templateDict"])
498
  # -----------------------------determine the next step------------------------------
499
  # step2 to step4
500
- if len(stage1_info["missingTemplates"]) == 0:
501
  md = """
502
  ### Mapping Results
503
  The mapping process has been finished.
@@ -559,14 +559,14 @@ with gr.Blocks() as demo:
559
  # find the 4 nearest in_channels for each unmatched tpl_channels
560
  stage1_info["mappingResults"][0]["newOrder"] = app_utils.find_neighbors(
561
  channel_info,
562
- stage1_info["missingTemplates"],
563
  stage1_info["mappingResults"][0]["newOrder"])
564
  # initialize the progress indication label
565
  stage1_info["step3"] = {
566
  "count" : 1,
567
- "totalNum" : len(stage1_info["missingTemplates"])
568
  }
569
- tpl_name = stage1_info["missingTemplates"][0]
570
  label = '{} (1/{})'.format(tpl_name, stage1_info["step3"]["totalNum"])
571
 
572
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
@@ -596,7 +596,7 @@ with gr.Blocks() as demo:
596
  # step3-2 to step4
597
  elif stage1_info["state"] == "step3-2-selecting":
598
  # --------------------------------store information---------------------------------
599
- prev_tpl_name = stage1_info["missingTemplates"][stage1_info["step3"]["count"]-1]
600
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
601
  sel_idx = [channel_info["inputDict"][channel]["index"] for channel in sel_chkbox]
602
  stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
@@ -673,7 +673,7 @@ with gr.Blocks() as demo:
673
  # ----------------------------------store information-----------------------------------
674
  # if the user has selected an in_channel to forward to the previous target tpl_channel
675
  if sel_name != []:
676
- prev_tpl_name = stage1_info["missingTemplates"][step2["count"]-1]
677
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
678
  sel_idx = channel_info["inputDict"][sel_name]["index"]
679
 
@@ -689,7 +689,7 @@ with gr.Blocks() as demo:
689
  # exclude the selected in_channel of the previous round
690
  stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"], channel_info["inputDict"])
691
 
692
- tpl_name = stage1_info["missingTemplates"][step2["count"]-1]
693
  label = '{} ({}/{})'.format(tpl_name, step2["count"], step2["totalNum"])
694
 
695
  stage1_info["step2"] = step2
@@ -724,7 +724,7 @@ with gr.Blocks() as demo:
724
  def update_chkbox(stage1_info, channel_info, sel_name):
725
  step3 = stage1_info["step3"]
726
  # ----------------------------------store information-----------------------------------
727
- prev_tpl_name = stage1_info["missingTemplates"][step3["count"]-1]
728
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
729
  sel_idx = [channel_info["inputDict"][channel]["index"] for channel in sel_name]
730
  stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
@@ -733,7 +733,7 @@ with gr.Blocks() as demo:
733
  # ---------------------------------update the new round---------------------------------
734
  step3["count"] += 1
735
 
736
- tpl_name = stage1_info["missingTemplates"][step3["count"]-1]
737
  label = '{} ({}/{})'.format(tpl_name, step3["count"], step3["totalNum"])
738
 
739
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
 
93
  item.querySelector(":scope > span").innerText = "";
94
  });
95
 
96
+ // add indication for the empty tpl_channels
97
+ channel = stage1_info.emptyTemplates[0];
98
  left = channel_info.templateDict[channel].css_position[0];
99
  bottom = channel_info.templateDict[channel].css_position[1];
100
  let dot_rule = `
 
167
  }else return;
168
 
169
  // update the indication
170
+ channel = stage1_info.emptyTemplates[cnt-1];
171
  left = channel_info.templateDict[channel].css_position[0];
172
  bottom = channel_info.templateDict[channel].css_position[1];
173
  let dot_rule = `
 
323
  "totalNum" : None
324
  },
325
  "unassignedInputs" : None,
326
+ "emptyTemplates" : None,
327
  "batchNum" : None,
328
  "mappingResults" : [
329
  {
 
404
  # ========================================step1=========================================
405
  elif stage1_info["state"] == "step1-finished":
406
  in_num = len(channel_info["inputOrder"])
407
+ matched_num = 30 - len(stage1_info["emptyTemplates"])
408
 
409
  # step1 to step4
410
  if matched_num == 30:
 
435
  # initialize the progress indication label
436
  stage1_info["step2"] = {
437
  "count" : 1,
438
+ "totalNum" : len(stage1_info["emptyTemplates"])
439
  }
440
+ tpl_name = stage1_info["emptyTemplates"][0]
441
  label = '{} (1/{})'.format(tpl_name, stage1_info["step2"]["totalNum"])
442
 
443
  stage1_info["state"] = "step2-selecting"
 
478
  # --------------------------------store information---------------------------------
479
  # if the user has selected an in_channel to forward to the previous target tpl_channel
480
  if sel_radio != []:
481
+ prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step2"]["count"]-1]
482
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
483
  sel_idx = channel_info["inputDict"][sel_radio]["index"]
484
 
 
493
  stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"],
494
  channel_info["inputDict"])
495
  # exclude the tpl_channels filled in step2
496
+ stage1_info["emptyTemplates"] = app_utils.get_empty_templates(channel_info["templateOrder"],
497
  channel_info["templateDict"])
498
  # -----------------------------determine the next step------------------------------
499
  # step2 to step4
500
+ if len(stage1_info["emptyTemplates"]) == 0:
501
  md = """
502
  ### Mapping Results
503
  The mapping process has been finished.
 
559
  # find the 4 nearest in_channels for each unmatched tpl_channels
560
  stage1_info["mappingResults"][0]["newOrder"] = app_utils.find_neighbors(
561
  channel_info,
562
+ stage1_info["emptyTemplates"],
563
  stage1_info["mappingResults"][0]["newOrder"])
564
  # initialize the progress indication label
565
  stage1_info["step3"] = {
566
  "count" : 1,
567
+ "totalNum" : len(stage1_info["emptyTemplates"])
568
  }
569
+ tpl_name = stage1_info["emptyTemplates"][0]
570
  label = '{} (1/{})'.format(tpl_name, stage1_info["step3"]["totalNum"])
571
 
572
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
 
596
  # step3-2 to step4
597
  elif stage1_info["state"] == "step3-2-selecting":
598
  # --------------------------------store information---------------------------------
599
+ prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step3"]["count"]-1]
600
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
601
  sel_idx = [channel_info["inputDict"][channel]["index"] for channel in sel_chkbox]
602
  stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
 
673
  # ----------------------------------store information-----------------------------------
674
  # if the user has selected an in_channel to forward to the previous target tpl_channel
675
  if sel_name != []:
676
+ prev_tpl_name = stage1_info["emptyTemplates"][step2["count"]-1]
677
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
678
  sel_idx = channel_info["inputDict"][sel_name]["index"]
679
 
 
689
  # exclude the selected in_channel of the previous round
690
  stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"], channel_info["inputDict"])
691
 
692
+ tpl_name = stage1_info["emptyTemplates"][step2["count"]-1]
693
  label = '{} ({}/{})'.format(tpl_name, step2["count"], step2["totalNum"])
694
 
695
  stage1_info["step2"] = step2
 
724
  def update_chkbox(stage1_info, channel_info, sel_name):
725
  step3 = stage1_info["step3"]
726
  # ----------------------------------store information-----------------------------------
727
+ prev_tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
728
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
729
  sel_idx = [channel_info["inputDict"][channel]["index"] for channel in sel_name]
730
  stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
 
733
  # ---------------------------------update the new round---------------------------------
734
  step3["count"] += 1
735
 
736
+ tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
737
  label = '{} ({}/{})'.format(tpl_name, step3["count"], step3["totalNum"])
738
 
739
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
app_utils.py CHANGED
@@ -157,7 +157,7 @@ def align_coords(channel_info, tpl_montage, in_montage):
157
  in_order = channel_info["inputOrder"]
158
  tpl_dict = channel_info["templateDict"]
159
  in_dict = channel_info["inputDict"]
160
- matched = get_matched(tpl_order, tpl_dict)
161
 
162
  # 2D alignment (for visualization purposes)
163
  fig = [tpl_montage.plot(), in_montage.plot()]
@@ -166,8 +166,8 @@ def align_coords(channel_info, tpl_montage, in_montage):
166
  # extract the displayed 2D coordinates
167
  all_tpl = ax[0].collections[0].get_offsets().data
168
  all_in= ax[1].collections[0].get_offsets().data
169
- matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
170
- matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
171
  plt.close('all')
172
 
173
  # apply TPS to transform in_channels to align with tpl_channels positions
@@ -188,8 +188,8 @@ def align_coords(channel_info, tpl_montage, in_montage):
188
  # 3D alignment
189
  all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
190
  all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
191
- matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
192
- matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
193
 
194
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
195
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
@@ -209,19 +209,19 @@ def align_coords(channel_info, tpl_montage, in_montage):
209
  })
210
  return channel_info
211
 
212
- def find_neighbors(channel_info, missing_channels, new_idx):
213
  in_order = channel_info["inputOrder"]
214
  tpl_dict = channel_info["templateDict"]
215
  in_dict = channel_info["inputDict"]
216
 
217
  all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
218
- empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
219
 
220
  # use KNN to choose k nearest channels
221
  k = 4 if len(in_order)>4 else len(in_order)
222
  knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
223
  knn.fit(all_in)
224
- for i, channel in enumerate(missing_channels):
225
  distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
226
  idx = tpl_dict[channel]["index"]
227
  new_idx[idx] = indices[0].tolist()
@@ -260,7 +260,7 @@ def match_names(stage1_info):
260
 
261
  stage1_info.update({
262
  "unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
263
- "missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
264
  "mappingResults" : [
265
  {
266
  "newOrder" : new_idx,
@@ -281,23 +281,23 @@ def optimal_mapping(channel_info):
281
  in_order = channel_info["inputOrder"]
282
  tpl_dict = channel_info["templateDict"]
283
  in_dict = channel_info["inputDict"]
284
- unassigned = get_unassigned_inputs(in_order, in_dict)
285
  # reset all tpl.matched to False
286
  for channel in tpl_dict:
287
  tpl_dict[channel]["matched"] = False
288
 
289
  all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
290
- unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
291
 
292
  # initialize the cost matrix for the Hungarian algorithm
293
- if len(unassigned) < 30:
294
  cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
295
  else:
296
- cost_matrix = np.zeros((30, len(unassigned)))
297
  # fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
298
  for i in range(30):
299
- for j in range(len(unassigned)):
300
- cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
301
 
302
  # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
303
  # by minimizing the total distances between their positions.
@@ -307,9 +307,9 @@ def optimal_mapping(channel_info):
307
  new_idx = [[None]]*30
308
  orig_flags = [False]*30
309
  for i, j in zip(row_idx, col_idx):
310
- if j < len(unassigned): # filter out dummy channels
311
  tpl_channel = tpl_order[i]
312
- in_channel = unassigned[j]
313
 
314
  new_idx[i] = [in_dict[in_channel]["index"]]
315
  orig_flags[i] = True
@@ -318,9 +318,9 @@ def optimal_mapping(channel_info):
318
  #print(f'{tpl_channel}({i}) <- {in_channel}({j})')
319
 
320
  # fill the remaining empty tpl_channels
321
- missing_channels = get_empty_templates(tpl_order, tpl_dict)
322
- if missing_channels != []:
323
- new_idx = find_neighbors(channel_info, missing_channels, new_idx)
324
 
325
  result = {
326
  "newOrder" : new_idx,
 
157
  in_order = channel_info["inputOrder"]
158
  tpl_dict = channel_info["templateDict"]
159
  in_dict = channel_info["inputDict"]
160
+ matched_order = get_matched(tpl_order, tpl_dict)
161
 
162
  # 2D alignment (for visualization purposes)
163
  fig = [tpl_montage.plot(), in_montage.plot()]
 
166
  # extract the displayed 2D coordinates
167
  all_tpl = ax[0].collections[0].get_offsets().data
168
  all_in= ax[1].collections[0].get_offsets().data
169
+ matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched_order])
170
+ matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched_order])
171
  plt.close('all')
172
 
173
  # apply TPS to transform in_channels to align with tpl_channels positions
 
188
  # 3D alignment
189
  all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
190
  all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
191
+ matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched_order])
192
+ matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched_order])
193
 
194
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
195
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
 
209
  })
210
  return channel_info
211
 
212
+ def find_neighbors(channel_info, empty_tpl_order, new_idx):
213
  in_order = channel_info["inputOrder"]
214
  tpl_dict = channel_info["templateDict"]
215
  in_dict = channel_info["inputDict"]
216
 
217
  all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
218
+ empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in empty_tpl_order]
219
 
220
  # use KNN to choose k nearest channels
221
  k = 4 if len(in_order)>4 else len(in_order)
222
  knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
223
  knn.fit(all_in)
224
+ for i, channel in enumerate(empty_tpl_order):
225
  distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
226
  idx = tpl_dict[channel]["index"]
227
  new_idx[idx] = indices[0].tolist()
 
260
 
261
  stage1_info.update({
262
  "unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
263
+ "emptyTemplates" : get_empty_templates(tpl_order, tpl_dict),
264
  "mappingResults" : [
265
  {
266
  "newOrder" : new_idx,
 
281
  in_order = channel_info["inputOrder"]
282
  tpl_dict = channel_info["templateDict"]
283
  in_dict = channel_info["inputDict"]
284
+ unass_in_order = get_unassigned_inputs(in_order, in_dict)
285
  # reset all tpl.matched to False
286
  for channel in tpl_dict:
287
  tpl_dict[channel]["matched"] = False
288
 
289
  all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
290
+ unass_in = np.array([in_dict[channel]["coord_3d"] for channel in unass_in_order])
291
 
292
  # initialize the cost matrix for the Hungarian algorithm
293
+ if len(unass_in_order) < 30:
294
  cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
295
  else:
296
+ cost_matrix = np.zeros((30, len(unass_in_order)))
297
  # fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
298
  for i in range(30):
299
+ for j in range(len(unass_in_order)):
300
+ cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unass_in[j])*1000)
301
 
302
  # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
303
  # by minimizing the total distances between their positions.
 
307
  new_idx = [[None]]*30
308
  orig_flags = [False]*30
309
  for i, j in zip(row_idx, col_idx):
310
+ if j < len(unass_in_order): # filter out dummy channels
311
  tpl_channel = tpl_order[i]
312
+ in_channel = unass_in_order[j]
313
 
314
  new_idx[i] = [in_dict[in_channel]["index"]]
315
  orig_flags[i] = True
 
318
  #print(f'{tpl_channel}({i}) <- {in_channel}({j})')
319
 
320
  # fill the remaining empty tpl_channels
321
+ empty_tpl_order = get_empty_templates(tpl_order, tpl_dict)
322
+ if empty_tpl_order != []:
323
+ new_idx = find_neighbors(channel_info, empty_tpl_order, new_idx)
324
 
325
  result = {
326
  "newOrder" : new_idx,