audrey06100 commited on
Commit
3566452
·
1 Parent(s): 356e41c
Files changed (2) hide show
  1. app.py +383 -426
  2. channel_mapping.py → app_utils.py +206 -122
app.py CHANGED
@@ -1,16 +1,8 @@
 
 
1
  import gradio as gr
2
-
3
  import os
4
  import random
5
- import math
6
- import numpy as np
7
- import matplotlib.pyplot as plt
8
- import mne
9
- from mne.channels import read_custom_montage
10
-
11
- import utils
12
- from channel_mapping import mapping_stage1, mapping_stage2, reorder_input_data, restore_original_order, find_neighbors
13
-
14
 
15
  readme = """
16
 
@@ -35,8 +27,14 @@ Your unmatched channels, previously highlighted in red, will be shown on your mo
35
  ### Step3: Filling Remaining Template Channels
36
  To run the models successfully, we need to ensure that all 30 template channels are filled. In this step, you are required to select one of the methods provided below to fill the remaining empty template channels:
37
  - **Mean** method: Each empty template channel is filled with the average value of data from the nearest input channels. By default, the 4 closest input channels (determined after aligning your montage to the template's scale using TPS) are selected for this averaging process. On the interface, you will see checkboxes displayed above each of your channel. The 4 nearest channels are pre-selected by default for each empty template channels, but you can modify these selections as needed. If you uncheck all the checkboxes for a particular template channel, it will be filled with zeros.
38
- - **Zero** method: All empty template channels are filled with zeros.
39
- Choose he method that best suits your needs, considering that the model's performance may vary depending on the method used. Once all template channels are filled, you can proceed to run the models.
 
 
 
 
 
 
40
 
41
  ## 2. Decode data
42
  In this phase, you can select which model to use for denoising your EEG data. Detailed information about the models can be found in the other tabs.
@@ -64,7 +62,7 @@ init_js = """
64
  selector = "#radio-group > div:nth-of-type(2)";
65
  //classname = "radio";
66
  attribute = "value";
67
- }else if(stage1_info.state == "step3-selecting"){
68
  selector = "#chkbox-group > div:nth-of-type(2)";
69
  //classname = "chkbox";
70
  attribute = "name";
@@ -78,7 +76,7 @@ init_js = """
78
  aspect-ratio: 1;
79
  //width: 560px;
80
  //height: 560px;
81
- background: url("file=${stage1_info.filenames.input_montage}");
82
  background-size: contain;
83
 
84
  `;
@@ -165,7 +163,7 @@ update_js = """
165
  item.className = "";
166
  item.querySelector(":scope > span").innerText = "";
167
  });
168
- }else if(stage1_info.state == "step3-selecting"){
169
  selector = "#chkbox-group > div:nth-of-type(2)";
170
  }else return;
171
 
@@ -247,15 +245,14 @@ with gr.Blocks() as demo:
247
  map_btn = gr.Button("Mapping", interactive=False, scale=1)
248
 
249
  # ------------------------mapping------------------------
250
- # description for stage1-123
251
  desc_md = gr.Markdown(visible=False)
252
- # stage1-1 : mapping result
253
  with gr.Row():
254
  tpl_img = gr.Image("./template_montage.png", label="Template channels", visible=False)
255
  mapped_img = gr.Image(label="Input channels", visible=False)
256
- # stage1-2 : assign unmatched input channels to empty template channels
257
  radio_group = gr.Radio(elem_id="radio-group", visible=False)
258
- # stage1-3 : select a way to fill the empty template channels
259
  with gr.Row():
260
  in_fillmode = gr.Dropdown(choices=["mean", "zero"],
261
  value="mean",
@@ -264,6 +261,8 @@ with gr.Blocks() as demo:
264
  scale=2)
265
  fillmode_btn = gr.Button("OK", visible=False, scale=1)
266
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
 
 
267
 
268
  with gr.Row():
269
  clear_btn = gr.Button("Clear", visible=False)
@@ -280,7 +279,9 @@ with gr.Blocks() as demo:
280
  ("ART", "EEGART"),
281
  ("IC-U-Net", "ICUNet"),
282
  ("IC-U-Net++", "UNetpp"),
283
- ("IC-U-Net-Attn", "AttUnet")],
 
 
284
  value="EEGART",
285
  label="Model",
286
  scale=2)
@@ -303,7 +304,7 @@ with gr.Blocks() as demo:
303
  with gr.Tab("README"):
304
  gr.Markdown(readme)
305
 
306
- #demo.load(js=tmp_js)
307
 
308
  # verify that all required inputs have been provided
309
  @gr.on(triggers = [in_data_file.upload, in_data_file.clear, in_loc_file.upload, in_loc_file.clear, in_samplerate.change],
@@ -316,317 +317,233 @@ with gr.Blocks() as demo:
316
 
317
 
318
  # +========================================================================================+
319
- # | stage1: channel mapping |
320
  # +========================================================================================+
321
  def reset_all(in_data, in_loc, samplerate):
322
  # establish a new folder for the current session
323
- filepath = os.path.dirname(str(in_data))
324
  try:
325
- os.mkdir(filepath+"/session_data/")
326
  except OSError as e:
327
- utils.dataDelete(filepath+"/session_data/")
328
- os.mkdir(filepath+"/session_data/")
329
  print(e)
330
  # establish new folders for stage1 and stage2
331
- os.mkdir(filepath+"/session_data/stage1/")
332
- os.mkdir(filepath+"/session_data/stage2/")
333
 
334
  # initialize channel_info, app_info
335
  channel_info = {}
336
  app_info = {
337
- "rootFilepath" : filepath+"/session_data/",
338
  "sampleRate" : int(samplerate),
339
- #"currentStage" : "stage1",
340
  "stage1" : {
341
- "filepath" : filepath+"/session_data/stage1/",
342
- "filenames" : {
343
  "input_data" : in_data,
344
  "input_loc" : in_loc,
345
  "input_montage" : "",
346
  "mapped_montage" : ""
347
  },
348
- "state" : None,
349
  "fillingCount" : None,
350
  "totalFillingNum" : None,
351
- "newOrder" : None,
352
  "unassignedInputs" : None,
353
- "missingTemplates" : None
 
 
 
 
 
 
 
354
  },
355
  "stage2" : {
356
- "filepath" : filepath+"/session_data/stage2/",
357
- "filenames" : {
358
  "output_data" : ""
359
  },
360
- #"state" : None,
361
- "totalBatchNum" : None,
362
- "newOrder" : None,
363
- "unassignedInputs" : None
364
  }
365
  }
366
  # reset layout
367
  return {app_info_json : app_info,
368
  channel_info_json : channel_info,
369
- # ------------------stage1-----------------------
370
  map_btn : gr.Button(interactive=False),
371
  desc_md : gr.Markdown(visible=False),
 
372
  tpl_img : gr.Image(visible=False),
373
  mapped_img : gr.Image(value=None, visible=False),
374
  radio_group : gr.Radio(choices=[], value=[], label="", visible=False),
375
- in_fillmode : gr.Dropdown(value="mean", visible=False),
376
- chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
377
- fillmode_btn : gr.Button(visible=False),
378
  clear_btn : gr.Button(visible=False),
379
  step2_btn : gr.Button(visible=False),
 
 
 
380
  step3_btn : gr.Button(visible=False),
381
- next_btn : gr.Button("Next step", visible=False),
382
- # ------------------stage2-----------------------
383
  run_btn : gr.Button(interactive=False),
384
  batch_md : gr.Markdown(visible=False),
385
  out_data_file : gr.File(visible=False)}
386
 
387
 
388
  # +========================================================================================+
389
- # | stage1-1 |
390
  # +========================================================================================+
391
- def save_figures(channel_info, filename1, filename2):
392
-
393
- tpl_montage = read_custom_montage("./template_chanlocs.loc")
394
- tpl_dict = channel_info["templateDict"]
395
- in_dict = channel_info["inputDict"]
396
- tpl_order = channel_info["templateOrder"]
397
- in_order = channel_info["inputOrder"]
398
-
399
- # get template and input's 2d coords
400
- tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
401
- tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
402
- in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
403
- in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
404
- tpl_coords = np.vstack((tpl_x, tpl_y)).T
405
- in_coords = np.vstack((in_x, in_y)).T
406
-
407
- # get template's head figure
408
- tpl_fig = tpl_montage.plot()
409
- tpl_ax = tpl_fig.axes[0]
410
- lines = tpl_ax.lines
411
- head_lines = []
412
- for line in lines:
413
- x, y = line.get_data()
414
- head_lines.append((x,y))
415
- plt.close()
416
-
417
- # -------------------------plot input montage------------------------------
418
- fig = plt.figure(figsize=(6.4,6.4), dpi=100)
419
- ax = fig.add_subplot(111)
420
- fig.tight_layout()
421
- ax.set_aspect('equal')
422
- ax.axis('off')
423
-
424
- # plot template's head
425
- for x, y in head_lines:
426
- ax.plot(x, y, color='black', linewidth=1.0)
427
- # plot input channels on it
428
- ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
429
- for i, channel in enumerate(in_order):
430
- ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
431
- # save input_montage
432
- fig.savefig(filename1)
433
-
434
- # ---------------------------add indications-------------------------------
435
- indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
436
-
437
- # plot unmatched input channels in red
438
- ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
439
- for i in indices:
440
- ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
441
- # save mapped_montage
442
- fig.savefig(filename2)
443
-
444
- # -------------------------------------------------------------------------
445
- # store the template and input channels' display position (in px).
446
- tpl_coords = ax.transData.transform(tpl_coords)
447
- in_coords = ax.transData.transform(in_coords)
448
- plt.close()
449
-
450
- for i, channel in enumerate(tpl_order):
451
- css_left = (tpl_coords[i,0]-11)/6.4
452
- css_bottom = (tpl_coords[i,1]-7)/6.4
453
- tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
454
-
455
- for i, channel in enumerate(in_order):
456
- css_left = (in_coords[i,0]-11)/6.4
457
- css_bottom = (in_coords[i,1]-7)/6.4
458
- in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
459
-
460
- channel_info.update({
461
- "templateDict" : tpl_dict,
462
- "inputDict" : in_dict
463
- })
464
- return channel_info
465
-
466
- def mapping_result(app_info, channel_info):
467
  stage1_info = app_info["stage1"]
468
- filepath = stage1_info["filepath"]
469
-
470
- # generate and save figures of the input montage and the mapped montage
471
- filename1 = filepath+"input_montage_"+str(random.randint(1,10000))+".png"
472
- filename2 = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
473
- channel_info = save_figures(channel_info, filename1, filename2)
474
- stage1_info["filenames"].update({
475
- "input_montage" : filename1,
476
- "mapped_montage" : filename2
477
- })
478
-
479
- # -------------------------determine the next step--------------------------
480
-
481
- in_num = len(channel_info["inputOrder"])
482
- matched_num = 30 - len(stage1_info["missingTemplates"])
483
-
484
- # if the in_channels has all the 30 tpl_channels (in_num>=30)
485
- # -> stage2
486
- if matched_num == 30:
487
- stage1_info["state"] = "finished"
488
- gr.Info('The mapping process has been finished.')
489
-
490
- if in_num == 30:
491
- md = """
492
- ---
493
- ### Step1: Initial Matching and Rescaling
494
- Below is the result of mapping your channels to our template channels based on their names.
495
- """
496
- else:
497
- md = """
498
- ---
499
- ### Step1: Initial Matching and Rescaling
500
- Below is the result of mapping your channels to our template channels based on their names.
501
- - channels highlighted in red are those that do not match any template channels.
502
- """
503
 
504
- app_info["stage1"] = stage1_info
505
- return {app_info_json : app_info,
506
- channel_info_json : channel_info,
507
- map_btn : gr.Button(interactive=True),
508
- desc_md : gr.Markdown(md, visible=True),
509
- tpl_img : gr.Image(visible=True),
510
- mapped_img : gr.Image(value=filename2, visible=True),
511
- run_btn : gr.Button(interactive=True)}
512
-
513
- else:
514
- # if matched_num < 30, and there're still some unmatched in_channels
515
- # -> assign these in_channels to nearby unmatched tpl_channels
516
- if in_num > matched_num:
517
- stage1_info["state"] = "step2-initializing"
 
 
 
518
  md = """
519
  ---
520
  ### Step1: Initial Matching and Rescaling
521
- Below is the result of mapping your channels to our template channels based on their names.
522
- - channels highlighted in red are those that do not match any template channels.
523
  """
524
-
525
- # if in_num < 30, but all of them can match to some tpl_channels
526
- # -> directly use fillmode to fill the remaining tpl_channels
527
- elif in_num == matched_num:
528
- stage1_info["state"] = "step3-initializing"
529
  md = """
530
  ---
531
  ### Step1: Initial Matching and Rescaling
532
- Below is the result of mapping your channels to our template channels based on their names.
 
533
  """
534
 
 
535
  app_info["stage1"] = stage1_info
536
- return {app_info_json : app_info,
537
  channel_info_json : channel_info,
538
  map_btn : gr.Button(interactive=True),
539
- desc_md : gr.Markdown(md, visible=True),
540
  tpl_img : gr.Image(visible=True),
541
  mapped_img : gr.Image(value=filename2, visible=True),
542
  next_btn : gr.Button(visible=True)}
543
-
544
- start_stage1 = map_btn.click(
545
- fn = reset_all,
546
- inputs = [in_data_file, in_loc_file, in_samplerate],
547
- outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_img, mapped_img, radio_group,
548
- in_fillmode, chkbox_group, fillmode_btn, clear_btn, step2_btn, step3_btn, next_btn,
549
- run_btn, batch_md, out_data_file]
550
- ).success(
551
- fn = mapping_stage1,
552
- inputs = [app_info_json, channel_info_json],
553
- outputs = [app_info_json, channel_info_json, desc_md]
554
- ).success(
555
- fn = mapping_result,
556
- inputs = [app_info_json, channel_info_json],
557
- outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_img, mapped_img, next_btn, run_btn]
558
- )
559
-
560
-
561
- # +========================================================================================+
562
- # | manage step transition |
563
- # +========================================================================================+
564
- def init_next_step(app_info, channel_info, selected_radio, selected_chkbox):
565
- stage1_info = app_info["stage1"]
566
 
