audrey06100 commited on
Commit
9fb4a28
·
1 Parent(s): db94983
Files changed (2) hide show
  1. app.py +20 -37
  2. app_utils.py +5 -9
app.py CHANGED
@@ -230,13 +230,13 @@ with gr.Blocks() as demo:
230
  desc_md = gr.Markdown(visible=False)
231
  out_json_file = gr.File(visible=False)
232
  # --------------------mapping--------------------
233
- # step1 : initial matching and scaling
234
  with gr.Row():
235
  tpl_img = gr.Image("./template_montage.png", label="Template montage", visible=False)
236
  mapped_img = gr.Image(label="Matching results", visible=False)
237
- # step2 : forward unmatched input channels to empty template channels
238
  radio_group = gr.Radio(elem_id="radio-group", visible=False)
239
- # step3 : fill the remaining template channels
240
  with gr.Row():
241
  in_fillmode = gr.Dropdown(choices=["mean", "zero"],
242
  value="mean",
@@ -303,7 +303,6 @@ with gr.Blocks() as demo:
303
  stage1_id = uuid.uuid4().hex
304
  os.mkdir(rootpath+'/'+stage1_id+'/')
305
 
306
- # initialize stage1_info, stage2_info, channel_info
307
  stage1_info = {
308
  "id" : stage1_id,
309
  "filePath" : rootpath+'/'+stage1_id+'/',
@@ -322,7 +321,7 @@ with gr.Blocks() as demo:
322
  "mappingData" : [
323
  {
324
  "newOrder" : None,
325
- "fillFlags" : None,
326
  #"channelUsageNum" : None
327
  }
328
  ]
@@ -401,7 +400,6 @@ with gr.Blocks() as demo:
401
  matched_num = 30 - len(stage1_info["missingTemplates"])
402
 
403
  # step1 to step4
404
- # the in_channels has all the 30 tpl_channels (in_num>=30)
405
  if matched_num == 30:
406
  md = """
407
  ### Mapping Results
@@ -421,9 +419,8 @@ with gr.Blocks() as demo:
421
  next_btn : gr.Button(visible=False),
422
  out_json_file : gr.File(filename, visible=True)}
423
 
424
- # step1 to step2
425
- # matched_num < 30, and there're still some unmatched in_channels
426
- elif in_num > matched_num:
427
  md = """
428
  ### Step2: Forwarding Unmatched Channels
429
  Select one of your unmatched channels to forward its data to the empty template channel
@@ -456,8 +453,7 @@ with gr.Blocks() as demo:
456
  step2_btn : gr.Button(visible=True),
457
  next_btn : gr.Button(visible=False)}
458
 
459
- # step1 to step3-1
460
- # in_num < 30, but all of them can match to some tpl_channels
461
  elif in_num == matched_num:
462
  md = """
463
  ### Step3: Filling Remaining Template Channels
@@ -489,22 +485,20 @@ with gr.Blocks() as demo:
489
  #print(prev_target_name, '<-', selected_radio)
490
 
491
  # -----------------------update information for the next step-----------------------
492
- # update the list of unassignedInputs to exclude the selected in_channel of the previous round
493
  stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"],
494
  channel_info["inputDict"])
495
- # update the list of missingTemplates to exclude those filled in step2
496
  stage1_info["missingTemplates"] = app_utils.get_empty_templates(channel_info["templateOrder"],
497
  channel_info["templateDict"])
498
  # -----------------------------determine the next step------------------------------
499
  # step2 to step4
500
- # all the unmatched tpl_channels were filled
501
  if len(stage1_info["missingTemplates"]) == 0:
502
  md = """
503
  ### Mapping Results
504
  The mapping process has been finished.
505
  Download the file below if you plan to run the models using the <a href="">source code</a>.
506
  """
507
- # finalize and save the mapping results
508
  filename = stage1_info["fileNames"]["output_json"]
509
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, filename)
510
 
@@ -542,7 +536,6 @@ with gr.Blocks() as demo:
542
  The mapping process has been finished.
543
  Download the file below if you plan to run the models using the <a href="">source code</a>.
544
  """
545
- # finalize and save the mapping results
546
  filename = stage1_info["fileNames"]["output_json"]
547
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, filename)
548
 
