audrey06100 commited on
Commit
6af0f90
Β·
1 Parent(s): 5bfc499
app.py CHANGED
@@ -45,7 +45,7 @@ Once all template channels are filled, you will be directed to **Mapping Result*
45
  ### Mapping Result
46
  After completing the previous steps, your channels will be aligned with the template channels required by our models.
47
  - In case there are still some channels that haven't been mapped, we will automatically batch and optimally assign them to the template. This ensures that even channels not initially mapped will still be included in the final result.
48
- - Once the mapping process is completed, a JSON file containing the mapping result will be generated. This file is necessary only if you plan to run the models using the <a href="">source code</a>; otherwise, you can ignore it.
49
 
50
  ## 2. Decode data
51
  After clicking on ``Run`` button, we will process your EEG data based on the mapping result. If necessary, your data will be divided into batches and run the models on each batch sequentially, ensuring that all channels are properly processed.
@@ -278,11 +278,11 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
278
  with gr.Column():
279
  in_samplerate = gr.Textbox(label="Sampling rate (Hz)")
280
  in_modelname = gr.Dropdown(choices=[
281
- ("ART", "EEGART"),
282
  ("IC-U-Net", "ICUNet"),
283
- ("IC-U-Net++", "UNetpp"),
284
- ("IC-U-Net-Attn", "AttUnet")],
285
- value="EEGART",
286
  label="Model")
287
  run_btn = gr.Button("Run", interactive=False)
288
  cancel_btn = gr.Button("Cancel", visible=False)
@@ -304,7 +304,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
304
  gr.Markdown()
305
 
306
  def create_dir(req: gr.Request):
307
- os.mkdir(gradio_temp_dir+'/'+req.session_hash)
308
  return gradio_temp_dir+'/'+req.session_hash+'/'
309
  demo.load(create_dir, inputs=[], outputs=session_dir)
310
 
@@ -326,8 +326,8 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
326
  stage1_dir = uuid.uuid4().hex + '_stage1/'
327
  os.mkdir(rootpath + stage1_dir)
328
 
329
- basename = os.path.basename(str(in_loc))
330
- outputname = os.path.splitext(basename)[0] + '_mapping_result.json'
331
 
332
  stage1_info = {
333
  "filePath" : rootpath + stage1_dir,
@@ -349,7 +349,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
349
  },
350
  "unassignedInput" : None,
351
  "emptyTemplate" : None,
