audrey06100 commited on
Commit
c1cd6f3
·
1 Parent(s): d92d76f
Files changed (2) hide show
  1. app.py +10 -10
  2. app_utils.py +17 -17
app.py CHANGED
@@ -331,7 +331,7 @@ with gr.Blocks() as demo:
331
  "batchNum" : None,
332
  "mappingResult" : [
333
  {
334
- "newOrder" : None,
335
  "isOriginalData" : None
336
  #"channelUsageNum" : None
337
  }
@@ -488,7 +488,7 @@ with gr.Blocks() as demo:
488
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
489
  sel_idx = channel_info["inputDict"][sel_radio]["index"]
490
 
491
- stage1_info["mappingResult"][0]["newOrder"][prev_tpl_idx] = [sel_idx]
492
  stage1_info["mappingResult"][0]["isOriginalData"][prev_tpl_idx] = True
493
  channel_info["templateDict"][prev_tpl_name]["matched"] = True
494
  channel_info["inputDict"][sel_radio]["assigned"] = True
@@ -565,10 +565,10 @@ with gr.Blocks() as demo:
565
  value of the data from the selected channels. (By default, the 4 nearest channels are pre-selected.)
566
  """
567
  # find the 4 nearest in_channels for each unmatched tpl_channels
568
- stage1_info["mappingResult"][0]["newOrder"] = app_utils.find_neighbors(
569
  channel_info,
570
  stage1_info["emptyTemplates"],
571
- stage1_info["mappingResult"][0]["newOrder"])
572
  # initialize the progress indication label
573
  stage1_info["step3"] = {
574
  "count" : 1,
@@ -578,7 +578,7 @@ with gr.Blocks() as demo:
578
  label = '{} (1/{})'.format(tpl_name, stage1_info["step3"]["totalNum"])
579
 
580
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
581
- value = stage1_info["mappingResult"][0]["newOrder"][tpl_idx]
582
  value = [channel_info["inputNames"][i] for i in value]
583
 
584
  stage1_info["state"] = "step3-2-selecting"
@@ -607,7 +607,7 @@ with gr.Blocks() as demo:
607
  prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step3"]["count"]-1]
608
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
609
  sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_chkbox]
610
- stage1_info["mappingResult"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
611
  #print(prev_tpl_name, '<-', sel_chkbox)
612
  # ----------------------------------------------------------------------------------
613
  md = """
@@ -686,7 +686,7 @@ with gr.Blocks() as demo:
686
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
687
  sel_idx = channel_info["inputDict"][sel_name]["index"]
688
 
689
- stage1_info["mappingResult"][0]["newOrder"][prev_tpl_idx] = [sel_idx]
690
  stage1_info["mappingResult"][0]["isOriginalData"][prev_tpl_idx] = True
691
  channel_info["templateDict"][prev_tpl_name]["matched"] = True
692
  channel_info["inputDict"][sel_name]["assigned"] = True
@@ -736,7 +736,7 @@ with gr.Blocks() as demo:
736
  prev_tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
737
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
738
  sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_name]
739
- stage1_info["mappingResult"][0]["newOrder"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
740
  #print(prev_tpl_name, '<-', sel_name)
741
 
742
  # ---------------------------------update the new round---------------------------------
@@ -746,7 +746,7 @@ with gr.Blocks() as demo:
746
  label = '{} ({}/{})'.format(tpl_name, step3["count"], step3["totalNum"])
747
 
748
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
749
- value = stage1_info["mappingResult"][0]["newOrder"][tpl_idx]
750
  value = [channel_info["inputNames"][i] for i in value]
751
 
752
  stage1_info["step3"] = step3
@@ -849,7 +849,7 @@ with gr.Blocks() as demo:
849
  outputname = outputname,
850
  samplerate = int(samplerate),
851
  batch_cnt = i,
852
- new_idx = stage1_info["mappingResult"][i]["newOrder"],
853
  orig_flags = stage1_info["mappingResult"][i]["isOriginalData"])
854
  except FileNotFoundError:
855
  print('stop!!')
 
331
  "batchNum" : None,
