Spaces:
Sleeping
Sleeping
Commit
·
b9816ab
1
Parent(s):
5e68958
update
Browse files- app.py +65 -73
- app_utils.py +359 -358
- utils.py +2 -1
app.py
CHANGED
@@ -14,11 +14,9 @@ This tool is designed to assist you with two main tasks:
|
|
14 |
- **Channel locations**: If you don't have the channel location file, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files don't match yours, you can use **EEGLAB** to adjust them to your required montage.
|
15 |
- **Raw data**: Your data format must be a two-dimensional array (channels, timepoints).
|
16 |
- **Channel requirements**: Your data must include some channels that correspond to our template channels, which include: ``Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2``. At least some of them need to be present for successful mapping. Additionally, please remove any reference, ECG, EOG, EMG... channels before uploading your files.
|
17 |
-
|
18 |
"""
|
19 |
|
20 |
readme = """
|
21 |
-
|
22 |
## 1. Channel Mapping
|
23 |
The following steps will guide you through the process of mapping your EEG channels to our template channels.
|
24 |
|
@@ -44,10 +42,10 @@ Once all template channels are filled, you will be directed to **Mapping Results
|
|
44 |
### Mapping Results
|
45 |
After completing the previous steps, your channels will be aligned with the template channels required by our models.
|
46 |
- In case there are still some channels that haven't been mapped, we will automatically batch and optimally assign them to the template. This ensures that even channels not initially mapped will still be included in the final results.
|
47 |
-
- Once the mapping process is completed, a
|
48 |
|
49 |
## 2. Decode data
|
50 |
-
After clicking on ``Run`` button, we will process your EEG data based on the mapping results. If necessary, your data will be
|
51 |
"""
|
52 |
|
53 |
icunet = """
|
@@ -62,14 +60,16 @@ init_js = """
|
|
62 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
63 |
stage1_info = app_info.stage1
|
64 |
|
65 |
-
let selector, attribute;
|
66 |
let channel, left, bottom;
|
67 |
|
68 |
if(stage1_info.state == "step2-selecting"){
|
69 |
selector = "#radio-group > div:nth-of-type(2)";
|
|
|
70 |
attribute = "value";
|
71 |
}else if(stage1_info.state == "step3-2-selecting"){
|
72 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
|
|
73 |
attribute = "name";
|
74 |
}else return;
|
75 |
|
@@ -93,7 +93,7 @@ init_js = """
|
|
93 |
bottom = channel_info.inputDict[channel].css_position[1];
|
94 |
|
95 |
item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
|
96 |
-
item.className = "";
|
97 |
item.querySelector(":scope > span").innerText = "";
|
98 |
});
|
99 |
|
@@ -217,21 +217,22 @@ update_js = """
|
|
217 |
"""
|
218 |
|
219 |
with gr.Blocks() as demo:
|
220 |
-
|
221 |
app_info_json = gr.JSON(visible=False)
|
222 |
channel_info_json = gr.JSON(visible=False)
|
223 |
|
224 |
gr.Markdown(intro)
|
225 |
with gr.Row():
|
226 |
|
227 |
-
with gr.Column(variant=
|
228 |
gr.Markdown("# 1.Channel Mapping")
|
229 |
-
#
|
230 |
in_loc_file = gr.File(label="Channel locations (.loc, .locs, .xyz, .sfp, .txt)",
|
231 |
file_types=[".loc", "locs", ".xyz", ".sfp", ".txt"])
|
232 |
map_btn = gr.Button("Map")
|
233 |
-
#
|
234 |
desc_md = gr.Markdown(visible=False)
|
|
|
|
|
235 |
# step1 : initial matching and scaling
|
236 |
with gr.Row():
|
237 |
tpl_img = gr.Image("./template_montage.png", label="Template montage", visible=False)
|
@@ -247,24 +248,17 @@ with gr.Blocks() as demo:
|
|
247 |
scale=2)
|
248 |
fillmode_btn = gr.Button("OK", visible=False, scale=1)
|
249 |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
|
250 |
-
|
251 |
-
out_json_file = gr.File(visible=False)
|
252 |
-
res_md = gr.Markdown(
|
253 |
-
"""
|
254 |
-
(Download this file if you plan to run the models using the <a href="">source code</a>.)
|
255 |
-
""",
|
256 |
-
visible=False)
|
257 |
-
|
258 |
with gr.Row():
|
259 |
clear_btn = gr.Button("Clear", visible=False)
|
260 |
step2_btn = gr.Button("Next", visible=False)
|
261 |
step3_btn = gr.Button("Next", visible=False)
|
262 |
next_btn = gr.Button("Next step", visible=False)
|
263 |
# -----------------------------------------------
|
264 |
-
|
265 |
-
with gr.Column(variant=
|
266 |
gr.Markdown("# 2.Decode Data")
|
267 |
-
#
|
268 |
with gr.Row():
|
269 |
in_data_file = gr.File(label="Raw data (.csv)", file_types=[".csv"])
|
270 |
with gr.Column():
|
@@ -274,15 +268,17 @@ with gr.Blocks() as demo:
|
|
274 |
("IC-U-Net", "ICUNet"),
|
275 |
("IC-U-Net++", "UNetpp"),
|
276 |
("IC-U-Net-Attn", "AttUnet")],
|
277 |
-
#"(mapped data)"
|
|
|
278 |
value="EEGART",
|
279 |
label="Model")
|
280 |
-
run_btn = gr.Button(interactive=False)
|
281 |
-
|
|
|
282 |
batch_md = gr.Markdown(visible=False)
|
283 |
out_data_file = gr.File(label="Denoised data", visible=False)
|
284 |
# -----------------------------------------------
|
285 |
-
|
286 |
with gr.Row():
|
287 |
with gr.Tab("README"):
|
288 |
gr.Markdown(readme)
|
@@ -340,8 +336,8 @@ with gr.Blocks() as demo:
|
|
340 |
"input_data" : "",
|
341 |
"output_data" : ""
|
342 |
},
|
343 |
-
"
|
344 |
-
"
|
345 |
}
|
346 |
}
|
347 |
return {app_info_json : app_info,
|
@@ -360,11 +356,11 @@ with gr.Blocks() as demo:
|
|
360 |
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
|
361 |
step3_btn : gr.Button(visible=False),
|
362 |
out_json_file : gr.File(value=None, visible=False),
|
363 |
-
res_md : gr.Markdown(visible=False),
|
364 |
# --------------------Stage2-------------------------
|
365 |
in_data_file : gr.File(value=None),
|
366 |
in_samplerate : gr.Textbox(value=None),
|
367 |
run_btn : gr.Button(interactive=False),
|
|
|
368 |
batch_md : gr.Markdown(visible=False),
|
369 |
out_data_file : gr.File(value=None, visible=False)}
|
370 |
|
@@ -427,7 +423,8 @@ with gr.Blocks() as demo:
|
|
427 |
if matched_num == 30:
|
428 |
md = """
|
429 |
### Mapping Results
|
430 |
-
The mapping process has been finished.
|
|
|
431 |
"""
|
432 |
# finalize and save the mapping results
|
433 |
filename = filepath+"mapping_result.json"
|
@@ -443,7 +440,6 @@ with gr.Blocks() as demo:
|
|
443 |
mapped_img : gr.Image(visible=False),
|
444 |
next_btn : gr.Button(visible=False),
|
445 |
out_json_file : gr.File(filename, visible=True),
|
446 |
-
res_md : gr.Markdown(visible=True),
|
447 |
run_btn : gr.Button(interactive=True)}
|
448 |
|
449 |
# step1 to step2
|
@@ -487,7 +483,7 @@ with gr.Blocks() as demo:
|
|
487 |
elif in_num == matched_num:
|
488 |
md = """
|
489 |
### Step3: Filling Remaining Template Channels
|
490 |
-
Select one of the methods provided below to fill the remaining
|
491 |
"""
|
492 |
stage1_info["state"] = "step3-select-method"
|
493 |
app_info["stage1"] = stage1_info
|
@@ -503,7 +499,7 @@ with gr.Blocks() as demo:
|
|
503 |
elif stage1_info["state"] == "step2-selecting":
|
504 |
|
505 |
# --------------------store information before the button click---------------------
|
506 |
-
#
|
507 |
if selected_radio != []:
|
508 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
509 |
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
@@ -528,7 +524,8 @@ with gr.Blocks() as demo:
|
|
528 |
if len(stage1_info["missingTemplates"]) == 0:
|
529 |
md = """
|
530 |
### Mapping Results
|
531 |
-
The mapping process has been finished.
|
|
|
532 |
"""
|
533 |
# finalize and save the mapping results
|
534 |
filename = filepath+"mapping_result.json"
|
@@ -542,7 +539,6 @@ with gr.Blocks() as demo:
|
|
542 |
desc_md : gr.Markdown(md),
|
543 |
radio_group : gr.Radio(visible=False),
|
544 |
out_json_file : gr.File(filename, visible=True),
|
545 |
-
res_md : gr.Markdown(visible=True),
|
546 |
clear_btn : gr.Button(visible=False),
|
547 |
next_btn : gr.Button(visible=False),
|
548 |
run_btn : gr.Button(interactive=True)}
|
@@ -550,7 +546,7 @@ with gr.Blocks() as demo:
|
|
550 |
else:
|
551 |
md = """
|
552 |
### Step3: Filling Remaining Template Channels
|
553 |
-
Select one of the methods provided below to fill the remaining
|
554 |
"""
|
555 |
stage1_info["state"] = "step3-select-method"
|
556 |
app_info["stage1"] = stage1_info
|
@@ -570,7 +566,8 @@ with gr.Blocks() as demo:
|
|
570 |
if fillmode == "zero":
|
571 |
md = """
|
572 |
### Mapping Results
|
573 |
-
The mapping process has been finished.
|
|
|
574 |
"""
|
575 |
# finalize and save the mapping results
|
576 |
filename = filepath+"mapping_result.json"
|
@@ -585,7 +582,6 @@ with gr.Blocks() as demo:
|
|
585 |
in_fillmode : gr.Dropdown(visible=False),
|
586 |
fillmode_btn : gr.Button(visible=False),
|
587 |
out_json_file : gr.File(filename, visible=True),
|
588 |
-
res_md : gr.Markdown(visible=True),
|
589 |
run_btn : gr.Button(interactive=True)}
|
590 |
# step3-1 to step3-2
|
591 |
elif fillmode == "mean":
|
@@ -636,18 +632,16 @@ with gr.Blocks() as demo:
|
|
636 |
elif stage1_info["state"] == "step3-2-selecting":
|
637 |
|
638 |
# --------------------store information before the button click---------------------
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
|
646 |
-
#print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
|
647 |
# ----------------------------------------------------------------------------------
|
648 |
md = """
|
649 |
### Mapping Results
|
650 |
-
The mapping process has been finished.
|
|
|
651 |
"""
|
652 |
# finalize and save the mapping results
|
653 |
filename = filepath+"mapping_result.json"
|
@@ -662,14 +656,13 @@ with gr.Blocks() as demo:
|
|
662 |
chkbox_group : gr.CheckboxGroup(visible=False),
|
663 |
next_btn : gr.Button(visible=False),
|
664 |
out_json_file : gr.File(filename, visible=True),
|
665 |
-
res_md : gr.Markdown(visible=True),
|
666 |
run_btn : gr.Button(interactive=True)}
|
667 |
|
668 |
next_btn.click(
|
669 |
fn = init_next_step,
|
670 |
inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
|
671 |
outputs = [app_info_json, channel_info_json, desc_md, tpl_img, mapped_img, radio_group, clear_btn, step2_btn,
|
672 |
-
in_fillmode, fillmode_btn, chkbox_group, step3_btn, out_json_file,
|
673 |
).success(
|
674 |
fn = None,
|
675 |
js = init_js,
|
@@ -686,7 +679,7 @@ with gr.Blocks() as demo:
|
|
686 |
inputs = in_loc_file,
|
687 |
outputs = [app_info_json, channel_info_json, map_btn, desc_md, next_btn, tpl_img, mapped_img,
|
688 |
radio_group, clear_btn, step2_btn, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
|
689 |
-
out_json_file,
|
690 |
).success(
|
691 |
fn = init_next_step,
|
692 |
inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
|
@@ -720,7 +713,7 @@ with gr.Blocks() as demo:
|
|
720 |
def update_radio(app_info, channel_info, selected):
|
721 |
stage1_info = app_info["stage1"]
|
722 |
# ----------------------store information before the button click-----------------------
|
723 |
-
#
|
724 |
if selected != []:
|
725 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
726 |
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
@@ -774,14 +767,11 @@ with gr.Blocks() as demo:
|
|
774 |
def update_chkbox(app_info, channel_info, selected):
|
775 |
stage1_info = app_info["stage1"]
|
776 |
# ----------------------store information before the button click-----------------------
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
|
784 |
-
#print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
|
785 |
|
786 |
# ------------------------update information for the new round--------------------------
|
787 |
stage1_info["fillingCount"] += 1
|
@@ -808,7 +798,7 @@ with gr.Blocks() as demo:
|
|
808 |
fn = init_next_step,
|
809 |
inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
|
810 |
outputs = [app_info_json, channel_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
|
811 |
-
out_json_file,
|
812 |
).success(
|
813 |
fn = None,
|
814 |
js = init_js,
|
@@ -855,8 +845,9 @@ with gr.Blocks() as demo:
|
|
855 |
})
|
856 |
app_info["stage2"] = stage2_info
|
857 |
return {app_info_json : app_info,
|
858 |
-
|
859 |
-
|
|
|
860 |
out_data_file : gr.File(visible=False)}
|
861 |
|
862 |
def run_model(app_info, modelname):
|
@@ -873,19 +864,14 @@ with gr.Blocks() as demo:
|
|
873 |
# establish a temp folder
|
874 |
try:
|
875 |
os.mkdir(filepath+"temp_data/")
|
876 |
-
#except FileExistsError:
|
877 |
-
#utils.dataDelete(filepath+"temp_data/")
|
878 |
-
#os.mkdir(filepath+"temp_data/")
|
879 |
except FileNotFoundError:
|
880 |
print('break1')
|
881 |
break_flag = True
|
882 |
break
|
883 |
-
except OSError as e:
|
884 |
-
print(e)
|
885 |
|
886 |
# update the running status
|
887 |
md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
|
888 |
-
yield {batch_md : gr.Markdown(md
|
889 |
|
890 |
# get the mapped index order and the filled status for each tpl_channels
|
891 |
new_idx = stage1_info["mappingData"][i]["newOrder"]
|
@@ -908,23 +894,29 @@ with gr.Blocks() as demo:
|
|
908 |
utils.dataDelete(filepath+"temp_data/")
|
909 |
|
910 |
if break_flag == True:
|
911 |
-
yield {
|
|
|
912 |
else:
|
913 |
-
yield {
|
|
|
914 |
batch_md : gr.Markdown(visible=False),
|
915 |
out_data_file : gr.File(new_filename, visible=True)}
|
916 |
|
|
|
|
|
|
|
|
|
|
|
917 |
run_btn.click(
|
918 |
fn = reset_run,
|
919 |
inputs = [app_info_json, in_data_file, in_samplerate, in_modelname],
|
920 |
-
outputs = [app_info_json, run_btn, batch_md, out_data_file]
|
921 |
-
|
922 |
).success(
|
923 |
fn = run_model,
|
924 |
inputs = [app_info_json, in_modelname],
|
925 |
-
outputs = [run_btn, batch_md, out_data_file]
|
926 |
)
|
927 |
|
928 |
if __name__ == "__main__":
|
929 |
-
demo.launch()
|
930 |
|
|
|
14 |
- **Channel locations**: If you don't have the channel location file, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files don't match yours, you can use **EEGLAB** to adjust them to your required montage.
|
15 |
- **Raw data**: Your data format must be a two-dimensional array (channels, timepoints).
|
16 |
- **Channel requirements**: Your data must include some channels that correspond to our template channels, which include: ``Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2``. At least some of them need to be present for successful mapping. Additionally, please remove any reference, ECG, EOG, EMG... channels before uploading your files.
|
|
|
17 |
"""
|
18 |
|
19 |
readme = """
|
|
|
20 |
## 1. Channel Mapping
|
21 |
The following steps will guide you through the process of mapping your EEG channels to our template channels.
|
22 |
|
|
|
42 |
### Mapping Results
|
43 |
After completing the previous steps, your channels will be aligned with the template channels required by our models.
|
44 |
- In case there are still some channels that haven't been mapped, we will automatically batch and optimally assign them to the template. This ensures that even channels not initially mapped will still be included in the final results.
|
45 |
+
- Once the mapping process is completed, a JSON file containing the mapping results will be generated. This file is necessary only if you plan to run the models using the <a href="">source code</a>; otherwise, you can ignore it.
|
46 |
|
47 |
## 2. Decode data
|
48 |
+
After clicking on ``Run`` button, we will process your EEG data based on the mapping results. If necessary, your data will be divided into batches and run the models on each batch sequentially, ensuring that all channels are properly processed.
|
49 |
"""
|
50 |
|
51 |
icunet = """
|
|
|
60 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
61 |
stage1_info = app_info.stage1
|
62 |
|
63 |
+
let selector, attribute; //, classname;
|
64 |
let channel, left, bottom;
|
65 |
|
66 |
if(stage1_info.state == "step2-selecting"){
|
67 |
selector = "#radio-group > div:nth-of-type(2)";
|
68 |
+
//classname = "radio";
|
69 |
attribute = "value";
|
70 |
}else if(stage1_info.state == "step3-2-selecting"){
|
71 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
72 |
+
//classname = "chkbox";
|
73 |
attribute = "name";
|
74 |
}else return;
|
75 |
|
|
|
93 |
bottom = channel_info.inputDict[channel].css_position[1];
|
94 |
|
95 |
item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
|
96 |
+
item.className = ""; //classname;
|
97 |
item.querySelector(":scope > span").innerText = "";
|
98 |
});
|
99 |
|
|
|
217 |
"""
|
218 |
|
219 |
with gr.Blocks() as demo:
|
|
|
220 |
app_info_json = gr.JSON(visible=False)
|
221 |
channel_info_json = gr.JSON(visible=False)
|
222 |
|
223 |
gr.Markdown(intro)
|
224 |
with gr.Row():
|
225 |
|
226 |
+
with gr.Column(variant="panel"):
|
227 |
gr.Markdown("# 1.Channel Mapping")
|
228 |
+
# ---------------------input---------------------
|
229 |
in_loc_file = gr.File(label="Channel locations (.loc, .locs, .xyz, .sfp, .txt)",
|
230 |
file_types=[".loc", "locs", ".xyz", ".sfp", ".txt"])
|
231 |
map_btn = gr.Button("Map")
|
232 |
+
# ---------------------output--------------------
|
233 |
desc_md = gr.Markdown(visible=False)
|
234 |
+
out_json_file = gr.File(visible=False)
|
235 |
+
# --------------------mapping--------------------
|
236 |
# step1 : initial matching and scaling
|
237 |
with gr.Row():
|
238 |
tpl_img = gr.Image("./template_montage.png", label="Template montage", visible=False)
|
|
|
248 |
scale=2)
|
249 |
fillmode_btn = gr.Button("OK", visible=False, scale=1)
|
250 |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
|
251 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
with gr.Row():
|
253 |
clear_btn = gr.Button("Clear", visible=False)
|
254 |
step2_btn = gr.Button("Next", visible=False)
|
255 |
step3_btn = gr.Button("Next", visible=False)
|
256 |
next_btn = gr.Button("Next step", visible=False)
|
257 |
# -----------------------------------------------
|
258 |
+
|
259 |
+
with gr.Column(variant="panel"):
|
260 |
gr.Markdown("# 2.Decode Data")
|
261 |
+
# ---------------------input---------------------
|
262 |
with gr.Row():
|
263 |
in_data_file = gr.File(label="Raw data (.csv)", file_types=[".csv"])
|
264 |
with gr.Column():
|
|
|
268 |
("IC-U-Net", "ICUNet"),
|
269 |
("IC-U-Net++", "UNetpp"),
|
270 |
("IC-U-Net-Attn", "AttUnet")],
|
271 |
+
#"(mapped data)",
|
272 |
+
#"(denoised data)"],
|
273 |
value="EEGART",
|
274 |
label="Model")
|
275 |
+
run_btn = gr.Button("Run", interactive=False)
|
276 |
+
cancel_btn = gr.Button("Cancel", visible=False)
|
277 |
+
# ---------------------output--------------------
|
278 |
batch_md = gr.Markdown(visible=False)
|
279 |
out_data_file = gr.File(label="Denoised data", visible=False)
|
280 |
# -----------------------------------------------
|
281 |
+
|
282 |
with gr.Row():
|
283 |
with gr.Tab("README"):
|
284 |
gr.Markdown(readme)
|
|
|
336 |
"input_data" : "",
|
337 |
"output_data" : ""
|
338 |
},
|
339 |
+
"totalBatchNum" : None,
|
340 |
+
"sampleRate" : None
|
341 |
}
|
342 |
}
|
343 |
return {app_info_json : app_info,
|
|
|
356 |
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
|
357 |
step3_btn : gr.Button(visible=False),
|
358 |
out_json_file : gr.File(value=None, visible=False),
|
|
|
359 |
# --------------------Stage2-------------------------
|
360 |
in_data_file : gr.File(value=None),
|
361 |
in_samplerate : gr.Textbox(value=None),
|
362 |
run_btn : gr.Button(interactive=False),
|
363 |
+
cancel_btn : gr.Button(interactive=False),
|
364 |
batch_md : gr.Markdown(visible=False),
|
365 |
out_data_file : gr.File(value=None, visible=False)}
|
366 |
|
|
|
423 |
if matched_num == 30:
|
424 |
md = """
|
425 |
### Mapping Results
|
426 |
+
The mapping process has been finished.
|
427 |
+
Download the file below if you plan to run the models using the <a href="">source code</a>.
|
428 |
"""
|
429 |
# finalize and save the mapping results
|
430 |
filename = filepath+"mapping_result.json"
|
|
|
440 |
mapped_img : gr.Image(visible=False),
|
441 |
next_btn : gr.Button(visible=False),
|
442 |
out_json_file : gr.File(filename, visible=True),
|
|
|
443 |
run_btn : gr.Button(interactive=True)}
|
444 |
|
445 |
# step1 to step2
|
|
|
483 |
elif in_num == matched_num:
|
484 |
md = """
|
485 |
### Step3: Filling Remaining Template Channels
|
486 |
+
Select one of the methods provided below to fill the remaining template channels.
|
487 |
"""
|
488 |
stage1_info["state"] = "step3-select-method"
|
489 |
app_info["stage1"] = stage1_info
|
|
|
499 |
elif stage1_info["state"] == "step2-selecting":
|
500 |
|
501 |
# --------------------store information before the button click---------------------
|
502 |
+
# if the user has selected an in_channel to forward to the previous target tpl_channel
|
503 |
if selected_radio != []:
|
504 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
505 |
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
|
|
524 |
if len(stage1_info["missingTemplates"]) == 0:
|
525 |
md = """
|
526 |
### Mapping Results
|
527 |
+
The mapping process has been finished.
|
528 |
+
Download the file below if you plan to run the models using the <a href="">source code</a>.
|
529 |
"""
|
530 |
# finalize and save the mapping results
|
531 |
filename = filepath+"mapping_result.json"
|
|
|
539 |
desc_md : gr.Markdown(md),
|
540 |
radio_group : gr.Radio(visible=False),
|
541 |
out_json_file : gr.File(filename, visible=True),
|
|
|
542 |
clear_btn : gr.Button(visible=False),
|
543 |
next_btn : gr.Button(visible=False),
|
544 |
run_btn : gr.Button(interactive=True)}
|
|
|
546 |
else:
|
547 |
md = """
|
548 |
### Step3: Filling Remaining Template Channels
|
549 |
+
Select one of the methods provided below to fill the remaining template channels.
|
550 |
"""
|
551 |
stage1_info["state"] = "step3-select-method"
|
552 |
app_info["stage1"] = stage1_info
|
|
|
566 |
if fillmode == "zero":
|
567 |
md = """
|
568 |
### Mapping Results
|
569 |
+
The mapping process has been finished.
|
570 |
+
Download the file below if you plan to run the models using the <a href="">source code</a>.
|
571 |
"""
|
572 |
# finalize and save the mapping results
|
573 |
filename = filepath+"mapping_result.json"
|
|
|
582 |
in_fillmode : gr.Dropdown(visible=False),
|
583 |
fillmode_btn : gr.Button(visible=False),
|
584 |
out_json_file : gr.File(filename, visible=True),
|
|
|
585 |
run_btn : gr.Button(interactive=True)}
|
586 |
# step3-1 to step3-2
|
587 |
elif fillmode == "mean":
|
|
|
632 |
elif stage1_info["state"] == "step3-2-selecting":
|
633 |
|
634 |
# --------------------store information before the button click---------------------
|
635 |
+
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
636 |
+
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
637 |
+
selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
|
638 |
+
stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
|
639 |
+
#print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
|
|
|
|
|
|
|
640 |
# ----------------------------------------------------------------------------------
|
641 |
md = """
|
642 |
### Mapping Results
|
643 |
+
The mapping process has been finished.
|
644 |
+
Download the file below if you plan to run the models using the <a href="">source code</a>.
|
645 |
"""
|
646 |
# finalize and save the mapping results
|
647 |
filename = filepath+"mapping_result.json"
|
|
|
656 |
chkbox_group : gr.CheckboxGroup(visible=False),
|
657 |
next_btn : gr.Button(visible=False),
|
658 |
out_json_file : gr.File(filename, visible=True),
|
|
|
659 |
run_btn : gr.Button(interactive=True)}
|
660 |
|
661 |
next_btn.click(
|
662 |
fn = init_next_step,
|
663 |
inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
|
664 |
outputs = [app_info_json, channel_info_json, desc_md, tpl_img, mapped_img, radio_group, clear_btn, step2_btn,
|
665 |
+
in_fillmode, fillmode_btn, chkbox_group, step3_btn, out_json_file, next_btn, run_btn]
|
666 |
).success(
|
667 |
fn = None,
|
668 |
js = init_js,
|
|
|
679 |
inputs = in_loc_file,
|
680 |
outputs = [app_info_json, channel_info_json, map_btn, desc_md, next_btn, tpl_img, mapped_img,
|
681 |
radio_group, clear_btn, step2_btn, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
|
682 |
+
out_json_file, in_data_file, in_samplerate, run_btn, cancel_btn, batch_md, out_data_file]
|
683 |
).success(
|
684 |
fn = init_next_step,
|
685 |
inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
|
|
|
713 |
def update_radio(app_info, channel_info, selected):
|
714 |
stage1_info = app_info["stage1"]
|
715 |
# ----------------------store information before the button click-----------------------
|
716 |
+
# if the user has selected an in_channel to forward to the previous target tpl_channel
|
717 |
if selected != []:
|
718 |
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
719 |
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
|
|
767 |
def update_chkbox(app_info, channel_info, selected):
|
768 |
stage1_info = app_info["stage1"]
|
769 |
# ----------------------store information before the button click-----------------------
|
770 |
+
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
771 |
+
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
772 |
+
selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected]
|
773 |
+
stage1_info["mappingData"][0]["newOrder"][prev_target_idx] = selected_indices
|
774 |
+
#print(f'{prev_target_name}({prev_target_idx}): {selected_indices}')
|
|
|
|
|
|
|
775 |
|
776 |
# ------------------------update information for the new round--------------------------
|
777 |
stage1_info["fillingCount"] += 1
|
|
|
798 |
fn = init_next_step,
|
799 |
inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
|
800 |
outputs = [app_info_json, channel_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
|
801 |
+
out_json_file, next_btn, run_btn]
|
802 |
).success(
|
803 |
fn = None,
|
804 |
js = init_js,
|
|
|
845 |
})
|
846 |
app_info["stage2"] = stage2_info
|
847 |
return {app_info_json : app_info,
|
848 |
+
run_btn : gr.Button(visible=False),
|
849 |
+
cancel_btn : gr.Button(visible=True, interactive=True),
|
850 |
+
batch_md : gr.Markdown("", visible=True),
|
851 |
out_data_file : gr.File(visible=False)}
|
852 |
|
853 |
def run_model(app_info, modelname):
|
|
|
864 |
# establish a temp folder
|
865 |
try:
|
866 |
os.mkdir(filepath+"temp_data/")
|
|
|
|
|
|
|
867 |
except FileNotFoundError:
|
868 |
print('break1')
|
869 |
break_flag = True
|
870 |
break
|
|
|
|
|
871 |
|
872 |
# update the running status
|
873 |
md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
|
874 |
+
yield {batch_md : gr.Markdown(md)}
|
875 |
|
876 |
# get the mapped index order and the filled status for each tpl_channels
|
877 |
new_idx = stage1_info["mappingData"][i]["newOrder"]
|
|
|
894 |
utils.dataDelete(filepath+"temp_data/")
|
895 |
|
896 |
if break_flag == True:
|
897 |
+
yield {run_btn : gr.Button(visible=True),
|
898 |
+
cancel_btn : gr.Button(visible=False)}
|
899 |
else:
|
900 |
+
yield {run_btn : gr.Button(visible=True),
|
901 |
+
cancel_btn : gr.Button(visible=False),
|
902 |
batch_md : gr.Markdown(visible=False),
|
903 |
out_data_file : gr.File(new_filename, visible=True)}
|
904 |
|
905 |
+
@cancel_btn.click(inputs = app_info_json, outputs = [cancel_btn, batch_md])
|
906 |
+
def stop_processing(app_info):
|
907 |
+
utils.dataDelete(app_info["stage2"]["filePath"])
|
908 |
+
return gr.Button(interactive=False), gr.Markdown(visible=False)
|
909 |
+
|
910 |
run_btn.click(
|
911 |
fn = reset_run,
|
912 |
inputs = [app_info_json, in_data_file, in_samplerate, in_modelname],
|
913 |
+
outputs = [app_info_json, run_btn, cancel_btn, batch_md, out_data_file]
|
|
|
914 |
).success(
|
915 |
fn = run_model,
|
916 |
inputs = [app_info_json, in_modelname],
|
917 |
+
outputs = [run_btn, cancel_btn, batch_md, out_data_file]
|
918 |
)
|
919 |
|
920 |
if __name__ == "__main__":
|
921 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
922 |
|
app_utils.py
CHANGED
@@ -1,358 +1,359 @@
|
|
1 |
-
import utils
|
2 |
-
import os
|
3 |
-
import math
|
4 |
-
import json
|
5 |
-
import numpy as np
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
-
import mne
|
8 |
-
from mne.channels import read_custom_montage
|
9 |
-
from scipy.interpolate import Rbf
|
10 |
-
from scipy.optimize import linear_sum_assignment
|
11 |
-
from sklearn.neighbors import NearestNeighbors
|
12 |
-
|
13 |
-
def reorder_data(idx_order, fill_flags, filename, new_filename):
|
14 |
-
# read the input data
|
15 |
-
raw_data = utils.read_train_data(filename)
|
16 |
-
#print(raw_data.shape)
|
17 |
-
new_data = np.zeros((30, raw_data.shape[1]))
|
18 |
-
|
19 |
-
zero_arr = np.zeros((1, raw_data.shape[1]))
|
20 |
-
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
21 |
-
if flag == False:
|
22 |
-
new_data[i, :] = raw_data[idx_set[0], :]
|
23 |
-
elif idx_set == []:
|
24 |
-
new_data[i, :] = zero_arr
|
25 |
-
else:
|
26 |
-
tmp_data = [raw_data[j, :] for j in idx_set]
|
27 |
-
new_data[i, :] = np.mean(tmp_data, axis=0)
|
28 |
-
|
29 |
-
utils.save_data(new_data, new_filename)
|
30 |
-
return raw_data.shape
|
31 |
-
|
32 |
-
def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
|
33 |
-
# read the denoised data
|
34 |
-
d_data = utils.read_train_data(filename)
|
35 |
-
if batch_cnt == 0:
|
36 |
-
new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
|
37 |
-
#print(new_data.shape)
|
38 |
-
else:
|
39 |
-
new_data = utils.read_train_data(new_filename)
|
40 |
-
|
41 |
-
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
42 |
-
# ignore if this channel was filled using "fillmode"
|
43 |
-
if flag == False:
|
44 |
-
new_data[idx_set[0], :] = d_data[i, :]
|
45 |
-
|
46 |
-
utils.save_data(new_data, new_filename)
|
47 |
-
return
|
48 |
-
|
49 |
-
def get_matched(tpl_order, tpl_dict):
|
50 |
-
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
|
51 |
-
|
52 |
-
def get_empty_templates(tpl_order, tpl_dict):
|
53 |
-
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
|
54 |
-
|
55 |
-
def get_unassigned_inputs(in_order, in_dict):
|
56 |
-
return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
|
57 |
-
|
58 |
-
def read_montage_data(loc_file):
|
59 |
-
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
60 |
-
in_montage = read_custom_montage(loc_file)
|
61 |
-
tpl_order = tpl_montage.ch_names
|
62 |
-
in_order = in_montage.ch_names
|
63 |
-
tpl_dict = {}
|
64 |
-
in_dict = {}
|
65 |
-
|
66 |
-
# convert all channel names to uppercase and store the channel information
|
67 |
-
for i, channel in enumerate(tpl_order):
|
68 |
-
up_channel = str.upper(channel)
|
69 |
-
tpl_montage.rename_channels({channel: up_channel})
|
70 |
-
tpl_dict[up_channel] = {
|
71 |
-
"index" : i,
|
72 |
-
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
|
73 |
-
"matched" : False
|
74 |
-
}
|
75 |
-
for i, channel in enumerate(in_order):
|
76 |
-
up_channel = str.upper(channel)
|
77 |
-
in_montage.rename_channels({channel: up_channel})
|
78 |
-
in_dict[up_channel] = {
|
79 |
-
"index" : i,
|
80 |
-
"coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
|
81 |
-
"assigned" : False
|
82 |
-
}
|
83 |
-
return tpl_montage, in_montage, tpl_dict, in_dict
|
84 |
-
|
85 |
-
def save_figures(channel_info, tpl_montage, filename1, filename2):
|
86 |
-
tpl_order = channel_info["templateOrder"]
|
87 |
-
in_order = channel_info["inputOrder"]
|
88 |
-
tpl_dict = channel_info["templateDict"]
|
89 |
-
in_dict = channel_info["inputDict"]
|
90 |
-
|
91 |
-
tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
|
92 |
-
tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
|
93 |
-
in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
|
94 |
-
in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
|
95 |
-
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
96 |
-
in_coords = np.vstack((in_x, in_y)).T
|
97 |
-
|
98 |
-
# extract template's head figure
|
99 |
-
tpl_fig = tpl_montage.plot()
|
100 |
-
tpl_ax = tpl_fig.axes[0]
|
101 |
-
lines = tpl_ax.lines
|
102 |
-
head_lines = []
|
103 |
-
for line in lines:
|
104 |
-
x, y = line.get_data()
|
105 |
-
head_lines.append((x,y))
|
106 |
-
plt.close()
|
107 |
-
|
108 |
-
# -------------------------plot input montage------------------------------
|
109 |
-
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
110 |
-
ax = fig.add_subplot(111)
|
111 |
-
fig.tight_layout()
|
112 |
-
ax.set_aspect('equal')
|
113 |
-
ax.axis('off')
|
114 |
-
|
115 |
-
# plot template's head
|
116 |
-
for x, y in head_lines:
|
117 |
-
ax.plot(x, y, color='black', linewidth=1.0)
|
118 |
-
# plot in_channels on it
|
119 |
-
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
120 |
-
for i, channel in enumerate(in_order):
|
121 |
-
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
|
122 |
-
# save input_montage
|
123 |
-
fig.savefig(filename1)
|
124 |
-
|
125 |
-
# ---------------------------add indications-------------------------------
|
126 |
-
# plot unmatched input channels in red
|
127 |
-
indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
#
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
"
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
"
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
knn
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
'
|
244 |
-
'
|
245 |
-
'
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
channel =
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
"
|
265 |
-
"
|
266 |
-
|
267 |
-
|
268 |
-
"
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
"
|
275 |
-
|
276 |
-
"
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
#
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
"
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
"
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
#"
|
350 |
-
"
|
351 |
-
"
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
|
|
|
1 |
+
import utils
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import mne
|
8 |
+
from mne.channels import read_custom_montage
|
9 |
+
from scipy.interpolate import Rbf
|
10 |
+
from scipy.optimize import linear_sum_assignment
|
11 |
+
from sklearn.neighbors import NearestNeighbors
|
12 |
+
|
13 |
+
def reorder_data(idx_order, fill_flags, filename, new_filename):
|
14 |
+
# read the input data
|
15 |
+
raw_data = utils.read_train_data(filename)
|
16 |
+
#print(raw_data.shape)
|
17 |
+
new_data = np.zeros((30, raw_data.shape[1]))
|
18 |
+
|
19 |
+
zero_arr = np.zeros((1, raw_data.shape[1]))
|
20 |
+
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
21 |
+
if flag == False:
|
22 |
+
new_data[i, :] = raw_data[idx_set[0], :]
|
23 |
+
elif idx_set == []:
|
24 |
+
new_data[i, :] = zero_arr
|
25 |
+
else:
|
26 |
+
tmp_data = [raw_data[j, :] for j in idx_set]
|
27 |
+
new_data[i, :] = np.mean(tmp_data, axis=0)
|
28 |
+
|
29 |
+
utils.save_data(new_data, new_filename)
|
30 |
+
return raw_data.shape
|
31 |
+
|
32 |
+
def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, new_filename):
|
33 |
+
# read the denoised data
|
34 |
+
d_data = utils.read_train_data(filename)
|
35 |
+
if batch_cnt == 0:
|
36 |
+
new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
|
37 |
+
#print(new_data.shape)
|
38 |
+
else:
|
39 |
+
new_data = utils.read_train_data(new_filename)
|
40 |
+
|
41 |
+
for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
|
42 |
+
# ignore if this channel was filled using "fillmode"
|
43 |
+
if flag == False:
|
44 |
+
new_data[idx_set[0], :] = d_data[i, :]
|
45 |
+
|
46 |
+
utils.save_data(new_data, new_filename)
|
47 |
+
return
|
48 |
+
|
49 |
+
def get_matched(tpl_order, tpl_dict):
|
50 |
+
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==True]
|
51 |
+
|
52 |
+
def get_empty_templates(tpl_order, tpl_dict):
|
53 |
+
return [channel for channel in tpl_order if tpl_dict[channel]["matched"]==False]
|
54 |
+
|
55 |
+
def get_unassigned_inputs(in_order, in_dict):
|
56 |
+
return [channel for channel in in_order if in_dict[channel]["assigned"]==False]
|
57 |
+
|
58 |
+
def read_montage_data(loc_file):
|
59 |
+
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
60 |
+
in_montage = read_custom_montage(loc_file)
|
61 |
+
tpl_order = tpl_montage.ch_names
|
62 |
+
in_order = in_montage.ch_names
|
63 |
+
tpl_dict = {}
|
64 |
+
in_dict = {}
|
65 |
+
|
66 |
+
# convert all channel names to uppercase and store the channel information
|
67 |
+
for i, channel in enumerate(tpl_order):
|
68 |
+
up_channel = str.upper(channel)
|
69 |
+
tpl_montage.rename_channels({channel: up_channel})
|
70 |
+
tpl_dict[up_channel] = {
|
71 |
+
"index" : i,
|
72 |
+
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_channel],
|
73 |
+
"matched" : False
|
74 |
+
}
|
75 |
+
for i, channel in enumerate(in_order):
|
76 |
+
up_channel = str.upper(channel)
|
77 |
+
in_montage.rename_channels({channel: up_channel})
|
78 |
+
in_dict[up_channel] = {
|
79 |
+
"index" : i,
|
80 |
+
"coord_3d" : in_montage.get_positions()['ch_pos'][up_channel],
|
81 |
+
"assigned" : False
|
82 |
+
}
|
83 |
+
return tpl_montage, in_montage, tpl_dict, in_dict
|
84 |
+
|
85 |
+
def save_figures(channel_info, tpl_montage, filename1, filename2):
|
86 |
+
tpl_order = channel_info["templateOrder"]
|
87 |
+
in_order = channel_info["inputOrder"]
|
88 |
+
tpl_dict = channel_info["templateDict"]
|
89 |
+
in_dict = channel_info["inputDict"]
|
90 |
+
|
91 |
+
tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
|
92 |
+
tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
|
93 |
+
in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
|
94 |
+
in_y = [in_dict[channel]["coord_2d"][1] for channel in in_order]
|
95 |
+
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
96 |
+
in_coords = np.vstack((in_x, in_y)).T
|
97 |
+
|
98 |
+
# extract template's head figure
|
99 |
+
tpl_fig = tpl_montage.plot()
|
100 |
+
tpl_ax = tpl_fig.axes[0]
|
101 |
+
lines = tpl_ax.lines
|
102 |
+
head_lines = []
|
103 |
+
for line in lines:
|
104 |
+
x, y = line.get_data()
|
105 |
+
head_lines.append((x,y))
|
106 |
+
plt.close()
|
107 |
+
|
108 |
+
# -------------------------plot input montage------------------------------
|
109 |
+
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
110 |
+
ax = fig.add_subplot(111)
|
111 |
+
fig.tight_layout()
|
112 |
+
ax.set_aspect('equal')
|
113 |
+
ax.axis('off')
|
114 |
+
|
115 |
+
# plot template's head
|
116 |
+
for x, y in head_lines:
|
117 |
+
ax.plot(x, y, color='black', linewidth=1.0)
|
118 |
+
# plot in_channels on it
|
119 |
+
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
120 |
+
for i, channel in enumerate(in_order):
|
121 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
|
122 |
+
# save input_montage
|
123 |
+
fig.savefig(filename1)
|
124 |
+
|
125 |
+
# ---------------------------add indications-------------------------------
|
126 |
+
# plot unmatched input channels in red
|
127 |
+
indices = [in_dict[channel]["index"] for channel in in_order if in_dict[channel]["assigned"]==False]
|
128 |
+
if indices != []:
|
129 |
+
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
130 |
+
for i in indices:
|
131 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_order[i], color='red', fontsize=10.0, va='center')
|
132 |
+
# save mapped_montage
|
133 |
+
fig.savefig(filename2)
|
134 |
+
|
135 |
+
# -------------------------------------------------------------------------
|
136 |
+
# store the tpl and in_channels' display positions (in px).
|
137 |
+
tpl_coords = ax.transData.transform(tpl_coords)
|
138 |
+
in_coords = ax.transData.transform(in_coords)
|
139 |
+
plt.close()
|
140 |
+
|
141 |
+
for i, channel in enumerate(tpl_order):
|
142 |
+
css_left = (tpl_coords[i,0]-11)/6.4
|
143 |
+
css_bottom = (tpl_coords[i,1]-7)/6.4
|
144 |
+
tpl_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
145 |
+
for i, channel in enumerate(in_order):
|
146 |
+
css_left = (in_coords[i,0]-11)/6.4
|
147 |
+
css_bottom = (in_coords[i,1]-7)/6.4
|
148 |
+
in_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
149 |
+
|
150 |
+
channel_info.update({
|
151 |
+
"templateDict" : tpl_dict,
|
152 |
+
"inputDict" : in_dict
|
153 |
+
})
|
154 |
+
return channel_info
|
155 |
+
|
156 |
+
def align_coords(channel_info, tpl_montage, in_montage):
|
157 |
+
tpl_order = channel_info["templateOrder"]
|
158 |
+
in_order = channel_info["inputOrder"]
|
159 |
+
tpl_dict = channel_info["templateDict"]
|
160 |
+
in_dict = channel_info["inputDict"]
|
161 |
+
matched = get_matched(tpl_order, tpl_dict)
|
162 |
+
|
163 |
+
# 2D alignment (for visualization purposes)
|
164 |
+
fig = [tpl_montage.plot(), in_montage.plot()]
|
165 |
+
ax = [fig[0].axes[0], fig[1].axes[0]]
|
166 |
+
|
167 |
+
# extract the displayed 2D coordinates from the plots
|
168 |
+
all_tpl = ax[0].collections[0].get_offsets().data
|
169 |
+
all_in= ax[1].collections[0].get_offsets().data
|
170 |
+
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
171 |
+
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
172 |
+
|
173 |
+
# apply TPS to transform in_channels positions to align with tpl_channels positions
|
174 |
+
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
|
175 |
+
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
|
176 |
+
|
177 |
+
# apply the transformation to all in_channels
|
178 |
+
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
|
179 |
+
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
|
180 |
+
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
|
181 |
+
|
182 |
+
# store the 2D positions
|
183 |
+
for i, channel in enumerate(tpl_order):
|
184 |
+
tpl_dict[channel]["coord_2d"] = all_tpl[i]
|
185 |
+
for i, channel in enumerate(in_order):
|
186 |
+
in_dict[channel]["coord_2d"] = transformed_in[i].tolist()
|
187 |
+
|
188 |
+
|
189 |
+
# 3D alignment
|
190 |
+
all_tpl = np.array([tpl_dict[channel]["coord_3d"].tolist() for channel in tpl_order])
|
191 |
+
all_in = np.array([in_dict[channel]["coord_3d"].tolist() for channel in in_order])
|
192 |
+
matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
|
193 |
+
matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
|
194 |
+
|
195 |
+
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
|
196 |
+
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
|
197 |
+
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
|
198 |
+
|
199 |
+
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
|
200 |
+
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
|
201 |
+
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
202 |
+
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
203 |
+
|
204 |
+
# update in_channels' 3D positions
|
205 |
+
for i, channel in enumerate(in_order):
|
206 |
+
in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
|
207 |
+
|
208 |
+
channel_info.update({
|
209 |
+
"templateDict" : tpl_dict,
|
210 |
+
"inputDict" : in_dict
|
211 |
+
})
|
212 |
+
return channel_info
|
213 |
+
|
214 |
+
def find_neighbors(channel_info, missing_channels, new_idx):
|
215 |
+
in_order = channel_info["inputOrder"]
|
216 |
+
tpl_dict = channel_info["templateDict"]
|
217 |
+
in_dict = channel_info["inputDict"]
|
218 |
+
|
219 |
+
all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
|
220 |
+
empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
|
221 |
+
|
222 |
+
# use KNN to choose k nearest channels
|
223 |
+
k = 4 if len(in_order)>4 else len(in_order)
|
224 |
+
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
225 |
+
knn.fit(all_in)
|
226 |
+
for i, channel in enumerate(missing_channels):
|
227 |
+
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
|
228 |
+
idx = tpl_dict[channel]["index"]
|
229 |
+
new_idx[idx] = indices[0].tolist()
|
230 |
+
|
231 |
+
return new_idx
|
232 |
+
|
233 |
+
def match_names(stage1_info, channel_info):
|
234 |
+
# read the location file
|
235 |
+
loc_file = stage1_info["fileNames"]["input_loc"]
|
236 |
+
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
|
237 |
+
tpl_order = tpl_montage.ch_names
|
238 |
+
in_order = in_montage.ch_names
|
239 |
+
new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channels
|
240 |
+
fill_flags = [True]*30 # record if each tpl_channel's data is filled by "fillmode"
|
241 |
+
|
242 |
+
alias_dict = {
|
243 |
+
'T3': 'T7',
|
244 |
+
'T4': 'T8',
|
245 |
+
'T5': 'P7',
|
246 |
+
'T6': 'P8'
|
247 |
+
}
|
248 |
+
for i, channel in enumerate(tpl_order):
|
249 |
+
if channel in alias_dict and alias_dict[channel] in in_dict:
|
250 |
+
tpl_montage.rename_channels({channel: alias_dict[channel]})
|
251 |
+
tpl_dict[alias_dict[channel]] = tpl_dict.pop(channel)
|
252 |
+
channel = alias_dict[channel]
|
253 |
+
|
254 |
+
if channel in in_dict:
|
255 |
+
new_idx[i] = [in_dict[channel]["index"]]
|
256 |
+
fill_flags[i] = False
|
257 |
+
tpl_dict[channel]["matched"] = True
|
258 |
+
in_dict[channel]["assigned"] = True
|
259 |
+
|
260 |
+
# update the names
|
261 |
+
tpl_order = tpl_montage.ch_names
|
262 |
+
|
263 |
+
stage1_info.update({
|
264 |
+
"unassignedInputs" : get_unassigned_inputs(in_order, in_dict),
|
265 |
+
"missingTemplates" : get_empty_templates(tpl_order, tpl_dict),
|
266 |
+
"mappingData" : [
|
267 |
+
{
|
268 |
+
"newOrder" : new_idx,
|
269 |
+
"fillFlags" : fill_flags
|
270 |
+
}
|
271 |
+
]
|
272 |
+
})
|
273 |
+
channel_info.update({
|
274 |
+
"templateOrder" : tpl_order,
|
275 |
+
"inputOrder" : in_order,
|
276 |
+
"templateDict" : tpl_dict,
|
277 |
+
"inputDict" : in_dict
|
278 |
+
})
|
279 |
+
return stage1_info, channel_info, tpl_montage, in_montage
|
280 |
+
|
281 |
+
def optimal_mapping(channel_info):
|
282 |
+
tpl_order = channel_info["templateOrder"]
|
283 |
+
in_order = channel_info["inputOrder"]
|
284 |
+
tpl_dict = channel_info["templateDict"]
|
285 |
+
in_dict = channel_info["inputDict"]
|
286 |
+
unassigned = get_unassigned_inputs(in_order, in_dict)
|
287 |
+
# reset all tpl.matched to False
|
288 |
+
for channel in tpl_dict:
|
289 |
+
tpl_dict[channel]["matched"] = False
|
290 |
+
|
291 |
+
all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
|
292 |
+
unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
|
293 |
+
|
294 |
+
# initialize the cost matrix for the Hungarian algorithm
|
295 |
+
if len(unassigned) < 30:
|
296 |
+
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
|
297 |
+
else:
|
298 |
+
cost_matrix = np.zeros((30, len(unassigned)))
|
299 |
+
# fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
|
300 |
+
for i in range(30):
|
301 |
+
for j in range(len(unassigned)):
|
302 |
+
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
|
303 |
+
|
304 |
+
# apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
|
305 |
+
# by minimizing the total distances between their positions.
|
306 |
+
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
307 |
+
|
308 |
+
# store the mapping results
|
309 |
+
new_idx = [[]]*30
|
310 |
+
fill_flags = [True]*30
|
311 |
+
for i, j in zip(row_idx, col_idx):
|
312 |
+
if j < len(unassigned): # filter out dummy channels
|
313 |
+
tpl_channel = tpl_order[i]
|
314 |
+
in_channel = unassigned[j]
|
315 |
+
|
316 |
+
new_idx[i] = [in_dict[in_channel]["index"]]
|
317 |
+
fill_flags[i] = False
|
318 |
+
tpl_dict[tpl_channel]["matched"] = True
|
319 |
+
in_dict[in_channel]["assigned"] = True
|
320 |
+
#print(f'{tpl_channel}({i}) <- {in_channel}({j})')
|
321 |
+
|
322 |
+
# fill the remaining empty tpl_channels
|
323 |
+
missing_channels = get_empty_templates(tpl_order, tpl_dict)
|
324 |
+
if missing_channels != []:
|
325 |
+
new_idx = find_neighbors(channel_info, missing_channels, new_idx)
|
326 |
+
|
327 |
+
mapping_data = {
|
328 |
+
"newOrder" : new_idx,
|
329 |
+
"fillFlags" : fill_flags
|
330 |
+
}
|
331 |
+
channel_info.update({
|
332 |
+
"templateDict" : tpl_dict,
|
333 |
+
"inputDict" : in_dict
|
334 |
+
})
|
335 |
+
return mapping_data, channel_info
|
336 |
+
|
337 |
+
def mapping_result(stage1_info, stage2_info, channel_info, filename):
|
338 |
+
unassigned_num = len(stage1_info["unassignedInputs"])
|
339 |
+
batch_num = math.ceil(unassigned_num/30) + 1
|
340 |
+
|
341 |
+
# map the remaining in_channels
|
342 |
+
for i in range(1, batch_num):
|
343 |
+
# optimally select 30 in_channels to map to the tpl_channels based on proximity
|
344 |
+
new_mapping_data, channel_info = optimal_mapping(channel_info)
|
345 |
+
stage1_info["mappingData"] += [new_mapping_data]
|
346 |
+
|
347 |
+
# save the mapping results
|
348 |
+
new_dict = {
|
349 |
+
#"templateOrder" : channel_info["templateOrder"],
|
350 |
+
#"inputOrder" : channel_info["inputOrder"],
|
351 |
+
"batchNum" : batch_num,
|
352 |
+
"mappingData" : stage1_info["mappingData"]
|
353 |
+
}
|
354 |
+
with open(filename, 'w') as jsonfile:
|
355 |
+
jsonfile.write(json.dumps(new_dict))
|
356 |
+
|
357 |
+
stage2_info["totalBatchNum"] = batch_num
|
358 |
+
return stage1_info, stage2_info, channel_info
|
359 |
+
|
utils.py
CHANGED
@@ -143,7 +143,8 @@ def dataDelete(path):
|
|
143 |
try:
|
144 |
shutil.rmtree(path)
|
145 |
except OSError as e:
|
146 |
-
|
|
|
147 |
else:
|
148 |
pass
|
149 |
#print("The directory is deleted successfully")
|
|
|
143 |
try:
|
144 |
shutil.rmtree(path)
|
145 |
except OSError as e:
|
146 |
+
pass
|
147 |
+
#print(e)
|
148 |
else:
|
149 |
pass
|
150 |
#print("The directory is deleted successfully")
|