352
- "batchNum" : None,
353
  "mappingResult" : [
354
  {
355
  "index" : None,
@@ -436,7 +436,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
436
  md = """
437
  ### Mapping Result
438
  The mapping process has been finished.
439
- Download the file below if you plan to run the models using the <a href="">source code</a>.
440
  """
441
  # finalize and save the mapping result
442
  outputname = stage1_info["fileNames"]["outputData"]
@@ -523,7 +523,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
523
  md = """
524
  ### Mapping Result
525
  The mapping process has been finished.
526
- Download the file below if you plan to run the models using the <a href="">source code</a>.
527
  """
528
  outputname = stage1_info["fileNames"]["outputData"]
529
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
@@ -560,7 +560,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
560
  md = """
561
  ### Mapping Result
562
  The mapping process has been finished.
563
- Download the file below if you plan to run the models using the <a href="">source code</a>.
564
  """
565
  outputname = stage1_info["fileNames"]["outputData"]
566
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
@@ -628,7 +628,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
628
  md = """
629
  ### Mapping Result
630
  The mapping process has been finished.
631
- Download the file below if you plan to run the models using the <a href="">source code</a>.
632
  """
633
  outputname = stage1_info["fileNames"]["outputData"]
634
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
@@ -814,8 +814,8 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
814
  stage2_dir = uuid.uuid4().hex + '_stage2/'
815
  os.mkdir(rootpath + stage2_dir)
816
 
817
- basename = os.path.basename(str(in_data))
818
- outputname = modelname + '_'+os.path.splitext(basename)[0] + '.csv'
819
 
820
  stage2_info = {
821
  "filePath" : rootpath + stage2_dir,
@@ -832,31 +832,27 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
832
  batch_md : gr.Markdown("", visible=True),
833
  out_data_file : gr.File(value=None, visible=False)}
834
 
835
- def run_stage2(stage1_info, stage2_info, samplerate, modelname):
836
  if stage2_info["errorFlag"] == True:
837
  stage2_info["errorFlag"] = False
838
  yield {stage2_json : stage2_info}
839
 
840
  else:
 
841
  inputname = stage2_info["fileNames"]["inputData"]
842
  outputname = stage2_info["fileNames"]["outputData"]
843
- basename = os.path.basename(str(outputname))
844
- basename = os.path.splitext(basename)[0]
845
  break_flag = False
846
 
847
- for i in range(stage1_info["batchNum"]):
848
- yield {batch_md : gr.Markdown('Running model({}/{})...'.format(i+1, stage1_info["batchNum"]))}
849
  try:
850
- app_utils.run_model(modelname = modelname,
851
- filepath = stage2_info["filePath"],
852
- inputname = inputname,
853
- m_filename = 'mapped_{:02d}.csv'.format(i+1),
854
- d_filename = '{}_{:02d}.csv'.format(basename, i+1),
855
- outputname = outputname,
856
- samplerate = int(samplerate),
857
- batch_cnt = i,
858
- old_idx = stage1_info["mappingResult"][i]["index"],
859
- orig_flags = stage1_info["mappingResult"][i]["isOriginalData"])
860
  except FileNotFoundError:
861
  print('stop!!')
862
  break_flag = True
@@ -877,7 +873,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
877
  inputs = [session_dir, stage2_json, in_data_file, in_samplerate, in_modelname],
878
  outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
879
  ).success(
880
- fn = run_stage2,
881
  inputs = [stage1_json, stage2_json, in_samplerate, in_modelname],
882
  outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
883
  )
 
45
  ### Mapping Result
46
  After completing the previous steps, your channels will be aligned with the template channels required by our models.
47
  - In case there are still some channels that haven't been mapped, we will automatically batch and optimally assign them to the template. This ensures that even channels not initially mapped will still be included in the final result.
48
+ - Once the mapping process is completed, a JSON file containing the mapping result will be generated. This file is necessary only if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>; otherwise, you can ignore it.
49
 
50
  ## 2. Decode data
51
  After clicking on ``Run`` button, we will process your EEG data based on the mapping result. If necessary, your data will be divided into batches and run the models on each batch sequentially, ensuring that all channels are properly processed.
 
278
  with gr.Column():
279
  in_samplerate = gr.Textbox(label="Sampling rate (Hz)")
280
  in_modelname = gr.Dropdown(choices=[
281
+ ("ART", "ART"),
282
  ("IC-U-Net", "ICUNet"),
283
+ ("IC-U-Net++", "ICUNet++"),
284
+ ("IC-U-Net-Attn", "ICUnet_attn")],
285
+ value="ART",
286
  label="Model")
287
  run_btn = gr.Button("Run", interactive=False)
288
  cancel_btn = gr.Button("Cancel", visible=False)
 
304
  gr.Markdown()
305
 
306
  def create_dir(req: gr.Request):
307
+ os.mkdir(gradio_temp_dir+'/'+req.session_hash+'/')
308
  return gradio_temp_dir+'/'+req.session_hash+'/'
309
  demo.load(create_dir, inputs=[], outputs=session_dir)
310
 
 
326
  stage1_dir = uuid.uuid4().hex + '_stage1/'
327
  os.mkdir(rootpath + stage1_dir)
328
 
329
+ inputname = os.path.basename(str(in_loc))
330
+ outputname = inputname[:-4] + '_mapping_result.json'
331
 
332
  stage1_info = {
333
  "filePath" : rootpath + stage1_dir,
 
349
  },
350
  "unassignedInput" : None,
351
  "emptyTemplate" : None,
352
+ "batch" : None,
353
  "mappingResult" : [
354
  {
355
  "index" : None,
 
436
  md = """
437
  ### Mapping Result
438
  The mapping process has been finished.
439
+ Download the file below if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>.
440
  """
441
  # finalize and save the mapping result
442
  outputname = stage1_info["fileNames"]["outputData"]
 
523
  md = """
524
  ### Mapping Result
525
  The mapping process has been finished.
526
+ Download the file below if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>.
527
  """
528
  outputname = stage1_info["fileNames"]["outputData"]
529
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
 
560
  md = """
561
  ### Mapping Result
562
  The mapping process has been finished.
563
+ Download the file below if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>.
564
  """
565
  outputname = stage1_info["fileNames"]["outputData"]
566
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
 
628
  md = """
629
  ### Mapping Result
630
  The mapping process has been finished.
631
+ Download the file below if you plan to run the models using the <a href="https://github.com/CNElab-Plus/ArtifactRemovalTransformer">source code</a>.
632
  """
633
  outputname = stage1_info["fileNames"]["outputData"]
634
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, outputname)
 
814
  stage2_dir = uuid.uuid4().hex + '_stage2/'
815
  os.mkdir(rootpath + stage2_dir)
816
 
817
+ inputname = os.path.basename(str(in_data))
818
+ outputname = modelname + '_'+inputname[:-4] + '.csv'
819
 
820
  stage2_info = {
821
  "filePath" : rootpath + stage2_dir,
 
832
  batch_md : gr.Markdown("", visible=True),
833
  out_data_file : gr.File(value=None, visible=False)}
834
 
835
+ def run_model(stage1_info, stage2_info, samplerate, modelname):
836
  if stage2_info["errorFlag"] == True:
837
  stage2_info["errorFlag"] = False
838
  yield {stage2_json : stage2_info}
839
 
840
  else:
841
+ filepath = stage2_info["filePath"]
842
  inputname = stage2_info["fileNames"]["inputData"]
843
  outputname = stage2_info["fileNames"]["outputData"]
844
+ mapping_result = stage1_info["mappingResult"]
 
845
  break_flag = False
846
 
847
+ for i in range(stage1_info["batch"]):
848
+ yield {batch_md : gr.Markdown('Running model({}/{})...'.format(i+1, stage1_info["batch"]))}
849
  try:
850
+ # step1: Data preprocessing
851
+ preprocess_data, channel_num = utils.preprocessing(filepath, inputname, int(samplerate), mapping_result[i])
852
+ # step2: Signal reconstruction
853
+ reconstructed_data = utils.reconstruct(modelname, preprocess_data, filepath, i)
854
+ # step3: Data postprocessing
855
+ utils.postprocessing(reconstructed_data, int(samplerate), outputname, mapping_result[i], i, channel_num)
 
 
 
 
856
  except FileNotFoundError:
857
  print('stop!!')
858
  break_flag = True
 
873
  inputs = [session_dir, stage2_json, in_data_file, in_samplerate, in_modelname],
874
  outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
875
  ).success(
876
+ fn = run_model,
877
  inputs = [stage1_json, stage2_json, in_samplerate, in_modelname],
878
  outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
879
  )
app_utils.py CHANGED
@@ -54,7 +54,7 @@ def match_name(stage1_info):
54
  tpl_names = tpl_montage.ch_names
55
  in_names = in_montage.ch_names
56
  old_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
57
- orig_flags = [False]*30
58
 
59
  alias_dict = {
60
  'T3': 'T7',
@@ -70,7 +70,7 @@ def match_name(stage1_info):
70
 
71
  if name in in_dict:
72
  old_idx[i] = [in_dict[name]["index"]]
73
- orig_flags[i] = True
74
  tpl_dict[name]["matched"] = True
75
  in_dict[name]["assigned"] = True
76
 
@@ -83,7 +83,7 @@ def match_name(stage1_info):
83
  "mappingResult" : [
84
  {
85
  "index" : old_idx,
86
- "isOriginalData" : orig_flags
87
  }
88
  ]
89
  })
