audrey06100 commited on
Commit
b1a3cc1
·
1 Parent(s): 23fe20b
Files changed (2) hide show
  1. app.py +64 -63
  2. app_utils.py +2 -4
app.py CHANGED
@@ -18,8 +18,8 @@ This tool is designed to assist you with two main tasks:
18
  ## 1. Channel Mapping
19
  The following steps will guide you through the process of mapping your EEG channels to our template channels.
20
 
21
- ### Step1: Initial Matching and Rescaling
22
- After clicking on ``Mapping`` button, we will first match your channels to our template channels by their names. Using the matched channels as reference points, we will apply Thin Plate Spline (TPS) transformation to rescale your montage to align with our template's scale. The template montage and your rescaled montage will be displayed side by side for comparison. Channels that do not have a match in our template will be **highlighted in red**.
23
  - If your data includes all the 30 template channels, you will see the **Mapping Results**.
24
  - If your data doesn't include all the 30 template channels and you have some channels that do not match the template, you will be directed to **Step2**.
25
  - If all your channels are included in our template but you have fewer than 30 channels, you will be directed to **Step3**.
@@ -227,66 +227,65 @@ with gr.Blocks() as demo:
227
  with gr.Row():
228
 
229
  with gr.Column():
230
- # ------------------------input--------------------------
231
  with gr.Row(variant='panel'):
232
  with gr.Column():
233
  gr.Markdown("# 1.Channel Mapping")
 
234
  in_loc_file = gr.File(label="Channel locations (.loc, .locs, .xyz, .sfp, .txt)",
235
  file_types=[".loc", "locs", ".xyz", ".sfp", ".txt"])
 
 
 
 
236
  with gr.Row():
237
- in_samplerate = gr.Textbox(label="Sampling rate (Hz)", scale=2)
238
- map_btn = gr.Button("Mapping", scale=1)
239
- # ------------------------mapping------------------------
240
- desc_md = gr.Markdown(visible=False)
241
- # step1 : initial matching and rescaling
242
- with gr.Row():
243
- tpl_img = gr.Image("./template_montage.png", label="Template channels", visible=False)
244
- mapped_img = gr.Image(label="Input channels", visible=False)
245
- # step2 : forward unmatched input channels to empty template channels
246
- radio_group = gr.Radio(elem_id="radio-group", visible=False)
247
- # step3 : fill the remaining template channels
248
- with gr.Row():
249
- in_fillmode = gr.Dropdown(choices=["mean", "zero"],
250
- value="mean",
251
- label="Filling method",
252
- visible=False,
253
- scale=2)
254
- fillmode_btn = gr.Button("OK", visible=False, scale=1)
255
- chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
256
- # step4 : mapping result
257
- out_json_file = gr.File(label="Mapping result", visible=False)
258
- res_md = gr.Markdown("""
259
- (Download this file if you plan to run the models using the source code.)
260
- """, visible=False)
261
-
262
- with gr.Row():
263
- clear_btn = gr.Button("Clear", visible=False)
264
- step2_btn = gr.Button("Next", visible=False)
265
- step3_btn = gr.Button("Next", visible=False)
266
- next_btn = gr.Button("Next step", visible=False)
267
- # -------------------------------------------------------
268
 
269
  with gr.Column():
270
- # ------------------------input--------------------------
271
  with gr.Row(variant='panel'):
272
  with gr.Column():
273
  gr.Markdown("# 2.Decode Data")
274
- in_data_file = gr.File(label="Raw data (.csv)", file_types=[".csv"])
275
  with gr.Row():
276
- in_modelname = gr.Dropdown(choices=[
277
- ("ART", "EEGART"),
278
- ("IC-U-Net", "ICUNet"),
279
- ("IC-U-Net++", "UNetpp"),
280
- ("IC-U-Net-Attn", "AttUnet")],
281
- #"(mapped data)"],
282
- value="EEGART",
283
- label="Model",
284
- scale=2)
285
- run_btn = gr.Button(interactive=False, scale=1)
286
- # ------------------------output-------------------------
287
- batch_md = gr.Markdown(visible=False)
288
- out_data_file = gr.File(label="Denoised data", visible=False)
289
- # -------------------------------------------------------
 
 
290
 