@@ -612,7 +605,6 @@ with gr.Blocks() as demo:
612
  The mapping process has been finished.
613
  Download the file below if you plan to run the models using the <a href="">source code</a>.
614
  """
615
- # finalize and save the mapping results
616
  filename = stage1_info["fileNames"]["output_json"]
617
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, filename)
618
 
@@ -663,7 +655,7 @@ with gr.Blocks() as demo:
663
  return {step2_btn : gr.Button(visible=False),
664
  next_btn : gr.Button(visible=True)}
665
  else:
666
- return {step2_btn : gr.Button()} # change nothing
667
  # clear the selected value and reset the buttons
668
  @clear_btn.click(inputs = stage1_json, outputs = [radio_group, step2_btn, next_btn])
669
  def clear_value(stage1_info):
@@ -691,7 +683,7 @@ with gr.Blocks() as demo:
691
  # ------------------------update information for the new round--------------------------
692
  stage1_info["fillingCount"] += 1
693
 
694
- # update the list of unassignedInputs to exclude the selected in_channel of the previous round
695
  stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"], channel_info["inputDict"])
696
 
697
  target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
@@ -789,7 +781,7 @@ with gr.Blocks() as demo:
789
  return gr.Button(interactive=True)
790
  else:
791
  return gr.Button(interactive=False)
792
-
793
  @cancel_btn.click(inputs = stage2_json, outputs = [stage2_json, cancel_btn, batch_md])
794
  def stop_processing(stage2_info):
795
  utils.dataDelete(stage2_info["filePath"])
@@ -828,32 +820,23 @@ with gr.Blocks() as demo:
828
  def run_model(stage1_info, stage2_info, modelname):
829
  batch_num = stage1_info["totalBatchNum"]
830
  mapping_data = stage1_info["mappingData"]
831
-
832
  samplerate = stage2_info["sampleRate"]
833
  filepath = stage2_info["filePath"]
834
  filename = stage2_info["fileNames"]["input_data"]
835
  new_filename = stage2_info["fileNames"]["output_data"]
836
 
837
- break_flag = False # indicate if the process has been interrupted by the user
838
  for i in range(batch_num):
839
- md = 'Running model({}/{})...'.format(i+1, batch_num)
840
- yield {batch_md : gr.Markdown(md)}
841
 
842
- # establish a temp folder
843
- try:
844
- os.mkdir(filepath+'temp_data/')
845
- except FileNotFoundError:
846
- print('break1')
847
- break_flag = True
848
- break
849
-
850
- # get the mapped index order and the filled status for each tpl_channels
851
  new_idx = mapping_data[i]["newOrder"]
852
  fill_flags = mapping_data[i]["fillFlags"]
853
- # ----------------------------------------------------------------------------------
854
  m_filename = 'mapped_{:02d}.csv'.format(i+1)
855
  d_filename = 'denoised_{:02d}.csv'.format(i+1)
856
  try:
 
 
 
857
  # step1: Reorder input data
858
  data_shape = app_utils.reorder_data(new_idx, fill_flags, filename, filepath+'temp_data/'+m_filename)
859
  # step2: Data preprocessing
@@ -863,11 +846,11 @@ with gr.Blocks() as demo:
863
  # step4: Restore original order
864
  app_utils.restore_order(i, data_shape, new_idx, fill_flags, filepath+'temp_data/'+d_filename, new_filename)
865
  except FileNotFoundError:
866
- print('break2')
867
  break_flag = True
868
  break
869
- # ----------------------------------------------------------------------------------
870
- utils.dataDelete(filepath+'temp_data/')
871
 
872
  if break_flag == True:
873
  yield {run_btn : gr.Button(visible=True),
 
230
  desc_md = gr.Markdown(visible=False)
231
  out_json_file = gr.File(visible=False)
232
  # --------------------mapping--------------------
233
+ # step1
234
  with gr.Row():
235
  tpl_img = gr.Image("./template_montage.png", label="Template montage", visible=False)
236
  mapped_img = gr.Image(label="Matching results", visible=False)
237
+ # step2
238
  radio_group = gr.Radio(elem_id="radio-group", visible=False)
239
+ # step3
240
  with gr.Row():
241
  in_fillmode = gr.Dropdown(choices=["mean", "zero"],
242
  value="mean",
 
303
  stage1_id = uuid.uuid4().hex
304
  os.mkdir(rootpath+'/'+stage1_id+'/')
305
 
 
306
  stage1_info = {
307
  "id" : stage1_id,
308
  "filePath" : rootpath+'/'+stage1_id+'/',
 
321
  "mappingData" : [
322
  {
323
  "newOrder" : None,
324
+ "fillFlags" : None
325
  #"channelUsageNum" : None
326
  }
327
  ]
 
400
  matched_num = 30 - len(stage1_info["missingTemplates"])
401
 
402
  # step1 to step4
 
403
  if matched_num == 30:
404
  md = """