@@ -270,14 +270,14 @@ def optimal_mapping(channel_info):
270
 
271
  # store the mapping result
272
  old_idx = [[None]]*30
273
- orig_flags = [False]*30
274
  for i, j in zip(row_idx, col_idx):
275
  if j < len(unass_in_names): # filter out dummy channels
276
  tpl_name = tpl_names[i]
277
  in_name = unass_in_names[j]
278
 
279
  old_idx[i] = [in_dict[in_name]["index"]]
280
- orig_flags[i] = True
281
  tpl_dict[tpl_name]["matched"] = True
282
  in_dict[in_name]["assigned"] = True
283
 
@@ -288,23 +288,23 @@ def optimal_mapping(channel_info):
288
 
289
  result = {
290
  "index" : old_idx,
291
- "isOriginalData" : orig_flags
292
  }
293
  channel_info["inputDict"] = in_dict
294
  return result, channel_info
295
 
296
  def mapping_result(stage1_info, channel_info, filename):
297
  unassigned_num = len(stage1_info["unassignedInput"])
298
- batch_num = math.ceil(unassigned_num/30) + 1
299
 
300
  # map the remaining in_channels
301
  results = stage1_info["mappingResult"]
302
- for i in range(1, batch_num):
303
  # optimally select 30 in_channels to map to the tpl_channels based on proximity