291
  with gr.Row():
292
  with gr.Tab("ART"):
@@ -303,7 +302,7 @@ with gr.Blocks() as demo:
303
  # +========================================================================================+
304
  # | Stage1: channel mapping |
305
  # +========================================================================================+
306
- def reset_all(in_loc, samplerate):
307
  # establish a new folder for the current session
308
  rootpath = os.path.dirname(str(in_loc))
309
  try:
@@ -319,7 +318,6 @@ with gr.Blocks() as demo:
319
  channel_info = {}
320
  app_info = {
321
  "rootPath" : rootpath+"/session_data/",
322
- "sampleRate" : int(samplerate),
323
  "stage1" : {
324
  "filePath" : rootpath+"/session_data/stage1/",
325
  "fileNames" : {
@@ -346,7 +344,8 @@ with gr.Blocks() as demo:
346
  "input_data" : "",
347
  "output_data" : ""
348
  },
349
- "totalBatchNum" : None
 
350
  }
351
  }
352
  return {app_info_json : app_info,
@@ -368,6 +367,7 @@ with gr.Blocks() as demo:
368
  res_md : gr.Markdown(visible=False),
369
  # --------------------Stage2-------------------------
370
  in_data_file : gr.File(value=None),
 
371
  run_btn : gr.Button(interactive=False),
372
  batch_md : gr.Markdown(visible=False),
373
  out_data_file : gr.File(visible=False)}
@@ -388,7 +388,7 @@ with gr.Blocks() as demo:
388
 
389
  # match the names of in_channels and tpl_channels
390
  stage1_info, channel_info, tpl_montage, in_montage = app_utils.match_names(stage1_info, channel_info)
391
- # rescale coordinates
392
  channel_info = app_utils.align_coords(channel_info, tpl_montage, in_montage)
393
  # generate and save figures of the montages
394
  filename1 = filepath+"input_montage_"+str(random.randint(1,10000))+".png"
@@ -402,12 +402,12 @@ with gr.Blocks() as demo:
402
  unassigned_num = len(stage1_info["unassignedInputs"])
403
  if unassigned_num == 0:
404
  md = """
405
- ### Step1: Initial Matching and Rescaling
406
  Below is the result of mapping your channels to our template channels based on their names.
407
  """
408
  else:
409
  md = """
410
- ### Step1: Initial Matching and Rescaling
411
  Below is the result of mapping your channels to our template channels based on their names.
412
  - channels highlighted in red are those that do not match any template channels.
413
  """
@@ -682,10 +682,10 @@ with gr.Blocks() as demo:
682
  # +========================================================================================+
683
  map_btn.click(
684
  fn = reset_all,
685
- inputs = [in_loc_file, in_samplerate],
686
  outputs = [app_info_json, channel_info_json, map_btn, desc_md, next_btn, tpl_img, mapped_img,
687
  radio_group, clear_btn, step2_btn, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
688
- out_json_file, res_md, in_data_file, run_btn, batch_md, out_data_file]
689
  ).success(
690
  fn = init_next_step,
691
  inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
@@ -833,7 +833,7 @@ with gr.Blocks() as demo:
833
  # +========================================================================================+
834
  # | Stage2: decode data |
835
  # +========================================================================================+
836
- def reset_run(app_info, in_data, modelname):
837
  stage1_info = app_info["stage1"]
838
  stage2_info = app_info["stage2"]
839
 
@@ -852,7 +852,8 @@ with gr.Blocks() as demo:
852
  "fileNames" : {
853
  "input_data" : in_data,
854
  "output_data" : new_filepath + new_filename
855
- }
 
856
  })
857
  app_info["stage2"] = stage2_info