405
  ### Mapping Results
 
419
  next_btn : gr.Button(visible=False),
420
  out_json_file : gr.File(filename, visible=True)}
421
 
422
+ # step1 to step2 (matched_num<30, and there're still some unmatched in_channels)
423
+ elif in_num > matched_num:
 
424
  md = """
425
  ### Step2: Forwarding Unmatched Channels
426
  Select one of your unmatched channels to forward its data to the empty template channel
 
453
  step2_btn : gr.Button(visible=True),
454
  next_btn : gr.Button(visible=False)}
455
 
456
+ # step1 to step3-1 (in_num<30, but all of them can match to some tpl_channels)
 
457
  elif in_num == matched_num:
458
  md = """
459
  ### Step3: Filling Remaining Template Channels
 
485
  #print(prev_target_name, '<-', selected_radio)
486
 
487
  # -----------------------update information for the next step-----------------------
488
+ # exclude the selected in_channel of the previous round
489
  stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"],
490
  channel_info["inputDict"])
491
+ # exclude the tpl_channels filled in step2
492
  stage1_info["missingTemplates"] = app_utils.get_empty_templates(channel_info["templateOrder"],
493
  channel_info["templateDict"])
494
  # -----------------------------determine the next step------------------------------
495
  # step2 to step4
 
496
  if len(stage1_info["missingTemplates"]) == 0:
497
  md = """
498
  ### Mapping Results
499
  The mapping process has been finished.
500
  Download the file below if you plan to run the models using the <a href="">source code</a>.
501
  """
 
502
  filename = stage1_info["fileNames"]["output_json"]
503
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, filename)
504
 
 
536
  The mapping process has been finished.
537
  Download the file below if you plan to run the models using the <a href="">source code</a>.
538
  """
 
539
  filename = stage1_info["fileNames"]["output_json"]
540
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, filename)
541
 
 
605
  The mapping process has been finished.
606
  Download the file below if you plan to run the models using the <a href="">source code</a>.
607
  """
 
608
  filename = stage1_info["fileNames"]["output_json"]
609
  stage1_info, channel_info = app_utils.mapping_result(stage1_info, channel_info, filename)
610
 
 
655
  return {step2_btn : gr.Button(visible=False),
656
  next_btn : gr.Button(visible=True)}
657
  else:
658
+ return {step2_btn : gr.Button()}
659
  # clear the selected value and reset the buttons
660
  @clear_btn.click(inputs = stage1_json, outputs = [radio_group, step2_btn, next_btn])
661
  def clear_value(stage1_info):
 
683
  # ------------------------update information for the new round--------------------------
684
  stage1_info["fillingCount"] += 1
685
 