304
  result, channel_info = optimal_mapping(channel_info)
305
  results += [result]
306
  '''
307
- for i in range(batch_num):
308
  results[i]["name"] = {}
309
  for j, indices in enumerate(results[i]["index"]):
310
  names = [channel_info["inputNames"][idx] for idx in indices] if indices!=[None] else ["zero"]
@@ -313,7 +313,7 @@ def mapping_result(stage1_info, channel_info, filename):
313
  data = {
314
  #"templateNames" : channel_info["templateNames"],
315
  #"inputNames" : channel_info["inputNames"],
316
- "batchNum" : batch_num,
317
  "mappingResult" : results
318
  }
319
  options = jsbeautifier.default_options()
@@ -323,59 +323,8 @@ def mapping_result(stage1_info, channel_info, filename):
323
  jsonfile.write(json_data)
324
 
325
  stage1_info.update({
326
- "batchNum" : batch_num,
327
  "mappingResult" : results
328
  })
329
  return stage1_info, channel_info
330
 
331
-
332
- def reorder_data(old_idx, orig_flags, inputname, filename):
333
- # read the input data
334
- raw_data = utils.read_train_data(inputname)
335
- #print(raw_data.shape)
336
- new_data = np.zeros((30, raw_data.shape[1]))
337
-
338
- zero_arr = np.zeros((1, raw_data.shape[1]))
339
- for i, (indices, flag) in enumerate(zip(old_idx, orig_flags)):
340
- if flag == True:
341
- new_data[i, :] = raw_data[indices[0], :]
342
- elif indices == [None]:
343
- new_data[i, :] = zero_arr
344
- else:
345
- tmp_data = [raw_data[idx, :] for idx in indices]
346
- new_data[i, :] = np.mean(tmp_data, axis=0)
347
-
348
- utils.save_data(new_data, filename)
349
- return raw_data.shape
350
-
351
- def restore_order(batch_cnt, raw_data_shape, old_idx, orig_flags, filename, outputname):
352
- # read the denoised data
353
- d_data = utils.read_train_data(filename)
354
- if batch_cnt == 0:
355
- new_data = np.zeros((raw_data_shape[0], d_data.shape[1]))
356
- #print(new_data.shape)
357
- else:
358
- new_data = utils.read_train_data(outputname)
359
-
360
- for i, (indices, flag) in enumerate(zip(old_idx, orig_flags)):
361
- if flag == True:
362
- new_data[indices[0], :] = d_data[i, :]
363
-
364
- utils.save_data(new_data, outputname)
365
- return
366
-
367
- def run_model(modelname, filepath, inputname, m_filename, d_filename, outputname, samplerate, batch_cnt, old_idx, orig_flags):
368
- # establish temp folder
369
- os.mkdir(filepath+'temp_data/')
370
-
371
- # step1: Reorder data
372
- data_shape = reorder_data(old_idx, orig_flags, inputname, filepath+'temp_data/'+m_filename)
373
- # step2: Data preprocessing
374
- total_file_num = utils.preprocessing(filepath+'temp_data/', m_filename, samplerate)
375
- # step3: Signal reconstruction
376
- utils.reconstruct(modelname, total_file_num, filepath+'temp_data/', d_filename, samplerate)
377
- # step4: Restore original order
378
- restore_order(batch_cnt, data_shape, old_idx, orig_flags, filepath+'temp_data/'+d_filename, outputname)
379
-
380
- utils.dataDelete(filepath+'temp_data/')
381
-
 
54
  tpl_names = tpl_montage.ch_names
55
  in_names = in_montage.ch_names
56
  old_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
57
+ is_orig_data = [False]*30
58
 
59
  alias_dict = {
60
  'T3': 'T7',
 
70
 
71
  if name in in_dict:
72
  old_idx[i] = [in_dict[name]["index"]]
73
+ is_orig_data[i] = True
74
  tpl_dict[name]["matched"] = True
75
  in_dict[name]["assigned"] = True
76
 
 
83
  "mappingResult" : [
84
  {
85
  "index" : old_idx,
86
+ "isOriginalData" : is_orig_data
87
  }
88
  ]
89
  })
 
270
 
271
  # store the mapping result
272
  old_idx = [[None]]*30
273
+ is_orig_data = [False]*30
274
  for i, j in zip(row_idx, col_idx):
275
  if j < len(unass_in_names): # filter out dummy channels
276
  tpl_name = tpl_names[i]
277
  in_name = unass_in_names[j]
278
 
279
  old_idx[i] = [in_dict[in_name]["index"]]
280
+ is_orig_data[i] = True
281
  tpl_dict[tpl_name]["matched"] = True
282
  in_dict[in_name]["assigned"] = True
283
 
 
288
 
289
  result = {
290
  "index" : old_idx,
291
+ "isOriginalData" : is_orig_data
292
  }
293
  channel_info["inputDict"] = in_dict
294
  return result, channel_info
295
 
296
  def mapping_result(stage1_info, channel_info, filename):
297
  unassigned_num = len(stage1_info["unassignedInput"])
298
+ batch = math.ceil(unassigned_num/30) + 1
299
 
300
  # map the remaining in_channels
301
  results = stage1_info["mappingResult"]
302
+ for i in range(1, batch):
303
  # optimally select 30 in_channels to map to the tpl_channels based on proximity
304
  result, channel_info = optimal_mapping(channel_info)
305
  results += [result]
306
  '''