567
- # stage1-1 -> stage1-2
568
- if stage1_info["state"] == "step2-initializing":
569
- #print('step1 -> step2')
570
- md = """
571
- ---
572
- ### Step2: Forwarding Unmatched Channels
573
- Select one of your unmatched channels to forward its data to the empty template channel
574
- currently indicated in red.
575
- """
576
 
577
- # initialize the progress indication label for step2
578
- stage1_info.update({
579
- "state" : "step2-selecting",
580
- "fillingCount" : 1,
581
- "totalFillingNum" : len(stage1_info["missingTemplates"])
582
- })
583
- name = stage1_info["missingTemplates"][0]
584
- label = "{} (1/{})".format(name, stage1_info["totalFillingNum"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
- app_info["stage1"] = stage1_info
587
- # determine which button to display
588
- if stage1_info["totalFillingNum"] == 1:
589
- return {app_info_json : app_info,
590
- channel_info_json : channel_info,
591
- desc_md : gr.Markdown(md),
592
- tpl_img : gr.Image(visible=False),
593
- mapped_img : gr.Image(visible=False),
594
- radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
595
- clear_btn : gr.Button(visible=True),
596
- next_btn : gr.Button("Next step")}
597
- else:
598
- return {app_info_json : app_info,
599
- channel_info_json : channel_info,
600
- desc_md : gr.Markdown(md),
601
- tpl_img : gr.Image(visible=False),
602
- mapped_img : gr.Image(visible=False),
603
- radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
604
- clear_btn : gr.Button(visible=True),
605
- step2_btn : gr.Button(visible=True),
606
- next_btn : gr.Button(visible=False)}
607
-
608
- # stage1-1 -> stage1-3
609
- elif stage1_info["state"] == "step3-initializing":
610
- #print('step1 -> step3')
611
- md = """
612
- ---
613
- ### Step3: Filling Remaining Template Channels
614
- To run the model successfully, we need to ensure that all 30 template channels are filled.
615
- In this step, you are required to select one of the methods provided below to fill the
616
- remaining empty template channels.
617
- """
618
- return {desc_md : gr.Markdown(md),
619
- tpl_img : gr.Image(visible=False),
620
- mapped_img : gr.Image(visible=False),
621
- in_fillmode : gr.Dropdown(visible=True),
622
- fillmode_btn : gr.Button(visible=True),
623
- next_btn : gr.Button(visible=False)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
 
625
- # stage1-2 -> stage1-3 or stage2
626
  elif stage1_info["state"] == "step2-selecting":
627
 
628
- # ----------------------store information before the button click----------------------
629
-
630
  # check if the user has selected an in_channel to forward to the previous target tpl_channel
631
  if selected_radio != []:
632
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
@@ -634,40 +551,51 @@ with gr.Blocks() as demo:
634
 
635
  # store the index of the in_channel
636
  selected_idx = channel_info["inputDict"][selected_radio]["index"]
637
- stage1_info["newOrder"][prev_target_idx] = [selected_idx]
 
638
  # mark the in_channel as assigned and tpl_channel as matched
639
  channel_info["templateDict"][prev_target_name]["matched"] = True
640
  channel_info["inputDict"][selected_radio]["assigned"] = True
641
- print(prev_target_name, '<-', selected_radio)
642
-
643
- # ------------------------update information for the next step-------------------------
644
 
 
645
  # update the list of unassignedInputs to exclude the selected in_channel of the previous round
646
- stage1_info["unassignedInputs"] = [channel for channel in channel_info["inputOrder"]
647
- if channel_info["inputDict"][channel]["assigned"]==False]
648
  # update the list of missingTemplates to exclude those filled in step2
649
- stage1_info["missingTemplates"] = [channel for channel in channel_info["templateOrder"]
650
- if channel_info["templateDict"][channel]["matched"]==False]
651
 
652
- # if all the unmatched tpl_channels were filled by in_channels
653
- # -> stage2
 
654
  if len(stage1_info["missingTemplates"]) == 0:
655
- #print('step2 -> stage2')
656
- stage1_info["state"] = "finished"
657
- gr.Info('The mapping process has been finished.')
 
 
 
658
 
 
 
 
 
 
 
659
  app_info["stage1"] = stage1_info
660
- return {app_info_json : app_info,
 
661
  channel_info_json : channel_info,
662
- desc_md : gr.Markdown(visible=False),
663
  radio_group : gr.Radio(visible=False),
 
664
  clear_btn : gr.Button(visible=False),
665
  next_btn : gr.Button(visible=False),
666
  run_btn : gr.Button(interactive=True)}
667
-
668
- # -> stage1-3
669
  else:
670
- #print('step2 -> step3')
671
  md = """
672
  ---
673
  ### Step3: Filling Remaining Template Channels
@@ -676,8 +604,9 @@ with gr.Blocks() as demo:
676
  remaining empty template channels.
677
  """
678
 
 
679
  app_info["stage1"] = stage1_info
680
- return {app_info_json : app_info,
681
  channel_info_json : channel_info,
682
  desc_md : gr.Markdown(md),
683
  radio_group : gr.Radio(visible=False),
@@ -686,14 +615,86 @@ with gr.Blocks() as demo:
686
  clear_btn : gr.Button(visible=False),
687
  next_btn : gr.Button(visible=False)}
688
 
689
- # stage1-3 -> stage2
690
- elif stage1_info["state"] == "step3-selecting":
691
- #print('step3 -> stage2')
692
- stage1_info["state"] = "finished"
693
- gr.Info('The mapping process has been finished.')
694
 
695
- # ----------------------store information before the button click----------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
 
 
697
  # if the user didn't uncheck all in_channel checkboxes
698
  if selected_chkbox != []:
699
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
@@ -701,22 +702,36 @@ with gr.Blocks() as demo:
701
 
702
  # store the indices of the in_channels
703
  selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
704
- stage1_info["newOrder"][prev_target_idx] = selected_indices
705
- #print(f'{prev_target_name}({prev_target_idx}): {selected_chkbox}')
706
- # -------------------------------------------------------------------------------------
 
 
 
 
 
707
 
 
 
 
 
 
 
708
  app_info["stage1"] = stage1_info
709
- return {app_info_json : app_info,
710
- desc_md : gr.Markdown(visible=False),
 
 
711
  chkbox_group : gr.CheckboxGroup(visible=False),
712
  next_btn : gr.Button(visible=False),
 
713
  run_btn : gr.Button(interactive=True)}
714
 
715
  next_btn.click(
716
  fn = init_next_step,
717
- inputs = [app_info_json, channel_info_json, radio_group, chkbox_group],
718
- outputs = [app_info_json, channel_info_json, desc_md, tpl_img, mapped_img, radio_group,
719
- in_fillmode, chkbox_group, fillmode_btn, clear_btn, step2_btn, next_btn, run_btn]
720
  ).success(
721
  fn = None,
722
  js = init_js,
@@ -726,9 +741,25 @@ with gr.Blocks() as demo:
726
 
727
 
728
  # +========================================================================================+
729
- # | stage1-2 |
730
  # +========================================================================================+
731
- # ....
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
732
  @radio_group.select(inputs = app_info_json, outputs = [step2_btn, next_btn])
733
  def determine_button(app_info):
734
  stage1_info = app_info["stage1"]
@@ -751,8 +782,7 @@ with gr.Blocks() as demo:
751
  def update_radio(app_info, channel_info, selected):
752
  stage1_info = app_info["stage1"]
753
 
754
- # ----------------------store information before the button click----------------------
755
-
756
  # check if the user has selected an in_channel to forward to the previous target tpl_channel
757
  if selected != []:
758
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
@@ -760,18 +790,18 @@ with gr.Blocks() as demo:
760
 
761
  # store the index of the selected in_channel
762
  selected_idx = channel_info["inputDict"][selected]["index"]
763
- stage1_info["newOrder"][prev_target_idx] = [selected_idx]
 
764
  # mark the in_channel as assigned and tpl_channel as matched
765
  channel_info["templateDict"][prev_target_name]["matched"] = True
766
  channel_info["inputDict"][selected]["assigned"] = True
767
- print(prev_target_name, '<-', selected)
768
 
769
- # ------------------------update information for the new round-------------------------
770
  stage1_info["fillingCount"] += 1
771
 
772
  # update the list of unassignedInputs to exclude the selected in_channel of the previous round
773
- stage1_info["unassignedInputs"] = [channel for channel in channel_info["inputOrder"]
774
- if channel_info["inputDict"][channel]["assigned"]==False]
775
  # update the progress indication label
776
  target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
777
  radio_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
@@ -784,7 +814,7 @@ with gr.Blocks() as demo:
784
  radio_group : gr.Radio(choices=stage1_info["unassignedInputs"],
785
  value=[], label=radio_label),
786
  step2_btn : gr.Button(visible=False),
787
- next_btn : gr.Button("Next step", visible=True)}
788
  else:
789
  return {app_info_json : app_info,
790
  channel_info_json : channel_info,
@@ -804,71 +834,12 @@ with gr.Blocks() as demo:
804
 
805
 
806
  # +========================================================================================+
807
- # | stage1-3 |
808
- # +========================================================================================+
809
- def fill_value(app_info, channel_info, fillmode):
810
- stage1_info = app_info["stage1"]
811
-
812
- if fillmode == "zero":
813
- stage1_info["state"] = "finished"
814
- gr.Info('The mapping process has been finished.')
815
-
816
- app_info["stage1"] = stage1_info
817
- return {app_info_json : app_info,
818
- desc_md : gr.Markdown(visible=False),
819
- in_fillmode : gr.Dropdown(visible=False),
820
- fillmode_btn : gr.Button(visible=False),
821
- run_btn : gr.Button(interactive=True)}
822
-
823
- elif fillmode == "mean":
824
- md = """
825
- ---
826
- ### Step3: Fill the remaining template channels
827
- The current empty template channel, indicated in red, will be filled with the average
828
- value of the data from the selected channels. (By default, the 4 nearest channels are pre-selected.)
829
- """
830
-
831
- # find the 4-NN in_channels for each of the unmatched tpl_channels
832
- new_idx = find_neighbors(channel_info, stage1_info["missingTemplates"], stage1_info["newOrder"])
833
-
834
- stage1_info.update({
835
- "state" : "step3-selecting",
836
- "newOrder" : new_idx,
837
- "fillingCount" : 1,
838
- "totalFillingNum" : len(stage1_info["missingTemplates"])
839
- })
840
-
841
- # initialize the progress indication label
842
- target_name = stage1_info["missingTemplates"][0]
843
- target_idx = channel_info["templateDict"][target_name]["index"]
844
- chkbox_value = stage1_info["newOrder"][target_idx]
845
- chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
846
- chkbox_label = "{} (1/{})".format(target_name, stage1_info["totalFillingNum"])
847
-
848
- app_info["stage1"] = stage1_info
849
- # determine which button to display
850
- if stage1_info["totalFillingNum"] == 1:
851
- return {app_info_json : app_info,
852
- desc_md : gr.Markdown(md),
853
- in_fillmode : gr.Dropdown(visible=False),
854
- fillmode_btn : gr.Button(visible=False),
855
- chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
856
- value=chkbox_value, label=chkbox_label, visible=True),
857
- next_btn : gr.Button(visible=True)}
858
- else:
859
- return {app_info_json : app_info,
860
- desc_md : gr.Markdown(md),
861
- in_fillmode : gr.Dropdown(visible=False),
862
- fillmode_btn : gr.Button(visible=False),
863
- chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
864
- value=chkbox_value, label=chkbox_label, visible=True),
865
- step3_btn : gr.Button(visible=True)}
866
-
867
  def update_chkbox(app_info, channel_info, selected):
868
  stage1_info = app_info["stage1"]
869
 
870
- # ----------------------store information before the button click----------------------
871
-
872
  # if the user didn't uncheck all in_channel checkboxes
873
  if selected != []:
874
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
@@ -876,16 +847,16 @@ with gr.Blocks() as demo:
876
 
877
  # store the indices of the selected in_channels
878
  selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected]
879
- stage1_info["newOrder"][prev_target_idx] = selected_indices
880
- #print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
881
 
882
- # ------------------------update information for the new round-------------------------
883
  stage1_info["fillingCount"] += 1
884
 
885
  # update the progress indication label
886
  target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
887
  target_idx = channel_info["templateDict"][target_name]["index"]
888
- chkbox_value = stage1_info["newOrder"][target_idx]
889
  chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
890
  chkbox_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
891
 
@@ -895,15 +866,16 @@ with gr.Blocks() as demo:
895
  return {app_info_json : app_info,
896
  chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label),
897
  step3_btn : gr.Button(visible=False),
898
- next_btn : gr.Button("Submit", visible=True)}
899
  else:
900
  return {app_info_json : app_info,
901
  chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)}
902
 
903
  fillmode_btn.click(
904
- fn = fill_value,
905
- inputs = [app_info_json, channel_info_json, in_fillmode],
906
- outputs = [app_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn, next_btn, run_btn]
 
907
  ).success(
908
  fn = None,
909
  js = init_js,
@@ -924,59 +896,45 @@ with gr.Blocks() as demo:
924
 
925
 
926
  # +========================================================================================+
927
- # | stage2: decode data |
928
  # +========================================================================================+
929
- def reset_run(app_info, channel_info, modelname):
930
  stage1_info = app_info["stage1"]
931
  stage2_info = app_info["stage2"]
932
 
933
- # delete the previous folder of stage2
934
- filepath = stage2_info["filepath"]
935
  utils.dataDelete(filepath)
936
- # establish a new folder for stage2
937
- new_filepath = app_info["rootFilepath"]+"stage2_"+str(random.randint(1,10000))+"/"
938
  os.mkdir(new_filepath)
939
  # generate the output filename
940
- filename = stage1_info["filenames"]["input_data"]
941
  filename = os.path.basename(str(filename))
942
  new_filename = os.path.splitext(filename)[0]+'_'+modelname+'.csv'
943
 
944
- # reset inputChannel.assigned back to the state after stage1
945
- for channel in stage1_info["unassignedInputs"]:
946
- channel_info["inputDict"][channel]["assigned"] = False
947
- # calculate how many times the model needs to be run
948
- unassigned_num = len(stage1_info["unassignedInputs"])
949
- batch_num = math.ceil(unassigned_num/30) + 1
950
-
951
- app_info.update({
952
- #"currentStage" : "stage2",
953
- "stage2" : {
954
- "filepath" : new_filepath,
955
- "filenames" : {
956
- "output_data" : new_filepath + new_filename
957
- },
958
- #"state" : "initializing",
959
- "totalBatchNum" : batch_num,
960
- "newOrder" : [[]]*30,
961
- "unassignedInputs" : stage1_info["unassignedInputs"]
962
  }
963
  })
 
964
  return {app_info_json : app_info,
965
- channel_info_json : channel_info,
966
  #run_btn : gr.Button(interactive=False),
967
  batch_md : gr.Markdown(visible=False),
968
  out_data_file : gr.File(visible=False)}
969
 
970
- def run_model(app_info, channel_info, modelname):
971
  stage1_info = app_info["stage1"]
972
  stage2_info = app_info["stage2"]
973
 
974
- filepath = stage2_info["filepath"]
975
  samplerate = app_info["sampleRate"]
976
- filename = stage1_info["filenames"]["input_data"]
977
- new_filename = stage2_info["filenames"]["output_data"]
978
 
979
- # set a flag to record whether the user has clicked the map_btn or run_btn while running the model
980
  break_flag = False
981
 
982
  # run the model multiple times until all in_channels are reconstructed
@@ -988,7 +946,7 @@ with gr.Blocks() as demo:
988
  #utils.dataDelete(filepath+"temp_data/")
989
  #os.mkdir(filepath+"temp_data/")
990
  except FileNotFoundError:
991
- #print('break1!!')
992
  break_flag = True
993
  break
994
  except OSError as e:
@@ -998,33 +956,32 @@ with gr.Blocks() as demo:
998
  md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
999
  yield {batch_md : gr.Markdown(md, visible=True)}
1000
 
1001
- # ....
1002
- if i == 0:
1003
- new_idx = stage1_info["newOrder"]
1004
- else:
1005
- # if this is not the first time running the model, the in_channels that have
1006
- # not been reconstructed yet will be optimally mapped to the template.
1007
- stage2_info, channel_info = mapping_stage2(stage2_info, channel_info)
1008
- new_idx = stage2_info["newOrder"]
1009
- print('unassigned num:', len(stage2_info["unassignedInputs"]))
1010
-
1011
  # ----------------------------------------------------------------------
1012
  try:
1013
  # step1: Reorder input data
1014
- reorder_input_data(new_idx, filename, filepath+"temp_data/mapped.csv")
 
 
 
1015
  # step2: Data preprocessing
1016
  total_file_num = utils.preprocessing(filepath+"temp_data/", "mapped.csv", samplerate)
1017
  # step3: Signal reconstruction
1018
  utils.reconstruct(modelname, total_file_num, filepath+"temp_data/", "denoised.csv", samplerate)
 
 
 
1019
  # step4: Restore original order
1020
- restore_original_order(channel_info, i, new_idx, filepath+"temp_data/denoised.csv", new_filename)
 
1021
  except FileNotFoundError:
1022
- #print('break2!!')
1023
  break_flag = True
1024
  break
1025
  # ----------------------------------------------------------------------
1026
  utils.dataDelete(filepath+"temp_data/")
1027
- app_info["stage2"] = stage2_info
1028
 
1029
  if break_flag == True:
1030
  yield {batch_md : gr.Markdown(visible=False)}
@@ -1035,12 +992,12 @@ with gr.Blocks() as demo:
1035
 
1036
  run_btn.click(
1037
  fn = reset_run,
1038
- inputs = [app_info_json, channel_info_json, in_modelname],
1039
- outputs = [app_info_json, channel_info_json, run_btn, batch_md, out_data_file]
1040
 
1041
  ).success(
1042
  fn = run_model,
1043
- inputs = [app_info_json, channel_info_json, in_modelname],
1044
  outputs = [run_btn, batch_md, out_data_file]
1045
  )
1046
 
 
1
+ import utils
2
+ import app_utils
3
  import gradio as gr
 
4
  import os
5
  import random
 
 
 
 
 
 
 
 
 
6
 
7
  readme = """
8
 
 
27
  ### Step3: Filling Remaining Template Channels
28
  To run the models successfully, we need to ensure that all 30 template channels are filled. In this step, you are required to select one of the methods provided below to fill the remaining empty template channels:
29
  - **Mean** method: Each empty template channel is filled with the average value of data from the nearest input channels. By default, the 4 closest input channels (determined after aligning your montage to the template's scale using TPS) are selected for this averaging process. On the interface, you will see checkboxes displayed above each of your channel. The 4 nearest channels are pre-selected by default for each empty template channels, but you can modify these selections as needed. If you uncheck all the checkboxes for a particular template channel, it will be filled with zeros.
30
+ - **Zero** method: All empty template channels are filled with zeros.
31
+ Choose the method that best suits your needs, considering that the model's performance may vary depending on the method used.
32
+
33
+ ### Step4: Auto-mapping Remaining Channels
34
+ After completing the initial mapping steps, any channels that are not yet assigned to a template will be processed in this step. These remaining channels will be automatically mapped in batches, with a batch size of up to 30 channels. If the final batch contains fewer than 30 channels, the **Mean** method from Step3 will be applied to fill the remaining template channels.
35
+
36
+
37
+ ### Mapping Result
38
 
39
  ## 2. Decode data
40
  In this phase, you can select which model to use for denoising your EEG data. Detailed information about the models can be found in the other tabs.
 
62
  selector = "#radio-group > div:nth-of-type(2)";
63
  //classname = "radio";
64
  attribute = "value";
65
+ }else if(stage1_info.state == "step3-2-selecting"){
66
  selector = "#chkbox-group > div:nth-of-type(2)";
67
  //classname = "chkbox";
68
  attribute = "name";
 
76
  aspect-ratio: 1;
77
  //width: 560px;
78
  //height: 560px;
79
+ background: url("file=${stage1_info.fileNames.input_montage}");
80
  background-size: contain;
81
 
82
  `;
 
163
  item.className = "";
164
  item.querySelector(":scope > span").innerText = "";
165
  });
166
+ }else if(stage1_info.state == "step3-2-selecting"){
167
  selector = "#chkbox-group > div:nth-of-type(2)";
168
  }else return;
169
 
 
245
  map_btn = gr.Button("Mapping", interactive=False, scale=1)
246
 
247
  # ------------------------mapping------------------------
 
248
  desc_md = gr.Markdown(visible=False)
249
+ # step1 : initial mapping abd rescaling
250
  with gr.Row():
251
  tpl_img = gr.Image("./template_montage.png", label="Template channels", visible=False)
252
  mapped_img = gr.Image(label="Input channels", visible=False)
253
+ # step2 : forward unmatched input channels to empty template channels
254
  radio_group = gr.Radio(elem_id="radio-group", visible=False)
255
+ # step3 : fill the remaining template channels
256
  with gr.Row():
257
  in_fillmode = gr.Dropdown(choices=["mean", "zero"],
258
  value="mean",
 
261
  scale=2)
262
  fillmode_btn = gr.Button("OK", visible=False, scale=1)
263
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
264
+ # step4 : mapping result
265
+ out_json_file = gr.File(label="Mapping result", visible=False)
266
 
267
  with gr.Row():
268
  clear_btn = gr.Button("Clear", visible=False)
 
279
  ("ART", "EEGART"),
280
  ("IC-U-Net", "ICUNet"),
281
  ("IC-U-Net++", "UNetpp"),
282
+ ("IC-U-Net-Attn", "AttUnet"),
283
+ "(mapped data)",
284
+ "(denoised data)"],
285
  value="EEGART",
286
  label="Model",
287
  scale=2)
 
304
  with gr.Tab("README"):
305
  gr.Markdown(readme)
306
 
307
+ #demo.load(js=js)
308
 
309
  # verify that all required inputs have been provided
310
  @gr.on(triggers = [in_data_file.upload, in_data_file.clear, in_loc_file.upload, in_loc_file.clear, in_samplerate.change],
 
317
 
318
 
319
  # +========================================================================================+
320
+ # | Stage1: channel mapping |
321
  # +========================================================================================+
322
  def reset_all(in_data, in_loc, samplerate):
323
  # establish a new folder for the current session
324
+ rootpath = os.path.dirname(str(in_data))
325
  try:
326
+ os.mkdir(rootpath+"/session_data/")
327
  except OSError as e:
328
+ utils.dataDelete(rootpath+"/session_data/")
329
+ os.mkdir(rootpath+"/session_data/")
330
  print(e)
331
  # establish new folders for stage1 and stage2
332
+ os.mkdir(rootpath+"/session_data/stage1/")
333
+ os.mkdir(rootpath+"/session_data/stage2/")
334
 
335
  # initialize channel_info, app_info
336
  channel_info = {}
337
  app_info = {
338
+ "rootPath" : rootpath+"/session_data/",
339
  "sampleRate" : int(samplerate),
 
340
  "stage1" : {
341
+ "filePath" : rootpath+"/session_data/stage1/",
342
+ "fileNames" : {
343
  "input_data" : in_data,
344
  "input_loc" : in_loc,
345
  "input_montage" : "",
346
  "mapped_montage" : ""
347
  },
348
+ "state" : "step1-initializing",
349
  "fillingCount" : None,
350
  "totalFillingNum" : None,
 
351
  "unassignedInputs" : None,
352
+ "missingTemplates" : None,
353
+ "mappingData" : [
354
+ {
355
+ "newOrder" : None,
356
+ "fillFlags" : None,
357
+ #"channelUsageNum" : None
358
+ }
359
+ ]
360
  },
361
  "stage2" : {
362
+ "filePath" : rootpath+"/session_data/stage2/",
363
+ "fileNames" : {
364
  "output_data" : ""
365
  },
366
+ "totalBatchNum" : None
 
 
 
367
  }
368
  }
369
  # reset layout
370
  return {app_info_json : app_info,
371
  channel_info_json : channel_info,
372
+ # --------------------Stage1-------------------------
373
  map_btn : gr.Button(interactive=False),
374
  desc_md : gr.Markdown(visible=False),
375
+ next_btn : gr.Button(visible=False),
376
  tpl_img : gr.Image(visible=False),
377
  mapped_img : gr.Image(value=None, visible=False),
378
  radio_group : gr.Radio(choices=[], value=[], label="", visible=False),
 
 
 
379
  clear_btn : gr.Button(visible=False),
380
  step2_btn : gr.Button(visible=False),
381
+ in_fillmode : gr.Dropdown(value="mean", visible=False),
382
+ fillmode_btn : gr.Button(visible=False),
383
+ chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
384
  step3_btn : gr.Button(visible=False),
385
+ out_json_file : gr.File(value=None, visible=False),
386
+ # --------------------Stage2-------------------------
387
  run_btn : gr.Button(interactive=False),
388
  batch_md : gr.Markdown(visible=False),
389
  out_data_file : gr.File(visible=False)}
390
 
391
 
392
  # +========================================================================================+
393
+ # | manage step transition |
394
  # +========================================================================================+
395
+ def init_next_step(app_info, channel_info, fillmode, selected_radio, selected_chkbox):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  stage1_info = app_info["stage1"]
397
+ stage2_info = app_info["stage2"]
398
+ filepath = stage1_info["filePath"]
399
+
400
+ # ========================================step0=========================================
401
+ # step0 to step1
402
+ if stage1_info["state"] == "step1-initializing":
403
+ #print('step0 -> step1')
404
+
405
+ # 1. match the names of in_channels and tpl_channels
406
+ yield {desc_md : gr.Markdown("Mapping...", visible=True)}
407
+ stage1_info, channel_info, tpl_montage, in_montage = app_utils.match_names(stage1_info, channel_info)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
 
409
+ # 2. rescale coordinates
410
+ yield {desc_md : gr.Markdown("Rescaling...")}
411
+ channel_info = app_utils.align_coords(channel_info, tpl_montage, in_montage)
412
+
413
+ # 3. generate and save figures of the montages
414
+ filename1 = filepath+"input_montage_"+str(random.randint(1,10000))+".png"
415
+ filename2 = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
416
+ channel_info = app_utils.save_figures(channel_info, tpl_montage, filename1, filename2)
417
+ stage1_info["fileNames"].update({
418
+ "input_montage" : filename1,
419
+ "mapped_montage" : filename2
420
+ })
421
+
422
+ # 4. matching result
423
+ # check if there are red dots (unmatched in_channels) on the input montage
424
+ unassigned_num = len(stage1_info["unassignedInputs"])
425
+ if unassigned_num == 0:
426
  md = """
427
  ---
428
  ### Step1: Initial Matching and Rescaling
429
+ Below is the result of mapping your channels to our template channels based on their names.
 
430
  """
431
+ else:
 
 
 
 
432
  md = """
433
  ---
434
  ### Step1: Initial Matching and Rescaling
435
+ Below is the result of mapping your channels to our template channels based on their names.
436
+ - channels highlighted in red are those that do not match any template channels.
437
  """
438
 
439
+ stage1_info["state"] = "step1-finished"
440
  app_info["stage1"] = stage1_info
441
+ yield {app_info_json : app_info,
442
  channel_info_json : channel_info,
443
  map_btn : gr.Button(interactive=True),
444
+ desc_md : gr.Markdown(md),
445
  tpl_img : gr.Image(visible=True),
446
  mapped_img : gr.Image(value=filename2, visible=True),
447
  next_btn : gr.Button(visible=True)}
448
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
450
+ # ========================================step1=========================================
451
+ elif stage1_info["state"] == "step1-finished":
452
+ in_num = len(channel_info["inputOrder"])
453
+ matched_num = 30 - len(stage1_info["missingTemplates"])
 
 
 
 
 
454
 
455
+ # step1 to step4
456
+ # the in_channels has all the 30 tpl_channels (in_num>=30)
457
+ if matched_num == 30:
458
+ #print('step1 -> step4')
459
+ md = """
460
+ ---
461
+ ### Mapping result
462
+ (...)
463
+ """
464
+
465
+ # finalize and save the mapping results
466
+ filename = filepath+"mapping_result.json"
467
+ stage1_info, stage2_info, channel_info = app_utils.mapping_result(
468
+ stage1_info, stage2_info, channel_info, filename)
469
+ #gr.Info('The mapping process has been finished.')
470
+ stage1_info["state"] = "finished"
471
+ app_info["stage1"] = stage1_info
472
+ app_info["stage2"] = stage2_info
473
+ yield {app_info_json : app_info,
474
+ channel_info_json : channel_info,
475
+ desc_md : gr.Markdown(md),
476
+ tpl_img : gr.Image(visible=False),
477
+ mapped_img : gr.Image(visible=False),
478
+ next_btn : gr.Button(visible=False),
479
+ out_json_file : gr.File(filename, visible=True),
480
+ run_btn : gr.Button(interactive=True)}
481
 
482
+ # step1 to step2
483
+ # matched_num < 30, and there're still some unmatched in_channels
484
+ elif in_num > matched_num:
485
+ #print('step1 -> step2')
486
+ md = """
487
+ ---
488
+ ### Step2: Forwarding Unmatched Channels
489
+ Select one of your unmatched channels to forward its data to the empty template channel
490
+ currently indicated in red.
491
+ """
492
+
493
+ # initialize the progress indication label for step2
494
+ stage1_info.update({
495
+ "fillingCount" : 1,
496
+ "totalFillingNum" : len(stage1_info["missingTemplates"])
497
+ })
498
+ name = stage1_info["missingTemplates"][0]
499
+ label = "{} (1/{})".format(name, stage1_info["totalFillingNum"])
500
+
501
+ stage1_info["state"] = "step2-selecting"
502
+ app_info["stage1"] = stage1_info
503
+ # determine which button to display
504
+ if stage1_info["totalFillingNum"] == 1:
505
+ yield {app_info_json : app_info,
506
+ desc_md : gr.Markdown(md),
507
+ tpl_img : gr.Image(visible=False),
508
+ mapped_img : gr.Image(visible=False),
509
+ radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
510
+ clear_btn : gr.Button(visible=True)}
511
+ else:
512
+ yield {app_info_json : app_info,
513
+ desc_md : gr.Markdown(md),
514
+ tpl_img : gr.Image(visible=False),
515
+ mapped_img : gr.Image(visible=False),
516
+ radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
517
+ clear_btn : gr.Button(visible=True),
518
+ step2_btn : gr.Button(visible=True),
519
+ next_btn : gr.Button(visible=False)}
520
+
521
+ # step1 to step3-1
522
+ # in_num < 30, but all of them can match to some tpl_channels
523
+ elif in_num == matched_num:
524
+ #print('step1 -> step3-1')
525
+ md = """
526
+ ---
527
+ ### Step3: Filling Remaining Template Channels
528
+ To run the model successfully, we need to ensure that all 30 template channels are filled.
529
+ In this step, you are required to select one of the methods provided below to fill the
530
+ remaining empty template channels.
531
+ """
532
+
533
+ stage1_info["state"] = "step3-select-method"
534
+ app_info["stage1"] = stage1_info
535
+ yield {app_info_json : app_info,
536
+ desc_md : gr.Markdown(md),
537
+ tpl_img : gr.Image(visible=False),
538
+ mapped_img : gr.Image(visible=False),
539
+ in_fillmode : gr.Dropdown(visible=True),
540
+ fillmode_btn : gr.Button(visible=True),
541
+ next_btn : gr.Button(visible=False)}
542
 
543
+ # ========================================step2=========================================
544
  elif stage1_info["state"] == "step2-selecting":
545
 
546
+ # --------------------store information before the button click---------------------
 
547
  # check if the user has selected an in_channel to forward to the previous target tpl_channel
548
  if selected_radio != []:
549
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
 
551
 
552
  # store the index of the in_channel
553
  selected_idx = channel_info["inputDict"][selected_radio]["index"]
554
+ stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = [selected_idx]
555
+ stage1_info["mappingData"][0]["fillFlags"][prev_target_idx] = False
556
  # mark the in_channel as assigned and tpl_channel as matched
557
  channel_info["templateDict"][prev_target_name]["matched"] = True
558
  channel_info["inputDict"][selected_radio]["assigned"] = True
559
+ #print(prev_target_name, '<-', selected_radio)
 
 
560
 
561
+ # -----------------------update information for the next step-----------------------
562
  # update the list of unassignedInputs to exclude the selected in_channel of the previous round
563
+ stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"],
564
+ channel_info["inputDict"])
565
  # update the list of missingTemplates to exclude those filled in step2
566
+ stage1_info["missingTemplates"] = app_utils.get_empty_templates(channel_info["templateOrder"],
567
+ channel_info["templateDict"])
568
 
569
+ # -----------------------------determine the next step------------------------------
570
+ # step2 to step4
571
+ # all the unmatched tpl_channels were filled by in_channels
572
  if len(stage1_info["missingTemplates"]) == 0:
573
+ #print('step2 -> step4')
574
+ md = """
575
+ ---
576
+ ### Mapping result
577
+ (...)
578
+ """
579
 
580
+ # finalize and save the mapping results
581
+ filename = filepath+"mapping_result.json"
582
+ stage1_info, stage2_info, channel_info = app_utils.mapping_result(
583
+ stage1_info, stage2_info, channel_info, filename)
584
+ #gr.Info('The mapping process has been finished.')
585
+ stage1_info["state"] = "finished"
586
  app_info["stage1"] = stage1_info
587
+ app_info["stage2"] = stage2_info
588
+ yield {app_info_json : app_info,
589
  channel_info_json : channel_info,
590
+ desc_md : gr.Markdown(md),
591
  radio_group : gr.Radio(visible=False),
592
+ out_json_file : gr.File(filename, visible=True),
593
  clear_btn : gr.Button(visible=False),
594
  next_btn : gr.Button(visible=False),
595
  run_btn : gr.Button(interactive=True)}
596
+ # step2 to step3-1
 
597
  else:
598
+ #print('step2 -> step3-1')
599
  md = """
600
  ---
601
  ### Step3: Filling Remaining Template Channels
 
604
  remaining empty template channels.
605
  """
606
 
607
+ stage1_info["state"] = "step3-select-method"
608
  app_info["stage1"] = stage1_info
609
+ yield {app_info_json : app_info,
610
  channel_info_json : channel_info,
611
  desc_md : gr.Markdown(md),
612
  radio_group : gr.Radio(visible=False),
 
615
  clear_btn : gr.Button(visible=False),
616
  next_btn : gr.Button(visible=False)}
617
 
618
+ # =======================================step3-1========================================
619
+ elif stage1_info["state"] == "step3-select-method":
 
 
 
620
 
621
+ # step3-1 to step4
622
+ if fillmode == "zero":
623
+ #print('step3-1 -> step4')
624
+ md = """
625
+ ---
626
+ ### Mapping result
627
+ (...)
628
+ """
629
+
630
+ # finalize and save the mapping results
631
+ filename = filepath+"mapping_result.json"
632
+ stage1_info, stage2_info, channel_info = app_utils.mapping_result(
633
+ stage1_info, stage2_info, channel_info, filename)
634
+ #gr.Info('The mapping process has been finished.')
635
+ stage1_info["state"] = "finished"
636
+ app_info["stage1"] = stage1_info
637
+ app_info["stage2"] = stage2_info
638
+ yield {app_info_json : app_info,
639
+ channel_info_json : channel_info,
640
+ desc_md : gr.Markdown(md),
641
+ in_fillmode : gr.Dropdown(visible=False),
642
+ fillmode_btn : gr.Button(visible=False),
643
+ out_json_file : gr.File(filename, visible=True),
644
+ run_btn : gr.Button(interactive=True)}
645
+ # step3-1 to step3-2
646
+ elif fillmode == "mean":
647
+ #print('step3-1 -> step3-2')
648
+ md = """
649
+ ---
650
+ ### Step3: Fill the remaining template channels
651
+ The current empty template channel, indicated in red, will be filled with the average
652
+ value of the data from the selected channels. (By default, the 4 nearest channels are pre-selected.)
653
+ """
654
+
655
+ # find the 4 nearest in_channels for each unmatched tpl_channels
656
+ stage1_info["mappingData"][0]["newOrder"] = app_utils.find_neighbors(
657
+ channel_info,
658
+ stage1_info["missingTemplates"],
659
+ stage1_info["mappingData"][0]["newOrder"])
660
+
661
+ # initialize the progress indication label
662
+ stage1_info.update({
663
+ "fillingCount" : 1,
664
+ "totalFillingNum" : len(stage1_info["missingTemplates"])
665
+ })
666
+ target_name = stage1_info["missingTemplates"][0]
667
+ target_idx = channel_info["templateDict"][target_name]["index"]
668
+ chkbox_value = stage1_info["mappingData"][0]["newOrder"][target_idx]
669
+ chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
670
+ chkbox_label = "{} (1/{})".format(target_name, stage1_info["totalFillingNum"])
671
+
672
+ stage1_info["state"] = "step3-2-selecting"
673
+ app_info["stage1"] = stage1_info
674
+ # determine which button to display
675
+ if stage1_info["totalFillingNum"] == 1:
676
+ yield {app_info_json : app_info,
677
+ desc_md : gr.Markdown(md),
678
+ in_fillmode : gr.Dropdown(visible=False),
679
+ fillmode_btn : gr.Button(visible=False),
680
+ chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
681
+ value=chkbox_value, label=chkbox_label, visible=True),
682
+ next_btn : gr.Button(visible=True)}
683
+ else:
684
+ yield {app_info_json : app_info,
685
+ desc_md : gr.Markdown(md),
686
+ in_fillmode : gr.Dropdown(visible=False),
687
+ fillmode_btn : gr.Button(visible=False),
688
+ chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
689
+ value=chkbox_value, label=chkbox_label, visible=True),
690
+ step3_btn : gr.Button(visible=True)}
691
+
692
+ # =======================================step3-2========================================
693
+ # step3-2 to step4
694
+ elif stage1_info["state"] == "step3-2-selecting":
695
+ #print('step3-2 -> step4')
696
 
697
+ # --------------------store information before the button click---------------------
698
  # if the user didn't uncheck all in_channel checkboxes
699
  if selected_chkbox != []:
700
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
 
702
 
703
  # store the indices of the in_channels
704
  selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
705
+ stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
706
+ #print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
707
+ # ----------------------------------------------------------------------------------
708
+ md = """
709
+ ---
710
+ ### Mapping result
711
+ (...)
712
+ """
713
 
714
+ # finalize and save the mapping results
715
+ filename = filepath+"mapping_result.json"
716
+ stage1_info, stage2_info, channel_info = app_utils.mapping_result(
717
+ stage1_info, stage2_info, channel_info, filename)
718
+ #gr.Info('The mapping process has been finished.')
719
+ stage1_info["state"] = "finished"
720
  app_info["stage1"] = stage1_info
721
+ app_info["stage2"] = stage2_info
722
+ yield {app_info_json : app_info,
723
+ channel_info_json : channel_info,
724
+ desc_md : gr.Markdown(md),
725
  chkbox_group : gr.CheckboxGroup(visible=False),
726
  next_btn : gr.Button(visible=False),
727
+ out_json_file : gr.File(filename, visible=True),
728
  run_btn : gr.Button(interactive=True)}
729
 
730
  next_btn.click(
731
  fn = init_next_step,
732
+ inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
733
+ outputs = [app_info_json, channel_info_json, desc_md, tpl_img, mapped_img, radio_group, clear_btn, step2_btn,
734
+ in_fillmode, fillmode_btn, chkbox_group, step3_btn, out_json_file, next_btn, run_btn]
735
  ).success(
736
  fn = None,
737
  js = init_js,
 
741
 
742
 
743
  # +========================================================================================+
744
+ # | Stage1-step0 |
745
  # +========================================================================================+
746
+ map_btn.click(
747
+ fn = reset_all,
748
+ inputs = [in_data_file, in_loc_file, in_samplerate],
749
+ outputs = [app_info_json, channel_info_json, map_btn, desc_md, next_btn, tpl_img, mapped_img,
750
+ radio_group, clear_btn, step2_btn, in_fillmode, fillmode_btn, chkbox_group, step3_btn, out_json_file,
751
+ run_btn, batch_md, out_data_file]
752
+ ).success(
753
+ fn = init_next_step,
754
+ inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
755
+ outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_img, mapped_img, next_btn]
756
+ )
757
+
758
+
759
+ # +========================================================================================+
760
+ # | Stage1-step2 |
761
+ # +========================================================================================+
762
+ # ...
763
  @radio_group.select(inputs = app_info_json, outputs = [step2_btn, next_btn])
764
  def determine_button(app_info):
765
  stage1_info = app_info["stage1"]
 
782
  def update_radio(app_info, channel_info, selected):
783
  stage1_info = app_info["stage1"]
784
 
785
+ # ----------------------store information before the button click-----------------------
 
786
  # check if the user has selected an in_channel to forward to the previous target tpl_channel
787
  if selected != []:
788
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
 
790
 
791
  # store the index of the selected in_channel
792
  selected_idx = channel_info["inputDict"][selected]["index"]
793
+ stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = [selected_idx]
794
+ stage1_info["mappingData"][0]["fillFlags"][prev_target_idx] = False
795
  # mark the in_channel as assigned and tpl_channel as matched
796
  channel_info["templateDict"][prev_target_name]["matched"] = True
797
  channel_info["inputDict"][selected]["assigned"] = True
798
+ #print(prev_target_name, '<-', selected)
799
 
800
+ # ------------------------update information for the new round--------------------------
801
  stage1_info["fillingCount"] += 1
802
 
803
  # update the list of unassignedInputs to exclude the selected in_channel of the previous round
804
+ stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"], channel_info["inputDict"])
 
805
  # update the progress indication label
806
  target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
807
  radio_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
 
814
  radio_group : gr.Radio(choices=stage1_info["unassignedInputs"],
815
  value=[], label=radio_label),
816
  step2_btn : gr.Button(visible=False),
817
+ next_btn : gr.Button(visible=True)}
818
  else:
819
  return {app_info_json : app_info,
820
  channel_info_json : channel_info,
 
834
 
835
 
836
  # +========================================================================================+
837
+ # | Stage1-step3 |
838
+ # +========================================================================================+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
  def update_chkbox(app_info, channel_info, selected):
840
  stage1_info = app_info["stage1"]
841
 
842
+ # ----------------------store information before the button click-----------------------
 
843
  # if the user didn't uncheck all in_channel checkboxes
844
  if selected != []:
845
  prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
 
847
 
848
  # store the indices of the selected in_channels
849
  selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected]
850
+ stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
851
+ #print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
852
 
853
+ # ------------------------update information for the new round--------------------------
854
  stage1_info["fillingCount"] += 1
855
 
856
  # update the progress indication label
857
  target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
858
  target_idx = channel_info["templateDict"][target_name]["index"]
859
+ chkbox_value = stage1_info["mappingData"][0]["newOrder"][target_idx]
860
  chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
861
  chkbox_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
862
 
 
866
  return {app_info_json : app_info,
867
  chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label),
868
  step3_btn : gr.Button(visible=False),
869
+ next_btn : gr.Button(visible=True)}
870
  else:
871
  return {app_info_json : app_info,
872
  chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)}
873
 
874
  fillmode_btn.click(
875
+ fn = init_next_step,
876
+ inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
877
+ outputs = [app_info_json, channel_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
878
+ out_json_file, next_btn, run_btn]
879
  ).success(
880
  fn = None,
881
  js = init_js,
 
896
 
897
 
898
  # +========================================================================================+
899
+ # | Stage2: decode data |
900
  # +========================================================================================+
901
+ def reset_run(app_info, modelname):
902
  stage1_info = app_info["stage1"]
903
  stage2_info = app_info["stage2"]
904
 
905
+ # delete the previous folder of Stage2
906
+ filepath = stage2_info["filePath"]
907
  utils.dataDelete(filepath)
908
+ # establish a new folder for Stage2
909
+ new_filepath = app_info["rootPath"]+"stage2_"+str(random.randint(1,10000))+"/"
910
  os.mkdir(new_filepath)
911
  # generate the output filename
912
+ filename = stage1_info["fileNames"]["input_data"]
913
  filename = os.path.basename(str(filename))
914
  new_filename = os.path.splitext(filename)[0]+'_'+modelname+'.csv'
915
 
916
+ stage2_info.update({
917
+ "filePath" : new_filepath,
918
+ "fileNames" : {
919
+ "output_data" : new_filepath + new_filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
920
  }
921
  })
922
+ app_info["stage2"] = stage2_info
923
  return {app_info_json : app_info,
 
924
  #run_btn : gr.Button(interactive=False),
925
  batch_md : gr.Markdown(visible=False),
926
  out_data_file : gr.File(visible=False)}
927
 
928
+ def run_model(app_info, modelname):
929
  stage1_info = app_info["stage1"]
930
  stage2_info = app_info["stage2"]
931
 
932
+ filepath = stage2_info["filePath"]
933
  samplerate = app_info["sampleRate"]
934
+ filename = stage1_info["fileNames"]["input_data"]
935
+ new_filename = stage2_info["fileNames"]["output_data"]
936
 
937
+ # flag to indicate if the process has been interrupted by the user
938
  break_flag = False
939
 
940
  # run the model multiple times until all in_channels are reconstructed
 
946
  #utils.dataDelete(filepath+"temp_data/")
947
  #os.mkdir(filepath+"temp_data/")
948
  except FileNotFoundError:
949
+ print('break1!!')
950
  break_flag = True
951
  break
952
  except OSError as e:
 
956
  md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
957
  yield {batch_md : gr.Markdown(md, visible=True)}
958
 
959
+ # get the mapped index order and the filled status for each tpl_channels
960
+ new_idx = stage1_info["mappingData"][i]["newOrder"]
961
+ fill_flags = stage1_info["mappingData"][i]["fillFlags"]
 
 
 
 
 
 
 
962
  # ----------------------------------------------------------------------
963
  try:
964
  # step1: Reorder input data
965
+ data_shape = app_utils.reorder_data(new_idx, fill_flags, filename, filepath+"temp_data/mapped.csv")
966
+ if modelname == "(mapped data)":
967
+ new_filename = filepath+"temp_data/mapped.csv"
968
+ break
969
  # step2: Data preprocessing
970
  total_file_num = utils.preprocessing(filepath+"temp_data/", "mapped.csv", samplerate)
971
  # step3: Signal reconstruction
972
  utils.reconstruct(modelname, total_file_num, filepath+"temp_data/", "denoised.csv", samplerate)
973
+ if modelname == "(denoised data)":
974
+ new_filename = filepath+"temp_data/denoised.csv"
975
+ break
976
  # step4: Restore original order
977
+ app_utils.restore_order(i, data_shape, new_idx, fill_flags, filepath+"temp_data/denoised.csv", new_filename)
978
+ break
979
  except FileNotFoundError:
980
+ print('break2!!')
981
  break_flag = True
982
  break
983
  # ----------------------------------------------------------------------
984
  utils.dataDelete(filepath+"temp_data/")
 
985
 
986
  if break_flag == True:
987
  yield {batch_md : gr.Markdown(visible=False)}
 
992
 
993
  run_btn.click(
994
  fn = reset_run,
995
+ inputs = [app_info_json, in_modelname],
996
+ outputs = [app_info_json, run_btn, batch_md, out_data_file]
997
 
998
  ).success(
999
  fn = run_model,
1000
+ inputs = [app_info_json, in_modelname],
1001
  outputs = [run_btn, batch_md, out_data_file]
1002
  )
1003
 
channel_mapping.py → app_utils.py RENAMED
@@ -1,144 +1,208 @@
1
  import utils
2
- import time
3
  import os
 
 
 
4
  import numpy as np
5
- import gradio as gr
6
-
7
  import mne
8
  from mne.channels import read_custom_montage
9
  from scipy.interpolate import Rbf
10
  from scipy.optimize import linear_sum_assignment
11
  from sklearn.neighbors import NearestNeighbors
12
 
13
- def reorder_input_data(old_idx, filename, new_filename):
14
- old_data = utils.read_train_data(filename) # original raw data
15
- new_data = np.zeros((30, old_data.shape[1])) # to store reordered raw data
16
- print('old.shape, new.shape: ', old_data.shape, new_data.shape)
17
- print('new index order:', old_idx)
18
-
19
- zero_arr = np.zeros((1, old_data.shape[1]))
20
- old_data = np.concatenate((old_data, zero_arr), axis=0)
21
-
22
- for i in range(30):
23
- idx_set = old_idx[i]
24
- #print("channel_{}'s index set: {}".format(i, idx_set))
25
-
26
- if idx_set == []:
27
  new_data[i, :] = zero_arr
28
  else:
29
- tmp_data = [old_data[j, :] for j in idx_set]
30
  new_data[i, :] = np.mean(tmp_data, axis=0)
31
 
 
32
  utils.save_data(new_data, new_filename)
33
- return
34
 
35
- def restore_original_order(channel_info, cnt, old_idx, filename, new_filename):
36
- old_data = utils.read_train_data(filename) # denoised data
37
- tpl_order = channel_info["templateOrder"]
38
- in_order = channel_info["inputOrder"]
39
-
40
- if cnt == 0:
41
- new_data = np.zeros((len(in_order), old_data.shape[1]))
42
  else:
43
  new_data = utils.read_train_data(new_filename)
44
 
45
- for i, channel in enumerate(tpl_order):
46
- idx_set = old_idx[i]
47
-
48
- # ignore if this channel was filled with fillmode ('mean' or 'zero')
49
- if len(idx_set)==1 and channel_info["templateDict"][channel]["matched"]==True:
50
- new_data[idx_set[0], :] = old_data[i, :]
51
 
52
- print('old.shape, new.shape: ', old_data.shape, new_data.shape)
53
  utils.save_data(new_data, new_filename)
54
  return
55
 
 
 
 
 
 
 
 
 
 
56
  def read_montage_data(loc_file):
57
-
58
  tpl_montage = read_custom_montage("./template_chanlocs.loc")
59
  in_montage = read_custom_montage(loc_file)
 
 
60
  tpl_dict = {}
61
  in_dict = {}
62
 
63
- for i in range(30):
64
- channel = tpl_montage.ch_names[i]
65
- tpl_montage.rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase
66
-
67
- channel = str.upper(channel)
68
- tpl_dict[channel] = {
69
  "index" : i,
70
- "coord_3d" : tpl_montage.get_positions()['ch_pos'][channel],
71
  "matched" : False
72
  }
73
- for i in range(len(in_montage.ch_names)):
74
- channel = in_montage.ch_names[i]
75
- in_montage.rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase
76
-
77
- channel = str.upper(channel)
78
- in_dict[channel] = {
79
  "index" : i,
80
- "coord_3d" : in_montage.get_positions()['ch_pos'][channel],
81
  "assigned" : False
82
  }
83
-
84
  return tpl_montage, in_montage, tpl_dict, in_dict
85
 
86
- def align_coords(channel_info, tpl_montage, in_montage):
87
-
 
88
  tpl_dict = channel_info["templateDict"]
89
  in_dict = channel_info["inputDict"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  tpl_order = channel_info["templateOrder"]
91
  in_order = channel_info["inputOrder"]
92
- matched = [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
93
-
 
94
 
95
- # 2-D (to indicate the location of the missing template channel when fill_mode:'mean')
96
  fig = [tpl_montage.plot(), in_montage.plot()]
97
  ax = [fig[0].axes[0], fig[1].axes[0]]
98
 
99
- # get the original coords
100
- #all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) # displayed coords (px)
101
- #all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data)
102
  all_tpl = ax[0].collections[0].get_offsets().data
103
  all_in= ax[1].collections[0].get_offsets().data
104
  matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
105
  matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
106
 
107
- # transform the xy axis (input's -> template's)
108
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
109
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
110
 
111
- # apply to all input channels
112
  transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
113
  transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
114
  transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
115
 
116
- # save template's and input's 2d position
117
  for i, channel in enumerate(tpl_order):
118
  tpl_dict[channel]["coord_2d"] = all_tpl[i]
119
  for i, channel in enumerate(in_order):
120
  in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
121
 
122
 
123
- # 3-d (to use KNN)
124
- # get the original coords
125
  all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
126
  all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
127
  matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
128
  matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
129
 
130
- # transform the xyz axis (input's -> template's)
131
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
132
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
133
  rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
134
 
135
- # apply to all input channels
136
  transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
137
  transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
138
  transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
139
  transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
140
 
141
- # update input's 3d position
142
  for i, channel in enumerate(in_order):
143
  in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
144
 
@@ -149,133 +213,153 @@ def align_coords(channel_info, tpl_montage, in_montage):
149
  return channel_info
150
 
151
  def find_neighbors(channel_info, missing_channels, new_idx):
 
152
  tpl_dict = channel_info["templateDict"]
153
  in_dict = channel_info["inputDict"]
154
- in_order = channel_info["inputOrder"]
155
 
 
156
  all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
157
- missing_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
158
 
159
  # use KNN to choose k nearest channels
160
  k = 4 if len(in_order)>4 else len(in_order)
161
  knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
162
  knn.fit(all_in)
163
-
164
  for i, channel in enumerate(missing_channels):
165
- distances, indices = knn.kneighbors(missing_tpl[i].reshape(1,-1))
166
- #selected = [in_order[j] for j in indices[0]]
167
- #print(channel, ':', selected)
168
-
169
  idx = tpl_dict[channel]["index"]
170
  new_idx[idx] = indices[0].tolist()
171
 
172
  return new_idx
173
 
174
- def mapping_stage1(app_info, channel_info):
175
- yield app_info, channel_info, gr.Markdown("Mapping...", visible=True)
176
- second1 = time.time()
177
-
178
- loc_file = app_info["stage1"]["filenames"]["input_loc"]
179
  tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
180
  tpl_order = tpl_montage.ch_names
181
  in_order = in_montage.ch_names
182
- new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channls
 
 
 
 
183
  alias_dict = {
184
  'T3': 'T7',
185
  'T4': 'T8',
186
  'T5': 'P7',
187
  'T6': 'P8'
188
  }
189
-
190
- # match the names of input channels and template channels
191
  for i, channel in enumerate(tpl_order):
192
  if channel in alias_dict and alias_dict[channel] in in_dict:
193
- tpl_montage.rename_channels({channel: alias_dict[channel]}) # rename the current tpl_channel
194
  tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
195
  channel = alias_dict[channel]
196
 
197
  if channel in in_dict:
198
  new_idx[i] = [in_dict[channel]["index"]]
 
199
  tpl_dict[channel]["matched"] = True
200
  in_dict[channel]["assigned"] = True
201
 
202
  # update the names
203
  tpl_order = tpl_montage.ch_names
204
 
 
 
 
 
 
 
 
 
 
 
205
  channel_info.update({
 
 
206
  "templateDict" : tpl_dict,
207
- "inputDict" : in_dict,
208
- "templateOrder" : tpl_order,
209
- "inputOrder" : in_order
210
- })
211
- app_info["stage1"].update({
212
- "newOrder" : new_idx,
213
- "unassignedInputs" : [channel for channel in in_order if in_dict[channel]["assigned"]==False],
214
- "missingTemplates" : [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
215
  })
216
-
217
- # align input, template's coordinates
218
- channel_info = align_coords(channel_info, tpl_montage, in_montage)
219
-
220
- second2 = time.time()
221
- print('Mapping (stage1) finished in',second2 - second1,'s.')
222
- yield app_info, channel_info, gr.Markdown("", visible=False)
223
 
224
- def mapping_stage2(stage2_info, channel_info):
225
- second1 = time.time()
226
-
227
- tpl_dict = channel_info["templateDict"]
228
- in_dict = channel_info["inputDict"]
229
  tpl_order = channel_info["templateOrder"]
230
  in_order = channel_info["inputOrder"]
231
- unassigned = stage2_info["unassignedInputs"]
232
-
233
- tpl_coords = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
234
- unassigned_coords = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
235
-
236
  # reset all tpl.matched to False
237
  for channel in tpl_dict:
238
  tpl_dict[channel]["matched"] = False
239
 
240
- # initialize the cost matrix
 
 
 
 
241
  if len(unassigned) < 30:
242
  cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
243
  else:
244
  cost_matrix = np.zeros((30, len(unassigned)))
 
245
  for i in range(30):
246
  for j in range(len(unassigned)):
247
- cost_matrix[i][j] = np.linalg.norm((tpl_coords[i]-unassigned_coords[j])*1000) # Euclidean distance
248
- #print(cost_matrix[i][j], tpl_coords[i] - unassigned_coords[j])
249
 
250
- # use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
 
251
  row_idx, col_idx = linear_sum_assignment(cost_matrix)
252
 
 
253
  new_idx = [[]]*30
 
254
  for i in range(30):
255
  if col_idx[i] < len(unassigned): # filter out dummy channels
256
  tpl_channel = tpl_order[row_idx[i]]
257
  in_channel = unassigned[col_idx[i]]
 
 
 
258
  tpl_dict[tpl_channel]["matched"] = True
259
  in_dict[in_channel]["assigned"] = True
260
- new_idx[row_idx[i]] = [in_dict[in_channel]["index"]]
261
-
262
- print(f'{tpl_order[row_idx[i]]}({row_idx[i]}) <- {unassigned[col_idx[i]]}({col_idx[i]})')
263
 
264
- # fill the missing_channels
265
- missing_channels = [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
266
  if missing_channels != []:
267
  new_idx = find_neighbors(channel_info, missing_channels, new_idx)
268
 
269
- stage2_info.update({
270
- "newOrder" : new_idx,
271
- "unassignedInputs" : [channel for channel in in_order if in_dict[channel]["assigned"]==False]
272
- })
273
  channel_info.update({
274
  "templateDict" : tpl_dict,
275
  "inputDict" : in_dict
276
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- second2 = time.time()
279
- print("The mapping process has been finished in", second2 - second1, "s.")
280
- return stage2_info, channel_info
281
 
 
1
  import utils
 
2
  import os
3
+ import time
4
+ import math
5
+ import json
6
  import numpy as np
7
+ import matplotlib.pyplot as plt
 
8
  import mne
9
  from mne.channels import read_custom_montage
10
  from scipy.interpolate import Rbf
11
  from scipy.optimize import linear_sum_assignment
12
  from sklearn.neighbors import NearestNeighbors
13
 
14
+ def reorder_data(idx_order, fill_flags, filename, new_filename):
15
+ # read the input data
16
+ raw_data = utils.read_train_data(filename)
17
+ new_data = np.zeros((30, raw_data.shape[1]))
18
+
19
+ zero_arr = np.zeros((1, raw_data.shape[1]))
20
+ for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
21
+ if flag == False:
22
+ new_data[i, :] = raw_data[idx_set[0], :]
23
+ elif idx_set == []:
 
 
 
 
24
  new_data[i, :] = zero_arr
25
  else:
26
+ tmp_data = [raw_data[j, :] for j in idx_set]
27
  new_data[i, :] = np.mean(tmp_data, axis=0)
28
 
29
+ #print(raw_data.shape, new_data.shape)
30
  utils.save_data(new_data, new_filename)
31
+ return raw_data.shape
32
 
33
+ def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
34
+ # read the denoised data
35
+ d_data = utils.read_train_data(filename)
36
+ if batch_cnt == 0:
37
+ new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
 
 
38
  else:
39
  new_data = utils.read_train_data(new_filename)
40
 
41
+ for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
42
+ # ignore if this channel was filled using "fillmode"
43
+ if flag == False:
44
+ new_data[idx_set[0], :] = d_data[i, :]
 
 
45
 
46
+ #print(d_data.shape, new_data.shape)
47
  utils.save_data(new_data, new_filename)
48
  return
49
 
50
+ def get_matched(tpl_order, tpl_dict):
51
+ return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
52
+
53
+ def get_empty_templates(tpl_order, tpl_dict):
54
+ return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
55
+
56
+ def get_unassigned_inputs(in_order, in_dict):
57
+ return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
58
+
59
  def read_montage_data(loc_file):
 
60
  tpl_montage = read_custom_montage("./template_chanlocs.loc")
61
  in_montage = read_custom_montage(loc_file)
62
+ tpl_order = tpl_montage.ch_names
63
+ in_order = in_montage.ch_names
64
  tpl_dict = {}
65
  in_dict = {}
66
 
67
+ # convert all channel names to uppercase and store the channel information
68
+ for i, channel in enumerate(tpl_order):
69
+ up_channel = str.upper(channel)
70
+ tpl_montage.rename_channels({channel: up_channel})
71
+ tpl_dict[up_channel] = {
 
72
  "index" : i,
73
+ "coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
74
  "matched" : False
75
  }
76
+ for i, channel in enumerate(in_order):
77
+ up_channel = str.upper(channel)
78
+ in_montage.rename_channels({channel: up_channel})
79
+ in_dict[up_channel] = {
 
 
80
  "index" : i,
81
+ "coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
82
  "assigned" : False
83
  }
 
84
  return tpl_montage, in_montage, tpl_dict, in_dict
85
 
86
+ def save_figures(channel_info, tpl_montage, filename1, filename2):
87
+ tpl_order = channel_info["templateOrder"]
88
+ in_order = channel_info["inputOrder"]
89
  tpl_dict = channel_info["templateDict"]
90
  in_dict = channel_info["inputDict"]
91
+
92
+ # get the 2D coordinates
93
+ tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
94
+ tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
95
+ in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
96
+ in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
97
+ tpl_coords = np.vstack((tpl_x, tpl_y)).T
98
+ in_coords = np.vstack((in_x, in_y)).T
99
+
100
+ # extract template's head figure
101
+ tpl_fig = tpl_montage.plot()
102
+ tpl_ax = tpl_fig.axes[0]
103
+ lines = tpl_ax.lines
104
+ head_lines = []
105
+ for line in lines:
106
+ x, y = line.get_data()
107
+ head_lines.append((x,y))
108
+ plt.close()
109
+
110
+ # -------------------------plot input montage------------------------------
111
+ fig = plt.figure(figsize=(6.4,6.4), dpi=100)
112
+ ax = fig.add_subplot(111)
113
+ fig.tight_layout()
114
+ ax.set_aspect('equal')
115
+ ax.axis('off')
116
+
117
+ # plot template's head
118
+ for x, y in head_lines:
119
+ ax.plot(x, y, color='black', linewidth=1.0)
120
+ # plot in_channels on it
121
+ ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
122
+ for i, channel in enumerate(in_order):
123
+ ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
124
+ # save input_montage
125
+ fig.savefig(filename1)
126
+
127
+ # ---------------------------add indications-------------------------------
128
+ # plot unmatched input channels in red
129
+ indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
130
+ ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
131
+ for i in indices:
132
+ ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
133
+ # save mapped_montage
134
+ fig.savefig(filename2)
135
+
136
+ # -------------------------------------------------------------------------
137
+ # store the tpl and in_channels' display positions (in px).
138
+ tpl_coords = ax.transData.transform(tpl_coords)
139
+ in_coords = ax.transData.transform(in_coords)
140
+ plt.close()
141
+
142
+ for i, channel in enumerate(tpl_order):
143
+ css_left = (tpl_coords[i,0]-11)/6.4
144
+ css_bottom = (tpl_coords[i,1]-7)/6.4
145
+ tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
146
+ for i, channel in enumerate(in_order):
147
+ css_left = (in_coords[i,0]-11)/6.4
148
+ css_bottom = (in_coords[i,1]-7)/6.4
149
+ in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
150
+
151
+ channel_info.update({
152
+ "templateDict" : tpl_dict,
153
+ "inputDict" : in_dict
154
+ })
155
+ return channel_info
156
+
157
+ def align_coords(channel_info, tpl_montage, in_montage):
158
  tpl_order = channel_info["templateOrder"]
159
  in_order = channel_info["inputOrder"]
160
+ tpl_dict = channel_info["templateDict"]
161
+ in_dict = channel_info["inputDict"]
162
+ matched = get_matched(tpl_order, tpl_dict)
163
 
164
+ # 2D alignment (for visualization purposes)
165
  fig = [tpl_montage.plot(), in_montage.plot()]
166
  ax = [fig[0].axes[0], fig[1].axes[0]]
167
 
168
+ # extract the displayed 2D coordinates from the plots
 
 
169
  all_tpl = ax[0].collections[0].get_offsets().data
170
  all_in= ax[1].collections[0].get_offsets().data
171
  matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
172
  matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
173
 
174
+ # apply TPS to transform in_channels positions to align with tpl_channels positions
175
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
176
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
177
 
178
+ # apply the transformation to all in_channels
179
  transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
180
  transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
181
  transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
182
 
183
+ # store the 2D positions
184
  for i, channel in enumerate(tpl_order):
185
  tpl_dict[channel]["coord_2d"] = all_tpl[i]
186
  for i, channel in enumerate(in_order):
187
  in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
188
 
189
 
190
+ # 3D alignment
 
191
  all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
192
  all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
193
  matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
194
  matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
195
 
 
196
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
197
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
198
  rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
199
 
 
200
  transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
201
  transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
202
  transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
203
  transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
204
 
205
+ # update in_channels' 3D positions
206
  for i, channel in enumerate(in_order):
207
  in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
208
 
 
213
  return channel_info
214
 
215
  def find_neighbors(channel_info, missing_channels, new_idx):
216
+ in_order = channel_info["inputOrder"]
217
  tpl_dict = channel_info["templateDict"]
218
  in_dict = channel_info["inputDict"]
 
219
 
220
+ # get the 3D coordinates
221
  all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
222
+ empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
223
 
224
  # use KNN to choose k nearest channels
225
  k = 4 if len(in_order)>4 else len(in_order)
226
  knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
227
  knn.fit(all_in)
 
228
  for i, channel in enumerate(missing_channels):
229
+ distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
 
 
 
230
  idx = tpl_dict[channel]["index"]
231
  new_idx[idx] = indices[0].tolist()
232
 
233
  return new_idx
234
 
235
+ def match_names(stage1_info, channel_info):
236
+ # read the location file
237
+ loc_file = stage1_info["fileNames"]["input_loc"]
 
 
238
  tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
239
  tpl_order = tpl_montage.ch_names
240
  in_order = in_montage.ch_names
241
+ # list to store the indices of the in_channels in the order of tpl_channls
242
+ new_idx = [[]]*30
243
+ # flags to record if each tpl_channel's data is filled by "fillmode"
244
+ fill_flags = [True]*30
245
+
246
  alias_dict = {
247
  'T3': 'T7',
248
  'T4': 'T8',
249
  'T5': 'P7',
250
  'T6': 'P8'
251
  }
 
 
252
  for i, channel in enumerate(tpl_order):
253
  if channel in alias_dict and alias_dict[channel] in in_dict:
254
+ tpl_montage.rename_channels({channel: alias_dict[channel]})
255
  tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
256
  channel = alias_dict[channel]
257
 
258
  if channel in in_dict:
259
  new_idx[i] = [in_dict[channel]["index"]]
260
+ fill_flags[i] = False
261
  tpl_dict[channel]["matched"] = True
262
  in_dict[channel]["assigned"] = True
263
 
264
  # update the names
265
  tpl_order = tpl_montage.ch_names
266
 
267
+ stage1_info.update({
268
+ "unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
269
+ "missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
270
+ "mappingData" : [
271
+ {
272
+ "newOrder" : new_idx,
273
+ "fillFlags" : fill_flags
274
+ }
275
+ ]
276
+ })
277
  channel_info.update({
278
+ "templateOrder" : tpl_order,
279
+ "inputOrder" : in_order,
280
  "templateDict" : tpl_dict,
281
+ "inputDict" : in_dict
 
 
 
 
 
 
 
282
  })
283
+ return stage1_info, channel_info, tpl_montage, in_montage
 
 
 
 
 
 
284
 
285
+ def optimal_mapping(channel_info):
 
 
 
 
286
  tpl_order = channel_info["templateOrder"]
287
  in_order = channel_info["inputOrder"]
288
+ tpl_dict = channel_info["templateDict"]
289
+ in_dict = channel_info["inputDict"]
290
+ unassigned = get_unassigned_inputs(in_order, in_dict)
 
 
291
  # reset all tpl.matched to False
292
  for channel in tpl_dict:
293
  tpl_dict[channel]["matched"] = False
294
 
295
+ # get the 3D coordinates
296
+ all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
297
+ unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
298
+
299
+ # initialize the cost matrix for the Hungarian algorithm
300
  if len(unassigned) < 30:
301
  cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
302
  else:
303
  cost_matrix = np.zeros((30, len(unassigned)))
304
+ # fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
305
  for i in range(30):
306
  for j in range(len(unassigned)):
307
+ cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
 
308
 
309
+ # apply the Hungarian algorithm to optimally assign each in_channel to a tpl_channel
310
+ # by minimizing the total distances between their positions.
311
  row_idx, col_idx = linear_sum_assignment(cost_matrix)
312
 
313
+ # store the mapping result
314
  new_idx = [[]]*30
315
+ fill_flags = [True]*30
316
  for i in range(30):
317
  if col_idx[i] < len(unassigned): # filter out dummy channels
318
  tpl_channel = tpl_order[row_idx[i]]
319
  in_channel = unassigned[col_idx[i]]
320
+
321
+ new_idx[row_idx[i]] = [in_dict[in_channel]["index"]]
322
+ fill_flags[row_idx[i]] = False
323
  tpl_dict[tpl_channel]["matched"] = True
324
  in_dict[in_channel]["assigned"] = True
325
+ #print(f'{tpl_channel}({row_idx[i]}) <- {in_channel}({col_idx[i]})')
 
 
326
 
327
+ # fill the remaining empty tpl_channels
328
+ missing_channels = get_empty_templates(tpl_order, tpl_dict)
329
  if missing_channels != []:
330
  new_idx = find_neighbors(channel_info, missing_channels, new_idx)
331
 
332
+ mapping_data = {
333
+ "newOrder" : new_idx,
334
+ "fillFlags" : fill_flags
335
+ }
336
  channel_info.update({
337
  "templateDict" : tpl_dict,
338
  "inputDict" : in_dict
339
  })
340
+ return mapping_data, channel_info
341
+
342
+ def mapping_result(stage1_info, stage2_info, channel_info, filename):
343
+ # 1. calculate how many times the model needs to be run
344
+ unassigned_num = len(stage1_info["unassignedInputs"])
345
+ batch_num = math.ceil(unassigned_num/30) + 1
346
+
347
+ # 2. map the remaining in_channels
348
+ for i in range(1, batch_num):
349
+ # optimally select 30 in_channels to map to the tpl_channels based on proximity
350
+ new_mapping_data, channel_info = optimal_mapping(channel_info)
351
+ stage1_info["mappingData"] += [new_mapping_data]
352
+
353
+ # 3. save the mapping result
354
+ new_dict = {
355
+ #"templateOrder" : channel_info["templateOrder"],
356
+ #"inputOrder" : channel_info["inputOrder"],
357
+ "batchNum" : batch_num,
358
+ "mappingData" : stage1_info["mappingData"]
359
+ }
360
+ with open(filename, 'w') as jsonfile:
361
+ jsonfile.write(json.dumps(new_dict))
362
 
363
+ stage2_info["totalBatchNum"] = batch_num
364
+ return stage1_info, stage2_info, channel_info
 
365