audrey06100 commited on
Commit
b9816ab
·
1 Parent(s): 5e68958
Files changed (3) hide show
  1. app.py +65 -73
  2. app_utils.py +359 -358
  3. utils.py +2 -1
app.py CHANGED
@@ -14,11 +14,9 @@ This tool is designed to assist you with two main tasks:
14
  - **Channel locations**: If you don't have the channel location file, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files don't match yours, you can use **EEGLAB** to adjust them to your required montage.
15
  - **Raw data**: Your data format must be a two-dimensional array (channels, timepoints).
16
  - **Channel requirements**: Your data must include some channels that correspond to our template channels, which include: ``Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2``. At least some of them need to be present for successful mapping. Additionally, please remove any reference, ECG, EOG, EMG... channels before uploading your files.
17
-
18
  """
19
 
20
  readme = """
21
-
22
  ## 1. Channel Mapping
23
  The following steps will guide you through the process of mapping your EEG channels to our template channels.
24
 
@@ -44,10 +42,10 @@ Once all template channels are filled, you will be directed to **Mapping Results
44
  ### Mapping Results
45
  After completing the previous steps, your channels will be aligned with the template channels required by our models.
46
  - In case there are still some channels that haven't been mapped, we will automatically batch and optimally assign them to the template. This ensures that even channels not initially mapped will still be included in the final results.
47
- - Once the mapping process is completed, a **JSON file** containing the mapping results will be generated. This file is necessary only if you plan to run the models using the <a href="">source code</a>; otherwise, you can ignore it.
48
 
49
  ## 2. Decode data
50
- After clicking on ``Run`` button, we will process your EEG data based on the mapping results. If necessary, your data will be devided into batches and run the models on each batch sequentially, ensuring that all channels are properly processed.
51
  """
52
 
53
  icunet = """
@@ -62,14 +60,16 @@ init_js = """
62
  channel_info = JSON.parse(JSON.stringify(channel_info));
63
  stage1_info = app_info.stage1
64
 
65
- let selector, attribute;
66
  let channel, left, bottom;
67
 
68
  if(stage1_info.state == "step2-selecting"){
69
  selector = "#radio-group > div:nth-of-type(2)";
 
70
  attribute = "value";
71
  }else if(stage1_info.state == "step3-2-selecting"){
72
  selector = "#chkbox-group > div:nth-of-type(2)";
 
73
  attribute = "name";
74
  }else return;
75
 
@@ -93,7 +93,7 @@ init_js = """
93
  bottom = channel_info.inputDict[channel].css_position[1];
94
 
95
  item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
96
- item.className = "";
97
  item.querySelector(":scope > span").innerText = "";
98
  });
99
 
@@ -217,21 +217,22 @@ update_js = """
217
  """
218
 
219
  with gr.Blocks() as demo:
220
-
221
  app_info_json = gr.JSON(visible=False)
222
  channel_info_json = gr.JSON(visible=False)
223
 
224
  gr.Markdown(intro)
225
  with gr.Row():
226
 
227
- with gr.Column(variant='panel'):
228
  gr.Markdown("# 1.Channel Mapping")
229
- # --------------------input----------------------
230
  in_loc_file = gr.File(label="Channel locations (.loc, .locs, .xyz, .sfp, .txt)",
231
  file_types=[".loc", "locs", ".xyz", ".sfp", ".txt"])
232
  map_btn = gr.Button("Map")
233
- # -------------------mapping---------------------
234
  desc_md = gr.Markdown(visible=False)
 
 
235
  # step1 : initial matching and scaling
236
  with gr.Row():
237
  tpl_img = gr.Image("./template_montage.png", label="Template montage", visible=False)
@@ -247,24 +248,17 @@ with gr.Blocks() as demo:
247
  scale=2)
248
  fillmode_btn = gr.Button("OK", visible=False, scale=1)
249
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
250
- # step4 : mapping results
251
- out_json_file = gr.File(visible=False)
252
- res_md = gr.Markdown(
253
- """
254
- (Download this file if you plan to run the models using the <a href="">source code</a>.)
255
- """,
256
- visible=False)
257
-
258
  with gr.Row():
259
  clear_btn = gr.Button("Clear", visible=False)
260
  step2_btn = gr.Button("Next", visible=False)
261
  step3_btn = gr.Button("Next", visible=False)
262
  next_btn = gr.Button("Next step", visible=False)
263
  # -----------------------------------------------
264
-
265
- with gr.Column(variant='panel'):
266
  gr.Markdown("# 2.Decode Data")
267
- # --------------------input----------------------
268
  with gr.Row():
269
  in_data_file = gr.File(label="Raw data (.csv)", file_types=[".csv"])
270
  with gr.Column():
@@ -274,15 +268,17 @@ with gr.Blocks() as demo:
274
  ("IC-U-Net", "ICUNet"),
275
  ("IC-U-Net++", "UNetpp"),
276
  ("IC-U-Net-Attn", "AttUnet")],
277
- #"(mapped data)"],
 
278
  value="EEGART",
279
  label="Model")
280
- run_btn = gr.Button(interactive=False)
281
- # --------------------output---------------------
 
282
  batch_md = gr.Markdown(visible=False)
283
  out_data_file = gr.File(label="Denoised data", visible=False)
284
  # -----------------------------------------------
285
-
286
  with gr.Row():
287
  with gr.Tab("README"):
288
  gr.Markdown(readme)
@@ -340,8 +336,8 @@ with gr.Blocks() as demo:
340
  "input_data" : "",
341
  "output_data" : ""
342
  },
343
- "sampleRate" : None,
344
- "totalBatchNum" : None
345
  }
346
  }
347
  return {app_info_json : app_info,
@@ -360,11 +356,11 @@ with gr.Blocks() as demo:
360
  chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
361
  step3_btn : gr.Button(visible=False),
362
  out_json_file : gr.File(value=None, visible=False),
363
- res_md : gr.Markdown(visible=False),
364
  # --------------------Stage2-------------------------
365
  in_data_file : gr.File(value=None),
366
  in_samplerate : gr.Textbox(value=None),
367
  run_btn : gr.Button(interactive=False),
 
368
  batch_md : gr.Markdown(visible=False),
369
  out_data_file : gr.File(value=None, visible=False)}
370
 
@@ -427,7 +423,8 @@ with gr.Blocks() as demo:
427
  if matched_num == 30:
428
  md = """
429
  ### Mapping Results
430
- The mapping process has been finished.
 
431
  """
432
  # finalize and save the mapping results
433
  filename = filepath+"mapping_result.json"
@@ -443,7 +440,6 @@ with gr.Blocks() as demo:
443
  mapped_img : gr.Image(visible=False),
444
  next_btn : gr.Button(visible=False),