307
+ for i in range(batch):
308
  results[i]["name"] = {}
309
  for j, indices in enumerate(results[i]["index"]):
310
  names = [channel_info["inputNames"][idx] for idx in indices] if indices!=[None] else ["zero"]
 
313
  data = {
314
  #"templateNames" : channel_info["templateNames"],
315
  #"inputNames" : channel_info["inputNames"],
316
+ "batch" : batch,
317
  "mappingResult" : results
318
  }
319
  options = jsbeautifier.default_options()
 
323
  jsonfile.write(json_data)
324
 
325
  stage1_info.update({
326
+ "batch" : batch,
327
  "mappingResult" : results
328
  })
329
  return stage1_info, channel_info
330
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/{EEGART β†’ ART}/modelsave/checkpoint.pth.tar RENAMED
File without changes
model/{EEGART β†’ ART}/modelsave/model_trainValLog.txt RENAMED
File without changes
model/{UNetpp β†’ ICUNet++}/modelsave/BEST_checkpoint.pth.tar RENAMED
File without changes
model/{UNetpp β†’ ICUNet++}/modelsave/checkpoint.pth.tar RENAMED
File without changes
model/{UNetpp β†’ ICUNet++}/modelsave/model_trainValLog.txt RENAMED
File without changes
model/{AttUnet β†’ ICUNet_attn}/modelsave/BEST_checkpoint.pth.tar RENAMED
File without changes
model/{AttUnet β†’ ICUNet_attn}/modelsave/checkpoint.pth.tar RENAMED
File without changes
model/{AttUnet β†’ ICUNet_attn}/modelsave/model_trainValLog.txt RENAMED
File without changes
model/__pycache__/UNet_attention.cpython-310.pyc CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:85dda6aba36abfc30a358e00499a45a17a2c9b2f16647b8adf91b6d0e6882517
3
  size 13043
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c22a35e64a5b9c7ebb09e866206e4e2ebea59bc1c8253ac2795bb0b64c84df16
3
  size 13043
model/__pycache__/tf_data.cpython-310.pyc CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a0fd072d1c6a351ce360a08e0d267605b1e294b2ee828fe9e04d967eb0546768
3
  size 5981
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a57237e3737f55f90c3ee13409cc90034d55ef94610c1fbcdd4768956757341
3
  size 5981
model/__pycache__/tf_model.cpython-310.pyc CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dbfd285f27e373a49cf4cca80978bb68a566ec349b220663e232868fce11b464
3
  size 12231
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d12b18b8c5cd4950c70c754e10f70272dd438c980c25a5b712bb1881ece86480
3
  size 12231
utils.py CHANGED
@@ -15,38 +15,12 @@ from scipy.signal import decimate, resample_poly, firwin, lfilter
15
 
16
 
17
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
 
18
 
19
- def resample(signal, fs):
20
- # downsample the signal to a sample rate of 256 Hz
21
- if fs>256:
22
- fs_down = 256 # Desired sample rate
23
- q = int(fs / fs_down) # Downsampling factor
24
- signal_new = []
25
- for ch in signal:
26
- x_down = decimate(ch, q)
27
- signal_new.append(x_down)
28
-
29
- # upsample the signal to a sample rate of 256 Hz
30
- elif fs<256:
31
- fs_up = 256 # Desired sample rate
32
- p = int(fs_up / fs) # Upsampling factor
33
- signal_new = []
34
- for ch in signal:
35
- x_up = resample_poly(ch, p, 1)
36
- signal_new.append(x_up)
37
-
38
- else:
39
- signal_new = signal
40
-
41
- signal_new = np.array(signal_new).astype(np.float64)
42
-
43
- return signal_new
44
-
45
- def resample_(signal, current_fs, target_fs):
46
- fs = current_fs
47
  # downsample the signal to the target sample rate