332
  "mappingResult" : [
333
  {
334
+ "indices" : None,
335
  "isOriginalData" : None
336
  #"channelUsageNum" : None
337
  }
 
488
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
489
  sel_idx = channel_info["inputDict"][sel_radio]["index"]
490
 
491
+ stage1_info["mappingResult"][0]["indices"][prev_tpl_idx] = [sel_idx]
492
  stage1_info["mappingResult"][0]["isOriginalData"][prev_tpl_idx] = True
493
  channel_info["templateDict"][prev_tpl_name]["matched"] = True
494
  channel_info["inputDict"][sel_radio]["assigned"] = True
 
565
  value of the data from the selected channels. (By default, the 4 nearest channels are pre-selected.)
566
  """
567
  # find the 4 nearest in_channels for each unmatched tpl_channels
568
+ stage1_info["mappingResult"][0]["indices"] = app_utils.find_neighbors(
569
  channel_info,
570
  stage1_info["emptyTemplates"],
571
+ stage1_info["mappingResult"][0]["indices"])
572
  # initialize the progress indication label
573
  stage1_info["step3"] = {
574
  "count" : 1,
 
578
  label = '{} (1/{})'.format(tpl_name, stage1_info["step3"]["totalNum"])
579
 
580
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
581
+ value = stage1_info["mappingResult"][0]["indices"][tpl_idx]
582
  value = [channel_info["inputNames"][i] for i in value]
583
 
584
  stage1_info["state"] = "step3-2-selecting"
 
607
  prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step3"]["count"]-1]
608
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
609
  sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_chkbox]
610
+ stage1_info["mappingResult"][0]["indices"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
611
  #print(prev_tpl_name, '<-', sel_chkbox)
612
  # ----------------------------------------------------------------------------------
613
  md = """
 
686
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
687
  sel_idx = channel_info["inputDict"][sel_name]["index"]
688
 
689
+ stage1_info["mappingResult"][0]["indices"][prev_tpl_idx] = [sel_idx]
690
  stage1_info["mappingResult"][0]["isOriginalData"][prev_tpl_idx] = True
691
  channel_info["templateDict"][prev_tpl_name]["matched"] = True
692
  channel_info["inputDict"][sel_name]["assigned"] = True
 
736
  prev_tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
737
  prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
738
  sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_name]
739
+ stage1_info["mappingResult"][0]["indices"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
740
  #print(prev_tpl_name, '<-', sel_name)
741
 
742
  # ---------------------------------update the new round---------------------------------
 
746
  label = '{} ({}/{})'.format(tpl_name, step3["count"], step3["totalNum"])
747
 
748
  tpl_idx = channel_info["templateDict"][tpl_name]["index"]
749
+ value = stage1_info["mappingResult"][0]["indices"][tpl_idx]
750
  value = [channel_info["inputNames"][i] for i in value]
751
 
752
  stage1_info["step3"] = step3
 
849
  outputname = outputname,
850
  samplerate = int(samplerate),
851
  batch_cnt = i,
852
+ old_idx = stage1_info["mappingResult"][i]["indices"],
853
  orig_flags = stage1_info["mappingResult"][i]["isOriginalData"])
854
  except FileNotFoundError:
855
  print('stop!!')
app_utils.py CHANGED
@@ -174,7 +174,7 @@ def align_coords(channel_info, tpl_montage, in_montage):
174
  })
175
  return channel_info
176
 
177
- def find_neighbors(channel_info, empty_tpl_names, new_idx):
178
  in_names = channel_info["inputNames"]
179
  tpl_dict = channel_info["templateDict"]
180
  in_dict = channel_info["inputDict"]
@@ -189,9 +189,9 @@ def find_neighbors(channel_info, empty_tpl_names, new_idx):
189
  for i, name in enumerate(empty_tpl_names):
190
  distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
191
  idx = tpl_dict[name]["index"]
192
- new_idx[idx] = indices[0].tolist()
193
 
194
- return new_idx
195
 
196
  def match_names(stage1_info):
197
  # read the location file
@@ -199,7 +199,7 @@ def match_names(stage1_info):
199
  tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
200
  tpl_names = tpl_montage.ch_names
201
  in_names = in_montage.ch_names
202
- new_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
203
  orig_flags = [False]*30
204
 
