Spaces:
Sleeping
Sleeping
Commit
·
3566452
1
Parent(s):
356e41c
update v4
Browse files- app.py +383 -426
- channel_mapping.py → app_utils.py +206 -122
app.py
CHANGED
@@ -1,16 +1,8 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
import os
|
4 |
import random
|
5 |
-
import math
|
6 |
-
import numpy as np
|
7 |
-
import matplotlib.pyplot as plt
|
8 |
-
import mne
|
9 |
-
from mne.channels import read_custom_montage
|
10 |
-
|
11 |
-
import utils
|
12 |
-
from channel_mapping import mapping_stage1, mapping_stage2, reorder_input_data, restore_original_order, find_neighbors
|
13 |
-
|
14 |
|
15 |
readme = """
|
16 |
|
@@ -35,8 +27,14 @@ Your unmatched channels, previously highlighted in red, will be shown on your mo
|
|
35 |
### Step3: Filling Remaining Template Channels
|
36 |
To run the models successfully, we need to ensure that all 30 template channels are filled. In this step, you are required to select one of the methods provided below to fill the remaining empty template channels:
|
37 |
- **Mean** method: Each empty template channel is filled with the average value of data from the nearest input channels. By default, the 4 closest input channels (determined after aligning your montage to the template's scale using TPS) are selected for this averaging process. On the interface, you will see checkboxes displayed above each of your channel. The 4 nearest channels are pre-selected by default for each empty template channels, but you can modify these selections as needed. If you uncheck all the checkboxes for a particular template channel, it will be filled with zeros.
|
38 |
-
- **Zero** method: All empty template channels are filled with zeros.
|
39 |
-
Choose
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
## 2. Decode data
|
42 |
In this phase, you can select which model to use for denoising your EEG data. Detailed information about the models can be found in the other tabs.
|
@@ -64,7 +62,7 @@ init_js = """
|
|
64 |
selector = "#radio-group > div:nth-of-type(2)";
|
65 |
//classname = "radio";
|
66 |
attribute = "value";
|
67 |
-
}else if(stage1_info.state == "step3-selecting"){
|
68 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
69 |
//classname = "chkbox";
|
70 |
attribute = "name";
|
@@ -78,7 +76,7 @@ init_js = """
|
|
78 |
aspect-ratio: 1;
|
79 |
//width: 560px;
|
80 |
//height: 560px;
|
81 |
-
background: url("file=${stage1_info.
|
82 |
background-size: contain;
|
83 |
|
84 |
`;
|
@@ -165,7 +163,7 @@ update_js = """
|
|
165 |
item.className = "";
|
166 |
item.querySelector(":scope > span").innerText = "";
|
167 |
});
|
168 |
-
}else if(stage1_info.state == "step3-selecting"){
|
169 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
170 |
}else return;
|
171 |
|
@@ -247,15 +245,14 @@ with gr.Blocks() as demo:
|
|
247 |
map_btn = gr.Button("Mapping", interactive=False, scale=1)
|
248 |
|
249 |
# ------------------------mapping------------------------
|
250 |
-
# description for stage1-123
|
251 |
desc_md = gr.Markdown(visible=False)
|
252 |
-
#
|
253 |
with gr.Row():
|
254 |
tpl_img = gr.Image("./template_montage.png", label="Template channels", visible=False)
|
255 |
mapped_img = gr.Image(label="Input channels", visible=False)
|
256 |
-
#
|
257 |
radio_group = gr.Radio(elem_id="radio-group", visible=False)
|
258 |
-
#
|
259 |
with gr.Row():
|
260 |
in_fillmode = gr.Dropdown(choices=["mean", "zero"],
|
261 |
value="mean",
|
@@ -264,6 +261,8 @@ with gr.Blocks() as demo:
|
|
264 |
scale=2)
|
265 |
fillmode_btn = gr.Button("OK", visible=False, scale=1)
|
266 |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
|
|
|
|
|
267 |
|
268 |
with gr.Row():
|
269 |
clear_btn = gr.Button("Clear", visible=False)
|
@@ -280,7 +279,9 @@ with gr.Blocks() as demo:
|
|
280 |
("ART", "EEGART"),
|
281 |
("IC-U-Net", "ICUNet"),
|
282 |
("IC-U-Net++", "UNetpp"),
|
283 |
-
("IC-U-Net-Attn", "AttUnet")
|
|
|
|
|
284 |
value="EEGART",
|
285 |
label="Model",
|
286 |
scale=2)
|
@@ -303,7 +304,7 @@ with gr.Blocks() as demo:
|
|
303 |
with gr.Tab("README"):
|
304 |
gr.Markdown(readme)
|
305 |
|
306 |
-
#demo.load(js=
|
307 |
|
308 |
# verify that all required inputs have been provided
|
309 |
@gr.on(triggers = [in_data_file.upload, in_data_file.clear, in_loc_file.upload, in_loc_file.clear, in_samplerate.change],
|
@@ -316,317 +317,233 @@ with gr.Blocks() as demo:
|
|
316 |
|
317 |
|
318 |
# +========================================================================================+
|
319 |
-
# |
|
320 |
# +========================================================================================+
|
321 |
def reset_all(in_data, in_loc, samplerate):
|
322 |
# establish a new folder for the current session
|
323 |
-
|
324 |
try:
|
325 |
-
os.mkdir(
|
326 |
except OSError as e:
|
327 |
-
utils.dataDelete(
|
328 |
-
os.mkdir(
|
329 |
print(e)
|
330 |
# establish new folders for stage1 and stage2
|
331 |
-
os.mkdir(
|
332 |
-
os.mkdir(
|
333 |
|
334 |
# initialize channel_info, app_info
|
335 |
channel_info = {}
|
336 |
app_info = {
|
337 |
-
"
|
338 |
"sampleRate" : int(samplerate),
|
339 |
-
#"currentStage" : "stage1",
|
340 |
"stage1" : {
|
341 |
-
"
|
342 |
-
"
|
343 |
"input_data" : in_data,
|
344 |
"input_loc" : in_loc,
|
345 |
"input_montage" : "",
|
346 |
"mapped_montage" : ""
|
347 |
},
|
348 |
-
"state" :
|
349 |
"fillingCount" : None,
|
350 |
"totalFillingNum" : None,
|
351 |
-
"newOrder" : None,
|
352 |
"unassignedInputs" : None,
|
353 |
-
"missingTemplates" : None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
},
|
355 |
"stage2" : {
|
356 |
-
"
|
357 |
-
"
|
358 |
"output_data" : ""
|
359 |
},
|
360 |
-
|
361 |
-
"totalBatchNum" : None,
|
362 |
-
"newOrder" : None,
|
363 |
-
"unassignedInputs" : None
|
364 |
}
|
365 |
}
|
366 |
# reset layout
|
367 |
return {app_info_json : app_info,
|
368 |
channel_info_json : channel_info,
|
369 |
-
#
|
370 |
map_btn : gr.Button(interactive=False),
|
371 |
desc_md : gr.Markdown(visible=False),
|
|
|
372 |
tpl_img : gr.Image(visible=False),
|
373 |
mapped_img : gr.Image(value=None, visible=False),
|
374 |
radio_group : gr.Radio(choices=[], value=[], label="", visible=False),
|
375 |
-
in_fillmode : gr.Dropdown(value="mean", visible=False),
|
376 |
-
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
|
377 |
-
fillmode_btn : gr.Button(visible=False),
|
378 |
clear_btn : gr.Button(visible=False),
|
379 |
step2_btn : gr.Button(visible=False),
|
|
|
|
|
|
|
380 |
step3_btn : gr.Button(visible=False),
|
381 |
-
|
382 |
-
#
|
383 |
run_btn : gr.Button(interactive=False),
|
384 |
batch_md : gr.Markdown(visible=False),
|
385 |
out_data_file : gr.File(visible=False)}
|
386 |
|
387 |
|
388 |
# +========================================================================================+
|
389 |
-
# |
|
390 |
# +========================================================================================+
|
391 |
-
def
|
392 |
-
|
393 |
-
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
394 |
-
tpl_dict = channel_info["templateDict"]
|
395 |
-
in_dict = channel_info["inputDict"]
|
396 |
-
tpl_order = channel_info["templateOrder"]
|
397 |
-
in_order = channel_info["inputOrder"]
|
398 |
-
|
399 |
-
# get template and input's 2d coords
|
400 |
-
tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
|
401 |
-
tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
|
402 |
-
in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
|
403 |
-
in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
|
404 |
-
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
405 |
-
in_coords = np.vstack((in_x, in_y)).T
|
406 |
-
|
407 |
-
# get template's head figure
|
408 |
-
tpl_fig = tpl_montage.plot()
|
409 |
-
tpl_ax = tpl_fig.axes[0]
|
410 |
-
lines = tpl_ax.lines
|
411 |
-
head_lines = []
|
412 |
-
for line in lines:
|
413 |
-
x, y = line.get_data()
|
414 |
-
head_lines.append((x,y))
|
415 |
-
plt.close()
|
416 |
-
|
417 |
-
# -------------------------plot input montage------------------------------
|
418 |
-
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
419 |
-
ax = fig.add_subplot(111)
|
420 |
-
fig.tight_layout()
|
421 |
-
ax.set_aspect('equal')
|
422 |
-
ax.axis('off')
|
423 |
-
|
424 |
-
# plot template's head
|
425 |
-
for x, y in head_lines:
|
426 |
-
ax.plot(x, y, color='black', linewidth=1.0)
|
427 |
-
# plot input channels on it
|
428 |
-
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
429 |
-
for i, channel in enumerate(in_order):
|
430 |
-
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
|
431 |
-
# save input_montage
|
432 |
-
fig.savefig(filename1)
|
433 |
-
|
434 |
-
# ---------------------------add indications-------------------------------
|
435 |
-
indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
|
436 |
-
|
437 |
-
# plot unmatched input channels in red
|
438 |
-
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
439 |
-
for i in indices:
|
440 |
-
ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
|
441 |
-
# save mapped_montage
|
442 |
-
fig.savefig(filename2)
|
443 |
-
|
444 |
-
# -------------------------------------------------------------------------
|
445 |
-
# store the template and input channels' display position (in px).
|
446 |
-
tpl_coords = ax.transData.transform(tpl_coords)
|
447 |
-
in_coords = ax.transData.transform(in_coords)
|
448 |
-
plt.close()
|
449 |
-
|
450 |
-
for i, channel in enumerate(tpl_order):
|
451 |
-
css_left = (tpl_coords[i,0]-11)/6.4
|
452 |
-
css_bottom = (tpl_coords[i,1]-7)/6.4
|
453 |
-
tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
454 |
-
|
455 |
-
for i, channel in enumerate(in_order):
|
456 |
-
css_left = (in_coords[i,0]-11)/6.4
|
457 |
-
css_bottom = (in_coords[i,1]-7)/6.4
|
458 |
-
in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
459 |
-
|
460 |
-
channel_info.update({
|
461 |
-
"templateDict" : tpl_dict,
|
462 |
-
"inputDict" : in_dict
|
463 |
-
})
|
464 |
-
return channel_info
|
465 |
-
|
466 |
-
def mapping_result(app_info, channel_info):
|
467 |
stage1_info = app_info["stage1"]
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
# -------------------------determine the next step--------------------------
|
480 |
-
|
481 |
-
in_num = len(channel_info["inputOrder"])
|
482 |
-
matched_num = 30 - len(stage1_info["missingTemplates"])
|
483 |
-
|
484 |
-
# if the in_channels has all the 30 tpl_channels (in_num>=30)
|
485 |
-
# -> stage2
|
486 |
-
if matched_num == 30:
|
487 |
-
stage1_info["state"] = "finished"
|
488 |
-
gr.Info('The mapping process has been finished.')
|
489 |
-
|
490 |
-
if in_num == 30:
|
491 |
-
md = """
|
492 |
-
---
|
493 |
-
### Step1: Initial Matching and Rescaling
|
494 |
-
Below is the result of mapping your channels to our template channels based on their names.
|
495 |
-
"""
|
496 |
-
else:
|
497 |
-
md = """
|
498 |
-
---
|
499 |
-
### Step1: Initial Matching and Rescaling
|
500 |
-
Below is the result of mapping your channels to our template channels based on their names.
|
501 |
-
- channels highlighted in red are those that do not match any template channels.
|
502 |
-
"""
|
503 |
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
|
|
|
|
|
|
518 |
md = """
|
519 |
---
|
520 |
### Step1: Initial Matching and Rescaling
|
521 |
-
Below is the result of mapping your channels to our template channels based on their names.
|
522 |
-
- channels highlighted in red are those that do not match any template channels.
|
523 |
"""
|
524 |
-
|
525 |
-
# if in_num < 30, but all of them can match to some tpl_channels
|
526 |
-
# -> directly use fillmode to fill the remaining tpl_channels
|
527 |
-
elif in_num == matched_num:
|
528 |
-
stage1_info["state"] = "step3-initializing"
|
529 |
md = """
|
530 |
---
|
531 |
### Step1: Initial Matching and Rescaling
|
532 |
-
Below is the result of mapping your channels to our template channels based on their names.
|
|
|
533 |
"""
|
534 |
|
|
|
535 |
app_info["stage1"] = stage1_info
|
536 |
-
|
537 |
channel_info_json : channel_info,
|
538 |
map_btn : gr.Button(interactive=True),
|
539 |
-
desc_md : gr.Markdown(md
|
540 |
tpl_img : gr.Image(visible=True),
|
541 |
mapped_img : gr.Image(value=filename2, visible=True),
|
542 |
next_btn : gr.Button(visible=True)}
|
543 |
-
|
544 |
-
start_stage1 = map_btn.click(
|
545 |
-
fn = reset_all,
|
546 |
-
inputs = [in_data_file, in_loc_file, in_samplerate],
|
547 |
-
outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_img, mapped_img, radio_group,
|
548 |
-
in_fillmode, chkbox_group, fillmode_btn, clear_btn, step2_btn, step3_btn, next_btn,
|
549 |
-
run_btn, batch_md, out_data_file]
|
550 |
-
).success(
|
551 |
-
fn = mapping_stage1,
|
552 |
-
inputs = [app_info_json, channel_info_json],
|
553 |
-
outputs = [app_info_json, channel_info_json, desc_md]
|
554 |
-
).success(
|
555 |
-
fn = mapping_result,
|
556 |
-
inputs = [app_info_json, channel_info_json],
|
557 |
-
outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_img, mapped_img, next_btn, run_btn]
|
558 |
-
)
|
559 |
-
|
560 |
-
|
561 |
-
# +========================================================================================+
|
562 |
-
# | manage step transition |
|
563 |
-
# +========================================================================================+
|
564 |
-
def init_next_step(app_info, channel_info, selected_radio, selected_chkbox):
|
565 |
-
stage1_info = app_info["stage1"]
|
566 |
|
567 |
-
#
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
---
|
572 |
-
### Step2: Forwarding Unmatched Channels
|
573 |
-
Select one of your unmatched channels to forward its data to the empty template channel
|
574 |
-
currently indicated in red.
|
575 |
-
"""
|
576 |
|
577 |
-
#
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
585 |
|
586 |
-
|
587 |
-
#
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
|
625 |
-
#
|
626 |
elif stage1_info["state"] == "step2-selecting":
|
627 |
|
628 |
-
#
|
629 |
-
|
630 |
# check if the user has selected an in_channel to forward to the previous target tpl_channel
|
631 |
if selected_radio != []:
|
632 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
@@ -634,40 +551,51 @@ with gr.Blocks() as demo:
|
|
634 |
|
635 |
# store the index of the in_channel
|
636 |
selected_idx = channel_info["inputDict"][selected_radio]["index"]
|
637 |
-
stage1_info["newOrder"][prev_target_idx] = [selected_idx]
|
|
|
638 |
# mark the in_channel as assigned and tpl_channel as matched
|
639 |
channel_info["templateDict"][prev_target_name]["matched"] = True
|
640 |
channel_info["inputDict"][selected_radio]["assigned"] = True
|
641 |
-
print(prev_target_name, '<-', selected_radio)
|
642 |
-
|
643 |
-
# ------------------------update information for the next step-------------------------
|
644 |
|
|
|
645 |
# update the list of unassignedInputs to exclude the selected in_channel of the previous round
|
646 |
-
stage1_info["unassignedInputs"] =
|
647 |
-
|
648 |
# update the list of missingTemplates to exclude those filled in step2
|
649 |
-
stage1_info["missingTemplates"] =
|
650 |
-
|
651 |
|
652 |
-
#
|
653 |
-
#
|
|
|
654 |
if len(stage1_info["missingTemplates"]) == 0:
|
655 |
-
#print('step2 ->
|
656 |
-
|
657 |
-
|
|
|
|
|
|
|
658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
659 |
app_info["stage1"] = stage1_info
|
660 |
-
|
|
|
661 |
channel_info_json : channel_info,
|
662 |
-
desc_md : gr.Markdown(
|
663 |
radio_group : gr.Radio(visible=False),
|
|
|
664 |
clear_btn : gr.Button(visible=False),
|
665 |
next_btn : gr.Button(visible=False),
|
666 |
run_btn : gr.Button(interactive=True)}
|
667 |
-
|
668 |
-
# -> stage1-3
|
669 |
else:
|
670 |
-
#print('step2 -> step3')
|
671 |
md = """
|
672 |
---
|
673 |
### Step3: Filling Remaining Template Channels
|
@@ -676,8 +604,9 @@ with gr.Blocks() as demo:
|
|
676 |
remaining empty template channels.
|
677 |
"""
|
678 |
|
|
|
679 |
app_info["stage1"] = stage1_info
|
680 |
-
|
681 |
channel_info_json : channel_info,
|
682 |
desc_md : gr.Markdown(md),
|
683 |
radio_group : gr.Radio(visible=False),
|
@@ -686,14 +615,86 @@ with gr.Blocks() as demo:
|
|
686 |
clear_btn : gr.Button(visible=False),
|
687 |
next_btn : gr.Button(visible=False)}
|
688 |
|
689 |
-
#
|
690 |
-
elif stage1_info["state"] == "step3-
|
691 |
-
#print('step3 -> stage2')
|
692 |
-
stage1_info["state"] = "finished"
|
693 |
-
gr.Info('The mapping process has been finished.')
|
694 |
|
695 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
696 |
|
|
|
697 |
# if the user didn't uncheck all in_channel checkboxes
|
698 |
if selected_chkbox != []:
|
699 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
@@ -701,22 +702,36 @@ with gr.Blocks() as demo:
|
|
701 |
|
702 |
# store the indices of the in_channels
|
703 |
selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
|
704 |
-
stage1_info["newOrder"][prev_target_idx] = selected_indices
|
705 |
-
#print(f'{prev_target_name}({prev_target_idx}): {
|
706 |
-
#
|
|
|
|
|
|
|
|
|
|
|
707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
708 |
app_info["stage1"] = stage1_info
|
709 |
-
|
710 |
-
|
|
|
|
|
711 |
chkbox_group : gr.CheckboxGroup(visible=False),
|
712 |
next_btn : gr.Button(visible=False),
|
|
|
713 |
run_btn : gr.Button(interactive=True)}
|
714 |
|
715 |
next_btn.click(
|
716 |
fn = init_next_step,
|
717 |
-
inputs = [app_info_json, channel_info_json, radio_group, chkbox_group],
|
718 |
-
outputs = [app_info_json, channel_info_json, desc_md, tpl_img, mapped_img, radio_group,
|
719 |
-
in_fillmode,
|
720 |
).success(
|
721 |
fn = None,
|
722 |
js = init_js,
|
@@ -726,9 +741,25 @@ with gr.Blocks() as demo:
|
|
726 |
|
727 |
|
728 |
# +========================================================================================+
|
729 |
-
# |
|
730 |
# +========================================================================================+
|
731 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
732 |
@radio_group.select(inputs = app_info_json, outputs = [step2_btn, next_btn])
|
733 |
def determine_button(app_info):
|
734 |
stage1_info = app_info["stage1"]
|
@@ -751,8 +782,7 @@ with gr.Blocks() as demo:
|
|
751 |
def update_radio(app_info, channel_info, selected):
|
752 |
stage1_info = app_info["stage1"]
|
753 |
|
754 |
-
# ----------------------store information before the button click
|
755 |
-
|
756 |
# check if the user has selected an in_channel to forward to the previous target tpl_channel
|
757 |
if selected != []:
|
758 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
@@ -760,18 +790,18 @@ with gr.Blocks() as demo:
|
|
760 |
|
761 |
# store the index of the selected in_channel
|
762 |
selected_idx = channel_info["inputDict"][selected]["index"]
|
763 |
-
stage1_info["newOrder"][prev_target_idx] = [selected_idx]
|
|
|
764 |
# mark the in_channel as assigned and tpl_channel as matched
|
765 |
channel_info["templateDict"][prev_target_name]["matched"] = True
|
766 |
channel_info["inputDict"][selected]["assigned"] = True
|
767 |
-
print(prev_target_name, '<-', selected)
|
768 |
|
769 |
-
# ------------------------update information for the new round
|
770 |
stage1_info["fillingCount"] += 1
|
771 |
|
772 |
# update the list of unassignedInputs to exclude the selected in_channel of the previous round
|
773 |
-
stage1_info["unassignedInputs"] = [
|
774 |
-
if channel_info["inputDict"][channel]["assigned"]==False]
|
775 |
# update the progress indication label
|
776 |
target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
777 |
radio_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
|
@@ -784,7 +814,7 @@ with gr.Blocks() as demo:
|
|
784 |
radio_group : gr.Radio(choices=stage1_info["unassignedInputs"],
|
785 |
value=[], label=radio_label),
|
786 |
step2_btn : gr.Button(visible=False),
|
787 |
-
next_btn : gr.Button(
|
788 |
else:
|
789 |
return {app_info_json : app_info,
|
790 |
channel_info_json : channel_info,
|
@@ -804,71 +834,12 @@ with gr.Blocks() as demo:
|
|
804 |
|
805 |
|
806 |
# +========================================================================================+
|
807 |
-
# |
|
808 |
-
# +========================================================================================+
|
809 |
-
def fill_value(app_info, channel_info, fillmode):
|
810 |
-
stage1_info = app_info["stage1"]
|
811 |
-
|
812 |
-
if fillmode == "zero":
|
813 |
-
stage1_info["state"] = "finished"
|
814 |
-
gr.Info('The mapping process has been finished.')
|
815 |
-
|
816 |
-
app_info["stage1"] = stage1_info
|
817 |
-
return {app_info_json : app_info,
|
818 |
-
desc_md : gr.Markdown(visible=False),
|
819 |
-
in_fillmode : gr.Dropdown(visible=False),
|
820 |
-
fillmode_btn : gr.Button(visible=False),
|
821 |
-
run_btn : gr.Button(interactive=True)}
|
822 |
-
|
823 |
-
elif fillmode == "mean":
|
824 |
-
md = """
|
825 |
-
---
|
826 |
-
### Step3: Fill the remaining template channels
|
827 |
-
The current empty template channel, indicated in red, will be filled with the average
|
828 |
-
value of the data from the selected channels. (By default, the 4 nearest channels are pre-selected.)
|
829 |
-
"""
|
830 |
-
|
831 |
-
# find the 4-NN in_channels for each of the unmatched tpl_channels
|
832 |
-
new_idx = find_neighbors(channel_info, stage1_info["missingTemplates"], stage1_info["newOrder"])
|
833 |
-
|
834 |
-
stage1_info.update({
|
835 |
-
"state" : "step3-selecting",
|
836 |
-
"newOrder" : new_idx,
|
837 |
-
"fillingCount" : 1,
|
838 |
-
"totalFillingNum" : len(stage1_info["missingTemplates"])
|
839 |
-
})
|
840 |
-
|
841 |
-
# initialize the progress indication label
|
842 |
-
target_name = stage1_info["missingTemplates"][0]
|
843 |
-
target_idx = channel_info["templateDict"][target_name]["index"]
|
844 |
-
chkbox_value = stage1_info["newOrder"][target_idx]
|
845 |
-
chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
|
846 |
-
chkbox_label = "{} (1/{})".format(target_name, stage1_info["totalFillingNum"])
|
847 |
-
|
848 |
-
app_info["stage1"] = stage1_info
|
849 |
-
# determine which button to display
|
850 |
-
if stage1_info["totalFillingNum"] == 1:
|
851 |
-
return {app_info_json : app_info,
|
852 |
-
desc_md : gr.Markdown(md),
|
853 |
-
in_fillmode : gr.Dropdown(visible=False),
|
854 |
-
fillmode_btn : gr.Button(visible=False),
|
855 |
-
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
|
856 |
-
value=chkbox_value, label=chkbox_label, visible=True),
|
857 |
-
next_btn : gr.Button(visible=True)}
|
858 |
-
else:
|
859 |
-
return {app_info_json : app_info,
|
860 |
-
desc_md : gr.Markdown(md),
|
861 |
-
in_fillmode : gr.Dropdown(visible=False),
|
862 |
-
fillmode_btn : gr.Button(visible=False),
|
863 |
-
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
|
864 |
-
value=chkbox_value, label=chkbox_label, visible=True),
|
865 |
-
step3_btn : gr.Button(visible=True)}
|
866 |
-
|
867 |
def update_chkbox(app_info, channel_info, selected):
|
868 |
stage1_info = app_info["stage1"]
|
869 |
|
870 |
-
# ----------------------store information before the button click
|
871 |
-
|
872 |
# if the user didn't uncheck all in_channel checkboxes
|
873 |
if selected != []:
|
874 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
@@ -876,16 +847,16 @@ with gr.Blocks() as demo:
|
|
876 |
|
877 |
# store the indices of the selected in_channels
|
878 |
selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected]
|
879 |
-
stage1_info["newOrder"][prev_target_idx] = selected_indices
|
880 |
-
#print('
|
881 |
|
882 |
-
# ------------------------update information for the new round
|
883 |
stage1_info["fillingCount"] += 1
|
884 |
|
885 |
# update the progress indication label
|
886 |
target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
887 |
target_idx = channel_info["templateDict"][target_name]["index"]
|
888 |
-
chkbox_value = stage1_info["newOrder"][target_idx]
|
889 |
chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
|
890 |
chkbox_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
|
891 |
|
@@ -895,15 +866,16 @@ with gr.Blocks() as demo:
|
|
895 |
return {app_info_json : app_info,
|
896 |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label),
|
897 |
step3_btn : gr.Button(visible=False),
|
898 |
-
next_btn : gr.Button(
|
899 |
else:
|
900 |
return {app_info_json : app_info,
|
901 |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)}
|
902 |
|
903 |
fillmode_btn.click(
|
904 |
-
fn =
|
905 |
-
|
906 |
-
outputs = [app_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
|
|
|
907 |
).success(
|
908 |
fn = None,
|
909 |
js = init_js,
|
@@ -924,59 +896,45 @@ with gr.Blocks() as demo:
|
|
924 |
|
925 |
|
926 |
# +========================================================================================+
|
927 |
-
# |
|
928 |
# +========================================================================================+
|
929 |
-
def reset_run(app_info,
|
930 |
stage1_info = app_info["stage1"]
|
931 |
stage2_info = app_info["stage2"]
|
932 |
|
933 |
-
# delete the previous folder of
|
934 |
-
filepath = stage2_info["
|
935 |
utils.dataDelete(filepath)
|
936 |
-
# establish a new folder for
|
937 |
-
new_filepath = app_info["
|
938 |
os.mkdir(new_filepath)
|
939 |
# generate the output filename
|
940 |
-
filename = stage1_info["
|
941 |
filename = os.path.basename(str(filename))
|
942 |
new_filename = os.path.splitext(filename)[0]+'_'+modelname+'.csv'
|
943 |
|
944 |
-
|
945 |
-
|
946 |
-
|
947 |
-
|
948 |
-
unassigned_num = len(stage1_info["unassignedInputs"])
|
949 |
-
batch_num = math.ceil(unassigned_num/30) + 1
|
950 |
-
|
951 |
-
app_info.update({
|
952 |
-
#"currentStage" : "stage2",
|
953 |
-
"stage2" : {
|
954 |
-
"filepath" : new_filepath,
|
955 |
-
"filenames" : {
|
956 |
-
"output_data" : new_filepath + new_filename
|
957 |
-
},
|
958 |
-
#"state" : "initializing",
|
959 |
-
"totalBatchNum" : batch_num,
|
960 |
-
"newOrder" : [[]]*30,
|
961 |
-
"unassignedInputs" : stage1_info["unassignedInputs"]
|
962 |
}
|
963 |
})
|
|
|
964 |
return {app_info_json : app_info,
|
965 |
-
channel_info_json : channel_info,
|
966 |
#run_btn : gr.Button(interactive=False),
|
967 |
batch_md : gr.Markdown(visible=False),
|
968 |
out_data_file : gr.File(visible=False)}
|
969 |
|
970 |
-
def run_model(app_info,
|
971 |
stage1_info = app_info["stage1"]
|
972 |
stage2_info = app_info["stage2"]
|
973 |
|
974 |
-
filepath = stage2_info["
|
975 |
samplerate = app_info["sampleRate"]
|
976 |
-
filename = stage1_info["
|
977 |
-
new_filename = stage2_info["
|
978 |
|
979 |
-
#
|
980 |
break_flag = False
|
981 |
|
982 |
# run the model multiple times until all in_channels are reconstructed
|
@@ -988,7 +946,7 @@ with gr.Blocks() as demo:
|
|
988 |
#utils.dataDelete(filepath+"temp_data/")
|
989 |
#os.mkdir(filepath+"temp_data/")
|
990 |
except FileNotFoundError:
|
991 |
-
|
992 |
break_flag = True
|
993 |
break
|
994 |
except OSError as e:
|
@@ -998,33 +956,32 @@ with gr.Blocks() as demo:
|
|
998 |
md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
|
999 |
yield {batch_md : gr.Markdown(md, visible=True)}
|
1000 |
|
1001 |
-
#
|
1002 |
-
|
1003 |
-
|
1004 |
-
else:
|
1005 |
-
# if this is not the first time running the model, the in_channels that have
|
1006 |
-
# not been reconstructed yet will be optimally mapped to the template.
|
1007 |
-
stage2_info, channel_info = mapping_stage2(stage2_info, channel_info)
|
1008 |
-
new_idx = stage2_info["newOrder"]
|
1009 |
-
print('unassigned num:', len(stage2_info["unassignedInputs"]))
|
1010 |
-
|
1011 |
# ----------------------------------------------------------------------
|
1012 |
try:
|
1013 |
# step1: Reorder input data
|
1014 |
-
|
|
|
|
|
|
|
1015 |
# step2: Data preprocessing
|
1016 |
total_file_num = utils.preprocessing(filepath+"temp_data/", "mapped.csv", samplerate)
|
1017 |
# step3: Signal reconstruction
|
1018 |
utils.reconstruct(modelname, total_file_num, filepath+"temp_data/", "denoised.csv", samplerate)
|
|
|
|
|
|
|
1019 |
# step4: Restore original order
|
1020 |
-
|
|
|
1021 |
except FileNotFoundError:
|
1022 |
-
|
1023 |
break_flag = True
|
1024 |
break
|
1025 |
# ----------------------------------------------------------------------
|
1026 |
utils.dataDelete(filepath+"temp_data/")
|
1027 |
-
app_info["stage2"] = stage2_info
|
1028 |
|
1029 |
if break_flag == True:
|
1030 |
yield {batch_md : gr.Markdown(visible=False)}
|
@@ -1035,12 +992,12 @@ with gr.Blocks() as demo:
|
|
1035 |
|
1036 |
run_btn.click(
|
1037 |
fn = reset_run,
|
1038 |
-
inputs = [app_info_json,
|
1039 |
-
outputs = [app_info_json,
|
1040 |
|
1041 |
).success(
|
1042 |
fn = run_model,
|
1043 |
-
inputs = [app_info_json,
|
1044 |
outputs = [run_btn, batch_md, out_data_file]
|
1045 |
)
|
1046 |
|
|
|
1 |
+
import utils
|
2 |
+
import app_utils
|
3 |
import gradio as gr
|
|
|
4 |
import os
|
5 |
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
readme = """
|
8 |
|
|
|
27 |
### Step3: Filling Remaining Template Channels
|
28 |
To run the models successfully, we need to ensure that all 30 template channels are filled. In this step, you are required to select one of the methods provided below to fill the remaining empty template channels:
|
29 |
- **Mean** method: Each empty template channel is filled with the average value of data from the nearest input channels. By default, the 4 closest input channels (determined after aligning your montage to the template's scale using TPS) are selected for this averaging process. On the interface, you will see checkboxes displayed above each of your channel. The 4 nearest channels are pre-selected by default for each empty template channels, but you can modify these selections as needed. If you uncheck all the checkboxes for a particular template channel, it will be filled with zeros.
|
30 |
+
- **Zero** method: All empty template channels are filled with zeros.
|
31 |
+
Choose the method that best suits your needs, considering that the model's performance may vary depending on the method used.
|
32 |
+
|
33 |
+
### Step4: Auto-mapping Remaining Channels
|
34 |
+
After completing the initial mapping steps, any channels that are not yet assigned to a template will be processed in this step. These remaining channels will be automatically mapped in batches, with a batch size of up to 30 channels. If the final batch contains fewer than 30 channels, the **Mean** method from Step3 will be applied to fill the remaining template channels.
|
35 |
+
|
36 |
+
|
37 |
+
### Mapping Result
|
38 |
|
39 |
## 2. Decode data
|
40 |
In this phase, you can select which model to use for denoising your EEG data. Detailed information about the models can be found in the other tabs.
|
|
|
62 |
selector = "#radio-group > div:nth-of-type(2)";
|
63 |
//classname = "radio";
|
64 |
attribute = "value";
|
65 |
+
}else if(stage1_info.state == "step3-2-selecting"){
|
66 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
67 |
//classname = "chkbox";
|
68 |
attribute = "name";
|
|
|
76 |
aspect-ratio: 1;
|
77 |
//width: 560px;
|
78 |
//height: 560px;
|
79 |
+
background: url("file=${stage1_info.fileNames.input_montage}");
|
80 |
background-size: contain;
|
81 |
|
82 |
`;
|
|
|
163 |
item.className = "";
|
164 |
item.querySelector(":scope > span").innerText = "";
|
165 |
});
|
166 |
+
}else if(stage1_info.state == "step3-2-selecting"){
|
167 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
168 |
}else return;
|
169 |
|
|
|
245 |
map_btn = gr.Button("Mapping", interactive=False, scale=1)
|
246 |
|
247 |
# ------------------------mapping------------------------
|
|
|
248 |
desc_md = gr.Markdown(visible=False)
|
249 |
+
# step1 : initial mapping abd rescaling
|
250 |
with gr.Row():
|
251 |
tpl_img = gr.Image("./template_montage.png", label="Template channels", visible=False)
|
252 |
mapped_img = gr.Image(label="Input channels", visible=False)
|
253 |
+
# step2 : forward unmatched input channels to empty template channels
|
254 |
radio_group = gr.Radio(elem_id="radio-group", visible=False)
|
255 |
+
# step3 : fill the remaining template channels
|
256 |
with gr.Row():
|
257 |
in_fillmode = gr.Dropdown(choices=["mean", "zero"],
|
258 |
value="mean",
|
|
|
261 |
scale=2)
|
262 |
fillmode_btn = gr.Button("OK", visible=False, scale=1)
|
263 |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
|
264 |
+
# step4 : mapping result
|
265 |
+
out_json_file = gr.File(label="Mapping result", visible=False)
|
266 |
|
267 |
with gr.Row():
|
268 |
clear_btn = gr.Button("Clear", visible=False)
|
|
|
279 |
("ART", "EEGART"),
|
280 |
("IC-U-Net", "ICUNet"),
|
281 |
("IC-U-Net++", "UNetpp"),
|
282 |
+
("IC-U-Net-Attn", "AttUnet"),
|
283 |
+
"(mapped data)",
|
284 |
+
"(denoised data)"],
|
285 |
value="EEGART",
|
286 |
label="Model",
|
287 |
scale=2)
|
|
|
304 |
with gr.Tab("README"):
|
305 |
gr.Markdown(readme)
|
306 |
|
307 |
+
#demo.load(js=js)
|
308 |
|
309 |
# verify that all required inputs have been provided
|
310 |
@gr.on(triggers = [in_data_file.upload, in_data_file.clear, in_loc_file.upload, in_loc_file.clear, in_samplerate.change],
|
|
|
317 |
|
318 |
|
319 |
# +========================================================================================+
|
320 |
+
# | Stage1: channel mapping |
|
321 |
# +========================================================================================+
|
322 |
def reset_all(in_data, in_loc, samplerate):
|
323 |
# establish a new folder for the current session
|
324 |
+
rootpath = os.path.dirname(str(in_data))
|
325 |
try:
|
326 |
+
os.mkdir(rootpath+"/session_data/")
|
327 |
except OSError as e:
|
328 |
+
utils.dataDelete(rootpath+"/session_data/")
|
329 |
+
os.mkdir(rootpath+"/session_data/")
|
330 |
print(e)
|
331 |
# establish new folders for stage1 and stage2
|
332 |
+
os.mkdir(rootpath+"/session_data/stage1/")
|
333 |
+
os.mkdir(rootpath+"/session_data/stage2/")
|
334 |
|
335 |
# initialize channel_info, app_info
|
336 |
channel_info = {}
|
337 |
app_info = {
|
338 |
+
"rootPath" : rootpath+"/session_data/",
|
339 |
"sampleRate" : int(samplerate),
|
|
|
340 |
"stage1" : {
|
341 |
+
"filePath" : rootpath+"/session_data/stage1/",
|
342 |
+
"fileNames" : {
|
343 |
"input_data" : in_data,
|
344 |
"input_loc" : in_loc,
|
345 |
"input_montage" : "",
|
346 |
"mapped_montage" : ""
|
347 |
},
|
348 |
+
"state" : "step1-initializing",
|
349 |
"fillingCount" : None,
|
350 |
"totalFillingNum" : None,
|
|
|
351 |
"unassignedInputs" : None,
|
352 |
+
"missingTemplates" : None,
|
353 |
+
"mappingData" : [
|
354 |
+
{
|
355 |
+
"newOrder" : None,
|
356 |
+
"fillFlags" : None,
|
357 |
+
#"channelUsageNum" : None
|
358 |
+
}
|
359 |
+
]
|
360 |
},
|
361 |
"stage2" : {
|
362 |
+
"filePath" : rootpath+"/session_data/stage2/",
|
363 |
+
"fileNames" : {
|
364 |
"output_data" : ""
|
365 |
},
|
366 |
+
"totalBatchNum" : None
|
|
|
|
|
|
|
367 |
}
|
368 |
}
|
369 |
# reset layout
|
370 |
return {app_info_json : app_info,
|
371 |
channel_info_json : channel_info,
|
372 |
+
# --------------------Stage1-------------------------
|
373 |
map_btn : gr.Button(interactive=False),
|
374 |
desc_md : gr.Markdown(visible=False),
|
375 |
+
next_btn : gr.Button(visible=False),
|
376 |
tpl_img : gr.Image(visible=False),
|
377 |
mapped_img : gr.Image(value=None, visible=False),
|
378 |
radio_group : gr.Radio(choices=[], value=[], label="", visible=False),
|
|
|
|
|
|
|
379 |
clear_btn : gr.Button(visible=False),
|
380 |
step2_btn : gr.Button(visible=False),
|
381 |
+
in_fillmode : gr.Dropdown(value="mean", visible=False),
|
382 |
+
fillmode_btn : gr.Button(visible=False),
|
383 |
+
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
|
384 |
step3_btn : gr.Button(visible=False),
|
385 |
+
out_json_file : gr.File(value=None, visible=False),
|
386 |
+
# --------------------Stage2-------------------------
|
387 |
run_btn : gr.Button(interactive=False),
|
388 |
batch_md : gr.Markdown(visible=False),
|
389 |
out_data_file : gr.File(visible=False)}
|
390 |
|
391 |
|
392 |
# +========================================================================================+
|
393 |
+
# | manage step transition |
|
394 |
# +========================================================================================+
|
395 |
+
def init_next_step(app_info, channel_info, fillmode, selected_radio, selected_chkbox):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
stage1_info = app_info["stage1"]
|
397 |
+
stage2_info = app_info["stage2"]
|
398 |
+
filepath = stage1_info["filePath"]
|
399 |
+
|
400 |
+
# ========================================step0=========================================
|
401 |
+
# step0 to step1
|
402 |
+
if stage1_info["state"] == "step1-initializing":
|
403 |
+
#print('step0 -> step1')
|
404 |
+
|
405 |
+
# 1. match the names of in_channels and tpl_channels
|
406 |
+
yield {desc_md : gr.Markdown("Mapping...", visible=True)}
|
407 |
+
stage1_info, channel_info, tpl_montage, in_montage = app_utils.match_names(stage1_info, channel_info)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
|
409 |
+
# 2. rescale coordinates
|
410 |
+
yield {desc_md : gr.Markdown("Rescaling...")}
|
411 |
+
channel_info = app_utils.align_coords(channel_info, tpl_montage, in_montage)
|
412 |
+
|
413 |
+
# 3. generate and save figures of the montages
|
414 |
+
filename1 = filepath+"input_montage_"+str(random.randint(1,10000))+".png"
|
415 |
+
filename2 = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
|
416 |
+
channel_info = app_utils.save_figures(channel_info, tpl_montage, filename1, filename2)
|
417 |
+
stage1_info["fileNames"].update({
|
418 |
+
"input_montage" : filename1,
|
419 |
+
"mapped_montage" : filename2
|
420 |
+
})
|
421 |
+
|
422 |
+
# 4. matching result
|
423 |
+
# check if there are red dots (unmatched in_channels) on the input montage
|
424 |
+
unassigned_num = len(stage1_info["unassignedInputs"])
|
425 |
+
if unassigned_num == 0:
|
426 |
md = """
|
427 |
---
|
428 |
### Step1: Initial Matching and Rescaling
|
429 |
+
Below is the result of mapping your channels to our template channels based on their names.
|
|
|
430 |
"""
|
431 |
+
else:
|
|
|
|
|
|
|
|
|
432 |
md = """
|
433 |
---
|
434 |
### Step1: Initial Matching and Rescaling
|
435 |
+
Below is the result of mapping your channels to our template channels based on their names.
|
436 |
+
- channels highlighted in red are those that do not match any template channels.
|
437 |
"""
|
438 |
|
439 |
+
stage1_info["state"] = "step1-finished"
|
440 |
app_info["stage1"] = stage1_info
|
441 |
+
yield {app_info_json : app_info,
|
442 |
channel_info_json : channel_info,
|
443 |
map_btn : gr.Button(interactive=True),
|
444 |
+
desc_md : gr.Markdown(md),
|
445 |
tpl_img : gr.Image(visible=True),
|
446 |
mapped_img : gr.Image(value=filename2, visible=True),
|
447 |
next_btn : gr.Button(visible=True)}
|
448 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
|
450 |
+
# ========================================step1=========================================
|
451 |
+
elif stage1_info["state"] == "step1-finished":
|
452 |
+
in_num = len(channel_info["inputOrder"])
|
453 |
+
matched_num = 30 - len(stage1_info["missingTemplates"])
|
|
|
|
|
|
|
|
|
|
|
454 |
|
455 |
+
# step1 to step4
|
456 |
+
# the in_channels has all the 30 tpl_channels (in_num>=30)
|
457 |
+
if matched_num == 30:
|
458 |
+
#print('step1 -> step4')
|
459 |
+
md = """
|
460 |
+
---
|
461 |
+
### Mapping result
|
462 |
+
(...)
|
463 |
+
"""
|
464 |
+
|
465 |
+
# finalize and save the mapping results
|
466 |
+
filename = filepath+"mapping_result.json"
|
467 |
+
stage1_info, stage2_info, channel_info = app_utils.mapping_result(
|
468 |
+
stage1_info, stage2_info, channel_info, filename)
|
469 |
+
#gr.Info('The mapping process has been finished.')
|
470 |
+
stage1_info["state"] = "finished"
|
471 |
+
app_info["stage1"] = stage1_info
|
472 |
+
app_info["stage2"] = stage2_info
|
473 |
+
yield {app_info_json : app_info,
|
474 |
+
channel_info_json : channel_info,
|
475 |
+
desc_md : gr.Markdown(md),
|
476 |
+
tpl_img : gr.Image(visible=False),
|
477 |
+
mapped_img : gr.Image(visible=False),
|
478 |
+
next_btn : gr.Button(visible=False),
|
479 |
+
out_json_file : gr.File(filename, visible=True),
|
480 |
+
run_btn : gr.Button(interactive=True)}
|
481 |
|
482 |
+
# step1 to step2
|
483 |
+
# matched_num < 30, and there're still some unmatched in_channels
|
484 |
+
elif in_num > matched_num:
|
485 |
+
#print('step1 -> step2')
|
486 |
+
md = """
|
487 |
+
---
|
488 |
+
### Step2: Forwarding Unmatched Channels
|
489 |
+
Select one of your unmatched channels to forward its data to the empty template channel
|
490 |
+
currently indicated in red.
|
491 |
+
"""
|
492 |
+
|
493 |
+
# initialize the progress indication label for step2
|
494 |
+
stage1_info.update({
|
495 |
+
"fillingCount" : 1,
|
496 |
+
"totalFillingNum" : len(stage1_info["missingTemplates"])
|
497 |
+
})
|
498 |
+
name = stage1_info["missingTemplates"][0]
|
499 |
+
label = "{} (1/{})".format(name, stage1_info["totalFillingNum"])
|
500 |
+
|
501 |
+
stage1_info["state"] = "step2-selecting"
|
502 |
+
app_info["stage1"] = stage1_info
|
503 |
+
# determine which button to display
|
504 |
+
if stage1_info["totalFillingNum"] == 1:
|
505 |
+
yield {app_info_json : app_info,
|
506 |
+
desc_md : gr.Markdown(md),
|
507 |
+
tpl_img : gr.Image(visible=False),
|
508 |
+
mapped_img : gr.Image(visible=False),
|
509 |
+
radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
|
510 |
+
clear_btn : gr.Button(visible=True)}
|
511 |
+
else:
|
512 |
+
yield {app_info_json : app_info,
|
513 |
+
desc_md : gr.Markdown(md),
|
514 |
+
tpl_img : gr.Image(visible=False),
|
515 |
+
mapped_img : gr.Image(visible=False),
|
516 |
+
radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
|
517 |
+
clear_btn : gr.Button(visible=True),
|
518 |
+
step2_btn : gr.Button(visible=True),
|
519 |
+
next_btn : gr.Button(visible=False)}
|
520 |
+
|
521 |
+
# step1 to step3-1
|
522 |
+
# in_num < 30, but all of them can match to some tpl_channels
|
523 |
+
elif in_num == matched_num:
|
524 |
+
#print('step1 -> step3-1')
|
525 |
+
md = """
|
526 |
+
---
|
527 |
+
### Step3: Filling Remaining Template Channels
|
528 |
+
To run the model successfully, we need to ensure that all 30 template channels are filled.
|
529 |
+
In this step, you are required to select one of the methods provided below to fill the
|
530 |
+
remaining empty template channels.
|
531 |
+
"""
|
532 |
+
|
533 |
+
stage1_info["state"] = "step3-select-method"
|
534 |
+
app_info["stage1"] = stage1_info
|
535 |
+
yield {app_info_json : app_info,
|
536 |
+
desc_md : gr.Markdown(md),
|
537 |
+
tpl_img : gr.Image(visible=False),
|
538 |
+
mapped_img : gr.Image(visible=False),
|
539 |
+
in_fillmode : gr.Dropdown(visible=True),
|
540 |
+
fillmode_btn : gr.Button(visible=True),
|
541 |
+
next_btn : gr.Button(visible=False)}
|
542 |
|
543 |
+
# ========================================step2=========================================
|
544 |
elif stage1_info["state"] == "step2-selecting":
|
545 |
|
546 |
+
# --------------------store information before the button click---------------------
|
|
|
547 |
# check if the user has selected an in_channel to forward to the previous target tpl_channel
|
548 |
if selected_radio != []:
|
549 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
|
|
551 |
|
552 |
# store the index of the in_channel
|
553 |
selected_idx = channel_info["inputDict"][selected_radio]["index"]
|
554 |
+
stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = [selected_idx]
|
555 |
+
stage1_info["mappingData"][0]["fillFlags"][prev_target_idx] = False
|
556 |
# mark the in_channel as assigned and tpl_channel as matched
|
557 |
channel_info["templateDict"][prev_target_name]["matched"] = True
|
558 |
channel_info["inputDict"][selected_radio]["assigned"] = True
|
559 |
+
#print(prev_target_name, '<-', selected_radio)
|
|
|
|
|
560 |
|
561 |
+
# -----------------------update information for the next step-----------------------
|
562 |
# update the list of unassignedInputs to exclude the selected in_channel of the previous round
|
563 |
+
stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"],
|
564 |
+
channel_info["inputDict"])
|
565 |
# update the list of missingTemplates to exclude those filled in step2
|
566 |
+
stage1_info["missingTemplates"] = app_utils.get_empty_templates(channel_info["templateOrder"],
|
567 |
+
channel_info["templateDict"])
|
568 |
|
569 |
+
# -----------------------------determine the next step------------------------------
|
570 |
+
# step2 to step4
|
571 |
+
# all the unmatched tpl_channels were filled by in_channels
|
572 |
if len(stage1_info["missingTemplates"]) == 0:
|
573 |
+
#print('step2 -> step4')
|
574 |
+
md = """
|
575 |
+
---
|
576 |
+
### Mapping result
|
577 |
+
(...)
|
578 |
+
"""
|
579 |
|
580 |
+
# finalize and save the mapping results
|
581 |
+
filename = filepath+"mapping_result.json"
|
582 |
+
stage1_info, stage2_info, channel_info = app_utils.mapping_result(
|
583 |
+
stage1_info, stage2_info, channel_info, filename)
|
584 |
+
#gr.Info('The mapping process has been finished.')
|
585 |
+
stage1_info["state"] = "finished"
|
586 |
app_info["stage1"] = stage1_info
|
587 |
+
app_info["stage2"] = stage2_info
|
588 |
+
yield {app_info_json : app_info,
|
589 |
channel_info_json : channel_info,
|
590 |
+
desc_md : gr.Markdown(md),
|
591 |
radio_group : gr.Radio(visible=False),
|
592 |
+
out_json_file : gr.File(filename, visible=True),
|
593 |
clear_btn : gr.Button(visible=False),
|
594 |
next_btn : gr.Button(visible=False),
|
595 |
run_btn : gr.Button(interactive=True)}
|
596 |
+
# step2 to step3-1
|
|
|
597 |
else:
|
598 |
+
#print('step2 -> step3-1')
|
599 |
md = """
|
600 |
---
|
601 |
### Step3: Filling Remaining Template Channels
|
|
|
604 |
remaining empty template channels.
|
605 |
"""
|
606 |
|
607 |
+
stage1_info["state"] = "step3-select-method"
|
608 |
app_info["stage1"] = stage1_info
|
609 |
+
yield {app_info_json : app_info,
|
610 |
channel_info_json : channel_info,
|
611 |
desc_md : gr.Markdown(md),
|
612 |
radio_group : gr.Radio(visible=False),
|
|
|
615 |
clear_btn : gr.Button(visible=False),
|
616 |
next_btn : gr.Button(visible=False)}
|
617 |
|
618 |
+
# =======================================step3-1========================================
|
619 |
+
elif stage1_info["state"] == "step3-select-method":
|
|
|
|
|
|
|
620 |
|
621 |
+
# step3-1 to step4
|
622 |
+
if fillmode == "zero":
|
623 |
+
#print('step3-1 -> step4')
|
624 |
+
md = """
|
625 |
+
---
|
626 |
+
### Mapping result
|
627 |
+
(...)
|
628 |
+
"""
|
629 |
+
|
630 |
+
# finalize and save the mapping results
|
631 |
+
filename = filepath+"mapping_result.json"
|
632 |
+
stage1_info, stage2_info, channel_info = app_utils.mapping_result(
|
633 |
+
stage1_info, stage2_info, channel_info, filename)
|
634 |
+
#gr.Info('The mapping process has been finished.')
|
635 |
+
stage1_info["state"] = "finished"
|
636 |
+
app_info["stage1"] = stage1_info
|
637 |
+
app_info["stage2"] = stage2_info
|
638 |
+
yield {app_info_json : app_info,
|
639 |
+
channel_info_json : channel_info,
|
640 |
+
desc_md : gr.Markdown(md),
|
641 |
+
in_fillmode : gr.Dropdown(visible=False),
|
642 |
+
fillmode_btn : gr.Button(visible=False),
|
643 |
+
out_json_file : gr.File(filename, visible=True),
|
644 |
+
run_btn : gr.Button(interactive=True)}
|
645 |
+
# step3-1 to step3-2
|
646 |
+
elif fillmode == "mean":
|
647 |
+
#print('step3-1 -> step3-2')
|
648 |
+
md = """
|
649 |
+
---
|
650 |
+
### Step3: Fill the remaining template channels
|
651 |
+
The current empty template channel, indicated in red, will be filled with the average
|
652 |
+
value of the data from the selected channels. (By default, the 4 nearest channels are pre-selected.)
|
653 |
+
"""
|
654 |
+
|
655 |
+
# find the 4 nearest in_channels for each unmatched tpl_channels
|
656 |
+
stage1_info["mappingData"][0]["newOrder"] = app_utils.find_neighbors(
|
657 |
+
channel_info,
|
658 |
+
stage1_info["missingTemplates"],
|
659 |
+
stage1_info["mappingData"][0]["newOrder"])
|
660 |
+
|
661 |
+
# initialize the progress indication label
|
662 |
+
stage1_info.update({
|
663 |
+
"fillingCount" : 1,
|
664 |
+
"totalFillingNum" : len(stage1_info["missingTemplates"])
|
665 |
+
})
|
666 |
+
target_name = stage1_info["missingTemplates"][0]
|
667 |
+
target_idx = channel_info["templateDict"][target_name]["index"]
|
668 |
+
chkbox_value = stage1_info["mappingData"][0]["newOrder"][target_idx]
|
669 |
+
chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
|
670 |
+
chkbox_label = "{} (1/{})".format(target_name, stage1_info["totalFillingNum"])
|
671 |
+
|
672 |
+
stage1_info["state"] = "step3-2-selecting"
|
673 |
+
app_info["stage1"] = stage1_info
|
674 |
+
# determine which button to display
|
675 |
+
if stage1_info["totalFillingNum"] == 1:
|
676 |
+
yield {app_info_json : app_info,
|
677 |
+
desc_md : gr.Markdown(md),
|
678 |
+
in_fillmode : gr.Dropdown(visible=False),
|
679 |
+
fillmode_btn : gr.Button(visible=False),
|
680 |
+
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
|
681 |
+
value=chkbox_value, label=chkbox_label, visible=True),
|
682 |
+
next_btn : gr.Button(visible=True)}
|
683 |
+
else:
|
684 |
+
yield {app_info_json : app_info,
|
685 |
+
desc_md : gr.Markdown(md),
|
686 |
+
in_fillmode : gr.Dropdown(visible=False),
|
687 |
+
fillmode_btn : gr.Button(visible=False),
|
688 |
+
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
|
689 |
+
value=chkbox_value, label=chkbox_label, visible=True),
|
690 |
+
step3_btn : gr.Button(visible=True)}
|
691 |
+
|
692 |
+
# =======================================step3-2========================================
|
693 |
+
# step3-2 to step4
|
694 |
+
elif stage1_info["state"] == "step3-2-selecting":
|
695 |
+
#print('step3-2 -> step4')
|
696 |
|
697 |
+
# --------------------store information before the button click---------------------
|
698 |
# if the user didn't uncheck all in_channel checkboxes
|
699 |
if selected_chkbox != []:
|
700 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
|
|
702 |
|
703 |
# store the indices of the in_channels
|
704 |
selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
|
705 |
+
stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
|
706 |
+
#print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
|
707 |
+
# ----------------------------------------------------------------------------------
|
708 |
+
md = """
|
709 |
+
---
|
710 |
+
### Mapping result
|
711 |
+
(...)
|
712 |
+
"""
|
713 |
|
714 |
+
# finalize and save the mapping results
|
715 |
+
filename = filepath+"mapping_result.json"
|
716 |
+
stage1_info, stage2_info, channel_info = app_utils.mapping_result(
|
717 |
+
stage1_info, stage2_info, channel_info, filename)
|
718 |
+
#gr.Info('The mapping process has been finished.')
|
719 |
+
stage1_info["state"] = "finished"
|
720 |
app_info["stage1"] = stage1_info
|
721 |
+
app_info["stage2"] = stage2_info
|
722 |
+
yield {app_info_json : app_info,
|
723 |
+
channel_info_json : channel_info,
|
724 |
+
desc_md : gr.Markdown(md),
|
725 |
chkbox_group : gr.CheckboxGroup(visible=False),
|
726 |
next_btn : gr.Button(visible=False),
|
727 |
+
out_json_file : gr.File(filename, visible=True),
|
728 |
run_btn : gr.Button(interactive=True)}
|
729 |
|
730 |
next_btn.click(
|
731 |
fn = init_next_step,
|
732 |
+
inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
|
733 |
+
outputs = [app_info_json, channel_info_json, desc_md, tpl_img, mapped_img, radio_group, clear_btn, step2_btn,
|
734 |
+
in_fillmode, fillmode_btn, chkbox_group, step3_btn, out_json_file, next_btn, run_btn]
|
735 |
).success(
|
736 |
fn = None,
|
737 |
js = init_js,
|
|
|
741 |
|
742 |
|
743 |
# +========================================================================================+
|
744 |
+
# | Stage1-step0 |
|
745 |
# +========================================================================================+
|
746 |
+
map_btn.click(
|
747 |
+
fn = reset_all,
|
748 |
+
inputs = [in_data_file, in_loc_file, in_samplerate],
|
749 |
+
outputs = [app_info_json, channel_info_json, map_btn, desc_md, next_btn, tpl_img, mapped_img,
|
750 |
+
radio_group, clear_btn, step2_btn, in_fillmode, fillmode_btn, chkbox_group, step3_btn, out_json_file,
|
751 |
+
run_btn, batch_md, out_data_file]
|
752 |
+
).success(
|
753 |
+
fn = init_next_step,
|
754 |
+
inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
|
755 |
+
outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_img, mapped_img, next_btn]
|
756 |
+
)
|
757 |
+
|
758 |
+
|
759 |
+
# +========================================================================================+
|
760 |
+
# | Stage1-step2 |
|
761 |
+
# +========================================================================================+
|
762 |
+
# ...
|
763 |
@radio_group.select(inputs = app_info_json, outputs = [step2_btn, next_btn])
|
764 |
def determine_button(app_info):
|
765 |
stage1_info = app_info["stage1"]
|
|
|
782 |
def update_radio(app_info, channel_info, selected):
|
783 |
stage1_info = app_info["stage1"]
|
784 |
|
785 |
+
# ----------------------store information before the button click-----------------------
|
|
|
786 |
# check if the user has selected an in_channel to forward to the previous target tpl_channel
|
787 |
if selected != []:
|
788 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
|
|
790 |
|
791 |
# store the index of the selected in_channel
|
792 |
selected_idx = channel_info["inputDict"][selected]["index"]
|
793 |
+
stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = [selected_idx]
|
794 |
+
stage1_info["mappingData"][0]["fillFlags"][prev_target_idx] = False
|
795 |
# mark the in_channel as assigned and tpl_channel as matched
|
796 |
channel_info["templateDict"][prev_target_name]["matched"] = True
|
797 |
channel_info["inputDict"][selected]["assigned"] = True
|
798 |
+
#print(prev_target_name, '<-', selected)
|
799 |
|
800 |
+
# ------------------------update information for the new round--------------------------
|
801 |
stage1_info["fillingCount"] += 1
|
802 |
|
803 |
# update the list of unassignedInputs to exclude the selected in_channel of the previous round
|
804 |
+
stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"], channel_info["inputDict"])
|
|
|
805 |
# update the progress indication label
|
806 |
target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
807 |
radio_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
|
|
|
814 |
radio_group : gr.Radio(choices=stage1_info["unassignedInputs"],
|
815 |
value=[], label=radio_label),
|
816 |
step2_btn : gr.Button(visible=False),
|
817 |
+
next_btn : gr.Button(visible=True)}
|
818 |
else:
|
819 |
return {app_info_json : app_info,
|
820 |
channel_info_json : channel_info,
|
|
|
834 |
|
835 |
|
836 |
# +========================================================================================+
|
837 |
+
# | Stage1-step3 |
|
838 |
+
# +========================================================================================+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
839 |
def update_chkbox(app_info, channel_info, selected):
|
840 |
stage1_info = app_info["stage1"]
|
841 |
|
842 |
+
# ----------------------store information before the button click-----------------------
|
|
|
843 |
# if the user didn't uncheck all in_channel checkboxes
|
844 |
if selected != []:
|
845 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
|
|
847 |
|
848 |
# store the indices of the selected in_channels
|
849 |
selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected]
|
850 |
+
stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
|
851 |
+
#print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
|
852 |
|
853 |
+
# ------------------------update information for the new round--------------------------
|
854 |
stage1_info["fillingCount"] += 1
|
855 |
|
856 |
# update the progress indication label
|
857 |
target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
858 |
target_idx = channel_info["templateDict"][target_name]["index"]
|
859 |
+
chkbox_value = stage1_info["mappingData"][0]["newOrder"][target_idx]
|
860 |
chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
|
861 |
chkbox_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
|
862 |
|
|
|
866 |
return {app_info_json : app_info,
|
867 |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label),
|
868 |
step3_btn : gr.Button(visible=False),
|
869 |
+
next_btn : gr.Button(visible=True)}
|
870 |
else:
|
871 |
return {app_info_json : app_info,
|
872 |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)}
|
873 |
|
874 |
fillmode_btn.click(
|
875 |
+
fn = init_next_step,
|
876 |
+
inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
|
877 |
+
outputs = [app_info_json, channel_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
|
878 |
+
out_json_file, next_btn, run_btn]
|
879 |
).success(
|
880 |
fn = None,
|
881 |
js = init_js,
|
|
|
896 |
|
897 |
|
898 |
# +========================================================================================+
|
899 |
+
# | Stage2: decode data |
|
900 |
# +========================================================================================+
|
901 |
+
def reset_run(app_info, modelname):
|
902 |
stage1_info = app_info["stage1"]
|
903 |
stage2_info = app_info["stage2"]
|
904 |
|
905 |
+
# delete the previous folder of Stage2
|
906 |
+
filepath = stage2_info["filePath"]
|
907 |
utils.dataDelete(filepath)
|
908 |
+
# establish a new folder for Stage2
|
909 |
+
new_filepath = app_info["rootPath"]+"stage2_"+str(random.randint(1,10000))+"/"
|
910 |
os.mkdir(new_filepath)
|
911 |
# generate the output filename
|
912 |
+
filename = stage1_info["fileNames"]["input_data"]
|
913 |
filename = os.path.basename(str(filename))
|
914 |
new_filename = os.path.splitext(filename)[0]+'_'+modelname+'.csv'
|
915 |
|
916 |
+
stage2_info.update({
|
917 |
+
"filePath" : new_filepath,
|
918 |
+
"fileNames" : {
|
919 |
+
"output_data" : new_filepath + new_filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
920 |
}
|
921 |
})
|
922 |
+
app_info["stage2"] = stage2_info
|
923 |
return {app_info_json : app_info,
|
|
|
924 |
#run_btn : gr.Button(interactive=False),
|
925 |
batch_md : gr.Markdown(visible=False),
|
926 |
out_data_file : gr.File(visible=False)}
|
927 |
|
928 |
+
def run_model(app_info, modelname):
|
929 |
stage1_info = app_info["stage1"]
|
930 |
stage2_info = app_info["stage2"]
|
931 |
|
932 |
+
filepath = stage2_info["filePath"]
|
933 |
samplerate = app_info["sampleRate"]
|
934 |
+
filename = stage1_info["fileNames"]["input_data"]
|
935 |
+
new_filename = stage2_info["fileNames"]["output_data"]
|
936 |
|
937 |
+
# flag to indicate if the process has been interrupted by the user
|
938 |
break_flag = False
|
939 |
|
940 |
# run the model multiple times until all in_channels are reconstructed
|
|
|
946 |
#utils.dataDelete(filepath+"temp_data/")
|
947 |
#os.mkdir(filepath+"temp_data/")
|
948 |
except FileNotFoundError:
|
949 |
+
print('break1!!')
|
950 |
break_flag = True
|
951 |
break
|
952 |
except OSError as e:
|
|
|
956 |
md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
|
957 |
yield {batch_md : gr.Markdown(md, visible=True)}
|
958 |
|
959 |
+
# get the mapped index order and the filled status for each tpl_channels
|
960 |
+
new_idx = stage1_info["mappingData"][i]["newOrder"]
|
961 |
+
fill_flags = stage1_info["mappingData"][i]["fillFlags"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
962 |
# ----------------------------------------------------------------------
|
963 |
try:
|
964 |
# step1: Reorder input data
|
965 |
+
data_shape = app_utils.reorder_data(new_idx, fill_flags, filename, filepath+"temp_data/mapped.csv")
|
966 |
+
if modelname == "(mapped data)":
|
967 |
+
new_filename = filepath+"temp_data/mapped.csv"
|
968 |
+
break
|
969 |
# step2: Data preprocessing
|
970 |
total_file_num = utils.preprocessing(filepath+"temp_data/", "mapped.csv", samplerate)
|
971 |
# step3: Signal reconstruction
|
972 |
utils.reconstruct(modelname, total_file_num, filepath+"temp_data/", "denoised.csv", samplerate)
|
973 |
+
if modelname == "(denoised data)":
|
974 |
+
new_filename = filepath+"temp_data/denoised.csv"
|
975 |
+
break
|
976 |
# step4: Restore original order
|
977 |
+
app_utils.restore_order(i, data_shape, new_idx, fill_flags, filepath+"temp_data/denoised.csv", new_filename)
|
978 |
+
break
|
979 |
except FileNotFoundError:
|
980 |
+
print('break2!!')
|
981 |
break_flag = True
|
982 |
break
|
983 |
# ----------------------------------------------------------------------
|
984 |
utils.dataDelete(filepath+"temp_data/")
|
|
|
985 |
|
986 |
if break_flag == True:
|
987 |
yield {batch_md : gr.Markdown(visible=False)}
|
|
|
992 |
|
993 |
run_btn.click(
|
994 |
fn = reset_run,
|
995 |
+
inputs = [app_info_json, in_modelname],
|
996 |
+
outputs = [app_info_json, run_btn, batch_md, out_data_file]
|
997 |
|
998 |
).success(
|
999 |
fn = run_model,
|
1000 |
+
inputs = [app_info_json, in_modelname],
|
1001 |
outputs = [run_btn, batch_md, out_data_file]
|
1002 |
)
|
1003 |
|
channel_mapping.py → app_utils.py
RENAMED
@@ -1,144 +1,208 @@
|
|
1 |
import utils
|
2 |
-
import time
|
3 |
import os
|
|
|
|
|
|
|
4 |
import numpy as np
|
5 |
-
import
|
6 |
-
|
7 |
import mne
|
8 |
from mne.channels import read_custom_montage
|
9 |
from scipy.interpolate import Rbf
|
10 |
from scipy.optimize import linear_sum_assignment
|
11 |
from sklearn.neighbors import NearestNeighbors
|
12 |
|
13 |
-
def
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
idx_set = old_idx[i]
|
24 |
-
#print("channel_{}'s index set: {}".format(i, idx_set))
|
25 |
-
|
26 |
-
if idx_set == []:
|
27 |
new_data[i, :] = zero_arr
|
28 |
else:
|
29 |
-
tmp_data = [
|
30 |
new_data[i, :] = np.mean(tmp_data, axis=0)
|
31 |
|
|
|
32 |
utils.save_data(new_data, new_filename)
|
33 |
-
return
|
34 |
|
35 |
-
def
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
if cnt == 0:
|
41 |
-
new_data = np.zeros((len(in_order), old_data.shape[1]))
|
42 |
else:
|
43 |
new_data = utils.read_train_data(new_filename)
|
44 |
|
45 |
-
for i,
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
if len(idx_set)==1 and channel_info["templateDict"][channel]["matched"]==True:
|
50 |
-
new_data[idx_set[0], :] = old_data[i, :]
|
51 |
|
52 |
-
print(
|
53 |
utils.save_data(new_data, new_filename)
|
54 |
return
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
def read_montage_data(loc_file):
|
57 |
-
|
58 |
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
59 |
in_montage = read_custom_montage(loc_file)
|
|
|
|
|
60 |
tpl_dict = {}
|
61 |
in_dict = {}
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
tpl_dict[channel] = {
|
69 |
"index" : i,
|
70 |
-
"coord_3d" : tpl_montage.get_positions()['ch_pos'][
|
71 |
"matched" : False
|
72 |
}
|
73 |
-
for i in
|
74 |
-
|
75 |
-
in_montage.rename_channels({channel:
|
76 |
-
|
77 |
-
channel = str.upper(channel)
|
78 |
-
in_dict[channel] = {
|
79 |
"index" : i,
|
80 |
-
"coord_3d" : in_montage.get_positions()['ch_pos'][
|
81 |
"assigned" : False
|
82 |
}
|
83 |
-
|
84 |
return tpl_montage, in_montage, tpl_dict, in_dict
|
85 |
|
86 |
-
def
|
87 |
-
|
|
|
88 |
tpl_dict = channel_info["templateDict"]
|
89 |
in_dict = channel_info["inputDict"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
tpl_order = channel_info["templateOrder"]
|
91 |
in_order = channel_info["inputOrder"]
|
92 |
-
|
93 |
-
|
|
|
94 |
|
95 |
-
#
|
96 |
fig = [tpl_montage.plot(), in_montage.plot()]
|
97 |
ax = [fig[0].axes[0], fig[1].axes[0]]
|
98 |
|
99 |
-
#
|
100 |
-
#all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) # displayed coords (px)
|
101 |
-
#all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data)
|
102 |
all_tpl = ax[0].collections[0].get_offsets().data
|
103 |
all_in= ax[1].collections[0].get_offsets().data
|
104 |
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
105 |
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
106 |
|
107 |
-
# transform
|
108 |
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
|
109 |
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
|
110 |
|
111 |
-
# apply to all
|
112 |
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
|
113 |
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
|
114 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
|
115 |
|
116 |
-
#
|
117 |
for i, channel in enumerate(tpl_order):
|
118 |
tpl_dict[channel]["coord_2d"] = all_tpl[i]
|
119 |
for i, channel in enumerate(in_order):
|
120 |
in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
|
121 |
|
122 |
|
123 |
-
#
|
124 |
-
# get the original coords
|
125 |
all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
|
126 |
all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
|
127 |
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
128 |
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
129 |
|
130 |
-
# transform the xyz axis (input's -> template's)
|
131 |
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
|
132 |
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
|
133 |
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
|
134 |
|
135 |
-
# apply to all input channels
|
136 |
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
|
137 |
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
|
138 |
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
139 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
140 |
|
141 |
-
# update
|
142 |
for i, channel in enumerate(in_order):
|
143 |
in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
|
144 |
|
@@ -149,133 +213,153 @@ def align_coords(channel_info, tpl_montage, in_montage):
|
|
149 |
return channel_info
|
150 |
|
151 |
def find_neighbors(channel_info, missing_channels, new_idx):
|
|
|
152 |
tpl_dict = channel_info["templateDict"]
|
153 |
in_dict = channel_info["inputDict"]
|
154 |
-
in_order = channel_info["inputOrder"]
|
155 |
|
|
|
156 |
all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
|
157 |
-
|
158 |
|
159 |
# use KNN to choose k nearest channels
|
160 |
k = 4 if len(in_order)>4 else len(in_order)
|
161 |
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
162 |
knn.fit(all_in)
|
163 |
-
|
164 |
for i, channel in enumerate(missing_channels):
|
165 |
-
distances, indices = knn.kneighbors(
|
166 |
-
#selected = [in_order[j] for j in indices[0]]
|
167 |
-
#print(channel, ':', selected)
|
168 |
-
|
169 |
idx = tpl_dict[channel]["index"]
|
170 |
new_idx[idx] = indices[0].tolist()
|
171 |
|
172 |
return new_idx
|
173 |
|
174 |
-
def
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
loc_file = app_info["stage1"]["filenames"]["input_loc"]
|
179 |
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
|
180 |
tpl_order = tpl_montage.ch_names
|
181 |
in_order = in_montage.ch_names
|
182 |
-
|
|
|
|
|
|
|
|
|
183 |
alias_dict = {
|
184 |
'T3': 'T7',
|
185 |
'T4': 'T8',
|
186 |
'T5': 'P7',
|
187 |
'T6': 'P8'
|
188 |
}
|
189 |
-
|
190 |
-
# match the names of input channels and template channels
|
191 |
for i, channel in enumerate(tpl_order):
|
192 |
if channel in alias_dict and alias_dict[channel] in in_dict:
|
193 |
-
tpl_montage.rename_channels({channel: alias_dict[channel]})
|
194 |
tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
|
195 |
channel = alias_dict[channel]
|
196 |
|
197 |
if channel in in_dict:
|
198 |
new_idx[i] = [in_dict[channel]["index"]]
|
|
|
199 |
tpl_dict[channel]["matched"] = True
|
200 |
in_dict[channel]["assigned"] = True
|
201 |
|
202 |
# update the names
|
203 |
tpl_order = tpl_montage.ch_names
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
channel_info.update({
|
|
|
|
|
206 |
"templateDict" : tpl_dict,
|
207 |
-
"inputDict" : in_dict
|
208 |
-
"templateOrder" : tpl_order,
|
209 |
-
"inputOrder" : in_order
|
210 |
-
})
|
211 |
-
app_info["stage1"].update({
|
212 |
-
"newOrder" : new_idx,
|
213 |
-
"unassignedInputs" : [channel for channel in in_order if in_dict[channel]["assigned"]==False],
|
214 |
-
"missingTemplates" : [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
|
215 |
})
|
216 |
-
|
217 |
-
# align input, template's coordinates
|
218 |
-
channel_info = align_coords(channel_info, tpl_montage, in_montage)
|
219 |
-
|
220 |
-
second2 = time.time()
|
221 |
-
print('Mapping (stage1) finished in',second2 - second1,'s.')
|
222 |
-
yield app_info, channel_info, gr.Markdown("", visible=False)
|
223 |
|
224 |
-
def
|
225 |
-
second1 = time.time()
|
226 |
-
|
227 |
-
tpl_dict = channel_info["templateDict"]
|
228 |
-
in_dict = channel_info["inputDict"]
|
229 |
tpl_order = channel_info["templateOrder"]
|
230 |
in_order = channel_info["inputOrder"]
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
unassigned_coords = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
|
235 |
-
|
236 |
# reset all tpl.matched to False
|
237 |
for channel in tpl_dict:
|
238 |
tpl_dict[channel]["matched"] = False
|
239 |
|
240 |
-
#
|
|
|
|
|
|
|
|
|
241 |
if len(unassigned) < 30:
|
242 |
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
|
243 |
else:
|
244 |
cost_matrix = np.zeros((30, len(unassigned)))
|
|
|
245 |
for i in range(30):
|
246 |
for j in range(len(unassigned)):
|
247 |
-
cost_matrix[i][j] = np.linalg.norm((
|
248 |
-
#print(cost_matrix[i][j], tpl_coords[i] - unassigned_coords[j])
|
249 |
|
250 |
-
#
|
|
|
251 |
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
252 |
|
|
|
253 |
new_idx = [[]]*30
|
|
|
254 |
for i in range(30):
|
255 |
if col_idx[i] < len(unassigned): # filter out dummy channels
|
256 |
tpl_channel = tpl_order[row_idx[i]]
|
257 |
in_channel = unassigned[col_idx[i]]
|
|
|
|
|
|
|
258 |
tpl_dict[tpl_channel]["matched"] = True
|
259 |
in_dict[in_channel]["assigned"] = True
|
260 |
-
|
261 |
-
|
262 |
-
print(f'{tpl_order[row_idx[i]]}({row_idx[i]}) <- {unassigned[col_idx[i]]}({col_idx[i]})')
|
263 |
|
264 |
-
# fill the
|
265 |
-
missing_channels =
|
266 |
if missing_channels != []:
|
267 |
new_idx = find_neighbors(channel_info, missing_channels, new_idx)
|
268 |
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
}
|
273 |
channel_info.update({
|
274 |
"templateDict" : tpl_dict,
|
275 |
"inputDict" : in_dict
|
276 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
return stage2_info, channel_info
|
281 |
|
|
|
1 |
import utils
|
|
|
2 |
import os
|
3 |
+
import time
|
4 |
+
import math
|
5 |
+
import json
|
6 |
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
|
|
8 |
import mne
|
9 |
from mne.channels import read_custom_montage
|
10 |
from scipy.interpolate import Rbf
|
11 |
from scipy.optimize import linear_sum_assignment
|
12 |
from sklearn.neighbors import NearestNeighbors
|
13 |
|
14 |
+
def reorder_data(idx_order, fill_flags, filename, new_filename):
|
15 |
+
# read the input data
|
16 |
+
raw_data = utils.read_train_data(filename)
|
17 |
+
new_data = np.zeros((30, raw_data.shape[1]))
|
18 |
+
|
19 |
+
zero_arr = np.zeros((1, raw_data.shape[1]))
|
20 |
+
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
21 |
+
if flag == False:
|
22 |
+
new_data[i, :] = raw_data[idx_set[0], :]
|
23 |
+
elif idx_set == []:
|
|
|
|
|
|
|
|
|
24 |
new_data[i, :] = zero_arr
|
25 |
else:
|
26 |
+
tmp_data = [raw_data[j, :] for j in idx_set]
|
27 |
new_data[i, :] = np.mean(tmp_data, axis=0)
|
28 |
|
29 |
+
#print(raw_data.shape, new_data.shape)
|
30 |
utils.save_data(new_data, new_filename)
|
31 |
+
return raw_data.shape
|
32 |
|
33 |
+
def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
|
34 |
+
# read the denoised data
|
35 |
+
d_data = utils.read_train_data(filename)
|
36 |
+
if batch_cnt == 0:
|
37 |
+
new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
|
|
|
|
|
38 |
else:
|
39 |
new_data = utils.read_train_data(new_filename)
|
40 |
|
41 |
+
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
42 |
+
# ignore if this channel was filled using "fillmode"
|
43 |
+
if flag == False:
|
44 |
+
new_data[idx_set[0], :] = d_data[i, :]
|
|
|
|
|
45 |
|
46 |
+
#print(d_data.shape, new_data.shape)
|
47 |
utils.save_data(new_data, new_filename)
|
48 |
return
|
49 |
|
50 |
+
def get_matched(tpl_order, tpl_dict):
|
51 |
+
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
|
52 |
+
|
53 |
+
def get_empty_templates(tpl_order, tpl_dict):
|
54 |
+
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
|
55 |
+
|
56 |
+
def get_unassigned_inputs(in_order, in_dict):
|
57 |
+
return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
|
58 |
+
|
59 |
def read_montage_data(loc_file):
|
|
|
60 |
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
61 |
in_montage = read_custom_montage(loc_file)
|
62 |
+
tpl_order = tpl_montage.ch_names
|
63 |
+
in_order = in_montage.ch_names
|
64 |
tpl_dict = {}
|
65 |
in_dict = {}
|
66 |
|
67 |
+
# convert all channel names to uppercase and store the channel information
|
68 |
+
for i, channel in enumerate(tpl_order):
|
69 |
+
up_channel = str.upper(channel)
|
70 |
+
tpl_montage.rename_channels({channel: up_channel})
|
71 |
+
tpl_dict[up_channel] = {
|
|
|
72 |
"index" : i,
|
73 |
+
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
|
74 |
"matched" : False
|
75 |
}
|
76 |
+
for i, channel in enumerate(in_order):
|
77 |
+
up_channel = str.upper(channel)
|
78 |
+
in_montage.rename_channels({channel: up_channel})
|
79 |
+
in_dict[up_channel] = {
|
|
|
|
|
80 |
"index" : i,
|
81 |
+
"coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
|
82 |
"assigned" : False
|
83 |
}
|
|
|
84 |
return tpl_montage, in_montage, tpl_dict, in_dict
|
85 |
|
86 |
+
def save_figures(channel_info, tpl_montage, filename1, filename2):
|
87 |
+
tpl_order = channel_info["templateOrder"]
|
88 |
+
in_order = channel_info["inputOrder"]
|
89 |
tpl_dict = channel_info["templateDict"]
|
90 |
in_dict = channel_info["inputDict"]
|
91 |
+
|
92 |
+
# get the 2D coordinates
|
93 |
+
tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
|
94 |
+
tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
|
95 |
+
in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
|
96 |
+
in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
|
97 |
+
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
98 |
+
in_coords = np.vstack((in_x, in_y)).T
|
99 |
+
|
100 |
+
# extract template's head figure
|
101 |
+
tpl_fig = tpl_montage.plot()
|
102 |
+
tpl_ax = tpl_fig.axes[0]
|
103 |
+
lines = tpl_ax.lines
|
104 |
+
head_lines = []
|
105 |
+
for line in lines:
|
106 |
+
x, y = line.get_data()
|
107 |
+
head_lines.append((x,y))
|
108 |
+
plt.close()
|
109 |
+
|
110 |
+
# -------------------------plot input montage------------------------------
|
111 |
+
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
112 |
+
ax = fig.add_subplot(111)
|
113 |
+
fig.tight_layout()
|
114 |
+
ax.set_aspect('equal')
|
115 |
+
ax.axis('off')
|
116 |
+
|
117 |
+
# plot template's head
|
118 |
+
for x, y in head_lines:
|
119 |
+
ax.plot(x, y, color='black', linewidth=1.0)
|
120 |
+
# plot in_channels on it
|
121 |
+
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
122 |
+
for i, channel in enumerate(in_order):
|
123 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
|
124 |
+
# save input_montage
|
125 |
+
fig.savefig(filename1)
|
126 |
+
|
127 |
+
# ---------------------------add indications-------------------------------
|
128 |
+
# plot unmatched input channels in red
|
129 |
+
indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
|
130 |
+
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
131 |
+
for i in indices:
|
132 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
|
133 |
+
# save mapped_montage
|
134 |
+
fig.savefig(filename2)
|
135 |
+
|
136 |
+
# -------------------------------------------------------------------------
|
137 |
+
# store the tpl and in_channels' display positions (in px).
|
138 |
+
tpl_coords = ax.transData.transform(tpl_coords)
|
139 |
+
in_coords = ax.transData.transform(in_coords)
|
140 |
+
plt.close()
|
141 |
+
|
142 |
+
for i, channel in enumerate(tpl_order):
|
143 |
+
css_left = (tpl_coords[i,0]-11)/6.4
|
144 |
+
css_bottom = (tpl_coords[i,1]-7)/6.4
|
145 |
+
tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
146 |
+
for i, channel in enumerate(in_order):
|
147 |
+
css_left = (in_coords[i,0]-11)/6.4
|
148 |
+
css_bottom = (in_coords[i,1]-7)/6.4
|
149 |
+
in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
150 |
+
|
151 |
+
channel_info.update({
|
152 |
+
"templateDict" : tpl_dict,
|
153 |
+
"inputDict" : in_dict
|
154 |
+
})
|
155 |
+
return channel_info
|
156 |
+
|
157 |
+
def align_coords(channel_info, tpl_montage, in_montage):
|
158 |
tpl_order = channel_info["templateOrder"]
|
159 |
in_order = channel_info["inputOrder"]
|
160 |
+
tpl_dict = channel_info["templateDict"]
|
161 |
+
in_dict = channel_info["inputDict"]
|
162 |
+
matched = get_matched(tpl_order, tpl_dict)
|
163 |
|
164 |
+
# 2D alignment (for visualization purposes)
|
165 |
fig = [tpl_montage.plot(), in_montage.plot()]
|
166 |
ax = [fig[0].axes[0], fig[1].axes[0]]
|
167 |
|
168 |
+
# extract the displayed 2D coordinates from the plots
|
|
|
|
|
169 |
all_tpl = ax[0].collections[0].get_offsets().data
|
170 |
all_in= ax[1].collections[0].get_offsets().data
|
171 |
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
172 |
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
173 |
|
174 |
+
# apply TPS to transform in_channels positions to align with tpl_channels positions
|
175 |
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
|
176 |
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
|
177 |
|
178 |
+
# apply the transformation to all in_channels
|
179 |
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
|
180 |
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
|
181 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
|
182 |
|
183 |
+
# store the 2D positions
|
184 |
for i, channel in enumerate(tpl_order):
|
185 |
tpl_dict[channel]["coord_2d"] = all_tpl[i]
|
186 |
for i, channel in enumerate(in_order):
|
187 |
in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
|
188 |
|
189 |
|
190 |
+
# 3D alignment
|
|
|
191 |
all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
|
192 |
all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
|
193 |
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
194 |
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
195 |
|
|
|
196 |
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
|
197 |
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
|
198 |
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
|
199 |
|
|
|
200 |
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
|
201 |
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
|
202 |
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
203 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
204 |
|
205 |
+
# update in_channels' 3D positions
|
206 |
for i, channel in enumerate(in_order):
|
207 |
in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
|
208 |
|
|
|
213 |
return channel_info
|
214 |
|
215 |
def find_neighbors(channel_info, missing_channels, new_idx):
|
216 |
+
in_order = channel_info["inputOrder"]
|
217 |
tpl_dict = channel_info["templateDict"]
|
218 |
in_dict = channel_info["inputDict"]
|
|
|
219 |
|
220 |
+
# get the 3D coordinates
|
221 |
all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
|
222 |
+
empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
|
223 |
|
224 |
# use KNN to choose k nearest channels
|
225 |
k = 4 if len(in_order)>4 else len(in_order)
|
226 |
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
227 |
knn.fit(all_in)
|
|
|
228 |
for i, channel in enumerate(missing_channels):
|
229 |
+
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
|
|
|
|
|
|
|
230 |
idx = tpl_dict[channel]["index"]
|
231 |
new_idx[idx] = indices[0].tolist()
|
232 |
|
233 |
return new_idx
|
234 |
|
235 |
+
def match_names(stage1_info, channel_info):
|
236 |
+
# read the location file
|
237 |
+
loc_file = stage1_info["fileNames"]["input_loc"]
|
|
|
|
|
238 |
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
|
239 |
tpl_order = tpl_montage.ch_names
|
240 |
in_order = in_montage.ch_names
|
241 |
+
# list to store the indices of the in_channels in the order of tpl_channls
|
242 |
+
new_idx = [[]]*30
|
243 |
+
# flags to record if each tpl_channel's data is filled by "fillmode"
|
244 |
+
fill_flags = [True]*30
|
245 |
+
|
246 |
alias_dict = {
|
247 |
'T3': 'T7',
|
248 |
'T4': 'T8',
|
249 |
'T5': 'P7',
|
250 |
'T6': 'P8'
|
251 |
}
|
|
|
|
|
252 |
for i, channel in enumerate(tpl_order):
|
253 |
if channel in alias_dict and alias_dict[channel] in in_dict:
|
254 |
+
tpl_montage.rename_channels({channel: alias_dict[channel]})
|
255 |
tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
|
256 |
channel = alias_dict[channel]
|
257 |
|
258 |
if channel in in_dict:
|
259 |
new_idx[i] = [in_dict[channel]["index"]]
|
260 |
+
fill_flags[i] = False
|
261 |
tpl_dict[channel]["matched"] = True
|
262 |
in_dict[channel]["assigned"] = True
|
263 |
|
264 |
# update the names
|
265 |
tpl_order = tpl_montage.ch_names
|
266 |
|
267 |
+
stage1_info.update({
|
268 |
+
"unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
|
269 |
+
"missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
|
270 |
+
"mappingData" : [
|
271 |
+
{
|
272 |
+
"newOrder" : new_idx,
|
273 |
+
"fillFlags" : fill_flags
|
274 |
+
}
|
275 |
+
]
|
276 |
+
})
|
277 |
channel_info.update({
|
278 |
+
"templateOrder" : tpl_order,
|
279 |
+
"inputOrder" : in_order,
|
280 |
"templateDict" : tpl_dict,
|
281 |
+
"inputDict" : in_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
})
|
283 |
+
return stage1_info, channel_info, tpl_montage, in_montage
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
+
def optimal_mapping(channel_info):
|
|
|
|
|
|
|
|
|
286 |
tpl_order = channel_info["templateOrder"]
|
287 |
in_order = channel_info["inputOrder"]
|
288 |
+
tpl_dict = channel_info["templateDict"]
|
289 |
+
in_dict = channel_info["inputDict"]
|
290 |
+
unassigned = get_unassigned_inputs(in_order, in_dict)
|
|
|
|
|
291 |
# reset all tpl.matched to False
|
292 |
for channel in tpl_dict:
|
293 |
tpl_dict[channel]["matched"] = False
|
294 |
|
295 |
+
# get the 3D coordinates
|
296 |
+
all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
|
297 |
+
unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
|
298 |
+
|
299 |
+
# initialize the cost matrix for the Hungarian algorithm
|
300 |
if len(unassigned) < 30:
|
301 |
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
|
302 |
else:
|
303 |
cost_matrix = np.zeros((30, len(unassigned)))
|
304 |
+
# fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
|
305 |
for i in range(30):
|
306 |
for j in range(len(unassigned)):
|
307 |
+
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
|
|
|
308 |
|
309 |
+
# apply the Hungarian algorithm to optimally assign each in_channel to a tpl_channel
|
310 |
+
# by minimizing the total distances between their positions.
|
311 |
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
312 |
|
313 |
+
# store the mapping result
|
314 |
new_idx = [[]]*30
|
315 |
+
fill_flags = [True]*30
|
316 |
for i in range(30):
|
317 |
if col_idx[i] < len(unassigned): # filter out dummy channels
|
318 |
tpl_channel = tpl_order[row_idx[i]]
|
319 |
in_channel = unassigned[col_idx[i]]
|
320 |
+
|
321 |
+
new_idx[row_idx[i]] = [in_dict[in_channel]["index"]]
|
322 |
+
fill_flags[row_idx[i]] = False
|
323 |
tpl_dict[tpl_channel]["matched"] = True
|
324 |
in_dict[in_channel]["assigned"] = True
|
325 |
+
#print(f'{tpl_channel}({row_idx[i]}) <- {in_channel}({col_idx[i]})')
|
|
|
|
|
326 |
|
327 |
+
# fill the remaining empty tpl_channels
|
328 |
+
missing_channels = get_empty_templates(tpl_order, tpl_dict)
|
329 |
if missing_channels != []:
|
330 |
new_idx = find_neighbors(channel_info, missing_channels, new_idx)
|
331 |
|
332 |
+
mapping_data = {
|
333 |
+
"newOrder" : new_idx,
|
334 |
+
"fillFlags" : fill_flags
|
335 |
+
}
|
336 |
channel_info.update({
|
337 |
"templateDict" : tpl_dict,
|
338 |
"inputDict" : in_dict
|
339 |
})
|
340 |
+
return mapping_data, channel_info
|
341 |
+
|
342 |
+
def mapping_result(stage1_info, stage2_info, channel_info, filename):
|
343 |
+
# 1. calculate how many times the model needs to be run
|
344 |
+
unassigned_num = len(stage1_info["unassignedInputs"])
|
345 |
+
batch_num = math.ceil(unassigned_num/30) + 1
|
346 |
+
|
347 |
+
# 2. map the remaining in_channels
|
348 |
+
for i in range(1, batch_num):
|
349 |
+
# optimally select 30 in_channels to map to the tpl_channels based on proximity
|
350 |
+
new_mapping_data, channel_info = optimal_mapping(channel_info)
|
351 |
+
stage1_info["mappingData"] += [new_mapping_data]
|
352 |
+
|
353 |
+
# 3. save the mapping result
|
354 |
+
new_dict = {
|
355 |
+
#"templateOrder" : channel_info["templateOrder"],
|
356 |
+
#"inputOrder" : channel_info["inputOrder"],
|
357 |
+
"batchNum" : batch_num,
|
358 |
+
"mappingData" : stage1_info["mappingData"]
|
359 |
+
}
|
360 |
+
with open(filename, 'w') as jsonfile:
|
361 |
+
jsonfile.write(json.dumps(new_dict))
|
362 |
|
363 |
+
stage2_info["totalBatchNum"] = batch_num
|
364 |
+
return stage1_info, stage2_info, channel_info
|
|
|
365 |
|