48
- if fs>target_fs:
49
- fs_down = target_fs # Desired sample rate
50
  q = int(fs / fs_down) # Downsampling factor
51
  signal_new = []
52
  for ch in signal:
@@ -54,8 +28,8 @@ def resample_(signal, current_fs, target_fs):
54
  signal_new.append(x_down)
55
 
56
  # upsample the signal to the target sample rate
57
- elif fs<target_fs:
58
- fs_up = target_fs # Desired sample rate
59
  p = int(fs_up / fs) # Upsampling factor
60
  signal_new = []
61
  for ch in signal:
@@ -104,7 +78,7 @@ def cut_data(filepath, raw_data):
104
  return total
105
 
106
 
107
- def glue_data(file_name, total, output):
108
  gluedata = 0
109
  for i in range(total):
110
  file_name1 = file_name + 'output{}.csv'.format(str(i))
@@ -123,13 +97,6 @@ def glue_data(file_name, total, output):
123
  raw_data[:, 1] = smooth
124
  gluedata = np.append(gluedata, raw_data, axis=1)
125
  #print(gluedata.shape)
126
- '''
127
- filename2 = output
128
- with open(filename2, 'w', newline='') as csvfile:
129
- writer = csv.writer(csvfile)
130
- writer.writerows(gluedata)
131
- #print("GLUE DONE!" + filename2)
132
- '''
133
  return gluedata
134
 
135
 
@@ -142,8 +109,7 @@ def dataDelete(path):
142
  try:
143
  shutil.rmtree(path)
144
  except OSError as e:
145
- pass
146
- #print(e)
147
  else:
148
  pass
149
  #print("The directory is deleted successfully")
@@ -153,64 +119,78 @@ def decode_data(data, std_num, mode=5):
153
 
154
  if mode == "ICUNet":
155
  # 1. read name
156
- model = cumbersome_model2.UNet1(n_channels=30, n_classes=30)
157
  resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
158
  # 2. load model
159
- checkpoint = torch.load(resumeLoc, map_location='cpu')
160
  model.load_state_dict(checkpoint['state_dict'], False)
161
  model.eval()
162
  # 3. decode strategy
163
  with torch.no_grad():
164
  data = data[np.newaxis, :, :]
165
- data = torch.Tensor(data)
166
  decode = model(data)
167
 
168
 
169
- elif mode == "UNetpp" or mode == "AttUnet":
170
  # 1. read name
171
- if mode == "UNetpp":
172
- model = UNet_family.NestedUNet3(num_classes=30)
173
- elif mode == "AttUnet":
174
- model = UNet_attention.UNetpp3_Transformer(num_classes=30)
175
- resumeLoc = './model/'+ mode + '/modelsave' + '/checkpoint.pth.tar'
176
  # 2. load model
177
- checkpoint = torch.load(resumeLoc, map_location='cpu')
178
  model.load_state_dict(checkpoint['state_dict'], False)
179
  model.eval()
180
  # 3. decode strategy
181
  with torch.no_grad():
182
  data = data[np.newaxis, :, :]
183
- data = torch.Tensor(data)
184
  decode1, decode2, decode = model(data)
185
 
186
 
187
- elif mode == "EEGART":
188
  # 1. read name
189
  resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
190
  # 2. load model
191
- checkpoint = torch.load(resumeLoc, map_location='cpu')
192
- model = tf_model.make_model(30, 30, N=2)
193
  model.load_state_dict(checkpoint['state_dict'])
194
  model.eval()
195
  # 3. decode strategy
196
  with torch.no_grad():
197
- data = torch.FloatTensor(data)
198
  data = data.unsqueeze(0)
199
  src = data
200
- tgt = data
201
  batch = tf_data.Batch(src, tgt, 0)
202
  out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
203
  decode = model.generator(out)
204
  decode = decode.permute(0, 2, 1)
205
- #add_tensor = torch.zeros(1, 30, 1)
206
- #decode = torch.cat((decode, add_tensor), dim=2)
207
 
208
  # 4. numpy
209
  #print(decode.shape)
210
  decode = np.array(decode.cpu()).astype(np.float64)
211
  return decode
212
 