205
  alias_dict = {
@@ -215,7 +215,7 @@ def match_names(stage1_info):
215
  name = alias_dict[name]
216
 
217
  if name in in_dict:
218
- new_idx[i] = [in_dict[name]["index"]]
219
  orig_flags[i] = True
220
  tpl_dict[name]["matched"] = True
221
  in_dict[name]["assigned"] = True
@@ -228,7 +228,7 @@ def match_names(stage1_info):
228
  "emptyTemplates" : get_empty_templates(tpl_names, tpl_dict),
229
  "mappingResult" : [
230
  {
231
- "newOrder" : new_idx,
232
  "isOriginalData" : orig_flags
233
  }
234
  ]
@@ -269,14 +269,14 @@ def optimal_mapping(channel_info):
269
  row_idx, col_idx = linear_sum_assignment(cost_matrix)
270
 
271
  # store the mapping result
272
- new_idx = [[None]]*30
273
  orig_flags = [False]*30
274
  for i, j in zip(row_idx, col_idx):
275
  if j < len(unass_in_names): # filter out dummy channels
276
  tpl_name = tpl_names[i]
277
  in_name = unass_in_names[j]
278
 
279
- new_idx[i] = [in_dict[in_name]["index"]]
280
  orig_flags[i] = True
281
  tpl_dict[tpl_name]["matched"] = True
282
  in_dict[in_name]["assigned"] = True
@@ -285,10 +285,10 @@ def optimal_mapping(channel_info):
285
  # fill the remaining empty tpl_channels
286
  empty_tpl_names = get_empty_templates(tpl_names, tpl_dict)
287
  if empty_tpl_names != []:
288
- new_idx = find_neighbors(channel_info, empty_tpl_names, new_idx)
289
 
290
  result = {
291
- "newOrder" : new_idx,
292
  "isOriginalData" : orig_flags
293
  }
294
  channel_info.update({
@@ -327,14 +327,14 @@ def mapping_result(stage1_info, channel_info, filename):
327
  return stage1_info, channel_info
328
 
329
 
330
- def reorder_data(idx_order, orig_flags, inputname, filename):
331
  # read the input data
332
  raw_data = utils.read_train_data(inputname)
333
  #print(raw_data.shape)
334
  new_data = np.zeros((30, raw_data.shape[1]))
335
 
336
  zero_arr = np.zeros((1, raw_data.shape[1]))
337
- for i, (indices, flag) in enumerate(zip(idx_order, orig_flags)):
338
  if flag == True:
339
  new_data[i, :] = raw_data[indices[0], :]
340
  elif indices == [None]:
@@ -346,7 +346,7 @@ def reorder_data(idx_order, orig_flags, inputname, filename):
346
  utils.save_data(new_data, filename)
347
  return raw_data.shape
348
 
349
- def restore_order(batch_cnt, raw_data_shape, idx_order, orig_flags, filename, outputname):
350
  # read the denoised data
351
  d_data = utils.read_train_data(filename)
352
  if batch_cnt == 0:
@@ -355,25 +355,25 @@ def restore_order(batch_cnt, raw_data_shape, idx_order, orig_flags, filename, ou
355
  else:
356
  new_data = utils.read_train_data(outputname)
357
 
358
- for i, (indices, flag) in enumerate(zip(idx_order, orig_flags)):
359
  if flag == True:
360
  new_data[indices[0], :] = d_data[i, :]
361
 
362
  utils.save_data(new_data, outputname)
363
  return
364
 
365
- def run_model(modelname, filepath, inputname, m_filename, d_filename, outputname, samplerate, batch_cnt, new_idx, orig_flags):
366
  # establish temp folder
367
  os.mkdir(filepath+'temp_data/')
368
 
369
  # step1: Reorder input data
370
- data_shape = reorder_data(new_idx, orig_flags, inputname, filepath+'temp_data/'+m_filename)
371
  # step2: Data preprocessing
372
  total_file_num = utils.preprocessing(filepath+'temp_data/', m_filename, samplerate)
373
  # step3: Signal reconstruction
374
  utils.reconstruct(modelname, total_file_num, filepath+'temp_data/', d_filename, samplerate)
375
  # step4: Restore original order
376
- restore_order(batch_cnt, data_shape, new_idx, orig_flags, filepath+'temp_data/'+d_filename, outputname)
377
 
378
  utils.dataDelete(filepath+'temp_data/')
379
 
 
174
  })
175
  return channel_info
176
 
177
+ def find_neighbors(channel_info, empty_tpl_names, old_idx):
178
  in_names = channel_info["inputNames"]
179
  tpl_dict = channel_info["templateDict"]
180
  in_dict = channel_info["inputDict"]
 
189
  for i, name in enumerate(empty_tpl_names):
190
  distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
191
  idx = tpl_dict[name]["index"]
192
+ old_idx[idx] = indices[0].tolist()
193
 
194
+ return old_idx
195
 
196
  def match_names(stage1_info):
197
  # read the location file
 
199
  tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
200
  tpl_names = tpl_montage.ch_names
201
  in_names = in_montage.ch_names
202
+ old_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
203
  orig_flags = [False]*30
204
 
205
  alias_dict = {
 
215
  name = alias_dict[name]
216
 
217
  if name in in_dict:
218
+ old_idx[i] = [in_dict[name]["index"]]
219
  orig_flags[i] = True
220
  tpl_dict[name]["matched"] = True
221
  in_dict[name]["assigned"] = True
 
228
  "emptyTemplates" : get_empty_templates(tpl_names, tpl_dict),
229
  "mappingResult" : [
230
  {
231
+ "indices" : old_idx,
232
  "isOriginalData" : orig_flags
233
  }
234
  ]
 
269
  row_idx, col_idx = linear_sum_assignment(cost_matrix)
270
 
271
  # store the mapping result
272
+ old_idx = [[None]]*30
273
  orig_flags = [False]*30
274
  for i, j in zip(row_idx, col_idx):
275
  if j < len(unass_in_names): # filter out dummy channels
276
  tpl_name = tpl_names[i]
277
  in_name = unass_in_names[j]
278
 
279
+ old_idx[i] = [in_dict[in_name]["index"]]
280
  orig_flags[i] = True
281
  tpl_dict[tpl_name]["matched"] = True
282
  in_dict[in_name]["assigned"] = True
 
285
  # fill the remaining empty tpl_channels
286
  empty_tpl_names = get_empty_templates(tpl_names, tpl_dict)
287
  if empty_tpl_names != []:
288
+ old_idx = find_neighbors(channel_info, empty_tpl_names, old_idx)
289
 
290
  result = {
291
+ "indices" : old_idx,
292
  "isOriginalData" : orig_flags
293
  }
294
  channel_info.update({
 
327
  return stage1_info, channel_info
328
 
329
 
330
+ def reorder_data(old_idx, orig_flags, inputname, filename):
331
  # read the input data
332
  raw_data = utils.read_train_data(inputname)
333
  #print(raw_data.shape)
334
  new_data = np.zeros((30, raw_data.shape[1]))
335
 
336
  zero_arr = np.zeros((1, raw_data.shape[1]))
337
+ for i, (indices, flag) in enumerate(zip(old_idx, orig_flags)):
338
  if flag == True:
339
  new_data[i, :] = raw_data[indices[0], :]
340
  elif indices == [None]:
 
346
  utils.save_data(new_data, filename)
347
  return raw_data.shape
348
 
349
+ def restore_order(batch_cnt, raw_data_shape, old_idx, orig_flags, filename, outputname):
350
  # read the denoised data
351
  d_data = utils.read_train_data(filename)
352
  if batch_cnt == 0:
 
355
  else:
356
  new_data = utils.read_train_data(outputname)
357
 
358
+ for i, (indices, flag) in enumerate(zip(old_idx, orig_flags)):
359
  if flag == True:
360
  new_data[indices[0], :] = d_data[i, :]
361
 
362
  utils.save_data(new_data, outputname)
363
  return
364
 
365
+ def run_model(modelname, filepath, inputname, m_filename, d_filename, outputname, samplerate, batch_cnt, old_idx, orig_flags):
366
  # establish temp folder
367
  os.mkdir(filepath+'temp_data/')
368
 
369
  # step1: Reorder input data
370
+ data_shape = reorder_data(old_idx, orig_flags, inputname, filepath+'temp_data/'+m_filename)
371
  # step2: Data preprocessing
372
  total_file_num = utils.preprocessing(filepath+'temp_data/', m_filename, samplerate)
373
  # step3: Signal reconstruction
374
  utils.reconstruct(modelname, total_file_num, filepath+'temp_data/', d_filename, samplerate)
375
  # step4: Restore original order
376
+ restore_order(batch_cnt, data_shape, old_idx, orig_flags, filepath+'temp_data/'+d_filename, outputname)
377
 
378
  utils.dataDelete(filepath+'temp_data/')
379