445
  out_json_file : gr.File(filename, visible=True),
446
- res_md : gr.Markdown(visible=True),
447
  run_btn : gr.Button(interactive=True)}
448
 
449
  # step1 to step2
@@ -487,7 +483,7 @@ with gr.Blocks() as demo:
487
  elif in_num == matched_num:
488
  md = """
489
  ### Step3: Filling Remaining Template Channels
490
- Select one of the methods provided below to fill the remaining empty template channels.
491
  """
492
  stage1_info["state"] = "step3-select-method"
493
  app_info["stage1"] = stage1_info
@@ -503,7 +499,7 @@ with gr.Blocks() as demo:
503
  elif stage1_info["state"] == "step2-selecting":
504
 
505
  # --------------------store information before the button click---------------------
506
- # check if the user has selected an in_channel to forward to the previous target tpl_channel
507
  if selected_radio != []:
508
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
509
  prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
@@ -528,7 +524,8 @@ with gr.Blocks() as demo:
528
  if len(stage1_info["missingTemplates"]) == 0:
529
  md = """
530
  ### Mapping Results
531
- The mapping process has been finished.
 
532
  """
533
  # finalize and save the mapping results
534
  filename = filepath+"mapping_result.json"
@@ -542,7 +539,6 @@ with gr.Blocks() as demo:
542
  desc_md : gr.Markdown(md),
543
  radio_group : gr.Radio(visible=False),
544
  out_json_file : gr.File(filename, visible=True),
545
- res_md : gr.Markdown(visible=True),
546
  clear_btn : gr.Button(visible=False),
547
  next_btn : gr.Button(visible=False),
548
  run_btn : gr.Button(interactive=True)}
@@ -550,7 +546,7 @@ with gr.Blocks() as demo:
550
  else:
551
  md = """
552
  ### Step3: Filling Remaining Template Channels
553
- Select one of the methods provided below to fill the remaining empty template channels.
554
  """
555
  stage1_info["state"] = "step3-select-method"
556
  app_info["stage1"] = stage1_info
@@ -570,7 +566,8 @@ with gr.Blocks() as demo:
570
  if fillmode == "zero":
571
  md = """
572
  ### Mapping Results
573
- The mapping process has been finished.
 
574
  """
575
  # finalize and save the mapping results
576
  filename = filepath+"mapping_result.json"
@@ -585,7 +582,6 @@ with gr.Blocks() as demo:
585
  in_fillmode : gr.Dropdown(visible=False),
586
  fillmode_btn : gr.Button(visible=False),
587
  out_json_file : gr.File(filename, visible=True),
588
- res_md : gr.Markdown(visible=True),
589
  run_btn : gr.Button(interactive=True)}
590
  # step3-1 to step3-2
591
  elif fillmode == "mean":
@@ -636,18 +632,16 @@ with gr.Blocks() as demo:
636
  elif stage1_info["state"] == "step3-2-selecting":
637
 
638
  # --------------------store information before the button click---------------------
639
- # check if the user didn't uncheck all in_channel checkboxes
640
- if selected_chkbox != []:
641
- prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
642
- prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
643
- selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
644
-
645
- stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
646
- #print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
647
  # ----------------------------------------------------------------------------------
648
  md = """
649
  ### Mapping Results
650
- The mapping process has been finished.
 
651
  """
652
  # finalize and save the mapping results
653
  filename = filepath+"mapping_result.json"
@@ -662,14 +656,13 @@ with gr.Blocks() as demo:
662
  chkbox_group : gr.CheckboxGroup(visible=False),
663
  next_btn : gr.Button(visible=False),
664
  out_json_file : gr.File(filename, visible=True),
665
- res_md : gr.Markdown(visible=True),
666
  run_btn : gr.Button(interactive=True)}
667
 
668
  next_btn.click(
669
  fn = init_next_step,
670
  inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
671
  outputs = [app_info_json, channel_info_json, desc_md, tpl_img, mapped_img, radio_group, clear_btn, step2_btn,
672
- in_fillmode, fillmode_btn, chkbox_group, step3_btn, out_json_file, res_md, next_btn, run_btn]
673
  ).success(
674
  fn = None,
675
  js = init_js,
@@ -686,7 +679,7 @@ with gr.Blocks() as demo:
686
  inputs = in_loc_file,
687
  outputs = [app_info_json, channel_info_json, map_btn, desc_md, next_btn, tpl_img, mapped_img,
688
  radio_group, clear_btn, step2_btn, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
689
- out_json_file, res_md, in_data_file, in_samplerate, run_btn, batch_md, out_data_file]
690
  ).success(
691
  fn = init_next_step,
692
  inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
@@ -720,7 +713,7 @@ with gr.Blocks() as demo:
720
  def update_radio(app_info, channel_info, selected):
721
  stage1_info = app_info["stage1"]
722
  # ----------------------store information before the button click-----------------------
723
- # check if the user has selected an in_channel to forward to the previous target tpl_channel
724
  if selected != []:
725
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
726
  prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
@@ -774,14 +767,11 @@ with gr.Blocks() as demo:
774
  def update_chkbox(app_info, channel_info, selected):
775
  stage1_info = app_info["stage1"]
776
  # ----------------------store information before the button click-----------------------
777
- # check if the user didn't uncheck all in_channel checkboxes
778
- if selected != []:
779
- prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
780
- prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
781
- selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected]
782
-
783
- stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
784
- #print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
785
 
786
  # ------------------------update information for the new round--------------------------
787
  stage1_info["fillingCount"] += 1
@@ -808,7 +798,7 @@ with gr.Blocks() as demo:
808
  fn = init_next_step,
809
  inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
810
  outputs = [app_info_json, channel_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
811
- out_json_file, res_md, next_btn, run_btn]
812
  ).success(
813
  fn = None,
814
  js = init_js,
@@ -855,8 +845,9 @@ with gr.Blocks() as demo:
855
  })
856
  app_info["stage2"] = stage2_info
857
  return {app_info_json : app_info,
858
- #run_btn : gr.Button(interactive=False),
859
- batch_md : gr.Markdown(visible=False),
 
860
  out_data_file : gr.File(visible=False)}
861
 
862
  def run_model(app_info, modelname):
@@ -873,19 +864,14 @@ with gr.Blocks() as demo:
873
  # establish a temp folder
874
  try:
875
  os.mkdir(filepath+"temp_data/")
876
- #except FileExistsError:
877
- #utils.dataDelete(filepath+"temp_data/")
878
- #os.mkdir(filepath+"temp_data/")
879
  except FileNotFoundError:
880
  print('break1')
881
  break_flag = True
882
  break
883
- except OSError as e:
884
- print(e)
885
 
886
  # update the running status
887
  md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