213
- def preprocessing(filepath, filename, samplerate):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  # establish temp folder
215
  try:
216
  os.mkdir(filepath+"temp2/")
@@ -220,10 +200,13 @@ def preprocessing(filepath, filename, samplerate):
220
  print(e)
221
 
222
  # read data
223
- signal = read_train_data(filepath+filename)
 
 
 
224
  #print(signal.shape)
225
  # resample
226
- signal = resample(signal, samplerate) #signal = resample_(signal, samplerate, 256)
227
  #print(signal.shape)
228
  # FIR_filter
229
  signal = FIR_filter(signal, 1, 50)
@@ -231,11 +214,27 @@ def preprocessing(filepath, filename, samplerate):
231
  # cutting data
232
  total_file_num = cut_data(filepath, signal)
233
 
234
- return total_file_num
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
 
237
  # model = tf.keras.models.load_model('./denoise_model/')
238
- def reconstruct(model_name, total, filepath, outputfile, samplerate):
239
  # -------------------decode_data---------------------------
240
  second1 = time.time()
241
  for i in range(total):
@@ -255,16 +254,11 @@ def reconstruct(model_name, total, filepath, outputfile, samplerate):
255
  save_data(d_data, outputname)
256
 
257
  # --------------------glue_data----------------------------
258
- signal = glue_data(filepath+"temp2/", total, filepath+outputfile)
259
- #print(signal.shape)
260
  # -------------------delete_data---------------------------
261
  dataDelete(filepath+"temp2/")
262
- # --------------------resample-----------------------------
263
- signal = resample_(signal, 256, samplerate)
264
- #print(signal.shape)
265
- # --------------------save_data----------------------------
266
- save_data(signal, filepath+outputfile)
267
  second2 = time.time()
268
-
269
- print("Using", model_name,"model to reconstruct", outputfile, " has been success in", second2 - second1, "sec(s)")
270
-
 
 
15
 
16
 
17
  os.environ["CUDA_VISIBLE_DEVICES"]="0"
18
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
19
 
20
+ def resample(signal, fs, tgt_fs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # downsample the signal to the target sample rate
22
+ if fs>tgt_fs:
23
+ fs_down = tgt_fs # Desired sample rate
24
  q = int(fs / fs_down) # Downsampling factor
25
  signal_new = []
26
  for ch in signal:
 
28
  signal_new.append(x_down)
29
 
30
  # upsample the signal to the target sample rate
31
+ elif fs<tgt_fs:
32
+ fs_up = tgt_fs # Desired sample rate
33
  p = int(fs_up / fs) # Upsampling factor
34
  signal_new = []
35
  for ch in signal:
 
78
  return total
79
 
80
 
81
+ def glue_data(file_name, total):
82
  gluedata = 0
83
  for i in range(total):
84
  file_name1 = file_name + 'output{}.csv'.format(str(i))
 
97
  raw_data[:, 1] = smooth
98
  gluedata = np.append(gluedata, raw_data, axis=1)
99
  #print(gluedata.shape)
 
 
 
 
 
 
 
100
  return gluedata
101
 
102
 
 
109
  try:
110
  shutil.rmtree(path)
111
  except OSError as e:
112
+ print('dataDelete:', e)
 
113
  else:
114
  pass
115
  #print("The directory is deleted successfully")
 
119
 
120
  if mode == "ICUNet":
121
  # 1. read name
122
+ model = cumbersome_model2.UNet1(n_channels=30, n_classes=30).to(device)
123
  resumeLoc = './model/ICUNet/modelsave' + '/checkpoint.pth.tar'
124
  # 2. load model
125
+ checkpoint = torch.load(resumeLoc, map_location=device)
126
  model.load_state_dict(checkpoint['state_dict'], False)
127
  model.eval()
128
  # 3. decode strategy
129
  with torch.no_grad():
130
  data = data[np.newaxis, :, :]
131
+ data = torch.Tensor(data).to(device)
132
  decode = model(data)
133
 
134
 
135
+ elif mode == "ICUNet++" or mode == "ICUnet_attn":
136
  # 1. read name
137
+ if mode == "ICUNet++":
138
+ model = UNet_family.NestedUNet3(num_classes=30).to(device)
139
+ elif mode == "ICUnet_attn":
140
+ model = UNet_attention.UNetpp3_Transformer(num_classes=30).to(device)
141
+ resumeLoc = './model/' + mode + '/modelsave' + '/checkpoint.pth.tar'
142
  # 2. load model
143
+ checkpoint = torch.load(resumeLoc, map_location=device)
144
  model.load_state_dict(checkpoint['state_dict'], False)
145
  model.eval()
146
  # 3. decode strategy
147
  with torch.no_grad():
148
  data = data[np.newaxis, :, :]
149
+ data = torch.Tensor(data).to(device)
150
  decode1, decode2, decode = model(data)
151
 
152
 
153
+ elif mode == "ART":
154
  # 1. read name
155
  resumeLoc = './model/' + mode + '/modelsave/checkpoint.pth.tar'
156
  # 2. load model
157
+ checkpoint = torch.load(resumeLoc, map_location=device)
158
+ model = tf_model.make_model(30, 30, N=2).to(device)
159
  model.load_state_dict(checkpoint['state_dict'])
160
  model.eval()
161
  # 3. decode strategy
162
  with torch.no_grad():
163
+ data = torch.FloatTensor(data).to(device)
164
  data = data.unsqueeze(0)
165
  src = data
166
+ tgt = data # you can modify to randomize data
167
  batch = tf_data.Batch(src, tgt, 0)
168
  out = model.forward(batch.src, batch.src[:,:,1:], batch.src_mask, batch.trg_mask)
169
  decode = model.generator(out)
170
  decode = decode.permute(0, 2, 1)
171
+ add_tensor = torch.zeros(1, 30, 1).to(device)
172
+ decode = torch.cat((decode, add_tensor), dim=2)
173
 
174
  # 4. numpy
175
  #print(decode.shape)
176
  decode = np.array(decode.cpu()).astype(np.float64)
177
  return decode
178
 
179
+
180
+ def reorder_data(raw_data, mapping_result):
181
+ new_data = np.zeros((30, raw_data.shape[1]))
182
+ zero_arr = np.zeros((1, raw_data.shape[1]))
183
+ for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):
184
+ if flag == True:
185
+ new_data[i, :] = raw_data[indices[0], :]
186
+ elif indices[0] == None:
187
+ new_data[i, :] = zero_arr
188
+ else:
189
+ data = [raw_data[idx, :] for idx in indices]
190
+ new_data[i, :] = np.mean(data, axis=0)
191
+ return new_data
192
+
193
+ def preprocessing(filepath, inputfile, samplerate, mapping_result):
194
  # establish temp folder
