Spaces:
Sleeping
Sleeping
Commit
·
a22369d
1
Parent(s):
7a54f74
update v2
Browse files- app.py +220 -183
- channel_mapping.py +237 -241
- utils.py +40 -3
app.py
CHANGED
@@ -2,8 +2,10 @@ import gradio as gr
|
|
2 |
import numpy as np
|
3 |
import os
|
4 |
import random
|
|
|
|
|
5 |
import utils
|
6 |
-
from channel_mapping import
|
7 |
|
8 |
import mne
|
9 |
from mne.channels import read_custom_montage
|
@@ -11,12 +13,9 @@ from mne.channels import read_custom_montage
|
|
11 |
quickstart = """
|
12 |
# Quickstart
|
13 |
|
14 |
-
## 1. Channel mapping
|
15 |
-
|
16 |
### Raw data
|
17 |
1. The data need to be a two-dimensional array (channel, timepoint).
|
18 |
-
2.
|
19 |
-
3. Upload your EEG data in `.csv` format.
|
20 |
|
21 |
### Channel locations
|
22 |
Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
|
@@ -27,29 +26,12 @@ The models was trained using the EEG signals of 30 channels, including: `Fp1, Fp
|
|
27 |
We expect your input data to include these channels as well.
|
28 |
If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from:
|
29 |
|
30 |
-
<u>Manually</u>:
|
31 |
-
- **mean**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**.
|
32 |
-
|
33 |
-
<u>Automatically</u>:
|
34 |
-
Firstly, we will attempt to find neighboring channel to use as alternative. For instance, if the required channel is **FC3** but you only have **FC1**, we will use it as a replacement for **FC3**.
|
35 |
-
Then, depending on the **Imputation** way you chose, we will:
|
36 |
- **zero**: fill the missing channels with zeros.
|
37 |
-
- **
|
38 |
-
|
39 |
|
40 |
### Mapping result
|
41 |
-
Once the mapping process is finished, the **template montage** and the **input montage**(with the channels
|
42 |
-
|
43 |
-
### Missing channels
|
44 |
-
The channels displayed here are those for which the template didn't find suitable channels to use, and utilized **Imputation** to fill the missing values.
|
45 |
-
Therefore, you need to
|
46 |
-
<span style="color:red">**remove these channels**</span>
|
47 |
-
after you download the denoised data.
|
48 |
-
|
49 |
-
### Template location file
|
50 |
-
You need to use this as the **new location file** for the denoised data.
|
51 |
-
|
52 |
-
## 2. Decode data
|
53 |
|
54 |
### Model
|
55 |
Select the model you want to use.
|
@@ -68,20 +50,43 @@ chkbox_js = """
|
|
68 |
state_json = JSON.parse(JSON.stringify(state_json));
|
69 |
if(state_json.state == "finished") return;
|
70 |
|
71 |
-
|
|
|
72 |
position: relative;
|
73 |
width: 560px;
|
74 |
height: 560px;
|
75 |
background: url("file=${state_json.files.raw_montage}");
|
76 |
`;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
|
|
|
79 |
all_chkbox = Array.apply(null, all_chkbox);
|
80 |
|
81 |
all_chkbox.forEach((item, index) => {
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
//console.log(`left: ${left}, bottom: ${bottom}`);
|
86 |
|
87 |
item.style.cssText = `
|
@@ -90,42 +95,73 @@ chkbox_js = """
|
|
90 |
bottom: ${bottom};
|
91 |
`;
|
92 |
item.className = "";
|
93 |
-
item.querySelector("span").innerText = "";
|
94 |
-
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
}
|
97 |
"""
|
98 |
|
99 |
|
100 |
with gr.Blocks() as demo:
|
101 |
|
102 |
-
state_json = gr.JSON(
|
103 |
|
104 |
with gr.Row():
|
105 |
gr.Markdown(
|
106 |
"""
|
107 |
-
|
108 |
"""
|
109 |
)
|
110 |
with gr.Row():
|
|
|
111 |
with gr.Column():
|
112 |
gr.Markdown(
|
113 |
"""
|
114 |
# 1.Channel Mapping
|
115 |
"""
|
116 |
)
|
|
|
|
|
117 |
with gr.Row():
|
118 |
in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
|
119 |
in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
res_md = gr.Markdown(
|
130 |
"""
|
131 |
### Mapping result:
|
@@ -134,11 +170,8 @@ with gr.Blocks() as demo:
|
|
134 |
)
|
135 |
with gr.Row():
|
136 |
tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
|
137 |
-
map_montage = gr.Image(label="
|
138 |
-
|
139 |
-
next_btn = gr.Button("Next", interactive=False, visible=False)
|
140 |
-
miss_txtbox = gr.Textbox(label="Missing channels", visible=False)
|
141 |
-
tpl_loc_file = gr.File("./template_chanlocs.loc", show_label=False, visible=False)
|
142 |
with gr.Column():
|
143 |
gr.Markdown(
|
144 |
"""
|
@@ -146,29 +179,36 @@ with gr.Blocks() as demo:
|
|
146 |
"""
|
147 |
)
|
148 |
with gr.Row():
|
149 |
-
in_model_name = gr.Dropdown(choices=[
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
label="Model",
|
152 |
scale=2)
|
153 |
run_btn = gr.Button(scale=1, interactive=False)
|
|
|
154 |
out_denoised_data = gr.File(label="Denoised data")
|
155 |
|
156 |
|
157 |
with gr.Row():
|
158 |
-
with gr.Tab("
|
159 |
gr.Markdown()
|
160 |
with gr.Tab("IC-U-Net"):
|
161 |
gr.Markdown(icunet)
|
162 |
with gr.Tab("IC-U-Net++"):
|
163 |
gr.Markdown()
|
164 |
-
with gr.Tab("IC-U-Net-
|
165 |
gr.Markdown()
|
166 |
with gr.Tab("QuickStart"):
|
167 |
gr.Markdown(quickstart)
|
168 |
|
169 |
#demo.load(js=js)
|
170 |
|
171 |
-
def reset_layout(raw_data):
|
172 |
# establish temp folder
|
173 |
filepath = os.path.dirname(str(raw_data))
|
174 |
try:
|
@@ -177,51 +217,53 @@ with gr.Blocks() as demo:
|
|
177 |
utils.dataDelete(filepath+"/temp_data/")
|
178 |
os.mkdir(filepath+"/temp_data/")
|
179 |
#print(e)
|
180 |
-
|
|
|
|
|
181 |
"filepath": filepath+"/temp_data/",
|
182 |
-
"files": {}
|
|
|
|
|
183 |
}
|
184 |
-
return {state_json :
|
185 |
-
|
186 |
next_btn : gr.Button("Next", interactive=False, visible=False),
|
187 |
run_btn : gr.Button(interactive=False),
|
188 |
tpl_montage : gr.Image(visible=False),
|
189 |
map_montage : gr.Image(value=None, visible=False),
|
190 |
-
miss_txtbox : gr.Textbox(visible=False),
|
191 |
res_md : gr.Markdown(visible=False),
|
192 |
-
|
193 |
|
194 |
-
def mapping_result(
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
if fill_mode=="
|
198 |
-
|
199 |
"state" : "initializing",
|
200 |
"fillingCount" : 0,
|
201 |
-
"totalFillingNum" : len(
|
202 |
})
|
203 |
-
#print("Missing channels:",
|
204 |
-
return {state_json :
|
|
|
205 |
next_btn : gr.Button(visible=True)}
|
206 |
else:
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
missing_channels = ', '.join(missing_channels)
|
211 |
-
|
212 |
-
state_obj.update({
|
213 |
-
"state" : "finished",
|
214 |
-
#"fillingCount" : -1,
|
215 |
-
#"totalFillingNum" : -1
|
216 |
-
})
|
217 |
-
return {state_json : state_obj,
|
218 |
res_md : gr.Markdown(visible=True),
|
219 |
-
miss_txtbox : gr.Textbox(value=missing_channels, visible=True),
|
220 |
-
tpl_loc_file : gr.File(visible=True),
|
221 |
run_btn : gr.Button(interactive=True)}
|
222 |
|
223 |
-
def show_montage(
|
224 |
-
filepath =
|
225 |
raw_montage = read_custom_montage(raw_loc)
|
226 |
|
227 |
# convert all channel names to uppercase
|
@@ -229,73 +271,63 @@ with gr.Blocks() as demo:
|
|
229 |
channel = raw_montage.ch_names[i]
|
230 |
raw_montage.rename_channels({channel: str.upper(channel)})
|
231 |
|
232 |
-
if
|
233 |
filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
|
234 |
-
|
235 |
raw_fig = raw_montage.plot()
|
236 |
raw_fig.set_size_inches(5.6, 5.6)
|
237 |
raw_fig.savefig(filename, pad_inches=0)
|
238 |
|
239 |
-
return {state_json :
|
240 |
-
#tpl_montage : gr.Image(visible=True),
|
241 |
-
#in_montage : gr.Image(value=filename, visible=True),
|
242 |
-
#map_montage : gr.Image(visible=False)}
|
243 |
|
244 |
-
elif
|
245 |
-
# didn't find any way to hide the dark points...
|
246 |
-
# tmp
|
247 |
filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
|
248 |
-
|
249 |
|
250 |
show_names= []
|
251 |
-
for channel in
|
252 |
-
if
|
253 |
-
if channel=='CZ' and state_obj["CZImputed"]:
|
254 |
-
continue
|
255 |
show_names.append(channel)
|
256 |
mapped_fig = raw_montage.plot(show_names=show_names)
|
257 |
mapped_fig.set_size_inches(5.6, 5.6)
|
258 |
mapped_fig.savefig(filename, pad_inches=0)
|
259 |
|
260 |
-
return {state_json :
|
261 |
tpl_montage : gr.Image(visible=True),
|
262 |
map_montage : gr.Image(value=filename, visible=True)}
|
263 |
|
264 |
-
|
265 |
-
|
266 |
-
#return {in_montage : gr.Image()}
|
267 |
-
return {state_json : state_obj}
|
268 |
|
269 |
-
def generate_chkbox(
|
270 |
-
if
|
271 |
-
in_channels = [channel for channel in
|
272 |
-
|
273 |
|
274 |
-
first_idx =
|
275 |
-
first_name =
|
276 |
-
chkbox_label = first_name+' (1/'+str(
|
277 |
-
return {state_json :
|
278 |
-
|
279 |
next_btn : gr.Button(interactive=True)}
|
280 |
else:
|
281 |
-
return {state_json :
|
282 |
|
283 |
|
284 |
map_btn.click(
|
285 |
fn = reset_layout,
|
286 |
-
inputs = in_raw_data,
|
287 |
-
outputs = [state_json,
|
288 |
-
|
289 |
-
|
290 |
).success(
|
291 |
-
fn =
|
292 |
-
inputs = [in_raw_data, in_raw_loc, in_fill_mode],
|
293 |
-
outputs =
|
294 |
|
295 |
).success(
|
296 |
fn = mapping_result,
|
297 |
-
inputs = [state_json,
|
298 |
-
outputs = [state_json,
|
299 |
|
300 |
).success(
|
301 |
fn = show_montage,
|
@@ -305,7 +337,8 @@ with gr.Blocks() as demo:
|
|
305 |
).success(
|
306 |
fn = generate_chkbox,
|
307 |
inputs = state_json,
|
308 |
-
outputs = [state_json,
|
|
|
309 |
).success(
|
310 |
fn = None,
|
311 |
js = chkbox_js,
|
@@ -314,81 +347,85 @@ with gr.Blocks() as demo:
|
|
314 |
)
|
315 |
|
316 |
|
317 |
-
def check_next(
|
318 |
-
if
|
319 |
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
next_btn : gr.Button(btn_label)}
|
344 |
-
else:
|
345 |
-
state_obj["state"] = "finished"
|
346 |
-
reorder_data(raw_data, state_obj["newOrder"], fill_mode, state_obj)
|
347 |
-
|
348 |
-
missing_channels = []
|
349 |
-
for idx in state_obj["missingChannelsIndex"]:
|
350 |
-
if idx != -1:
|
351 |
-
missing_channels.append(state_obj["templateByIndex"][idx])
|
352 |
-
missing_channels = ', '.join(missing_channels)
|
353 |
-
|
354 |
-
return {state_json : state_obj,
|
355 |
-
chs_chkbox : gr.CheckboxGroup(visible=False),
|
356 |
-
next_btn : gr.Button(visible=False),
|
357 |
-
res_md : gr.Markdown(visible=True),
|
358 |
-
miss_txtbox : gr.Textbox(value=missing_channels, visible=True),
|
359 |
-
tpl_loc_file : gr.File(visible=True),
|
360 |
-
run_btn : gr.Button(interactive=True)}
|
361 |
|
362 |
next_btn.click(
|
363 |
-
fn = check_next,
|
364 |
-
inputs = [state_json,
|
365 |
-
outputs = [state_json,
|
366 |
-
|
367 |
).success(
|
368 |
fn = show_montage,
|
369 |
inputs = [state_json, in_raw_loc],
|
370 |
outputs = [state_json, tpl_montage, map_montage]
|
371 |
)
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
|
|
|
|
|
|
379 |
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
|
380 |
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
-
|
388 |
-
utils.reconstruct(model_name, total_file_num, filepath, output_name)
|
389 |
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
if __name__ == "__main__":
|
394 |
demo.launch()
|
|
|
2 |
import numpy as np
|
3 |
import os
|
4 |
import random
|
5 |
+
import math
|
6 |
+
import json
|
7 |
import utils
|
8 |
+
from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin
|
9 |
|
10 |
import mne
|
11 |
from mne.channels import read_custom_montage
|
|
|
13 |
quickstart = """
|
14 |
# Quickstart
|
15 |
|
|
|
|
|
16 |
### Raw data
|
17 |
1. The data need to be a two-dimensional array (channel, timepoint).
|
18 |
+
2. Upload your EEG data in `.csv` format.
|
|
|
19 |
|
20 |
### Channel locations
|
21 |
Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
|
|
|
26 |
We expect your input data to include these channels as well.
|
27 |
If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from:
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
- **zero**: fill the missing channels with zeros.
|
30 |
+
- **mean(auto)**: select 4 neareat channels for each missing channels, and we will average their values.
|
31 |
+
- **mean(manual)**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**.
|
32 |
|
33 |
### Mapping result
|
34 |
+
Once the mapping process is finished, the **template montage** and the **input montage**(with the matched channels displaying their names) will be shown.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
### Model
|
37 |
Select the model you want to use.
|
|
|
50 |
state_json = JSON.parse(JSON.stringify(state_json));
|
51 |
if(state_json.state == "finished") return;
|
52 |
|
53 |
+
// add figure of in_montage
|
54 |
+
document.querySelector("#chkbox-group> div:nth-of-type(2)").style.cssText = `
|
55 |
position: relative;
|
56 |
width: 560px;
|
57 |
height: 560px;
|
58 |
background: url("file=${state_json.files.raw_montage}");
|
59 |
`;
|
60 |
+
|
61 |
+
// add indication for the missing channel
|
62 |
+
/*
|
63 |
+
let indicator = document.getElementById("indicator")
|
64 |
+
if(!indicator) document.querySelector("#chkbox-group> div:nth-of-type(2)").innerHTML += '<div id="indicator"></div>'
|
65 |
+
|
66 |
+
let channel = state_json.missingChannelsIndex[0]
|
67 |
+
channel = state_json.templateByIndex[channel]
|
68 |
+
let left = state_json.templateByName[channel].css_position[0];
|
69 |
+
let bottom = state_json.templateByName[channel].css_position[1];
|
70 |
+
|
71 |
+
document.getElementById("red-dot").style.cssText = `
|
72 |
+
position: absolute;
|
73 |
+
background-color: red;
|
74 |
+
width: 10px;
|
75 |
+
height: 10px;
|
76 |
+
border-radius: 50%;
|
77 |
+
left: ${left};
|
78 |
+
bottom: ${bottom};
|
79 |
+
`;
|
80 |
+
*/
|
81 |
|
82 |
+
// move the checkboxes
|
83 |
+
let all_chkbox = document.querySelectorAll("#chkbox-group> div:nth-of-type(2)> label");
|
84 |
all_chkbox = Array.apply(null, all_chkbox);
|
85 |
|
86 |
all_chkbox.forEach((item, index) => {
|
87 |
+
channel = state_json.inputByIndex[index];
|
88 |
+
left = state_json.inputByName[channel].css_position[0];
|
89 |
+
bottom = state_json.inputByName[channel].css_position[1];
|
90 |
//console.log(`left: ${left}, bottom: ${bottom}`);
|
91 |
|
92 |
item.style.cssText = `
|
|
|
95 |
bottom: ${bottom};
|
96 |
`;
|
97 |
item.className = "";
|
98 |
+
item.querySelector(":scope> span").innerText = "";
|
99 |
+
});
|
100 |
+
}
|
101 |
+
"""
|
102 |
+
|
103 |
+
dot_js = """
|
104 |
+
(state_json) => {
|
105 |
+
state_json = JSON.parse(JSON.stringify(state_json));
|
106 |
+
if(state_json.state == "finished") return;
|
107 |
+
|
108 |
+
let channel = state_json.missingChannelsIndex[state_json["fillingCount"]]
|
109 |
+
channel = state_json.templateByIndex[channel]
|
110 |
+
let left = state_json.templateByName[channel].css_position[0];
|
111 |
+
let bottom = state_json.templateByName[channel].css_position[1];
|
112 |
|
113 |
+
document.getElementById("indicator").style.cssText = `
|
114 |
+
position: absolute;
|
115 |
+
background-color: red;
|
116 |
+
width: 10px;
|
117 |
+
height: 10px;
|
118 |
+
border-radius: 50%;
|
119 |
+
left: ${left};
|
120 |
+
bottom: ${bottom};
|
121 |
+
`;
|
122 |
}
|
123 |
"""
|
124 |
|
125 |
|
126 |
with gr.Blocks() as demo:
|
127 |
|
128 |
+
state_json = gr.JSON(visible=False)
|
129 |
|
130 |
with gr.Row():
|
131 |
gr.Markdown(
|
132 |
"""
|
133 |
+
<p style="text-align: center;">(...)</p>
|
134 |
"""
|
135 |
)
|
136 |
with gr.Row():
|
137 |
+
|
138 |
with gr.Column():
|
139 |
gr.Markdown(
|
140 |
"""
|
141 |
# 1.Channel Mapping
|
142 |
"""
|
143 |
)
|
144 |
+
|
145 |
+
# upload files, chose imputation way (???
|
146 |
with gr.Row():
|
147 |
in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
|
148 |
in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
|
149 |
+
with gr.Column(min_width=100):
|
150 |
+
in_sample_rate = gr.Textbox(label="Sampling rate (Hz)")
|
151 |
+
in_fill_mode = gr.Dropdown(choices=[
|
152 |
+
#("adjacent channel", "adjacent"),
|
153 |
+
("mean (auto)", "mean_auto"),
|
154 |
+
("mean (manual)", "mean_manual"),
|
155 |
+
("",""),
|
156 |
+
"zero"],
|
157 |
+
value="mean_auto",
|
158 |
+
label="Imputation")
|
159 |
+
map_btn = gr.Button("Mapping")
|
160 |
+
|
161 |
+
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
|
162 |
+
next_btn = gr.Button("Next", interactive=False, visible=False)
|
163 |
+
|
164 |
+
# mapping result
|
165 |
res_md = gr.Markdown(
|
166 |
"""
|
167 |
### Mapping result:
|
|
|
170 |
)
|
171 |
with gr.Row():
|
172 |
tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
|
173 |
+
map_montage = gr.Image(label="Matched channels", visible=False)
|
174 |
+
|
|
|
|
|
|
|
175 |
with gr.Column():
|
176 |
gr.Markdown(
|
177 |
"""
|
|
|
179 |
"""
|
180 |
)
|
181 |
with gr.Row():
|
182 |
+
in_model_name = gr.Dropdown(choices=[
|
183 |
+
("ART", "EEGART"),
|
184 |
+
("IC-U-Net", "ICUNet"),
|
185 |
+
("IC-U-Net++", "UNetpp"),
|
186 |
+
("IC-U-Net-Attn", "AttUnet"),
|
187 |
+
"(mapped data)",
|
188 |
+
"(denoised data)"],
|
189 |
+
value="EEGART",
|
190 |
label="Model",
|
191 |
scale=2)
|
192 |
run_btn = gr.Button(scale=1, interactive=False)
|
193 |
+
batch_md = gr.Markdown(visible=False)
|
194 |
out_denoised_data = gr.File(label="Denoised data")
|
195 |
|
196 |
|
197 |
with gr.Row():
|
198 |
+
with gr.Tab("ART"):
|
199 |
gr.Markdown()
|
200 |
with gr.Tab("IC-U-Net"):
|
201 |
gr.Markdown(icunet)
|
202 |
with gr.Tab("IC-U-Net++"):
|
203 |
gr.Markdown()
|
204 |
+
with gr.Tab("IC-U-Net-Attn"):
|
205 |
gr.Markdown()
|
206 |
with gr.Tab("QuickStart"):
|
207 |
gr.Markdown(quickstart)
|
208 |
|
209 |
#demo.load(js=js)
|
210 |
|
211 |
+
def reset_layout(raw_data, samplerate):
|
212 |
# establish temp folder
|
213 |
filepath = os.path.dirname(str(raw_data))
|
214 |
try:
|
|
|
217 |
utils.dataDelete(filepath+"/temp_data/")
|
218 |
os.mkdir(filepath+"/temp_data/")
|
219 |
#print(e)
|
220 |
+
|
221 |
+
data = utils.read_train_data(raw_data)
|
222 |
+
state = {
|
223 |
"filepath": filepath+"/temp_data/",
|
224 |
+
"files": {},
|
225 |
+
"sampleRate": int(samplerate),
|
226 |
+
"dataShape" : data.shape
|
227 |
}
|
228 |
+
return {state_json : state,
|
229 |
+
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
|
230 |
next_btn : gr.Button("Next", interactive=False, visible=False),
|
231 |
run_btn : gr.Button(interactive=False),
|
232 |
tpl_montage : gr.Image(visible=False),
|
233 |
map_montage : gr.Image(value=None, visible=False),
|
|
|
234 |
res_md : gr.Markdown(visible=False),
|
235 |
+
batch_md : gr.Markdown(visible=False)}
|
236 |
|
237 |
+
def mapping_result(state, fill_mode):
|
238 |
+
|
239 |
+
in_num = len(state["inputByName"])
|
240 |
+
matched_num = 30 - len(state["missingChannelsIndex"])
|
241 |
+
batch_num = math.ceil((in_num-matched_num)/30) + 1
|
242 |
+
state.update({
|
243 |
+
"runnigState" : "stage1",
|
244 |
+
"batchCount" : 1,
|
245 |
+
"totalBatchNum" : batch_num
|
246 |
+
})
|
247 |
|
248 |
+
if fill_mode=="mean_manual" and state["missingChannelsIndex"]!=[]:
|
249 |
+
state.update({
|
250 |
"state" : "initializing",
|
251 |
"fillingCount" : 0,
|
252 |
+
"totalFillingNum" : len(state["missingChannelsIndex"])-1
|
253 |
})
|
254 |
+
#print("Missing channels:", state["missingChannelsIndex"])
|
255 |
+
return {state_json : state,
|
256 |
+
#chkbox_group : gr.CheckboxGroup(visible=True),
|
257 |
next_btn : gr.Button(visible=True)}
|
258 |
else:
|
259 |
+
state["state"] = "finished"
|
260 |
+
|
261 |
+
return {state_json : state,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
res_md : gr.Markdown(visible=True),
|
|
|
|
|
263 |
run_btn : gr.Button(interactive=True)}
|
264 |
|
265 |
+
def show_montage(state, raw_loc):
|
266 |
+
filepath = state["filepath"]
|
267 |
raw_montage = read_custom_montage(raw_loc)
|
268 |
|
269 |
# convert all channel names to uppercase
|
|
|
271 |
channel = raw_montage.ch_names[i]
|
272 |
raw_montage.rename_channels({channel: str.upper(channel)})
|
273 |
|
274 |
+
if state["state"] == "initializing":
|
275 |
filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
|
276 |
+
state["files"]["raw_montage"] = filename
|
277 |
raw_fig = raw_montage.plot()
|
278 |
raw_fig.set_size_inches(5.6, 5.6)
|
279 |
raw_fig.savefig(filename, pad_inches=0)
|
280 |
|
281 |
+
return {state_json : state}
|
|
|
|
|
|
|
282 |
|
283 |
+
elif state["state"] == "finished":
|
|
|
|
|
284 |
filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
|
285 |
+
state["files"]["map_montage"] = filename
|
286 |
|
287 |
show_names= []
|
288 |
+
for channel in state["inputByName"]:
|
289 |
+
if state["inputByName"][channel]["matched"]:
|
|
|
|
|
290 |
show_names.append(channel)
|
291 |
mapped_fig = raw_montage.plot(show_names=show_names)
|
292 |
mapped_fig.set_size_inches(5.6, 5.6)
|
293 |
mapped_fig.savefig(filename, pad_inches=0)
|
294 |
|
295 |
+
return {state_json : state,
|
296 |
tpl_montage : gr.Image(visible=True),
|
297 |
map_montage : gr.Image(value=filename, visible=True)}
|
298 |
|
299 |
+
else:
|
300 |
+
return {state_json : state} # change nothing
|
|
|
|
|
301 |
|
302 |
+
def generate_chkbox(state):
|
303 |
+
if state["state"] == "initializing":
|
304 |
+
in_channels = [channel for channel in state["inputByName"]]
|
305 |
+
state["state"] = "selecting"
|
306 |
|
307 |
+
first_idx = state["missingChannelsIndex"][0]
|
308 |
+
first_name = state["templateByIndex"][first_idx]
|
309 |
+
chkbox_label = first_name+' (1/'+str(state["totalFillingNum"]+1)+')'
|
310 |
+
return {state_json : state,
|
311 |
+
chkbox_group : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True),
|
312 |
next_btn : gr.Button(interactive=True)}
|
313 |
else:
|
314 |
+
return {state_json : state} # change nothing
|
315 |
|
316 |
|
317 |
map_btn.click(
|
318 |
fn = reset_layout,
|
319 |
+
inputs = [in_raw_data, in_sample_rate],
|
320 |
+
outputs = [state_json, chkbox_group, next_btn, run_btn, tpl_montage, map_montage, res_md, batch_md]
|
321 |
+
|
|
|
322 |
).success(
|
323 |
+
fn = mapping_stage1,
|
324 |
+
inputs = [state_json, in_raw_data, in_raw_loc, in_fill_mode],
|
325 |
+
outputs = state_json
|
326 |
|
327 |
).success(
|
328 |
fn = mapping_result,
|
329 |
+
inputs = [state_json, in_fill_mode],
|
330 |
+
outputs = [state_json, next_btn, res_md, run_btn]
|
331 |
|
332 |
).success(
|
333 |
fn = show_montage,
|
|
|
337 |
).success(
|
338 |
fn = generate_chkbox,
|
339 |
inputs = state_json,
|
340 |
+
outputs = [state_json, chkbox_group, next_btn]
|
341 |
+
|
342 |
).success(
|
343 |
fn = None,
|
344 |
js = chkbox_js,
|
|
|
347 |
)
|
348 |
|
349 |
|
350 |
+
def check_next(state, selected, raw_data, fill_mode):
|
351 |
+
#if state["state"] == "selecting":
|
352 |
|
353 |
+
# save info before clicking on next_btn
|
354 |
+
prev_target_idx = state["missingChannelsIndex"][state["fillingCount"]]
|
355 |
+
prev_target_name = state["templateByIndex"][prev_target_idx]
|
356 |
+
|
357 |
+
selected_idx = [state["inputByName"][channel]["index"] for channel in selected]
|
358 |
+
state["newOrder"][prev_target_idx] = selected_idx
|
359 |
+
|
360 |
+
#if len(selected)==1 and state["inputByName"][selected[0]]["used"]==False:
|
361 |
+
#state["inputByName"][selected[0]]["used"] = True
|
362 |
+
#state["missingChannelsIndex"][state["fillingCount"]] = -1
|
363 |
+
|
364 |
+
print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
|
365 |
+
|
366 |
+
# update next round
|
367 |
+
state["fillingCount"] += 1
|
368 |
+
if state["fillingCount"] <= state["totalFillingNum"]:
|
369 |
+
target_idx = state["missingChannelsIndex"][state["fillingCount"]]
|
370 |
+
target_name = state["templateByIndex"][target_idx]
|
371 |
+
chkbox_label = target_name+' ('+str(state["fillingCount"]+1)+'/'+str(state["totalFillingNum"]+1)+')'
|
372 |
+
btn_label = "Submit" if state["fillingCount"]==state["totalFillingNum"] else "Next"
|
373 |
|
374 |
+
return {state_json : state,
|
375 |
+
chkbox_group : gr.CheckboxGroup(value=[], label=chkbox_label),
|
376 |
+
next_btn : gr.Button(btn_label)}
|
377 |
+
else:
|
378 |
+
state["state"] = "finished"
|
379 |
+
return {state_json : state,
|
380 |
+
chkbox_group : gr.CheckboxGroup(visible=False),
|
381 |
+
next_btn : gr.Button(visible=False),
|
382 |
+
res_md : gr.Markdown(visible=True),
|
383 |
+
run_btn : gr.Button(interactive=True)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
next_btn.click(
|
386 |
+
fn = check_next,
|
387 |
+
inputs = [state_json, chkbox_group, in_raw_data, in_fill_mode],
|
388 |
+
outputs = [state_json, chkbox_group, next_btn, run_btn, res_md]
|
389 |
+
|
390 |
).success(
|
391 |
fn = show_montage,
|
392 |
inputs = [state_json, in_raw_loc],
|
393 |
outputs = [state_json, tpl_montage, map_montage]
|
394 |
)
|
395 |
+
|
396 |
+
@run_btn.click(inputs = [state_json, in_raw_data, in_model_name, in_fill_mode], outputs = out_denoised_data)
|
397 |
+
def run_model(state, raw_data, model_name, fill_mode):
|
398 |
+
filepath = state["filepath"]
|
399 |
+
samplerate = state["sampleRate"]
|
400 |
+
|
401 |
+
#if batch > total_batch:
|
402 |
+
#return {batch_md : gr.Markdown("error", visible=True)}
|
403 |
+
|
404 |
+
input_name = os.path.basename(str(raw_data))
|
405 |
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
|
406 |
|
407 |
+
while(state["runnigState"] != "finished"):
|
408 |
+
if state["batchCount"] > state["totalBatchNum"]:
|
409 |
+
break
|
410 |
+
if state["batchCount"] > 1:
|
411 |
+
state["runnigState"] = "stage2"
|
412 |
+
state = mapping_stage2(state, fill_mode)
|
413 |
+
state["batchCount"] += 1
|
414 |
+
|
415 |
+
reorder_to_template(state, raw_data)
|
416 |
+
# step1: Data preprocessing
|
417 |
+
total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
|
418 |
+
# step2: Signal reconstruction
|
419 |
+
utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
|
420 |
+
reorder_to_origin(state, filepath+'denoised.csv', filepath+output_name)
|
421 |
+
|
422 |
+
if model_name == "(mapped data)":
|
423 |
+
return {out_denoised_data : filepath + 'mapped.csv'}
|
424 |
+
elif model_name == "(denoised data)":
|
425 |
+
return {out_denoised_data : filepath + 'denoised.csv'}
|
426 |
|
427 |
+
return {out_denoised_data : filepath + output_name}
|
|
|
428 |
|
429 |
+
|
|
|
|
|
430 |
if __name__ == "__main__":
|
431 |
demo.launch()
|
channel_mapping.py
CHANGED
@@ -2,15 +2,19 @@ import utils
|
|
2 |
import time
|
3 |
import os
|
4 |
import numpy as np
|
|
|
5 |
|
6 |
import mne
|
7 |
from mne.channels import read_custom_montage
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
14 |
|
15 |
zero_arr = np.zeros((1, old_data.shape[1]))
|
16 |
old_data = np.concatenate((old_data, zero_arr), axis=0)
|
@@ -24,286 +28,278 @@ def reorder_data(filename, old_idx, fill_mode, state_obj):
|
|
24 |
else:
|
25 |
tmp_data = [old_data[j, :] for j in curr_idx_set]
|
26 |
new_data[i, :] = np.mean(tmp_data, axis=0)
|
27 |
-
|
28 |
-
|
29 |
utils.save_data(new_data, new_filename)
|
30 |
return
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
class Channel:
|
34 |
|
35 |
-
def __init__(self, index, name=None,
|
36 |
-
|
37 |
self.name = name
|
38 |
self.index = index
|
39 |
-
self.
|
|
|
40 |
self.coord = coord
|
41 |
self.css_position = css_position
|
42 |
-
self.topo_index = topo_index
|
43 |
-
self.topo_position = topo_position
|
44 |
-
|
45 |
-
def prefix(self):
|
46 |
-
ret = ''.join(filter(str.isalpha, self.name))
|
47 |
-
return ret[:len(ret) - 1] if ret[-1] == 'Z' else ret
|
48 |
|
49 |
-
def suffix(self):
|
50 |
-
return -1 if self.name[-1] == 'Z' else int(''.join(filter(str.isdigit, self.name)))
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
return {
|
56 |
-
"newOrder" : [([i] if i!=-1 else []) for i in new_idx],
|
57 |
-
"missingChannelsIndex" : missing_channels,
|
58 |
-
"templateByName" : {k : v.__dict__ for k,v in tpl_dict.items()}, # dict, {name:object}
|
59 |
-
"templateByIndex" : tpl_ordered_name, # list
|
60 |
-
"inputByName" : {k : v.__dict__ for k,v in in_dict.items()},
|
61 |
-
"inputByIndex" : in_ordered_name
|
62 |
-
}
|
63 |
-
|
64 |
-
def mapping(data_file, loc_file, fill_mode):
|
65 |
-
second1 = time.time()
|
66 |
-
|
67 |
-
data = utils.read_train_data(data_file)
|
68 |
-
|
69 |
-
template_dict = {}
|
70 |
-
input_dict = {}
|
71 |
template_montage = read_custom_montage("./template_chanlocs.loc")
|
72 |
input_montage = read_custom_montage(loc_file)
|
|
|
|
|
73 |
|
74 |
montages = [template_montage, input_montage]
|
75 |
dicts = [template_dict, input_dict]
|
76 |
num = [30, len(input_montage.ch_names)]
|
77 |
|
78 |
for i in range(2):
|
79 |
-
fig = montages[i].plot()
|
80 |
-
fig.set_size_inches(5.6, 5.6)
|
81 |
-
ax = fig.axes[0]
|
82 |
-
ax.set_aspect('equal')
|
83 |
-
ax.figure.canvas.draw() #update the figure
|
84 |
-
coords = ax.collections[0].get_offsets().data
|
85 |
-
abs_coords = ax.transData.transform(coords)
|
86 |
-
#print("abs_coords)
|
87 |
for j in range(num[i]):
|
88 |
channel = montages[i].ch_names[j]
|
|
|
89 |
|
90 |
-
# convert all channel names to uppercase
|
91 |
-
montages[i].rename_channels({channel: str.upper(channel)})
|
92 |
-
|
93 |
-
css_left = (abs_coords[j][0]-11)/560
|
94 |
-
css_bottom = (abs_coords[j][1]-7)/560
|
95 |
channel = str.upper(channel)
|
96 |
-
dicts[i][channel] = Channel(index=j,
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
|
112 |
-
|
|
|
|
|
113 |
alias = {
|
114 |
'T3': 'T7',
|
115 |
'T4': 'T8',
|
116 |
'T5': 'P7',
|
117 |
'T6': 'P8',
|
118 |
-
'TP7': 'T5\'',
|
119 |
-
'TP8': 'T6\'',
|
120 |
}
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
if channel
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
print('Finish at stage 1 ! (',second2 - second1,'s)')
|
146 |
-
#print('new idx order:', new_idx)
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
channels_obj = pack_data(new_idx, exact_missing_channels,
|
156 |
-
template_dict, input_dict,
|
157 |
-
template_montage.ch_names, input_montage.ch_names)
|
158 |
-
channels_obj.update({"CZImputed" : False})
|
159 |
-
return channels_obj
|
160 |
|
|
|
|
|
161 |
|
|
|
|
|
|
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
-
#
|
|
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
temporal_channels = []
|
169 |
-
temporal_row_prefix = ['FC', 'C', 'CP', 'P']
|
170 |
-
|
171 |
-
cnt = 0
|
172 |
-
for i in range(7):
|
173 |
-
tmp = []
|
174 |
-
for j in range(5):
|
175 |
-
if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]:
|
176 |
-
tmp.append('')
|
177 |
-
else:
|
178 |
-
channel = template_montage.ch_names[cnt]
|
179 |
-
tmp.append(channel)
|
180 |
-
|
181 |
-
ver = 'front' if i<3 else 'center' if i==3 else 'back'
|
182 |
-
hor = 'left' if j<2 else 'center' if j==2 else 'right'
|
183 |
-
template_dict[channel].topo_index = [i, j]
|
184 |
-
template_dict[channel].topo_position = [ver, hor]
|
185 |
-
|
186 |
-
if i > 1 and j in [0, 4]:
|
187 |
-
temporal_channels.append(channel)
|
188 |
-
cnt += 1
|
189 |
-
template_topo_pos.append(tmp)
|
190 |
-
|
191 |
-
|
192 |
-
# ensure that CZ is found or imputed by another channel
|
193 |
-
CZ_impute_flag = False
|
194 |
-
if 'CZ' not in input_dict and fill_mode=='adjacent':
|
195 |
-
CZ_impute_flag = True
|
196 |
-
min_dist = 1e5
|
197 |
-
for channel in input_montage.ch_names:
|
198 |
-
curr_x, curr_y, curr_z = input_dict[channel].coord.round(6)
|
199 |
-
if curr_x**2 + curr_y**2 < min_dist:
|
200 |
-
nearest_channel = channel
|
201 |
-
min_dist = curr_x**2 + curr_y**2
|
202 |
-
|
203 |
-
if input_dict[nearest_channel].used == True:
|
204 |
-
missing_channels.append(template_dict['CZ'].index)
|
205 |
-
input_dict[nearest_channel].used = True
|
206 |
-
input_dict['CZ'] = input_dict[nearest_channel]
|
207 |
-
print("CZ's nearest neighbor:", nearest_channel)
|
208 |
-
|
209 |
-
|
210 |
for i in range(30):
|
211 |
-
if
|
212 |
-
|
213 |
-
|
214 |
-
channel = template_montage.ch_names[i]
|
215 |
-
|
216 |
-
curr_prefix = template_dict[channel].prefix()
|
217 |
-
curr_suffix = template_dict[channel].suffix()
|
218 |
-
|
219 |
-
curr_row = template_dict[channel].topo_index[0]
|
220 |
-
curr_col = template_dict[channel].topo_index[1]
|
221 |
-
curr_ver = template_dict[channel].topo_position[0]
|
222 |
-
curr_hor = template_dict[channel].topo_position[1]
|
223 |
-
|
224 |
-
impute_channel = ''
|
225 |
-
|
226 |
-
# if the current channel is a temporal channel
|
227 |
-
if channel in temporal_channels:
|
228 |
-
curr_prefix = temporal_row_prefix[temporal_channels.index(channel)//2]
|
229 |
-
curr_suffix = 7 if curr_hor=='left' else 8
|
230 |
-
|
231 |
-
if fill_mode == 'zero':
|
232 |
-
|
233 |
-
impute_channel = curr_prefix+str(1) if curr_hor=='center' else curr_prefix+str(curr_suffix-2)
|
234 |
-
if impute_channel not in input_dict or input_dict[impute_channel].used==True:
|
235 |
-
impute_channel = ''
|
236 |
-
new_idx[i] = z_row_idx
|
237 |
-
missing_channels.append(i)
|
238 |
-
continue
|
239 |
-
|
240 |
-
elif fill_mode == 'adjacent':
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
impute_channel = curr_prefix + str(1)
|
248 |
-
|
249 |
-
elif template_topo_pos[curr_row+ver_dir][curr_col] in input_dict: # ex: front:FZ<-FCZ,
|
250 |
-
impute_channel = template_topo_pos[curr_row+ver_dir][curr_col]
|
251 |
-
|
252 |
-
elif curr_prefix+str(3) in input_dict: # ex: FZ<-F3
|
253 |
-
impute_channel = curr_prefix + str(3)
|
254 |
-
|
255 |
-
else:
|
256 |
-
impute_channel = 'CZ'
|
257 |
-
|
258 |
-
elif curr_hor == 'left' or curr_hor == 'right':
|
259 |
-
|
260 |
-
ver_ctrl = 1 if curr_ver=='front' else 2 if curr_ver=='back' else 3 # bit0: row+1, bit1: row-1
|
261 |
-
|
262 |
-
# search horizontally
|
263 |
-
cnt = 0
|
264 |
-
tmp_suffix = curr_suffix
|
265 |
-
while tmp_suffix > 0: # ex: F7<-F5/F3/F1
|
266 |
-
tmp_suffix = curr_suffix - 2*cnt
|
267 |
-
if curr_prefix+str(tmp_suffix) in input_dict:
|
268 |
-
impute_channel = curr_prefix + str(tmp_suffix)
|
269 |
-
break
|
270 |
-
|
271 |
-
if cnt == 2:
|
272 |
-
# check row+1/row-1
|
273 |
-
if ver_ctrl&1 and template_topo_pos[curr_row+1][curr_col] in input_dict:
|
274 |
-
impute_channel = template_topo_pos[curr_row+1][curr_col]
|
275 |
-
break
|
276 |
-
if ver_ctrl&2 and template_topo_pos[curr_row-1][curr_col] in input_dict:
|
277 |
-
impute_channel = template_topo_pos[curr_row-1][curr_col]
|
278 |
-
break
|
279 |
-
cnt += 1
|
280 |
-
|
281 |
-
# search vertically
|
282 |
-
if impute_channel == '':
|
283 |
-
cnt = 0
|
284 |
-
tmp_row = curr_row + ver_dir
|
285 |
-
while tmp_row-ver_dir != 3: # terminate if the last channel is a middle one
|
286 |
-
if template_topo_pos[tmp_row][curr_col] in input_dict:
|
287 |
-
impute_channel = template_topo_pos[tmp_row][curr_col]
|
288 |
-
break
|
289 |
-
tmp_row += ver_dir
|
290 |
-
|
291 |
-
# if still cannot find available channel...
|
292 |
-
if impute_channel == '':
|
293 |
-
impute_channel = 'CZ'
|
294 |
-
|
295 |
-
new_idx[i] = input_dict[impute_channel].index
|
296 |
-
if input_dict[impute_channel].used == True: # this channel is shared with others
|
297 |
-
missing_channels.append(i)
|
298 |
-
input_dict[impute_channel].used = True
|
299 |
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
303 |
|
304 |
-
|
305 |
-
|
306 |
-
template_montage.ch_names, input_montage.ch_names)
|
307 |
-
channels_obj.update({"CZImputed" : CZ_impute_flag})
|
308 |
-
return channels_obj
|
309 |
|
|
|
|
|
|
|
|
|
|
2 |
import time
|
3 |
import os
|
4 |
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
|
7 |
import mne
|
8 |
from mne.channels import read_custom_montage
|
9 |
+
from scipy.interpolate import Rbf
|
10 |
+
from scipy.optimize import linear_sum_assignment
|
11 |
+
from sklearn.neighbors import NearestNeighbors
|
12 |
+
|
13 |
+
def reorder_to_template(state, filename):
|
14 |
+
old_idx = state["newOrder"]
|
15 |
+
old_data = utils.read_train_data(filename) # original raw data
|
16 |
+
new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
|
17 |
+
new_filename = state["filepath"]+'mapped.csv'
|
18 |
|
19 |
zero_arr = np.zeros((1, old_data.shape[1]))
|
20 |
old_data = np.concatenate((old_data, zero_arr), axis=0)
|
|
|
28 |
else:
|
29 |
tmp_data = [old_data[j, :] for j in curr_idx_set]
|
30 |
new_data[i, :] = np.mean(tmp_data, axis=0)
|
31 |
+
|
32 |
+
print('old.shape, new.shape: ', old_data.shape, new_data.shape)
|
33 |
utils.save_data(new_data, new_filename)
|
34 |
return
|
35 |
|
36 |
+
def reorder_to_origin(state, filename, new_filename):
|
37 |
+
old_idx = state["newOrder"]
|
38 |
+
old_data = utils.read_train_data(filename) # denoised data
|
39 |
+
template_order = state["templateByIndex"]
|
40 |
+
|
41 |
+
if state["runnigState"] == "stage1":
|
42 |
+
new_data = np.zeros((len(state["inputByName"]), old_data.shape[1]))
|
43 |
+
else:
|
44 |
+
new_data = utils.read_train_data(new_filename)
|
45 |
+
|
46 |
+
for i, channel in enumerate(template_order):
|
47 |
+
idx_set = old_idx[i]
|
48 |
+
|
49 |
+
# ignore if this channel doesn't exist
|
50 |
+
if len(idx_set)==1 and state["templateByName"][channel]["matched"]==True:
|
51 |
+
new_data[idx_set[0], :] = old_data[i, :]
|
52 |
+
|
53 |
+
print('old.shape, new.shape: ', old_data.shape, new_data.shape)
|
54 |
+
utils.save_data(new_data, new_filename)
|
55 |
+
return
|
56 |
|
57 |
class Channel:
|
58 |
|
59 |
+
def __init__(self, index, name=None, matched=False, assigned=False, coord=None, css_position=None):
|
|
|
60 |
self.name = name
|
61 |
self.index = index
|
62 |
+
self.matched = matched
|
63 |
+
self.assigned = assigned # for input channels
|
64 |
self.coord = coord
|
65 |
self.css_position = css_position
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
|
|
|
|
67 |
|
68 |
+
def read_montage_data(loc_file):
|
69 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
template_montage = read_custom_montage("./template_chanlocs.loc")
|
71 |
input_montage = read_custom_montage(loc_file)
|
72 |
+
template_dict = {}
|
73 |
+
input_dict = {}
|
74 |
|
75 |
montages = [template_montage, input_montage]
|
76 |
dicts = [template_dict, input_dict]
|
77 |
num = [30, len(input_montage.ch_names)]
|
78 |
|
79 |
for i in range(2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
for j in range(num[i]):
|
81 |
channel = montages[i].ch_names[j]
|
82 |
+
montages[i].rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase
|
83 |
|
|
|
|
|
|
|
|
|
|
|
84 |
channel = str.upper(channel)
|
85 |
+
dicts[i][channel] = Channel(index=j, name=channel, coord=montages[i].get_positions()['ch_pos'][channel])
|
86 |
+
|
87 |
+
return template_montage, input_montage, template_dict, input_dict
|
88 |
+
|
89 |
+
def align_coords(state, template_montage, input_montage):
|
90 |
+
|
91 |
+
template_dict = state["templateByName"]
|
92 |
+
input_dict = state["inputByName"]
|
93 |
+
template_order = state["templateByIndex"]
|
94 |
+
input_order = state["inputByIndex"]
|
95 |
+
matched = [channel for channel in input_dict if input_dict[channel]["matched"]==True]
|
96 |
|
97 |
+
# 2-d (fot the indication of missing template channel's position when fill_mode:'mean_manual')
|
98 |
+
fig = [template_montage.plot(), input_montage.plot()]
|
99 |
+
fig[0].set_size_inches(5.6, 5.6)
|
100 |
+
fig[1].set_size_inches(5.6, 5.6)
|
101 |
|
102 |
+
ax = [fig[0].axes[0], fig[1].axes[0]]
|
103 |
+
ax[0].set_aspect('equal')
|
104 |
+
ax[1].set_aspect('equal')
|
105 |
+
ax[0].figure.canvas.draw() #update the figure
|
106 |
+
ax[1].figure.canvas.draw()
|
107 |
+
|
108 |
+
# get the original coords
|
109 |
+
all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) # display coords (px)
|
110 |
+
all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data)
|
111 |
+
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
|
112 |
+
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
|
113 |
+
|
114 |
+
# transform the xy axis (template's -> input's)
|
115 |
+
rbf_x = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,0], function='thin_plate')
|
116 |
+
rbf_y = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,1], function='thin_plate')
|
117 |
+
|
118 |
+
# apply to all template channels
|
119 |
+
transformed_tpl_x = rbf_x(all_tpl[:,0], all_tpl[:,1])
|
120 |
+
transformed_tpl_y = rbf_y(all_tpl[:,0], all_tpl[:,1])
|
121 |
+
#transformed_tpl = np.vstack((transformed_tpl_x, transformed_tpl_y)).T
|
122 |
+
|
123 |
+
# update input, template's position
|
124 |
+
for i, channel in enumerate(template_order):
|
125 |
+
css_left = (transformed_tpl_x[i]-11)/560
|
126 |
+
css_bottom = (transformed_tpl_y[i]-7)/560
|
127 |
+
template_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
|
128 |
+
for i, channel in enumerate(input_order):
|
129 |
+
css_left = (all_in[i][0]-11)/560
|
130 |
+
css_bottom = (all_in[i][1]-7)/560
|
131 |
+
input_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
|
132 |
+
|
133 |
+
|
134 |
+
# 3-d (to use KNN)
|
135 |
+
# get the original coords
|
136 |
+
all_tpl = np.array([template_dict[channel]["coord"].tolist() for channel in template_order])
|
137 |
+
all_in = np.array([input_dict[channel]["coord"].tolist() for channel in input_order])
|
138 |
+
matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
|
139 |
+
matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
|
140 |
|
141 |
+
# transform the xyz axis (input's -> template's)
|
142 |
+
rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
|
143 |
+
rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
|
144 |
+
rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
|
145 |
|
146 |
+
# apply to all input channels
|
147 |
+
transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
|
148 |
+
transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
|
149 |
+
transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
|
150 |
+
transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
|
151 |
+
|
152 |
+
# update input's position
|
153 |
+
for i, channel in enumerate(input_order):
|
154 |
+
input_dict[channel]["coord"] = transformed_in[i].tolist()
|
155 |
+
|
156 |
+
state.update({
|
157 |
+
"templateByName" : template_dict,
|
158 |
+
"inputByName" : input_dict,
|
159 |
+
})
|
160 |
+
|
161 |
+
return state
|
162 |
+
|
163 |
+
def fill_channels(state, fill_mode):
|
164 |
+
|
165 |
+
new_idx = state["newOrder"]
|
166 |
+
template_dict = state["templateByName"]
|
167 |
+
input_dict = state["inputByName"]
|
168 |
+
template_order = state["templateByIndex"]
|
169 |
+
input_order = state["inputByIndex"]
|
170 |
+
z_row_idx = state["dataShape"][0]
|
171 |
+
unmatched = [channel for channel in template_dict if template_dict[channel]["matched"]==False]
|
172 |
+
if unmatched == []:
|
173 |
+
return state
|
174 |
+
|
175 |
+
if fill_mode == 'zero':
|
176 |
+
for channel in unmatched:
|
177 |
+
idx = template_dict[channel]["index"]
|
178 |
+
new_idx[idx] = [z_row_idx]
|
179 |
+
|
180 |
+
elif fill_mode == 'mean_auto':
|
181 |
+
# use KNN to choose k nearest channels
|
182 |
+
in_coords = [input_dict[channel]["coord"] for channel in input_order]
|
183 |
+
in_coords = np.array([in_coords[i] for i in range(len(in_coords))])
|
184 |
+
|
185 |
+
k = 4 if len(input_dict)>4 else len(input_dict)
|
186 |
+
knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
|
187 |
+
knn.fit(in_coords)
|
188 |
+
|
189 |
+
for channel in unmatched:
|
190 |
+
distances, indices = knn.kneighbors(template_dict[channel]["coord"].reshape(1,-1))
|
191 |
+
selected = [input_order[i] for i in indices[0]]
|
192 |
+
print(channel, ':', selected)
|
193 |
+
|
194 |
+
idx = template_dict[channel]["index"]
|
195 |
+
new_idx[idx] = indices[0].tolist()
|
196 |
+
|
197 |
+
state["newOrder"] = new_idx
|
198 |
+
return state
|
199 |
+
|
200 |
+
def mapping_stage1(state, data_file, loc_file, fill_mode):
|
201 |
+
second1 = time.time()
|
202 |
|
203 |
+
template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
|
204 |
+
template_order = template_montage.ch_names
|
205 |
+
new_idx = [[]]*30
|
206 |
+
missing_channels = []
|
207 |
alias = {
|
208 |
'T3': 'T7',
|
209 |
'T4': 'T8',
|
210 |
'T5': 'P7',
|
211 |
'T6': 'P8',
|
212 |
+
#'TP7': 'T5\'',
|
213 |
+
#'TP8': 'T6\'',
|
214 |
}
|
215 |
|
216 |
+
# match the names of input channels -> template channels
|
217 |
+
for i, channel in enumerate(template_order):
|
218 |
+
if channel in alias and alias[channel] in input_dict:
|
219 |
+
template_montage.rename_channels({channel: alias[channel]})
|
220 |
+
template_dict[alias[channel]] = template_dict.pop(channel)
|
221 |
+
template_dict[alias[channel]].name = alias[channel]
|
222 |
+
channel = alias[channel]
|
223 |
+
|
224 |
+
if channel in input_dict:
|
225 |
+
new_idx[i] = [input_dict[channel].index]
|
226 |
+
|
227 |
+
template_dict[channel].matched = True
|
228 |
+
input_dict[channel].matched = True
|
229 |
+
input_dict[channel].assigned = True
|
230 |
+
else:
|
231 |
+
missing_channels.append(i)
|
232 |
+
|
233 |
+
state.update({
|
234 |
+
"newOrder" : new_idx,
|
235 |
+
"missingChannelsIndex" : missing_channels,
|
236 |
+
"templateByName" : {k : v.__dict__ for k,v in template_dict.items()},
|
237 |
+
"inputByName" : {k : v.__dict__ for k,v in input_dict.items()},
|
238 |
+
"templateByIndex" : template_montage.ch_names,
|
239 |
+
"inputByIndex" : input_montage.ch_names
|
240 |
+
})
|
241 |
+
|
242 |
+
# align input, template's coordinates
|
243 |
+
state = align_coords(state, template_montage, input_montage)
|
244 |
+
# fill the unmatched channels
|
245 |
+
state = fill_channels(state, fill_mode)
|
246 |
+
|
247 |
+
second2 = time.time()
|
248 |
+
print('Mapping (stage1) finished in',second2 - second1,'s.')
|
249 |
+
return state
|
250 |
|
251 |
+
def mapping_stage2(state, fill_mode):
|
252 |
+
second1 = time.time()
|
|
|
|
|
253 |
|
254 |
+
template_dict = state["templateByName"]
|
255 |
+
input_dict = state["inputByName"]
|
256 |
+
template_order = state["templateByIndex"]
|
257 |
+
unassigned = [channel for channel in input_dict if input_dict[channel]["assigned"]==False]
|
258 |
+
if unassigned == []:
|
259 |
+
state["runnigState"] = "finished"
|
260 |
+
return state
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
+
tpl_coords = np.array([template_dict[channel]["coord"] for channel in template_order])
|
263 |
+
unassigned_coords = np.array([input_dict[channel]["coord"] for channel in unassigned])
|
264 |
|
265 |
+
# set all tpl.matched to False
|
266 |
+
for channel in template_dict:
|
267 |
+
template_dict[channel]["matched"] = False
|
268 |
|
269 |
+
# initialize the cost matrix
|
270 |
+
if len(unassigned) < 30:
|
271 |
+
cost_matrix = np.full((30, 30), 10000) # add dummy channels to ensure num_col > num_row
|
272 |
+
else:
|
273 |
+
cost_matrix = np.zeros((30, len(unassigned)))
|
274 |
+
for i in range(30):
|
275 |
+
for j in range(len(unassigned)):
|
276 |
+
cost_matrix[i][j] = np.linalg.norm(tpl_coords[i] - unassigned_coords[j]) # Euclidean distance
|
277 |
|
278 |
+
# use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
|
279 |
+
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
280 |
|
281 |
+
matches = []
|
282 |
+
new_idx = [[]]*30
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
for i in range(30):
|
284 |
+
if col_idx[i] < len(unassigned): # filter out dummy channels
|
285 |
+
matches.append([row_idx[i], col_idx[i]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
+
tpl_channel = template_order[row_idx[i]]
|
288 |
+
in_channel = unassigned[col_idx[i]]
|
289 |
+
template_dict[tpl_channel]["matched"] = True
|
290 |
+
input_dict[in_channel]["assigned"] = True
|
291 |
+
new_idx[i] = [input_dict[in_channel]["index"]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
+
state.update({
|
294 |
+
"newOrder" : new_idx,
|
295 |
+
"templateByName" : template_dict,
|
296 |
+
"inputByName" : input_dict
|
297 |
+
})
|
298 |
|
299 |
+
# fill the unmatched channels
|
300 |
+
state = fill_channels(state, fill_mode)
|
|
|
|
|
|
|
301 |
|
302 |
+
second2 = time.time()
|
303 |
+
print(f'Mapping (stage2-{state["batchCount"]-1}) finished in {second2 - second1}s.')
|
304 |
+
return state
|
305 |
+
|
utils.py
CHANGED
@@ -42,6 +42,34 @@ def resample(signal, fs):
|
|
42 |
|
43 |
return signal_new
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
def FIR_filter(signal, lowcut, highcut):
|
46 |
fs = 256.0
|
47 |
# Number of FIR filter taps
|
@@ -96,11 +124,14 @@ def glue_data(file_name, total, output):
|
|
96 |
raw_data[:, 1] = smooth
|
97 |
gluedata = np.append(gluedata, raw_data, axis=1)
|
98 |
#print(gluedata.shape)
|
|
|
99 |
filename2 = output
|
100 |
with open(filename2, 'w', newline='') as csvfile:
|
101 |
writer = csv.writer(csvfile)
|
102 |
writer.writerows(gluedata)
|
103 |
#print("GLUE DONE!" + filename2)
|
|
|
|
|
104 |
|
105 |
|
106 |
def save_data(data, filename):
|
@@ -189,10 +220,11 @@ def preprocessing(filepath, filename, samplerate):
|
|
189 |
print(e)
|
190 |
|
191 |
# read data
|
192 |
-
signal = read_train_data(filepath+
|
193 |
#print(signal.shape)
|
194 |
# resample
|
195 |
signal = resample(signal, samplerate)
|
|
|
196 |
#print(signal.shape)
|
197 |
# FIR_filter
|
198 |
signal = FIR_filter(signal, 1, 50)
|
@@ -204,7 +236,7 @@ def preprocessing(filepath, filename, samplerate):
|
|
204 |
|
205 |
|
206 |
# model = tf.keras.models.load_model('./denoise_model/')
|
207 |
-
def reconstruct(model_name, total, filepath, outputfile):
|
208 |
# -------------------decode_data---------------------------
|
209 |
second1 = time.time()
|
210 |
for i in range(total):
|
@@ -224,9 +256,14 @@ def reconstruct(model_name, total, filepath, outputfile):
|
|
224 |
save_data(d_data, outputname)
|
225 |
|
226 |
# --------------------glue_data----------------------------
|
227 |
-
glue_data(filepath+"/temp2/", total, filepath+
|
|
|
228 |
# -------------------delete_data---------------------------
|
229 |
dataDelete(filepath+"/temp2/")
|
|
|
|
|
|
|
|
|
230 |
second2 = time.time()
|
231 |
|
232 |
print("Using", model_name,"model to reconstruct", outputfile, " has been success in", second2 - second1, "sec(s)")
|
|
|
42 |
|
43 |
return signal_new
|
44 |
|
45 |
+
# original -> 256Hz or 256Hz -> original
|
46 |
+
def resample_(signal, current_fs, target_fs):
|
47 |
+
fs = current_fs
|
48 |
+
# downsample the signal to the target sample rate
|
49 |
+
if fs>target_fs:
|
50 |
+
fs_down = target_fs # Desired sample rate
|
51 |
+
q = int(fs / fs_down) # Downsampling factor
|
52 |
+
signal_new = []
|
53 |
+
for ch in signal:
|
54 |
+
x_down = decimate(ch, q)
|
55 |
+
signal_new.append(x_down)
|
56 |
+
|
57 |
+
# upsample the signal to the target sample rate
|
58 |
+
elif fs<target_fs:
|
59 |
+
fs_up = target_fs # Desired sample rate
|
60 |
+
p = int(fs_up / fs) # Upsampling factor
|
61 |
+
signal_new = []
|
62 |
+
for ch in signal:
|
63 |
+
x_up = resample_poly(ch, p, 1)
|
64 |
+
signal_new.append(x_up)
|
65 |
+
|
66 |
+
else:
|
67 |
+
signal_new = signal
|
68 |
+
|
69 |
+
signal_new = np.array(signal_new).astype(np.float64)
|
70 |
+
|
71 |
+
return signal_new
|
72 |
+
|
73 |
def FIR_filter(signal, lowcut, highcut):
|
74 |
fs = 256.0
|
75 |
# Number of FIR filter taps
|
|
|
124 |
raw_data[:, 1] = smooth
|
125 |
gluedata = np.append(gluedata, raw_data, axis=1)
|
126 |
#print(gluedata.shape)
|
127 |
+
'''
|
128 |
filename2 = output
|
129 |
with open(filename2, 'w', newline='') as csvfile:
|
130 |
writer = csv.writer(csvfile)
|
131 |
writer.writerows(gluedata)
|
132 |
#print("GLUE DONE!" + filename2)
|
133 |
+
'''
|
134 |
+
return gluedata
|
135 |
|
136 |
|
137 |
def save_data(data, filename):
|
|
|
220 |
print(e)
|
221 |
|
222 |
# read data
|
223 |
+
signal = read_train_data(filepath+filename)
|
224 |
#print(signal.shape)
|
225 |
# resample
|
226 |
signal = resample(signal, samplerate)
|
227 |
+
#signal = resample_(signal, samplerate, 256)
|
228 |
#print(signal.shape)
|
229 |
# FIR_filter
|
230 |
signal = FIR_filter(signal, 1, 50)
|
|
|
236 |
|
237 |
|
238 |
# model = tf.keras.models.load_model('./denoise_model/')
|
239 |
+
def reconstruct(model_name, total, filepath, outputfile, samplerate):
|
240 |
# -------------------decode_data---------------------------
|
241 |
second1 = time.time()
|
242 |
for i in range(total):
|
|
|
256 |
save_data(d_data, outputname)
|
257 |
|
258 |
# --------------------glue_data----------------------------
|
259 |
+
signal = glue_data(filepath+"/temp2/", total, filepath+outputfile)
|
260 |
+
#print(signal.shape)
|
261 |
# -------------------delete_data---------------------------
|
262 |
dataDelete(filepath+"/temp2/")
|
263 |
+
# --------------------resample-----------------------------
|
264 |
+
signal = resample_(signal, 256, samplerate) # 256Hz -> original sampling rate
|
265 |
+
#print(signal.shape)
|
266 |
+
save_data(signal, filepath+outputfile)
|
267 |
second2 = time.time()
|
268 |
|
269 |
print("Using", model_name,"model to reconstruct", outputfile, " has been success in", second2 - second1, "sec(s)")
|