686
+ # exclude the selected in_channel of the previous round
687
  stage1_info["unassignedInputs"] = app_utils.get_unassigned_inputs(channel_info["inputOrder"], channel_info["inputDict"])
688
 
689
  target_name = stage1_info["missingTemplates"][stage1_info["fillingCount"]-1]
 
781
  return gr.Button(interactive=True)
782
  else:
783
  return gr.Button(interactive=False)
784
+ # interrupt Stage2
785
  @cancel_btn.click(inputs = stage2_json, outputs = [stage2_json, cancel_btn, batch_md])
786
  def stop_processing(stage2_info):
787
  utils.dataDelete(stage2_info["filePath"])
 
820
  def run_model(stage1_info, stage2_info, modelname):
821
  batch_num = stage1_info["totalBatchNum"]
822
  mapping_data = stage1_info["mappingData"]
 
823
  samplerate = stage2_info["sampleRate"]
824
  filepath = stage2_info["filePath"]
825
  filename = stage2_info["fileNames"]["input_data"]
826
  new_filename = stage2_info["fileNames"]["output_data"]
827
 
828
+ break_flag = False # indicate if the process has been interrupted
829
  for i in range(batch_num):
830
+ yield {batch_md : gr.Markdown('Running model({}/{})...'.format(i+1, batch_num))}
 
831
 
 
 
 
 
 
 
 
 
 
832
  new_idx = mapping_data[i]["newOrder"]
833
  fill_flags = mapping_data[i]["fillFlags"]
 
834
  m_filename = 'mapped_{:02d}.csv'.format(i+1)
835
  d_filename = 'denoised_{:02d}.csv'.format(i+1)
836
  try:
837
+ # establish a temp folder
838
+ os.mkdir(filepath+'temp_data/')
839
+
840
  # step1: Reorder input data
841
  data_shape = app_utils.reorder_data(new_idx, fill_flags, filename, filepath+'temp_data/'+m_filename)
842
  # step2: Data preprocessing
 
846
  # step4: Restore original order
847
  app_utils.restore_order(i, data_shape, new_idx, fill_flags, filepath+'temp_data/'+d_filename, new_filename)
848
  except FileNotFoundError:
849
+ print('break!!')
850
  break_flag = True
851
  break
852
+ else:
853
+ utils.dataDelete(filepath+'temp_data/')
854
 
855
  if break_flag == True:
856
  yield {run_btn : gr.Button(visible=True),
app_utils.py CHANGED
@@ -39,8 +39,7 @@ def restore_order(batch_cnt, raw_data_shape, idx_order, fill_flags, filename, ne
39
  new_data = utils.read_train_data(new_filename)
40
 
41
  for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
42
- # ignore if this channel was filled using "fillmode"
43
- if flag == False:
44
  new_data[idx_set[0], :] = d_data[i, :]
45
 
46
  utils.save_data(new_data, new_filename)
@@ -63,7 +62,7 @@ def read_montage_data(loc_file):
63
  tpl_dict = {}
64
  in_dict = {}
65
 
66
- # convert all channel names to uppercase and store the channel information
67
  for i, channel in enumerate(tpl_order):
68
  up_channel = str.upper(channel)
69
  tpl_montage.rename_channels({channel: up_channel})
@@ -163,14 +162,14 @@ def align_coords(channel_info, tpl_montage, in_montage):
163
  fig = [tpl_montage.plot(), in_montage.plot()]
164
  ax = [fig[0].axes[0], fig[1].axes[0]]
165
 
166
- # extract the displayed 2D coordinates from the plots
167
  all_tpl = ax[0].collections[0].get_offsets().data
168
  all_in= ax[1].collections[0].get_offsets().data
169
  matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
170
  matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
171
  plt.close('all')
172
 
173
- # apply TPS to transform in_channels positions to align with tpl_channels positions
174
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
175
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
176
 
@@ -179,7 +178,6 @@ def align_coords(channel_info, tpl_montage, in_montage):
179
  transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
180
  transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
181
 
182
- # store the 2D positions
183
  for i, channel in enumerate(tpl_order):
184
  tpl_dict[channel]["coord_2d"] = all_tpl[i]
185
  for i, channel in enumerate(in_order):
@@ -201,7 +199,6 @@ def align_coords(channel_info, tpl_montage, in_montage):
201
  transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
202
  transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
203
 
204
- # update in_channels' 3D positions
205
  for i, channel in enumerate(in_order):
206
  in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
207
 
@@ -296,7 +293,7 @@ def optimal_mapping(channel_info):
296
  cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
297
  else:
298
  cost_matrix = np.zeros((30, len(unassigned)))
299
- # fill the cost matrix with Euclidean distances between tpl_channels and unassigned in_channels
300
  for i in range(30):
301
  for j in range(len(unassigned)):
302
  cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
@@ -345,7 +342,6 @@ def mapping_result(stage1_info, channel_info, filename):
345
  new_data, channel_info = optimal_mapping(channel_info)
346
  all_mapping_data += [new_data]
347
 
348
- # save the mapping results
349
  new_dict = {
350
  #"templateOrder" : channel_info["templateOrder"],
351
  #"inputOrder" : channel_info["inputOrder"],
 
39
  new_data = utils.read_train_data(new_filename)
40
 
41
  for i, (idx_set, flag) in enumerate(zip(idx_order, fill_flags)):
42
+ if flag == False: # ignore if this channel was filled using "fillmode"
 
43
  new_data[idx_set[0], :] = d_data[i, :]
44
 
45
  utils.save_data(new_data, new_filename)
 
62
  tpl_dict = {}
63
  in_dict = {}
64
 
65
+ # convert all channel names to uppercase and store their information
66
  for i, channel in enumerate(tpl_order):
67
  up_channel = str.upper(channel)
68
  tpl_montage.rename_channels({channel: up_channel})
 
162
  fig = [tpl_montage.plot(), in_montage.plot()]
163
  ax = [fig[0].axes[0], fig[1].axes[0]]
164
 
165
+ # extract the displayed 2D coordinates
166
  all_tpl = ax[0].collections[0].get_offsets().data
167
  all_in= ax[1].collections[0].get_offsets().data
168
  matched_tpl = np.array([all_tpl[tpl_dict[channel]["index"]] for channel in matched])
169
  matched_in = np.array([all_in[in_dict[channel]["index"]] for channel in matched])
170
  plt.close('all')
171
 
172
+ # apply TPS to transform in_channels to align with tpl_channels positions
173
  rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,0], function='thin_plate')
174
  rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_tpl[:,1], function='thin_plate')
175
 
 
178
  transformed_in_y = rbf_y(all_in[:,0], all_in[:,1])
179
  transformed_in = np.vstack((transformed_in_x, transformed_in_y)).T
180
 
 
181
  for i, channel in enumerate(tpl_order):
182
  tpl_dict[channel]["coord_2d"] = all_tpl[i]
183
  for i, channel in enumerate(in_order):
 
199
  transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
200
  transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
201
 
 
202
  for i, channel in enumerate(in_order):
203
  in_dict[channel]["coord_3d"] = transformed_in[i].tolist()
204
 
 
293
  cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col >= num_row
294
  else:
295
  cost_matrix = np.zeros((30, len(unassigned)))
296
+ # fill the cost matrix with Euclidean distances between tpl and unassigned in_channels
297
  for i in range(30):
298
  for j in range(len(unassigned)):
299
  cost_matrix[i][j] = np.linalg.norm((all_tpl[i]-unassigned_in[j])*1000)
 
342
  new_data, channel_info = optimal_mapping(channel_info)
343
  all_mapping_data += [new_data]
344
 
 
345
  new_dict = {
346
  #"templateOrder" : channel_info["templateOrder"],
347
  #"inputOrder" : channel_info["inputOrder"],