audrey06100 commited on
Commit
857bb0f
·
1 Parent(s): 8b18526

update comment

Browse files
Files changed (3) hide show
  1. app.py +487 -337
  2. channel_mapping.py +46 -66
  3. utils.py +9 -8
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
 
 
3
  import os
4
  import random
5
  import math
@@ -9,7 +10,7 @@ import mne
9
  from mne.channels import read_custom_montage
10
 
11
  import utils
12
- from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin, find_neighbors
13
 
14
 
15
  quickstart = """
@@ -20,7 +21,8 @@ quickstart = """
20
 
21
  ## Channel locations
22
  Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
23
- >If you cannot obtain it, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage.
 
24
 
25
  ## Mapping
26
  (...)
@@ -55,32 +57,33 @@ Electroencephalography (EEG) signals are often contaminated with artifacts. It i
55
  """
56
 
57
  init_js = """
58
- (app_state, channel_info) => {
59
- app_state = JSON.parse(JSON.stringify(app_state));
60
  channel_info = JSON.parse(JSON.stringify(channel_info));
 
61
 
62
  let selector, classname, attribute;
63
  let channel, left, bottom;
64
 
65
- if(app_state.stage1State == "step2-selecting"){
66
  selector = "#radio-group > div:nth-of-type(2)";
67
  //classname = "radio";
68
  attribute = "value";
69
- }else if(app_state.stage1State == "step3-selecting"){
70
  selector = "#chkbox-group > div:nth-of-type(2)";
71
  //classname = "chkbox";
72
  attribute = "name";
73
  }else return;
74
 
75
 
76
- // add figure of the mapping result
77
  document.querySelector(selector).style.cssText = `
78
  position: relative;
79
  width: 100%;
80
  aspect-ratio: 1;
81
  //width: 560px;
82
  //height: 560px;
83
- background: url("file=${app_state.filenames.raw_montage}");
84
  background-size: contain;
85
 
86
  `;
@@ -99,13 +102,13 @@ init_js = """
99
 
100
 
101
  // add indication for the missing channels
102
- channel = app_state.missingTemplates[0];
103
  left = channel_info.templateDict[channel].css_position[0];
104
  bottom = channel_info.templateDict[channel].css_position[1];
105
 
106
  let dot_rule = `
107
  ${selector}::before {
108
- content: '';
109
  position: absolute;
110
  background-color: red;
111
  width: 10px;
@@ -145,14 +148,15 @@ init_js = """
145
  """
146
 
147
  update_js = """
148
- (app_state, channel_info) => {
149
- app_state = JSON.parse(JSON.stringify(app_state));
150
  channel_info = JSON.parse(JSON.stringify(channel_info));
 
151
 
152
  let selector;
153
  let channel, left, bottom;
154
 
155
- if(app_state.stage1State == "step2-selecting"){
156
  selector = "#radio-group > div:nth-of-type(2)";
157
 
158
  // update the radios
@@ -166,12 +170,12 @@ update_js = """
166
  item.className = "";
167
  item.querySelector(":scope > span").innerText = "";
168
  });
169
- }else if(app_state.stage1State == "step3-selecting"){
170
  selector = "#chkbox-group > div:nth-of-type(2)";
171
  }else return;
172
 
173
  // update indication
174
- channel = app_state.missingTemplates[app_state["fillingCount"]-1];
175
  left = channel_info.templateDict[channel].css_position[0];
176
  bottom = channel_info.templateDict[channel].css_position[1];
177
 
@@ -202,7 +206,7 @@ update_js = """
202
  }
203
  `;
204
 
205
- // check if indicator already exist
206
  const styleSheet = document.styleSheets[0];
207
  for(let i=0; i<styleSheet.cssRules.length; i++){
208
  let tmp = styleSheet.cssRules[i].selectorText;
@@ -219,7 +223,7 @@ update_js = """
219
 
220
  with gr.Blocks() as demo:
221
 
222
- app_state_json = gr.JSON(visible=False)
223
  channel_info_json = gr.JSON(visible=False)
224
 
225
  with gr.Row():
@@ -234,27 +238,25 @@ with gr.Blocks() as demo:
234
  gr.Markdown("# 1.Channel Mapping")
235
  # ------------------------input--------------------------
236
  with gr.Row():
237
- in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
238
- in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
 
239
  with gr.Row():
240
  in_samplerate = gr.Textbox(label="Sampling rate (Hz)", scale=2)
241
- map_btn = gr.Button("Mapping", scale=1)
242
 
243
  # ------------------------mapping------------------------
244
  # description for stage1-123
245
- desc_md = gr.Markdown("### Step1: Mapping result", visible=False) # """??? # test
246
-
247
  # stage1-1 : mapping result
248
  with gr.Row():
249
- tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
250
  mapped_montage = gr.Image(label="Input channels", visible=False)
251
-
252
  # stage1-2 : assign unmatched input channels to empty template channels
253
  radio_group = gr.Radio(elem_id="radio-group", visible=False)
254
-
255
  # stage1-3 : select a way to fill the empty template channels
256
  with gr.Row():
257
- in_fill_mode = gr.Dropdown(choices=["mean", "zero"],
258
  value="mean",
259
  label="Filling method",
260
  visible=False,
@@ -263,7 +265,7 @@ with gr.Blocks() as demo:
263
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
264
 
265
  with gr.Row():
266
- clear_btn = gr.Button("Clear", visible=False) #, interactive=False
267
  step2_btn = gr.Button("Next", visible=False)
268
  step3_btn = gr.Button("Next", visible=False)
269
  next_btn = gr.Button("Next step", visible=False)
@@ -273,13 +275,11 @@ with gr.Blocks() as demo:
273
  gr.Markdown("# 2.Decode Data")
274
  # ------------------------input--------------------------
275
  with gr.Row():
276
- in_model_name = gr.Dropdown(choices=[
277
  ("ART", "EEGART"),
278
  ("IC-U-Net", "ICUNet"),
279
  ("IC-U-Net++", "UNetpp"),
280
- ("IC-U-Net-Attn", "AttUnet"),
281
- "(mapped data)",
282
- "(denoised data)"],
283
  value="EEGART",
284
  label="Model",
285
  scale=2)
@@ -287,8 +287,7 @@ with gr.Blocks() as demo:
287
 
288
  # ------------------------output-------------------------
289
  batch_md = gr.Markdown(visible=False)
290
- out_denoised_data = gr.File(label="Denoised data", visible=False)
291
-
292
  # -------------------------------------------------------
293
 
294
  with gr.Row():
@@ -302,59 +301,91 @@ with gr.Blocks() as demo:
302
  gr.Markdown()
303
  with gr.Tab("QuickStart"):
304
  gr.Markdown(quickstart)
305
-
306
- #demo.load(js=tmp_js)
307
-
308
- # -------------------------stage1: channel mapping-------------------------------
309
- def reset_all(raw_data, raw_loc, samplerate):
310
- # verify that all required inputs have been provided
311
- if raw_data == None or raw_loc == None:
312
- gr.Warning('Please upload both the raw data and the channel location files.')
313
- return
314
- if samplerate == "":
315
- gr.Warning('Please enter the sampling rate.')
316
- return
317
-
318
- # establish temp folder
319
- filepath = os.path.dirname(str(raw_data))
 
 
 
320
  try:
321
- os.mkdir(filepath+"/temp_data/")
322
  except OSError as e:
323
- utils.dataDelete(filepath+"/temp_data/")
324
- os.mkdir(filepath+"/temp_data/")
325
- #print(e)
 
 
 
326
 
327
- # initialize channel_info, app_state
328
  channel_info = {}
329
- app_state = {
330
- "filepath": filepath+"/temp_data/",
331
- "filenames": {},
332
- "sampleRate": int(samplerate),
333
- "stage1State" : "step1"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  }
335
-
336
  # reset layout
337
- return {app_state_json : app_state,
338
  channel_info_json : channel_info,
339
  # ------------------stage1-----------------------
340
- desc_md : gr.Markdown("### Step1: Mapping result", visible=False),
 
341
  tpl_montage : gr.Image(visible=False),
342
  mapped_montage : gr.Image(value=None, visible=False),
343
  radio_group : gr.Radio(choices=[], value=[], label="", visible=False),
344
- in_fill_mode : gr.Dropdown(value="mean", visible=False),
345
  chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
346
  fillmode_btn : gr.Button(visible=False),
347
  clear_btn : gr.Button(visible=False),
348
  step2_btn : gr.Button(visible=False),
349
  step3_btn : gr.Button(visible=False),
350
- next_btn : gr.Button(visible=False),
351
  # ------------------stage2-----------------------
352
  run_btn : gr.Button(interactive=False),
353
  batch_md : gr.Markdown(visible=False),
354
- out_denoised_data : gr.File(visible=False)}
355
 
356
 
357
- # ---------------------------stage1-1-------------------------------
 
 
358
  def save_figures(channel_info, filename1, filename2):
359
 
360
  template_montage = read_custom_montage("./template_chanlocs.loc")
@@ -363,6 +394,14 @@ with gr.Blocks() as demo:
363
  template_order = channel_info["templateOrder"]
364
  input_order = channel_info["inputOrder"]
365
 
 
 
 
 
 
 
 
 
366
  # get template's head figure
367
  tpl_fig = template_montage.plot()
368
  tpl_ax = tpl_fig.axes[0]
@@ -373,14 +412,6 @@ with gr.Blocks() as demo:
373
  head_lines.append((x,y))
374
  plt.close()
375
 
376
- # get template's and input's 2d coords
377
- tpl_x = [template_dict[channel]["coord_2d"][0] for channel in template_order]
378
- tpl_y = [template_dict[channel]["coord_2d"][1] for channel in template_order]
379
- in_x = [input_dict[channel]["coord_2d"][0] for channel in input_order]
380
- in_y = [input_dict[channel]["coord_2d"][1] for channel in input_order]
381
- tpl_coords = np.vstack((tpl_x, tpl_y)).T
382
- in_coords = np.vstack((in_x, in_y)).T
383
-
384
  # -------------------------plot input montage------------------------------
385
  fig = plt.figure(figsize=(6.4,6.4), dpi=100)
386
  ax = fig.add_subplot(111)
@@ -391,12 +422,11 @@ with gr.Blocks() as demo:
391
  # plot template's head
392
  for x, y in head_lines:
393
  ax.plot(x, y, color='black', linewidth=1.0)
394
- # plot input channels
395
  ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
396
  for i, channel in enumerate(input_order):
397
  ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
398
-
399
- # save raw_montage
400
  fig.savefig(filename1)
401
 
402
  # ---------------------------add indications-------------------------------
@@ -406,15 +436,14 @@ with gr.Blocks() as demo:
406
  ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
407
  for i in indices:
408
  ax.text(in_coords[i,0]+0.003, in_coords[i,1], input_order[i], color='red', fontsize=10.0, va='center')
409
-
410
  # save mapped_montage
411
  fig.savefig(filename2)
412
- plt.close()
413
 
414
  # -------------------------------------------------------------------------
415
- # save the template and input channels' display position (in px).
416
  tpl_coords = ax.transData.transform(tpl_coords)
417
  in_coords = ax.transData.transform(in_coords)
 
418
 
419
  for i, channel in enumerate(template_order):
420
  css_left = (tpl_coords[i,0]-11)/6.4
@@ -432,156 +461,189 @@ with gr.Blocks() as demo:
432
  })
433
  return channel_info
434
 
435
- def mapping_result(app_state, channel_info):
436
- filepath = app_state["filepath"]
437
- filename1 = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
 
 
 
438
  filename2 = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
439
  channel_info = save_figures(channel_info, filename1, filename2)
440
-
441
- app_state["filenames"].update({
442
- "raw_montage" : filename1,
443
  "mapped_montage" : filename2
444
  })
445
 
446
- # ------------------determine the next step-----------------------
447
 
448
- in_num = len(channel_info["inputOrder"])
449
- matched_num = 30 - len(app_state["missingTemplates"])
450
 
451
- # if the input channels(>=30) has all the 30 template channels
452
  # -> stage2
453
  if matched_num == 30:
454
- app_state["stage1State"] = "finished"
455
  gr.Info('The mapping process has been finished.')
 
 
 
 
 
 
 
 
 
 
456
 
457
- return {app_state_json : app_state,
 
458
  channel_info_json : channel_info,
459
- desc_md : gr.Markdown("### Mapping result", visible=True),
 
460
  tpl_montage : gr.Image(visible=True),
461
  mapped_montage : gr.Image(value=filename2, visible=True),
462
  run_btn : gr.Button(interactive=True)}
463
 
464
- # if matched channels < 30, and there're still some unmatched input channels
465
- # -> assign these input channels to nearby unmatched/empty template channels
466
- if in_num > matched_num:
467
- app_state["stage1State"] = "step2-initializing"
468
-
469
- # if input channels < 30, but all of them can match to some template channels
470
- # -> directly use fill_mode to fill the remaining channels
471
- elif in_num == matched_num:
472
- app_state["stage1State"] = "step3-initializing"
473
-
474
- return {app_state_json : app_state,
475
- channel_info_json : channel_info,
476
- desc_md : gr.Markdown("### Step1: Mapping result", visible=True),
477
- tpl_montage : gr.Image(visible=True),
478
- mapped_montage : gr.Image(value=filename2, visible=True),
479
- next_btn : gr.Button("Next step", visible=True)}
 
 
 
 
 
 
 
 
 
 
 
480
 
481
- map_btn.click(
482
  fn = reset_all,
483
- inputs = [in_raw_data, in_raw_loc, in_samplerate],
484
- outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, mapped_montage, radio_group,
485
- in_fill_mode, chkbox_group, fillmode_btn, clear_btn, step2_btn, step3_btn, next_btn,
486
- run_btn, batch_md, out_denoised_data]
487
  ).success(
488
  fn = mapping_stage1,
489
- inputs = [app_state_json, channel_info_json, in_raw_loc],
490
- outputs = [app_state_json, channel_info_json, desc_md]
491
-
492
  ).success(
493
  fn = mapping_result,
494
- inputs = [app_state_json, channel_info_json],
495
- outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, mapped_montage, next_btn, run_btn]
496
  )
497
 
498
-
499
- def init_next_step(app_state, channel_info, selected_radio, selected_chkbox):
 
 
 
 
500
 
501
  # stage1-1 -> stage1-2
502
- if app_state["stage1State"] == "step2-initializing":
503
- print('step1 -> step2')
504
- app_state["missingTemplates"] = [channel for channel in channel_info["templateOrder"]
505
- if channel_info["templateDict"][channel]["matched"]==False]
506
- app_state.update({
507
- "stage1State" : "step2-selecting",
 
 
 
 
508
  "fillingCount" : 1,
509
- "totalFillingNum" : len(app_state["missingTemplates"])
510
  })
 
 
511
 
512
- name = app_state["missingTemplates"][0]
513
- label = name+' (1/'+str(app_state["totalFillingNum"])+')'
514
-
515
- if len(app_state["stage1UnassignedInputs"])==1 or app_state["totalFillingNum"]==1:
516
- return {app_state_json : app_state,
517
  channel_info_json : channel_info,
518
- desc_md : gr.Markdown("### Step2: Assign unmatched input channels"),
519
  tpl_montage : gr.Image(visible=False),
520
  mapped_montage : gr.Image(visible=False),
521
- radio_group : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=label, visible=True),
522
  clear_btn : gr.Button(visible=True),
523
  next_btn : gr.Button("Next step")}
524
  else:
525
- return {app_state_json : app_state,
526
  channel_info_json : channel_info,
527
- desc_md : gr.Markdown("### Step2: Assign unmatched input channels"),
528
  tpl_montage : gr.Image(visible=False),
529
  mapped_montage : gr.Image(visible=False),
530
- radio_group : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=label, visible=True),
531
  clear_btn : gr.Button(visible=True),
532
  step2_btn : gr.Button(visible=True),
533
  next_btn : gr.Button(visible=False)}
534
 
535
  # stage1-1 -> stage1-3
536
- elif app_state["stage1State"] == "step3-initializing":
537
- print('step1 -> step3')
538
- app_state["missingTemplates"] = [channel for channel in channel_info["templateOrder"]
539
- if channel_info["templateDict"][channel]["matched"]==False]
540
- app_state.update({
541
- "stage1State" : "step3-initializing",
542
- "fillingCount" : 1,
543
- "totalFillingNum" : len(app_state["missingTemplates"])
544
- })
545
- return {app_state_json : app_state,
546
- channel_info_json : channel_info,
547
- desc_md : gr.Markdown("### Step3: Fill the remaining template channels"),
548
  tpl_montage : gr.Image(visible=False),
549
  mapped_montage : gr.Image(visible=False),
550
- in_fill_mode : gr.Dropdown(visible=True),
551
  fillmode_btn : gr.Button(visible=True),
552
  next_btn : gr.Button(visible=False)}
553
 
554
  # stage1-2 -> stage1-3 or stage2
555
- elif app_state["stage1State"] == "step2-selecting":
556
 
557
- # save info before clicking on next_btn
558
- prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
559
- prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
560
- if selected_radio == []:
561
- app_state["stage1NewOrder"][prev_target_idx] = []
562
- else:
563
- selected_idx = channel_info["inputDict"][selected_radio]["index"]
564
- app_state["stage1NewOrder"][prev_target_idx] = [selected_idx]
565
 
 
 
 
 
 
 
 
 
 
566
  channel_info["templateDict"][prev_target_name]["matched"] = True
567
  channel_info["inputDict"][selected_radio]["assigned"] = True
568
- print(prev_target_name, '<-', selected_radio)
569
 
570
- app_state.update({
571
- "stage1UnassignedInputs" : [channel for channel in channel_info["inputOrder"]
572
- if channel_info["inputDict"][channel]["assigned"]==False],
573
- "missingTemplates" : [channel for channel in channel_info["templateOrder"]
574
- if channel_info["templateDict"][channel]["matched"]==False]
575
- })
 
 
576
 
577
- # if all the unmatched template channels were filled by input channels
578
  # -> stage2
579
- if len(app_state["missingTemplates"]) == 0:
580
- print('step2 -> stage2')
 
581
  gr.Info('The mapping process has been finished.')
582
- app_state["stage1State"] = "finished"
583
 
584
- return {app_state_json : app_state,
 
585
  channel_info_json : channel_info,
586
  desc_md : gr.Markdown(visible=False),
587
  radio_group : gr.Radio(visible=False),
@@ -591,42 +653,43 @@ with gr.Blocks() as demo:
591
 
592
  # -> stage1-3
593
  else:
594
- print('step2 -> step3')
595
- app_state.update({
596
- "stage1State" : "step3-initializing",
597
- "fillingCount" : 1,
598
- "totalFillingNum" : len(app_state["missingTemplates"])
599
- })
600
- return {app_state_json : app_state,
 
601
  channel_info_json : channel_info,
602
- desc_md : gr.Markdown("### Step3: Fill the remaining template channels"),
603
  radio_group : gr.Radio(visible=False),
604
- in_fill_mode : gr.Dropdown(visible=True),
605
  fillmode_btn : gr.Button(visible=True),
606
  clear_btn : gr.Button(visible=False),
607
  next_btn : gr.Button(visible=False)}
608
 
609
  # stage1-3 -> stage2
610
- elif app_state["stage1State"] == "step3-selecting":
611
-
612
- # save info before clicking on next_btn
613
- prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
614
- prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
615
- if selected_chkbox == []:
616
- app_state["stage1NewOrder"][prev_target_idx] = []
617
- else:
618
- selected_idx = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
619
- app_state["stage1NewOrder"][prev_target_idx] = selected_idx
620
- #print(f'{prev_target_name}({prev_target_idx}): {selected_chkbox}')
621
-
622
  gr.Info('The mapping process has been finished.')
623
- app_state["stage1State"] = "finished"
624
- print('step3 -> stage2')
625
 
626
- app_state["missingTemplates"] = [channel for channel in channel_info["templateOrder"]
627
- if channel_info["templateDict"][channel]["matched"]==False]
 
 
 
 
 
 
 
 
 
 
628
 
629
- return {app_state_json : app_state,
 
630
  desc_md : gr.Markdown(visible=False),
631
  chkbox_group : gr.CheckboxGroup(visible=False),
632
  next_btn : gr.Button(visible=False),
@@ -634,62 +697,71 @@ with gr.Blocks() as demo:
634
 
635
  next_btn.click(
636
  fn = init_next_step,
637
- inputs = [app_state_json, channel_info_json, radio_group, chkbox_group],
638
- outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, mapped_montage, radio_group,
639
- in_fill_mode, chkbox_group, fillmode_btn, clear_btn, step2_btn, next_btn, run_btn]
640
  ).success(
641
  fn = None,
642
  js = init_js,
643
- inputs = [app_state_json, channel_info_json],
644
  outputs = []
645
  )
646
 
647
- # ---------------------------stage1-2-------------------------------
648
- def update_radio(app_state, channel_info, selected):
 
 
 
 
649
 
650
- # save info before clicking on next_btn
651
- prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
652
- prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
653
- if selected == []:
654
- app_state["stage1NewOrder"][prev_target_idx] = []
655
- else:
656
- selected_idx = channel_info["inputDict"][selected]["index"]
657
- app_state["stage1NewOrder"][prev_target_idx] = [selected_idx]
658
 
 
 
 
 
 
 
 
 
 
659
  channel_info["templateDict"][prev_target_name]["matched"] = True
660
  channel_info["inputDict"][selected]["assigned"] = True
661
- print(prev_target_name, '<-', selected)
 
 
 
662
 
663
- # update the current round
664
- app_state["fillingCount"] += 1
665
- app_state["stage1UnassignedInputs"] = [channel for channel in channel_info["inputOrder"]
666
  if channel_info["inputDict"][channel]["assigned"]==False]
 
 
 
667
 
668
- target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
669
- radio_label = target_name+' ('+str(app_state["fillingCount"])+'/'+str(app_state["totalFillingNum"])+')'
670
-
671
- if len(app_state["stage1UnassignedInputs"])==1 or app_state["fillingCount"]==app_state["totalFillingNum"]:
672
- return {app_state_json : app_state,
673
  channel_info_json : channel_info,
674
- radio_group : gr.Radio(choices=app_state["stage1UnassignedInputs"],
675
  value=[], label=radio_label),
676
  step2_btn : gr.Button(visible=False),
677
  next_btn : gr.Button("Next step", visible=True)}
678
  else:
679
- return {app_state_json : app_state,
680
  channel_info_json : channel_info,
681
- radio_group : gr.Radio(choices=app_state["stage1UnassignedInputs"],
682
  value=[], label=radio_label)}
683
 
684
  step2_btn.click(
685
  fn = update_radio,
686
- inputs = [app_state_json, channel_info_json, radio_group],
687
- outputs = [app_state_json, channel_info_json, radio_group, step2_btn, next_btn]
688
-
689
  ).success(
690
  fn = None,
691
  js = update_js,
692
- inputs = [app_state_json, channel_info_json],
693
  outputs = []
694
  )
695
 
@@ -700,183 +772,261 @@ with gr.Blocks() as demo:
700
  )
701
 
702
 
703
- # ---------------------------stage1-3-------------------------------
704
- def fill_value(app_state, channel_info, fill_mode):
 
 
 
705
 
706
- if fill_mode == 'zero':
707
- app_state["stage1State"] = "finished"
708
  gr.Info('The mapping process has been finished.')
709
 
710
- return {app_state_json : app_state,
 
711
  desc_md : gr.Markdown(visible=False),
712
- in_fill_mode : gr.Dropdown(visible=False),
713
  fillmode_btn : gr.Button(visible=False),
714
  run_btn : gr.Button(interactive=True)}
715
 
716
- elif fill_mode == 'mean':
717
- app_state["stage1State"] = "step3-selecting"
718
- app_state = find_neighbors(app_state, channel_info)
 
 
719
 
720
- # init stage1-3-selecting
721
- target_name = app_state["missingTemplates"][0]
722
- target_idx = channel_info["templateDict"][target_name]["index"]
 
 
 
 
 
 
723
 
724
- chkbox_value = app_state["stage1NewOrder"][target_idx]
 
 
 
725
  chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
726
- chkbox_label = target_name+' (1/'+str(app_state["totalFillingNum"])+')'
727
 
728
- if app_state["totalFillingNum"] == 1:
729
- return {app_state_json : app_state,
730
- in_fill_mode : gr.Dropdown(visible=False),
 
 
 
731
  fillmode_btn : gr.Button(visible=False),
732
  chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
733
  value=chkbox_value, label=chkbox_label, visible=True),
734
  next_btn : gr.Button(visible=True)}
735
  else:
736
- return {app_state_json : app_state,
737
- in_fill_mode : gr.Dropdown(visible=False),
 
738
  fillmode_btn : gr.Button(visible=False),
739
  chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
740
  value=chkbox_value, label=chkbox_label, visible=True),
741
  step3_btn : gr.Button(visible=True)}
742
 
743
- def update_chkbox(app_state, channel_info, selected):
744
-
745
- # save info before clicking on next_btn
746
- prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
747
- prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
748
- if selected == []:
749
- app_state["stage1NewOrder"][prev_target_idx] = []
750
- else:
751
- selected_idx = [channel_info["inputDict"][channel]["index"] for channel in selected]
752
- app_state["stage1NewOrder"][prev_target_idx] = selected_idx
753
- #print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
754
 
755
- # update the current round
756
- app_state["fillingCount"] += 1
757
 
758
- target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
759
- target_idx = channel_info["templateDict"][target_name]["index"]
 
 
 
 
 
 
 
 
 
 
760
 
761
- chkbox_value = app_state["stage1NewOrder"][target_idx]
 
 
 
762
  chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
763
- chkbox_label = target_name+' ('+str(app_state["fillingCount"])+'/'+str(app_state["totalFillingNum"])+')'
764
 
765
- if app_state["fillingCount"] == app_state["totalFillingNum"]:
766
- return {app_state_json : app_state,
 
 
767
  chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label),
768
  step3_btn : gr.Button(visible=False),
769
  next_btn : gr.Button("Submit", visible=True)}
770
  else:
771
- return {app_state_json : app_state,
772
  chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)}
773
 
774
  fillmode_btn.click(
775
  fn = fill_value,
776
- inputs = [app_state_json, channel_info_json, in_fill_mode],
777
- outputs = [app_state_json, desc_md, in_fill_mode, fillmode_btn, chkbox_group, step3_btn, next_btn, run_btn]
778
  ).success(
779
  fn = None,
780
  js = init_js,
781
- inputs = [app_state_json, channel_info_json],
782
  outputs = []
783
  )
784
 
785
  step3_btn.click(
786
  fn = update_chkbox,
787
- inputs = [app_state_json, channel_info_json, chkbox_group],
788
- outputs = [app_state_json, chkbox_group, step3_btn, next_btn]
789
-
790
  ).success(
791
  fn = None,
792
  js = update_js,
793
- inputs = [app_state_json, channel_info_json],
794
  outputs = []
795
  )
796
 
797
- # -------------------------stage2: decode data-------------------------------
798
- def delete_file(filename):
799
- try:
800
- os.remove(filename)
801
- except OSError as e:
802
- print(e)
803
 
804
- def reset_run(app_state, channel_info, raw_data, model_name):
 
 
 
 
 
805
 
806
- # reset in.assigned back to the state after stage1
807
- for channel in app_state["stage1UnassignedInputs"]:
808
- channel_info["inputDict"][channel]["assigned"] = False
809
-
810
- filepath = app_state["filepath"]
811
- delete_file(filepath+'mapped.csv')
812
- delete_file(filepath+'denoised.csv')
813
-
814
- input_name = os.path.basename(str(raw_data))
815
- output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
816
 
817
- in_num = len(channel_info["inputOrder"])
818
- assigned_num = len(app_state["stage1UnassignedInputs"])
819
- batch_num = math.ceil((in_num-assigned_num)/30) + 1
 
 
 
820
 
821
- app_state["filenames"]["denoised"] = filepath + output_name
822
- app_state.update({
823
- "runningState" : "stage1",
824
- "batchCount" : 1,
825
- "totalBatchNum" : batch_num,
826
- "stage2UnassignedInputs" : app_state["stage1UnassignedInputs"],
827
- "stage2NewOrder" : [[]]*30,
 
 
 
 
 
828
  })
829
- return {app_state_json : app_state,
830
  channel_info_json : channel_info,
831
- run_btn : gr.Button(interactive=False),
832
  batch_md : gr.Markdown(visible=False),
833
- out_denoised_data : gr.File(visible=False)}
834
 
835
- def run_model(app_state, channel_info, raw_data, model_name):
836
- filepath = app_state["filepath"]
837
- samplerate = app_state["sampleRate"]
838
- new_filename = app_state["filenames"]["denoised"]
839
 
840
- while app_state["runningState"] != "finished":
841
- md = 'Running model('+str(app_state["batchCount"])+'/'+str(app_state["totalBatchNum"])+')...'
842
- yield {batch_md : gr.Markdown(md, visible=True)}
843
-
844
- if app_state["batchCount"] > 1:
845
- app_state, channel_info = mapping_stage2(app_state, channel_info)
846
- if app_state["runningState"] == "finished":
847
- #yield {batch_md : gr.Markdown("error", visible=True)}
848
- break
849
-
850
- reorder_to_template(app_state, raw_data)
851
- # step1: Data preprocessing
852
- total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
853
- # step2: Signal reconstruction
854
- utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
855
- reorder_to_origin(app_state, channel_info, new_filename)
 
 
 
 
 
 
856
 
857
- #if model_name == "(mapped data)":
858
- #return {out_denoised_data : filepath + 'mapped.csv'}
859
- #elif model_name == "(denoised data)":
860
- #return {out_denoised_data : filepath + 'denoised.csv'}
861
 
862
- delete_file(filepath+'mapped.csv')
863
- delete_file(filepath+'denoised.csv')
864
- app_state["batchCount"] += 1
 
 
 
 
 
865
 
866
- yield {run_btn : gr.Button(interactive=True),
867
- batch_md : gr.Markdown(visible=False),
868
- out_denoised_data : gr.File(new_filename, visible=True)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
869
 
870
  run_btn.click(
871
  fn = reset_run,
872
- inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name],
873
- outputs = [app_state_json, channel_info_json, run_btn, batch_md, out_denoised_data]
874
 
875
  ).success(
876
  fn = run_model,
877
- inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name],
878
- outputs = [run_btn, batch_md, out_denoised_data]
879
  )
880
 
881
  if __name__ == "__main__":
882
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ import time
4
  import os
5
  import random
6
  import math
 
10
  from mne.channels import read_custom_montage
11
 
12
  import utils
13
+ from channel_mapping import mapping_stage1, mapping_stage2, reorder_input_data, restore_original_order, find_neighbors
14
 
15
 
16
  quickstart = """
 
21
 
22
  ## Channel locations
23
  Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
24
+ **Note:**
25
+ If you cannot obtain it, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage.
26
 
27
  ## Mapping
28
  (...)
 
57
  """
58
 
59
  init_js = """
60
+ (app_info, channel_info) => {
61
+ app_info = JSON.parse(JSON.stringify(app_info));
62
  channel_info = JSON.parse(JSON.stringify(channel_info));
63
+ stage1_info = app_info.stage1
64
 
65
  let selector, classname, attribute;
66
  let channel, left, bottom;
67
 
68
+ if(stage1_info.state == "step2-selecting"){
69
  selector = "#radio-group > div:nth-of-type(2)";
70
  //classname = "radio";
71
  attribute = "value";
72
+ }else if(stage1_info.state == "step3-selecting"){
73
  selector = "#chkbox-group > div:nth-of-type(2)";
74
  //classname = "chkbox";
75
  attribute = "name";
76
  }else return;
77
 
78
 
79
+ // add figure of the input montage
80
  document.querySelector(selector).style.cssText = `
81
  position: relative;
82
  width: 100%;
83
  aspect-ratio: 1;
84
  //width: 560px;
85
  //height: 560px;
86
+ background: url("file=${stage1_info.filenames.input_montage}");
87
  background-size: contain;
88
 
89
  `;
 
102
 
103
 
104
  // add indication for the missing channels
105
+ channel = stage1_info.missingTemplates[0];
106
  left = channel_info.templateDict[channel].css_position[0];
107
  bottom = channel_info.templateDict[channel].css_position[1];
108
 
109
  let dot_rule = `
110
  ${selector}::before {
111
+ content: "";
112
  position: absolute;
113
  background-color: red;
114
  width: 10px;
 
148
  """
149
 
150
  update_js = """
151
+ (app_info, channel_info) => {
152
+ app_info = JSON.parse(JSON.stringify(app_info));
153
  channel_info = JSON.parse(JSON.stringify(channel_info));
154
+ stage1_info = app_info.stage1
155
 
156
  let selector;
157
  let channel, left, bottom;
158
 
159
+ if(stage1_info.state == "step2-selecting"){
160
  selector = "#radio-group > div:nth-of-type(2)";
161
 
162
  // update the radios
 
170
  item.className = "";
171
  item.querySelector(":scope > span").innerText = "";
172
  });
173
+ }else if(stage1_info.state == "step3-selecting"){
174
  selector = "#chkbox-group > div:nth-of-type(2)";
175
  }else return;
176
 
177
  // update indication
178
+ channel = stage1_info.missingTemplates[stage1_info["fillingCount"]-1];
179
  left = channel_info.templateDict[channel].css_position[0];
180
  bottom = channel_info.templateDict[channel].css_position[1];
181
 
 
206
  }
207
  `;
208
 
209
+ // update the rules
210
  const styleSheet = document.styleSheets[0];
211
  for(let i=0; i<styleSheet.cssRules.length; i++){
212
  let tmp = styleSheet.cssRules[i].selectorText;
 
223
 
224
  with gr.Blocks() as demo:
225
 
226
+ app_info_json = gr.JSON(visible=False)
227
  channel_info_json = gr.JSON(visible=False)
228
 
229
  with gr.Row():
 
238
  gr.Markdown("# 1.Channel Mapping")
239
  # ------------------------input--------------------------
240
  with gr.Row():
241
+ in_data_file = gr.File(label="Raw data (.csv)", file_types=[".csv"])
242
+ in_loc_file = gr.File(label="Channel locations (.loc, .locs, .xyz, .sfp, .txt)",
243
+ file_types=[".loc", "locs", ".xyz", ".sfp", ".txt"])
244
  with gr.Row():
245
  in_samplerate = gr.Textbox(label="Sampling rate (Hz)", scale=2)
246
+ map_btn = gr.Button("Mapping", interactive=False, scale=1)
247
 
248
  # ------------------------mapping------------------------
249
  # description for stage1-123
250
+ desc_md = gr.Markdown(visible=False)
 
251
  # stage1-1 : mapping result
252
  with gr.Row():
253
+ tpl_montage = gr.Image("./template_montage.png", label="Template channels", visible=False)
254
  mapped_montage = gr.Image(label="Input channels", visible=False)
 
255
  # stage1-2 : assign unmatched input channels to empty template channels
256
  radio_group = gr.Radio(elem_id="radio-group", visible=False)
 
257
  # stage1-3 : select a way to fill the empty template channels
258
  with gr.Row():
259
+ in_fillmode = gr.Dropdown(choices=["mean", "zero"],
260
  value="mean",
261
  label="Filling method",
262
  visible=False,
 
265
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
266
 
267
  with gr.Row():
268
+ clear_btn = gr.Button("Clear", visible=False)
269
  step2_btn = gr.Button("Next", visible=False)
270
  step3_btn = gr.Button("Next", visible=False)
271
  next_btn = gr.Button("Next step", visible=False)
 
275
  gr.Markdown("# 2.Decode Data")
276
  # ------------------------input--------------------------
277
  with gr.Row():
278
+ in_modelname = gr.Dropdown(choices=[
279
  ("ART", "EEGART"),
280
  ("IC-U-Net", "ICUNet"),
281
  ("IC-U-Net++", "UNetpp"),
282
+ ("IC-U-Net-Attn", "AttUnet")],
 
 
283
  value="EEGART",
284
  label="Model",
285
  scale=2)
 
287
 
288
  # ------------------------output-------------------------
289
  batch_md = gr.Markdown(visible=False)
290
+ out_data_file = gr.File(label="Denoised data", visible=False)
 
291
  # -------------------------------------------------------
292
 
293
  with gr.Row():
 
301
  gr.Markdown()
302
  with gr.Tab("QuickStart"):
303
  gr.Markdown(quickstart)
304
+
305
+
306
+ # verify that all required inputs have been provided
307
+ @gr.on(triggers = [in_data_file.upload, in_data_file.clear, in_loc_file.upload, in_loc_file.clear, in_samplerate.change],
308
+ inputs = [in_data_file, in_loc_file, in_samplerate], outputs = map_btn)
309
+ def check_input(in_data, in_loc, samplerate):
310
+ if in_data!=None and in_loc!=None and samplerate!="":
311
+ return gr.Button(interactive=True)
312
+ else:
313
+ return gr.Button(interactive=False)
314
+
315
+
316
+ # +========================================================================================+
317
+ # | stage1: channel mapping |
318
+ # +========================================================================================+
319
+ def reset_all(in_data, in_loc, samplerate):
320
+ # establish a new folder for the current session
321
+ filepath = os.path.dirname(str(in_data))
322
  try:
323
+ os.mkdir(filepath+"/session_data/")
324
  except OSError as e:
325
+ utils.dataDelete(filepath+"/session_data/")
326
+ os.mkdir(filepath+"/session_data/")
327
+ print(e)
328
+ # establish new folders for stage1 and stage2
329
+ os.mkdir(filepath+"/session_data/stage1/")
330
+ os.mkdir(filepath+"/session_data/stage2/")
331
 
332
+ # initialize channel_info, app_info
333
  channel_info = {}
334
+ app_info = {
335
+ "rootFilepath" : filepath+"/session_data/",
336
+ "sampleRate" : int(samplerate),
337
+ #"currentStage" : "stage1",
338
+ "stage1" : {
339
+ "filepath" : filepath+"/session_data/stage1/",
340
+ "filenames" : {
341
+ "input_data" : in_data,
342
+ "input_loc" : in_loc,
343
+ "input_montage" : "",
344
+ "mapped_montage" : ""
345
+ },
346
+ "state" : None,
347
+ "fillingCount" : None,
348
+ "totalFillingNum" : None,
349
+ "newOrder" : None,
350
+ "unassignedInputs" : None,
351
+ "missingTemplates" : None
352
+ },
353
+ "stage2" : {
354
+ "filepath" : filepath+"/session_data/stage2/",
355
+ "filenames" : {
356
+ "output_data" : ""
357
+ },
358
+ #"state" : None,
359
+ "totalBatchNum" : None,
360
+ "newOrder" : None,
361
+ "unassignedInputs" : None
362
+ }
363
  }
 
364
  # reset layout
365
+ return {app_info_json : app_info,
366
  channel_info_json : channel_info,
367
  # ------------------stage1-----------------------
368
+ map_btn : gr.Button(interactive=False),
369
+ desc_md : gr.Markdown(visible=False),
370
  tpl_montage : gr.Image(visible=False),
371
  mapped_montage : gr.Image(value=None, visible=False),
372
  radio_group : gr.Radio(choices=[], value=[], label="", visible=False),
373
+ in_fillmode : gr.Dropdown(value="mean", visible=False),
374
  chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
375
  fillmode_btn : gr.Button(visible=False),
376
  clear_btn : gr.Button(visible=False),
377
  step2_btn : gr.Button(visible=False),
378
  step3_btn : gr.Button(visible=False),
379
+ next_btn : gr.Button("Next step", visible=False),
380
  # ------------------stage2-----------------------
381
  run_btn : gr.Button(interactive=False),
382
  batch_md : gr.Markdown(visible=False),
383
+ out_data_file : gr.File(visible=False)}
384
 
385
 
386
+ # +========================================================================================+
387
+ # | stage1-1 |
388
+ # +========================================================================================+
389
  def save_figures(channel_info, filename1, filename2):
390
 
391
  template_montage = read_custom_montage("./template_chanlocs.loc")
 
394
  template_order = channel_info["templateOrder"]
395
  input_order = channel_info["inputOrder"]
396
 
397
+ # get template and input's 2d coords
398
+ tpl_x = [template_dict[channel]["coord_2d"][0] for channel in template_order]
399
+ tpl_y = [template_dict[channel]["coord_2d"][1] for channel in template_order]
400
+ in_x = [input_dict[channel]["coord_2d"][0] for channel in input_order]
401
+ in_y = [input_dict[channel]["coord_2d"][1] for channel in input_order]
402
+ tpl_coords = np.vstack((tpl_x, tpl_y)).T
403
+ in_coords = np.vstack((in_x, in_y)).T
404
+
405
  # get template's head figure
406
  tpl_fig = template_montage.plot()
407
  tpl_ax = tpl_fig.axes[0]
 
412
  head_lines.append((x,y))
413
  plt.close()
414
 
 
 
 
 
 
 
 
 
415
  # -------------------------plot input montage------------------------------
416
  fig = plt.figure(figsize=(6.4,6.4), dpi=100)
417
  ax = fig.add_subplot(111)
 
422
  # plot template's head
423
  for x, y in head_lines:
424
  ax.plot(x, y, color='black', linewidth=1.0)
425
+ # plot input channels on it
426
  ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
427
  for i, channel in enumerate(input_order):
428
  ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
429
+ # save input_montage
 
430
  fig.savefig(filename1)
431
 
432
  # ---------------------------add indications-------------------------------
 
436
  ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
437
  for i in indices:
438
  ax.text(in_coords[i,0]+0.003, in_coords[i,1], input_order[i], color='red', fontsize=10.0, va='center')
 
439
  # save mapped_montage
440
  fig.savefig(filename2)
 
441
 
442
  # -------------------------------------------------------------------------
443
+ # store the template and input channels' display position (in px).
444
  tpl_coords = ax.transData.transform(tpl_coords)
445
  in_coords = ax.transData.transform(in_coords)
446
+ plt.close()
447
 
448
  for i, channel in enumerate(template_order):
449
  css_left = (tpl_coords[i,0]-11)/6.4
 
461
  })
462
  return channel_info
463
 
464
+ def mapping_result(app_info, channel_info):
465
+ stage1_info = app_info["stage1"]
466
+ filepath = stage1_info["filepath"]
467
+
468
+ # generate and save figures of the input montage and the mapped montage
469
+ filename1 = filepath+"input_montage_"+str(random.randint(1,10000))+".png"
470
  filename2 = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
471
  channel_info = save_figures(channel_info, filename1, filename2)
472
+ stage1_info["filenames"].update({
473
+ "input_montage" : filename1,
 
474
  "mapped_montage" : filename2
475
  })
476
 
477
+ # -------------------------determine the next step--------------------------
478
 
479
+ input_num = len(channel_info["inputOrder"])
480
+ matched_num = 30 - len(stage1_info["missingTemplates"])
481
 
482
+ # if the in_channels has all the 30 tpl_channels (input_num>=30)
483
  # -> stage2
484
  if matched_num == 30:
485
+ stage1_info["state"] = "finished"
486
  gr.Info('The mapping process has been finished.')
487
+ if input_num > 30:
488
+ md = """
489
+ ### Mapping result
490
+ (...in red...)
491
+ """
492
+ else:
493
+ md = """
494
+ ### Mapping result
495
+ (...)
496
+ """
497
 
498
+ app_info["stage1"] = stage1_info
499
+ return {app_info_json : app_info,
500
  channel_info_json : channel_info,
501
+ map_btn : gr.Button(interactive=True),
502
+ desc_md : gr.Markdown(md, visible=True),
503
  tpl_montage : gr.Image(visible=True),
504
  mapped_montage : gr.Image(value=filename2, visible=True),
505
  run_btn : gr.Button(interactive=True)}
506
 
507
+ else:
508
+ # if matched_num < 30, and there're still some unmatched in_channels
509
+ # -> assign these in_channels to nearby unmatched tpl_channels
510
+ if input_num > matched_num:
511
+ stage1_info["state"] = "step2-initializing"
512
+ md = """
513
+ ### Step1: Mapping result
514
+ (...in red...)
515
+ """
516
+
517
+ # if input_num < 30, but all of them can match to some tpl_channels
518
+ # -> directly use fillmode to fill the remaining tpl_channels
519
+ elif input_num == matched_num:
520
+ stage1_info["state"] = "step3-initializing"
521
+ md = """
522
+ ### Step1: Mapping result
523
+ (...)
524
+ """
525
+
526
+ app_info["stage1"] = stage1_info
527
+ return {app_info_json : app_info,
528
+ channel_info_json : channel_info,
529
+ map_btn : gr.Button(interactive=True),
530
+ desc_md : gr.Markdown(md, visible=True),
531
+ tpl_montage : gr.Image(visible=True),
532
+ mapped_montage : gr.Image(value=filename2, visible=True),
533
+ next_btn : gr.Button(visible=True)}
534
 
535
+ start_stage1 = map_btn.click(
536
  fn = reset_all,
537
+ inputs = [in_data_file, in_loc_file, in_samplerate],
538
+ outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_montage, mapped_montage, radio_group,
539
+ in_fillmode, chkbox_group, fillmode_btn, clear_btn, step2_btn, step3_btn, next_btn,
540
+ run_btn, batch_md, out_data_file]
541
  ).success(
542
  fn = mapping_stage1,
543
+ inputs = [app_info_json, channel_info_json],
544
+ outputs = [app_info_json, channel_info_json, desc_md]
 
545
  ).success(
546
  fn = mapping_result,
547
+ inputs = [app_info_json, channel_info_json],
548
+ outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_montage, mapped_montage, next_btn, run_btn]
549
  )
550
 
551
+
552
+ # +========================================================================================+
553
+ # | manage step transition |
554
+ # +========================================================================================+
555
+ def init_next_step(app_info, channel_info, selected_radio, selected_chkbox):
556
+ stage1_info = app_info["stage1"]
557
 
558
  # stage1-1 -> stage1-2
559
+ if stage1_info["state"] == "step2-initializing":
560
+ #print('step1 -> step2')
561
+ md = """
562
+ ### Step2: Assign unmatched input channels
563
+ (...)
564
+ """
565
+
566
+ # initialize the progress indication label for step2
567
+ stage1_info.update({
568
+ "state" : "step2-selecting",
569
  "fillingCount" : 1,
570
+ "totalFillingNum" : len(stage1_info["missingTemplates"])
571
  })
572
+ name = stage1_info["missingTemplates"][0]
573
+ label = "{} (1/{})".format(name, stage1_info["totalFillingNum"])
574
 
575
+ app_info["stage1"] = stage1_info
576
+ # determine which button to display
577
+ if len(stage1_info["unassignedInputs"])==1 or stage1_info["totalFillingNum"]==1:
578
+ return {app_info_json : app_info,
 
579
  channel_info_json : channel_info,
580
+ desc_md : gr.Markdown(md),
581
  tpl_montage : gr.Image(visible=False),
582
  mapped_montage : gr.Image(visible=False),
583
+ radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
584
  clear_btn : gr.Button(visible=True),
585
  next_btn : gr.Button("Next step")}
586
  else:
587
+ return {app_info_json : app_info,
588
  channel_info_json : channel_info,
589
+ desc_md : gr.Markdown(md),
590
  tpl_montage : gr.Image(visible=False),
591
  mapped_montage : gr.Image(visible=False),
592
+ radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
593
  clear_btn : gr.Button(visible=True),
594
  step2_btn : gr.Button(visible=True),
595
  next_btn : gr.Button(visible=False)}
596
 
597
  # stage1-1 -> stage1-3
598
+ elif stage1_info["state"] == "step3-initializing":
599
+ #print('step1 -> step3')
600
+ md = """
601
+ ### Step3: Fill the remaining template channels
602
+ (...)
603
+ """
604
+ return {desc_md : gr.Markdown(md),
 
 
 
 
 
605
  tpl_montage : gr.Image(visible=False),
606
  mapped_montage : gr.Image(visible=False),
607
+ in_fillmode : gr.Dropdown(visible=True),
608
  fillmode_btn : gr.Button(visible=True),
609
  next_btn : gr.Button(visible=False)}
610
 
611
  # stage1-2 -> stage1-3 or stage2
612
+ elif stage1_info["state"] == "step2-selecting":
613
 
614
+ # ----------------------store information before the button click----------------------
 
 
 
 
 
 
 
615
 
616
+ # check if the user has selected an in_channel to forward to the previous target tpl_channel
617
+ if selected_radio != []:
618
+ prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
619
+ prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
620
+
621
+ # store the index of the in_channel
622
+ selected_idx = channel_info["inputDict"][selected_radio]["index"]
623
+ stage1_info["newOrder"][prev_target_idx] = [selected_idx]
624
+ # mark the in_channel as assigned and tpl_channel as matched
625
  channel_info["templateDict"][prev_target_name]["matched"] = True
626
  channel_info["inputDict"][selected_radio]["assigned"] = True
627
+ print(prev_target_name, '<-', selected_radio)
628
 
629
+ # ------------------------update information for the next step-------------------------
630
+
631
+ # update the list of unassignedInputs to exclude the selected in_channel of the previous round
632
+ stage1_info["unassignedInputs"] = [channel for channel in channel_info["inputOrder"]
633
+ if channel_info["inputDict"][channel]["assigned"]==False]
634
+ # update the list of missingTemplates to exclude those filled in step2
635
+ stage1_info["missingTemplates"] = [channel for channel in channel_info["templateOrder"]
636
+ if channel_info["templateDict"][channel]["matched"]==False]
637
 
638
+ # if all the unmatched tpl_channels were filled by in_channels
639
  # -> stage2
640
+ if len(stage1_info["missingTemplates"]) == 0:
641
+ #print('step2 -> stage2')
642
+ stage1_info["state"] = "finished"
643
  gr.Info('The mapping process has been finished.')
 
644
 
645
+ app_info["stage1"] = stage1_info
646
+ return {app_info_json : app_info,
647
  channel_info_json : channel_info,
648
  desc_md : gr.Markdown(visible=False),
649
  radio_group : gr.Radio(visible=False),
 
653
 
654
  # -> stage1-3
655
  else:
656
+ #print('step2 -> step3')
657
+ md = """
658
+ ### Step3: Fill the remaining template channels
659
+ (...)
660
+ """
661
+
662
+ app_info["stage1"] = stage1_info
663
+ return {app_info_json : app_info,
664
  channel_info_json : channel_info,
665
+ desc_md : gr.Markdown(md),
666
  radio_group : gr.Radio(visible=False),
667
+ in_fillmode : gr.Dropdown(visible=True),
668
  fillmode_btn : gr.Button(visible=True),
669
  clear_btn : gr.Button(visible=False),
670
  next_btn : gr.Button(visible=False)}
671
 
672
  # stage1-3 -> stage2
673
+ elif stage1_info["state"] == "step3-selecting":
674
+ #print('step3 -> stage2')
675
+ stage1_info["state"] = "finished"
 
 
 
 
 
 
 
 
 
676
  gr.Info('The mapping process has been finished.')
 
 
677
 
678
+ # ----------------------store information before the button click----------------------
679
+
680
+ # check if the user has not unchecked all in_channel checkboxes
681
+ if selected_chkbox != []:
682
+ prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
683
+ prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
684
+
685
+ # store the indices of the in_channels
686
+ selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
687
+ stage1_info["newOrder"][prev_target_idx] = selected_indices
688
+ #print(f'{prev_target_name}({prev_target_idx}): {selected_chkbox}')
689
+ # -------------------------------------------------------------------------------------
690
 
691
+ app_info["stage1"] = stage1_info
692
+ return {app_info_json : app_info,
693
  desc_md : gr.Markdown(visible=False),
694
  chkbox_group : gr.CheckboxGroup(visible=False),
695
  next_btn : gr.Button(visible=False),
 
697
 
698
  next_btn.click(
699
  fn = init_next_step,
700
+ inputs = [app_info_json, channel_info_json, radio_group, chkbox_group],
701
+ outputs = [app_info_json, channel_info_json, desc_md, tpl_montage, mapped_montage, radio_group,
702
+ in_fillmode, chkbox_group, fillmode_btn, clear_btn, step2_btn, next_btn, run_btn]
703
  ).success(
704
  fn = None,
705
  js = init_js,
706
+ inputs = [app_info_json, channel_info_json],
707
  outputs = []
708
  )
709
 
710
+
711
+ # +========================================================================================+
712
+ # | stage1-2 |
713
+ # +========================================================================================+
714
+ def update_radio(app_info, channel_info, selected):
715
+ stage1_info = app_info["stage1"]
716
 
717
+ # ----------------------store information before the button click----------------------
 
 
 
 
 
 
 
718
 
719
+ # check if the user has selected an in_channel to forward to the previous target tpl_channel
720
+ if selected != []:
721
+ prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
722
+ prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
723
+
724
+ # store the index of the selected in_channel
725
+ selected_idx = channel_info["inputDict"][selected]["index"]
726
+ stage1_info["newOrder"][prev_target_idx] = [selected_idx]
727
+ # mark the in_channel as assigned and tpl_channel as matched
728
  channel_info["templateDict"][prev_target_name]["matched"] = True
729
  channel_info["inputDict"][selected]["assigned"] = True
730
+ print(prev_target_name, '<-', selected)
731
+
732
+ # ------------------------update information for the new round-------------------------
733
+ stage1_info["fillingCount"] += 1
734
 
735
+ # update the list of unassignedInputs to exclude the selected in_channel of the previous round
736
+ stage1_info["unassignedInputs"] = [channel for channel in channel_info["inputOrder"]
 
737
  if channel_info["inputDict"][channel]["assigned"]==False]
738
+ # update the progress indication label
739
+ target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
740
+ radio_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
741
 
742
+ app_info["stage1"] = stage1_info
743
+ # determine which button to display
744
+ if len(stage1_info["unassignedInputs"])==1 or stage1_info["fillingCount"]==stage1_info["totalFillingNum"]:
745
+ return {app_info_json : app_info,
 
746
  channel_info_json : channel_info,
747
+ radio_group : gr.Radio(choices=stage1_info["unassignedInputs"],
748
  value=[], label=radio_label),
749
  step2_btn : gr.Button(visible=False),
750
  next_btn : gr.Button("Next step", visible=True)}
751
  else:
752
+ return {app_info_json : app_info,
753
  channel_info_json : channel_info,
754
+ radio_group : gr.Radio(choices=stage1_info["unassignedInputs"],
755
  value=[], label=radio_label)}
756
 
757
  step2_btn.click(
758
  fn = update_radio,
759
+ inputs = [app_info_json, channel_info_json, radio_group],
760
+ outputs = [app_info_json, channel_info_json, radio_group, step2_btn, next_btn]
 
761
  ).success(
762
  fn = None,
763
  js = update_js,
764
+ inputs = [app_info_json, channel_info_json],
765
  outputs = []
766
  )
767
 
 
772
  )
773
 
774
 
775
+ # +========================================================================================+
776
+ # | stage1-3 |
777
+ # +========================================================================================+
778
+ def fill_value(app_info, channel_info, fillmode):
779
+ stage1_info = app_info["stage1"]
780
 
781
+ if fillmode == "zero":
782
+ stage1_info["state"] = "finished"
783
  gr.Info('The mapping process has been finished.')
784
 
785
+ app_info["stage1"] = stage1_info
786
+ return {app_info_json : app_info,
787
  desc_md : gr.Markdown(visible=False),
788
+ in_fillmode : gr.Dropdown(visible=False),
789
  fillmode_btn : gr.Button(visible=False),
790
  run_btn : gr.Button(interactive=True)}
791
 
792
+ elif fillmode == "mean":
793
+ md = """
794
+ ### Step3: Fill the remaining template channels
795
+ (...)
796
+ """
797
 
798
+ # find the 4-NN in_channels for each of the unmatched tpl_channels
799
+ new_idx = find_neighbors(channel_info, stage1_info["missingTemplates"], stage1_info["newOrder"])
800
+
801
+ stage1_info.update({
802
+ "state" : "step3-selecting",
803
+ "newOrder" : new_idx,
804
+ "fillingCount" : 1,
805
+ "totalFillingNum" : len(stage1_info["missingTemplates"])
806
+ })
807
 
808
+ # initialize the progress indicator label
809
+ target_name = stage1_info["missingTemplates"][0]
810
+ target_idx = channel_info["templateDict"][target_name]["index"]
811
+ chkbox_value = stage1_info["newOrder"][target_idx]
812
  chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
813
+ chkbox_label = "{} (1/{})".format(target_name, stage1_info["totalFillingNum"])
814
 
815
+ app_info["stage1"] = stage1_info
816
+ # determine which button to display
817
+ if stage1_info["totalFillingNum"] == 1:
818
+ return {app_info_json : app_info,
819
+ desc_md : gr.Markdown(md),
820
+ in_fillmode : gr.Dropdown(visible=False),
821
  fillmode_btn : gr.Button(visible=False),
822
  chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
823
  value=chkbox_value, label=chkbox_label, visible=True),
824
  next_btn : gr.Button(visible=True)}
825
  else:
826
+ return {app_info_json : app_info,
827
+ desc_md : gr.Markdown(md),
828
+ in_fillmode : gr.Dropdown(visible=False),
829
  fillmode_btn : gr.Button(visible=False),
830
  chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
831
  value=chkbox_value, label=chkbox_label, visible=True),
832
  step3_btn : gr.Button(visible=True)}
833
 
834
+ def update_chkbox(app_info, channel_info, selected):
835
+ stage1_info = app_info["stage1"]
 
 
 
 
 
 
 
 
 
836
 
837
+ # ----------------------store information before the button click----------------------
 
838
 
839
+ # check if the user has not unchecked all in_channel checkboxes
840
+ if selected != []:
841
+ prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
842
+ prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
843
+
844
+ # store the indices of the selected in_channels
845
+ selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected]
846
+ stage1_info["newOrder"][prev_target_idx] = selected_indices
847
+ #print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
848
+
849
+ # ------------------------update information for the new round-------------------------
850
+ stage1_info["fillingCount"] += 1
851
 
852
+ # update the progress indication label
853
+ target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
854
+ target_idx = channel_info["templateDict"][target_name]["index"]
855
+ chkbox_value = stage1_info["newOrder"][target_idx]
856
  chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
857
+ chkbox_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
858
 
859
+ app_info["stage1"] = stage1_info
860
+ # determine which button to display
861
+ if stage1_info["fillingCount"] == stage1_info["totalFillingNum"]:
862
+ return {app_info_json : app_info,
863
  chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label),
864
  step3_btn : gr.Button(visible=False),
865
  next_btn : gr.Button("Submit", visible=True)}
866
  else:
867
+ return {app_info_json : app_info,
868
  chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)}
869
 
870
  fillmode_btn.click(
871
  fn = fill_value,
872
+ inputs = [app_info_json, channel_info_json, in_fillmode],
873
+ outputs = [app_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn, next_btn, run_btn]
874
  ).success(
875
  fn = None,
876
  js = init_js,
877
+ inputs = [app_info_json, channel_info_json],
878
  outputs = []
879
  )
880
 
881
  step3_btn.click(
882
  fn = update_chkbox,
883
+ inputs = [app_info_json, channel_info_json, chkbox_group],
884
+ outputs = [app_info_json, chkbox_group, step3_btn, next_btn]
 
885
  ).success(
886
  fn = None,
887
  js = update_js,
888
+ inputs = [app_info_json, channel_info_json],
889
  outputs = []
890
  )
891
 
 
 
 
 
 
 
892
 
893
+ # +========================================================================================+
894
+ # | stage2: decode data |
895
+ # +========================================================================================+
896
+ def reset_run(app_info, channel_info, modelname):
897
+ stage1_info = app_info["stage1"]
898
+ stage2_info = app_info["stage2"]
899
 
900
+ # delete the previous folder of stage2 if it exists
901
+ filepath = stage2_info["filepath"]
902
+ utils.dataDelete(filepath)
903
+ # establish a new folder for stage2
904
+ new_filepath = app_info["rootFilepath"]+"stage2_"+str(random.randint(1,10000))+"/"
905
+ os.mkdir(new_filepath)
906
+ # generate the output filename
907
+ filename = stage1_info["filenames"]["input_data"]
908
+ filename = os.path.basename(str(filename))
909
+ new_filename = os.path.splitext(filename)[0]+'_'+modelname+'.csv'
910
 
911
+ # reset inputChannel.assigned back to the state after stage1
912
+ for channel in stage1_info["unassignedInputs"]:
913
+ channel_info["inputDict"][channel]["assigned"] = False
914
+ # calculate how many times the model needs to be run
915
+ unassigned_num = len(stage1_info["unassignedInputs"])
916
+ batch_num = math.ceil(unassigned_num/30) + 1
917
 
918
+ app_info.update({
919
+ #"currentStage" : "stage2",
920
+ "stage2" : {
921
+ "filepath" : new_filepath,
922
+ "filenames" : {
923
+ "output_data" : new_filepath + new_filename
924
+ },
925
+ #"state" : "initializing",
926
+ "totalBatchNum" : batch_num,
927
+ "newOrder" : [[]]*30,
928
+ "unassignedInputs" : stage1_info["unassignedInputs"]
929
+ }
930
  })
931
+ return {app_info_json : app_info,
932
  channel_info_json : channel_info,
933
+ #run_btn : gr.Button(interactive=False),
934
  batch_md : gr.Markdown(visible=False),
935
+ out_data_file : gr.File(visible=False)}
936
 
937
+ def run_model(app_info, channel_info, modelname):
938
+ stage1_info = app_info["stage1"]
939
+ stage2_info = app_info["stage2"]
 
940
 
941
+ filepath = stage2_info["filepath"]
942
+ samplerate = app_info["sampleRate"]
943
+ filename = stage1_info["filenames"]["input_data"]
944
+ new_filename = stage2_info["filenames"]["output_data"]
945
+
946
+ # set a flag to record whether the user has clicked the map_btn or run_btn while running the model
947
+ break_flag = False
948
+
949
+ # run the model multiple times until all in_channels are reconstructed
950
+ for i in range(stage2_info["totalBatchNum"]):
951
+ # establish a temp folder
952
+ try:
953
+ os.mkdir(filepath+"temp_data/")
954
+ #except FileExistsError:
955
+ #utils.dataDelete(filepath+"temp_data/")
956
+ #os.mkdir(filepath+"temp_data/")
957
+ except FileNotFoundError:
958
+ #print('break1!!')
959
+ break_flag = True
960
+ break
961
+ except OSError as e:
962
+ print(e)
963
 
964
+ # update the running status
965
+ md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
966
+ yield {batch_md : gr.Markdown(md, visible=True)}
 
967
 
968
+ if i == 0:
969
+ new_idx = stage1_info["newOrder"]
970
+ else:
971
+ # if this is not the first time running the model, the in_channels that have
972
+ # not been reconstructed yet will be optimally mapped to the template.
973
+ stage2_info, channel_info = mapping_stage2(stage2_info, channel_info)
974
+ new_idx = stage2_info["newOrder"]
975
+ #print('unassigned num:', len(stage2_info["unassignedInputs"]))
976
 
977
+ # ----------------------------------------------------------------------
978
+ try:
979
+ # step1: Reorder input data
980
+ reorder_input_data(new_idx, filename, filepath+"temp_data/mapped.csv")
981
+ # step2: Data preprocessing
982
+ total_file_num = utils.preprocessing(filepath+"temp_data/", "mapped.csv", samplerate)
983
+ # step3: Signal reconstruction
984
+ utils.reconstruct(modelname, total_file_num, filepath+"temp_data/", "denoised.csv", samplerate)
985
+ # step4: Restore original order
986
+ restore_original_order(channel_info, i, new_idx, filepath+"temp_data/denoised.csv", new_filename)
987
+ except FileNotFoundError:
988
+ #print('break2!!')
989
+ break_flag = True
990
+ break
991
+ # ----------------------------------------------------------------------
992
+ utils.dataDelete(filepath+"temp_data/")
993
+ app_info["stage2"] = stage2_info
994
+
995
+ if break_flag == True:
996
+ yield {batch_md : gr.Markdown(visible=False)}
997
+ else:
998
+ yield {#run_btn : gr.Button(interactive=True),
999
+ batch_md : gr.Markdown(visible=False),
1000
+ out_data_file : gr.File(new_filename, visible=True)}
1001
 
1002
  run_btn.click(
1003
  fn = reset_run,
1004
+ inputs = [app_info_json, channel_info_json, in_modelname],
1005
+ outputs = [app_info_json, channel_info_json, run_btn, batch_md, out_data_file]
1006
 
1007
  ).success(
1008
  fn = run_model,
1009
+ inputs = [app_info_json, channel_info_json, in_modelname],
1010
+ outputs = [run_btn, batch_md, out_data_file]
1011
  )
1012
 
1013
  if __name__ == "__main__":
1014
  demo.launch()
1015
+
1016
+
1017
+ """
1018
+ --------
1019
+ |----(inputname).csv
1020
+ |----session_data
1021
+ |----stage1
1022
+ |----input_montage.png
1023
+ |----mapped_montage.png
1024
+ |----stage2_(...)
1025
+ |----temp_data
1026
+ |----mapped.csv
1027
+ |----denoised.csv
1028
+ |----temp2
1029
+ |...
1030
+ |----(outputname).csv
1031
+ """
1032
+
channel_mapping.py CHANGED
@@ -10,13 +10,10 @@ from scipy.interpolate import Rbf
10
  from scipy.optimize import linear_sum_assignment
11
  from sklearn.neighbors import NearestNeighbors
12
 
13
- def reorder_to_template(app_state, filename):
14
- old_idx = app_state["stage1NewOrder"] if app_state["runningState"]=="stage1" else app_state["stage2NewOrder"]
15
  old_data = utils.read_train_data(filename) # original raw data
16
- new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
17
- new_filename = app_state["filepath"]+'mapped.csv'
18
- #print('new order 1:', app_state["stage1NewOrder"])
19
- #print('new order 2:', app_state["stage2NewOrder"])
20
 
21
  zero_arr = np.zeros((1, old_data.shape[1]))
22
  old_data = np.concatenate((old_data, zero_arr), axis=0)
@@ -31,25 +28,25 @@ def reorder_to_template(app_state, filename):
31
  tmp_data = [old_data[j, :] for j in idx_set]
32
  new_data[i, :] = np.mean(tmp_data, axis=0)
33
 
34
- print('old.shape, new.shape: ', old_data.shape, new_data.shape)
 
35
  utils.save_data(new_data, new_filename)
36
  return
37
 
38
- def reorder_to_origin(app_state, channel_info, new_filename):
39
- filename = app_state["filepath"]+'denoised.csv'
40
- old_idx = app_state["stage1NewOrder"] if app_state["runningState"]=="stage1" else app_state["stage2NewOrder"]
41
  old_data = utils.read_train_data(filename) # denoised data
42
  template_order = channel_info["templateOrder"]
 
43
 
44
- if app_state["runningState"] == "stage1":
45
- new_data = np.zeros((len(channel_info["inputOrder"]), old_data.shape[1]))
46
  else:
47
  new_data = utils.read_train_data(new_filename)
48
 
49
  for i, channel in enumerate(template_order):
50
  idx_set = old_idx[i]
51
 
52
- # ignore if this channel was filled with fill_mode ('mean' or 'zero')
53
  if len(idx_set)==1 and channel_info["templateDict"][channel]["matched"]==True:
54
  new_data[idx_set[0], :] = old_data[i, :]
55
 
@@ -97,9 +94,9 @@ def align_coords(channel_info, template_montage, input_montage):
97
 
98
 
99
  # --------------------------------2-D------------------------------------
100
- # (for the indication of missing template channel's position when fill_mode:'mean')
101
 
102
- fig = [template_montage.plot(), input_montage.plot()]
103
  ax = [fig[0].axes[0], fig[1].axes[0]]
104
 
105
  # get the original coords
@@ -154,48 +151,38 @@ def align_coords(channel_info, template_montage, input_montage):
154
  })
155
  return channel_info
156
 
157
- def find_neighbors(app_state, channel_info):
158
- new_idx = app_state["stage1NewOrder"] if app_state["runningState"]=="stage1" else app_state["stage2NewOrder"]
159
  template_dict = channel_info["templateDict"]
160
  input_dict = channel_info["inputDict"]
161
- template_order = channel_info["templateOrder"]
162
  input_order = channel_info["inputOrder"]
163
- missing_channels = app_state["missingTemplates"]
164
- if missing_channels == []:
165
- return app_state # change nothing
166
 
167
-
168
- in_coords = [input_dict[channel]["coord_3d"] for channel in input_order]
169
- in_coords = np.array([in_coords[i] for i in range(len(in_coords))])
170
 
171
  # use KNN to choose k nearest channels
172
  k = 4 if len(input_order)>4 else len(input_order)
173
  knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
174
- knn.fit(in_coords)
175
 
176
- for channel in missing_channels:
177
- distances, indices = knn.kneighbors(np.array(template_dict[channel]["coord_3d"]).reshape(1,-1))
178
- selected = [input_order[i] for i in indices[0]]
179
  #print(channel, ':', selected)
180
 
181
  idx = template_dict[channel]["index"]
182
  new_idx[idx] = indices[0].tolist()
183
-
184
- if app_state["runningState"] == "stage1":
185
- app_state["stage1NewOrder"] = new_idx
186
- else:
187
- app_state["stage2NewOrder"] = new_idx
188
 
189
- return app_state
190
 
191
- def mapping_stage1(app_state, channel_info, loc_file):
192
- yield app_state, channel_info, gr.Markdown("Mapping...", visible=True) # app_state, channel_info???
193
  second1 = time.time()
194
 
 
195
  template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
196
  template_order = template_montage.ch_names
197
  input_order = input_montage.ch_names
198
- new_idx = [[]]*30
199
  alias_dict = {
200
  'T3': 'T7',
201
  'T4': 'T8',
@@ -203,21 +190,20 @@ def mapping_stage1(app_state, channel_info, loc_file):
203
  'T6': 'P8'
204
  }
205
 
206
- # match the names of input channels -> template channels
207
  for i, channel in enumerate(template_order):
208
  if channel in alias_dict and alias_dict[channel] in input_dict:
209
- template_montage.rename_channels({channel: alias_dict[channel]})
210
  template_dict[alias_dict[channel]] = template_dict.pop(channel)
211
  channel = alias_dict[channel]
212
-
213
  if channel in input_dict:
214
  new_idx[i] = [input_dict[channel]["index"]]
215
  template_dict[channel]["matched"] = True
216
  input_dict[channel]["assigned"] = True
217
 
218
- # update names
219
  template_order = template_montage.ch_names
220
- input_order = input_montage.ch_names
221
 
222
  channel_info.update({
223
  "templateDict" : template_dict,
@@ -225,10 +211,9 @@ def mapping_stage1(app_state, channel_info, loc_file):
225
  "templateOrder" : template_order,
226
  "inputOrder" : input_order
227
  })
228
- app_state.update({
229
- "stage1NewOrder" : new_idx,
230
- "runningState" : "stage1",
231
- "stage1UnassignedInputs" : [channel for channel in input_order if input_dict[channel]["assigned"]==False],
232
  "missingTemplates" : [channel for channel in template_order if template_dict[channel]["matched"]==False]
233
  })
234
 
@@ -237,19 +222,16 @@ def mapping_stage1(app_state, channel_info, loc_file):
237
 
238
  second2 = time.time()
239
  print('Mapping (stage1) finished in',second2 - second1,'s.')
240
- yield app_state, channel_info, gr.Markdown("", visible=False)
241
 
242
- def mapping_stage2(app_state, channel_info):
243
  second1 = time.time()
244
 
245
  template_dict = channel_info["templateDict"]
246
  input_dict = channel_info["inputDict"]
247
  template_order = channel_info["templateOrder"]
248
  input_order = channel_info["inputOrder"]
249
- unassigned = app_state["stage2UnassignedInputs"]
250
- if unassigned == []:
251
- app_state["runningState"] = "finished"
252
- return app_state, channel_info
253
 
254
  tpl_coords = np.array([template_dict[channel]["coord_3d"] for channel in template_order])
255
  unassigned_coords = np.array([input_dict[channel]["coord_3d"] for channel in unassigned])
@@ -274,31 +256,29 @@ def mapping_stage2(app_state, channel_info):
274
  new_idx = [[]]*30
275
  for i in range(30):
276
  if col_idx[i] < len(unassigned): # filter out dummy channels
277
- print(f'({row_idx[i]}, {col_idx[i]})')
278
-
279
  tpl_channel = template_order[row_idx[i]]
280
  in_channel = unassigned[col_idx[i]]
281
  template_dict[tpl_channel]["matched"] = True
282
  input_dict[in_channel]["assigned"] = True
283
  new_idx[row_idx[i]] = [input_dict[in_channel]["index"]]
284
 
285
- print(template_order[row_idx[i]], '<-', unassigned[col_idx[i]])
286
 
 
 
 
 
 
 
 
 
 
287
  channel_info.update({
288
  "templateDict" : template_dict,
289
  "inputDict" : input_dict
290
  })
291
- app_state.update({
292
- "stage2NewOrder" : new_idx,
293
- "runningState" : "stage2",
294
- "stage2UnassignedInputs" : [channel for channel in input_order if input_dict[channel]["assigned"]==False],
295
- "missingTemplates" : [channel for channel in template_order if template_dict[channel]["matched"]==False]
296
- })
297
-
298
- # fill the missing_channels
299
- app_state = find_neighbors(app_state, channel_info)
300
 
301
  second2 = time.time()
302
- print(f'Mapping (stage2-{app_state["batchCount"]-1}) finished in {second2 - second1}s.')
303
- return app_state, channel_info
304
 
 
10
  from scipy.optimize import linear_sum_assignment
11
  from sklearn.neighbors import NearestNeighbors
12
 
13
+ def reorder_input_data(old_idx, filename, new_filename):
 
14
  old_data = utils.read_train_data(filename) # original raw data
15
+ new_data = np.zeros((30, old_data.shape[1])) # to store reordered raw data
16
+ print('new index order:', old_idx)
 
 
17
 
18
  zero_arr = np.zeros((1, old_data.shape[1]))
19
  old_data = np.concatenate((old_data, zero_arr), axis=0)
 
28
  tmp_data = [old_data[j, :] for j in idx_set]
29
  new_data[i, :] = np.mean(tmp_data, axis=0)
30
 
31
+ old_shape = (old_data.shape[0]-1, old_data.shape[1])
32
+ print('old.shape, new.shape: ', old_shape, new_data.shape)
33
  utils.save_data(new_data, new_filename)
34
  return
35
 
36
+ def restore_original_order(channel_info, cnt, old_idx, filename, new_filename):
 
 
37
  old_data = utils.read_train_data(filename) # denoised data
38
  template_order = channel_info["templateOrder"]
39
+ input_order = channel_info["inputOrder"]
40
 
41
+ if cnt == 0:
42
+ new_data = np.zeros((len(input_order), old_data.shape[1]))
43
  else:
44
  new_data = utils.read_train_data(new_filename)
45
 
46
  for i, channel in enumerate(template_order):
47
  idx_set = old_idx[i]
48
 
49
+ # ignore if this channel was filled with fillmode ('mean' or 'zero')
50
  if len(idx_set)==1 and channel_info["templateDict"][channel]["matched"]==True:
51
  new_data[idx_set[0], :] = old_data[i, :]
52
 
 
94
 
95
 
96
  # --------------------------------2-D------------------------------------
97
+ # (for the indicate the location missing template channel's position when fill_mode:'mean')
98
 
99
+ fig = [template_montage.plot(), input_montage.plot()]
100
  ax = [fig[0].axes[0], fig[1].axes[0]]
101
 
102
  # get the original coords
 
151
  })
152
  return channel_info
153
 
154
+ def find_neighbors(channel_info, missing_channels, new_idx):
 
155
  template_dict = channel_info["templateDict"]
156
  input_dict = channel_info["inputDict"]
 
157
  input_order = channel_info["inputOrder"]
 
 
 
158
 
159
+ all_in = [np.array(input_dict[channel]["coord_3d"]) for channel in input_order]
160
+ missing_tpl = [np.array(template_dict[channel]["coord_3d"]) for channel in missing_channels]
 
161
 
162
  # use KNN to choose k nearest channels
163
  k = 4 if len(input_order)>4 else len(input_order)
164
  knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
165
+ knn.fit(all_in)
166
 
167
+ for i, channel in enumerate(missing_channels):
168
+ distances, indices = knn.kneighbors(missing_tpl[i].reshape(1,-1))
169
+ #selected = [input_order[j] for j in indices[0]]
170
  #print(channel, ':', selected)
171
 
172
  idx = template_dict[channel]["index"]
173
  new_idx[idx] = indices[0].tolist()
 
 
 
 
 
174
 
175
+ return new_idx
176
 
177
+ def mapping_stage1(app_info, channel_info):
178
+ yield app_info, channel_info, gr.Markdown("Mapping...", visible=True)
179
  second1 = time.time()
180
 
181
+ loc_file = app_info["stage1"]["filenames"]["input_loc"]
182
  template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
183
  template_order = template_montage.ch_names
184
  input_order = input_montage.ch_names
185
+ new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channls
186
  alias_dict = {
187
  'T3': 'T7',
188
  'T4': 'T8',
 
190
  'T6': 'P8'
191
  }
192
 
193
+ # match the names of input channels and template channels
194
  for i, channel in enumerate(template_order):
195
  if channel in alias_dict and alias_dict[channel] in input_dict:
196
+ template_montage.rename_channels({channel: alias_dict[channel]}) # rename the current tpl_channel
197
  template_dict[alias_dict[channel]] = template_dict.pop(channel)
198
  channel = alias_dict[channel]
199
+
200
  if channel in input_dict:
201
  new_idx[i] = [input_dict[channel]["index"]]
202
  template_dict[channel]["matched"] = True
203
  input_dict[channel]["assigned"] = True
204
 
205
+ # update the names
206
  template_order = template_montage.ch_names
 
207
 
208
  channel_info.update({
209
  "templateDict" : template_dict,
 
211
  "templateOrder" : template_order,
212
  "inputOrder" : input_order
213
  })
214
+ app_info["stage1"].update({
215
+ "newOrder" : new_idx,
216
+ "unassignedInputs" : [channel for channel in input_order if input_dict[channel]["assigned"]==False],
 
217
  "missingTemplates" : [channel for channel in template_order if template_dict[channel]["matched"]==False]
218
  })
219
 
 
222
 
223
  second2 = time.time()
224
  print('Mapping (stage1) finished in',second2 - second1,'s.')
225
+ yield app_info, channel_info, gr.Markdown("", visible=False)
226
 
227
+ def mapping_stage2(stage2_info, channel_info):
228
  second1 = time.time()
229
 
230
  template_dict = channel_info["templateDict"]
231
  input_dict = channel_info["inputDict"]
232
  template_order = channel_info["templateOrder"]
233
  input_order = channel_info["inputOrder"]
234
+ unassigned = stage2_info["unassignedInputs"]
 
 
 
235
 
236
  tpl_coords = np.array([template_dict[channel]["coord_3d"] for channel in template_order])
237
  unassigned_coords = np.array([input_dict[channel]["coord_3d"] for channel in unassigned])
 
256
  new_idx = [[]]*30
257
  for i in range(30):
258
  if col_idx[i] < len(unassigned): # filter out dummy channels
 
 
259
  tpl_channel = template_order[row_idx[i]]
260
  in_channel = unassigned[col_idx[i]]
261
  template_dict[tpl_channel]["matched"] = True
262
  input_dict[in_channel]["assigned"] = True
263
  new_idx[row_idx[i]] = [input_dict[in_channel]["index"]]
264
 
265
+ print(f'{template_order[row_idx[i]]}({row_idx[i]}) <- {unassigned[col_idx[i]]}({col_idx[i]})')
266
 
267
+ # fill the missing_channels
268
+ missing_channels = [channel for channel in template_order if template_dict[channel]["matched"]==False]
269
+ if missing_channels != []:
270
+ new_idx = find_neighbors(channel_info, missing_channels, new_idx)
271
+
272
+ stage2_info.update({
273
+ "newOrder" : new_idx,
274
+ "unassignedInputs" : [channel for channel in input_order if input_dict[channel]["assigned"]==False]
275
+ })
276
  channel_info.update({
277
  "templateDict" : template_dict,
278
  "inputDict" : input_dict
279
  })
 
 
 
 
 
 
 
 
 
280
 
281
  second2 = time.time()
282
+ print("The mapping process has been finished in", second2 - second1, "s.")
283
+ return stage2_info, channel_info
284
 
utils.py CHANGED
@@ -98,7 +98,7 @@ def cut_data(filepath, raw_data):
98
  total = int(len(raw_data[0]) / 1024)
99
  for i in range(total):
100
  table = raw_data[:, i * 1024:(i + 1) * 1024]
101
- filename = filepath + '/temp2/' + str(i) + '.csv'
102
  with open(filename, 'w', newline='') as csvfile:
103
  writer = csv.writer(csvfile)
104
  writer.writerows(table)
@@ -213,10 +213,10 @@ def decode_data(data, std_num, mode=5):
213
  def preprocessing(filepath, filename, samplerate):
214
  # establish temp folder
215
  try:
216
- os.mkdir(filepath+"/temp2/")
217
  except OSError as e:
218
- dataDelete(filepath+"/temp2/")
219
- os.mkdir(filepath+"/temp2/")
220
  print(e)
221
 
222
  # read data
@@ -239,7 +239,7 @@ def reconstruct(model_name, total, filepath, outputfile, samplerate):
239
  # -------------------decode_data---------------------------
240
  second1 = time.time()
241
  for i in range(total):
242
- file_name = filepath + '/temp2/{}.csv'.format(str(i))
243
  data_noise = read_train_data(file_name)
244
 
245
  std = np.std(data_noise)
@@ -251,17 +251,18 @@ def reconstruct(model_name, total, filepath, outputfile, samplerate):
251
  d_data = decode_data(data_noise, std, model_name)
252
  d_data = d_data[0]
253
 
254
- outputname = filepath + '/temp2/output{}.csv'.format(str(i))
255
  save_data(d_data, outputname)
256
 
257
  # --------------------glue_data----------------------------
258
- signal = glue_data(filepath+"/temp2/", total, filepath+outputfile)
259
  #print(signal.shape)
260
  # -------------------delete_data---------------------------
261
- dataDelete(filepath+"/temp2/")
262
  # --------------------resample-----------------------------
263
  signal = resample_(signal, 256, samplerate) # 256Hz -> original sampling rate
264
  #print(signal.shape)
 
265
  save_data(signal, filepath+outputfile)
266
  second2 = time.time()
267
 
 
98
  total = int(len(raw_data[0]) / 1024)
99
  for i in range(total):
100
  table = raw_data[:, i * 1024:(i + 1) * 1024]
101
+ filename = filepath + 'temp2/' + str(i) + '.csv'
102
  with open(filename, 'w', newline='') as csvfile:
103
  writer = csv.writer(csvfile)
104
  writer.writerows(table)
 
213
  def preprocessing(filepath, filename, samplerate):
214
  # establish temp folder
215
  try:
216
+ os.mkdir(filepath+"temp2/")
217
  except OSError as e:
218
+ dataDelete(filepath+"temp2/")
219
+ os.mkdir(filepath+"temp2/")
220
  print(e)
221
 
222
  # read data
 
239
  # -------------------decode_data---------------------------
240
  second1 = time.time()
241
  for i in range(total):
242
+ file_name = filepath + 'temp2/{}.csv'.format(str(i))
243
  data_noise = read_train_data(file_name)
244
 
245
  std = np.std(data_noise)
 
251
  d_data = decode_data(data_noise, std, model_name)
252
  d_data = d_data[0]
253
 
254
+ outputname = filepath + 'temp2/output{}.csv'.format(str(i))
255
  save_data(d_data, outputname)
256
 
257
  # --------------------glue_data----------------------------
258
+ signal = glue_data(filepath+"temp2/", total, filepath+outputfile)
259
  #print(signal.shape)
260
  # -------------------delete_data---------------------------
261
+ dataDelete(filepath+"temp2/")
262
  # --------------------resample-----------------------------
263
  signal = resample_(signal, 256, samplerate) # 256Hz -> original sampling rate
264
  #print(signal.shape)
265
+ # --------------------save_data----------------------------
266
  save_data(signal, filepath+outputfile)
267
  second2 = time.time()
268