888
- yield {batch_md : gr.Markdown(md, visible=True)}
889
 
890
  # get the mapped index order and the filled status for each tpl_channels
891
  new_idx = stage1_info["mappingData"][i]["newOrder"]
@@ -908,23 +894,29 @@ with gr.Blocks() as demo:
908
  utils.dataDelete(filepath+"temp_data/")
909
 
910
  if break_flag == True:
911
- yield {batch_md : gr.Markdown(visible=False)}
 
912
  else:
913
- yield {#run_btn : gr.Button(interactive=True),
 
914
  batch_md : gr.Markdown(visible=False),
915
  out_data_file : gr.File(new_filename, visible=True)}
916
 
 
 
 
 
 
917
  run_btn.click(
918
  fn = reset_run,
919
  inputs = [app_info_json, in_data_file, in_samplerate, in_modelname],
920
- outputs = [app_info_json, run_btn, batch_md, out_data_file]
921
-
922
  ).success(
923
  fn = run_model,
924
  inputs = [app_info_json, in_modelname],
925
- outputs = [run_btn, batch_md, out_data_file]
926
  )
927
 
928
  if __name__ == "__main__":
929
- demo.launch()
930
 
 
14
  - **Channel locations**: If you don't have the channel location file, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files don't match yours, you can use **EEGLAB** to adjust them to your required montage.
15
  - **Raw data**: Your data format must be a two-dimensional array (channels, timepoints).
16
  - **Channel requirements**: Your data must include some channels that correspond to our template channels, which include: ``Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2``. At least some of them need to be present for successful mapping. Additionally, please remove any reference, ECG, EOG, EMG... channels before uploading your files.
 
17
  """
18
 
19
  readme = """
 
20
  ## 1. Channel Mapping
21
  The following steps will guide you through the process of mapping your EEG channels to our template channels.
22
 
 
42
  ### Mapping Results
43
  After completing the previous steps, your channels will be aligned with the template channels required by our models.
44
  - In case there are still some channels that haven't been mapped, we will automatically batch and optimally assign them to the template. This ensures that even channels not initially mapped will still be included in the final results.
45
+ - Once the mapping process is completed, a JSON file containing the mapping results will be generated. This file is necessary only if you plan to run the models using the <a href="">source code</a>; otherwise, you can ignore it.
46
 
47
  ## 2. Decode data
48
+ After clicking on ``Run`` button, we will process your EEG data based on the mapping results. If necessary, your data will be divided into batches and run the models on each batch sequentially, ensuring that all channels are properly processed.
49
  """
50
 
51
  icunet = """
 
60
  channel_info = JSON.parse(JSON.stringify(channel_info));
61
  stage1_info = app_info.stage1
62
 
63
+ let selector, attribute; //, classname;
64
  let channel, left, bottom;
65
 
66
  if(stage1_info.state == "step2-selecting"){
67
  selector = "#radio-group > div:nth-of-type(2)";
68
+ //classname = "radio";
69
  attribute = "value";
70
  }else if(stage1_info.state == "step3-2-selecting"){
71
  selector = "#chkbox-group > div:nth-of-type(2)";
72
+ //classname = "chkbox";
73
  attribute = "name";
74
  }else return;
75
 
 
93
  bottom = channel_info.inputDict[channel].css_position[1];
94
 
95
  item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
96
+ item.className = ""; //classname;
97
  item.querySelector(":scope > span").innerText = "";
98
  });
99
 
 
217
  """
218
 
219
  with gr.Blocks() as demo:
 
220
  app_info_json = gr.JSON(visible=False)
221
  channel_info_json = gr.JSON(visible=False)
222
 
223
  gr.Markdown(intro)
224
  with gr.Row():
225
 
226
+ with gr.Column(variant="panel"):
227
  gr.Markdown("# 1.Channel Mapping")
228
+ # ---------------------input---------------------
229
  in_loc_file = gr.File(label="Channel locations (.loc, .locs, .xyz, .sfp, .txt)",
230
  file_types=[".loc", "locs", ".xyz", ".sfp", ".txt"])
231
  map_btn = gr.Button("Map")
232
+ # ---------------------output--------------------
233
  desc_md = gr.Markdown(visible=False)
234
+ out_json_file = gr.File(visible=False)
235
+ # --------------------mapping--------------------
236
  # step1 : initial matching and scaling
237
  with gr.Row():
238
  tpl_img = gr.Image("./template_montage.png", label="Template montage", visible=False)
 
248
  scale=2)
249
  fillmode_btn = gr.Button("OK", visible=False, scale=1)
250
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
251
+
 
 
 
 
 
 
 
252
  with gr.Row():
253
  clear_btn = gr.Button("Clear", visible=False)
254
  step2_btn = gr.Button("Next", visible=False)
255
  step3_btn = gr.Button("Next", visible=False)
256
  next_btn = gr.Button("Next step", visible=False)
257
  # -----------------------------------------------
258
+
259
+ with gr.Column(variant="panel"):
260
  gr.Markdown("# 2.Decode Data")
261
+ # ---------------------input---------------------
262
  with gr.Row():
263
  in_data_file = gr.File(label="Raw data (.csv)", file_types=[".csv"])
264
  with gr.Column():
 
268
  ("IC-U-Net", "ICUNet"),
269
  ("IC-U-Net++", "UNetpp"),
270
  ("IC-U-Net-Attn", "AttUnet")],
271
+ #"(mapped data)",
272
+ #"(denoised data)"],
273
  value="EEGART",
274
  label="Model")
275
+ run_btn = gr.Button("Run", interactive=False)
276
+ cancel_btn = gr.Button("Cancel", visible=False)
277
+ # ---------------------output--------------------
278
  batch_md = gr.Markdown(visible=False)
279
  out_data_file = gr.File(label="Denoised data", visible=False)
280
  # -----------------------------------------------
281
+
282
  with gr.Row():
283
  with gr.Tab("README"):
284
  gr.Markdown(readme)
 
336
  "input_data" : "",
337
  "output_data" : ""
338
  },
339
+ "totalBatchNum" : None,
340
+ "sampleRate" : None
341
  }
342
  }
343
  return {app_info_json : app_info,
 
356
  chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
357
  step3_btn : gr.Button(visible=False),
358
  out_json_file : gr.File(value=None, visible=False),
 
359
  # --------------------Stage2-------------------------
360
  in_data_file : gr.File(value=None),
361
  in_samplerate : gr.Textbox(value=None),
362
  run_btn : gr.Button(interactive=False),
363
+ cancel_btn : gr.Button(interactive=False),
364
  batch_md : gr.Markdown(visible=False),
365
  out_data_file : gr.File(value=None, visible=False)}
366
 
 
423
  if matched_num == 30:
424
  md = """
425
  ### Mapping Results
