audrey06100 commited on
Commit
995c1d0
·
1 Parent(s): 884c10b
Files changed (2) hide show
  1. app.py +72 -32
  2. channel_mapping.py +21 -12
app.py CHANGED
@@ -55,7 +55,7 @@ chkbox_js = """
55
  position: relative;
56
  width: 560px;
57
  height: 560px;
58
- background: url("file=${app_state.files.raw_montage}");
59
  `;
60
 
61
 
@@ -191,7 +191,6 @@ with gr.Blocks() as demo:
191
  label="Imputation")
192
  map_btn = gr.Button("Mapping")
193
 
194
- #indic
195
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
196
  next_btn = gr.Button("Next", interactive=False, visible=False)
197
 
@@ -228,7 +227,7 @@ with gr.Blocks() as demo:
228
  scale=2)
229
  run_btn = gr.Button(scale=1, interactive=False)
230
  batch_md = gr.Markdown(visible=False)
231
- out_denoised_data = gr.File(label="Denoised data")
232
 
233
 
234
  with gr.Row():
@@ -245,7 +244,7 @@ with gr.Blocks() as demo:
245
 
246
  #demo.load(js=js)
247
 
248
- def reset_layout(raw_data, samplerate):
249
  # establish temp folder
250
  filepath = os.path.dirname(str(raw_data))
251
  try:
@@ -259,7 +258,7 @@ with gr.Blocks() as demo:
259
  data = utils.read_train_data(raw_data)
260
  app_state = {
261
  "filepath": filepath+"/temp_data/",
262
- "files": {},
263
  "sampleRate": int(samplerate),
264
 
265
  }
@@ -275,7 +274,8 @@ with gr.Blocks() as demo:
275
  tpl_montage : gr.Image(visible=False),
276
  map_montage : gr.Image(value=None, visible=False),
277
  res_md : gr.Markdown(visible=False),
278
- batch_md : gr.Markdown(visible=False)}
 
279
 
280
  def mapping_result(app_state, channel_info, fill_mode):
281
 
@@ -283,7 +283,6 @@ with gr.Blocks() as demo:
283
  matched_num = 30 - len(channel_info["missingChannelsIndex"])
284
  batch_num = math.ceil((in_num-matched_num)/30) + 1
285
  app_state.update({
286
- "runnigState" : "stage1",
287
  "batchCount" : 1,
288
  "totalBatchNum" : batch_num
289
  })
@@ -295,11 +294,11 @@ with gr.Blocks() as demo:
295
  })
296
  #print("Missing channels:", channel_info["missingChannelsIndex"])
297
  return {app_state_json : app_state,
298
- #chkbox_group : gr.CheckboxGroup(visible=True),
299
  next_btn : gr.Button(visible=True)}
300
  else:
301
- app_state["state"] = "finished"
302
-
 
303
  return {app_state_json : app_state,
304
  res_md : gr.Markdown(visible=True),
305
  run_btn : gr.Button(interactive=True)}
@@ -318,7 +317,7 @@ with gr.Blocks() as demo:
318
 
319
  if app_state["state"] == "initializing":
320
  filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
321
- app_state["files"]["raw_montage"] = filename
322
  raw_fig = raw_montage.plot()
323
  raw_fig.set_size_inches(5.6, 5.6)
324
  raw_fig.savefig(filename, pad_inches=0)
@@ -327,7 +326,7 @@ with gr.Blocks() as demo:
327
 
328
  elif app_state["state"] == "finished":
329
  filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
330
- app_state["files"]["map_montage"] = filename
331
 
332
  show_names= []
333
  for channel in channel_info["inputByName"]:
@@ -361,9 +360,10 @@ with gr.Blocks() as demo:
361
 
362
 
363
  map_btn.click(
364
- fn = reset_layout,
365
  inputs = [in_raw_data, in_sample_rate],
366
- outputs = [app_state_json, channel_info_json, chkbox_group, next_btn, run_btn, tpl_montage, map_montage, res_md, batch_md]
 
367
 
368
  ).success(
369
  fn = mapping_stage1,
@@ -401,7 +401,7 @@ with gr.Blocks() as demo:
401
  prev_target_name = channel_info["templateByIndex"][prev_target_idx]
402
 
403
  selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected]
404
- app_state["newOrder"][prev_target_idx] = selected_idx
405
 
406
  #if len(selected)==1 and channel_info["inputByName"][selected[0]]["used"]==False:
407
  #channel_info["inputByName"][selected[0]]["used"] = True
@@ -450,22 +450,47 @@ with gr.Blocks() as demo:
450
  outputs = []
451
  )
452
 
453
- @run_btn.click(inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name, in_fill_mode],
454
- outputs = [batch_md, out_denoised_data])
455
- def run_model(app_state, channel_info, raw_data, model_name, fill_mode):
 
 
 
 
456
  filepath = app_state["filepath"]
457
- samplerate = app_state["sampleRate"]
458
-
459
  input_name = os.path.basename(str(raw_data))
460
  output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
461
 
462
- while(app_state["runnigState"] != "finished"):
463
- if app_state["batchCount"] > app_state["totalBatchNum"]:
464
- app_state["runnigState"] = "finished"
465
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  if app_state["batchCount"] > 1:
467
- app_state["runnigState"] = "stage2"
468
  app_state, channel_info = mapping_stage2(app_state, channel_info, fill_mode)
 
 
469
  app_state["batchCount"] += 1
470
 
471
  reorder_to_template(app_state, raw_data)
@@ -473,15 +498,30 @@ with gr.Blocks() as demo:
473
  total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
474
  # step2: Signal reconstruction
475
  utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
476
- reorder_to_origin(app_state, channel_info, filepath+'denoised.csv', filepath+output_name)
 
 
 
 
 
477
 
478
- if model_name == "(mapped data)":
479
- return {out_denoised_data : filepath + 'mapped.csv'}
480
- elif model_name == "(denoised data)":
481
- return {out_denoised_data : filepath + 'denoised.csv'}
482
 
483
- return {out_denoised_data : filepath + output_name}
 
 
 
 
 
 
 
484
 
 
 
 
 
 
485
 
486
  if __name__ == "__main__":
487
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
55
  position: relative;
56
  width: 560px;
57
  height: 560px;
58
+ background: url("file=${app_state.filenames.raw_montage}");
59
  `;
60
 
61
 
 
191
  label="Imputation")
192
  map_btn = gr.Button("Mapping")
193
 
 
194
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
195
  next_btn = gr.Button("Next", interactive=False, visible=False)
196
 
 
227
  scale=2)
228
  run_btn = gr.Button(scale=1, interactive=False)
229
  batch_md = gr.Markdown(visible=False)
230
+ out_denoised_data = gr.File(label="Denoised data", visible=False)
231
 
232
 
233
  with gr.Row():
 
244
 
245
  #demo.load(js=js)
246
 
247
+ def reset1(raw_data, samplerate):
248
  # establish temp folder
249
  filepath = os.path.dirname(str(raw_data))
250
  try:
 
258
  data = utils.read_train_data(raw_data)
259
  app_state = {
260
  "filepath": filepath+"/temp_data/",
261
+ "filenames": {},
262
  "sampleRate": int(samplerate),
263
 
264
  }
 
274
  tpl_montage : gr.Image(visible=False),
275
  map_montage : gr.Image(value=None, visible=False),
276
  res_md : gr.Markdown(visible=False),
277
+ batch_md : gr.Markdown(visible=False),
278
+ out_denoised_data : gr.File(visible=False)}
279
 
280
  def mapping_result(app_state, channel_info, fill_mode):
281
 
 
283
  matched_num = 30 - len(channel_info["missingChannelsIndex"])
284
  batch_num = math.ceil((in_num-matched_num)/30) + 1
285
  app_state.update({
 
286
  "batchCount" : 1,
287
  "totalBatchNum" : batch_num
288
  })
 
294
  })
295
  #print("Missing channels:", channel_info["missingChannelsIndex"])
296
  return {app_state_json : app_state,
 
297
  next_btn : gr.Button(visible=True)}
298
  else:
299
+ app_state.update({
300
+ "state" : "finished"
301
+ })
302
  return {app_state_json : app_state,
303
  res_md : gr.Markdown(visible=True),
304
  run_btn : gr.Button(interactive=True)}
 
317
 
318
  if app_state["state"] == "initializing":
319
  filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
320
+ app_state["filenames"]["raw_montage"] = filename
321
  raw_fig = raw_montage.plot()
322
  raw_fig.set_size_inches(5.6, 5.6)
323
  raw_fig.savefig(filename, pad_inches=0)
 
326
 
327
  elif app_state["state"] == "finished":
328
  filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
329
+ app_state["filenames"]["map_montage"] = filename
330
 
331
  show_names= []
332
  for channel in channel_info["inputByName"]:
 
360
 
361
 
362
  map_btn.click(
363
+ fn = reset1,
364
  inputs = [in_raw_data, in_sample_rate],
365
+ outputs = [app_state_json, channel_info_json, chkbox_group, next_btn, run_btn,
366
+ tpl_montage, map_montage, res_md, batch_md, out_denoised_data]
367
 
368
  ).success(
369
  fn = mapping_stage1,
 
401
  prev_target_name = channel_info["templateByIndex"][prev_target_idx]
402
 
403
  selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected]
404
+ app_state["stage1NewOrder"][prev_target_idx] = selected_idx
405
 
406
  #if len(selected)==1 and channel_info["inputByName"][selected[0]]["used"]==False:
407
  #channel_info["inputByName"][selected[0]]["used"] = True
 
450
  outputs = []
451
  )
452
 
453
+ def delete_file(filename):
454
+ try:
455
+ os.remove(filename)
456
+ except OSError as e:
457
+ print(e)
458
+
459
+ def reset2(app_state, raw_data, model_name):
460
  filepath = app_state["filepath"]
 
 
461
  input_name = os.path.basename(str(raw_data))
462
  output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
463
 
464
+ app_state["filenames"]["denoised"] = filepath + output_name
465
+ app_state.update({
466
+ "runnigState" : "stage1",
467
+ "batchCount" : 1,
468
+ "stage2NewOrder" : [[]]*30
469
+ })
470
+
471
+ delete_file(filepath+'mapped.csv')
472
+ delete_file(filepath+'denoised.csv')
473
+ return {app_state_json : app_state,
474
+ run_btn : gr.Button(interactive=False),
475
+ batch_md : gr.Markdown(visible=False),
476
+ out_denoised_data : gr.File(visible=False)}
477
+
478
+ def run_model(app_state, channel_info, raw_data, model_name, fill_mode):
479
+ filepath = app_state["filepath"]
480
+ samplerate = app_state["sampleRate"]
481
+ new_filename = app_state["filenames"]["denoised"]
482
+
483
+ while app_state["runnigState"] != "finished":
484
+ #if app_state["batchCount"] > app_state["totalBatchNum"]:
485
+ #app_state["runnigState"] = "finished"
486
+ #break
487
+ md = 'Running model('+str(app_state["batchCount"])+'/'+str(app_state["totalBatchNum"])+')...'
488
+ yield {batch_md : gr.Markdown(md, visible=True)}
489
+
490
  if app_state["batchCount"] > 1:
 
491
  app_state, channel_info = mapping_stage2(app_state, channel_info, fill_mode)
492
+ if app_state["runnigState"] == "finished":
493
+ break
494
  app_state["batchCount"] += 1
495
 
496
  reorder_to_template(app_state, raw_data)
 
498
  total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
499
  # step2: Signal reconstruction
500
  utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
501
+ reorder_to_origin(app_state, channel_info, filepath+'denoised.csv', new_filename)
502
+
503
+ #if model_name == "(mapped data)":
504
+ #return {out_denoised_data : filepath + 'mapped.csv'}
505
+ #elif model_name == "(denoised data)":
506
+ #return {out_denoised_data : filepath + 'denoised.csv'}
507
 
508
+ delete_file(filepath+'mapped.csv')
509
+ delete_file(filepath+'denoised.csv')
 
 
510
 
511
+ yield {run_btn : gr.Button(interactive=True),
512
+ batch_md : gr.Markdown(visible=False),
513
+ out_denoised_data : gr.File(new_filename, visible=True)}
514
+
515
+ run_btn.click(
516
+ fn = reset2,
517
+ inputs = [app_state_json, in_raw_data, in_model_name],
518
+ outputs = [app_state_json, run_btn, batch_md, out_denoised_data]
519
 
520
+ ).success(
521
+ fn = run_model,
522
+ inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name, in_fill_mode],
523
+ outputs = [run_btn, batch_md, out_denoised_data]
524
+ )
525
 
526
  if __name__ == "__main__":
527
+ demo.launch()
channel_mapping.py CHANGED
@@ -11,10 +11,12 @@ from scipy.optimize import linear_sum_assignment
11
  from sklearn.neighbors import NearestNeighbors
12
 
13
  def reorder_to_template(app_state, filename):
14
- old_idx = app_state["newOrder"]
15
  old_data = utils.read_train_data(filename) # original raw data
16
  new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
17
  new_filename = app_state["filepath"]+'mapped.csv'
 
 
18
 
19
  zero_arr = np.zeros((1, old_data.shape[1]))
20
  old_data = np.concatenate((old_data, zero_arr), axis=0)
@@ -34,7 +36,7 @@ def reorder_to_template(app_state, filename):
34
  return
35
 
36
  def reorder_to_origin(app_state, channel_info, filename, new_filename):
37
- old_idx = app_state["newOrder"]
38
  old_data = utils.read_train_data(filename) # denoised data
39
  template_order = channel_info["templateByIndex"]
40
 
@@ -161,7 +163,7 @@ def align_coords(channel_info, template_montage, input_montage):
161
 
162
  def fill_channels(app_state, channel_info, fill_mode):
163
 
164
- new_idx = app_state["newOrder"]
165
  template_dict = channel_info["templateByName"]
166
  input_dict = channel_info["inputByName"]
167
  template_order = channel_info["templateByIndex"]
@@ -186,16 +188,18 @@ def fill_channels(app_state, channel_info, fill_mode):
186
  knn.fit(in_coords)
187
 
188
  for channel in unmatched:
189
- distances, indices = knn.kneighbors(template_dict[channel]["coord"].reshape(1,-1))
190
  selected = [input_order[i] for i in indices[0]]
191
  print(channel, ':', selected)
192
 
193
  idx = template_dict[channel]["index"]
194
  new_idx[idx] = indices[0].tolist()
195
-
196
- app_state.update({
197
- "newOrder" : new_idx
198
- })
 
 
199
  return app_state
200
 
201
  def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode):
@@ -239,7 +243,8 @@ def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode):
239
  "inputByIndex" : input_montage.ch_names
240
  })
241
  app_state.update({
242
- "newOrder" : new_idx
 
243
  })
244
 
245
  # align input, template's coordinates
@@ -271,12 +276,13 @@ def mapping_stage2(app_state, channel_info, fill_mode):
271
 
272
  # initialize the cost matrix
273
  if len(unassigned) < 30:
274
- cost_matrix = np.full((30, 30), 10000) # add dummy channels to ensure num_col > num_row
275
  else:
276
  cost_matrix = np.zeros((30, len(unassigned)))
277
  for i in range(30):
278
  for j in range(len(unassigned)):
279
- cost_matrix[i][j] = np.linalg.norm(tpl_coords[i] - unassigned_coords[j]) # Euclidean distance
 
280
 
281
  # use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
282
  row_idx, col_idx = linear_sum_assignment(cost_matrix)
@@ -292,13 +298,16 @@ def mapping_stage2(app_state, channel_info, fill_mode):
292
  template_dict[tpl_channel]["matched"] = True
293
  input_dict[in_channel]["assigned"] = True
294
  new_idx[i] = [input_dict[in_channel]["index"]]
 
 
295
 
296
  channel_info.update({
297
  "templateByName" : template_dict,
298
  "inputByName" : input_dict
299
  })
300
  app_state.update({
301
- "newOrder" : new_idx
 
302
  })
303
 
304
  # fill the unmatched channels
 
11
  from sklearn.neighbors import NearestNeighbors
12
 
13
  def reorder_to_template(app_state, filename):
14
+ old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
15
  old_data = utils.read_train_data(filename) # original raw data
16
  new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
17
  new_filename = app_state["filepath"]+'mapped.csv'
18
+ #print('new order 1:', app_state["stage1NewOrder"])
19
+ #print('new order 2:', app_state["stage2NewOrder"])
20
 
21
  zero_arr = np.zeros((1, old_data.shape[1]))
22
  old_data = np.concatenate((old_data, zero_arr), axis=0)
 
36
  return
37
 
38
  def reorder_to_origin(app_state, channel_info, filename, new_filename):
39
+ old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
40
  old_data = utils.read_train_data(filename) # denoised data
41
  template_order = channel_info["templateByIndex"]
42
 
 
163
 
164
  def fill_channels(app_state, channel_info, fill_mode):
165
 
166
+ new_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
167
  template_dict = channel_info["templateByName"]
168
  input_dict = channel_info["inputByName"]
169
  template_order = channel_info["templateByIndex"]
 
188
  knn.fit(in_coords)
189
 
190
  for channel in unmatched:
191
+ distances, indices = knn.kneighbors(np.array(template_dict[channel]["coord"]).reshape(1,-1))
192
  selected = [input_order[i] for i in indices[0]]
193
  print(channel, ':', selected)
194
 
195
  idx = template_dict[channel]["index"]
196
  new_idx[idx] = indices[0].tolist()
197
+
198
+ if app_state["runnigState"] == "stage1":
199
+ app_state["stage1NewOrder"] = new_idx
200
+ else:
201
+ app_state["stage2NewOrder"] = new_idx
202
+
203
  return app_state
204
 
205
  def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode):
 
243
  "inputByIndex" : input_montage.ch_names
244
  })
245
  app_state.update({
246
+ "stage1NewOrder" : new_idx,
247
+ "runnigState" : "stage1"
248
  })
249
 
250
  # align input, template's coordinates
 
276
 
277
  # initialize the cost matrix
278
  if len(unassigned) < 30:
279
+ cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col > num_row
280
  else:
281
  cost_matrix = np.zeros((30, len(unassigned)))
282
  for i in range(30):
283
  for j in range(len(unassigned)):
284
+ cost_matrix[i][j] = np.linalg.norm((tpl_coords[i]-unassigned_coords[j])*1000) # Euclidean distance
285
+ #print(cost_matrix[i][j], tpl_coords[i] - unassigned_coords[j])
286
 
287
  # use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
288
  row_idx, col_idx = linear_sum_assignment(cost_matrix)
 
298
  template_dict[tpl_channel]["matched"] = True
299
  input_dict[in_channel]["assigned"] = True
300
  new_idx[i] = [input_dict[in_channel]["index"]]
301
+
302
+ print(template_order[row_idx[i]], '<-', unassigned[col_idx[i]])
303
 
304
  channel_info.update({
305
  "templateByName" : template_dict,
306
  "inputByName" : input_dict
307
  })
308
  app_state.update({
309
+ "stage2NewOrder" : new_idx,
310
+ "runnigState" : "stage2"
311
  })
312
 
313
  # fill the unmatched channels