Spaces:
Sleeping
Sleeping
Commit
Β·
6af0f90
1
Parent(s):
5bfc499
update
Browse files- app.py +27 -31
- app_utils.py +11 -62
- model/{EEGART β ART}/modelsave/checkpoint.pth.tar +0 -0
- model/{EEGART β ART}/modelsave/model_trainValLog.txt +0 -0
- model/{UNetpp β ICUNet++}/modelsave/BEST_checkpoint.pth.tar +0 -0
- model/{UNetpp β ICUNet++}/modelsave/checkpoint.pth.tar +0 -0
- model/{UNetpp β ICUNet++}/modelsave/model_trainValLog.txt +0 -0
- model/{AttUnet β ICUNet_attn}/modelsave/BEST_checkpoint.pth.tar +0 -0
- model/{AttUnet β ICUNet_attn}/modelsave/checkpoint.pth.tar +0 -0
- model/{AttUnet β ICUNet_attn}/modelsave/model_trainValLog.txt +0 -0
- model/__pycache__/UNet_attention.cpython-310.pyc +1 -1
- model/__pycache__/tf_data.cpython-310.pyc +1 -1
- model/__pycache__/tf_model.cpython-310.pyc +1 -1
- utils.py +69 -75
app.py
CHANGED
@@ -45,7 +45,7 @@ Once all template channels are filled, you will be directed to **Mapping Result*
|
|
45 |
### Mapping Result
|
46 |
After completing the previous steps, your channels will be aligned with the template channels required by our models.
|
47 |
- In case there are still some channels that haven't been mapped, we will automatically batch and optimally assign them to the template. This ensures that even channels not initially mapped will still be included in the final result.
|
48 |
-
- Once the mapping process is completed, a JSON file containing the mapping result will be generated. This file is necessary only if you plan to run the models using the <a href="">source code</a>; otherwise, you can ignore it.
|
49 |
|
50 |
## 2. Decode data
|
51 |
After clicking on ``Run`` button, we will process your EEG data based on the mapping result. If necessary, your data will be divided into batches and run the models on each batch sequentially, ensuring that all channels are properly processed.
|
@@ -278,11 +278,11 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
278 |
with gr.Column():
|
279 |
in_samplerate = gr.Textbox(label="Sampling rate (Hz)")
|
280 |
in_modelname = gr.Dropdown(choices=[
|
281 |
-
("ART", "
|
282 |
("IC-U-Net", "ICUNet"),
|
283 |
-
("IC-U-Net++", "
|
284 |
-
("IC-U-Net-Attn", "
|
285 |
-
value="
|
286 |
label="Model")
|
287 |
run_btn = gr.Button("Run", interactive=False)
|
288 |
cancel_btn = gr.Button("Cancel", visible=False)
|
@@ -304,7 +304,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
304 |
gr.Markdown()
|
305 |
|
306 |
def create_dir(req: gr.Request):
|
307 |
-
os.mkdir(gradio_temp_dir+'/'+req.session_hash)
|
308 |
return gradio_temp_dir+'/'+req.session_hash+'/'
|
309 |
demo.load(create_dir, inputs=[], outputs=session_dir)
|
310 |
|
@@ -326,8 +326,8 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
326 |
stage1_dir = uuid.uuid4().hex + '_stage1/'
|
327 |
os.mkdir(rootpath + stage1_dir)
|
328 |
|
329 |
-
|
330 |
-
outputname =
|
331 |
|
332 |
stage1_info = {
|
333 |
"filePath" : rootpath + stage1_dir,
|
@@ -349,7 +349,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
349 |
},
|
350 |
"unassignedInput" : None,
|
351 |
"emptyTemplate" : None,
|
352 |
-
"
|
353 |
"mappingResult" : [
|
354 |
{
|
355 |
"index" : None,
|
@@ -436,7 +436,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
436 |
md = """
|
437 |
### Mapping Result
|
438 |
The mapping process has been finished.
|
439 |
-
Download the file below if you plan to run the models using the <a href="">source code</a>.
|
440 |
"""
|
441 |
# finalize and save the mapping result
|
442 |
outputname = stage1_info["fileNames"]["outputData"]
|
@@ -523,7 +523,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
523 |
md = """
|
524 |
### Mapping Result
|
525 |
The mapping process has been finished.
|
526 |
-
Download the file below if you plan to run the models using the <a href="">source code</a>.
|
527 |
"""
|
528 |
outputname = stage1_info["fileNames"]["outputData"]
|
529 |
stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
|
@@ -560,7 +560,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
560 |
md = """
|
561 |
### Mapping Result
|
562 |
The mapping process has been finished.
|
563 |
-
Download the file below if you plan to run the models using the <a href="">source code</a>.
|
564 |
"""
|
565 |
outputname = stage1_info["fileNames"]["outputData"]
|
566 |
stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
|
@@ -628,7 +628,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
628 |
md = """
|
629 |
### Mapping Result
|
630 |
The mapping process has been finished.
|
631 |
-
Download the file below if you plan to run the models using the <a href="">source code</a>.
|
632 |
"""
|
633 |
outputname = stage1_info["fileNames"]["outputData"]
|
634 |
stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
|
@@ -814,8 +814,8 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
814 |
stage2_dir = uuid.uuid4().hex + '_stage2/'
|
815 |
os.mkdir(rootpath + stage2_dir)
|
816 |
|
817 |
-
|
818 |
-
outputname = modelname + '_'+
|
819 |
|
820 |
stage2_info = {
|
821 |
"filePath" : rootpath + stage2_dir,
|
@@ -832,31 +832,27 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
832 |
batch_md : gr.Markdown("", visible=True),
|
833 |
out_data_file : gr.File(value=None, visible=False)}
|
834 |
|
835 |
-
def
|
836 |
if stage2_info["errorFlag"] == True:
|
837 |
stage2_info["errorFlag"] = False
|
838 |
yield {stage2_json : stage2_info}
|
839 |
|
840 |
else:
|
|
|
841 |
inputname = stage2_info["fileNames"]["inputData"]
|
842 |
outputname = stage2_info["fileNames"]["outputData"]
|
843 |
-
|
844 |
-
basename = os.path.splitext(basename)[0]
|
845 |
break_flag = False
|
846 |
|
847 |
-
for i in range(stage1_info["
|
848 |
-
yield {batch_md : gr.Markdown('Running model({}/{})...'.format(i+1, stage1_info["
|
849 |
try:
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
samplerate = int(samplerate),
|
857 |
-
batch_cnt = i,
|
858 |
-
old_idx = stage1_info["mappingResult"][i]["index"],
|
859 |
-
orig_flags = stage1_info["mappingResult"][i]["isOriginalData"])
|
860 |
except FileNotFoundError:
|
861 |
print('stop!!')
|
862 |
break_flag = True
|
@@ -877,7 +873,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
|
|
877 |
inputs = [session_dir, stage2_json, in_data_file, in_samplerate, in_modelname],
|
878 |
outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
|
879 |
).success(
|
880 |
-
fn =
|
881 |
inputs = [stage1_json, stage2_json, in_samplerate, in_modelname],
|
882 |
outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
|
883 |
)
|
|
|
45 |
### Mapping Result
|
46 |
After completing the previous steps, your channels will be aligned with the template channels required by our models.
|
47 |
- In case there are still some channels that haven't been mapped, we will automatically batch and optimally assign them to the template. This ensures that even channels not initially mapped will still be included in the final result.
|
48 |
+
- Once the mapping process is completed, a JSON file containing the mapping result will be generated. This file is necessary only if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>; otherwise, you can ignore it.
|
49 |
|
50 |
## 2. Decode data
|
51 |
After clicking on ``Run`` button, we will process your EEG data based on the mapping result. If necessary, your data will be divided into batches and run the models on each batch sequentially, ensuring that all channels are properly processed.
|
|
|
278 |
with gr.Column():
|
279 |
in_samplerate = gr.Textbox(label="Sampling rate (Hz)")
|
280 |
in_modelname = gr.Dropdown(choices=[
|
281 |
+
("ART", "ART"),
|
282 |
("IC-U-Net", "ICUNet"),
|
283 |
+
("IC-U-Net++", "ICUNet++"),
|
284 |
+
("IC-U-Net-Attn", "ICUnet_attn")],
|
285 |
+
value="ART",
|
286 |
label="Model")
|
287 |
run_btn = gr.Button("Run", interactive=False)
|
288 |
cancel_btn = gr.Button("Cancel", visible=False)
|
|
|
304 |
gr.Markdown()
|
305 |
|
306 |
def create_dir(req: gr.Request):
|
307 |
+
os.mkdir(gradio_temp_dir+'/'+req.session_hash+'/')
|
308 |
return gradio_temp_dir+'/'+req.session_hash+'/'
|
309 |
demo.load(create_dir, inputs=[], outputs=session_dir)
|
310 |
|
|
|
326 |
stage1_dir = uuid.uuid4().hex + '_stage1/'
|
327 |
os.mkdir(rootpath + stage1_dir)
|
328 |
|
329 |
+
inputname = os.path.basename(str(in_loc))
|
330 |
+
outputname = inputname[:-4] + '_mapping_result.json'
|
331 |
|
332 |
stage1_info = {
|
333 |
"filePath" : rootpath + stage1_dir,
|
|
|
349 |
},
|
350 |
"unassignedInput" : None,
|
351 |
"emptyTemplate" : None,
|
352 |
+
"batch" : None,
|
353 |
"mappingResult" : [
|
354 |
{
|
355 |
"index" : None,
|
|
|
436 |
md = """
|
437 |
### Mapping Result
|
438 |
The mapping process has been finished.
|
439 |
+
Download the file below if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>.
|
440 |
"""
|
441 |
# finalize and save the mapping result
|
442 |
outputname = stage1_info["fileNames"]["outputData"]
|
|
|
523 |
md = """
|
524 |
### Mapping Result
|
525 |
The mapping process has been finished.
|
526 |
+
Download the file below if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>.
|
527 |
"""
|
528 |
outputname = stage1_info["fileNames"]["outputData"]
|
529 |
stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
|
|
|
560 |
md = """
|
561 |
### Mapping Result
|
562 |
The mapping process has been finished.
|
563 |
+
Download the file below if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>.
|
564 |
"""
|
565 |
outputname = stage1_info["fileNames"]["outputData"]
|
566 |
stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
|
|
|
628 |
md = """
|
629 |
### Mapping Result
|
630 |
The mapping process has been finished.
|
631 |
+
Download the file below if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>.
|
632 |
"""
|
633 |
outputname = stage1_info["fileNames"]["outputData"]
|
634 |
stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
|
|
|
814 |
stage2_dir = uuid.uuid4().hex + '_stage2/'
|
815 |
os.mkdir(rootpath + stage2_dir)
|
816 |
|
817 |
+
inputname = os.path.basename(str(in_data))
|
818 |
+
outputname = modelname + '_'+inputname[:-4] + '.csv'
|
819 |
|
820 |
stage2_info = {
|
821 |
"filePath" : rootpath + stage2_dir,
|
|
|
832 |
batch_md : gr.Markdown("", visible=True),
|
833 |
out_data_file : gr.File(value=None, visible=False)}
|
834 |
|
835 |
+
def run_model(stage1_info, stage2_info, samplerate, modelname):
|
836 |
if stage2_info["errorFlag"] == True:
|
837 |
stage2_info["errorFlag"] = False
|
838 |
yield {stage2_json : stage2_info}
|
839 |
|
840 |
else:
|
841 |
+
filepath = stage2_info["filePath"]
|
842 |
inputname = stage2_info["fileNames"]["inputData"]
|
843 |
outputname = stage2_info["fileNames"]["outputData"]
|
844 |
+
mapping_result = stage1_info["mappingResult"]
|
|
|
845 |
break_flag = False
|
846 |
|
847 |
+
for i in range(stage1_info["batch"]):
|
848 |
+
yield {batch_md : gr.Markdown('Running model({}/{})...'.format(i+1, stage1_info["batch"]))}
|
849 |
try:
|
850 |
+
# step1: Data preprocessing
|
851 |
+
preprocess_data, channel_num = utils.preprocessing(filepath, inputname, int(samplerate), mapping_result[i])
|
852 |
+
# step2: Signal reconstruction
|
853 |
+
reconstructed_data = utils.reconstruct(modelname, preprocess_data, filepath, i)
|
854 |
+
# step3: Data postprocessing
|
855 |
+
utils.postprocessing(reconstructed_data, int(samplerate), outputname, mapping_result[i], i, channel_num)
|
|
|
|
|
|
|
|
|
856 |
except FileNotFoundError:
|
857 |
print('stop!!')
|
858 |
break_flag = True
|
|
|
873 |
inputs = [session_dir, stage2_json, in_data_file, in_samplerate, in_modelname],
|
874 |
outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
|
875 |
).success(
|
876 |
+
fn = run_model,
|
877 |
inputs = [stage1_json, stage2_json, in_samplerate, in_modelname],
|
878 |
outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
|
879 |
)
|
app_utils.py
CHANGED
@@ -54,7 +54,7 @@ def match_name(stage1_info):
|
|
54 |
tpl_names = tpl_montage.ch_names
|
55 |
in_names = in_montage.ch_names
|
56 |
old_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
|
57 |
-
|
58 |
|
59 |
alias_dict = {
|
60 |
'T3': 'T7',
|
@@ -70,7 +70,7 @@ def match_name(stage1_info):
|
|
70 |
|
71 |
if name in in_dict:
|
72 |
old_idx[i] = [in_dict[name]["index"]]
|
73 |
-
|
74 |
tpl_dict[name]["matched"] = True
|
75 |
in_dict[name]["assigned"] = True
|
76 |
|
@@ -83,7 +83,7 @@ def match_name(stage1_info):
|
|
83 |
"mappingResult" : [
|
84 |
{
|
85 |
"index" : old_idx,
|
86 |
-
"isOriginalData" :
|
87 |
}
|
88 |
]
|
89 |
})
|
@@ -270,14 +270,14 @@ def optimal_mapping(channel_info):
|
|
270 |
|
271 |
# store the mapping result
|
272 |
old_idx = [[None]]*30
|
273 |
-
|
274 |
for i, j in zip(row_idx, col_idx):
|
275 |
if j < len(unass_in_names): # filter out dummy channels
|
276 |
tpl_name = tpl_names[i]
|
277 |
in_name = unass_in_names[j]
|
278 |
|
279 |
old_idx[i] = [in_dict[in_name]["index"]]
|
280 |
-
|
281 |
tpl_dict[tpl_name]["matched"] = True
|
282 |
in_dict[in_name]["assigned"] = True
|
283 |
|
@@ -288,23 +288,23 @@ def optimal_mapping(channel_info):
|
|
288 |
|
289 |
result = {
|
290 |
"index" : old_idx,
|
291 |
-
"isOriginalData" :
|
292 |
}
|
293 |
channel_info["inputDict"] = in_dict
|
294 |
return result, channel_info
|
295 |
|
296 |
def mapping_result(stage1_info, channel_info, filename):
|
297 |
unassigned_num = len(stage1_info["unassignedInput"])
|
298 |
-
|
299 |
|
300 |
# map the remaining in_channels
|
301 |
results = stage1_info["mappingResult"]
|
302 |
-
for i in range(1,
|
303 |
# optimally select 30 in_channels to map to the tpl_channels based on proximity
|
304 |
result, channel_info = optimal_mapping(channel_info)
|
305 |
results += [result]
|
306 |
'''
|
307 |
-
for i in range(
|
308 |
results[i]["name"] = {}
|
309 |
for j, indices in enumerate(results[i]["index"]):
|
310 |
names = [channel_info["inputNames"][idx] for idx in indices] if indices!=[None] else ["zero"]
|
@@ -313,7 +313,7 @@ def mapping_result(stage1_info, channel_info, filename):
|
|
313 |
data = {
|
314 |
#"templateNames" : channel_info["templateNames"],
|
315 |
#"inputNames" : channel_info["inputNames"],
|
316 |
-
"
|
317 |
"mappingResult" : results
|
318 |
}
|
319 |
options = jsbeautifier.default_options()
|
@@ -323,59 +323,8 @@ def mapping_result(stage1_info, channel_info, filename):
|
|
323 |
jsonfile.write(json_data)
|
324 |
|
325 |
stage1_info.update({
|
326 |
-
"
|
327 |
"mappingResult" : results
|
328 |
})
|
329 |
return stage1_info, channel_info
|
330 |
|
331 |
-
|
332 |
-
def reorder_data(old_idx, orig_flags, inputname, filename):
|
333 |
-
# read the input data
|
334 |
-
raw_data = utils.read_train_data(inputname)
|
335 |
-
#print(raw_data.shape)
|
336 |
-
new_data = np.zeros((30, raw_data.shape[1]))
|
337 |
-
|
338 |
-
zero_arr = np.zeros((1, raw_data.shape[1]))
|
339 |
-
for i, (indices, flag) in enumerate(zip(old_idx, orig_flags)):
|
340 |
-
if flag == True:
|
341 |
-
new_data[i, :] = raw_data[indices[0], :]
|
342 |
-
elif indices == [None]:
|
343 |
-
new_data[i, :] = zero_arr
|
344 |
-
else:
|
345 |
-
tmp_data = [raw_data[idx, :] for idx in indices]
|
346 |
-
new_data[i, :] = np.mean(tmp_data, axis=0)
|
347 |
-
|
348 |
-
utils.save_data(new_data, filename)
|
349 |
-
return raw_data.shape
|
350 |
-
|
351 |
-
def restore_order(batch_cnt, raw_data_shape, old_idx, orig_flags, filename, outputname):
|
352 |
-
# read the denoised data
|
353 |
-
d_data = utils.read_train_data(filename)
|
354 |
-
if batch_cnt == 0:
|
355 |
-
new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
|
356 |
-
#print(new_data.shape)
|
357 |
-
else:
|
358 |
-
new_data = utils.read_train_data(outputname)
|
359 |
-
|
360 |
-
for i, (indices, flag) in enumerate(zip(old_idx, orig_flags)):
|
361 |
-
if flag == True:
|
362 |
-
new_data[indices[0], :] = d_data[i, :]
|
363 |
-
|
364 |
-
utils.save_data(new_data, outputname)
|
365 |
-
return
|
366 |
-
|
367 |
-
def run_model(modelname, filepath, inputname, m_filename, d_filename, outputname, samplerate, batch_cnt, old_idx, orig_flags):
|
368 |
-
# establish temp folder
|
369 |
-
os.mkdir(filepath+'temp_data/')
|
370 |
-
|
371 |
-
# step1: Reorder data
|
372 |
-
data_shape = reorder_data(old_idx, orig_flags, inputname, filepath+'temp_data/'+m_filename)
|
373 |
-
# step2: Data preprocessing
|
374 |
-
total_file_num = utils.preprocessing(filepath+'temp_data/', m_filename, samplerate)
|
375 |
-
# step3: Signal reconstruction
|
376 |
-
utils.reconstruct(modelname, total_file_num, filepath+'temp_data/', d_filename, samplerate)
|
377 |
-
# step4: Restore original order
|
378 |
-
restore_order(batch_cnt, data_shape, old_idx, orig_flags, filepath+'temp_data/'+d_filename, outputname)
|
379 |
-
|
380 |
-
utils.dataDelete(filepath+'temp_data/')
|
381 |
-
|
|
|
54 |
tpl_names = tpl_montage.ch_names
|
55 |
in_names = in_montage.ch_names
|
56 |
old_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
|
57 |
+
is_orig_data = [False]*30
|
58 |
|
59 |
alias_dict = {
|
60 |
'T3': 'T7',
|
|
|
70 |
|
71 |
if name in in_dict:
|
72 |
old_idx[i] = [in_dict[name]["index"]]
|
73 |
+
is_orig_data[i] = True
|
74 |
tpl_dict[name]["matched"] = True
|
75 |
in_dict[name]["assigned"] = True
|
76 |
|
|
|
83 |
"mappingResult" : [
|
84 |
{
|
85 |
"index" : old_idx,
|
86 |
+
"isOriginalData" : is_orig_data
|
87 |
}
|
88 |
]
|
89 |
})
|
|
|
270 |
|
271 |
# store the mapping result
|
272 |
old_idx = [[None]]*30
|
273 |
+
is_orig_data = [False]*30
|
274 |
for i, j in zip(row_idx, col_idx):
|
275 |
if j < len(unass_in_names): # filter out dummy channels
|
276 |
tpl_name = tpl_names[i]
|
277 |
in_name = unass_in_names[j]
|
278 |
|
279 |
old_idx[i] = [in_dict[in_name]["index"]]
|
280 |
+
is_orig_data[i] = True
|
281 |
tpl_dict[tpl_name]["matched"] = True
|
282 |
in_dict[in_name]["assigned"] = True
|
283 |
|
|
|
288 |
|
289 |
result = {
|
290 |
"index" : old_idx,
|
291 |
+
"isOriginalData" : is_orig_data
|
292 |
}
|
293 |
channel_info["inputDict"] = in_dict
|
294 |
return result, channel_info
|
295 |
|
296 |
def mapping_result(stage1_info, channel_info, filename):
|
297 |
unassigned_num = len(stage1_info["unassignedInput"])
|
298 |
+
batch = math.ceil(unassigned_num/30) + 1
|
299 |
|
300 |
# map the remaining in_channels
|
301 |
results = stage1_info["mappingResult"]
|
302 |
+
for i in range(1, batch):
|
303 |
# optimally select 30 in_channels to map to the tpl_channels based on proximity
|
304 |
result, channel_info = optimal_mapping(channel_info)
|
305 |
results += [result]
|
306 |
'''
|
307 |
+
for i in range(batch):
|
308 |
results[i]["name"] = {}
|
309 |
for j, indices in enumerate(results[i]["index"]):
|
310 |
names = [channel_info["inputNames"][idx] for idx in indices] if indices!=[None] else ["zero"]
|
|
|
313 |
data = {
|
314 |
#"templateNames" : channel_info["templateNames"],
|
315 |
#"inputNames" : channel_info["inputNames"],
|
316 |
+
"batch" : batch,
|
317 |
"mappingResult" : results
|
318 |
}
|
319 |
options = jsbeautifier.default_options()
|
|
|
323 |
jsonfile.write(json_data)
|
324 |
|
325 |
stage1_info.update({
|
326 |
+
"batch" : batch,
|
327 |
"mappingResult" : results
|
328 |
})
|
329 |
return stage1_info, channel_info
|
330 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/{EEGART β ART}/modelsave/checkpoint.pth.tar
RENAMED
File without changes
|
model/{EEGART β ART}/modelsave/model_trainValLog.txt
RENAMED
File without changes
|
model/{UNetpp β ICUNet++}/modelsave/BEST_checkpoint.pth.tar
RENAMED
File without changes
|
model/{UNetpp β ICUNet++}/modelsave/checkpoint.pth.tar
RENAMED
File without changes
|
model/{UNetpp β ICUNet++}/modelsave/model_trainValLog.txt
RENAMED
File without changes
|
model/{AttUnet β ICUNet_attn}/modelsave/BEST_checkpoint.pth.tar
RENAMED
File without changes
|
model/{AttUnet β ICUNet_attn}/modelsave/checkpoint.pth.tar
RENAMED
File without changes
|
model/{AttUnet β ICUNet_attn}/modelsave/model_trainValLog.txt
RENAMED
File without changes
|
model/__pycache__/UNet_attention.cpython-310.pyc
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 13043
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c22a35e64a5b9c7ebb09e866206e4e2ebea59bc1c8253ac2795bb0b64c84df16
|
3 |
size 13043
|
model/__pycache__/tf_data.cpython-310.pyc
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 5981
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a57237e3737f55f90c3ee13409cc90034d55ef94610c1fbcdd4768956757341
|
3 |
size 5981
|
model/__pycache__/tf_model.cpython-310.pyc
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 12231
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d12b18b8c5cd4950c70c754e10f70272dd438c980c25a5b712bb1881ece86480
|
3 |
size 12231
|
utils.py
CHANGED
@@ -15,38 +15,12 @@ from scipy.signal import decimate, resample_poly, firwin, lfilter
|
|
15 |
|
16 |
|
17 |
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
|
|
18 |
|
19 |
-
def resample(signal, fs):
|
20 |
-
# downsample the signal to a sample rate of 256 Hz
|
21 |
-
if fs>256:
|
22 |
-
fs_down = 256 # Desired sample rate
|
23 |
-
q = int(fs / fs_down) # Downsampling factor
|
24 |
-
signal_new = []
|
25 |
-
for ch in signal:
|
26 |
-
x_down = decimate(ch, q)
|
27 |
-
signal_new.append(x_down)
|
28 |
-
|
29 |
-
# upsample the signal to a sample rate of 256 Hz
|
30 |
-
elif fs<256:
|
31 |
-
fs_up = 256 # Desired sample rate
|
32 |
-
p = int(fs_up / fs) # Upsampling factor
|
33 |
-
signal_new = []
|
34 |
-
for ch in signal:
|
35 |
-
x_up = resample_poly(ch, p, 1)
|
36 |
-
signal_new.append(x_up)
|
37 |
-
|
38 |
-
else:
|
39 |
-
signal_new = signal
|
40 |
-
|
41 |
-
signal_new = np.array(signal_new).astype(np.float64)
|
42 |
-
|
43 |
-
return signal_new
|
44 |
-
|
45 |
-
def resample_(signal, current_fs, target_fs):
|
46 |
-
fs = current_fs
|
47 |
# downsample the signal to the target sample rate
|
48 |
-
if fs>
|
49 |
-
fs_down =
|
50 |
q = int(fs / fs_down) # Downsampling factor
|
51 |
signal_new = []
|
52 |
for ch in signal:
|
@@ -54,8 +28,8 @@ def resample_(signal, current_fs, target_fs):
|
|
54 |
signal_new.append(x_down)
|
55 |
|
56 |
# upsample the signal to the target sample rate
|
57 |
-
elif fs<
|
58 |
-
fs_up =
|
59 |
p = int(fs_up / fs) # Upsampling factor
|
60 |
signal_new = []
|
61 |
for ch in signal:
|
@@ -104,7 +78,7 @@ def cut_data(filepath, raw_data):
|
|
104 |
return total
|
105 |
|
106 |
|
107 |
-
def glue_data(file_name, total
|
108 |
gluedata = 0
|
109 |
for i in range(total):
|
110 |
file_name1 = file_name + 'output{}.csv'.format(str(i))
|
@@ -123,13 +97,6 @@ def glue_data(file_name, total, output):
|
|
123 |
raw_data[:, 1] = smooth
|
124 |
gluedata = np.append(gluedata, raw_data, axis=1)
|
125 |
#print(gluedata.shape)
|
126 |
-
'''
|
127 |
-
filename2 = output
|
128 |
-
with open(filename2, 'w', newline='') as csvfile:
|
129 |
-
writer = csv.writer(csvfile)
|
130 |
-
writer.writerows(gluedata)
|
131 |
-
#print("GLUE DONE!" + filename2)
|
132 |
-
'''
|
133 |
return gluedata
|
134 |
|
135 |
|
@@ -142,8 +109,7 @@ def dataDelete(path):
|
|
142 |
try:
|
143 |
shutil.rmtree(path)
|
144 |
except OSError as e:
|
145 |
-
|
146 |
-
#print(e)
|
147 |
else:
|
148 |
pass
|
149 |
#print("The directory is deleted successfully")
|
@@ -153,64 +119,78 @@ def decode_data(data, std_num, mode=5):
|
|
153 |
|
154 |
if mode == "ICUNet":
|
155 |
# 1. read name
|
156 |
-
model = cumbersome_model2.UNet1(n_channels=30, n_classes=30)
|
157 |
resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
|
158 |
# 2. load model
|
159 |
-
checkpoint = torch.load(resumeLoc, map_location=
|
160 |
model.load_state_dict(checkpoint['state_dict'], False)
|
161 |
model.eval()
|
162 |
# 3. decode strategy
|
163 |
with torch.no_grad():
|
164 |
data = data[np.newaxis, :, :]
|
165 |
-
data = torch.Tensor(data)
|
166 |
decode = model(data)
|
167 |
|
168 |
|
169 |
-
elif mode == "
|
170 |
# 1. read name
|
171 |
-
if mode == "
|
172 |
-
model = UNet_family.NestedUNet3(num_classes=30)
|
173 |
-
elif mode == "
|
174 |
-
model = UNet_attention.UNetpp3_Transformer(num_classes=30)
|
175 |
-
resumeLoc = './model/'+ mode + '/modelsave' + '/checkpoint.pth.tar'
|
176 |
# 2. load model
|
177 |
-
checkpoint = torch.load(resumeLoc, map_location=
|
178 |
model.load_state_dict(checkpoint['state_dict'], False)
|
179 |
model.eval()
|
180 |
# 3. decode strategy
|
181 |
with torch.no_grad():
|
182 |
data = data[np.newaxis, :, :]
|
183 |
-
data = torch.Tensor(data)
|
184 |
decode1, decode2, decode = model(data)
|
185 |
|
186 |
|
187 |
-
elif mode == "
|
188 |
# 1. read name
|
189 |
resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
|
190 |
# 2. load model
|
191 |
-
checkpoint = torch.load(resumeLoc, map_location=
|
192 |
-
model = tf_model.make_model(30, 30, N=2)
|
193 |
model.load_state_dict(checkpoint['state_dict'])
|
194 |
model.eval()
|
195 |
# 3. decode strategy
|
196 |
with torch.no_grad():
|
197 |
-
data = torch.FloatTensor(data)
|
198 |
data = data.unsqueeze(0)
|
199 |
src = data
|
200 |
-
tgt = data
|
201 |
batch = tf_data.Batch(src, tgt, 0)
|
202 |
out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
|
203 |
decode = model.generator(out)
|
204 |
decode = decode.permute(0, 2, 1)
|
205 |
-
|
206 |
-
|
207 |
|
208 |
# 4. numpy
|
209 |
#print(decode.shape)
|
210 |
decode = np.array(decode.cpu()).astype(np.float64)
|
211 |
return decode
|
212 |
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
# establish temp folder
|
215 |
try:
|
216 |
os.mkdir(filepath+"temp2/")
|
@@ -220,10 +200,13 @@ def preprocessing(filepath, filename, samplerate):
|
|
220 |
print(e)
|
221 |
|
222 |
# read data
|
223 |
-
signal = read_train_data(
|
|
|
|
|
|
|
224 |
#print(signal.shape)
|
225 |
# resample
|
226 |
-
signal = resample(signal, samplerate
|
227 |
#print(signal.shape)
|
228 |
# FIR_filter
|
229 |
signal = FIR_filter(signal, 1, 50)
|
@@ -231,11 +214,27 @@ def preprocessing(filepath, filename, samplerate):
|
|
231 |
# cutting data
|
232 |
total_file_num = cut_data(filepath, signal)
|
233 |
|
234 |
-
return total_file_num
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
|
237 |
# model = tf.keras.models.load_model('./denoise_model/')
|
238 |
-
def reconstruct(model_name, total, filepath,
|
239 |
# -------------------decode_data---------------------------
|
240 |
second1 = time.time()
|
241 |
for i in range(total):
|
@@ -255,16 +254,11 @@ def reconstruct(model_name, total, filepath, outputfile, samplerate):
|
|
255 |
save_data(d_data, outputname)
|
256 |
|
257 |
# --------------------glue_data----------------------------
|
258 |
-
|
259 |
-
#print(signal.shape)
|
260 |
# -------------------delete_data---------------------------
|
261 |
dataDelete(filepath+"temp2/")
|
262 |
-
# --------------------resample-----------------------------
|
263 |
-
signal = resample_(signal, 256, samplerate)
|
264 |
-
#print(signal.shape)
|
265 |
-
# --------------------save_data----------------------------
|
266 |
-
save_data(signal, filepath+outputfile)
|
267 |
second2 = time.time()
|
268 |
-
|
269 |
-
print("Using
|
270 |
-
|
|
|
|
15 |
|
16 |
|
17 |
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
18 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
19 |
|
20 |
+
def resample(signal, fs, tgt_fs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# downsample the signal to the target sample rate
|
22 |
+
if fs>tgt_fs:
|
23 |
+
fs_down = tgt_fs # Desired sample rate
|
24 |
q = int(fs / fs_down) # Downsampling factor
|
25 |
signal_new = []
|
26 |
for ch in signal:
|
|
|
28 |
signal_new.append(x_down)
|
29 |
|
30 |
# upsample the signal to the target sample rate
|
31 |
+
elif fs<tgt_fs:
|
32 |
+
fs_up = tgt_fs # Desired sample rate
|
33 |
p = int(fs_up / fs) # Upsampling factor
|
34 |
signal_new = []
|
35 |
for ch in signal:
|
|
|
78 |
return total
|
79 |
|
80 |
|
81 |
+
def glue_data(file_name, total):
|
82 |
gluedata = 0
|
83 |
for i in range(total):
|
84 |
file_name1 = file_name + 'output{}.csv'.format(str(i))
|
|
|
97 |
raw_data[:, 1] = smooth
|
98 |
gluedata = np.append(gluedata, raw_data, axis=1)
|
99 |
#print(gluedata.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
return gluedata
|
101 |
|
102 |
|
|
|
109 |
try:
|
110 |
shutil.rmtree(path)
|
111 |
except OSError as e:
|
112 |
+
print('dataDelete:', e)
|
|
|
113 |
else:
|
114 |
pass
|
115 |
#print("The directory is deleted successfully")
|
|
|
119 |
|
120 |
if mode == "ICUNet":
|
121 |
# 1. read name
|
122 |
+
model = cumbersome_model2.UNet1(n_channels=30, n_classes=30).to(device)
|
123 |
resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
|
124 |
# 2. load model
|
125 |
+
checkpoint = torch.load(resumeLoc, map_location=device)
|
126 |
model.load_state_dict(checkpoint['state_dict'], False)
|
127 |
model.eval()
|
128 |
# 3. decode strategy
|
129 |
with torch.no_grad():
|
130 |
data = data[np.newaxis, :, :]
|
131 |
+
data = torch.Tensor(data).to(device)
|
132 |
decode = model(data)
|
133 |
|
134 |
|
135 |
+
elif mode == "ICUNet++" or mode == "ICUnet_attn":
|
136 |
# 1. read name
|
137 |
+
if mode == "ICUNet++":
|
138 |
+
model = UNet_family.NestedUNet3(num_classes=30).to(device)
|
139 |
+
elif mode == "ICUnet_attn":
|
140 |
+
model = UNet_attention.UNetpp3_Transformer(num_classes=30).to(device)
|
141 |
+
resumeLoc = './model/' + mode + '/modelsave' + '/checkpoint.pth.tar'
|
142 |
# 2. load model
|
143 |
+
checkpoint = torch.load(resumeLoc, map_location=device)
|
144 |
model.load_state_dict(checkpoint['state_dict'], False)
|
145 |
model.eval()
|
146 |
# 3. decode strategy
|
147 |
with torch.no_grad():
|
148 |
data = data[np.newaxis, :, :]
|
149 |
+
data = torch.Tensor(data).to(device)
|
150 |
decode1, decode2, decode = model(data)
|
151 |
|
152 |
|
153 |
+
elif mode == "ART":
|
154 |
# 1. read name
|
155 |
resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
|
156 |
# 2. load model
|
157 |
+
checkpoint = torch.load(resumeLoc, map_location=device)
|
158 |
+
model = tf_model.make_model(30, 30, N=2).to(device)
|
159 |
model.load_state_dict(checkpoint['state_dict'])
|
160 |
model.eval()
|
161 |
# 3. decode strategy
|
162 |
with torch.no_grad():
|
163 |
+
data = torch.FloatTensor(data).to(device)
|
164 |
data = data.unsqueeze(0)
|
165 |
src = data
|
166 |
+
tgt = data # you can modify to randomize data
|
167 |
batch = tf_data.Batch(src, tgt, 0)
|
168 |
out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
|
169 |
decode = model.generator(out)
|
170 |
decode = decode.permute(0, 2, 1)
|
171 |
+
add_tensor = torch.zeros(1, 30, 1).to(device)
|
172 |
+
decode = torch.cat((decode, add_tensor), dim=2)
|
173 |
|
174 |
# 4. numpy
|
175 |
#print(decode.shape)
|
176 |
decode = np.array(decode.cpu()).astype(np.float64)
|
177 |
return decode
|
178 |
|
179 |
+
|
180 |
+
def reorder_data(raw_data, mapping_result):
|
181 |
+
new_data = np.zeros((30, raw_data.shape[1]))
|
182 |
+
zero_arr = np.zeros((1, raw_data.shape[1]))
|
183 |
+
for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):
|
184 |
+
if flag == True:
|
185 |
+
new_data[i, :] = raw_data[indices[0], :]
|
186 |
+
elif indices[0] == None:
|
187 |
+
new_data[i, :] = zero_arr
|
188 |
+
else:
|
189 |
+
data = [raw_data[idx, :] for idx in indices]
|
190 |
+
new_data[i, :] = np.mean(data, axis=0)
|
191 |
+
return new_data
|
192 |
+
|
193 |
+
def preprocessing(filepath, inputfile, samplerate, mapping_result):
|
194 |
# establish temp folder
|
195 |
try:
|
196 |
os.mkdir(filepath+"temp2/")
|
|
|
200 |
print(e)
|
201 |
|
202 |
# read data
|
203 |
+
signal = read_train_data(inputfile)
|
204 |
+
channel_num = signal.shape[0]
|
205 |
+
# reorder data
|
206 |
+
signal = reorder_data(signal, mapping_result)
|
207 |
#print(signal.shape)
|
208 |
# resample
|
209 |
+
signal = resample(signal, samplerate, 256)
|
210 |
#print(signal.shape)
|
211 |
# FIR_filter
|
212 |
signal = FIR_filter(signal, 1, 50)
|
|
|
214 |
# cutting data
|
215 |
total_file_num = cut_data(filepath, signal)
|
216 |
|
217 |
+
return total_file_num, channel_num
|
218 |
+
|
219 |
+
def restore_order(data, all_data, mapping_result):
|
220 |
+
for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):
|
221 |
+
if flag == True:
|
222 |
+
all_data[indices[0], :] = data[i, :]
|
223 |
+
return all_data
|
224 |
+
|
225 |
+
def postprocessing(data, samplerate, outputfile, mapping_result, batch_cnt, channel_num):
|
226 |
+
|
227 |
+
# resample to original sampling rate
|
228 |
+
data = resample(data, 256, samplerate)
|
229 |
+
# restore original order
|
230 |
+
all_data = np.zeros((channel_num, data.shape[1])) if batch_cnt==0 else read_train_data(outputfile)
|
231 |
+
all_data = restore_order(data, all_data, mapping_result)
|
232 |
+
# save data
|
233 |
+
save_data(all_data, outputfile)
|
234 |
|
235 |
|
236 |
# model = tf.keras.models.load_model('./denoise_model/')
|
237 |
+
def reconstruct(model_name, total, filepath, batch_cnt):
|
238 |
# -------------------decode_data---------------------------
|
239 |
second1 = time.time()
|
240 |
for i in range(total):
|
|
|
254 |
save_data(d_data, outputname)
|
255 |
|
256 |
# --------------------glue_data----------------------------
|
257 |
+
data = glue_data(filepath+"temp2/", total)
|
|
|
258 |
# -------------------delete_data---------------------------
|
259 |
dataDelete(filepath+"temp2/")
|
|
|
|
|
|
|
|
|
|
|
260 |
second2 = time.time()
|
261 |
+
|
262 |
+
print(f"Using {model_name} model to reconstruct batch-{batch_cnt+1} has been success in {second2 - second1} sec(s)")
|
263 |
+
return data
|
264 |
+
|