426
+ The mapping process has been finished.
427
+ Download the file below if you plan to run the models using the <a href="">source code</a>.
428
  """
429
  # finalize and save the mapping results
430
  filename = filepath+"mapping_result.json"
 
440
  mapped_img : gr.Image(visible=False),
441
  next_btn : gr.Button(visible=False),
442
  out_json_file : gr.File(filename, visible=True),
 
443
  run_btn : gr.Button(interactive=True)}
444
 
445
  # step1 to step2
 
483
  elif in_num == matched_num:
484
  md = """
485
  ### Step3: Filling Remaining Template Channels
486
+ Select one of the methods provided below to fill the remaining template channels.
487
  """
488
  stage1_info["state"] = "step3-select-method"
489
  app_info["stage1"] = stage1_info
 
499
  elif stage1_info["state"] == "step2-selecting":
500
 
501
  # --------------------store information before the button click---------------------
502
+ # if the user has selected an in_channel to forward to the previous target tpl_channel
503
  if selected_radio != []:
504
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
505
  prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
 
524
  if len(stage1_info["missingTemplates"]) == 0:
525
  md = """
526
  ### Mapping Results
527
+ The mapping process has been finished.
528
+ Download the file below if you plan to run the models using the <a href="">source code</a>.
529
  """
530
  # finalize and save the mapping results
531
  filename = filepath+"mapping_result.json"
 
539
  desc_md : gr.Markdown(md),
540
  radio_group : gr.Radio(visible=False),
541
  out_json_file : gr.File(filename, visible=True),
 
542
  clear_btn : gr.Button(visible=False),
543
  next_btn : gr.Button(visible=False),
544
  run_btn : gr.Button(interactive=True)}
 
546
  else:
547
  md = """
548
  ### Step3: Filling Remaining Template Channels
549
+ Select one of the methods provided below to fill the remaining template channels.
550
  """
551
  stage1_info["state"] = "step3-select-method"
552
  app_info["stage1"] = stage1_info
 
566
  if fillmode == "zero":
567
  md = """
568
  ### Mapping Results
569
+ The mapping process has been finished.
570
+ Download the file below if you plan to run the models using the <a href="">source code</a>.
571
  """
572
  # finalize and save the mapping results
573
  filename = filepath+"mapping_result.json"
 
582
  in_fillmode : gr.Dropdown(visible=False),
583
  fillmode_btn : gr.Button(visible=False),
584
  out_json_file : gr.File(filename, visible=True),
 
585
  run_btn : gr.Button(interactive=True)}
586
  # step3-1 to step3-2
587
  elif fillmode == "mean":
 
632
  elif stage1_info["state"] == "step3-2-selecting":
633
 
634
  # --------------------store information before the button click---------------------
635
+ prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
636
+ prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
637
+ selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
638
+ stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
639
+ #print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
 
 
 
640
  # ----------------------------------------------------------------------------------
641
  md = """
642
  ### Mapping Results
