Spaces:
Sleeping
Sleeping
Commit
·
df70562
1
Parent(s):
ed22689
update
Browse files- app.py +26 -26
- app_utils.py +92 -92
app.py
CHANGED
@@ -60,7 +60,7 @@ init_js = """
|
|
60 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
61 |
|
62 |
let selector, attribute;
|
63 |
-
let
|
64 |
|
65 |
if(stage1_info.state == "step2-selecting"){
|
66 |
selector = "#radio-group > div:nth-of-type(2)";
|
@@ -84,9 +84,9 @@ init_js = """
|
|
84 |
// move the radios/checkboxes
|
85 |
let all_elem = document.querySelectorAll(selector+" > label");
|
86 |
Array.from(all_elem).forEach(item => {
|
87 |
-
|
88 |
-
left = channel_info.inputDict[
|
89 |
-
bottom = channel_info.inputDict[
|
90 |
|
91 |
item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
|
92 |
item.className = "";
|
@@ -94,9 +94,9 @@ init_js = """
|
|
94 |
});
|
95 |
|
96 |
// add indication for the empty tpl_channels
|
97 |
-
|
98 |
-
left = channel_info.templateDict[
|
99 |
-
bottom = channel_info.templateDict[
|
100 |
let dot_rule = `
|
101 |
${selector}::before {
|
102 |
content: "";
|
@@ -116,7 +116,7 @@ init_js = """
|
|
116 |
bottom = bottom.toString()+"%";
|
117 |
let txt_rule = `
|
118 |
${selector}::after {
|
119 |
-
content: "${
|
120 |
position: absolute;
|
121 |
color: red;
|
122 |
left: ${left};
|
@@ -144,7 +144,7 @@ update_js = """
|
|
144 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
145 |
|
146 |
let selector;
|
147 |
-
let cnt,
|
148 |
|
149 |
if(stage1_info.state == "step2-selecting"){
|
150 |
selector = "#radio-group > div:nth-of-type(2)";
|
@@ -153,9 +153,9 @@ update_js = """
|
|
153 |
// update the radios
|
154 |
let all_elem = document.querySelectorAll(selector+" > label");
|
155 |
Array.from(all_elem).forEach(item => {
|
156 |
-
|
157 |
-
left = channel_info.inputDict[
|
158 |
-
bottom = channel_info.inputDict[
|
159 |
|
160 |
item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
|
161 |
item.className = "";
|
@@ -167,9 +167,9 @@ update_js = """
|
|
167 |
}else return;
|
168 |
|
169 |
// update the indication
|
170 |
-
|
171 |
-
left = channel_info.templateDict[
|
172 |
-
bottom = channel_info.templateDict[
|
173 |
let dot_rule = `
|
174 |
${selector}::before {
|
175 |
content: "";
|
@@ -189,7 +189,7 @@ update_js = """
|
|
189 |
bottom = bottom.toString()+"%";
|
190 |
let txt_rule = `
|
191 |
${selector}::after {
|
192 |
-
content: "${
|
193 |
position: absolute;
|
194 |
color: red;
|
195 |
left: ${left};
|
@@ -403,7 +403,7 @@ with gr.Blocks() as demo:
|
|
403 |
|
404 |
# ========================================step1=========================================
|
405 |
elif stage1_info["state"] == "step1-finished":
|
406 |
-
in_num = len(channel_info["
|
407 |
matched_num = 30 - len(stage1_info["emptyTemplates"])
|
408 |
|
409 |
# step1 to step4
|
@@ -490,10 +490,10 @@ with gr.Blocks() as demo:
|
|
490 |
|
491 |
# --------------------------------update information--------------------------------
|
492 |
# exclude the selected in_channel of the previous round
|
493 |
-
stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["
|
494 |
channel_info["inputDict"])
|
495 |
# exclude the tpl_channels filled in step2
|
496 |
-
stage1_info["emptyTemplates"] = app_utils.get_empty_templates(channel_info["
|
497 |
channel_info["templateDict"])
|
498 |
# -----------------------------determine the next step------------------------------
|
499 |
# step2 to step4
|
@@ -571,7 +571,7 @@ with gr.Blocks() as demo:
|
|
571 |
|
572 |
tpl_idx = channel_info["templateDict"][tpl_name]["index"]
|
573 |
value = stage1_info["mappingResults"][0]["newOrder"][tpl_idx]
|
574 |
-
value = [channel_info["
|
575 |
|
576 |
stage1_info["state"] = "step3-2-selecting"
|
577 |
# determine which button to display
|
@@ -580,7 +580,7 @@ with gr.Blocks() as demo:
|
|
580 |
desc_md : gr.Markdown(md),
|
581 |
in_fillmode : gr.Dropdown(visible=False),
|
582 |
fillmode_btn : gr.Button(visible=False),
|
583 |
-
chkbox_group : gr.CheckboxGroup(choices=channel_info["
|
584 |
value=value, label=label, visible=True),
|
585 |
next_btn : gr.Button(visible=True)}
|
586 |
else:
|
@@ -588,7 +588,7 @@ with gr.Blocks() as demo:
|
|
588 |
desc_md : gr.Markdown(md),
|
589 |
in_fillmode : gr.Dropdown(visible=False),
|
590 |
fillmode_btn : gr.Button(visible=False),
|
591 |
-
chkbox_group : gr.CheckboxGroup(choices=channel_info["
|
592 |
value=value, label=label, visible=True),
|
593 |
step3_btn : gr.Button(visible=True)}
|
594 |
|
@@ -598,7 +598,7 @@ with gr.Blocks() as demo:
|
|
598 |
# --------------------------------store information---------------------------------
|
599 |
prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step3"]["count"]-1]
|
600 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
601 |
-
sel_idx = [channel_info["inputDict"][
|
602 |
stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
|
603 |
#print(prev_tpl_name, '<-', sel_chkbox)
|
604 |
# ----------------------------------------------------------------------------------
|
@@ -687,7 +687,7 @@ with gr.Blocks() as demo:
|
|
687 |
step2["count"] += 1
|
688 |
|
689 |
# exclude the selected in_channel of the previous round
|
690 |
-
stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["
|
691 |
|
692 |
tpl_name = stage1_info["emptyTemplates"][step2["count"]-1]
|
693 |
label = '{} ({}/{})'.format(tpl_name, step2["count"], step2["totalNum"])
|
@@ -726,7 +726,7 @@ with gr.Blocks() as demo:
|
|
726 |
# ----------------------------------store information-----------------------------------
|
727 |
prev_tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
|
728 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
729 |
-
sel_idx = [channel_info["inputDict"][
|
730 |
stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
|
731 |
#print(prev_tpl_name, '<-', sel_name)
|
732 |
|
@@ -738,7 +738,7 @@ with gr.Blocks() as demo:
|
|
738 |
|
739 |
tpl_idx = channel_info["templateDict"][tpl_name]["index"]
|
740 |
value = stage1_info["mappingResults"][0]["newOrder"][tpl_idx]
|
741 |
-
value = [channel_info["
|
742 |
|
743 |
stage1_info["step3"] = step3
|
744 |
# determine which button to display
|
|
|
60 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
61 |
|
62 |
let selector, attribute;
|
63 |
+
let name, left, bottom;
|
64 |
|
65 |
if(stage1_info.state == "step2-selecting"){
|
66 |
selector = "#radio-group > div:nth-of-type(2)";
|
|
|
84 |
// move the radios/checkboxes
|
85 |
let all_elem = document.querySelectorAll(selector+" > label");
|
86 |
Array.from(all_elem).forEach(item => {
|
87 |
+
name = item.querySelector("input").getAttribute(attribute);
|
88 |
+
left = channel_info.inputDict[name].css_position[0];
|
89 |
+
bottom = channel_info.inputDict[name].css_position[1];
|
90 |
|
91 |
item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
|
92 |
item.className = "";
|
|
|
94 |
});
|
95 |
|
96 |
// add indication for the empty tpl_channels
|
97 |
+
name = stage1_info.emptyTemplates[0];
|
98 |
+
left = channel_info.templateDict[name].css_position[0];
|
99 |
+
bottom = channel_info.templateDict[name].css_position[1];
|
100 |
let dot_rule = `
|
101 |
${selector}::before {
|
102 |
content: "";
|
|
|
116 |
bottom = bottom.toString()+"%";
|
117 |
let txt_rule = `
|
118 |
${selector}::after {
|
119 |
+
content: "${name}";
|
120 |
position: absolute;
|
121 |
color: red;
|
122 |
left: ${left};
|
|
|
144 |
channel_info = JSON.parse(JSON.stringify(channel_info));
|
145 |
|
146 |
let selector;
|
147 |
+
let cnt, name, left, bottom;
|
148 |
|
149 |
if(stage1_info.state == "step2-selecting"){
|
150 |
selector = "#radio-group > div:nth-of-type(2)";
|
|
|
153 |
// update the radios
|
154 |
let all_elem = document.querySelectorAll(selector+" > label");
|
155 |
Array.from(all_elem).forEach(item => {
|
156 |
+
name = item.querySelector("input").value;
|
157 |
+
left = channel_info.inputDict[name].css_position[0];
|
158 |
+
bottom = channel_info.inputDict[name].css_position[1];
|
159 |
|
160 |
item.style.cssText = `position: absolute; left: ${left}; bottom: ${bottom};`;
|
161 |
item.className = "";
|
|
|
167 |
}else return;
|
168 |
|
169 |
// update the indication
|
170 |
+
name = stage1_info.emptyTemplates[cnt-1];
|
171 |
+
left = channel_info.templateDict[name].css_position[0];
|
172 |
+
bottom = channel_info.templateDict[name].css_position[1];
|
173 |
let dot_rule = `
|
174 |
${selector}::before {
|
175 |
content: "";
|
|
|
189 |
bottom = bottom.toString()+"%";
|
190 |
let txt_rule = `
|
191 |
${selector}::after {
|
192 |
+
content: "${name}";
|
193 |
position: absolute;
|
194 |
color: red;
|
195 |
left: ${left};
|
|
|
403 |
|
404 |
# ========================================step1=========================================
|
405 |
elif stage1_info["state"] == "step1-finished":
|
406 |
+
in_num = len(channel_info["inputNames"])
|
407 |
matched_num = 30 - len(stage1_info["emptyTemplates"])
|
408 |
|
409 |
# step1 to step4
|
|
|
490 |
|
491 |
# --------------------------------update information--------------------------------
|
492 |
# exclude the selected in_channel of the previous round
|
493 |
+
stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputNames"],
|
494 |
channel_info["inputDict"])
|
495 |
# exclude the tpl_channels filled in step2
|
496 |
+
stage1_info["emptyTemplates"] = app_utils.get_empty_templates(channel_info["templateNames"],
|
497 |
channel_info["templateDict"])
|
498 |
# -----------------------------determine the next step------------------------------
|
499 |
# step2 to step4
|
|
|
571 |
|
572 |
tpl_idx = channel_info["templateDict"][tpl_name]["index"]
|
573 |
value = stage1_info["mappingResults"][0]["newOrder"][tpl_idx]
|
574 |
+
value = [channel_info["inputNames"][i] for i in value]
|
575 |
|
576 |
stage1_info["state"] = "step3-2-selecting"
|
577 |
# determine which button to display
|
|
|
580 |
desc_md : gr.Markdown(md),
|
581 |
in_fillmode : gr.Dropdown(visible=False),
|
582 |
fillmode_btn : gr.Button(visible=False),
|
583 |
+
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputNames"],
|
584 |
value=value, label=label, visible=True),
|
585 |
next_btn : gr.Button(visible=True)}
|
586 |
else:
|
|
|
588 |
desc_md : gr.Markdown(md),
|
589 |
in_fillmode : gr.Dropdown(visible=False),
|
590 |
fillmode_btn : gr.Button(visible=False),
|
591 |
+
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputNames"],
|
592 |
value=value, label=label, visible=True),
|
593 |
step3_btn : gr.Button(visible=True)}
|
594 |
|
|
|
598 |
# --------------------------------store information---------------------------------
|
599 |
prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step3"]["count"]-1]
|
600 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
601 |
+
sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_chkbox]
|
602 |
stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
|
603 |
#print(prev_tpl_name, '<-', sel_chkbox)
|
604 |
# ----------------------------------------------------------------------------------
|
|
|
687 |
step2["count"] += 1
|
688 |
|
689 |
# exclude the selected in_channel of the previous round
|
690 |
+
stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputNames"], channel_info["inputDict"])
|
691 |
|
692 |
tpl_name = stage1_info["emptyTemplates"][step2["count"]-1]
|
693 |
label = '{} ({}/{})'.format(tpl_name, step2["count"], step2["totalNum"])
|
|
|
726 |
# ----------------------------------store information-----------------------------------
|
727 |
prev_tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
|
728 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
729 |
+
sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_name]
|
730 |
stage1_info["mappingResults"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
|
731 |
#print(prev_tpl_name, '<-', sel_name)
|
732 |
|
|
|
738 |
|
739 |
tpl_idx = channel_info["templateDict"][tpl_name]["index"]
|
740 |
value = stage1_info["mappingResults"][0]["newOrder"][tpl_idx]
|
741 |
+
value = [channel_info["inputNames"][i] for i in value]
|
742 |
|
743 |
stage1_info["step3"] = step3
|
744 |
# determine which button to display
|
app_utils.py
CHANGED
@@ -46,52 +46,52 @@ def restore_order(batch_cnt, raw_data_shape, idx_order, orig_flags, filename, ou
|
|
46 |
utils.save_data(new_data, outputname)
|
47 |
return
|
48 |
|
49 |
-
def get_matched(
|
50 |
-
return [
|
51 |
|
52 |
-
def get_empty_templates(
|
53 |
-
return [
|
54 |
|
55 |
-
def get_unassigned_inputs(
|
56 |
-
return [
|
57 |
|
58 |
def read_montage_data(loc_file):
|
59 |
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
60 |
in_montage = read_custom_montage(loc_file)
|
61 |
-
|
62 |
-
|
63 |
tpl_dict = {}
|
64 |
in_dict = {}
|
65 |
|
66 |
# convert all channel names to uppercase and store their information
|
67 |
-
for i,
|
68 |
-
|
69 |
-
tpl_montage.rename_channels({
|
70 |
-
tpl_dict[
|
71 |
"index" : i,
|
72 |
-
"coord_3d" : tpl_montage.get_positions()['ch_pos'][
|
73 |
"matched" : False
|
74 |
}
|
75 |
-
for i,
|
76 |
-
|
77 |
-
in_montage.rename_channels({
|
78 |
-
in_dict[
|
79 |
"index" : i,
|
80 |
-
"coord_3d" : in_montage.get_positions()['ch_pos'][
|
81 |
"assigned" : False
|
82 |
}
|
83 |
return tpl_montage, in_montage, tpl_dict, in_dict
|
84 |
|
85 |
def save_figures(channel_info, tpl_montage, filename1, filename2):
|
86 |
-
|
87 |
-
|
88 |
tpl_dict = channel_info["templateDict"]
|
89 |
in_dict = channel_info["inputDict"]
|
90 |
|
91 |
-
tpl_x = [tpl_dict[
|
92 |
-
tpl_y = [tpl_dict[
|
93 |
-
in_x = [in_dict[
|
94 |
-
in_y = [in_dict[
|
95 |
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
96 |
in_coords = np.vstack((in_x, in_y)).T
|
97 |
|
@@ -116,18 +116,18 @@ def save_figures(channel_info, tpl_montage, filename1, filename2):
|
|
116 |
ax.plot(x, y, color='black', linewidth=1.0)
|
117 |
# plot in_channels on it
|
118 |
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
119 |
-
for i,
|
120 |
-
ax.text(in_coords[i,0]+0.003, in_coords[i,1],
|
121 |
# save input_montage
|
122 |
fig.savefig(filename1)
|
123 |
|
124 |
# ---------------------------add indications-------------------------------
|
125 |
# plot unmatched input channels in red
|
126 |
-
indices = [in_dict[
|
127 |
if indices != []:
|
128 |
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
129 |
for i in indices:
|
130 |
-
ax.text(in_coords[i,0]+0.003, in_coords[i,1],
|
131 |
# save mapped_montage
|
132 |
fig.savefig(filename2)
|
133 |
|
@@ -137,14 +137,14 @@ def save_figures(channel_info, tpl_montage, filename1, filename2):
|
|
137 |
in_coords = ax.transData.transform(in_coords)
|
138 |
plt.close('all')
|
139 |
|
140 |
-
for i,
|
141 |
css_left = (tpl_coords[i,0]-11)/6.4
|
142 |
css_bottom = (tpl_coords[i,1]-7)/6.4
|
143 |
-
tpl_dict[
|
144 |
-
for i,
|
145 |
css_left = (in_coords[i,0]-11)/6.4
|
146 |
css_bottom = (in_coords[i,1]-7)/6.4
|
147 |
-
in_dict[
|
148 |
|
149 |
channel_info.update({
|
150 |
"templateDict" : tpl_dict,
|
@@ -153,11 +153,11 @@ def save_figures(channel_info, tpl_montage, filename1, filename2):
|
|
153 |
return channel_info
|
154 |
|
155 |
def align_coords(channel_info, tpl_montage, in_montage):
|
156 |
-
|
157 |
-
|
158 |
tpl_dict = channel_info["templateDict"]
|
159 |
in_dict = channel_info["inputDict"]
|
160 |
-
|
161 |
|
162 |
# 2D alignment (for visualization purposes)
|
163 |
fig = [tpl_montage.plot(), in_montage.plot()]
|
@@ -166,8 +166,8 @@ def align_coords(channel_info, tpl_montage, in_montage):
|
|
166 |
# extract the displayed 2D coordinates
|
167 |
all_tpl = ax[0].collections[0].get_offsets().data
|
168 |
all_in= ax[1].collections[0].get_offsets().data
|
169 |
-
matched_tpl = np.array([all_tpl[tpl_dict[
|
170 |
-
matched_in = np.array([all_in[in_dict[
|
171 |
plt.close('all')
|
172 |
|
173 |
# apply TPS to transform in_channels to align with tpl_channels positions
|
@@ -179,17 +179,17 @@ def align_coords(channel_info, tpl_montage, in_montage):
|
|
179 |
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
|
180 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
|
181 |
|
182 |
-
for i,
|
183 |
-
tpl_dict[
|
184 |
-
for i,
|
185 |
-
in_dict[
|
186 |
|
187 |
|
188 |
# 3D alignment
|
189 |
-
all_tpl = np.array([tpl_dict[
|
190 |
-
all_in = np.array([in_dict[
|
191 |
-
matched_tpl = np.array([all_tpl[tpl_dict[
|
192 |
-
matched_in = np.array([all_in[in_dict[
|
193 |
|
194 |
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
|
195 |
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
|
@@ -200,8 +200,8 @@ def align_coords(channel_info, tpl_montage, in_montage):
|
|
200 |
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
201 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
202 |
|
203 |
-
for i,
|
204 |
-
in_dict[
|
205 |
|
206 |
channel_info.update({
|
207 |
"templateDict" : tpl_dict,
|
@@ -209,21 +209,21 @@ def align_coords(channel_info, tpl_montage, in_montage):
|
|
209 |
})
|
210 |
return channel_info
|
211 |
|
212 |
-
def find_neighbors(channel_info,
|
213 |
-
|
214 |
tpl_dict = channel_info["templateDict"]
|
215 |
in_dict = channel_info["inputDict"]
|
216 |
|
217 |
-
all_in = [np.array(in_dict[
|
218 |
-
empty_tpl = [np.array(tpl_dict[
|
219 |
|
220 |
# use KNN to choose k nearest channels
|
221 |
-
k = 4 if len(
|
222 |
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
223 |
knn.fit(all_in)
|
224 |
-
for i,
|
225 |
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
|
226 |
-
idx = tpl_dict[
|
227 |
new_idx[idx] = indices[0].tolist()
|
228 |
|
229 |
return new_idx
|
@@ -232,8 +232,8 @@ def match_names(stage1_info):
|
|
232 |
# read the location file
|
233 |
loc_file = stage1_info["fileNames"]["inputLocation"]
|
234 |
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
|
235 |
-
|
236 |
-
|
237 |
new_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
|
238 |
orig_flags = [False]*30
|
239 |
|
@@ -243,24 +243,24 @@ def match_names(stage1_info):
|
|
243 |
'T5': 'P7',
|
244 |
'T6': 'P8'
|
245 |
}
|
246 |
-
for i,
|
247 |
-
if
|
248 |
-
tpl_montage.rename_channels({
|
249 |
-
tpl_dict[alias_dict[
|
250 |
-
|
251 |
|
252 |
-
if
|
253 |
-
new_idx[i] = [in_dict[
|
254 |
orig_flags[i] = True
|
255 |
-
tpl_dict[
|
256 |
-
in_dict[
|
257 |
|
258 |
# update the names
|
259 |
-
|
260 |
|
261 |
stage1_info.update({
|
262 |
-
"unassignedInputs" : get_unassigned_inputs(
|
263 |
-
"emptyTemplates" : get_empty_templates(
|
264 |
"mappingResults" : [
|
265 |
{
|
266 |
"newOrder" : new_idx,
|
@@ -269,34 +269,34 @@ def match_names(stage1_info):
|
|
269 |
]
|
270 |
})
|
271 |
channel_info = {
|
272 |
-
"
|
273 |
-
"
|
274 |
"templateDict" : tpl_dict,
|
275 |
"inputDict" : in_dict
|
276 |
}
|
277 |
return stage1_info, channel_info, tpl_montage, in_montage
|
278 |
|
279 |
def optimal_mapping(channel_info):
|
280 |
-
|
281 |
-
|
282 |
tpl_dict = channel_info["templateDict"]
|
283 |
in_dict = channel_info["inputDict"]
|
284 |
-
|
285 |
# reset all tpl.matched to False
|
286 |
-
for
|
287 |
-
tpl_dict[
|
288 |
|
289 |
-
all_tpl = np.array([tpl_dict[
|
290 |
-
unass_in = np.array([in_dict[
|
291 |
|
292 |
# initialize the cost matrix for the Hungarian algorithm
|
293 |
-
if len(
|
294 |
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
|
295 |
else:
|
296 |
-
cost_matrix = np.zeros((30, len(
|
297 |
# fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
|
298 |
for i in range(30):
|
299 |
-
for j in range(len(
|
300 |
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unass_in[j])*1000)
|
301 |
|
302 |
# apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
|
@@ -307,20 +307,20 @@ def optimal_mapping(channel_info):
|
|
307 |
new_idx = [[None]]*30
|
308 |
orig_flags = [False]*30
|
309 |
for i, j in zip(row_idx, col_idx):
|
310 |
-
if j < len(
|
311 |
-
|
312 |
-
|
313 |
|
314 |
-
new_idx[i] = [in_dict[
|
315 |
orig_flags[i] = True
|
316 |
-
tpl_dict[
|
317 |
-
in_dict[
|
318 |
-
#print(f'{
|
319 |
|
320 |
# fill the remaining empty tpl_channels
|
321 |
-
|
322 |
-
if
|
323 |
-
new_idx = find_neighbors(channel_info,
|
324 |
|
325 |
result = {
|
326 |
"newOrder" : new_idx,
|
@@ -344,8 +344,8 @@ def mapping_result(stage1_info, channel_info, filename):
|
|
344 |
results += [result]
|
345 |
|
346 |
data = {
|
347 |
-
#"
|
348 |
-
#"
|
349 |
"batchNum" : batch_num,
|
350 |
"mappingResults" : results
|
351 |
}
|
|
|
46 |
utils.save_data(new_data, outputname)
|
47 |
return
|
48 |
|
49 |
+
def get_matched(tpl_names, tpl_dict):
|
50 |
+
return [name for name in tpl_names if tpl_dict[name]["matched"]==True]
|
51 |
|
52 |
+
def get_empty_templates(tpl_names, tpl_dict):
|
53 |
+
return [name for name in tpl_names if tpl_dict[name]["matched"]==False]
|
54 |
|
55 |
+
def get_unassigned_inputs(in_names, in_dict):
|
56 |
+
return [name for name in in_names if in_dict[name]["assigned"]==False]
|
57 |
|
58 |
def read_montage_data(loc_file):
|
59 |
tpl_montage = read_custom_montage("./template_chanlocs.loc")
|
60 |
in_montage = read_custom_montage(loc_file)
|
61 |
+
tpl_names = tpl_montage.ch_names
|
62 |
+
in_names = in_montage.ch_names
|
63 |
tpl_dict = {}
|
64 |
in_dict = {}
|
65 |
|
66 |
# convert all channel names to uppercase and store their information
|
67 |
+
for i, name in enumerate(tpl_names):
|
68 |
+
up_name = str.upper(name)
|
69 |
+
tpl_montage.rename_channels({name: up_name})
|
70 |
+
tpl_dict[up_name] = {
|
71 |
"index" : i,
|
72 |
+
"coord_3d" : tpl_montage.get_positions()['ch_pos'][up_name],
|
73 |
"matched" : False
|
74 |
}
|
75 |
+
for i, name in enumerate(in_names):
|
76 |
+
up_name = str.upper(name)
|
77 |
+
in_montage.rename_channels({name: up_name})
|
78 |
+
in_dict[up_name] = {
|
79 |
"index" : i,
|
80 |
+
"coord_3d" : in_montage.get_positions()['ch_pos'][up_name],
|
81 |
"assigned" : False
|
82 |
}
|
83 |
return tpl_montage, in_montage, tpl_dict, in_dict
|
84 |
|
85 |
def save_figures(channel_info, tpl_montage, filename1, filename2):
|
86 |
+
tpl_names = channel_info["templateNames"]
|
87 |
+
in_names = channel_info["inputNames"]
|
88 |
tpl_dict = channel_info["templateDict"]
|
89 |
in_dict = channel_info["inputDict"]
|
90 |
|
91 |
+
tpl_x = [tpl_dict[name]["coord_2d"][0] for name in tpl_names]
|
92 |
+
tpl_y = [tpl_dict[name]["coord_2d"][1] for name in tpl_names]
|
93 |
+
in_x = [in_dict[name]["coord_2d"][0] for name in in_names]
|
94 |
+
in_y = [in_dict[name]["coord_2d"][1] for name in in_names]
|
95 |
tpl_coords = np.vstack((tpl_x, tpl_y)).T
|
96 |
in_coords = np.vstack((in_x, in_y)).T
|
97 |
|
|
|
116 |
ax.plot(x, y, color='black', linewidth=1.0)
|
117 |
# plot in_channels on it
|
118 |
ax.scatter(in_coords[:,0], in_coords[:,1], s=35, color='black')
|
119 |
+
for i, name in enumerate(in_names):
|
120 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], name, color='black', fontsize=10.0, va='center')
|
121 |
# save input_montage
|
122 |
fig.savefig(filename1)
|
123 |
|
124 |
# ---------------------------add indications-------------------------------
|
125 |
# plot unmatched input channels in red
|
126 |
+
indices = [in_dict[name]["index"] for name in in_names if in_dict[name]["assigned"]==False]
|
127 |
if indices != []:
|
128 |
ax.scatter(in_coords[indices,0], in_coords[indices,1], s=35, color='red')
|
129 |
for i in indices:
|
130 |
+
ax.text(in_coords[i,0]+0.003, in_coords[i,1], in_names[i], color='red', fontsize=10.0, va='center')
|
131 |
# save mapped_montage
|
132 |
fig.savefig(filename2)
|
133 |
|
|
|
137 |
in_coords = ax.transData.transform(in_coords)
|
138 |
plt.close('all')
|
139 |
|
140 |
+
for i, name in enumerate(tpl_names):
|
141 |
css_left = (tpl_coords[i,0]-11)/6.4
|
142 |
css_bottom = (tpl_coords[i,1]-7)/6.4
|
143 |
+
tpl_dict[name]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
144 |
+
for i, name in enumerate(in_names):
|
145 |
css_left = (in_coords[i,0]-11)/6.4
|
146 |
css_bottom = (in_coords[i,1]-7)/6.4
|
147 |
+
in_dict[name]["css_position"] = [str(round(css_left, 2))+"%", str(round(css_bottom, 2))+"%"]
|
148 |
|
149 |
channel_info.update({
|
150 |
"templateDict" : tpl_dict,
|
|
|
153 |
return channel_info
|
154 |
|
155 |
def align_coords(channel_info, tpl_montage, in_montage):
|
156 |
+
tpl_names = channel_info["templateNames"]
|
157 |
+
in_names = channel_info["inputNames"]
|
158 |
tpl_dict = channel_info["templateDict"]
|
159 |
in_dict = channel_info["inputDict"]
|
160 |
+
matched_names = get_matched(tpl_names, tpl_dict)
|
161 |
|
162 |
# 2D alignment (for visualization purposes)
|
163 |
fig = [tpl_montage.plot(), in_montage.plot()]
|
|
|
166 |
# extract the displayed 2D coordinates
|
167 |
all_tpl = ax[0].collections[0].get_offsets().data
|
168 |
all_in= ax[1].collections[0].get_offsets().data
|
169 |
+
matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
|
170 |
+
matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
|
171 |
plt.close('all')
|
172 |
|
173 |
# apply TPS to transform in_channels to align with tpl_channels positions
|
|
|
179 |
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
|
180 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
|
181 |
|
182 |
+
for i, name in enumerate(tpl_names):
|
183 |
+
tpl_dict[name]["coord_2d"] = all_tpl[i]
|
184 |
+
for i, name in enumerate(in_names):
|
185 |
+
in_dict[name]["coord_2d"] = transformed_in[i].tolist()
|
186 |
|
187 |
|
188 |
# 3D alignment
|
189 |
+
all_tpl = np.array([tpl_dict[name]["coord_3d"].tolist() for name in tpl_names])
|
190 |
+
all_in = np.array([in_dict[name]["coord_3d"].tolist() for name in in_names])
|
191 |
+
matched_tpl = np.array([all_tpl[tpl_dict[name]["index"]] for name in matched_names])
|
192 |
+
matched_in = np.array([all_in[in_dict[name]["index"]] for name in matched_names])
|
193 |
|
194 |
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
|
195 |
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
|
|
|
200 |
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
201 |
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
202 |
|
203 |
+
for i, name in enumerate(in_names):
|
204 |
+
in_dict[name]["coord_3d"] = transformed_in[i].tolist()
|
205 |
|
206 |
channel_info.update({
|
207 |
"templateDict" : tpl_dict,
|
|
|
209 |
})
|
210 |
return channel_info
|
211 |
|
212 |
+
def find_neighbors(channel_info, empty_tpl_names, new_idx):
|
213 |
+
in_names = channel_info["inputNames"]
|
214 |
tpl_dict = channel_info["templateDict"]
|
215 |
in_dict = channel_info["inputDict"]
|
216 |
|
217 |
+
all_in = [np.array(in_dict[name]["coord_3d"]) for name in in_names]
|
218 |
+
empty_tpl = [np.array(tpl_dict[name]["coord_3d"]) for name in empty_tpl_names]
|
219 |
|
220 |
# use KNN to choose k nearest channels
|
221 |
+
k = 4 if len(in_names)>4 else len(in_names)
|
222 |
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
223 |
knn.fit(all_in)
|
224 |
+
for i, name in enumerate(empty_tpl_names):
|
225 |
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
|
226 |
+
idx = tpl_dict[name]["index"]
|
227 |
new_idx[idx] = indices[0].tolist()
|
228 |
|
229 |
return new_idx
|
|
|
232 |
# read the location file
|
233 |
loc_file = stage1_info["fileNames"]["inputLocation"]
|
234 |
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
|
235 |
+
tpl_names = tpl_montage.ch_names
|
236 |
+
in_names = in_montage.ch_names
|
237 |
new_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
|
238 |
orig_flags = [False]*30
|
239 |
|
|
|
243 |
'T5': 'P7',
|
244 |
'T6': 'P8'
|
245 |
}
|
246 |
+
for i, name in enumerate(tpl_names):
|
247 |
+
if name in alias_dict and alias_dict[name] in in_dict:
|
248 |
+
tpl_montage.rename_channels({name: alias_dict[name]})
|
249 |
+
tpl_dict[alias_dict[name]] = tpl_dict.pop(name)
|
250 |
+
name = alias_dict[name]
|
251 |
|
252 |
+
if name in in_dict:
|
253 |
+
new_idx[i] = [in_dict[name]["index"]]
|
254 |
orig_flags[i] = True
|
255 |
+
tpl_dict[name]["matched"] = True
|
256 |
+
in_dict[name]["assigned"] = True
|
257 |
|
258 |
# update the names
|
259 |
+
tpl_names = tpl_montage.ch_names
|
260 |
|
261 |
stage1_info.update({
|
262 |
+
"unassignedInputs" : get_unassigned_inputs(in_names, in_dict),
|
263 |
+
"emptyTemplates" : get_empty_templates(tpl_names, tpl_dict),
|
264 |
"mappingResults" : [
|
265 |
{
|
266 |
"newOrder" : new_idx,
|
|
|
269 |
]
|
270 |
})
|
271 |
channel_info = {
|
272 |
+
"templateNames" : tpl_names,
|
273 |
+
"inputNames" : in_names,
|
274 |
"templateDict" : tpl_dict,
|
275 |
"inputDict" : in_dict
|
276 |
}
|
277 |
return stage1_info, channel_info, tpl_montage, in_montage
|
278 |
|
279 |
def optimal_mapping(channel_info):
|
280 |
+
tpl_names = channel_info["templateNames"]
|
281 |
+
in_names = channel_info["inputNames"]
|
282 |
tpl_dict = channel_info["templateDict"]
|
283 |
in_dict = channel_info["inputDict"]
|
284 |
+
unass_in_names = get_unassigned_inputs(in_names, in_dict)
|
285 |
# reset all tpl.matched to False
|
286 |
+
for name in tpl_dict:
|
287 |
+
tpl_dict[name]["matched"] = False
|
288 |
|
289 |
+
all_tpl = np.array([tpl_dict[name]["coord_3d"] for name in tpl_names])
|
290 |
+
unass_in = np.array([in_dict[name]["coord_3d"] for name in unass_in_names])
|
291 |
|
292 |
# initialize the cost matrix for the Hungarian algorithm
|
293 |
+
if len(unass_in_names) < 30:
|
294 |
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
|
295 |
else:
|
296 |
+
cost_matrix = np.zeros((30, len(unass_in_names)))
|
297 |
# fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
|
298 |
for i in range(30):
|
299 |
+
for j in range(len(unass_in_names)):
|
300 |
cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unass_in[j])*1000)
|
301 |
|
302 |
# apply the Hungarian algorithm to optimally assign one in_channel to each tpl_channel
|
|
|
307 |
new_idx = [[None]]*30
|
308 |
orig_flags = [False]*30
|
309 |
for i, j in zip(row_idx, col_idx):
|
310 |
+
if j < len(unass_in_names): # filter out dummy channels
|
311 |
+
tpl_name = tpl_names[i]
|
312 |
+
in_name = unass_in_names[j]
|
313 |
|
314 |
+
new_idx[i] = [in_dict[in_name]["index"]]
|
315 |
orig_flags[i] = True
|
316 |
+
tpl_dict[tpl_name]["matched"] = True
|
317 |
+
in_dict[in_name]["assigned"] = True
|
318 |
+
#print(f'{tpl_name}({i}) <- {in_name}({j})')
|
319 |
|
320 |
# fill the remaining empty tpl_channels
|
321 |
+
empty_tpl_names = get_empty_templates(tpl_names, tpl_dict)
|
322 |
+
if empty_tpl_names != []:
|
323 |
+
new_idx = find_neighbors(channel_info, empty_tpl_names, new_idx)
|
324 |
|
325 |
result = {
|
326 |
"newOrder" : new_idx,
|
|
|
344 |
results += [result]
|
345 |
|
346 |
data = {
|
347 |
+
#"templateNames" : channel_info["templateNames"],
|
348 |
+
#"inputNames" : channel_info["inputNames"],
|
349 |
"batchNum" : batch_num,
|
350 |
"mappingResults" : results
|
351 |
}
|