858
  return {app_info_json : app_info,
@@ -865,7 +866,7 @@ with gr.Blocks() as demo:
865
  stage2_info = app_info["stage2"]
866
 
867
  filepath = stage2_info["filePath"]
868
- samplerate = app_info["sampleRate"]
869
  filename = stage2_info["fileNames"]["input_data"]
870
  new_filename = stage2_info["fileNames"]["output_data"]
871
 
@@ -917,7 +918,7 @@ with gr.Blocks() as demo:
917
 
918
  run_btn.click(
919
  fn = reset_run,
920
- inputs = [app_info_json, in_data_file, in_modelname],
921
  outputs = [app_info_json, run_btn, batch_md, out_data_file]
922
 
923
  ).success(
@@ -927,5 +928,5 @@ with gr.Blocks() as demo:
927
  )
928
 
929
  if __name__ == "__main__":
930
- demo.launch()
931
 
 
18
  ## 1. Channel Mapping
19
  The following steps will guide you through the process of mapping your EEG channels to our template channels.
20
 
21
+ ### Step1: Initial Matching and Scaling
22
+ After clicking on ``Mapping`` button, we will first match your channels to our template channels by their names. Using the matched channels as reference points, we will apply Thin Plate Spline (TPS) transformation to scale your montage to align with our template's dimensions. The template montage and your scaled montage will be displayed side by side for comparison. Channels that do not have a match in our template will be **highlighted in red**.
23
  - If your data includes all the 30 template channels, you will see the **Mapping Results**.
24
  - If your data doesn't include all the 30 template channels and you have some channels that do not match the template, you will be directed to **Step2**.
25
  - If all your channels are included in our template but you have fewer than 30 channels, you will be directed to **Step3**.
 
227
  with gr.Row():
228
 
229
  with gr.Column():
 
230
  with gr.Row(variant='panel'):
231
  with gr.Column():
232
  gr.Markdown("# 1.Channel Mapping")
233
+ # --------------------input----------------------
234
  in_loc_file = gr.File(label="Channel locations (.loc, .locs, .xyz, .sfp, .txt)",
235
  file_types=[".loc", "locs", ".xyz", ".sfp", ".txt"])
236
+ map_btn = gr.Button("Mapping")
237
+ # -------------------mapping---------------------
238
+ desc_md = gr.Markdown(visible=False)
239
+ # step1 : initial matching and scaling
240
  with gr.Row():
241
+ tpl_img = gr.Image("./template_montage.png", label="Template montage", visible=False)
242
+ mapped_img = gr.Image(label="Matching results", visible=False)
243
+ # step2 : forward unmatched input channels to empty template channels
244
+ radio_group = gr.Radio(elem_id="radio-group", visible=False)
245
+ # step3 : fill the remaining template channels
246
+ with gr.Row():
247
+ in_fillmode = gr.Dropdown(choices=["mean", "zero"],
248
+ value="mean",
249
+ label="Filling method",
250
+ visible=False,
251
+ scale=2)
252
+ fillmode_btn = gr.Button("OK", visible=False, scale=1)
253
+ chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", visible=False)
254
+ # step4 : mapping result
255
+ out_json_file = gr.File(label="Mapping result", visible=False)
256
+ res_md = gr.Markdown("""
257
+ (Download this file if you plan to run the models using the source code.)
258
+ """, visible=False)
259
+
260
+ with gr.Row():
261
+ clear_btn = gr.Button("Clear", visible=False)
262
+ step2_btn = gr.Button("Next", visible=False)
263
+ step3_btn = gr.Button("Next", visible=False)
264
+ next_btn = gr.Button("Next step", visible=False)
265
+ # -----------------------------------------------
 
 
 
 
 
 
266
 
267
  with gr.Column():
 
268
  with gr.Row(variant='panel'):
269
  with gr.Column():
270
  gr.Markdown("# 2.Decode Data")
271
+ # --------------------input----------------------
272
  with gr.Row():
273
+ in_data_file = gr.File(label="Raw data (.csv)", file_types=[".csv"])
274
+ with gr.Column():
275
+ in_samplerate = gr.Textbox(label="Sampling rate (Hz)")
276
+ in_modelname = gr.Dropdown(choices=[
277
+ ("ART", "EEGART"),
278
+ ("IC-U-Net", "ICUNet"),
279
+ ("IC-U-Net++", "UNetpp"),
280
+ ("IC-U-Net-Attn", "AttUnet")],
281
+ #"(mapped data)"],
282
+ value="EEGART",
283
+ label="Model")
284
+ run_btn = gr.Button(interactive=False)
285
+ # --------------------output---------------------
286
+ batch_md = gr.Markdown(visible=False)
287
+ out_data_file = gr.File(label="Denoised data", visible=False)
288
+ # -----------------------------------------------
289
 
290
  with gr.Row():
291
  with gr.Tab("ART"):
 
302
  # +========================================================================================+
303
  # | Stage1: channel mapping |
304
  # +========================================================================================+
305
+ def reset_all(in_loc):
306
  # establish a new folder for the current session
307
  rootpath = os.path.dirname(str(in_loc))
308
  try:
 
318
  channel_info = {}
319
  app_info = {
320
  "rootPath" : rootpath+"/session_data/",
 
321
  "stage1" : {
322
  "filePath" : rootpath+"/session_data/stage1/",
323
  "fileNames" : {
 
344
  "input_data" : "",
345
  "output_data" : ""
346
  },
347
+ "totalBatchNum" : None,
348
+ "sampleRate" : None
349
  }
350
  }
351
  return {app_info_json : app_info,
 
367
  res_md : gr.Markdown(visible=False),
368
  # --------------------Stage2-------------------------
369
  in_data_file : gr.File(value=None),
370
+ in_samplerate : gr.Textbox(visible=""),
371
  run_btn : gr.Button(interactive=False),
372
  batch_md : gr.Markdown(visible=False),
373
  out_data_file : gr.File(visible=False)}
 
388
 
389
  # match the names of in_channels and tpl_channels
390
  stage1_info, channel_info, tpl_montage, in_montage = app_utils.match_names(stage1_info, channel_info)
391
+ # scale the coordinates
392
  channel_info = app_utils.align_coords(channel_info, tpl_montage, in_montage)
393
  # generate and save figures of the montages
394
  filename1 = filepath+"input_montage_"+str(random.randint(1,10000))+".png"
 
402
  unassigned_num = len(stage1_info["unassignedInputs"])
403
  if unassigned_num == 0:
404
  md = """
405
+ ### Step1: Initial Matching and Scaling
406
  Below is the result of mapping your channels to our template channels based on their names.
407
  """
408
  else:
409
  md = """
410
+ ### Step1: Initial Matching and Scaling
411
  Below is the result of mapping your channels to our template channels based on their names.
412
  - channels highlighted in red are those that do not match any template channels.
413
  """
 
682
  # +========================================================================================+
683
  map_btn.click(
684
  fn = reset_all,
685
+ inputs = in_loc_file,
686
  outputs = [app_info_json, channel_info_json, map_btn, desc_md, next_btn, tpl_img, mapped_img,
687
  radio_group, clear_btn, step2_btn, in_fillmode, fillmode_btn, chkbox_group, step3_btn,
688
+ out_json_file, res_md, in_data_file, in_samplerate, run_btn, batch_md, out_data_file]
689
  ).success(
690
  fn = init_next_step,
691
  inputs = [app_info_json, channel_info_json, in_fillmode, radio_group, chkbox_group],
 
833
  # +========================================================================================+
834
  # | Stage2: decode data |
835
  # +========================================================================================+
836
+ def reset_run(app_info, in_data, samplerate, modelname):
837
  stage1_info = app_info["stage1"]
838
  stage2_info = app_info["stage2"]
839
 
 
852
  "fileNames" : {
853
  "input_data" : in_data,
854
  "output_data" : new_filepath + new_filename
855
+ },
856
+ "sampleRate" : int(samplerate)
857
  })