195
  try:
196
  os.mkdir(filepath+"temp2/")
 
200
  print(e)
201
 
202
  # read data
203
+ signal = read_train_data(inputfile)
204
+ channel_num = signal.shape[0]
205
+ # reorder data
206
+ signal = reorder_data(signal, mapping_result)
207
  #print(signal.shape)
208
  # resample
209
+ signal = resample(signal, samplerate, 256)
210
  #print(signal.shape)
211
  # FIR_filter
212
  signal = FIR_filter(signal, 1, 50)
 
214
  # cutting data
215
  total_file_num = cut_data(filepath, signal)
216
 
217
+ return total_file_num, channel_num
218
+
219
+ def restore_order(data, all_data, mapping_result):
220
+ for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):
221
+ if flag == True:
222
+ all_data[indices[0], :] = data[i, :]
223
+ return all_data
224
+
225
+ def postprocessing(data, samplerate, outputfile, mapping_result, batch_cnt, channel_num):
226
+
227
+ # resample to original sampling rate
228
+ data = resample(data, 256, samplerate)
229
+ # restore original order
230
+ all_data = np.zeros((channel_num, data.shape[1])) if batch_cnt==0 else read_train_data(outputfile)
231
+ all_data = restore_order(data, all_data, mapping_result)
232
+ # save data
233
+ save_data(all_data, outputfile)
234
 
235
 
236
  # model = tf.keras.models.load_model('./denoise_model/')
237
+ def reconstruct(model_name, total, filepath, batch_cnt):
238
  # -------------------decode_data---------------------------
239
  second1 = time.time()
240
  for i in range(total):
 
254
  save_data(d_data, outputname)
255
 
256
  # --------------------glue_data----------------------------
257
+ data = glue_data(filepath+"temp2/", total)
 
258
  # -------------------delete_data---------------------------
259
  dataDelete(filepath+"temp2/")
 
 
 
 
 
260
  second2 = time.time()
261
+
262
+ print(f"Using {model_name} model to reconstruct batch-{batch_cnt+1} has been success in {second2 - second1} sec(s)")
263
+ return data
264
+