Spaces:
Sleeping
Sleeping
Commit
·
8b18526
1
Parent(s):
a2a070a
update
Browse files- app.py +144 -85
- channel_mapping.py +65 -82
- template_montage.png +0 -0
app.py
CHANGED
@@ -1,13 +1,16 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
import os
|
4 |
import random
|
5 |
import math
|
|
|
|
|
|
|
|
|
|
|
6 |
import utils
|
7 |
from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin, find_neighbors
|
8 |
|
9 |
-
import mne
|
10 |
-
from mne.channels import read_custom_montage
|
11 |
|
12 |
quickstart = """
|
13 |
|
@@ -56,51 +59,52 @@ init_js = """
|
|
56 |
app_state = JSON.parse(JSON.stringify(app_state));
|
57 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
58 |
|
59 |
-
let selector, attribute;
|
60 |
let channel, left, bottom;
|
61 |
|
62 |
if(app_state.stage1State == "step2-selecting"){
|
63 |
-
selector = "#radio > div:nth-of-type(2)";
|
|
|
64 |
attribute = "value";
|
65 |
}else if(app_state.stage1State == "step3-selecting"){
|
66 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
|
|
67 |
attribute = "name";
|
68 |
}else return;
|
69 |
|
70 |
-
|
|
|
71 |
document.querySelector(selector).style.cssText = `
|
72 |
position: relative;
|
73 |
-
width:
|
74 |
-
|
|
|
|
|
75 |
background: url("file=${app_state.filenames.raw_montage}");
|
|
|
|
|
76 |
`;
|
77 |
|
78 |
-
|
79 |
// move the radios/checkboxes
|
80 |
let all_elem = document.querySelectorAll(selector+" > label");
|
81 |
-
Array.from(all_elem).forEach(
|
82 |
channel = item.querySelector("input").getAttribute(attribute);
|
83 |
left = channel_info.inputDict[channel].css_position[0];
|
84 |
bottom = channel_info.inputDict[channel].css_position[1];
|
85 |
-
//console.log(`channel: ${channel}, left: ${left}, bottom: ${bottom}`);
|
86 |
|
87 |
-
item.style.cssText = `
|
88 |
-
|
89 |
-
left: ${left};
|
90 |
-
bottom: ${bottom};
|
91 |
-
`;
|
92 |
-
item.className = "";
|
93 |
item.querySelector(":scope > span").innerText = "";
|
94 |
});
|
95 |
|
96 |
|
97 |
// add indication for the missing channels
|
98 |
-
channel = app_state.missingTemplates[0]
|
99 |
left = channel_info.templateDict[channel].css_position[0];
|
100 |
bottom = channel_info.templateDict[channel].css_position[1];
|
101 |
|
102 |
let dot_rule = `
|
103 |
-
${selector}::before{
|
104 |
content: '';
|
105 |
position: absolute;
|
106 |
background-color: red;
|
@@ -112,12 +116,12 @@ init_js = """
|
|
112 |
}
|
113 |
`;
|
114 |
|
115 |
-
left = parseFloat(left.slice(0, -1))+2.5
|
116 |
-
left = left.toString()+"%"
|
117 |
-
bottom = parseFloat(bottom.slice(0, -1))-1
|
118 |
-
bottom = bottom.toString()+"%"
|
119 |
let txt_rule = `
|
120 |
-
${selector}::after{
|
121 |
content: "${channel}";
|
122 |
position: absolute;
|
123 |
color: red;
|
@@ -149,21 +153,16 @@ update_js = """
|
|
149 |
let channel, left, bottom;
|
150 |
|
151 |
if(app_state.stage1State == "step2-selecting"){
|
152 |
-
selector = "#radio > div:nth-of-type(2)";
|
153 |
|
154 |
// update the radios
|
155 |
let all_elem = document.querySelectorAll(selector+" > label");
|
156 |
-
Array.from(all_elem).forEach(
|
157 |
channel = item.querySelector("input").value;
|
158 |
left = channel_info.inputDict[channel].css_position[0];
|
159 |
bottom = channel_info.inputDict[channel].css_position[1];
|
160 |
-
//console.log(`channel: ${channel}, left: ${left}, bottom: ${bottom}`);
|
161 |
|
162 |
-
item.style.cssText = `
|
163 |
-
position: absolute;
|
164 |
-
left: ${left};
|
165 |
-
bottom: ${bottom};
|
166 |
-
`;
|
167 |
item.className = "";
|
168 |
item.querySelector(":scope > span").innerText = "";
|
169 |
});
|
@@ -172,12 +171,12 @@ update_js = """
|
|
172 |
}else return;
|
173 |
|
174 |
// update indication
|
175 |
-
channel = app_state.missingTemplates[app_state["fillingCount"]-1]
|
176 |
left = channel_info.templateDict[channel].css_position[0];
|
177 |
bottom = channel_info.templateDict[channel].css_position[1];
|
178 |
|
179 |
let dot_rule = `
|
180 |
-
${selector}::before{
|
181 |
content: "";
|
182 |
position: absolute;
|
183 |
background-color: red;
|
@@ -189,12 +188,12 @@ update_js = """
|
|
189 |
}
|
190 |
`;
|
191 |
|
192 |
-
left = parseFloat(left.slice(0, -1))+2.5
|
193 |
-
left = left.toString()+"%"
|
194 |
-
bottom = parseFloat(bottom.slice(0, -1))-1
|
195 |
-
bottom = bottom.toString()+"%"
|
196 |
let txt_rule = `
|
197 |
-
${selector}::after{
|
198 |
content: "${channel}";
|
199 |
position: absolute;
|
200 |
color: red;
|
@@ -208,7 +207,6 @@ update_js = """
|
|
208 |
for(let i=0; i<styleSheet.cssRules.length; i++){
|
209 |
let tmp = styleSheet.cssRules[i].selectorText;
|
210 |
if(tmp==selector+"::before" || tmp==selector+"::after"){
|
211 |
-
console.log('exist!!', tmp);
|
212 |
styleSheet.deleteRule(i);
|
213 |
i--;
|
214 |
}
|
@@ -249,10 +247,10 @@ with gr.Blocks() as demo:
|
|
249 |
# stage1-1 : mapping result
|
250 |
with gr.Row():
|
251 |
tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
|
252 |
-
mapped_montage = gr.Image(
|
253 |
|
254 |
# stage1-2 : assign unmatched input channels to empty template channels
|
255 |
-
|
256 |
|
257 |
# stage1-3 : select a way to fill the empty template channels
|
258 |
with gr.Row():
|
@@ -291,10 +289,6 @@ with gr.Blocks() as demo:
|
|
291 |
batch_md = gr.Markdown(visible=False)
|
292 |
out_denoised_data = gr.File(label="Denoised data", visible=False)
|
293 |
|
294 |
-
#files = []
|
295 |
-
#for i in range():
|
296 |
-
#f = gr.File()
|
297 |
-
#files.append(f)
|
298 |
# -------------------------------------------------------
|
299 |
|
300 |
with gr.Row():
|
@@ -309,7 +303,7 @@ with gr.Blocks() as demo:
|
|
309 |
with gr.Tab("QuickStart"):
|
310 |
gr.Markdown(quickstart)
|
311 |
|
312 |
-
#demo.load(js=
|
313 |
|
314 |
# -------------------------stage1: channel mapping-------------------------------
|
315 |
def reset_all(raw_data, raw_loc, samplerate):
|
@@ -330,17 +324,14 @@ with gr.Blocks() as demo:
|
|
330 |
os.mkdir(filepath+"/temp_data/")
|
331 |
#print(e)
|
332 |
|
333 |
-
# initialize
|
334 |
-
|
335 |
app_state = {
|
336 |
"filepath": filepath+"/temp_data/",
|
337 |
"filenames": {},
|
338 |
"sampleRate": int(samplerate),
|
339 |
"stage1State" : "step1"
|
340 |
}
|
341 |
-
channel_info = {
|
342 |
-
#"dataShape" : data.shape
|
343 |
-
}
|
344 |
|
345 |
# reset layout
|
346 |
return {app_state_json : app_state,
|
@@ -349,7 +340,7 @@ with gr.Blocks() as demo:
|
|
349 |
desc_md : gr.Markdown("### Step1: Mapping result", visible=False),
|
350 |
tpl_montage : gr.Image(visible=False),
|
351 |
mapped_montage : gr.Image(value=None, visible=False),
|
352 |
-
|
353 |
in_fill_mode : gr.Dropdown(value="mean", visible=False),
|
354 |
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
|
355 |
fillmode_btn : gr.Button(visible=False),
|
@@ -364,28 +355,94 @@ with gr.Blocks() as demo:
|
|
364 |
|
365 |
|
366 |
# ---------------------------stage1-1-------------------------------
|
367 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
filepath = app_state["filepath"]
|
369 |
filename1 = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
|
370 |
filename2 = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
|
|
|
371 |
|
372 |
app_state["filenames"].update({
|
373 |
"raw_montage" : filename1,
|
374 |
"mapped_montage" : filename2
|
375 |
})
|
376 |
|
377 |
-
raw_montage = read_custom_montage(raw_loc)
|
378 |
-
raw_fig = raw_montage.plot()
|
379 |
-
raw_fig.set_size_inches(5.6, 5.6)
|
380 |
-
raw_fig.savefig(filename1, pad_inches=0)
|
381 |
-
|
382 |
-
# plot red dots on unmatched input channels
|
383 |
-
ax = raw_fig.axes[0]
|
384 |
-
coords = ax.collections[0].get_offsets().data
|
385 |
-
idx = [channel_info["inputDict"][channel]["index"] for channel in app_state["stage1UnassignedInputs"]]
|
386 |
-
ax.scatter(coords[idx,0]-0.0001, coords[idx,1]+0.0001, color='red')
|
387 |
-
raw_fig.savefig(filename2, pad_inches=0)
|
388 |
-
|
389 |
# ------------------determine the next step-----------------------
|
390 |
|
391 |
in_num = len(channel_info["inputOrder"])
|
@@ -398,6 +455,7 @@ with gr.Blocks() as demo:
|
|
398 |
gr.Info('The mapping process has been finished.')
|
399 |
|
400 |
return {app_state_json : app_state,
|
|
|
401 |
desc_md : gr.Markdown("### Mapping result", visible=True),
|
402 |
tpl_montage : gr.Image(visible=True),
|
403 |
mapped_montage : gr.Image(value=filename2, visible=True),
|
@@ -414,6 +472,7 @@ with gr.Blocks() as demo:
|
|
414 |
app_state["stage1State"] = "step3-initializing"
|
415 |
|
416 |
return {app_state_json : app_state,
|
|
|
417 |
desc_md : gr.Markdown("### Step1: Mapping result", visible=True),
|
418 |
tpl_montage : gr.Image(visible=True),
|
419 |
mapped_montage : gr.Image(value=filename2, visible=True),
|
@@ -422,8 +481,9 @@ with gr.Blocks() as demo:
|
|
422 |
map_btn.click(
|
423 |
fn = reset_all,
|
424 |
inputs = [in_raw_data, in_raw_loc, in_samplerate],
|
425 |
-
outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, mapped_montage,
|
426 |
-
|
|
|
427 |
).success(
|
428 |
fn = mapping_stage1,
|
429 |
inputs = [app_state_json, channel_info_json, in_raw_loc],
|
@@ -431,8 +491,8 @@ with gr.Blocks() as demo:
|
|
431 |
|
432 |
).success(
|
433 |
fn = mapping_result,
|
434 |
-
inputs = [app_state_json, channel_info_json
|
435 |
-
outputs = [app_state_json, desc_md, tpl_montage, mapped_montage, next_btn, run_btn]
|
436 |
)
|
437 |
|
438 |
|
@@ -458,7 +518,7 @@ with gr.Blocks() as demo:
|
|
458 |
desc_md : gr.Markdown("### Step2: Assign unmatched input channels"),
|
459 |
tpl_montage : gr.Image(visible=False),
|
460 |
mapped_montage : gr.Image(visible=False),
|
461 |
-
|
462 |
clear_btn : gr.Button(visible=True),
|
463 |
next_btn : gr.Button("Next step")}
|
464 |
else:
|
@@ -467,7 +527,7 @@ with gr.Blocks() as demo:
|
|
467 |
desc_md : gr.Markdown("### Step2: Assign unmatched input channels"),
|
468 |
tpl_montage : gr.Image(visible=False),
|
469 |
mapped_montage : gr.Image(visible=False),
|
470 |
-
|
471 |
clear_btn : gr.Button(visible=True),
|
472 |
step2_btn : gr.Button(visible=True),
|
473 |
next_btn : gr.Button(visible=False)}
|
@@ -524,7 +584,7 @@ with gr.Blocks() as demo:
|
|
524 |
return {app_state_json : app_state,
|
525 |
channel_info_json : channel_info,
|
526 |
desc_md : gr.Markdown(visible=False),
|
527 |
-
|
528 |
clear_btn : gr.Button(visible=False),
|
529 |
next_btn : gr.Button(visible=False),
|
530 |
run_btn : gr.Button(interactive=True)}
|
@@ -540,7 +600,7 @@ with gr.Blocks() as demo:
|
|
540 |
return {app_state_json : app_state,
|
541 |
channel_info_json : channel_info,
|
542 |
desc_md : gr.Markdown("### Step3: Fill the remaining template channels"),
|
543 |
-
|
544 |
in_fill_mode : gr.Dropdown(visible=True),
|
545 |
fillmode_btn : gr.Button(visible=True),
|
546 |
clear_btn : gr.Button(visible=False),
|
@@ -574,9 +634,9 @@ with gr.Blocks() as demo:
|
|
574 |
|
575 |
next_btn.click(
|
576 |
fn = init_next_step,
|
577 |
-
inputs = [app_state_json, channel_info_json,
|
578 |
-
outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, mapped_montage,
|
579 |
-
|
580 |
).success(
|
581 |
fn = None,
|
582 |
js = init_js,
|
@@ -611,18 +671,20 @@ with gr.Blocks() as demo:
|
|
611 |
if len(app_state["stage1UnassignedInputs"])==1 or app_state["fillingCount"]==app_state["totalFillingNum"]:
|
612 |
return {app_state_json : app_state,
|
613 |
channel_info_json : channel_info,
|
614 |
-
|
|
|
615 |
step2_btn : gr.Button(visible=False),
|
616 |
next_btn : gr.Button("Next step", visible=True)}
|
617 |
else:
|
618 |
return {app_state_json : app_state,
|
619 |
channel_info_json : channel_info,
|
620 |
-
|
|
|
621 |
|
622 |
step2_btn.click(
|
623 |
fn = update_radio,
|
624 |
-
inputs = [app_state_json, channel_info_json,
|
625 |
-
outputs = [app_state_json, channel_info_json,
|
626 |
|
627 |
).success(
|
628 |
fn = None,
|
@@ -634,7 +696,7 @@ with gr.Blocks() as demo:
|
|
634 |
clear_btn.click(
|
635 |
fn = lambda : gr.Radio(value=[]),
|
636 |
inputs = [],
|
637 |
-
outputs =
|
638 |
)
|
639 |
|
640 |
|
@@ -776,9 +838,6 @@ with gr.Blocks() as demo:
|
|
776 |
new_filename = app_state["filenames"]["denoised"]
|
777 |
|
778 |
while app_state["runningState"] != "finished":
|
779 |
-
#if app_state["batchCount"] > app_state["totalBatchNum"]:
|
780 |
-
#app_state["runningState"] = "finished"
|
781 |
-
#break
|
782 |
md = 'Running model('+str(app_state["batchCount"])+'/'+str(app_state["totalBatchNum"])+')...'
|
783 |
yield {batch_md : gr.Markdown(md, visible=True)}
|
784 |
|
|
|
1 |
import gradio as gr
|
2 |
+
|
3 |
import os
|
4 |
import random
|
5 |
import math
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import mne
|
9 |
+
from mne.channels import read_custom_montage
|
10 |
+
|
11 |
import utils
|
12 |
from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin, find_neighbors
|
13 |
|
|
|
|
|
14 |
|
15 |
quickstart = """
|
16 |
|
|
|
59 |
app_state = JSON.parse(JSON.stringify(app_state));
|
60 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
61 |
|
62 |
+
let selector, classname, attribute;
|
63 |
let channel, left, bottom;
|
64 |
|
65 |
if(app_state.stage1State == "step2-selecting"){
|
66 |
+
selector = "#radio-group > div:nth-of-type(2)";
|
67 |
+
//classname = "radio";
|
68 |
attribute = "value";
|
69 |
}else if(app_state.stage1State == "step3-selecting"){
|
70 |
selector = "#chkbox-group > div:nth-of-type(2)";
|
71 |
+
//classname = "chkbox";
|
72 |
attribute = "name";
|
73 |
}else return;
|
74 |
|
75 |
+
|
76 |
+
// add figure of the mapping result
|
77 |
document.querySelector(selector).style.cssText = `
|
78 |
position: relative;
|
79 |
+
width: 100%;
|
80 |
+
aspect-ratio: 1;
|
81 |
+
//width: 560px;
|
82 |
+
//height: 560px;
|
83 |
background: url("file=${app_state.filenames.raw_montage}");
|
84 |
+
background-size: contain;
|
85 |
+
|
86 |
`;
|
87 |
|
|
|
88 |
// move the radios/checkboxes
|
89 |
let all_elem = document.querySelectorAll(selector+" > label");
|
90 |
+
Array.from(all_elem).forEach(item => {
|
91 |
channel = item.querySelector("input").getAttribute(attribute);
|
92 |
left = channel_info.inputDict[channel].css_position[0];
|
93 |
bottom = channel_info.inputDict[channel].css_position[1];
|
|
|
94 |
|
95 |
+
item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
|
96 |
+
item.className = ""; //classname;
|
|
|
|
|
|
|
|
|
97 |
item.querySelector(":scope > span").innerText = "";
|
98 |
});
|
99 |
|
100 |
|
101 |
// add indication for the missing channels
|
102 |
+
channel = app_state.missingTemplates[0];
|
103 |
left = channel_info.templateDict[channel].css_position[0];
|
104 |
bottom = channel_info.templateDict[channel].css_position[1];
|
105 |
|
106 |
let dot_rule = `
|
107 |
+
${selector}::before {
|
108 |
content: '';
|
109 |
position: absolute;
|
110 |
background-color: red;
|
|
|
116 |
}
|
117 |
`;
|
118 |
|
119 |
+
left = parseFloat(left.slice(0, -1))+2.5;
|
120 |
+
left = left.toString()+"%";
|
121 |
+
bottom = parseFloat(bottom.slice(0, -1))-1;
|
122 |
+
bottom = bottom.toString()+"%";
|
123 |
let txt_rule = `
|
124 |
+
${selector}::after {
|
125 |
content: "${channel}";
|
126 |
position: absolute;
|
127 |
color: red;
|
|
|
153 |
let channel, left, bottom;
|
154 |
|
155 |
if(app_state.stage1State == "step2-selecting"){
|
156 |
+
selector = "#radio-group > div:nth-of-type(2)";
|
157 |
|
158 |
// update the radios
|
159 |
let all_elem = document.querySelectorAll(selector+" > label");
|
160 |
+
Array.from(all_elem).forEach(item => {
|
161 |
channel = item.querySelector("input").value;
|
162 |
left = channel_info.inputDict[channel].css_position[0];
|
163 |
bottom = channel_info.inputDict[channel].css_position[1];
|
|
|
164 |
|
165 |
+
item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
|
|
|
|
|
|
|
|
|
166 |
item.className = "";
|
167 |
item.querySelector(":scope > span").innerText = "";
|
168 |
});
|
|
|
171 |
}else return;
|
172 |
|
173 |
// update indication
|
174 |
+
channel = app_state.missingTemplates[app_state["fillingCount"]-1];
|
175 |
left = channel_info.templateDict[channel].css_position[0];
|
176 |
bottom = channel_info.templateDict[channel].css_position[1];
|
177 |
|
178 |
let dot_rule = `
|
179 |
+
${selector}::before {
|
180 |
content: "";
|
181 |
position: absolute;
|
182 |
background-color: red;
|
|
|
188 |
}
|
189 |
`;
|
190 |
|
191 |
+
left = parseFloat(left.slice(0, -1))+2.5;
|
192 |
+
left = left.toString()+"%";
|
193 |
+
bottom = parseFloat(bottom.slice(0, -1))-1;
|
194 |
+
bottom = bottom.toString()+"%";
|
195 |
let txt_rule = `
|
196 |
+
${selector}::after {
|
197 |
content: "${channel}";
|
198 |
position: absolute;
|
199 |
color: red;
|
|
|
207 |
for(let i=0; i<styleSheet.cssRules.length; i++){
|
208 |
let tmp = styleSheet.cssRules[i].selectorText;
|
209 |
if(tmp==selector+"::before" || tmp==selector+"::after"){
|
|
|
210 |
styleSheet.deleteRule(i);
|
211 |
i--;
|
212 |
}
|
|
|
247 |
# stage1-1 : mapping result
|
248 |
with gr.Row():
|
249 |
tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
|
250 |
+
mapped_montage = gr.Image(label="Input channels", visible=False)
|
251 |
|
252 |
# stage1-2 : assign unmatched input channels to empty template channels
|
253 |
+
radio_group = gr.Radio(elem_id="radio-group", visible=False)
|
254 |
|
255 |
# stage1-3 : select a way to fill the empty template channels
|
256 |
with gr.Row():
|
|
|
289 |
batch_md = gr.Markdown(visible=False)
|
290 |
out_denoised_data = gr.File(label="Denoised data", visible=False)
|
291 |
|
|
|
|
|
|
|
|
|
292 |
# -------------------------------------------------------
|
293 |
|
294 |
with gr.Row():
|
|
|
303 |
with gr.Tab("QuickStart"):
|
304 |
gr.Markdown(quickstart)
|
305 |
|
306 |
+
#demo.load(js=tmp_js)
|
307 |
|
308 |
# -------------------------stage1: channel mapping-------------------------------
|
309 |
def reset_all(raw_data, raw_loc, samplerate):
|
|
|
324 |
os.mkdir(filepath+"/temp_data/")
|
325 |
#print(e)
|
326 |
|
327 |
+
# initialize channel_info, app_state
|
328 |
+
channel_info = {}
|
329 |
app_state = {
|
330 |
"filepath": filepath+"/temp_data/",
|
331 |
"filenames": {},
|
332 |
"sampleRate": int(samplerate),
|
333 |
"stage1State" : "step1"
|
334 |
}
|
|
|
|
|
|
|
335 |
|
336 |
# reset layout
|
337 |
return {app_state_json : app_state,
|
|
|
340 |
desc_md : gr.Markdown("### Step1: Mapping result", visible=False),
|
341 |
tpl_montage : gr.Image(visible=False),
|
342 |
mapped_montage : gr.Image(value=None, visible=False),
|
343 |
+
radio_group : gr.Radio(choices=[], value=[], label="", visible=False),
|
344 |
in_fill_mode : gr.Dropdown(value="mean", visible=False),
|
345 |
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
|
346 |
fillmode_btn : gr.Button(visible=False),
|
|
|
355 |
|
356 |
|
357 |
# ---------------------------stage1-1-------------------------------
|
358 |
+
def save_figures(channel_info, filename1, filename2):
|
359 |
+
|
360 |
+
template_montage = read_custom_montage("./template_chanlocs.loc")
|
361 |
+
template_dict = channel_info["templateDict"]
|
362 |
+
input_dict = channel_info["inputDict"]
|
363 |
+
template_order = channel_info["templateOrder"]
|
364 |
+
input_order = channel_info["inputOrder"]
|
365 |
+
|
366 |
+
# get template's head figure
|
367 |
+
tpl_fig = template_montage.plot()
|
368 |
+
tpl_ax = tpl_fig.axes[0]
|
369 |
+
lines = tpl_ax.lines
|
370 |
+
head_lines = []
|
371 |
+
for line in lines:
|
372 |
+
x, y = line.get_data()
|
373 |
+
head_lines.append((x,y))
|
374 |
+
plt.close()
|
375 |
+
|
376 |
+
# get template's and input's 2d coords
|
377 |
+
tpl_x = [template_dict[channel]["coord_2d"][0] for channel in template_order]
|
378 |
+
tpl_y = [template_dict[channel]["coord_2d"][1] for channel in template_order]
|
379 |
+
in_x = [input_dict[channel]["coord_2d"][0] for channel in input_order]
|
380 |
+
in_y = [input_dict[channel]["coord_2d"][1] for channel in input_order]
|
381 |
+
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
382 |
+
in_coords = np.vstack((in_x, in_y)).T
|
383 |
+
|
384 |
+
# -------------------------plot input montage------------------------------
|
385 |
+
fig = plt.figure(figsize=(6.4,6.4), dpi=100)
|
386 |
+
ax = fig.add_subplot(111)
|
387 |
+
fig.tight_layout()
|
388 |
+
ax.set_aspect('equal')
|
389 |
+
ax.axis('off')
|
390 |
+
|
391 |
+
# plot template's head
|
392 |
+
for x, y in head_lines:
|
393 |
+
ax.plot(x, y, color='black', linewidth=1.0)
|
394 |
+
# plot input channels
|
395 |
+
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
396 |
+
for i, channel in enumerate(input_order):
|
397 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], channel, color='black', fontsize=10.0, va='center')
|
398 |
+
|
399 |
+
# save raw_montage
|
400 |
+
fig.savefig(filename1)
|
401 |
+
|
402 |
+
# ---------------------------add indications-------------------------------
|
403 |
+
indices = [input_dict[channel]["index"] for channel in input_order if input_dict[channel]["assigned"]==False]
|
404 |
+
|
405 |
+
# plot unmatched input channels in red
|
406 |
+
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
407 |
+
for i in indices:
|
408 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], input_order[i], color='red', fontsize=10.0, va='center')
|
409 |
+
|
410 |
+
# save mapped_montage
|
411 |
+
fig.savefig(filename2)
|
412 |
+
plt.close()
|
413 |
+
|
414 |
+
# -------------------------------------------------------------------------
|
415 |
+
# save the template and input channels' display position (in px).
|
416 |
+
tpl_coords = ax.transData.transform(tpl_coords)
|
417 |
+
in_coords = ax.transData.transform(in_coords)
|
418 |
+
|
419 |
+
for i, channel in enumerate(template_order):
|
420 |
+
css_left = (tpl_coords[i,0]-11)/6.4
|
421 |
+
css_bottom = (tpl_coords[i,1]-7)/6.4
|
422 |
+
template_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
423 |
+
|
424 |
+
for i, channel in enumerate(input_order):
|
425 |
+
css_left = (in_coords[i,0]-11)/6.4
|
426 |
+
css_bottom = (in_coords[i,1]-7)/6.4
|
427 |
+
input_dict[channel]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
428 |
+
|
429 |
+
channel_info.update({
|
430 |
+
"templateDict" : template_dict,
|
431 |
+
"inputDict" : input_dict
|
432 |
+
})
|
433 |
+
return channel_info
|
434 |
+
|
435 |
+
def mapping_result(app_state, channel_info):
|
436 |
filepath = app_state["filepath"]
|
437 |
filename1 = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
|
438 |
filename2 = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
|
439 |
+
channel_info = save_figures(channel_info, filename1, filename2)
|
440 |
|
441 |
app_state["filenames"].update({
|
442 |
"raw_montage" : filename1,
|
443 |
"mapped_montage" : filename2
|
444 |
})
|
445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
# ------------------determine the next step-----------------------
|
447 |
|
448 |
in_num = len(channel_info["inputOrder"])
|
|
|
455 |
gr.Info('The mapping process has been finished.')
|
456 |
|
457 |
return {app_state_json : app_state,
|
458 |
+
channel_info_json : channel_info,
|
459 |
desc_md : gr.Markdown("### Mapping result", visible=True),
|
460 |
tpl_montage : gr.Image(visible=True),
|
461 |
mapped_montage : gr.Image(value=filename2, visible=True),
|
|
|
472 |
app_state["stage1State"] = "step3-initializing"
|
473 |
|
474 |
return {app_state_json : app_state,
|
475 |
+
channel_info_json : channel_info,
|
476 |
desc_md : gr.Markdown("### Step1: Mapping result", visible=True),
|
477 |
tpl_montage : gr.Image(visible=True),
|
478 |
mapped_montage : gr.Image(value=filename2, visible=True),
|
|
|
481 |
map_btn.click(
|
482 |
fn = reset_all,
|
483 |
inputs = [in_raw_data, in_raw_loc, in_samplerate],
|
484 |
+
outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, mapped_montage, radio_group,
|
485 |
+
in_fill_mode, chkbox_group, fillmode_btn, clear_btn, step2_btn, step3_btn, next_btn,
|
486 |
+
run_btn, batch_md, out_denoised_data]
|
487 |
).success(
|
488 |
fn = mapping_stage1,
|
489 |
inputs = [app_state_json, channel_info_json, in_raw_loc],
|
|
|
491 |
|
492 |
).success(
|
493 |
fn = mapping_result,
|
494 |
+
inputs = [app_state_json, channel_info_json],
|
495 |
+
outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, mapped_montage, next_btn, run_btn]
|
496 |
)
|
497 |
|
498 |
|
|
|
518 |
desc_md : gr.Markdown("### Step2: Assign unmatched input channels"),
|
519 |
tpl_montage : gr.Image(visible=False),
|
520 |
mapped_montage : gr.Image(visible=False),
|
521 |
+
radio_group : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=label, visible=True),
|
522 |
clear_btn : gr.Button(visible=True),
|
523 |
next_btn : gr.Button("Next step")}
|
524 |
else:
|
|
|
527 |
desc_md : gr.Markdown("### Step2: Assign unmatched input channels"),
|
528 |
tpl_montage : gr.Image(visible=False),
|
529 |
mapped_montage : gr.Image(visible=False),
|
530 |
+
radio_group : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=label, visible=True),
|
531 |
clear_btn : gr.Button(visible=True),
|
532 |
step2_btn : gr.Button(visible=True),
|
533 |
next_btn : gr.Button(visible=False)}
|
|
|
584 |
return {app_state_json : app_state,
|
585 |
channel_info_json : channel_info,
|
586 |
desc_md : gr.Markdown(visible=False),
|
587 |
+
radio_group : gr.Radio(visible=False),
|
588 |
clear_btn : gr.Button(visible=False),
|
589 |
next_btn : gr.Button(visible=False),
|
590 |
run_btn : gr.Button(interactive=True)}
|
|
|
600 |
return {app_state_json : app_state,
|
601 |
channel_info_json : channel_info,
|
602 |
desc_md : gr.Markdown("### Step3: Fill the remaining template channels"),
|
603 |
+
radio_group : gr.Radio(visible=False),
|
604 |
in_fill_mode : gr.Dropdown(visible=True),
|
605 |
fillmode_btn : gr.Button(visible=True),
|
606 |
clear_btn : gr.Button(visible=False),
|
|
|
634 |
|
635 |
next_btn.click(
|
636 |
fn = init_next_step,
|
637 |
+
inputs = [app_state_json, channel_info_json, radio_group, chkbox_group],
|
638 |
+
outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, mapped_montage, radio_group,
|
639 |
+
in_fill_mode, chkbox_group, fillmode_btn, clear_btn, step2_btn, next_btn, run_btn]
|
640 |
).success(
|
641 |
fn = None,
|
642 |
js = init_js,
|
|
|
671 |
if len(app_state["stage1UnassignedInputs"])==1 or app_state["fillingCount"]==app_state["totalFillingNum"]:
|
672 |
return {app_state_json : app_state,
|
673 |
channel_info_json : channel_info,
|
674 |
+
radio_group : gr.Radio(choices=app_state["stage1UnassignedInputs"],
|
675 |
+
value=[], label=radio_label),
|
676 |
step2_btn : gr.Button(visible=False),
|
677 |
next_btn : gr.Button("Next step", visible=True)}
|
678 |
else:
|
679 |
return {app_state_json : app_state,
|
680 |
channel_info_json : channel_info,
|
681 |
+
radio_group : gr.Radio(choices=app_state["stage1UnassignedInputs"],
|
682 |
+
value=[], label=radio_label)}
|
683 |
|
684 |
step2_btn.click(
|
685 |
fn = update_radio,
|
686 |
+
inputs = [app_state_json, channel_info_json, radio_group],
|
687 |
+
outputs = [app_state_json, channel_info_json, radio_group, step2_btn, next_btn]
|
688 |
|
689 |
).success(
|
690 |
fn = None,
|
|
|
696 |
clear_btn.click(
|
697 |
fn = lambda : gr.Radio(value=[]),
|
698 |
inputs = [],
|
699 |
+
outputs = radio_group
|
700 |
)
|
701 |
|
702 |
|
|
|
838 |
new_filename = app_state["filenames"]["denoised"]
|
839 |
|
840 |
while app_state["runningState"] != "finished":
|
|
|
|
|
|
|
841 |
md = 'Running model('+str(app_state["batchCount"])+'/'+str(app_state["totalBatchNum"])+')...'
|
842 |
yield {batch_md : gr.Markdown(md, visible=True)}
|
843 |
|
channel_mapping.py
CHANGED
@@ -57,17 +57,6 @@ def reorder_to_origin(app_state, channel_info, new_filename):
|
|
57 |
utils.save_data(new_data, new_filename)
|
58 |
return
|
59 |
|
60 |
-
class Channel:
|
61 |
-
|
62 |
-
def __init__(self, index, name=None, matched=False, assigned=False, coord=None, css_position=None):
|
63 |
-
self.name = name
|
64 |
-
self.index = index
|
65 |
-
self.matched = matched
|
66 |
-
self.assigned = assigned # for input channels
|
67 |
-
self.coord = coord
|
68 |
-
self.css_position = css_position
|
69 |
-
|
70 |
-
|
71 |
def read_montage_data(loc_file):
|
72 |
|
73 |
template_montage = read_custom_montage("./template_chanlocs.loc")
|
@@ -75,17 +64,26 @@ def read_montage_data(loc_file):
|
|
75 |
template_dict = {}
|
76 |
input_dict = {}
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
return template_montage, input_montage, template_dict, input_dict
|
91 |
|
@@ -95,49 +93,43 @@ def align_coords(channel_info, template_montage, input_montage):
|
|
95 |
input_dict = channel_info["inputDict"]
|
96 |
template_order = channel_info["templateOrder"]
|
97 |
input_order = channel_info["inputOrder"]
|
98 |
-
matched = [channel for channel in
|
|
|
99 |
|
100 |
-
# 2-
|
101 |
-
|
102 |
-
fig[0].set_size_inches(5.6, 5.6)
|
103 |
-
fig[1].set_size_inches(5.6, 5.6)
|
104 |
|
|
|
105 |
ax = [fig[0].axes[0], fig[1].axes[0]]
|
106 |
-
ax[0].set_aspect('equal')
|
107 |
-
ax[1].set_aspect('equal')
|
108 |
-
ax[0].figure.canvas.draw() #update the figure
|
109 |
-
ax[1].figure.canvas.draw()
|
110 |
|
111 |
# get the original coords
|
112 |
-
all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) #
|
113 |
-
all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data)
|
|
|
|
|
114 |
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
|
115 |
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
|
116 |
|
117 |
-
# transform the xy axis (
|
118 |
-
rbf_x = Rbf(
|
119 |
-
rbf_y = Rbf(
|
120 |
|
121 |
-
# apply to all
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
-
#
|
127 |
for i, channel in enumerate(template_order):
|
128 |
-
|
129 |
-
css_bottom = (transformed_tpl_y[i]-7)/560
|
130 |
-
template_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
|
131 |
for i, channel in enumerate(input_order):
|
132 |
-
|
133 |
-
css_bottom = (all_in[i][1]-7)/560
|
134 |
-
input_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
|
135 |
|
136 |
|
137 |
# 3-d (to use KNN)
|
138 |
# get the original coords
|
139 |
-
all_tpl = np.array([template_dict[channel]["
|
140 |
-
all_in = np.array([input_dict[channel]["
|
141 |
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
|
142 |
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
|
143 |
|
@@ -152,13 +144,13 @@ def align_coords(channel_info, template_montage, input_montage):
|
|
152 |
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
153 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
154 |
|
155 |
-
# update input's position
|
156 |
for i, channel in enumerate(input_order):
|
157 |
-
input_dict[channel]["
|
158 |
|
159 |
channel_info.update({
|
160 |
"templateDict" : template_dict,
|
161 |
-
"inputDict" : input_dict
|
162 |
})
|
163 |
return channel_info
|
164 |
|
@@ -168,13 +160,12 @@ def find_neighbors(app_state, channel_info):
|
|
168 |
input_dict = channel_info["inputDict"]
|
169 |
template_order = channel_info["templateOrder"]
|
170 |
input_order = channel_info["inputOrder"]
|
171 |
-
#z_row_idx = channel_info["dataShape"][0]
|
172 |
missing_channels = app_state["missingTemplates"]
|
173 |
if missing_channels == []:
|
174 |
return app_state # change nothing
|
175 |
|
176 |
|
177 |
-
in_coords = [input_dict[channel]["
|
178 |
in_coords = np.array([in_coords[i] for i in range(len(in_coords))])
|
179 |
|
180 |
# use KNN to choose k nearest channels
|
@@ -183,7 +174,7 @@ def find_neighbors(app_state, channel_info):
|
|
183 |
knn.fit(in_coords)
|
184 |
|
185 |
for channel in missing_channels:
|
186 |
-
distances, indices = knn.kneighbors(np.array(template_dict[channel]["
|
187 |
selected = [input_order[i] for i in indices[0]]
|
188 |
#print(channel, ':', selected)
|
189 |
|
@@ -205,51 +196,44 @@ def mapping_stage1(app_state, channel_info, loc_file):
|
|
205 |
template_order = template_montage.ch_names
|
206 |
input_order = input_montage.ch_names
|
207 |
new_idx = [[]]*30
|
208 |
-
|
209 |
'T3': 'T7',
|
210 |
'T4': 'T8',
|
211 |
'T5': 'P7',
|
212 |
-
'T6': 'P8'
|
213 |
-
#'TP7': 'T5\'',
|
214 |
-
#'TP8': 'T6\'',
|
215 |
}
|
216 |
|
217 |
# match the names of input channels -> template channels
|
218 |
for i, channel in enumerate(template_order):
|
219 |
-
if channel in
|
220 |
-
template_montage.rename_channels({channel:
|
221 |
-
template_dict[
|
222 |
-
|
223 |
-
channel = alias[channel]
|
224 |
|
225 |
if channel in input_dict:
|
226 |
-
new_idx[i] = [input_dict[channel]
|
227 |
-
|
228 |
-
|
229 |
-
input_dict[channel].matched = True
|
230 |
-
input_dict[channel].assigned = True
|
231 |
|
232 |
# update names
|
233 |
template_order = template_montage.ch_names
|
234 |
input_order = input_montage.ch_names
|
235 |
|
236 |
channel_info.update({
|
237 |
-
"templateDict" :
|
238 |
-
"inputDict" :
|
239 |
"templateOrder" : template_order,
|
240 |
"inputOrder" : input_order
|
241 |
})
|
242 |
app_state.update({
|
243 |
"stage1NewOrder" : new_idx,
|
244 |
"runningState" : "stage1",
|
245 |
-
"stage1UnassignedInputs" : [channel for channel in input_order if input_dict[channel]
|
246 |
-
"missingTemplates" : [channel for channel in template_order if template_dict[channel]
|
247 |
})
|
248 |
|
249 |
# align input, template's coordinates
|
250 |
channel_info = align_coords(channel_info, template_montage, input_montage)
|
251 |
-
# fill the unmatched channels
|
252 |
-
#app_state = fill_channels(app_state, channel_info, fill_mode)
|
253 |
|
254 |
second2 = time.time()
|
255 |
print('Mapping (stage1) finished in',second2 - second1,'s.')
|
@@ -267,8 +251,8 @@ def mapping_stage2(app_state, channel_info):
|
|
267 |
app_state["runningState"] = "finished"
|
268 |
return app_state, channel_info
|
269 |
|
270 |
-
tpl_coords = np.array([template_dict[channel]["
|
271 |
-
unassigned_coords = np.array([input_dict[channel]["
|
272 |
|
273 |
# reset all tpl.matched to False
|
274 |
for channel in template_dict:
|
@@ -276,7 +260,7 @@ def mapping_stage2(app_state, channel_info):
|
|
276 |
|
277 |
# initialize the cost matrix
|
278 |
if len(unassigned) < 30:
|
279 |
-
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col
|
280 |
else:
|
281 |
cost_matrix = np.zeros((30, len(unassigned)))
|
282 |
for i in range(30):
|
@@ -287,17 +271,16 @@ def mapping_stage2(app_state, channel_info):
|
|
287 |
# use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
|
288 |
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
289 |
|
290 |
-
matches = []
|
291 |
new_idx = [[]]*30
|
292 |
for i in range(30):
|
293 |
if col_idx[i] < len(unassigned): # filter out dummy channels
|
294 |
-
|
295 |
|
296 |
tpl_channel = template_order[row_idx[i]]
|
297 |
in_channel = unassigned[col_idx[i]]
|
298 |
template_dict[tpl_channel]["matched"] = True
|
299 |
input_dict[in_channel]["assigned"] = True
|
300 |
-
new_idx[i] = [input_dict[in_channel]["index"]]
|
301 |
|
302 |
print(template_order[row_idx[i]], '<-', unassigned[col_idx[i]])
|
303 |
|
@@ -312,7 +295,7 @@ def mapping_stage2(app_state, channel_info):
|
|
312 |
"missingTemplates" : [channel for channel in template_order if template_dict[channel]["matched"]==False]
|
313 |
})
|
314 |
|
315 |
-
# fill the missing_channels
|
316 |
app_state = find_neighbors(app_state, channel_info)
|
317 |
|
318 |
second2 = time.time()
|
|
|
57 |
utils.save_data(new_data, new_filename)
|
58 |
return
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def read_montage_data(loc_file):
|
61 |
|
62 |
template_montage = read_custom_montage("./template_chanlocs.loc")
|
|
|
64 |
template_dict = {}
|
65 |
input_dict = {}
|
66 |
|
67 |
+
for i in range(30):
|
68 |
+
channel = template_montage.ch_names[i]
|
69 |
+
template_montage.rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase
|
70 |
+
|
71 |
+
channel = str.upper(channel)
|
72 |
+
template_dict[channel] = {
|
73 |
+
"index" : i,
|
74 |
+
"coord_3d" : template_montage.get_positions()['ch_pos'][channel],
|
75 |
+
"matched" : False
|
76 |
+
}
|
77 |
+
for i in range(len(input_montage.ch_names)):
|
78 |
+
channel = input_montage.ch_names[i]
|
79 |
+
input_montage.rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase
|
80 |
+
|
81 |
+
channel = str.upper(channel)
|
82 |
+
input_dict[channel] = {
|
83 |
+
"index" : i,
|
84 |
+
"coord_3d" : input_montage.get_positions()['ch_pos'][channel],
|
85 |
+
"assigned" : False
|
86 |
+
}
|
87 |
|
88 |
return template_montage, input_montage, template_dict, input_dict
|
89 |
|
|
|
93 |
input_dict = channel_info["inputDict"]
|
94 |
template_order = channel_info["templateOrder"]
|
95 |
input_order = channel_info["inputOrder"]
|
96 |
+
matched = [channel for channel in template_order if template_dict[channel]["matched"]==True]
|
97 |
+
|
98 |
|
99 |
+
# --------------------------------2-D------------------------------------
|
100 |
+
# (for the indication of missing template channel's position when fill_mode:'mean')
|
|
|
|
|
101 |
|
102 |
+
fig = [template_montage.plot(), input_montage.plot()]
|
103 |
ax = [fig[0].axes[0], fig[1].axes[0]]
|
|
|
|
|
|
|
|
|
104 |
|
105 |
# get the original coords
|
106 |
+
#all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) # displayed coords (px)
|
107 |
+
#all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data)
|
108 |
+
all_tpl = ax[0].collections[0].get_offsets().data
|
109 |
+
all_in= ax[1].collections[0].get_offsets().data
|
110 |
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
|
111 |
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
|
112 |
|
113 |
+
# transform the xy axis (input's -> template's)
|
114 |
+
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
|
115 |
+
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
|
116 |
|
117 |
+
# apply to all input channels
|
118 |
+
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1])
|
119 |
+
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
|
120 |
+
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
|
121 |
|
122 |
+
# save template's and input's 2d position
|
123 |
for i, channel in enumerate(template_order):
|
124 |
+
template_dict[channel]["coord_2d"] = all_tpl[i]
|
|
|
|
|
125 |
for i, channel in enumerate(input_order):
|
126 |
+
input_dict[channel]["coord_2d"] = transformed_in[i].tolist()
|
|
|
|
|
127 |
|
128 |
|
129 |
# 3-d (to use KNN)
|
130 |
# get the original coords
|
131 |
+
all_tpl = np.array([template_dict[channel]["coord_3d"].tolist() for channel in template_order])
|
132 |
+
all_in = np.array([input_dict[channel]["coord_3d"].tolist() for channel in input_order])
|
133 |
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
|
134 |
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
|
135 |
|
|
|
144 |
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
145 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
146 |
|
147 |
+
# update input's 3d position
|
148 |
for i, channel in enumerate(input_order):
|
149 |
+
input_dict[channel]["coord_3d"] = transformed_in[i].tolist()
|
150 |
|
151 |
channel_info.update({
|
152 |
"templateDict" : template_dict,
|
153 |
+
"inputDict" : input_dict
|
154 |
})
|
155 |
return channel_info
|
156 |
|
|
|
160 |
input_dict = channel_info["inputDict"]
|
161 |
template_order = channel_info["templateOrder"]
|
162 |
input_order = channel_info["inputOrder"]
|
|
|
163 |
missing_channels = app_state["missingTemplates"]
|
164 |
if missing_channels == []:
|
165 |
return app_state # change nothing
|
166 |
|
167 |
|
168 |
+
in_coords = [input_dict[channel]["coord_3d"] for channel in input_order]
|
169 |
in_coords = np.array([in_coords[i] for i in range(len(in_coords))])
|
170 |
|
171 |
# use KNN to choose k nearest channels
|
|
|
174 |
knn.fit(in_coords)
|
175 |
|
176 |
for channel in missing_channels:
|
177 |
+
distances, indices = knn.kneighbors(np.array(template_dict[channel]["coord_3d"]).reshape(1,-1))
|
178 |
selected = [input_order[i] for i in indices[0]]
|
179 |
#print(channel, ':', selected)
|
180 |
|
|
|
196 |
template_order = template_montage.ch_names
|
197 |
input_order = input_montage.ch_names
|
198 |
new_idx = [[]]*30
|
199 |
+
alias_dict = {
|
200 |
'T3': 'T7',
|
201 |
'T4': 'T8',
|
202 |
'T5': 'P7',
|
203 |
+
'T6': 'P8'
|
|
|
|
|
204 |
}
|
205 |
|
206 |
# match the names of input channels -> template channels
|
207 |
for i, channel in enumerate(template_order):
|
208 |
+
if channel in alias_dict and alias_dict[channel] in input_dict:
|
209 |
+
template_montage.rename_channels({channel: alias_dict[channel]})
|
210 |
+
template_dict[alias_dict[channel]] = template_dict.pop(channel)
|
211 |
+
channel = alias_dict[channel]
|
|
|
212 |
|
213 |
if channel in input_dict:
|
214 |
+
new_idx[i] = [input_dict[channel]["index"]]
|
215 |
+
template_dict[channel]["matched"] = True
|
216 |
+
input_dict[channel]["assigned"] = True
|
|
|
|
|
217 |
|
218 |
# update names
|
219 |
template_order = template_montage.ch_names
|
220 |
input_order = input_montage.ch_names
|
221 |
|
222 |
channel_info.update({
|
223 |
+
"templateDict" : template_dict,
|
224 |
+
"inputDict" : input_dict,
|
225 |
"templateOrder" : template_order,
|
226 |
"inputOrder" : input_order
|
227 |
})
|
228 |
app_state.update({
|
229 |
"stage1NewOrder" : new_idx,
|
230 |
"runningState" : "stage1",
|
231 |
+
"stage1UnassignedInputs" : [channel for channel in input_order if input_dict[channel]["assigned"]==False],
|
232 |
+
"missingTemplates" : [channel for channel in template_order if template_dict[channel]["matched"]==False]
|
233 |
})
|
234 |
|
235 |
# align input, template's coordinates
|
236 |
channel_info = align_coords(channel_info, template_montage, input_montage)
|
|
|
|
|
237 |
|
238 |
second2 = time.time()
|
239 |
print('Mapping (stage1) finished in',second2 - second1,'s.')
|
|
|
251 |
app_state["runningState"] = "finished"
|
252 |
return app_state, channel_info
|
253 |
|
254 |
+
tpl_coords = np.array([template_dict[channel]["coord_3d"] for channel in template_order])
|
255 |
+
unassigned_coords = np.array([input_dict[channel]["coord_3d"] for channel in unassigned])
|
256 |
|
257 |
# reset all tpl.matched to False
|
258 |
for channel in template_dict:
|
|
|
260 |
|
261 |
# initialize the cost matrix
|
262 |
if len(unassigned) < 30:
|
263 |
+
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
|
264 |
else:
|
265 |
cost_matrix = np.zeros((30, len(unassigned)))
|
266 |
for i in range(30):
|
|
|
271 |
# use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
|
272 |
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
273 |
|
|
|
274 |
new_idx = [[]]*30
|
275 |
for i in range(30):
|
276 |
if col_idx[i] < len(unassigned): # filter out dummy channels
|
277 |
+
print(f'({row_idx[i]}, {col_idx[i]})')
|
278 |
|
279 |
tpl_channel = template_order[row_idx[i]]
|
280 |
in_channel = unassigned[col_idx[i]]
|
281 |
template_dict[tpl_channel]["matched"] = True
|
282 |
input_dict[in_channel]["assigned"] = True
|
283 |
+
new_idx[row_idx[i]] = [input_dict[in_channel]["index"]]
|
284 |
|
285 |
print(template_order[row_idx[i]], '<-', unassigned[col_idx[i]])
|
286 |
|
|
|
295 |
"missingTemplates" : [channel for channel in template_order if template_dict[channel]["matched"]==False]
|
296 |
})
|
297 |
|
298 |
+
# fill the missing_channels
|
299 |
app_state = find_neighbors(app_state, channel_info)
|
300 |
|
301 |
second2 = time.time()
|
template_montage.png
CHANGED
![]() |
![]() |