858
  app_info["stage2"] = stage2_info
859
  return {app_info_json : app_info,
 
866
  stage2_info = app_info["stage2"]
867
 
868
  filepath = stage2_info["filePath"]
869
+ samplerate = stage2_info["sampleRate"]
870
  filename = stage2_info["fileNames"]["input_data"]
871
  new_filename = stage2_info["fileNames"]["output_data"]
872
 
 
918
 
919
  run_btn.click(
920
  fn = reset_run,
921
+ inputs = [app_info_json, in_data_file, in_samplerate, in_modelname],
922
  outputs = [app_info_json, run_btn, batch_md, out_data_file]
923
 
924
  ).success(
 
928
  )
929
 
930
  if __name__ == "__main__":
931
+ demo.launch(server_name="0.0.0.0", server_port=7860)
932
 
app_utils.py CHANGED
@@ -89,7 +89,6 @@ def save_figures(channel_info, tpl_montage, filename1, filename2):
89
  tpl_dict = channel_info["templateDict"]
90
  in_dict = channel_info["inputDict"]
91
 
92
- # get the 2D coordinates
93
  tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
94
  tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
95
  in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
@@ -217,7 +216,6 @@ def find_neighbors(channel_info, missing_channels, new_idx):
217
  tpl_dict = channel_info["templateDict"]
218
  in_dict = channel_info["inputDict"]
219
 
220
- # get the 3D coordinates
221
  all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
222
  empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
223
 
@@ -290,7 +288,6 @@ def optimal_mapping(channel_info):
290
  for channel in tpl_dict:
291
  tpl_dict[channel]["matched"] = False
292
 
293
- # get the 3D coordinates
294
  all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
295
  unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
296
 
@@ -338,11 +335,12 @@ def optimal_mapping(channel_info):
338
  return mapping_data, channel_info
339
 
340
  def mapping_result(stage1_info, stage2_info, channel_info, filename):
341
- # calculate how many times the model needs to be run
342
  unassigned_num = len(stage1_info["unassignedInputs"])
343
  batch_num = math.ceil(unassigned_num/30) + 1
344
 
345
  # map the remaining in_channels
 
 
346
  for i in range(1, batch_num):
347
  # optimally select 30 in_channels to map to the tpl_channels based on proximity
348
  new_mapping_data, channel_info = optimal_mapping(channel_info)
 
89
  tpl_dict = channel_info["templateDict"]
90
  in_dict = channel_info["inputDict"]
91
 
 
92
  tpl_x = [tpl_dict[channel]["coord_2d"][0] for channel in tpl_order]
93
  tpl_y = [tpl_dict[channel]["coord_2d"][1] for channel in tpl_order]
94
  in_x = [in_dict[channel]["coord_2d"][0] for channel in in_order]
 
216
  tpl_dict = channel_info["templateDict"]
217
  in_dict = channel_info["inputDict"]
218
 
 
219
  all_in = [np.array(in_dict[channel]["coord_3d"]) for channel in in_order]
220
  empty_tpl = [np.array(tpl_dict[channel]["coord_3d"]) for channel in missing_channels]
221
 
 
288
  for channel in tpl_dict:
289
  tpl_dict[channel]["matched"] = False
290
 
 
291
  all_tpl = np.array([tpl_dict[channel]["coord_3d"] for channel in tpl_order])
292
  unassigned_in = np.array([in_dict[channel]["coord_3d"] for channel in unassigned])
293
 
 
335
  return mapping_data, channel_info
336
 
337
  def mapping_result(stage1_info, stage2_info, channel_info, filename):
 
338
  unassigned_num = len(stage1_info["unassignedInputs"])
339
  batch_num = math.ceil(unassigned_num/30) + 1
340
 
341
  # map the remaining in_channels
342
+ unassigned_num = len(stage1_info["unassignedInputs"])
343
+ batch_num = math.ceil(unassigned_num/30) + 1
344
  for i in range(1, batch_num):
345
  # optimally select 30 in_channels to map to the tpl_channels based on proximity
346
  new_mapping_data, channel_info = optimal_mapping(channel_info)