Spaces:
Sleeping
Sleeping
Commit
·
857bb0f
1
Parent(s):
8b18526
update comment
Browse files- app.py +487 -337
- channel_mapping.py +46 -66
- utils.py +9 -8
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
|
|
|
3 |
import os
|
4 |
import random
|
5 |
import math
|
@@ -9,7 +10,7 @@ import mne
|
|
9 |
from mne.channels import read_custom_montage
|
10 |
|
11 |
import utils
|
12 |
-
from channel_mapping import mapping_stage1, mapping_stage2,
|
13 |
|
14 |
|
15 |
quickstart = """
|
@@ -20,7 +21,8 @@ quickstart = """
|
|
20 |
|
21 |
## Channel locations
|
22 |
Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
|
23 |
-
|
|
|
24 |
|
25 |
## Mapping
|
26 |
(...)
|
@@ -55,32 +57,33 @@ Electroencephalography (EEG) signals are often contaminated with artifacts. It i
|
|
55 |
"""
|
56 |
|
57 |
init_js = """
|
58 |
-
(
|
59 |
-
|
60 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
|
|
61 |
|
62 |
let selector, classname, attribute;
|
63 |
let channel, left, bottom;
|
64 |
|
65 |
-
if(
|
66 |
selector = "#radio-group > div:nth-of-type(2)";
|
67 |
//classname = "radio";
|
68 |
attribute = "value";
|
69 |
-
}else if(
|
70 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
71 |
//classname = "chkbox";
|
72 |
attribute = "name";
|
73 |
}else return;
|
74 |
|
75 |
|
76 |
-
// add figure of the
|
77 |
document.querySelector(selector).style.cssText = `
|
78 |
position: relative;
|
79 |
width: 100%;
|
80 |
aspect-ratio: 1;
|
81 |
//width: 560px;
|
82 |
//height: 560px;
|
83 |
-
background: url("file=${
|
84 |
background-size: contain;
|
85 |
|
86 |
`;
|
@@ -99,13 +102,13 @@ init_js = """
|
|
99 |
|
100 |
|
101 |
// add indication for the missing channels
|
102 |
-
channel =
|
103 |
left = channel_info.templateDict[channel].css_position[0];
|
104 |
bottom = channel_info.templateDict[channel].css_position[1];
|
105 |
|
106 |
let dot_rule = `
|
107 |
${selector}::before {
|
108 |
-
content:
|
109 |
position: absolute;
|
110 |
background-color: red;
|
111 |
width: 10px;
|
@@ -145,14 +148,15 @@ init_js = """
|
|
145 |
"""
|
146 |
|
147 |
update_js = """
|
148 |
-
(
|
149 |
-
|
150 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
|
|
151 |
|
152 |
let selector;
|
153 |
let channel, left, bottom;
|
154 |
|
155 |
-
if(
|
156 |
selector = "#radio-group > div:nth-of-type(2)";
|
157 |
|
158 |
// update the radios
|
@@ -166,12 +170,12 @@ update_js = """
|
|
166 |
item.className = "";
|
167 |
item.querySelector(":scope > span").innerText = "";
|
168 |
});
|
169 |
-
}else if(
|
170 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
171 |
}else return;
|
172 |
|
173 |
// update indication
|
174 |
-
channel =
|
175 |
left = channel_info.templateDict[channel].css_position[0];
|
176 |
bottom = channel_info.templateDict[channel].css_position[1];
|
177 |
|
@@ -202,7 +206,7 @@ update_js = """
|
|
202 |
}
|
203 |
`;
|
204 |
|
205 |
-
//
|
206 |
const styleSheet = document.styleSheets[0];
|
207 |
for(let i=0; i<styleSheet.cssRules.length; i++){
|
208 |
let tmp = styleSheet.cssRules[i].selectorText;
|
@@ -219,7 +223,7 @@ update_js = """
|
|
219 |
|
220 |
with gr.Blocks() as demo:
|
221 |
|
222 |
-
|
223 |
channel_info_json = gr.JSON(visible=False)
|
224 |
|
225 |
with gr.Row():
|
@@ -234,27 +238,25 @@ with gr.Blocks() as demo:
|
|
234 |
gr.Markdown("# 1.Channel Mapping")
|
235 |
# ------------------------input--------------------------
|
236 |
with gr.Row():
|
237 |
-
|
238 |
-
|
|
|
239 |
with gr.Row():
|
240 |
in_samplerate = gr.Textbox(label="Sampling rate (Hz)", scale=2)
|
241 |
-
map_btn = gr.Button("Mapping", scale=1)
|
242 |
|
243 |
# ------------------------mapping------------------------
|
244 |
# description for stage1-123
|
245 |
-
desc_md = gr.Markdown(
|
246 |
-
|
247 |
# stage1-1 : mapping result
|
248 |
with gr.Row():
|
249 |
-
tpl_montage = gr.Image("./template_montage.png", label="Template
|
250 |
mapped_montage = gr.Image(label="Input channels", visible=False)
|
251 |
-
|
252 |
# stage1-2 : assign unmatched input channels to empty template channels
|
253 |
radio_group = gr.Radio(elem_id="radio-group", visible=False)
|
254 |
-
|
255 |
# stage1-3 : select a way to fill the empty template channels
|
256 |
with gr.Row():
|
257 |
-
|
258 |
value="mean",
|
259 |
label="Filling method",
|
260 |
visible=False,
|
@@ -263,7 +265,7 @@ with gr.Blocks() as demo:
|
|
263 |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
|
264 |
|
265 |
with gr.Row():
|
266 |
-
clear_btn = gr.Button("Clear", visible=False)
|
267 |
step2_btn = gr.Button("Next", visible=False)
|
268 |
step3_btn = gr.Button("Next", visible=False)
|
269 |
next_btn = gr.Button("Next step", visible=False)
|
@@ -273,13 +275,11 @@ with gr.Blocks() as demo:
|
|
273 |
gr.Markdown("# 2.Decode Data")
|
274 |
# ------------------------input--------------------------
|
275 |
with gr.Row():
|
276 |
-
|
277 |
("ART", "EEGART"),
|
278 |
("IC-U-Net", "ICUNet"),
|
279 |
("IC-U-Net++", "UNetpp"),
|
280 |
-
("IC-U-Net-Attn", "AttUnet"),
|
281 |
-
"(mapped data)",
|
282 |
-
"(denoised data)"],
|
283 |
value="EEGART",
|
284 |
label="Model",
|
285 |
scale=2)
|
@@ -287,8 +287,7 @@ with gr.Blocks() as demo:
|
|
287 |
|
288 |
# ------------------------output-------------------------
|
289 |
batch_md = gr.Markdown(visible=False)
|
290 |
-
|
291 |
-
|
292 |
# -------------------------------------------------------
|
293 |
|
294 |
with gr.Row():
|
@@ -302,59 +301,91 @@ with gr.Blocks() as demo:
|
|
302 |
gr.Markdown()
|
303 |
with gr.Tab("QuickStart"):
|
304 |
gr.Markdown(quickstart)
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
|
|
|
|
|
|
320 |
try:
|
321 |
-
os.mkdir(filepath+"/
|
322 |
except OSError as e:
|
323 |
-
utils.dataDelete(filepath+"/
|
324 |
-
os.mkdir(filepath+"/
|
325 |
-
|
|
|
|
|
|
|
326 |
|
327 |
-
# initialize channel_info,
|
328 |
channel_info = {}
|
329 |
-
|
330 |
-
"
|
331 |
-
"
|
332 |
-
"
|
333 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
}
|
335 |
-
|
336 |
# reset layout
|
337 |
-
return {
|
338 |
channel_info_json : channel_info,
|
339 |
# ------------------stage1-----------------------
|
340 |
-
|
|
|
341 |
tpl_montage : gr.Image(visible=False),
|
342 |
mapped_montage : gr.Image(value=None, visible=False),
|
343 |
radio_group : gr.Radio(choices=[], value=[], label="", visible=False),
|
344 |
-
|
345 |
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
|
346 |
fillmode_btn : gr.Button(visible=False),
|
347 |
clear_btn : gr.Button(visible=False),
|
348 |
step2_btn : gr.Button(visible=False),
|
349 |
step3_btn : gr.Button(visible=False),
|
350 |
-
next_btn : gr.Button(visible=False),
|
351 |
# ------------------stage2-----------------------
|
352 |
run_btn : gr.Button(interactive=False),
|
353 |
batch_md : gr.Markdown(visible=False),
|
354 |
-
|
355 |
|
356 |
|
357 |
-
#
|
|
|
|
|
358 |
def save_figures(channel_info, filename1, filename2):
|
359 |
|
360 |
template_montage = read_custom_montage("./template_chanlocs.loc")
|
@@ -363,6 +394,14 @@ with gr.Blocks() as demo:
|
|
363 |
template_order = channel_info["templateOrder"]
|
364 |
input_order = channel_info["inputOrder"]
|
365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
# get template's head figure
|
367 |
tpl_fig = template_montage.plot()
|
368 |
tpl_ax = tpl_fig.axes[0]
|
@@ -373,14 +412,6 @@ with gr.Blocks() as demo:
|
|
373 |
head_lines.append((x,y))
|
374 |
plt.close()
|
375 |
|
376 |
-
# get template's and input's 2d coords
|
377 |
-
tpl_x = [template_dict[channel]["coord_2d"][0] for channel in template_order]
|
378 |
-
tpl_y = [template_dict[channel]["coord_2d"][1] for channel in template_order]
|
379 |
-
in_x = [input_dict[channel]["coord_2d"][0] for channel in input_order]
|
380 |
-
in_y = [input_dict[channel]["coord_2d"][1] for channel in input_order]
|
381 |
-
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
382 |
-
in_coords = np.vstack((in_x, in_y)).T
|
383 |
-
|
384 |
# -------------------------plot input montage------------------------------
|
385 |
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
386 |
ax = fig.add_subplot(111)
|
@@ -391,12 +422,11 @@ with gr.Blocks() as demo:
|
|
391 |
# plot template's head
|
392 |
for x, y in head_lines:
|
393 |
ax.plot(x, y, color='black', linewidth=1.0)
|
394 |
-
# plot input channels
|
395 |
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
396 |
for i, channel in enumerate(input_order):
|
397 |
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
|
398 |
-
|
399 |
-
# save raw_montage
|
400 |
fig.savefig(filename1)
|
401 |
|
402 |
# ---------------------------add indications-------------------------------
|
@@ -406,15 +436,14 @@ with gr.Blocks() as demo:
|
|
406 |
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
407 |
for i in indices:
|
408 |
ax.text(in_coords[i,0]+0.003, in_coords[i,1], input_order[i], color='red', fontsize=10.0, va='center')
|
409 |
-
|
410 |
# save mapped_montage
|
411 |
fig.savefig(filename2)
|
412 |
-
plt.close()
|
413 |
|
414 |
# -------------------------------------------------------------------------
|
415 |
-
#
|
416 |
tpl_coords = ax.transData.transform(tpl_coords)
|
417 |
in_coords = ax.transData.transform(in_coords)
|
|
|
418 |
|
419 |
for i, channel in enumerate(template_order):
|
420 |
css_left = (tpl_coords[i,0]-11)/6.4
|
@@ -432,156 +461,189 @@ with gr.Blocks() as demo:
|
|
432 |
})
|
433 |
return channel_info
|
434 |
|
435 |
-
def mapping_result(
|
436 |
-
|
437 |
-
|
|
|
|
|
|
|
438 |
filename2 = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
|
439 |
channel_info = save_figures(channel_info, filename1, filename2)
|
440 |
-
|
441 |
-
|
442 |
-
"raw_montage" : filename1,
|
443 |
"mapped_montage" : filename2
|
444 |
})
|
445 |
|
446 |
-
#
|
447 |
|
448 |
-
|
449 |
-
matched_num = 30 - len(
|
450 |
|
451 |
-
# if the
|
452 |
# -> stage2
|
453 |
if matched_num == 30:
|
454 |
-
|
455 |
gr.Info('The mapping process has been finished.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
|
457 |
-
|
|
|
458 |
channel_info_json : channel_info,
|
459 |
-
|
|
|
460 |
tpl_montage : gr.Image(visible=True),
|
461 |
mapped_montage : gr.Image(value=filename2, visible=True),
|
462 |
run_btn : gr.Button(interactive=True)}
|
463 |
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
|
481 |
-
map_btn.click(
|
482 |
fn = reset_all,
|
483 |
-
inputs = [
|
484 |
-
outputs = [
|
485 |
-
|
486 |
-
run_btn, batch_md,
|
487 |
).success(
|
488 |
fn = mapping_stage1,
|
489 |
-
inputs = [
|
490 |
-
outputs = [
|
491 |
-
|
492 |
).success(
|
493 |
fn = mapping_result,
|
494 |
-
inputs = [
|
495 |
-
outputs = [
|
496 |
)
|
497 |
|
498 |
-
|
499 |
-
|
|
|
|
|
|
|
|
|
500 |
|
501 |
# stage1-1 -> stage1-2
|
502 |
-
if
|
503 |
-
print('step1 -> step2')
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
|
|
|
|
|
|
|
|
508 |
"fillingCount" : 1,
|
509 |
-
"totalFillingNum" : len(
|
510 |
})
|
|
|
|
|
511 |
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
return {app_state_json : app_state,
|
517 |
channel_info_json : channel_info,
|
518 |
-
desc_md : gr.Markdown(
|
519 |
tpl_montage : gr.Image(visible=False),
|
520 |
mapped_montage : gr.Image(visible=False),
|
521 |
-
radio_group : gr.Radio(choices=
|
522 |
clear_btn : gr.Button(visible=True),
|
523 |
next_btn : gr.Button("Next step")}
|
524 |
else:
|
525 |
-
return {
|
526 |
channel_info_json : channel_info,
|
527 |
-
desc_md : gr.Markdown(
|
528 |
tpl_montage : gr.Image(visible=False),
|
529 |
mapped_montage : gr.Image(visible=False),
|
530 |
-
radio_group : gr.Radio(choices=
|
531 |
clear_btn : gr.Button(visible=True),
|
532 |
step2_btn : gr.Button(visible=True),
|
533 |
next_btn : gr.Button(visible=False)}
|
534 |
|
535 |
# stage1-1 -> stage1-3
|
536 |
-
elif
|
537 |
-
print('step1 -> step3')
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
"totalFillingNum" : len(app_state["missingTemplates"])
|
544 |
-
})
|
545 |
-
return {app_state_json : app_state,
|
546 |
-
channel_info_json : channel_info,
|
547 |
-
desc_md : gr.Markdown("### Step3: Fill the remaining template channels"),
|
548 |
tpl_montage : gr.Image(visible=False),
|
549 |
mapped_montage : gr.Image(visible=False),
|
550 |
-
|
551 |
fillmode_btn : gr.Button(visible=True),
|
552 |
next_btn : gr.Button(visible=False)}
|
553 |
|
554 |
# stage1-2 -> stage1-3 or stage2
|
555 |
-
elif
|
556 |
|
557 |
-
#
|
558 |
-
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
|
559 |
-
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
560 |
-
if selected_radio == []:
|
561 |
-
app_state["stage1NewOrder"][prev_target_idx] = []
|
562 |
-
else:
|
563 |
-
selected_idx = channel_info["inputDict"][selected_radio]["index"]
|
564 |
-
app_state["stage1NewOrder"][prev_target_idx] = [selected_idx]
|
565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
channel_info["templateDict"][prev_target_name]["matched"] = True
|
567 |
channel_info["inputDict"][selected_radio]["assigned"] = True
|
568 |
-
|
569 |
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
|
|
|
|
576 |
|
577 |
-
# if all the unmatched
|
578 |
# -> stage2
|
579 |
-
if len(
|
580 |
-
print('step2 -> stage2')
|
|
|
581 |
gr.Info('The mapping process has been finished.')
|
582 |
-
app_state["stage1State"] = "finished"
|
583 |
|
584 |
-
|
|
|
585 |
channel_info_json : channel_info,
|
586 |
desc_md : gr.Markdown(visible=False),
|
587 |
radio_group : gr.Radio(visible=False),
|
@@ -591,42 +653,43 @@ with gr.Blocks() as demo:
|
|
591 |
|
592 |
# -> stage1-3
|
593 |
else:
|
594 |
-
print('step2 -> step3')
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
|
|
601 |
channel_info_json : channel_info,
|
602 |
-
desc_md : gr.Markdown(
|
603 |
radio_group : gr.Radio(visible=False),
|
604 |
-
|
605 |
fillmode_btn : gr.Button(visible=True),
|
606 |
clear_btn : gr.Button(visible=False),
|
607 |
next_btn : gr.Button(visible=False)}
|
608 |
|
609 |
# stage1-3 -> stage2
|
610 |
-
elif
|
611 |
-
|
612 |
-
|
613 |
-
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
|
614 |
-
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
615 |
-
if selected_chkbox == []:
|
616 |
-
app_state["stage1NewOrder"][prev_target_idx] = []
|
617 |
-
else:
|
618 |
-
selected_idx = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
|
619 |
-
app_state["stage1NewOrder"][prev_target_idx] = selected_idx
|
620 |
-
#print(f'{prev_target_name}({prev_target_idx}): {selected_chkbox}')
|
621 |
-
|
622 |
gr.Info('The mapping process has been finished.')
|
623 |
-
app_state["stage1State"] = "finished"
|
624 |
-
print('step3 -> stage2')
|
625 |
|
626 |
-
|
627 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
628 |
|
629 |
-
|
|
|
630 |
desc_md : gr.Markdown(visible=False),
|
631 |
chkbox_group : gr.CheckboxGroup(visible=False),
|
632 |
next_btn : gr.Button(visible=False),
|
@@ -634,62 +697,71 @@ with gr.Blocks() as demo:
|
|
634 |
|
635 |
next_btn.click(
|
636 |
fn = init_next_step,
|
637 |
-
inputs = [
|
638 |
-
outputs = [
|
639 |
-
|
640 |
).success(
|
641 |
fn = None,
|
642 |
js = init_js,
|
643 |
-
inputs = [
|
644 |
outputs = []
|
645 |
)
|
646 |
|
647 |
-
|
648 |
-
|
|
|
|
|
|
|
|
|
649 |
|
650 |
-
#
|
651 |
-
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
|
652 |
-
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
653 |
-
if selected == []:
|
654 |
-
app_state["stage1NewOrder"][prev_target_idx] = []
|
655 |
-
else:
|
656 |
-
selected_idx = channel_info["inputDict"][selected]["index"]
|
657 |
-
app_state["stage1NewOrder"][prev_target_idx] = [selected_idx]
|
658 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
659 |
channel_info["templateDict"][prev_target_name]["matched"] = True
|
660 |
channel_info["inputDict"][selected]["assigned"] = True
|
661 |
-
|
|
|
|
|
|
|
662 |
|
663 |
-
# update the
|
664 |
-
|
665 |
-
app_state["stage1UnassignedInputs"] = [channel for channel in channel_info["inputOrder"]
|
666 |
if channel_info["inputDict"][channel]["assigned"]==False]
|
|
|
|
|
|
|
667 |
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
return {app_state_json : app_state,
|
673 |
channel_info_json : channel_info,
|
674 |
-
radio_group : gr.Radio(choices=
|
675 |
value=[], label=radio_label),
|
676 |
step2_btn : gr.Button(visible=False),
|
677 |
next_btn : gr.Button("Next step", visible=True)}
|
678 |
else:
|
679 |
-
return {
|
680 |
channel_info_json : channel_info,
|
681 |
-
radio_group : gr.Radio(choices=
|
682 |
value=[], label=radio_label)}
|
683 |
|
684 |
step2_btn.click(
|
685 |
fn = update_radio,
|
686 |
-
inputs = [
|
687 |
-
outputs = [
|
688 |
-
|
689 |
).success(
|
690 |
fn = None,
|
691 |
js = update_js,
|
692 |
-
inputs = [
|
693 |
outputs = []
|
694 |
)
|
695 |
|
@@ -700,183 +772,261 @@ with gr.Blocks() as demo:
|
|
700 |
)
|
701 |
|
702 |
|
703 |
-
#
|
704 |
-
|
|
|
|
|
|
|
705 |
|
706 |
-
if
|
707 |
-
|
708 |
gr.Info('The mapping process has been finished.')
|
709 |
|
710 |
-
|
|
|
711 |
desc_md : gr.Markdown(visible=False),
|
712 |
-
|
713 |
fillmode_btn : gr.Button(visible=False),
|
714 |
run_btn : gr.Button(interactive=True)}
|
715 |
|
716 |
-
elif
|
717 |
-
|
718 |
-
|
|
|
|
|
719 |
|
720 |
-
#
|
721 |
-
|
722 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
723 |
|
724 |
-
|
|
|
|
|
|
|
725 |
chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
|
726 |
-
chkbox_label =
|
727 |
|
728 |
-
|
729 |
-
|
730 |
-
|
|
|
|
|
|
|
731 |
fillmode_btn : gr.Button(visible=False),
|
732 |
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
|
733 |
value=chkbox_value, label=chkbox_label, visible=True),
|
734 |
next_btn : gr.Button(visible=True)}
|
735 |
else:
|
736 |
-
return {
|
737 |
-
|
|
|
738 |
fillmode_btn : gr.Button(visible=False),
|
739 |
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
|
740 |
value=chkbox_value, label=chkbox_label, visible=True),
|
741 |
step3_btn : gr.Button(visible=True)}
|
742 |
|
743 |
-
def update_chkbox(
|
744 |
-
|
745 |
-
# save info before clicking on next_btn
|
746 |
-
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
|
747 |
-
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
748 |
-
if selected == []:
|
749 |
-
app_state["stage1NewOrder"][prev_target_idx] = []
|
750 |
-
else:
|
751 |
-
selected_idx = [channel_info["inputDict"][channel]["index"] for channel in selected]
|
752 |
-
app_state["stage1NewOrder"][prev_target_idx] = selected_idx
|
753 |
-
#print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
|
754 |
|
755 |
-
#
|
756 |
-
app_state["fillingCount"] += 1
|
757 |
|
758 |
-
|
759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
760 |
|
761 |
-
|
|
|
|
|
|
|
762 |
chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
|
763 |
-
chkbox_label =
|
764 |
|
765 |
-
|
766 |
-
|
|
|
|
|
767 |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label),
|
768 |
step3_btn : gr.Button(visible=False),
|
769 |
next_btn : gr.Button("Submit", visible=True)}
|
770 |
else:
|
771 |
-
return {
|
772 |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)}
|
773 |
|
774 |
fillmode_btn.click(
|
775 |
fn = fill_value,
|
776 |
-
inputs = [
|
777 |
-
outputs = [
|
778 |
).success(
|
779 |
fn = None,
|
780 |
js = init_js,
|
781 |
-
inputs = [
|
782 |
outputs = []
|
783 |
)
|
784 |
|
785 |
step3_btn.click(
|
786 |
fn = update_chkbox,
|
787 |
-
inputs = [
|
788 |
-
outputs = [
|
789 |
-
|
790 |
).success(
|
791 |
fn = None,
|
792 |
js = update_js,
|
793 |
-
inputs = [
|
794 |
outputs = []
|
795 |
)
|
796 |
|
797 |
-
# -------------------------stage2: decode data-------------------------------
|
798 |
-
def delete_file(filename):
|
799 |
-
try:
|
800 |
-
os.remove(filename)
|
801 |
-
except OSError as e:
|
802 |
-
print(e)
|
803 |
|
804 |
-
|
|
|
|
|
|
|
|
|
|
|
805 |
|
806 |
-
#
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
|
817 |
-
|
818 |
-
|
819 |
-
|
|
|
|
|
|
|
820 |
|
821 |
-
|
822 |
-
|
823 |
-
"
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
|
|
|
|
|
|
|
|
|
|
828 |
})
|
829 |
-
return {
|
830 |
channel_info_json : channel_info,
|
831 |
-
run_btn : gr.Button(interactive=False),
|
832 |
batch_md : gr.Markdown(visible=False),
|
833 |
-
|
834 |
|
835 |
-
def run_model(
|
836 |
-
|
837 |
-
|
838 |
-
new_filename = app_state["filenames"]["denoised"]
|
839 |
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
#
|
854 |
-
|
855 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
856 |
|
857 |
-
#
|
858 |
-
|
859 |
-
|
860 |
-
#return {out_denoised_data : filepath + 'denoised.csv'}
|
861 |
|
862 |
-
|
863 |
-
|
864 |
-
|
|
|
|
|
|
|
|
|
|
|
865 |
|
866 |
-
|
867 |
-
|
868 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
869 |
|
870 |
run_btn.click(
|
871 |
fn = reset_run,
|
872 |
-
inputs = [
|
873 |
-
outputs = [
|
874 |
|
875 |
).success(
|
876 |
fn = run_model,
|
877 |
-
inputs = [
|
878 |
-
outputs = [run_btn, batch_md,
|
879 |
)
|
880 |
|
881 |
if __name__ == "__main__":
|
882 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
import time
|
4 |
import os
|
5 |
import random
|
6 |
import math
|
|
|
10 |
from mne.channels import read_custom_montage
|
11 |
|
12 |
import utils
|
13 |
+
from channel_mapping import mapping_stage1, mapping_stage2, reorder_input_data, restore_original_order, find_neighbors
|
14 |
|
15 |
|
16 |
quickstart = """
|
|
|
21 |
|
22 |
## Channel locations
|
23 |
Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
|
24 |
+
**Note:**
|
25 |
+
If you cannot obtain it, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage.
|
26 |
|
27 |
## Mapping
|
28 |
(...)
|
|
|
57 |
"""
|
58 |
|
59 |
init_js = """
|
60 |
+
(app_info, channel_info) => {
|
61 |
+
app_info = JSON.parse(JSON.stringify(app_info));
|
62 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
63 |
+
stage1_info = app_info.stage1
|
64 |
|
65 |
let selector, classname, attribute;
|
66 |
let channel, left, bottom;
|
67 |
|
68 |
+
if(stage1_info.state == "step2-selecting"){
|
69 |
selector = "#radio-group > div:nth-of-type(2)";
|
70 |
//classname = "radio";
|
71 |
attribute = "value";
|
72 |
+
}else if(stage1_info.state == "step3-selecting"){
|
73 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
74 |
//classname = "chkbox";
|
75 |
attribute = "name";
|
76 |
}else return;
|
77 |
|
78 |
|
79 |
+
// add figure of the input montage
|
80 |
document.querySelector(selector).style.cssText = `
|
81 |
position: relative;
|
82 |
width: 100%;
|
83 |
aspect-ratio: 1;
|
84 |
//width: 560px;
|
85 |
//height: 560px;
|
86 |
+
background: url("file=${stage1_info.filenames.input_montage}");
|
87 |
background-size: contain;
|
88 |
|
89 |
`;
|
|
|
102 |
|
103 |
|
104 |
// add indication for the missing channels
|
105 |
+
channel = stage1_info.missingTemplates[0];
|
106 |
left = channel_info.templateDict[channel].css_position[0];
|
107 |
bottom = channel_info.templateDict[channel].css_position[1];
|
108 |
|
109 |
let dot_rule = `
|
110 |
${selector}::before {
|
111 |
+
content: "";
|
112 |
position: absolute;
|
113 |
background-color: red;
|
114 |
width: 10px;
|
|
|
148 |
"""
|
149 |
|
150 |
update_js = """
|
151 |
+
(app_info, channel_info) => {
|
152 |
+
app_info = JSON.parse(JSON.stringify(app_info));
|
153 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
154 |
+
stage1_info = app_info.stage1
|
155 |
|
156 |
let selector;
|
157 |
let channel, left, bottom;
|
158 |
|
159 |
+
if(stage1_info.state == "step2-selecting"){
|
160 |
selector = "#radio-group > div:nth-of-type(2)";
|
161 |
|
162 |
// update the radios
|
|
|
170 |
item.className = "";
|
171 |
item.querySelector(":scope > span").innerText = "";
|
172 |
});
|
173 |
+
}else if(stage1_info.state == "step3-selecting"){
|
174 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
175 |
}else return;
|
176 |
|
177 |
// update indication
|
178 |
+
channel = stage1_info.missingTemplates[stage1_info["fillingCount"]-1];
|
179 |
left = channel_info.templateDict[channel].css_position[0];
|
180 |
bottom = channel_info.templateDict[channel].css_position[1];
|
181 |
|
|
|
206 |
}
|
207 |
`;
|
208 |
|
209 |
+
// update the rules
|
210 |
const styleSheet = document.styleSheets[0];
|
211 |
for(let i=0; i<styleSheet.cssRules.length; i++){
|
212 |
let tmp = styleSheet.cssRules[i].selectorText;
|
|
|
223 |
|
224 |
with gr.Blocks() as demo:
|
225 |
|
226 |
+
app_info_json = gr.JSON(visible=False)
|
227 |
channel_info_json = gr.JSON(visible=False)
|
228 |
|
229 |
with gr.Row():
|
|
|
238 |
gr.Markdown("# 1.Channel Mapping")
|
239 |
# ------------------------input--------------------------
|
240 |
with gr.Row():
|
241 |
+
in_data_file = gr.File(label="Raw data (.csv)", file_types=[".csv"])
|
242 |
+
in_loc_file = gr.File(label="Channel locations (.loc, .locs, .xyz, .sfp, .txt)",
|
243 |
+
file_types=[".loc", "locs", ".xyz", ".sfp", ".txt"])
|
244 |
with gr.Row():
|
245 |
in_samplerate = gr.Textbox(label="Sampling rate (Hz)", scale=2)
|
246 |
+
map_btn = gr.Button("Mapping", interactive=False, scale=1)
|
247 |
|
248 |
# ------------------------mapping------------------------
|
249 |
# description for stage1-123
|
250 |
+
desc_md = gr.Markdown(visible=False)
|
|
|
251 |
# stage1-1 : mapping result
|
252 |
with gr.Row():
|
253 |
+
tpl_montage = gr.Image("./template_montage.png", label="Template channels", visible=False)
|
254 |
mapped_montage = gr.Image(label="Input channels", visible=False)
|
|
|
255 |
# stage1-2 : assign unmatched input channels to empty template channels
|
256 |
radio_group = gr.Radio(elem_id="radio-group", visible=False)
|
|
|
257 |
# stage1-3 : select a way to fill the empty template channels
|
258 |
with gr.Row():
|
259 |
+
in_fillmode = gr.Dropdown(choices=["mean", "zero"],
|
260 |
value="mean",
|
261 |
label="Filling method",
|
262 |
visible=False,
|
|
|
265 |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
|
266 |
|
267 |
with gr.Row():
|
268 |
+
clear_btn = gr.Button("Clear", visible=False)
|
269 |
step2_btn = gr.Button("Next", visible=False)
|
270 |
step3_btn = gr.Button("Next", visible=False)
|
271 |
next_btn = gr.Button("Next step", visible=False)
|
|
|
275 |
gr.Markdown("# 2.Decode Data")
|
276 |
# ------------------------input--------------------------
|
277 |
with gr.Row():
|
278 |
+
in_modelname = gr.Dropdown(choices=[
|
279 |
("ART", "EEGART"),
|
280 |
("IC-U-Net", "ICUNet"),
|
281 |
("IC-U-Net++", "UNetpp"),
|
282 |
+
("IC-U-Net-Attn", "AttUnet")],
|
|
|
|
|
283 |
value="EEGART",
|
284 |
label="Model",
|
285 |
scale=2)
|
|
|
287 |
|
288 |
# ------------------------output-------------------------
|
289 |
batch_md = gr.Markdown(visible=False)
|
290 |
+
out_data_file = gr.File(label="Denoised data", visible=False)
|
|
|
291 |
# -------------------------------------------------------
|
292 |
|
293 |
with gr.Row():
|
|
|
301 |
gr.Markdown()
|
302 |
with gr.Tab("QuickStart"):
|
303 |
gr.Markdown(quickstart)
|
304 |
+
|
305 |
+
|
306 |
+
# verify that all required inputs have been provided
|
307 |
+
@gr.on(triggers = [in_data_file.upload, in_data_file.clear, in_loc_file.upload, in_loc_file.clear, in_samplerate.change],
|
308 |
+
inputs = [in_data_file, in_loc_file, in_samplerate], outputs = map_btn)
|
309 |
+
def check_input(in_data, in_loc, samplerate):
|
310 |
+
if in_data!=None and in_loc!=None and samplerate!="":
|
311 |
+
return gr.Button(interactive=True)
|
312 |
+
else:
|
313 |
+
return gr.Button(interactive=False)
|
314 |
+
|
315 |
+
|
316 |
+
# +========================================================================================+
|
317 |
+
# | stage1: channel mapping |
|
318 |
+
# +========================================================================================+
|
319 |
+
def reset_all(in_data, in_loc, samplerate):
|
320 |
+
# establish a new folder for the current session
|
321 |
+
filepath = os.path.dirname(str(in_data))
|
322 |
try:
|
323 |
+
os.mkdir(filepath+"/session_data/")
|
324 |
except OSError as e:
|
325 |
+
utils.dataDelete(filepath+"/session_data/")
|
326 |
+
os.mkdir(filepath+"/session_data/")
|
327 |
+
print(e)
|
328 |
+
# establish new folders for stage1 and stage2
|
329 |
+
os.mkdir(filepath+"/session_data/stage1/")
|
330 |
+
os.mkdir(filepath+"/session_data/stage2/")
|
331 |
|
332 |
+
# initialize channel_info, app_info
|
333 |
channel_info = {}
|
334 |
+
app_info = {
|
335 |
+
"rootFilepath" : filepath+"/session_data/",
|
336 |
+
"sampleRate" : int(samplerate),
|
337 |
+
#"currentStage" : "stage1",
|
338 |
+
"stage1" : {
|
339 |
+
"filepath" : filepath+"/session_data/stage1/",
|
340 |
+
"filenames" : {
|
341 |
+
"input_data" : in_data,
|
342 |
+
"input_loc" : in_loc,
|
343 |
+
"input_montage" : "",
|
344 |
+
"mapped_montage" : ""
|
345 |
+
},
|
346 |
+
"state" : None,
|
347 |
+
"fillingCount" : None,
|
348 |
+
"totalFillingNum" : None,
|
349 |
+
"newOrder" : None,
|
350 |
+
"unassignedInputs" : None,
|
351 |
+
"missingTemplates" : None
|
352 |
+
},
|
353 |
+
"stage2" : {
|
354 |
+
"filepath" : filepath+"/session_data/stage2/",
|
355 |
+
"filenames" : {
|
356 |
+
"output_data" : ""
|
357 |
+
},
|
358 |
+
#"state" : None,
|
359 |
+
"totalBatchNum" : None,
|
360 |
+
"newOrder" : None,
|
361 |
+
"unassignedInputs" : None
|
362 |
+
}
|
363 |
}
|
|
|
364 |
# reset layout
|
365 |
+
return {app_info_json : app_info,
|
366 |
channel_info_json : channel_info,
|
367 |
# ------------------stage1-----------------------
|
368 |
+
map_btn : gr.Button(interactive=False),
|
369 |
+
desc_md : gr.Markdown(visible=False),
|
370 |
tpl_montage : gr.Image(visible=False),
|
371 |
mapped_montage : gr.Image(value=None, visible=False),
|
372 |
radio_group : gr.Radio(choices=[], value=[], label="", visible=False),
|
373 |
+
in_fillmode : gr.Dropdown(value="mean", visible=False),
|
374 |
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
|
375 |
fillmode_btn : gr.Button(visible=False),
|
376 |
clear_btn : gr.Button(visible=False),
|
377 |
step2_btn : gr.Button(visible=False),
|
378 |
step3_btn : gr.Button(visible=False),
|
379 |
+
next_btn : gr.Button("Next step", visible=False),
|
380 |
# ------------------stage2-----------------------
|
381 |
run_btn : gr.Button(interactive=False),
|
382 |
batch_md : gr.Markdown(visible=False),
|
383 |
+
out_data_file : gr.File(visible=False)}
|
384 |
|
385 |
|
386 |
+
# +========================================================================================+
|
387 |
+
# | stage1-1 |
|
388 |
+
# +========================================================================================+
|
389 |
def save_figures(channel_info, filename1, filename2):
|
390 |
|
391 |
template_montage = read_custom_montage("./template_chanlocs.loc")
|
|
|
394 |
template_order = channel_info["templateOrder"]
|
395 |
input_order = channel_info["inputOrder"]
|
396 |
|
397 |
+
# get template and input's 2d coords
|
398 |
+
tpl_x = [template_dict[channel]["coord_2d"][0] for channel in template_order]
|
399 |
+
tpl_y = [template_dict[channel]["coord_2d"][1] for channel in template_order]
|
400 |
+
in_x = [input_dict[channel]["coord_2d"][0] for channel in input_order]
|
401 |
+
in_y = [input_dict[channel]["coord_2d"][1] for channel in input_order]
|
402 |
+
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
403 |
+
in_coords = np.vstack((in_x, in_y)).T
|
404 |
+
|
405 |
# get template's head figure
|
406 |
tpl_fig = template_montage.plot()
|
407 |
tpl_ax = tpl_fig.axes[0]
|
|
|
412 |
head_lines.append((x,y))
|
413 |
plt.close()
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
# -------------------------plot input montage------------------------------
|
416 |
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
417 |
ax = fig.add_subplot(111)
|
|
|
422 |
# plot template's head
|
423 |
for x, y in head_lines:
|
424 |
ax.plot(x, y, color='black', linewidth=1.0)
|
425 |
+
# plot input channels on it
|
426 |
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
427 |
for i, channel in enumerate(input_order):
|
428 |
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
|
429 |
+
# save input_montage
|
|
|
430 |
fig.savefig(filename1)
|
431 |
|
432 |
# ---------------------------add indications-------------------------------
|
|
|
436 |
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
437 |
for i in indices:
|
438 |
ax.text(in_coords[i,0]+0.003, in_coords[i,1], input_order[i], color='red', fontsize=10.0, va='center')
|
|
|
439 |
# save mapped_montage
|
440 |
fig.savefig(filename2)
|
|
|
441 |
|
442 |
# -------------------------------------------------------------------------
|
443 |
+
# store the template and input channels' display position (in px).
|
444 |
tpl_coords = ax.transData.transform(tpl_coords)
|
445 |
in_coords = ax.transData.transform(in_coords)
|
446 |
+
plt.close()
|
447 |
|
448 |
for i, channel in enumerate(template_order):
|
449 |
css_left = (tpl_coords[i,0]-11)/6.4
|
|
|
461 |
})
|
462 |
return channel_info
|
463 |
|
464 |
+
def mapping_result(app_info, channel_info):
|
465 |
+
stage1_info = app_info["stage1"]
|
466 |
+
filepath = stage1_info["filepath"]
|
467 |
+
|
468 |
+
# generate and save figures of the input montage and the mapped montage
|
469 |
+
filename1 = filepath+"input_montage_"+str(random.randint(1,10000))+".png"
|
470 |
filename2 = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
|
471 |
channel_info = save_figures(channel_info, filename1, filename2)
|
472 |
+
stage1_info["filenames"].update({
|
473 |
+
"input_montage" : filename1,
|
|
|
474 |
"mapped_montage" : filename2
|
475 |
})
|
476 |
|
477 |
+
# -------------------------determine the next step--------------------------
|
478 |
|
479 |
+
input_num = len(channel_info["inputOrder"])
|
480 |
+
matched_num = 30 - len(stage1_info["missingTemplates"])
|
481 |
|
482 |
+
# if the in_channels has all the 30 tpl_channels (input_num>=30)
|
483 |
# -> stage2
|
484 |
if matched_num == 30:
|
485 |
+
stage1_info["state"] = "finished"
|
486 |
gr.Info('The mapping process has been finished.')
|
487 |
+
if input_num > 30:
|
488 |
+
md = """
|
489 |
+
### Mapping result
|
490 |
+
(...in red...)
|
491 |
+
"""
|
492 |
+
else:
|
493 |
+
md = """
|
494 |
+
### Mapping result
|
495 |
+
(...)
|
496 |
+
"""
|
497 |
|
498 |
+
app_info["stage1"] = stage1_info
|
499 |
+
return {app_info_json : app_info,
|
500 |
channel_info_json : channel_info,
|
501 |
+
map_btn : gr.Button(interactive=True),
|
502 |
+
desc_md : gr.Markdown(md, visible=True),
|
503 |
tpl_montage : gr.Image(visible=True),
|
504 |
mapped_montage : gr.Image(value=filename2, visible=True),
|
505 |
run_btn : gr.Button(interactive=True)}
|
506 |
|
507 |
+
else:
|
508 |
+
# if matched_num < 30, and there're still some unmatched in_channels
|
509 |
+
# -> assign these in_channels to nearby unmatched tpl_channels
|
510 |
+
if input_num > matched_num:
|
511 |
+
stage1_info["state"] = "step2-initializing"
|
512 |
+
md = """
|
513 |
+
### Step1: Mapping result
|
514 |
+
(...in red...)
|
515 |
+
"""
|
516 |
+
|
517 |
+
# if input_num < 30, but all of them can match to some tpl_channels
|
518 |
+
# -> directly use fillmode to fill the remaining tpl_channels
|
519 |
+
elif input_num == matched_num:
|
520 |
+
stage1_info["state"] = "step3-initializing"
|
521 |
+
md = """
|
522 |
+
### Step1: Mapping result
|
523 |
+
(...)
|
524 |
+
"""
|
525 |
+
|
526 |
+
app_info["stage1"] = stage1_info
|
527 |
+
return {app_info_json : app_info,
|
528 |
+
channel_info_json : channel_info,
|
529 |
+
map_btn : gr.Button(interactive=True),
|
530 |
+
desc_md : gr.Markdown(md, visible=True),
|
531 |
+
tpl_montage : gr.Image(visible=True),
|
532 |
+
mapped_montage : gr.Image(value=filename2, visible=True),
|
533 |
+
next_btn : gr.Button(visible=True)}
|
534 |
|
535 |
+
start_stage1 = map_btn.click(
|
536 |
fn = reset_all,
|
537 |
+
inputs = [in_data_file, in_loc_file, in_samplerate],
|
538 |
+
outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_montage, mapped_montage, radio_group,
|
539 |
+
in_fillmode, chkbox_group, fillmode_btn, clear_btn, step2_btn, step3_btn, next_btn,
|
540 |
+
run_btn, batch_md, out_data_file]
|
541 |
).success(
|
542 |
fn = mapping_stage1,
|
543 |
+
inputs = [app_info_json, channel_info_json],
|
544 |
+
outputs = [app_info_json, channel_info_json, desc_md]
|
|
|
545 |
).success(
|
546 |
fn = mapping_result,
|
547 |
+
inputs = [app_info_json, channel_info_json],
|
548 |
+
outputs = [app_info_json, channel_info_json, map_btn, desc_md, tpl_montage, mapped_montage, next_btn, run_btn]
|
549 |
)
|
550 |
|
551 |
+
|
552 |
+
# +========================================================================================+
|
553 |
+
# | manage step transition |
|
554 |
+
# +========================================================================================+
|
555 |
+
def init_next_step(app_info, channel_info, selected_radio, selected_chkbox):
|
556 |
+
stage1_info = app_info["stage1"]
|
557 |
|
558 |
# stage1-1 -> stage1-2
|
559 |
+
if stage1_info["state"] == "step2-initializing":
|
560 |
+
#print('step1 -> step2')
|
561 |
+
md = """
|
562 |
+
### Step2: Assign unmatched input channels
|
563 |
+
(...)
|
564 |
+
"""
|
565 |
+
|
566 |
+
# initialize the progress indication label for step2
|
567 |
+
stage1_info.update({
|
568 |
+
"state" : "step2-selecting",
|
569 |
"fillingCount" : 1,
|
570 |
+
"totalFillingNum" : len(stage1_info["missingTemplates"])
|
571 |
})
|
572 |
+
name = stage1_info["missingTemplates"][0]
|
573 |
+
label = "{} (1/{})".format(name, stage1_info["totalFillingNum"])
|
574 |
|
575 |
+
app_info["stage1"] = stage1_info
|
576 |
+
# determine which button to display
|
577 |
+
if len(stage1_info["unassignedInputs"])==1 or stage1_info["totalFillingNum"]==1:
|
578 |
+
return {app_info_json : app_info,
|
|
|
579 |
channel_info_json : channel_info,
|
580 |
+
desc_md : gr.Markdown(md),
|
581 |
tpl_montage : gr.Image(visible=False),
|
582 |
mapped_montage : gr.Image(visible=False),
|
583 |
+
radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
|
584 |
clear_btn : gr.Button(visible=True),
|
585 |
next_btn : gr.Button("Next step")}
|
586 |
else:
|
587 |
+
return {app_info_json : app_info,
|
588 |
channel_info_json : channel_info,
|
589 |
+
desc_md : gr.Markdown(md),
|
590 |
tpl_montage : gr.Image(visible=False),
|
591 |
mapped_montage : gr.Image(visible=False),
|
592 |
+
radio_group : gr.Radio(choices=stage1_info["unassignedInputs"], value=[], label=label, visible=True),
|
593 |
clear_btn : gr.Button(visible=True),
|
594 |
step2_btn : gr.Button(visible=True),
|
595 |
next_btn : gr.Button(visible=False)}
|
596 |
|
597 |
# stage1-1 -> stage1-3
|
598 |
+
elif stage1_info["state"] == "step3-initializing":
|
599 |
+
#print('step1 -> step3')
|
600 |
+
md = """
|
601 |
+
### Step3: Fill the remaining template channels
|
602 |
+
(...)
|
603 |
+
"""
|
604 |
+
return {desc_md : gr.Markdown(md),
|
|
|
|
|
|
|
|
|
|
|
605 |
tpl_montage : gr.Image(visible=False),
|
606 |
mapped_montage : gr.Image(visible=False),
|
607 |
+
in_fillmode : gr.Dropdown(visible=True),
|
608 |
fillmode_btn : gr.Button(visible=True),
|
609 |
next_btn : gr.Button(visible=False)}
|
610 |
|
611 |
# stage1-2 -> stage1-3 or stage2
|
612 |
+
elif stage1_info["state"] == "step2-selecting":
|
613 |
|
614 |
+
# ----------------------store information before the button click----------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
615 |
|
616 |
+
# check if the user has selected an in_channel to forward to the previous target tpl_channel
|
617 |
+
if selected_radio != []:
|
618 |
+
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
619 |
+
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
620 |
+
|
621 |
+
# store the index of the in_channel
|
622 |
+
selected_idx = channel_info["inputDict"][selected_radio]["index"]
|
623 |
+
stage1_info["newOrder"][prev_target_idx] = [selected_idx]
|
624 |
+
# mark the in_channel as assigned and tpl_channel as matched
|
625 |
channel_info["templateDict"][prev_target_name]["matched"] = True
|
626 |
channel_info["inputDict"][selected_radio]["assigned"] = True
|
627 |
+
print(prev_target_name, '<-', selected_radio)
|
628 |
|
629 |
+
# ------------------------update information for the next step-------------------------
|
630 |
+
|
631 |
+
# update the list of unassignedInputs to exclude the selected in_channel of the previous round
|
632 |
+
stage1_info["unassignedInputs"] = [channel for channel in channel_info["inputOrder"]
|
633 |
+
if channel_info["inputDict"][channel]["assigned"]==False]
|
634 |
+
# update the list of missingTemplates to exclude those filled in step2
|
635 |
+
stage1_info["missingTemplates"] = [channel for channel in channel_info["templateOrder"]
|
636 |
+
if channel_info["templateDict"][channel]["matched"]==False]
|
637 |
|
638 |
+
# if all the unmatched tpl_channels were filled by in_channels
|
639 |
# -> stage2
|
640 |
+
if len(stage1_info["missingTemplates"]) == 0:
|
641 |
+
#print('step2 -> stage2')
|
642 |
+
stage1_info["state"] = "finished"
|
643 |
gr.Info('The mapping process has been finished.')
|
|
|
644 |
|
645 |
+
app_info["stage1"] = stage1_info
|
646 |
+
return {app_info_json : app_info,
|
647 |
channel_info_json : channel_info,
|
648 |
desc_md : gr.Markdown(visible=False),
|
649 |
radio_group : gr.Radio(visible=False),
|
|
|
653 |
|
654 |
# -> stage1-3
|
655 |
else:
|
656 |
+
#print('step2 -> step3')
|
657 |
+
md = """
|
658 |
+
### Step3: Fill the remaining template channels
|
659 |
+
(...)
|
660 |
+
"""
|
661 |
+
|
662 |
+
app_info["stage1"] = stage1_info
|
663 |
+
return {app_info_json : app_info,
|
664 |
channel_info_json : channel_info,
|
665 |
+
desc_md : gr.Markdown(md),
|
666 |
radio_group : gr.Radio(visible=False),
|
667 |
+
in_fillmode : gr.Dropdown(visible=True),
|
668 |
fillmode_btn : gr.Button(visible=True),
|
669 |
clear_btn : gr.Button(visible=False),
|
670 |
next_btn : gr.Button(visible=False)}
|
671 |
|
672 |
# stage1-3 -> stage2
|
673 |
+
elif stage1_info["state"] == "step3-selecting":
|
674 |
+
#print('step3 -> stage2')
|
675 |
+
stage1_info["state"] = "finished"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
676 |
gr.Info('The mapping process has been finished.')
|
|
|
|
|
677 |
|
678 |
+
# ----------------------store information before the button click----------------------
|
679 |
+
|
680 |
+
# check if the user has not unchecked all in_channel checkboxes
|
681 |
+
if selected_chkbox != []:
|
682 |
+
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
683 |
+
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
684 |
+
|
685 |
+
# store the indices of the in_channels
|
686 |
+
selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected_chkbox]
|
687 |
+
stage1_info["newOrder"][prev_target_idx] = selected_indices
|
688 |
+
#print(f'{prev_target_name}({prev_target_idx}): {selected_chkbox}')
|
689 |
+
# -------------------------------------------------------------------------------------
|
690 |
|
691 |
+
app_info["stage1"] = stage1_info
|
692 |
+
return {app_info_json : app_info,
|
693 |
desc_md : gr.Markdown(visible=False),
|
694 |
chkbox_group : gr.CheckboxGroup(visible=False),
|
695 |
next_btn : gr.Button(visible=False),
|
|
|
697 |
|
698 |
next_btn.click(
|
699 |
fn = init_next_step,
|
700 |
+
inputs = [app_info_json, channel_info_json, radio_group, chkbox_group],
|
701 |
+
outputs = [app_info_json, channel_info_json, desc_md, tpl_montage, mapped_montage, radio_group,
|
702 |
+
in_fillmode, chkbox_group, fillmode_btn, clear_btn, step2_btn, next_btn, run_btn]
|
703 |
).success(
|
704 |
fn = None,
|
705 |
js = init_js,
|
706 |
+
inputs = [app_info_json, channel_info_json],
|
707 |
outputs = []
|
708 |
)
|
709 |
|
710 |
+
|
711 |
+
# +========================================================================================+
|
712 |
+
# | stage1-2 |
|
713 |
+
# +========================================================================================+
|
714 |
+
def update_radio(app_info, channel_info, selected):
|
715 |
+
stage1_info = app_info["stage1"]
|
716 |
|
717 |
+
# ----------------------store information before the button click----------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
718 |
|
719 |
+
# check if the user has selected an in_channel to forward to the previous target tpl_channel
|
720 |
+
if selected != []:
|
721 |
+
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
722 |
+
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
723 |
+
|
724 |
+
# store the index of the selected in_channel
|
725 |
+
selected_idx = channel_info["inputDict"][selected]["index"]
|
726 |
+
stage1_info["newOrder"][prev_target_idx] = [selected_idx]
|
727 |
+
# mark the in_channel as assigned and tpl_channel as matched
|
728 |
channel_info["templateDict"][prev_target_name]["matched"] = True
|
729 |
channel_info["inputDict"][selected]["assigned"] = True
|
730 |
+
print(prev_target_name, '<-', selected)
|
731 |
+
|
732 |
+
# ------------------------update information for the new round-------------------------
|
733 |
+
stage1_info["fillingCount"] += 1
|
734 |
|
735 |
+
# update the list of unassignedInputs to exclude the selected in_channel of the previous round
|
736 |
+
stage1_info["unassignedInputs"] = [channel for channel in channel_info["inputOrder"]
|
|
|
737 |
if channel_info["inputDict"][channel]["assigned"]==False]
|
738 |
+
# update the progress indication label
|
739 |
+
target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
740 |
+
radio_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
|
741 |
|
742 |
+
app_info["stage1"] = stage1_info
|
743 |
+
# determine which button to display
|
744 |
+
if len(stage1_info["unassignedInputs"])==1 or stage1_info["fillingCount"]==stage1_info["totalFillingNum"]:
|
745 |
+
return {app_info_json : app_info,
|
|
|
746 |
channel_info_json : channel_info,
|
747 |
+
radio_group : gr.Radio(choices=stage1_info["unassignedInputs"],
|
748 |
value=[], label=radio_label),
|
749 |
step2_btn : gr.Button(visible=False),
|
750 |
next_btn : gr.Button("Next step", visible=True)}
|
751 |
else:
|
752 |
+
return {app_info_json : app_info,
|
753 |
channel_info_json : channel_info,
|
754 |
+
radio_group : gr.Radio(choices=stage1_info["unassignedInputs"],
|
755 |
value=[], label=radio_label)}
|
756 |
|
757 |
step2_btn.click(
|
758 |
fn = update_radio,
|
759 |
+
inputs = [app_info_json, channel_info_json, radio_group],
|
760 |
+
outputs = [app_info_json, channel_info_json, radio_group, step2_btn, next_btn]
|
|
|
761 |
).success(
|
762 |
fn = None,
|
763 |
js = update_js,
|
764 |
+
inputs = [app_info_json, channel_info_json],
|
765 |
outputs = []
|
766 |
)
|
767 |
|
|
|
772 |
)
|
773 |
|
774 |
|
775 |
+
# +========================================================================================+
|
776 |
+
# | stage1-3 |
|
777 |
+
# +========================================================================================+
|
778 |
+
def fill_value(app_info, channel_info, fillmode):
|
779 |
+
stage1_info = app_info["stage1"]
|
780 |
|
781 |
+
if fillmode == "zero":
|
782 |
+
stage1_info["state"] = "finished"
|
783 |
gr.Info('The mapping process has been finished.')
|
784 |
|
785 |
+
app_info["stage1"] = stage1_info
|
786 |
+
return {app_info_json : app_info,
|
787 |
desc_md : gr.Markdown(visible=False),
|
788 |
+
in_fillmode : gr.Dropdown(visible=False),
|
789 |
fillmode_btn : gr.Button(visible=False),
|
790 |
run_btn : gr.Button(interactive=True)}
|
791 |
|
792 |
+
elif fillmode == "mean":
|
793 |
+
md = """
|
794 |
+
### Step3: Fill the remaining template channels
|
795 |
+
(...)
|
796 |
+
"""
|
797 |
|
798 |
+
# find the 4-NN in_channels for each of the unmatched tpl_channels
|
799 |
+
new_idx = find_neighbors(channel_info, stage1_info["missingTemplates"], stage1_info["newOrder"])
|
800 |
+
|
801 |
+
stage1_info.update({
|
802 |
+
"state" : "step3-selecting",
|
803 |
+
"newOrder" : new_idx,
|
804 |
+
"fillingCount" : 1,
|
805 |
+
"totalFillingNum" : len(stage1_info["missingTemplates"])
|
806 |
+
})
|
807 |
|
808 |
+
# initialize the progress indicator label
|
809 |
+
target_name = stage1_info["missingTemplates"][0]
|
810 |
+
target_idx = channel_info["templateDict"][target_name]["index"]
|
811 |
+
chkbox_value = stage1_info["newOrder"][target_idx]
|
812 |
chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
|
813 |
+
chkbox_label = "{} (1/{})".format(target_name, stage1_info["totalFillingNum"])
|
814 |
|
815 |
+
app_info["stage1"] = stage1_info
|
816 |
+
# determine which button to display
|
817 |
+
if stage1_info["totalFillingNum"] == 1:
|
818 |
+
return {app_info_json : app_info,
|
819 |
+
desc_md : gr.Markdown(md),
|
820 |
+
in_fillmode : gr.Dropdown(visible=False),
|
821 |
fillmode_btn : gr.Button(visible=False),
|
822 |
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
|
823 |
value=chkbox_value, label=chkbox_label, visible=True),
|
824 |
next_btn : gr.Button(visible=True)}
|
825 |
else:
|
826 |
+
return {app_info_json : app_info,
|
827 |
+
desc_md : gr.Markdown(md),
|
828 |
+
in_fillmode : gr.Dropdown(visible=False),
|
829 |
fillmode_btn : gr.Button(visible=False),
|
830 |
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputOrder"],
|
831 |
value=chkbox_value, label=chkbox_label, visible=True),
|
832 |
step3_btn : gr.Button(visible=True)}
|
833 |
|
834 |
+
def update_chkbox(app_info, channel_info, selected):
|
835 |
+
stage1_info = app_info["stage1"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
|
837 |
+
# ----------------------store information before the button click----------------------
|
|
|
838 |
|
839 |
+
# check if the user has not unchecked all in_channel checkboxes
|
840 |
+
if selected != []:
|
841 |
+
prev_target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
842 |
+
prev_target_idx = channel_info["templateDict"][prev_target_name]["index"]
|
843 |
+
|
844 |
+
# store the indices of the selected in_channels
|
845 |
+
selected_indices = [channel_info["inputDict"][channel]["index"] for channel in selected]
|
846 |
+
stage1_info["newOrder"][prev_target_idx] = selected_indices
|
847 |
+
#print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
|
848 |
+
|
849 |
+
# ------------------------update information for the new round-------------------------
|
850 |
+
stage1_info["fillingCount"] += 1
|
851 |
|
852 |
+
# update the progress indication label
|
853 |
+
target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
|
854 |
+
target_idx = channel_info["templateDict"][target_name]["index"]
|
855 |
+
chkbox_value = stage1_info["newOrder"][target_idx]
|
856 |
chkbox_value = [channel_info["inputOrder"][i] for i in chkbox_value]
|
857 |
+
chkbox_label = "{} ({}/{})".format(target_name, stage1_info["fillingCount"], stage1_info["totalFillingNum"])
|
858 |
|
859 |
+
app_info["stage1"] = stage1_info
|
860 |
+
# determine which button to display
|
861 |
+
if stage1_info["fillingCount"] == stage1_info["totalFillingNum"]:
|
862 |
+
return {app_info_json : app_info,
|
863 |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label),
|
864 |
step3_btn : gr.Button(visible=False),
|
865 |
next_btn : gr.Button("Submit", visible=True)}
|
866 |
else:
|
867 |
+
return {app_info_json : app_info,
|
868 |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)}
|
869 |
|
870 |
fillmode_btn.click(
|
871 |
fn = fill_value,
|
872 |
+
inputs = [app_info_json, channel_info_json, in_fillmode],
|
873 |
+
outputs = [app_info_json, desc_md, in_fillmode, fillmode_btn, chkbox_group, step3_btn, next_btn, run_btn]
|
874 |
).success(
|
875 |
fn = None,
|
876 |
js = init_js,
|
877 |
+
inputs = [app_info_json, channel_info_json],
|
878 |
outputs = []
|
879 |
)
|
880 |
|
881 |
step3_btn.click(
|
882 |
fn = update_chkbox,
|
883 |
+
inputs = [app_info_json, channel_info_json, chkbox_group],
|
884 |
+
outputs = [app_info_json, chkbox_group, step3_btn, next_btn]
|
|
|
885 |
).success(
|
886 |
fn = None,
|
887 |
js = update_js,
|
888 |
+
inputs = [app_info_json, channel_info_json],
|
889 |
outputs = []
|
890 |
)
|
891 |
|
|
|
|
|
|
|
|
|
|
|
|
|
892 |
|
893 |
+
# +========================================================================================+
|
894 |
+
# | stage2: decode data |
|
895 |
+
# +========================================================================================+
|
896 |
+
def reset_run(app_info, channel_info, modelname):
|
897 |
+
stage1_info = app_info["stage1"]
|
898 |
+
stage2_info = app_info["stage2"]
|
899 |
|
900 |
+
# delete the previous folder of stage2 if it exists
|
901 |
+
filepath = stage2_info["filepath"]
|
902 |
+
utils.dataDelete(filepath)
|
903 |
+
# establish a new folder for stage2
|
904 |
+
new_filepath = app_info["rootFilepath"]+"stage2_"+str(random.randint(1,10000))+"/"
|
905 |
+
os.mkdir(new_filepath)
|
906 |
+
# generate the output filename
|
907 |
+
filename = stage1_info["filenames"]["input_data"]
|
908 |
+
filename = os.path.basename(str(filename))
|
909 |
+
new_filename = os.path.splitext(filename)[0]+'_'+modelname+'.csv'
|
910 |
|
911 |
+
# reset inputChannel.assigned back to the state after stage1
|
912 |
+
for channel in stage1_info["unassignedInputs"]:
|
913 |
+
channel_info["inputDict"][channel]["assigned"] = False
|
914 |
+
# calculate how many times the model needs to be run
|
915 |
+
unassigned_num = len(stage1_info["unassignedInputs"])
|
916 |
+
batch_num = math.ceil(unassigned_num/30) + 1
|
917 |
|
918 |
+
app_info.update({
|
919 |
+
#"currentStage" : "stage2",
|
920 |
+
"stage2" : {
|
921 |
+
"filepath" : new_filepath,
|
922 |
+
"filenames" : {
|
923 |
+
"output_data" : new_filepath + new_filename
|
924 |
+
},
|
925 |
+
#"state" : "initializing",
|
926 |
+
"totalBatchNum" : batch_num,
|
927 |
+
"newOrder" : [[]]*30,
|
928 |
+
"unassignedInputs" : stage1_info["unassignedInputs"]
|
929 |
+
}
|
930 |
})
|
931 |
+
return {app_info_json : app_info,
|
932 |
channel_info_json : channel_info,
|
933 |
+
#run_btn : gr.Button(interactive=False),
|
934 |
batch_md : gr.Markdown(visible=False),
|
935 |
+
out_data_file : gr.File(visible=False)}
|
936 |
|
937 |
+
def run_model(app_info, channel_info, modelname):
|
938 |
+
stage1_info = app_info["stage1"]
|
939 |
+
stage2_info = app_info["stage2"]
|
|
|
940 |
|
941 |
+
filepath = stage2_info["filepath"]
|
942 |
+
samplerate = app_info["sampleRate"]
|
943 |
+
filename = stage1_info["filenames"]["input_data"]
|
944 |
+
new_filename = stage2_info["filenames"]["output_data"]
|
945 |
+
|
946 |
+
# set a flag to record whether the user has clicked the map_btn or run_btn while running the model
|
947 |
+
break_flag = False
|
948 |
+
|
949 |
+
# run the model multiple times until all in_channels are reconstructed
|
950 |
+
for i in range(stage2_info["totalBatchNum"]):
|
951 |
+
# establish a temp folder
|
952 |
+
try:
|
953 |
+
os.mkdir(filepath+"temp_data/")
|
954 |
+
#except FileExistsError:
|
955 |
+
#utils.dataDelete(filepath+"temp_data/")
|
956 |
+
#os.mkdir(filepath+"temp_data/")
|
957 |
+
except FileNotFoundError:
|
958 |
+
#print('break1!!')
|
959 |
+
break_flag = True
|
960 |
+
break
|
961 |
+
except OSError as e:
|
962 |
+
print(e)
|
963 |
|
964 |
+
# update the running status
|
965 |
+
md = "Running model({}/{})...".format(i+1, stage2_info["totalBatchNum"])
|
966 |
+
yield {batch_md : gr.Markdown(md, visible=True)}
|
|
|
967 |
|
968 |
+
if i == 0:
|
969 |
+
new_idx = stage1_info["newOrder"]
|
970 |
+
else:
|
971 |
+
# if this is not the first time running the model, the in_channels that have
|
972 |
+
# not been reconstructed yet will be optimally mapped to the template.
|
973 |
+
stage2_info, channel_info = mapping_stage2(stage2_info, channel_info)
|
974 |
+
new_idx = stage2_info["newOrder"]
|
975 |
+
#print('unassigned num:', len(stage2_info["unassignedInputs"]))
|
976 |
|
977 |
+
# ----------------------------------------------------------------------
|
978 |
+
try:
|
979 |
+
# step1: Reorder input data
|
980 |
+
reorder_input_data(new_idx, filename, filepath+"temp_data/mapped.csv")
|
981 |
+
# step2: Data preprocessing
|
982 |
+
total_file_num = utils.preprocessing(filepath+"temp_data/", "mapped.csv", samplerate)
|
983 |
+
# step3: Signal reconstruction
|
984 |
+
utils.reconstruct(modelname, total_file_num, filepath+"temp_data/", "denoised.csv", samplerate)
|
985 |
+
# step4: Restore original order
|
986 |
+
restore_original_order(channel_info, i, new_idx, filepath+"temp_data/denoised.csv", new_filename)
|
987 |
+
except FileNotFoundError:
|
988 |
+
#print('break2!!')
|
989 |
+
break_flag = True
|
990 |
+
break
|
991 |
+
# ----------------------------------------------------------------------
|
992 |
+
utils.dataDelete(filepath+"temp_data/")
|
993 |
+
app_info["stage2"] = stage2_info
|
994 |
+
|
995 |
+
if break_flag == True:
|
996 |
+
yield {batch_md : gr.Markdown(visible=False)}
|
997 |
+
else:
|
998 |
+
yield {#run_btn : gr.Button(interactive=True),
|
999 |
+
batch_md : gr.Markdown(visible=False),
|
1000 |
+
out_data_file : gr.File(new_filename, visible=True)}
|
1001 |
|
1002 |
run_btn.click(
|
1003 |
fn = reset_run,
|
1004 |
+
inputs = [app_info_json, channel_info_json, in_modelname],
|
1005 |
+
outputs = [app_info_json, channel_info_json, run_btn, batch_md, out_data_file]
|
1006 |
|
1007 |
).success(
|
1008 |
fn = run_model,
|
1009 |
+
inputs = [app_info_json, channel_info_json, in_modelname],
|
1010 |
+
outputs = [run_btn, batch_md, out_data_file]
|
1011 |
)
|
1012 |
|
1013 |
if __name__ == "__main__":
|
1014 |
demo.launch()
|
1015 |
+
|
1016 |
+
|
1017 |
+
"""
|
1018 |
+
--------
|
1019 |
+
|----(inputname).csv
|
1020 |
+
|----session_data
|
1021 |
+
|----stage1
|
1022 |
+
|----input_montage.png
|
1023 |
+
|----mapped_montage.png
|
1024 |
+
|----stage2_(...)
|
1025 |
+
|----temp_data
|
1026 |
+
|----mapped.csv
|
1027 |
+
|----denoised.csv
|
1028 |
+
|----temp2
|
1029 |
+
|...
|
1030 |
+
|----(outputname).csv
|
1031 |
+
"""
|
1032 |
+
|
channel_mapping.py
CHANGED
@@ -10,13 +10,10 @@ from scipy.interpolate import Rbf
|
|
10 |
from scipy.optimize import linear_sum_assignment
|
11 |
from sklearn.neighbors import NearestNeighbors
|
12 |
|
13 |
-
def
|
14 |
-
old_idx = app_state["stage1NewOrder"] if app_state["runningState"]=="stage1" else app_state["stage2NewOrder"]
|
15 |
old_data = utils.read_train_data(filename) # original raw data
|
16 |
-
new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
|
17 |
-
|
18 |
-
#print('new order 1:', app_state["stage1NewOrder"])
|
19 |
-
#print('new order 2:', app_state["stage2NewOrder"])
|
20 |
|
21 |
zero_arr = np.zeros((1, old_data.shape[1]))
|
22 |
old_data = np.concatenate((old_data, zero_arr), axis=0)
|
@@ -31,25 +28,25 @@ def reorder_to_template(app_state, filename):
|
|
31 |
tmp_data = [old_data[j, :] for j in idx_set]
|
32 |
new_data[i, :] = np.mean(tmp_data, axis=0)
|
33 |
|
34 |
-
|
|
|
35 |
utils.save_data(new_data, new_filename)
|
36 |
return
|
37 |
|
38 |
-
def
|
39 |
-
filename = app_state["filepath"]+'denoised.csv'
|
40 |
-
old_idx = app_state["stage1NewOrder"] if app_state["runningState"]=="stage1" else app_state["stage2NewOrder"]
|
41 |
old_data = utils.read_train_data(filename) # denoised data
|
42 |
template_order = channel_info["templateOrder"]
|
|
|
43 |
|
44 |
-
if
|
45 |
-
new_data = np.zeros((len(
|
46 |
else:
|
47 |
new_data = utils.read_train_data(new_filename)
|
48 |
|
49 |
for i, channel in enumerate(template_order):
|
50 |
idx_set = old_idx[i]
|
51 |
|
52 |
-
# ignore if this channel was filled with
|
53 |
if len(idx_set)==1 and channel_info["templateDict"][channel]["matched"]==True:
|
54 |
new_data[idx_set[0], :] = old_data[i, :]
|
55 |
|
@@ -97,9 +94,9 @@ def align_coords(channel_info, template_montage, input_montage):
|
|
97 |
|
98 |
|
99 |
# --------------------------------2-D------------------------------------
|
100 |
-
# (for the
|
101 |
|
102 |
-
fig = [template_montage.plot(), input_montage.plot()]
|
103 |
ax = [fig[0].axes[0], fig[1].axes[0]]
|
104 |
|
105 |
# get the original coords
|
@@ -154,48 +151,38 @@ def align_coords(channel_info, template_montage, input_montage):
|
|
154 |
})
|
155 |
return channel_info
|
156 |
|
157 |
-
def find_neighbors(
|
158 |
-
new_idx = app_state["stage1NewOrder"] if app_state["runningState"]=="stage1" else app_state["stage2NewOrder"]
|
159 |
template_dict = channel_info["templateDict"]
|
160 |
input_dict = channel_info["inputDict"]
|
161 |
-
template_order = channel_info["templateOrder"]
|
162 |
input_order = channel_info["inputOrder"]
|
163 |
-
missing_channels = app_state["missingTemplates"]
|
164 |
-
if missing_channels == []:
|
165 |
-
return app_state # change nothing
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
in_coords = np.array([in_coords[i] for i in range(len(in_coords))])
|
170 |
|
171 |
# use KNN to choose k nearest channels
|
172 |
k = 4 if len(input_order)>4 else len(input_order)
|
173 |
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
174 |
-
knn.fit(
|
175 |
|
176 |
-
for channel in missing_channels:
|
177 |
-
distances, indices = knn.kneighbors(
|
178 |
-
selected = [input_order[
|
179 |
#print(channel, ':', selected)
|
180 |
|
181 |
idx = template_dict[channel]["index"]
|
182 |
new_idx[idx] = indices[0].tolist()
|
183 |
-
|
184 |
-
if app_state["runningState"] == "stage1":
|
185 |
-
app_state["stage1NewOrder"] = new_idx
|
186 |
-
else:
|
187 |
-
app_state["stage2NewOrder"] = new_idx
|
188 |
|
189 |
-
return
|
190 |
|
191 |
-
def mapping_stage1(
|
192 |
-
yield
|
193 |
second1 = time.time()
|
194 |
|
|
|
195 |
template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
|
196 |
template_order = template_montage.ch_names
|
197 |
input_order = input_montage.ch_names
|
198 |
-
new_idx = [[]]*30
|
199 |
alias_dict = {
|
200 |
'T3': 'T7',
|
201 |
'T4': 'T8',
|
@@ -203,21 +190,20 @@ def mapping_stage1(app_state, channel_info, loc_file):
|
|
203 |
'T6': 'P8'
|
204 |
}
|
205 |
|
206 |
-
# match the names of input channels
|
207 |
for i, channel in enumerate(template_order):
|
208 |
if channel in alias_dict and alias_dict[channel] in input_dict:
|
209 |
-
template_montage.rename_channels({channel: alias_dict[channel]})
|
210 |
template_dict[alias_dict[channel]] = template_dict.pop(channel)
|
211 |
channel = alias_dict[channel]
|
212 |
-
|
213 |
if channel in input_dict:
|
214 |
new_idx[i] = [input_dict[channel]["index"]]
|
215 |
template_dict[channel]["matched"] = True
|
216 |
input_dict[channel]["assigned"] = True
|
217 |
|
218 |
-
# update names
|
219 |
template_order = template_montage.ch_names
|
220 |
-
input_order = input_montage.ch_names
|
221 |
|
222 |
channel_info.update({
|
223 |
"templateDict" : template_dict,
|
@@ -225,10 +211,9 @@ def mapping_stage1(app_state, channel_info, loc_file):
|
|
225 |
"templateOrder" : template_order,
|
226 |
"inputOrder" : input_order
|
227 |
})
|
228 |
-
|
229 |
-
"
|
230 |
-
"
|
231 |
-
"stage1UnassignedInputs" : [channel for channel in input_order if input_dict[channel]["assigned"]==False],
|
232 |
"missingTemplates" : [channel for channel in template_order if template_dict[channel]["matched"]==False]
|
233 |
})
|
234 |
|
@@ -237,19 +222,16 @@ def mapping_stage1(app_state, channel_info, loc_file):
|
|
237 |
|
238 |
second2 = time.time()
|
239 |
print('Mapping (stage1) finished in',second2 - second1,'s.')
|
240 |
-
yield
|
241 |
|
242 |
-
def mapping_stage2(
|
243 |
second1 = time.time()
|
244 |
|
245 |
template_dict = channel_info["templateDict"]
|
246 |
input_dict = channel_info["inputDict"]
|
247 |
template_order = channel_info["templateOrder"]
|
248 |
input_order = channel_info["inputOrder"]
|
249 |
-
unassigned =
|
250 |
-
if unassigned == []:
|
251 |
-
app_state["runningState"] = "finished"
|
252 |
-
return app_state, channel_info
|
253 |
|
254 |
tpl_coords = np.array([template_dict[channel]["coord_3d"] for channel in template_order])
|
255 |
unassigned_coords = np.array([input_dict[channel]["coord_3d"] for channel in unassigned])
|
@@ -274,31 +256,29 @@ def mapping_stage2(app_state, channel_info):
|
|
274 |
new_idx = [[]]*30
|
275 |
for i in range(30):
|
276 |
if col_idx[i] < len(unassigned): # filter out dummy channels
|
277 |
-
print(f'({row_idx[i]}, {col_idx[i]})')
|
278 |
-
|
279 |
tpl_channel = template_order[row_idx[i]]
|
280 |
in_channel = unassigned[col_idx[i]]
|
281 |
template_dict[tpl_channel]["matched"] = True
|
282 |
input_dict[in_channel]["assigned"] = True
|
283 |
new_idx[row_idx[i]] = [input_dict[in_channel]["index"]]
|
284 |
|
285 |
-
print(template_order[row_idx[i]]
|
286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
channel_info.update({
|
288 |
"templateDict" : template_dict,
|
289 |
"inputDict" : input_dict
|
290 |
})
|
291 |
-
app_state.update({
|
292 |
-
"stage2NewOrder" : new_idx,
|
293 |
-
"runningState" : "stage2",
|
294 |
-
"stage2UnassignedInputs" : [channel for channel in input_order if input_dict[channel]["assigned"]==False],
|
295 |
-
"missingTemplates" : [channel for channel in template_order if template_dict[channel]["matched"]==False]
|
296 |
-
})
|
297 |
-
|
298 |
-
# fill the missing_channels
|
299 |
-
app_state = find_neighbors(app_state, channel_info)
|
300 |
|
301 |
second2 = time.time()
|
302 |
-
print(
|
303 |
-
return
|
304 |
|
|
|
10 |
from scipy.optimize import linear_sum_assignment
|
11 |
from sklearn.neighbors import NearestNeighbors
|
12 |
|
13 |
+
def reorder_input_data(old_idx, filename, new_filename):
|
|
|
14 |
old_data = utils.read_train_data(filename) # original raw data
|
15 |
+
new_data = np.zeros((30, old_data.shape[1])) # to store reordered raw data
|
16 |
+
print('new index order:', old_idx)
|
|
|
|
|
17 |
|
18 |
zero_arr = np.zeros((1, old_data.shape[1]))
|
19 |
old_data = np.concatenate((old_data, zero_arr), axis=0)
|
|
|
28 |
tmp_data = [old_data[j, :] for j in idx_set]
|
29 |
new_data[i, :] = np.mean(tmp_data, axis=0)
|
30 |
|
31 |
+
old_shape = (old_data.shape[0]-1, old_data.shape[1])
|
32 |
+
print('old.shape, new.shape: ', old_shape, new_data.shape)
|
33 |
utils.save_data(new_data, new_filename)
|
34 |
return
|
35 |
|
36 |
+
def restore_original_order(channel_info, cnt, old_idx, filename, new_filename):
|
|
|
|
|
37 |
old_data = utils.read_train_data(filename) # denoised data
|
38 |
template_order = channel_info["templateOrder"]
|
39 |
+
input_order = channel_info["inputOrder"]
|
40 |
|
41 |
+
if cnt == 0:
|
42 |
+
new_data = np.zeros((len(input_order), old_data.shape[1]))
|
43 |
else:
|
44 |
new_data = utils.read_train_data(new_filename)
|
45 |
|
46 |
for i, channel in enumerate(template_order):
|
47 |
idx_set = old_idx[i]
|
48 |
|
49 |
+
# ignore if this channel was filled with fillmode ('mean' or 'zero')
|
50 |
if len(idx_set)==1 and channel_info["templateDict"][channel]["matched"]==True:
|
51 |
new_data[idx_set[0], :] = old_data[i, :]
|
52 |
|
|
|
94 |
|
95 |
|
96 |
# --------------------------------2-D------------------------------------
|
97 |
+
# (for the indicate the location missing template channel's position when fill_mode:'mean')
|
98 |
|
99 |
+
fig = [template_montage.plot(), input_montage.plot()]
|
100 |
ax = [fig[0].axes[0], fig[1].axes[0]]
|
101 |
|
102 |
# get the original coords
|
|
|
151 |
})
|
152 |
return channel_info
|
153 |
|
154 |
+
def find_neighbors(channel_info, missing_channels, new_idx):
|
|
|
155 |
template_dict = channel_info["templateDict"]
|
156 |
input_dict = channel_info["inputDict"]
|
|
|
157 |
input_order = channel_info["inputOrder"]
|
|
|
|
|
|
|
158 |
|
159 |
+
all_in = [np.array(input_dict[channel]["coord_3d"]) for channel in input_order]
|
160 |
+
missing_tpl = [np.array(template_dict[channel]["coord_3d"]) for channel in missing_channels]
|
|
|
161 |
|
162 |
# use KNN to choose k nearest channels
|
163 |
k = 4 if len(input_order)>4 else len(input_order)
|
164 |
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
165 |
+
knn.fit(all_in)
|
166 |
|
167 |
+
for i, channel in enumerate(missing_channels):
|
168 |
+
distances, indices = knn.kneighbors(missing_tpl[i].reshape(1,-1))
|
169 |
+
#selected = [input_order[j] for j in indices[0]]
|
170 |
#print(channel, ':', selected)
|
171 |
|
172 |
idx = template_dict[channel]["index"]
|
173 |
new_idx[idx] = indices[0].tolist()
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
+
return new_idx
|
176 |
|
177 |
+
def mapping_stage1(app_info, channel_info):
|
178 |
+
yield app_info, channel_info, gr.Markdown("Mapping...", visible=True)
|
179 |
second1 = time.time()
|
180 |
|
181 |
+
loc_file = app_info["stage1"]["filenames"]["input_loc"]
|
182 |
template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
|
183 |
template_order = template_montage.ch_names
|
184 |
input_order = input_montage.ch_names
|
185 |
+
new_idx = [[]]*30 # store the indices of the in_channels in the order of tpl_channls
|
186 |
alias_dict = {
|
187 |
'T3': 'T7',
|
188 |
'T4': 'T8',
|
|
|
190 |
'T6': 'P8'
|
191 |
}
|
192 |
|
193 |
+
# match the names of input channels and template channels
|
194 |
for i, channel in enumerate(template_order):
|
195 |
if channel in alias_dict and alias_dict[channel] in input_dict:
|
196 |
+
template_montage.rename_channels({channel: alias_dict[channel]}) # rename the current tpl_channel
|
197 |
template_dict[alias_dict[channel]] = template_dict.pop(channel)
|
198 |
channel = alias_dict[channel]
|
199 |
+
|
200 |
if channel in input_dict:
|
201 |
new_idx[i] = [input_dict[channel]["index"]]
|
202 |
template_dict[channel]["matched"] = True
|
203 |
input_dict[channel]["assigned"] = True
|
204 |
|
205 |
+
# update the names
|
206 |
template_order = template_montage.ch_names
|
|
|
207 |
|
208 |
channel_info.update({
|
209 |
"templateDict" : template_dict,
|
|
|
211 |
"templateOrder" : template_order,
|
212 |
"inputOrder" : input_order
|
213 |
})
|
214 |
+
app_info["stage1"].update({
|
215 |
+
"newOrder" : new_idx,
|
216 |
+
"unassignedInputs" : [channel for channel in input_order if input_dict[channel]["assigned"]==False],
|
|
|
217 |
"missingTemplates" : [channel for channel in template_order if template_dict[channel]["matched"]==False]
|
218 |
})
|
219 |
|
|
|
222 |
|
223 |
second2 = time.time()
|
224 |
print('Mapping (stage1) finished in',second2 - second1,'s.')
|
225 |
+
yield app_info, channel_info, gr.Markdown("", visible=False)
|
226 |
|
227 |
+
def mapping_stage2(stage2_info, channel_info):
|
228 |
second1 = time.time()
|
229 |
|
230 |
template_dict = channel_info["templateDict"]
|
231 |
input_dict = channel_info["inputDict"]
|
232 |
template_order = channel_info["templateOrder"]
|
233 |
input_order = channel_info["inputOrder"]
|
234 |
+
unassigned = stage2_info["unassignedInputs"]
|
|
|
|
|
|
|
235 |
|
236 |
tpl_coords = np.array([template_dict[channel]["coord_3d"] for channel in template_order])
|
237 |
unassigned_coords = np.array([input_dict[channel]["coord_3d"] for channel in unassigned])
|
|
|
256 |
new_idx = [[]]*30
|
257 |
for i in range(30):
|
258 |
if col_idx[i] < len(unassigned): # filter out dummy channels
|
|
|
|
|
259 |
tpl_channel = template_order[row_idx[i]]
|
260 |
in_channel = unassigned[col_idx[i]]
|
261 |
template_dict[tpl_channel]["matched"] = True
|
262 |
input_dict[in_channel]["assigned"] = True
|
263 |
new_idx[row_idx[i]] = [input_dict[in_channel]["index"]]
|
264 |
|
265 |
+
print(f'{template_order[row_idx[i]]}({row_idx[i]}) <- {unassigned[col_idx[i]]}({col_idx[i]})')
|
266 |
|
267 |
+
# fill the missing_channels
|
268 |
+
missing_channels = [channel for channel in template_order if template_dict[channel]["matched"]==False]
|
269 |
+
if missing_channels != []:
|
270 |
+
new_idx = find_neighbors(channel_info, missing_channels, new_idx)
|
271 |
+
|
272 |
+
stage2_info.update({
|
273 |
+
"newOrder" : new_idx,
|
274 |
+
"unassignedInputs" : [channel for channel in input_order if input_dict[channel]["assigned"]==False]
|
275 |
+
})
|
276 |
channel_info.update({
|
277 |
"templateDict" : template_dict,
|
278 |
"inputDict" : input_dict
|
279 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
280 |
|
281 |
second2 = time.time()
|
282 |
+
print("The mapping process has been finished in", second2 - second1, "s.")
|
283 |
+
return stage2_info, channel_info
|
284 |
|
utils.py
CHANGED
@@ -98,7 +98,7 @@ def cut_data(filepath, raw_data):
|
|
98 |
total = int(len(raw_data[0]) / 1024)
|
99 |
for i in range(total):
|
100 |
table = raw_data[:, i * 1024:(i + 1) * 1024]
|
101 |
-
filename = filepath + '
|
102 |
with open(filename, 'w', newline='') as csvfile:
|
103 |
writer = csv.writer(csvfile)
|
104 |
writer.writerows(table)
|
@@ -213,10 +213,10 @@ def decode_data(data, std_num, mode=5):
|
|
213 |
def preprocessing(filepath, filename, samplerate):
|
214 |
# establish temp folder
|
215 |
try:
|
216 |
-
os.mkdir(filepath+"
|
217 |
except OSError as e:
|
218 |
-
dataDelete(filepath+"
|
219 |
-
os.mkdir(filepath+"
|
220 |
print(e)
|
221 |
|
222 |
# read data
|
@@ -239,7 +239,7 @@ def reconstruct(model_name, total, filepath, outputfile, samplerate):
|
|
239 |
# -------------------decode_data---------------------------
|
240 |
second1 = time.time()
|
241 |
for i in range(total):
|
242 |
-
file_name = filepath + '
|
243 |
data_noise = read_train_data(file_name)
|
244 |
|
245 |
std = np.std(data_noise)
|
@@ -251,17 +251,18 @@ def reconstruct(model_name, total, filepath, outputfile, samplerate):
|
|
251 |
d_data = decode_data(data_noise, std, model_name)
|
252 |
d_data = d_data[0]
|
253 |
|
254 |
-
outputname = filepath + '
|
255 |
save_data(d_data, outputname)
|
256 |
|
257 |
# --------------------glue_data----------------------------
|
258 |
-
signal = glue_data(filepath+"
|
259 |
#print(signal.shape)
|
260 |
# -------------------delete_data---------------------------
|
261 |
-
dataDelete(filepath+"
|
262 |
# --------------------resample-----------------------------
|
263 |
signal = resample_(signal, 256, samplerate) # 256Hz -> original sampling rate
|
264 |
#print(signal.shape)
|
|
|
265 |
save_data(signal, filepath+outputfile)
|
266 |
second2 = time.time()
|
267 |
|
|
|
98 |
total = int(len(raw_data[0]) / 1024)
|
99 |
for i in range(total):
|
100 |
table = raw_data[:, i * 1024:(i + 1) * 1024]
|
101 |
+
filename = filepath + 'temp2/' + str(i) + '.csv'
|
102 |
with open(filename, 'w', newline='') as csvfile:
|
103 |
writer = csv.writer(csvfile)
|
104 |
writer.writerows(table)
|
|
|
213 |
def preprocessing(filepath, filename, samplerate):
|
214 |
# establish temp folder
|
215 |
try:
|
216 |
+
os.mkdir(filepath+"temp2/")
|
217 |
except OSError as e:
|
218 |
+
dataDelete(filepath+"temp2/")
|
219 |
+
os.mkdir(filepath+"temp2/")
|
220 |
print(e)
|
221 |
|
222 |
# read data
|
|
|
239 |
# -------------------decode_data---------------------------
|
240 |
second1 = time.time()
|
241 |
for i in range(total):
|
242 |
+
file_name = filepath + 'temp2/{}.csv'.format(str(i))
|
243 |
data_noise = read_train_data(file_name)
|
244 |
|
245 |
std = np.std(data_noise)
|
|
|
251 |
d_data = decode_data(data_noise, std, model_name)
|
252 |
d_data = d_data[0]
|
253 |
|
254 |
+
outputname = filepath + 'temp2/output{}.csv'.format(str(i))
|
255 |
save_data(d_data, outputname)
|
256 |
|
257 |
# --------------------glue_data----------------------------
|
258 |
+
signal = glue_data(filepath+"temp2/", total, filepath+outputfile)
|
259 |
#print(signal.shape)
|
260 |
# -------------------delete_data---------------------------
|
261 |
+
dataDelete(filepath+"temp2/")
|
262 |
# --------------------resample-----------------------------
|
263 |
signal = resample_(signal, 256, samplerate) # 256Hz -> original sampling rate
|
264 |
#print(signal.shape)
|
265 |
+
# --------------------save_data----------------------------
|
266 |
save_data(signal, filepath+outputfile)
|
267 |
second2 = time.time()
|
268 |
|