643
+ The mapping process has been finished.
644
+ Download the file below if you plan to run the models using the <a href="">source code</a>.
645
  """
646
  # finalize and save the mapping results
647
  filename = filepath+"mapping_result.json"
 
656
  chkbox_group : gr.CheckboxGroup(visible=False),
657
  next_btn : gr.Button(visible=False),
658
  out_json_file : gr.File(filename, visible=True),
 
659
  run_btn : gr.Button(interactive=True)}
660
 
661
  next_btn.click(
662
  fn = init_next_step,
663
  inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
664
  outputs = [app_info_json, channel_info_json, desc_md, tpl_img, mapped_img, radio_group, clear_btn, step2_btn,
665
+ in_fillmode, fillmode_btn, chkbox_group, step3_btn, out_json_file, next_btn, run_btn]
666
  ).success(
667
  fn = None,
668
  js = init_js,
 
679
  inputs = in_loc_file,
680
  outputs = [app_info_json, channel_info_json, map_btn, desc_md, next_btn, tpl_img, mapped_img,
681
  radio_group, clear_btn, step2_btn, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
682
+ out_json_file, in_data_file, in_samplerate, run_btn, cancel_btn, batch_md, out_data_file]
683
  ).success(
684
  fn = init_next_step,
685
  inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
 
713
  def update_radio(app_info, channel_info, selected):
714
  stage1_info = app_info["stage1"]
715
  # ----------------------store information before the button click-----------------------
716
+ # if the user has selected an in_channel to forward to the previous target tpl_channel
717
  if selected != []:
718
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
719
  prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
 
767
  def update_chkbox(app_info, channel_info, selected):
768
  stage1_info = app_info["stage1"]
769
  # ----------------------store information before the button click-----------------------
770
+ prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
771
+ prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
772
+ selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected]
773
+ stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
774
+ #print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
 
 
 
775
 
776
  # ------------------------update information for the new round--------------------------
777
  stage1_info["fillingCount"] += 1
 
798
  fn = init_next_step,
799
  inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
800
  outputs = [app_info_json, channel_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
801
+ out_json_file, next_btn, run_btn]
802
  ).success(
803
  fn = None,
804
  js = init_js,
 
845
  })
846
  app_info["stage2"] = stage2_info
847
  return {app_info_json : app_info,
848
+ run_btn : gr.Button(visible=False),
849
+ cancel_btn : gr.Button(visible=True, interactive=True),
850
+ batch_md : gr.Markdown("", visible=True),
851
  out_data_file : gr.File(visible=False)}
852
 
853
  def run_model(app_info, modelname):
 
864
  # establish a temp folder
865
  try:
866
  os.mkdir(filepath+"temp_data/")
 
 
 
867
  except FileNotFoundError:
868
  print('break1')
869
  break_flag = True
870
  break
 
 
871
 
872
  # update the running status
873
  md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
874
+ yield {batch_md : gr.Markdown(md)}
875
 
876
  # get the mapped index order and the filled status for each tpl_channels
877
  new_idx = stage1_info["mappingData"][i]["newOrder"]
 
894
  utils.dataDelete(filepath+"temp_data/")
895
 
896
  if break_flag == True:
897
+ yield {run_btn : gr.Button(visible=True),
898
+ cancel_btn : gr.Button(visible=False)}
899
  else:
900
+ yield {run_btn : gr.Button(visible=True),
901
+ cancel_btn : gr.Button(visible=False),
902
  batch_md : gr.Markdown(visible=False),
903
  out_data_file : gr.File(new_filename, visible=True)}
904
 
905
+ @cancel_btn.click(inputs = app_info_json, outputs = [cancel_btn, batch_md])
906
+ def stop_processing(app_info):
907
+ utils.dataDelete(app_info["stage2"]["filePath"])
908
+ return gr.Button(interactive=False), gr.Markdown(visible=False)
909
+
910
  run_btn.click(
911
  fn = reset_run,
912
  inputs = [app_info_json, in_data_file, in_samplerate, in_modelname],
913
+ outputs = [app_info_json, run_btn, cancel_btn, batch_md, out_data_file]
 
914
  ).success(
915
  fn = run_model,
916
  inputs = [app_info_json, in_modelname],
917
+ outputs = [run_btn, cancel_btn, batch_md, out_data_file]
918
  )
919
 
920
  if __name__ == "__main__":
921
+ demo.launch(server_name="0.0.0.0", server_port=7860)
922
 
app_utils.py CHANGED
@@ -1,358 +1,359 @@
1
- import utils
2
- import os
3
- import math
4
- import json
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- import mne
8
- from mne.channels import read_custom_montage
9
- from scipy.interpolate import Rbf
10
- from scipy.optimize import linear_sum_assignment
11
- from sklearn.neighbors import NearestNeighbors
12
-
13
- def reorder_data(idx_order, fill_flags, filename, new_filename):
14
- # read the input data
15
- raw_data = utils.read_train_data(filename)
16
- #print(raw_data.shape)
17
- new_data = np.zeros((30, raw_data.shape[1]))
18
-
19
- zero_arr = np.zeros((1, raw_data.shape[1]))
20
- for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
21
- if flag == False:
22
- new_data[i, :] = raw_data[idx_set[0], :]
23
- elif idx_set == []:
24
- new_data[i, :] = zero_arr
25
- else:
26
- tmp_data = [raw_data[j, :] for j in idx_set]
27
- new_data[i, :] = np.mean(tmp_data, axis=0)
28
-
29
- utils.save_data(new_data, new_filename)
30
- return raw_data.shape
31
-
32
- def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
33
- # read the denoised data
34
- d_data = utils.read_train_data(filename)
35
- if batch_cnt == 0:
36
- new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
37
- #print(new_data.shape)
38
- else:
39
- new_data = utils.read_train_data(new_filename)
40
-
41
- for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
42
- # ignore if this channel was filled using "fillmode"
43
- if flag == False:
44
- new_data[idx_set[0], :] = d_data[i, :]
45
-
46
- utils.save_data(new_data, new_filename)
47
- return
48
-
49
- def get_matched(tpl_order, tpl_dict):
50
- return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
51
-
52
- def get_empty_templates(tpl_order, tpl_dict):
53
- return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
54
-
55
- def get_unassigned_inputs(in_order, in_dict):
56
- return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
57
-
58
- def read_montage_data(loc_file):
59
- tpl_montage = read_custom_montage("./template_chanlocs.loc")
60
- in_montage = read_custom_montage(loc_file)
61
- tpl_order = tpl_montage.ch_names
62
- in_order = in_montage.ch_names
63
- tpl_dict = {}
64
- in_dict = {}
65
-
66
- # convert all channel names to uppercase and store the channel information
67
- for i, channel in enumerate(tpl_order):
68
- up_channel = str.upper(channel)
69
- tpl_montage.rename_channels({channel: up_channel})
70
- tpl_dict[up_channel] = {
71
- "index" : i,
72
- "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
73
- "matched" : False
74
- }
75
- for i, channel in enumerate(in_order):
76
- up_channel = str.upper(channel)
77
- in_montage.rename_channels({channel: up_channel})
78
- in_dict[up_channel] = {
79
- "index" : i,
80
- "coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
81
- "assigned" : False
82
- }
83
- return tpl_montage, in_montage, tpl_dict, in_dict
84
-
85
- def save_figures(channel_info, tpl_montage, filename1, filename2):
86
- tpl_order = channel_info["templateOrder"]
87
- in_order = channel_info["inputOrder"]
88
- tpl_dict = channel_info["templateDict"]
89
- in_dict = channel_info["inputDict"]
90
-
91
- tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
92
- tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
93
- in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
94
- in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
95
- tpl_coords = np.vstack((tpl_x, tpl_y)).T
96
- in_coords = np.vstack((in_x, in_y)).T
97
-
98
- # extract template's head figure
99
- tpl_fig = tpl_montage.plot()
100
- tpl_ax = tpl_fig.axes[0]
101
- lines = tpl_ax.lines
102
- head_lines = []
103
- for line in lines:
104
- x, y = line.get_data()
105
- head_lines.append((x,y))
106
- plt.close()
107
-
108
- # -------------------------plot input montage------------------------------
109
- fig = plt.figure(figsize=(6.4,6.4), dpi=100)
110
- ax = fig.add_subplot(111)
111
- fig.tight_layout()
112
- ax.set_aspect('equal')
113
- ax.axis('off')
114
-
115
- # plot template's head
116
- for x, y in head_lines:
117
- ax.plot(x, y, color='black', linewidth=1.0)
118
- # plot in_channels on it
119
- ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
120
- for i, channel in enumerate(in_order):
121
- ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
122
- # save input_montage
123
- fig.savefig(filename1)
124
-
125
- # ---------------------------add indications-------------------------------
126
- # plot unmatched input channels in red
127
- indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
128
- ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
129
- for i in indices:
130
- ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
131
- # save mapped_montage
132
- fig.savefig(filename2)
133
-
134
- # -------------------------------------------------------------------------
135
- # store the tpl and in_channels' display positions (in px).
136
- tpl_coords = ax.transData.transform(tpl_coords)
137
- in_coords = ax.transData.transform(in_coords)
138
- plt.close()
139
-
140
- for i, channel in enumerate(tpl_order):
141
- css_left = (tpl_coords[i,0]-11)/6.4
142
- css_bottom = (tpl_coords[i,1]-7)/6.4
143
- tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
144
- for i, channel in enumerate(in_order):
145
- css_left = (in_coords[i,0]-11)/6.4
146
- css_bottom = (in_coords[i,1]-7)/6.4
147
- in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
148
-
149
- channel_info.update({
150
- "templateDict" : tpl_dict,
151
- "inputDict" : in_dict
152
- })
153
- return channel_info
154
-
155
- def align_coords(channel_info, tpl_montage, in_montage):
156
- tpl_order = channel_info["templateOrder"]
157
- in_order = channel_info["inputOrder"]
158
- tpl_dict = channel_info["templateDict"]
159
- in_dict = channel_info["inputDict"]
160
- matched = get_matched(tpl_order, tpl_dict)
161
-
162
- # 2D alignment (for visualization purposes)
163
- fig = [tpl_montage.plot(), in_montage.plot()]
164
- ax = [fig[0].axes[0], fig[1].axes[0]]
165
-
166
- # extract the displayed 2D coordinates from the plots
167
- all_tpl = ax[0].collections[0].get_offsets().data
168
- all_in= ax[1].collections[0].get_offsets().data
169
- matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
170
- matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
171
-
172
- # apply TPS to transform in_channels positions to align with tpl_channels positions
173
- rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
174
- rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
175
-
176
- # apply the transformation to all in_channels
177
- transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
178
- transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
179
- transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
180
-
181
- # store the 2D positions
182
- for i, channel in enumerate(tpl_order):
183
- tpl_dict[channel]["coord_2d"] = all_tpl[i]
184
- for i, channel in enumerate(in_order):
185
- in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
186
-
187
-
188
- # 3D alignment
189
- all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
190
- all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
191
- matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
192
- matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
193
-
194
- rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
195
- rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
196
- rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
197
-
198
- transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
199
- transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
200
- transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
201
- transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
202
-
203
- # update in_channels' 3D positions
204
- for i, channel in enumerate(in_order):
205
- in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
206
-
207
- channel_info.update({
208
- "templateDict" : tpl_dict,
209
- "inputDict" : in_dict
210
- })
211
- return channel_info
212
-
213
- def find_neighbors(channel_info, missing_channels, new_idx):
214
- in_order = channel_info["inputOrder"]
215
- tpl_dict = channel_info["templateDict"]
216
- in_dict = channel_info["inputDict"]
217
-
218
- all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
219
- empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
220
-
221
- # use KNN to choose k nearest channels
222
- k = 4 if len(in_order)>4 else len(in_order)
223
- knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
224
- knn.fit(all_in)
225
- for i, channel in enumerate(missing_channels):
226
- distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
227
- idx = tpl_dict[channel]["index"]
228
- new_idx[idx] = indices[0].tolist()
229
-
230
- return new_idx
231
-
232
- def match_names(stage1_info, channel_info):
233
- # read the location file
234
- loc_file = stage1_info["fileNames"]["input_loc"]
235
- tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
236
- tpl_order = tpl_montage.ch_names
237
- in_order = in_montage.ch_names
238
- new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channels
239
- fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode"
240
-
241
- alias_dict = {
242
- 'T3': 'T7',
243
- 'T4': 'T8',
244
- 'T5': 'P7',
245
- 'T6': 'P8'
246
- }
247
- for i, channel in enumerate(tpl_order):
248
- if channel in alias_dict and alias_dict[channel] in in_dict:
249
- tpl_montage.rename_channels({channel: alias_dict[channel]})
250
- tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
251
- channel = alias_dict[channel]
252
-
253
- if channel in in_dict:
254
- new_idx[i] = [in_dict[channel]["index"]]
255
- fill_flags[i] = False
256
- tpl_dict[channel]["matched"] = True
257
- in_dict[channel]["assigned"] = True
258
-
259
- # update the names
260
- tpl_order = tpl_montage.ch_names
261
-
262
- stage1_info.update({
263
- "unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
264
- "missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
265
- "mappingData" : [
266
- {
267
- "newOrder" : new_idx,
268
- "fillFlags" : fill_flags
269
- }
270
- ]
271
- })
272
- channel_info.update({
273
- "templateOrder" : tpl_order,
274
- "inputOrder" : in_order,
275
- "templateDict" : tpl_dict,
276
- "inputDict" : in_dict
277
- })
278
- return stage1_info, channel_info, tpl_montage, in_montage
279
-
280
- def optimal_mapping(channel_info):
281
- tpl_order = channel_info["templateOrder"]
282
- in_order = channel_info["inputOrder"]
283
- tpl_dict = channel_info["templateDict"]
284
- in_dict = channel_info["inputDict"]
285
- unassigned = get_unassigned_inputs(in_order, in_dict)
286
- # reset all tpl.matched to False
287
- for channel in tpl_dict:
288
- tpl_dict[channel]["matched"] = False
289
-
290
- all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
291
- unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
292
-
293
- # initialize the cost matrix for the Hungarian algorithm
294
- if len(unassigned) < 30:
295
- cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
296
- else:
297
- cost_matrix = np.zeros((30, len(unassigned)))
298
- # fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
299
- for i in range(30):
300
- for j in range(len(unassigned)):
301
- cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
302
-
303
- # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
304
- # by minimizing the total distances between their positions.
305
- row_idx, col_idx = linear_sum_assignment(cost_matrix)
306
-
307
- # store the mapping results
308
- new_idx = [[]]*30
309
- fill_flags = [True]*30
310
- for i, j in zip(row_idx, col_idx):
311
- if j < len(unassigned): # filter out dummy channels
312
- tpl_channel = tpl_order[i]
313
- in_channel = unassigned[j]
314
-
315
- new_idx[i] = [in_dict[in_channel]["index"]]
316
- fill_flags[i] = False
317
- tpl_dict[tpl_channel]["matched"] = True
318
- in_dict[in_channel]["assigned"] = True
319
- #print(f'{tpl_channel}({i}) <- {in_channel}({j})')
320
-
321
- # fill the remaining empty tpl_channels
322
- missing_channels = get_empty_templates(tpl_order, tpl_dict)
323
- if missing_channels != []:
324
- new_idx = find_neighbors(channel_info, missing_channels, new_idx)
325
-
326
- mapping_data = {
327
- "newOrder" : new_idx,
328
- "fillFlags" : fill_flags
329
- }
330
- channel_info.update({
331
- "templateDict" : tpl_dict,
332
- "inputDict" : in_dict
333
- })
334
- return mapping_data, channel_info
335
-
336
- def mapping_result(stage1_info, stage2_info, channel_info, filename):
337
- unassigned_num = len(stage1_info["unassignedInputs"])
338
- batch_num = math.ceil(unassigned_num/30) + 1
339
-
340
- # map the remaining in_channels
341
- for i in range(1, batch_num):
342
- # optimally select 30 in_channels to map to the tpl_channels based on proximity
343
- new_mapping_data, channel_info = optimal_mapping(channel_info)
344
- stage1_info["mappingData"] += [new_mapping_data]
345
-
346
- # save the mapping results
347
- new_dict = {
348
- #"templateOrder" : channel_info["templateOrder"],
349
- #"inputOrder" : channel_info["inputOrder"],
350
- "batchNum" : batch_num,
351
- "mappingData" : stage1_info["mappingData"]
352
- }
353
- with open(filename, 'w') as jsonfile:
354
- jsonfile.write(json.dumps(new_dict))
355
-
356
- stage2_info["totalBatchNum"] = batch_num
357
- return stage1_info, stage2_info, channel_info
358
-
 
 
1
+ import utils
2
+ import os
3
+ import math
4
+ import json
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import mne
8
+ from mne.channels import read_custom_montage
9
+ from scipy.interpolate import Rbf
10
+ from scipy.optimize import linear_sum_assignment
11
+ from sklearn.neighbors import NearestNeighbors
12
+
13
+ def reorder_data(idx_order, fill_flags, filename, new_filename):
14
+ # read the input data
15
+ raw_data = utils.read_train_data(filename)
16
+ #print(raw_data.shape)
17
+ new_data = np.zeros((30, raw_data.shape[1]))
18
+
19
+ zero_arr = np.zeros((1, raw_data.shape[1]))
20
+ for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
21
+ if flag == False:
22
+ new_data[i, :] = raw_data[idx_set[0], :]
23
+ elif idx_set == []:
24
+ new_data[i, :] = zero_arr
25
+ else:
26
+ tmp_data = [raw_data[j, :] for j in idx_set]
27
+ new_data[i, :] = np.mean(tmp_data, axis=0)
28
+
29
+ utils.save_data(new_data, new_filename)
30
+ return raw_data.shape
31
+
32
+ def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
33
+ # read the denoised data
34
+ d_data = utils.read_train_data(filename)
35
+ if batch_cnt == 0:
36
+ new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
37
+ #print(new_data.shape)
38
+ else:
39
+ new_data = utils.read_train_data(new_filename)
40
+
41
+ for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
42
+ # ignore if this channel was filled using "fillmode"
43
+ if flag == False:
44
+ new_data[idx_set[0], :] = d_data[i, :]
45
+
46
+ utils.save_data(new_data, new_filename)
47
+ return
48
+
49
+ def get_matched(tpl_order, tpl_dict):
50
+ return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
51
+
52
+ def get_empty_templates(tpl_order, tpl_dict):
53
+ return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
54
+
55
+ def get_unassigned_inputs(in_order, in_dict):
56
+ return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
57
+
58
+ def read_montage_data(loc_file):
59
+ tpl_montage = read_custom_montage("./template_chanlocs.loc")
60
+ in_montage = read_custom_montage(loc_file)
61
+ tpl_order = tpl_montage.ch_names
62
+ in_order = in_montage.ch_names
63
+ tpl_dict = {}
64
+ in_dict = {}
65
+
66
+ # convert all channel names to uppercase and store the channel information
67
+ for i, channel in enumerate(tpl_order):
68
+ up_channel = str.upper(channel)
69
+ tpl_montage.rename_channels({channel: up_channel})
70
+ tpl_dict[up_channel] = {
71
+ "index" : i,
72
+ "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
73
+ "matched" : False
74
+ }
75
+ for i, channel in enumerate(in_order):
76
+ up_channel = str.upper(channel)
77
+ in_montage.rename_channels({channel: up_channel})
78
+ in_dict[up_channel] = {
79
+ "index" : i,
80
+ "coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
81
+ "assigned" : False
82
+ }
83
+ return tpl_montage, in_montage, tpl_dict, in_dict
84
+
85
+ def save_figures(channel_info, tpl_montage, filename1, filename2):
86
+ tpl_order = channel_info["templateOrder"]
87
+ in_order = channel_info["inputOrder"]
88
+ tpl_dict = channel_info["templateDict"]
89
+ in_dict = channel_info["inputDict"]
90
+
91
+ tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
92
+ tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
93
+ in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
94
+ in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
95
+ tpl_coords = np.vstack((tpl_x, tpl_y)).T
96
+ in_coords = np.vstack((in_x, in_y)).T
97
+
98
+ # extract template's head figure
99
+ tpl_fig = tpl_montage.plot()
100
+ tpl_ax = tpl_fig.axes[0]
101
+ lines = tpl_ax.lines
102
+ head_lines = []
103
+ for line in lines:
104
+ x, y = line.get_data()
105
+ head_lines.append((x,y))
106
+ plt.close()
107
+
108
+ # -------------------------plot input montage------------------------------
109
+ fig = plt.figure(figsize=(6.4,6.4), dpi=100)
110
+ ax = fig.add_subplot(111)
111
+ fig.tight_layout()
112
+ ax.set_aspect('equal')
113
+ ax.axis('off')
114
+
115
+ # plot template's head
116
+ for x, y in head_lines:
117
+ ax.plot(x, y, color='black', linewidth=1.0)
118
+ # plot in_channels on it
119
+ ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
120
+ for i, channel in enumerate(in_order):
121
+ ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
122
+ # save input_montage
123
+ fig.savefig(filename1)
124
+
125
+ # ---------------------------add indications-------------------------------
126
+ # plot unmatched input channels in red
127
+ indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
128
+ if indices != []:
129
+ ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
130
+ for i in indices:
131
+ ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
132
+ # save mapped_montage
133
+ fig.savefig(filename2)
134
+
135
+ # -------------------------------------------------------------------------
136
+ # store the tpl and in_channels' display positions (in px).
137
+ tpl_coords = ax.transData.transform(tpl_coords)
138
+ in_coords = ax.transData.transform(in_coords)
139
+ plt.close()
140
+
141
+ for i, channel in enumerate(tpl_order):
142
+ css_left = (tpl_coords[i,0]-11)/6.4
143
+ css_bottom = (tpl_coords[i,1]-7)/6.4
144
+ tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
145
+ for i, channel in enumerate(in_order):
146
+ css_left = (in_coords[i,0]-11)/6.4
147
+ css_bottom = (in_coords[i,1]-7)/6.4
148
+ in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
149
+
150
+ channel_info.update({
151
+ "templateDict" : tpl_dict,
152
+ "inputDict" : in_dict
153
+ })
154
+ return channel_info
155
+
156
+ def align_coords(channel_info, tpl_montage, in_montage):
157
+ tpl_order = channel_info["templateOrder"]
158
+ in_order = channel_info["inputOrder"]
159
+ tpl_dict = channel_info["templateDict"]
160
+ in_dict = channel_info["inputDict"]
161
+ matched = get_matched(tpl_order, tpl_dict)
162
+
163
+ # 2D alignment (for visualization purposes)
164
+ fig = [tpl_montage.plot(), in_montage.plot()]
165
+ ax = [fig[0].axes[0], fig[1].axes[0]]
166
+
167
+ # extract the displayed 2D coordinates from the plots
168
+ all_tpl = ax[0].collections[0].get_offsets().data
169
+ all_in= ax[1].collections[0].get_offsets().data
170
+ matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
171
+ matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
172
+
173
+ # apply TPS to transform in_channels positions to align with tpl_channels positions
174
+ rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
175
+ rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
176
+
177
+ # apply the transformation to all in_channels
178
+ transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
179
+ transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
180
+ transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
181
+
182
+ # store the 2D positions
183
+ for i, channel in enumerate(tpl_order):
184
+ tpl_dict[channel]["coord_2d"] = all_tpl[i]
185
+ for i, channel in enumerate(in_order):
186
+ in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
187
+
188
+
189
+ # 3D alignment
190
+ all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
191
+ all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
192
+ matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
193
+ matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
194
+
195
+ rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
196
+ rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
197
+ rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
198
+
199
+ transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
200
+ transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
201
+ transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
202
+ transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
203
+
204
+ # update in_channels' 3D positions
205
+ for i, channel in enumerate(in_order):
206
+ in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
207
+
208
+ channel_info.update({
209
+ "templateDict" : tpl_dict,
210
+ "inputDict" : in_dict
211
+ })
212
+ return channel_info
213
+
214
+ def find_neighbors(channel_info, missing_channels, new_idx):
215
+ in_order = channel_info["inputOrder"]
216
+ tpl_dict = channel_info["templateDict"]
217
+ in_dict = channel_info["inputDict"]
218
+
219
+ all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
220
+ empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
221
+
222
+ # use KNN to choose k nearest channels
223
+ k = 4 if len(in_order)>4 else len(in_order)
224
+ knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
225
+ knn.fit(all_in)
226
+ for i, channel in enumerate(missing_channels):
227
+ distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
228
+ idx = tpl_dict[channel]["index"]
229
+ new_idx[idx] = indices[0].tolist()
230
+
231
+ return new_idx
232
+
233
+ def match_names(stage1_info, channel_info):
234
+ # read the location file
235
+ loc_file = stage1_info["fileNames"]["input_loc"]
236
+ tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
237
+ tpl_order = tpl_montage.ch_names
238
+ in_order = in_montage.ch_names
239
+ new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channels
240
+ fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode"
241
+
242
+ alias_dict = {
243
+ 'T3': 'T7',
244
+ 'T4': 'T8',
245
+ 'T5': 'P7',
246
+ 'T6': 'P8'
247
+ }
248
+ for i, channel in enumerate(tpl_order):
249
+ if channel in alias_dict and alias_dict[channel] in in_dict:
250
+ tpl_montage.rename_channels({channel: alias_dict[channel]})
251
+ tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
252
+ channel = alias_dict[channel]
253
+
254
+ if channel in in_dict:
255
+ new_idx[i] = [in_dict[channel]["index"]]
256
+ fill_flags[i] = False
257
+ tpl_dict[channel]["matched"] = True
258
+ in_dict[channel]["assigned"] = True
259
+
260
+ # update the names
261
+ tpl_order = tpl_montage.ch_names
262
+
263
+ stage1_info.update({
264
+ "unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
265
+ "missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
266
+ "mappingData" : [
267
+ {
268
+ "newOrder" : new_idx,
269
+ "fillFlags" : fill_flags
270
+ }
271
+ ]
272
+ })
273
+ channel_info.update({
274
+ "templateOrder" : tpl_order,
275
+ "inputOrder" : in_order,
276
+ "templateDict" : tpl_dict,
277
+ "inputDict" : in_dict
278
+ })
279
+ return stage1_info, channel_info, tpl_montage, in_montage
280
+
281
+ def optimal_mapping(channel_info):
282
+ tpl_order = channel_info["templateOrder"]
283
+ in_order = channel_info["inputOrder"]
284
+ tpl_dict = channel_info["templateDict"]
285
+ in_dict = channel_info["inputDict"]
286
+ unassigned = get_unassigned_inputs(in_order, in_dict)
287
+ # reset all tpl.matched to False
288
+ for channel in tpl_dict:
289
+ tpl_dict[channel]["matched"] = False
290
+
291
+ all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
292
+ unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
293
+
294
+ # initialize the cost matrix for the Hungarian algorithm
295
+ if len(unassigned) < 30:
296
+ cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
297
+ else:
298
+ cost_matrix = np.zeros((30, len(unassigned)))
299
+ # fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
300
+ for i in range(30):
301
+ for j in range(len(unassigned)):
302
+ cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
303
+
304
+ # apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
305
+ # by minimizing the total distances between their positions.
306
+ row_idx, col_idx = linear_sum_assignment(cost_matrix)
307
+
308
+ # store the mapping results
309
+ new_idx = [[]]*30
310
+ fill_flags = [True]*30
311
+ for i, j in zip(row_idx, col_idx):
312
+ if j < len(unassigned): # filter out dummy channels
313
+ tpl_channel = tpl_order[i]
314
+ in_channel = unassigned[j]
315
+
316
+ new_idx[i] = [in_dict[in_channel]["index"]]
317
+ fill_flags[i] = False
318
+ tpl_dict[tpl_channel]["matched"] = True
319
+ in_dict[in_channel]["assigned"] = True
320
+ #print(f'{tpl_channel}({i}) <- {in_channel}({j})')
321
+
322
+ # fill the remaining empty tpl_channels
323
+ missing_channels = get_empty_templates(tpl_order, tpl_dict)
324
+ if missing_channels != []:
325
+ new_idx = find_neighbors(channel_info, missing_channels, new_idx)
326
+
327
+ mapping_data = {
328
+ "newOrder" : new_idx,
329
+ "fillFlags" : fill_flags
330
+ }
331
+ channel_info.update({
332
+ "templateDict" : tpl_dict,
333
+ "inputDict" : in_dict
334
+ })
335
+ return mapping_data, channel_info
336
+
337
+ def mapping_result(stage1_info, stage2_info, channel_info, filename):
338
+ unassigned_num = len(stage1_info["unassignedInputs"])
339
+ batch_num = math.ceil(unassigned_num/30) + 1
340
+
341
+ # map the remaining in_channels
342
+ for i in range(1, batch_num):
343
+ # optimally select 30 in_channels to map to the tpl_channels based on proximity
344
+ new_mapping_data, channel_info = optimal_mapping(channel_info)
345
+ stage1_info["mappingData"] += [new_mapping_data]
346
+
347
+ # save the mapping results
348
+ new_dict = {
349
+ #"templateOrder" : channel_info["templateOrder"],
350
+ #"inputOrder" : channel_info["inputOrder"],
351
+ "batchNum" : batch_num,
352
+ "mappingData" : stage1_info["mappingData"]
353
+ }
354
+ with open(filename, 'w') as jsonfile:
355
+ jsonfile.write(json.dumps(new_dict))
356
+
357
+ stage2_info["totalBatchNum"] = batch_num
358
+ return stage1_info, stage2_info, channel_info
359
+
utils.py CHANGED
@@ -143,7 +143,8 @@ def dataDelete(path):
143
  try:
144
  shutil.rmtree(path)
145
  except OSError as e:
146
- print(e)
 
147
  else:
148
  pass
149
  #print("The directory is deleted successfully")
 
143
  try:
144
  shutil.rmtree(path)
145
  except OSError as e:
146
+ pass
147
+ #print(e)
148
  else:
149
  pass
150
  #print("The directory is deleted successfully")