audrey06100 commited on
Commit
884c10b
·
1 Parent(s): 4e5fc37
Files changed (3) hide show
  1. app.py +186 -130
  2. channel_mapping.py +49 -44
  3. utils.py +1 -2
app.py CHANGED
@@ -3,7 +3,6 @@ import numpy as np
3
  import os
4
  import random
5
  import math
6
- import json
7
  import utils
8
  from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin
9
 
@@ -46,48 +45,63 @@ Electroencephalography (EEG) signals are often contaminated with artifacts. It i
46
  """
47
 
48
  chkbox_js = """
49
- (state_json) => {
50
- state_json = JSON.parse(JSON.stringify(state_json));
51
- if(state_json.state == "finished") return;
 
52
 
53
  // add figure of in_montage
54
  document.querySelector("#chkbox-group> div:nth-of-type(2)").style.cssText = `
55
  position: relative;
56
  width: 560px;
57
  height: 560px;
58
- background: url("file=${state_json.files.raw_montage}");
59
  `;
60
 
61
- // add indication for the missing channel
62
- /*
63
- let indicator = document.getElementById("indicator")
64
- if(!indicator) document.querySelector("#chkbox-group> div:nth-of-type(2)").innerHTML += '<div id="indicator"></div>'
65
 
66
- let channel = state_json.missingChannelsIndex[0]
67
- channel = state_json.templateByIndex[channel]
68
- let left = state_json.templateByName[channel].css_position[0];
69
- let bottom = state_json.templateByName[channel].css_position[1];
 
70
 
71
- document.getElementById("red-dot").style.cssText = `
72
- position: absolute;
73
- background-color: red;
74
- width: 10px;
75
- height: 10px;
76
- border-radius: 50%;
77
- left: ${left};
78
- bottom: ${bottom};
 
 
 
79
  `;
80
- */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  // move the checkboxes
83
  let all_chkbox = document.querySelectorAll("#chkbox-group> div:nth-of-type(2)> label");
84
- all_chkbox = Array.apply(null, all_chkbox);
85
 
86
- all_chkbox.forEach((item, index) => {
87
- channel = state_json.inputByIndex[index];
88
- left = state_json.inputByName[channel].css_position[0];
89
- bottom = state_json.inputByName[channel].css_position[1];
90
- //console.log(`left: ${left}, bottom: ${bottom}`);
91
 
92
  item.style.cssText = `
93
  position: absolute;
@@ -96,36 +110,55 @@ chkbox_js = """
96
  `;
97
  item.className = "";
98
  item.querySelector(":scope> span").innerText = "";
99
- });
100
  }
101
  """
102
 
103
- dot_js = """
104
- (state_json) => {
105
- state_json = JSON.parse(JSON.stringify(state_json));
106
- if(state_json.state == "finished") return;
 
107
 
108
- let channel = state_json.missingChannelsIndex[state_json["fillingCount"]]
109
- channel = state_json.templateByIndex[channel]
110
- let left = state_json.templateByName[channel].css_position[0];
111
- let bottom = state_json.templateByName[channel].css_position[1];
112
 
113
- document.getElementById("indicator").style.cssText = `
114
- position: absolute;
115
- background-color: red;
116
- width: 10px;
117
- height: 10px;
118
- border-radius: 50%;
119
- left: ${left};
120
- bottom: ${bottom};
 
 
 
121
  `;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  }
123
  """
124
 
125
 
126
  with gr.Blocks() as demo:
127
 
128
- state_json = gr.JSON(visible=False)
 
129
 
130
  with gr.Row():
131
  gr.Markdown(
@@ -158,6 +191,7 @@ with gr.Blocks() as demo:
158
  label="Imputation")
159
  map_btn = gr.Button("Mapping")
160
 
 
161
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
162
  next_btn = gr.Button("Next", interactive=False, visible=False)
163
 
@@ -171,6 +205,9 @@ with gr.Blocks() as demo:
171
  with gr.Row():
172
  tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
173
  map_montage = gr.Image(label="Matched channels", visible=False)
 
 
 
174
 
175
  with gr.Column():
176
  gr.Markdown(
@@ -218,14 +255,20 @@ with gr.Blocks() as demo:
218
  os.mkdir(filepath+"/temp_data/")
219
  #print(e)
220
 
 
221
  data = utils.read_train_data(raw_data)
222
- state = {
223
  "filepath": filepath+"/temp_data/",
224
  "files": {},
225
  "sampleRate": int(samplerate),
 
 
 
226
  "dataShape" : data.shape
227
  }
228
- return {state_json : state,
 
 
229
  chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
230
  next_btn : gr.Button("Next", interactive=False, visible=False),
231
  run_btn : gr.Button(interactive=False),
@@ -234,36 +277,38 @@ with gr.Blocks() as demo:
234
  res_md : gr.Markdown(visible=False),
235
  batch_md : gr.Markdown(visible=False)}
236
 
237
- def mapping_result(state, fill_mode):
238
 
239
- in_num = len(state["inputByName"])
240
- matched_num = 30 - len(state["missingChannelsIndex"])
241
  batch_num = math.ceil((in_num-matched_num)/30) + 1
242
- state.update({
243
  "runnigState" : "stage1",
244
  "batchCount" : 1,
245
  "totalBatchNum" : batch_num
246
  })
247
 
248
- if fill_mode=="mean_manual" and state["missingChannelsIndex"]!=[]:
249
- state.update({
250
  "state" : "initializing",
251
- "fillingCount" : 0,
252
- "totalFillingNum" : len(state["missingChannelsIndex"])-1
253
  })
254
- #print("Missing channels:", state["missingChannelsIndex"])
255
- return {state_json : state,
256
  #chkbox_group : gr.CheckboxGroup(visible=True),
257
  next_btn : gr.Button(visible=True)}
258
  else:
259
- state["state"] = "finished"
260
 
261
- return {state_json : state,
262
  res_md : gr.Markdown(visible=True),
263
  run_btn : gr.Button(interactive=True)}
264
 
265
- def show_montage(state, raw_loc):
266
- filepath = state["filepath"]
 
 
 
267
  raw_montage = read_custom_montage(raw_loc)
268
 
269
  # convert all channel names to uppercase
@@ -271,112 +316,118 @@ with gr.Blocks() as demo:
271
  channel = raw_montage.ch_names[i]
272
  raw_montage.rename_channels({channel: str.upper(channel)})
273
 
274
- if state["state"] == "initializing":
275
  filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
276
- state["files"]["raw_montage"] = filename
277
  raw_fig = raw_montage.plot()
278
  raw_fig.set_size_inches(5.6, 5.6)
279
  raw_fig.savefig(filename, pad_inches=0)
280
 
281
- return {state_json : state}
282
 
283
- elif state["state"] == "finished":
284
  filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
285
- state["files"]["map_montage"] = filename
286
 
287
  show_names= []
288
- for channel in state["inputByName"]:
289
- if state["inputByName"][channel]["matched"]:
290
  show_names.append(channel)
291
  mapped_fig = raw_montage.plot(show_names=show_names)
292
  mapped_fig.set_size_inches(5.6, 5.6)
293
  mapped_fig.savefig(filename, pad_inches=0)
294
 
295
- return {state_json : state,
296
  tpl_montage : gr.Image(visible=True),
297
  map_montage : gr.Image(value=filename, visible=True)}
298
 
299
- else:
300
- return {state_json : state} # change nothing
301
 
302
- def generate_chkbox(state):
303
- if state["state"] == "initializing":
304
- in_channels = [channel for channel in state["inputByName"]]
305
- state["state"] = "selecting"
 
306
 
307
- first_idx = state["missingChannelsIndex"][0]
308
- first_name = state["templateByIndex"][first_idx]
309
- chkbox_label = first_name+' (1/'+str(state["totalFillingNum"]+1)+')'
310
- return {state_json : state,
311
  chkbox_group : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True),
312
  next_btn : gr.Button(interactive=True)}
313
  else:
314
- return {state_json : state} # change nothing
315
 
316
 
317
  map_btn.click(
318
  fn = reset_layout,
319
  inputs = [in_raw_data, in_sample_rate],
320
- outputs = [state_json, chkbox_group, next_btn, run_btn, tpl_montage, map_montage, res_md, batch_md]
321
 
322
  ).success(
323
  fn = mapping_stage1,
324
- inputs = [state_json, in_raw_data, in_raw_loc, in_fill_mode],
325
- outputs = state_json
326
 
327
  ).success(
328
  fn = mapping_result,
329
- inputs = [state_json, in_fill_mode],
330
- outputs = [state_json, next_btn, res_md, run_btn]
331
 
332
  ).success(
333
  fn = show_montage,
334
- inputs = [state_json, in_raw_loc],
335
- outputs = [state_json, tpl_montage, map_montage]
336
 
337
  ).success(
338
  fn = generate_chkbox,
339
- inputs = state_json,
340
- outputs = [state_json, chkbox_group, next_btn]
341
 
342
  ).success(
343
  fn = None,
344
  js = chkbox_js,
345
- inputs = state_json,
346
  outputs = []
347
  )
348
 
349
 
350
- def check_next(state, selected, raw_data, fill_mode):
351
  #if state["state"] == "selecting":
352
 
353
  # save info before clicking on next_btn
354
- prev_target_idx = state["missingChannelsIndex"][state["fillingCount"]]
355
- prev_target_name = state["templateByIndex"][prev_target_idx]
356
 
357
- selected_idx = [state["inputByName"][channel]["index"] for channel in selected]
358
- state["newOrder"][prev_target_idx] = selected_idx
359
 
360
- #if len(selected)==1 and state["inputByName"][selected[0]]["used"]==False:
361
- #state["inputByName"][selected[0]]["used"] = True
362
- #state["missingChannelsIndex"][state["fillingCount"]] = -1
363
 
364
  print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
365
 
366
  # update next round
367
- state["fillingCount"] += 1
368
- if state["fillingCount"] <= state["totalFillingNum"]:
369
- target_idx = state["missingChannelsIndex"][state["fillingCount"]]
370
- target_name = state["templateByIndex"][target_idx]
371
- chkbox_label = target_name+' ('+str(state["fillingCount"]+1)+'/'+str(state["totalFillingNum"]+1)+')'
372
- btn_label = "Submit" if state["fillingCount"]==state["totalFillingNum"] else "Next"
 
 
373
 
374
- return {state_json : state,
 
375
  chkbox_group : gr.CheckboxGroup(value=[], label=chkbox_label),
376
  next_btn : gr.Button(btn_label)}
377
  else:
378
- state["state"] = "finished"
379
- return {state_json : state,
 
 
380
  chkbox_group : gr.CheckboxGroup(visible=False),
381
  next_btn : gr.Button(visible=False),
382
  res_md : gr.Markdown(visible=True),
@@ -384,48 +435,53 @@ with gr.Blocks() as demo:
384
 
385
  next_btn.click(
386
  fn = check_next,
387
- inputs = [state_json, chkbox_group, in_raw_data, in_fill_mode],
388
- outputs = [state_json, chkbox_group, next_btn, run_btn, res_md]
389
 
390
  ).success(
391
- fn = show_montage,
392
- inputs = [state_json, in_raw_loc],
393
- outputs = [state_json, tpl_montage, map_montage]
 
 
 
 
 
 
394
  )
395
 
396
- @run_btn.click(inputs = [state_json, in_raw_data, in_model_name, in_fill_mode], outputs = out_denoised_data)
397
- def run_model(state, raw_data, model_name, fill_mode):
398
- filepath = state["filepath"]
399
- samplerate = state["sampleRate"]
400
-
401
- #if batch > total_batch:
402
- #return {batch_md : gr.Markdown("error", visible=True)}
403
 
404
  input_name = os.path.basename(str(raw_data))
405
  output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
406
 
407
- while(state["runnigState"] != "finished"):
408
- if state["batchCount"] > state["totalBatchNum"]:
 
409
  break
410
- if state["batchCount"] > 1:
411
- state["runnigState"] = "stage2"
412
- state = mapping_stage2(state, fill_mode)
413
- state["batchCount"] += 1
414
 
415
- reorder_to_template(state, raw_data)
416
  # step1: Data preprocessing
417
  total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
418
  # step2: Signal reconstruction
419
  utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
420
- reorder_to_origin(state, filepath+'denoised.csv', filepath+output_name)
421
 
422
  if model_name == "(mapped data)":
423
  return {out_denoised_data : filepath + 'mapped.csv'}
424
  elif model_name == "(denoised data)":
425
  return {out_denoised_data : filepath + 'denoised.csv'}
426
-
427
  return {out_denoised_data : filepath + output_name}
428
 
429
 
430
  if __name__ == "__main__":
431
- demo.launch()
 
3
  import os
4
  import random
5
  import math
 
6
  import utils
7
  from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin
8
 
 
45
  """
46
 
47
  chkbox_js = """
48
+ (app_state, channel_info) => {
49
+ app_state = JSON.parse(JSON.stringify(app_state));
50
+ channel_info = JSON.parse(JSON.stringify(channel_info));
51
+ if(app_state.state == "finished") return;
52
 
53
  // add figure of in_montage
54
  document.querySelector("#chkbox-group> div:nth-of-type(2)").style.cssText = `
55
  position: relative;
56
  width: 560px;
57
  height: 560px;
58
+ background: url("file=${app_state.files.raw_montage}");
59
  `;
60
 
 
 
 
 
61
 
62
+ // add indication for the missing channels
63
+ let channel = channel_info.missingChannelsIndex[0]
64
+ channel = channel_info.templateByIndex[channel]
65
+ let left = channel_info.templateByName[channel].css_position[0];
66
+ let bottom = channel_info.templateByName[channel].css_position[1];
67
 
68
+ let rule = `
69
+ #chkbox-group> div:nth-of-type(2)::after{
70
+ content: '';
71
+ position: absolute;
72
+ background-color: red;
73
+ width: 10px;
74
+ height: 10px;
75
+ border-radius: 50%;
76
+ left: ${left};
77
+ bottom: ${bottom};
78
+ }
79
  `;
80
+
81
+ // check if indicator already exist
82
+ let exist = 0;
83
+ const styleSheet = document.styleSheets[0];
84
+ for(let i=0; i<styleSheet.cssRules.length; i++){
85
+ if(styleSheet.cssRules[i].selectorText == "#chkbox-group> div:nth-of-type(2)::after"){
86
+ exist = 1;
87
+ console.log('exist!');
88
+ styleSheet.deleteRule(i);
89
+ styleSheet.insertRule(rule, styleSheet.cssRules.length);
90
+ break;
91
+ }
92
+ }
93
+ if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length);
94
+
95
 
96
  // move the checkboxes
97
  let all_chkbox = document.querySelectorAll("#chkbox-group> div:nth-of-type(2)> label");
98
+ //all_chkbox = Array.apply(null, all_chkbox);
99
 
100
+ Array.from(all_chkbox).forEach((item, index) => {
101
+ channel = channel_info.inputByIndex[index];
102
+ left = channel_info.inputByName[channel].css_position[0];
103
+ bottom = channel_info.inputByName[channel].css_position[1];
104
+ console.log(`left: ${left}, bottom: ${bottom}`);
105
 
106
  item.style.cssText = `
107
  position: absolute;
 
110
  `;
111
  item.className = "";
112
  item.querySelector(":scope> span").innerText = "";
113
+ });
114
  }
115
  """
116
 
117
+ indication_js = """
118
+ (app_state, channel_info) => {
119
+ app_state = JSON.parse(JSON.stringify(app_state));
120
+ channel_info = JSON.parse(JSON.stringify(channel_info));
121
+ if(app_state.state == "finished") return;
122
 
123
+ let channel = channel_info.missingChannelsIndex[app_state["fillingCount"]-1]
124
+ channel = channel_info.templateByIndex[channel]
125
+ let left = channel_info.templateByName[channel].css_position[0];
126
+ let bottom = channel_info.templateByName[channel].css_position[1];
127
 
128
+ let rule = `
129
+ #chkbox-group> div:nth-of-type(2)::after{
130
+ content: '';
131
+ position: absolute;
132
+ background-color: red;
133
+ width: 10px;
134
+ height: 10px;
135
+ border-radius: 50%;
136
+ left: ${left};
137
+ bottom: ${bottom};
138
+ }
139
  `;
140
+
141
+ // check if indicator already exist
142
+ let exist = 0;
143
+ const styleSheet = document.styleSheets[0];
144
+ for(let i=0; i<styleSheet.cssRules.length; i++){
145
+ if(styleSheet.cssRules[i].selectorText == "#chkbox-group> div:nth-of-type(2)::after"){
146
+ exist = 1;
147
+ console.log('exist!');
148
+ styleSheet.deleteRule(i);
149
+ styleSheet.insertRule(rule, styleSheet.cssRules.length);
150
+ break;
151
+ }
152
+ }
153
+ if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length);
154
  }
155
  """
156
 
157
 
158
  with gr.Blocks() as demo:
159
 
160
+ app_state_json = gr.JSON(visible=False)
161
+ channel_info_json = gr.JSON(visible=False)
162
 
163
  with gr.Row():
164
  gr.Markdown(
 
191
  label="Imputation")
192
  map_btn = gr.Button("Mapping")
193
 
194
+ #indic
195
  chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
196
  next_btn = gr.Button("Next", interactive=False, visible=False)
197
 
 
205
  with gr.Row():
206
  tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
207
  map_montage = gr.Image(label="Matched channels", visible=False)
208
+
209
+ #miss_txtbox = gr.Textbox(label="Missing channels", visible=False)
210
+ #tpl_loc_file = gr.File("./template_chanlocs.loc", show_label=False, visible=False)
211
 
212
  with gr.Column():
213
  gr.Markdown(
 
255
  os.mkdir(filepath+"/temp_data/")
256
  #print(e)
257
 
258
+ # initialize app_state, channel_info
259
  data = utils.read_train_data(raw_data)
260
+ app_state = {
261
  "filepath": filepath+"/temp_data/",
262
  "files": {},
263
  "sampleRate": int(samplerate),
264
+
265
+ }
266
+ channel_info = {
267
  "dataShape" : data.shape
268
  }
269
+
270
+ return {app_state_json : app_state,
271
+ channel_info_json : channel_info,
272
  chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
273
  next_btn : gr.Button("Next", interactive=False, visible=False),
274
  run_btn : gr.Button(interactive=False),
 
277
  res_md : gr.Markdown(visible=False),
278
  batch_md : gr.Markdown(visible=False)}
279
 
280
+ def mapping_result(app_state, channel_info, fill_mode):
281
 
282
+ in_num = len(channel_info["inputByName"])
283
+ matched_num = 30 - len(channel_info["missingChannelsIndex"])
284
  batch_num = math.ceil((in_num-matched_num)/30) + 1
285
+ app_state.update({
286
  "runnigState" : "stage1",
287
  "batchCount" : 1,
288
  "totalBatchNum" : batch_num
289
  })
290
 
291
+ if fill_mode=="mean_manual" and channel_info["missingChannelsIndex"]!=[]:
292
+ app_state.update({
293
  "state" : "initializing",
294
+ "totalFillingNum" : len(channel_info["missingChannelsIndex"])
 
295
  })
296
+ #print("Missing channels:", channel_info["missingChannelsIndex"])
297
+ return {app_state_json : app_state,
298
  #chkbox_group : gr.CheckboxGroup(visible=True),
299
  next_btn : gr.Button(visible=True)}
300
  else:
301
+ app_state["state"] = "finished"
302
 
303
+ return {app_state_json : app_state,
304
  res_md : gr.Markdown(visible=True),
305
  run_btn : gr.Button(interactive=True)}
306
 
307
+ def show_montage(app_state, channel_info, raw_loc):
308
+ if app_state["state"] == "selecting":
309
+ return {app_state_json : app_state} # change nothing
310
+
311
+ filepath = app_state["filepath"]
312
  raw_montage = read_custom_montage(raw_loc)
313
 
314
  # convert all channel names to uppercase
 
316
  channel = raw_montage.ch_names[i]
317
  raw_montage.rename_channels({channel: str.upper(channel)})
318
 
319
+ if app_state["state"] == "initializing":
320
  filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
321
+ app_state["files"]["raw_montage"] = filename
322
  raw_fig = raw_montage.plot()
323
  raw_fig.set_size_inches(5.6, 5.6)
324
  raw_fig.savefig(filename, pad_inches=0)
325
 
326
+ return {app_state_json : app_state}
327
 
328
+ elif app_state["state"] == "finished":
329
  filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
330
+ app_state["files"]["map_montage"] = filename
331
 
332
  show_names= []
333
+ for channel in channel_info["inputByName"]:
334
+ if channel_info["inputByName"][channel]["matched"]:
335
  show_names.append(channel)
336
  mapped_fig = raw_montage.plot(show_names=show_names)
337
  mapped_fig.set_size_inches(5.6, 5.6)
338
  mapped_fig.savefig(filename, pad_inches=0)
339
 
340
+ return {app_state_json : app_state,
341
  tpl_montage : gr.Image(visible=True),
342
  map_montage : gr.Image(value=filename, visible=True)}
343
 
344
+ #else:
345
+ #return {app_state_json : app_state} # change nothing
346
 
347
+ def generate_chkbox(app_state, channel_info):
348
+ if app_state["state"] == "initializing":
349
+ in_channels = [channel for channel in channel_info["inputByName"]]
350
+ app_state["state"] = "selecting"
351
+ app_state["fillingCount"] = 1
352
 
353
+ idx = channel_info["missingChannelsIndex"][0]
354
+ name = channel_info["templateByIndex"][idx]
355
+ chkbox_label = name+' (1/'+str(app_state["totalFillingNum"])+')'
356
+ return {app_state_json : app_state,
357
  chkbox_group : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True),
358
  next_btn : gr.Button(interactive=True)}
359
  else:
360
+ return {app_state_json : app_state} # change nothing
361
 
362
 
363
  map_btn.click(
364
  fn = reset_layout,
365
  inputs = [in_raw_data, in_sample_rate],
366
+ outputs = [app_state_json, channel_info_json, chkbox_group, next_btn, run_btn, tpl_montage, map_montage, res_md, batch_md]
367
 
368
  ).success(
369
  fn = mapping_stage1,
370
+ inputs = [app_state_json, channel_info_json, in_raw_data, in_raw_loc, in_fill_mode],
371
+ outputs = [app_state_json, channel_info_json]
372
 
373
  ).success(
374
  fn = mapping_result,
375
+ inputs = [app_state_json, channel_info_json, in_fill_mode],
376
+ outputs = [app_state_json, next_btn, res_md, run_btn]
377
 
378
  ).success(
379
  fn = show_montage,
380
+ inputs = [app_state_json, channel_info_json, in_raw_loc],
381
+ outputs = [app_state_json, tpl_montage, map_montage]
382
 
383
  ).success(
384
  fn = generate_chkbox,
385
+ inputs = [app_state_json, channel_info_json],
386
+ outputs = [app_state_json, chkbox_group, next_btn]
387
 
388
  ).success(
389
  fn = None,
390
  js = chkbox_js,
391
+ inputs = [app_state_json, channel_info_json],
392
  outputs = []
393
  )
394
 
395
 
396
+ def check_next(app_state, channel_info, selected, raw_data, fill_mode):
397
  #if state["state"] == "selecting":
398
 
399
  # save info before clicking on next_btn
400
+ prev_target_idx = channel_info["missingChannelsIndex"][app_state["fillingCount"]-1]
401
+ prev_target_name = channel_info["templateByIndex"][prev_target_idx]
402
 
403
+ selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected]
404
+ app_state["newOrder"][prev_target_idx] = selected_idx
405
 
406
+ #if len(selected)==1 and channel_info["inputByName"][selected[0]]["used"]==False:
407
+ #channel_info["inputByName"][selected[0]]["used"] = True
408
+ #channel_info["missingChannelsIndex"][state["fillingCount"]-1] = -1
409
 
410
  print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
411
 
412
  # update next round
413
+ app_state["fillingCount"] += 1
414
+
415
+ if app_state["fillingCount"] <= app_state["totalFillingNum"]:
416
+ target_idx = channel_info["missingChannelsIndex"][app_state["fillingCount"]-1]
417
+ target_name = channel_info["templateByIndex"][target_idx]
418
+
419
+ chkbox_label = target_name+' ('+str(app_state["fillingCount"])+'/'+str(app_state["totalFillingNum"])+')'
420
+ btn_label = "Submit" if app_state["fillingCount"]==app_state["totalFillingNum"] else "Next"
421
 
422
+ return {app_state_json : app_state,
423
+ #channel_info_json : channel_info,
424
  chkbox_group : gr.CheckboxGroup(value=[], label=chkbox_label),
425
  next_btn : gr.Button(btn_label)}
426
  else:
427
+ app_state["state"] = "finished"
428
+
429
+ return {app_state_json : app_state,
430
+ #channel_info_json : channel_info,
431
  chkbox_group : gr.CheckboxGroup(visible=False),
432
  next_btn : gr.Button(visible=False),
433
  res_md : gr.Markdown(visible=True),
 
435
 
436
  next_btn.click(
437
  fn = check_next,
438
+ inputs = [app_state_json, channel_info_json, chkbox_group, in_raw_data, in_fill_mode],
439
+ outputs = [app_state_json, chkbox_group, next_btn, run_btn, res_md]
440
 
441
  ).success(
442
+ fn = show_montage,
443
+ inputs = [app_state_json, channel_info_json, in_raw_loc],
444
+ outputs = [app_state_json, tpl_montage, map_montage]
445
+
446
+ ).success(
447
+ fn = None,
448
+ js = indication_js,
449
+ inputs = [app_state_json, channel_info_json],
450
+ outputs = []
451
  )
452
 
453
+ @run_btn.click(inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name, in_fill_mode],
454
+ outputs = [batch_md, out_denoised_data])
455
+ def run_model(app_state, channel_info, raw_data, model_name, fill_mode):
456
+ filepath = app_state["filepath"]
457
+ samplerate = app_state["sampleRate"]
 
 
458
 
459
  input_name = os.path.basename(str(raw_data))
460
  output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
461
 
462
+ while(app_state["runnigState"] != "finished"):
463
+ if app_state["batchCount"] > app_state["totalBatchNum"]:
464
+ app_state["runnigState"] = "finished"
465
  break
466
+ if app_state["batchCount"] > 1:
467
+ app_state["runnigState"] = "stage2"
468
+ app_state, channel_info = mapping_stage2(app_state, channel_info, fill_mode)
469
+ app_state["batchCount"] += 1
470
 
471
+ reorder_to_template(app_state, raw_data)
472
  # step1: Data preprocessing
473
  total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
474
  # step2: Signal reconstruction
475
  utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
476
+ reorder_to_origin(app_state, channel_info, filepath+'denoised.csv', filepath+output_name)
477
 
478
  if model_name == "(mapped data)":
479
  return {out_denoised_data : filepath + 'mapped.csv'}
480
  elif model_name == "(denoised data)":
481
  return {out_denoised_data : filepath + 'denoised.csv'}
482
+
483
  return {out_denoised_data : filepath + output_name}
484
 
485
 
486
  if __name__ == "__main__":
487
+ demo.launch(server_name="0.0.0.0", server_port=7860)
channel_mapping.py CHANGED
@@ -10,11 +10,11 @@ from scipy.interpolate import Rbf
10
  from scipy.optimize import linear_sum_assignment
11
  from sklearn.neighbors import NearestNeighbors
12
 
13
- def reorder_to_template(state, filename):
14
- old_idx = state["newOrder"]
15
  old_data = utils.read_train_data(filename) # original raw data
16
  new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
17
- new_filename = state["filepath"]+'mapped.csv'
18
 
19
  zero_arr = np.zeros((1, old_data.shape[1]))
20
  old_data = np.concatenate((old_data, zero_arr), axis=0)
@@ -33,13 +33,13 @@ def reorder_to_template(state, filename):
33
  utils.save_data(new_data, new_filename)
34
  return
35
 
36
- def reorder_to_origin(state, filename, new_filename):
37
- old_idx = state["newOrder"]
38
  old_data = utils.read_train_data(filename) # denoised data
39
- template_order = state["templateByIndex"]
40
 
41
- if state["runnigState"] == "stage1":
42
- new_data = np.zeros((len(state["inputByName"]), old_data.shape[1]))
43
  else:
44
  new_data = utils.read_train_data(new_filename)
45
 
@@ -47,7 +47,7 @@ def reorder_to_origin(state, filename, new_filename):
47
  idx_set = old_idx[i]
48
 
49
  # ignore if this channel doesn't exist
50
- if len(idx_set)==1 and state["templateByName"][channel]["matched"]==True:
51
  new_data[idx_set[0], :] = old_data[i, :]
52
 
53
  print('old.shape, new.shape: ', old_data.shape, new_data.shape)
@@ -86,12 +86,12 @@ def read_montage_data(loc_file):
86
 
87
  return template_montage, input_montage, template_dict, input_dict
88
 
89
- def align_coords(state, template_montage, input_montage):
90
 
91
- template_dict = state["templateByName"]
92
- input_dict = state["inputByName"]
93
- template_order = state["templateByIndex"]
94
- input_order = state["inputByIndex"]
95
  matched = [channel for channel in input_dict if input_dict[channel]["matched"]==True]
96
 
97
  # 2-d (fot the indication of missing template channel's position when fill_mode:'mean_manual')
@@ -153,24 +153,23 @@ def align_coords(state, template_montage, input_montage):
153
  for i, channel in enumerate(input_order):
154
  input_dict[channel]["coord"] = transformed_in[i].tolist()
155
 
156
- state.update({
157
  "templateByName" : template_dict,
158
  "inputByName" : input_dict,
159
  })
160
-
161
- return state
162
 
163
- def fill_channels(state, fill_mode):
164
 
165
- new_idx = state["newOrder"]
166
- template_dict = state["templateByName"]
167
- input_dict = state["inputByName"]
168
- template_order = state["templateByIndex"]
169
- input_order = state["inputByIndex"]
170
- z_row_idx = state["dataShape"][0]
171
  unmatched = [channel for channel in template_dict if template_dict[channel]["matched"]==False]
172
  if unmatched == []:
173
- return state
174
 
175
  if fill_mode == 'zero':
176
  for channel in unmatched:
@@ -194,10 +193,12 @@ def fill_channels(state, fill_mode):
194
  idx = template_dict[channel]["index"]
195
  new_idx[idx] = indices[0].tolist()
196
 
197
- state["newOrder"] = new_idx
198
- return state
 
 
199
 
200
- def mapping_stage1(state, data_file, loc_file, fill_mode):
201
  second1 = time.time()
202
 
203
  template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
@@ -230,34 +231,36 @@ def mapping_stage1(state, data_file, loc_file, fill_mode):
230
  else:
231
  missing_channels.append(i)
232
 
233
- state.update({
234
- "newOrder" : new_idx,
235
  "missingChannelsIndex" : missing_channels,
236
  "templateByName" : {k : v.__dict__ for k,v in template_dict.items()},
237
  "inputByName" : {k : v.__dict__ for k,v in input_dict.items()},
238
  "templateByIndex" : template_montage.ch_names,
239
  "inputByIndex" : input_montage.ch_names
240
  })
 
 
 
241
 
242
  # align input, template's coordinates
243
- state = align_coords(state, template_montage, input_montage)
244
  # fill the unmatched channels
245
- state = fill_channels(state, fill_mode)
246
 
247
  second2 = time.time()
248
  print('Mapping (stage1) finished in',second2 - second1,'s.')
249
- return state
250
 
251
- def mapping_stage2(state, fill_mode):
252
  second1 = time.time()
253
 
254
- template_dict = state["templateByName"]
255
- input_dict = state["inputByName"]
256
- template_order = state["templateByIndex"]
257
  unassigned = [channel for channel in input_dict if input_dict[channel]["assigned"]==False]
258
  if unassigned == []:
259
- state["runnigState"] = "finished"
260
- return state
261
 
262
  tpl_coords = np.array([template_dict[channel]["coord"] for channel in template_order])
263
  unassigned_coords = np.array([input_dict[channel]["coord"] for channel in unassigned])
@@ -290,16 +293,18 @@ def mapping_stage2(state, fill_mode):
290
  input_dict[in_channel]["assigned"] = True
291
  new_idx[i] = [input_dict[in_channel]["index"]]
292
 
293
- state.update({
294
- "newOrder" : new_idx,
295
  "templateByName" : template_dict,
296
  "inputByName" : input_dict
297
  })
 
 
 
298
 
299
  # fill the unmatched channels
300
- state = fill_channels(state, fill_mode)
301
 
302
  second2 = time.time()
303
- print(f'Mapping (stage2-{state["batchCount"]-1}) finished in {second2 - second1}s.')
304
- return state
305
 
 
10
  from scipy.optimize import linear_sum_assignment
11
  from sklearn.neighbors import NearestNeighbors
12
 
13
+ def reorder_to_template(app_state, filename):
14
+ old_idx = app_state["newOrder"]
15
  old_data = utils.read_train_data(filename) # original raw data
16
  new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
17
+ new_filename = app_state["filepath"]+'mapped.csv'
18
 
19
  zero_arr = np.zeros((1, old_data.shape[1]))
20
  old_data = np.concatenate((old_data, zero_arr), axis=0)
 
33
  utils.save_data(new_data, new_filename)
34
  return
35
 
36
+ def reorder_to_origin(app_state, channel_info, filename, new_filename):
37
+ old_idx = app_state["newOrder"]
38
  old_data = utils.read_train_data(filename) # denoised data
39
+ template_order = channel_info["templateByIndex"]
40
 
41
+ if app_state["runnigState"] == "stage1":
42
+ new_data = np.zeros((len(channel_info["inputByName"]), old_data.shape[1]))
43
  else:
44
  new_data = utils.read_train_data(new_filename)
45
 
 
47
  idx_set = old_idx[i]
48
 
49
  # ignore if this channel doesn't exist
50
+ if len(idx_set)==1 and channel_info["templateByName"][channel]["matched"]==True:
51
  new_data[idx_set[0], :] = old_data[i, :]
52
 
53
  print('old.shape, new.shape: ', old_data.shape, new_data.shape)
 
86
 
87
  return template_montage, input_montage, template_dict, input_dict
88
 
89
+ def align_coords(channel_info, template_montage, input_montage):
90
 
91
+ template_dict = channel_info["templateByName"]
92
+ input_dict = channel_info["inputByName"]
93
+ template_order = channel_info["templateByIndex"]
94
+ input_order = channel_info["inputByIndex"]
95
  matched = [channel for channel in input_dict if input_dict[channel]["matched"]==True]
96
 
97
  # 2-d (fot the indication of missing template channel's position when fill_mode:'mean_manual')
 
153
  for i, channel in enumerate(input_order):
154
  input_dict[channel]["coord"] = transformed_in[i].tolist()
155
 
156
+ channel_info.update({
157
  "templateByName" : template_dict,
158
  "inputByName" : input_dict,
159
  })
160
+ return channel_info
 
161
 
162
+ def fill_channels(app_state, channel_info, fill_mode):
163
 
164
+ new_idx = app_state["newOrder"]
165
+ template_dict = channel_info["templateByName"]
166
+ input_dict = channel_info["inputByName"]
167
+ template_order = channel_info["templateByIndex"]
168
+ input_order = channel_info["inputByIndex"]
169
+ z_row_idx = channel_info["dataShape"][0]
170
  unmatched = [channel for channel in template_dict if template_dict[channel]["matched"]==False]
171
  if unmatched == []:
172
+ return app_state
173
 
174
  if fill_mode == 'zero':
175
  for channel in unmatched:
 
193
  idx = template_dict[channel]["index"]
194
  new_idx[idx] = indices[0].tolist()
195
 
196
+ app_state.update({
197
+ "newOrder" : new_idx
198
+ })
199
+ return app_state
200
 
201
+ def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode):
202
  second1 = time.time()
203
 
204
  template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
 
231
  else:
232
  missing_channels.append(i)
233
 
234
+ channel_info.update({
 
235
  "missingChannelsIndex" : missing_channels,
236
  "templateByName" : {k : v.__dict__ for k,v in template_dict.items()},
237
  "inputByName" : {k : v.__dict__ for k,v in input_dict.items()},
238
  "templateByIndex" : template_montage.ch_names,
239
  "inputByIndex" : input_montage.ch_names
240
  })
241
+ app_state.update({
242
+ "newOrder" : new_idx
243
+ })
244
 
245
  # align input, template's coordinates
246
+ channel_info = align_coords(channel_info, template_montage, input_montage)
247
  # fill the unmatched channels
248
+ app_state = fill_channels(app_state, channel_info, fill_mode)
249
 
250
  second2 = time.time()
251
  print('Mapping (stage1) finished in',second2 - second1,'s.')
252
+ return app_state, channel_info
253
 
254
+ def mapping_stage2(app_state, channel_info, fill_mode):
255
  second1 = time.time()
256
 
257
+ template_dict = channel_info["templateByName"]
258
+ input_dict = channel_info["inputByName"]
259
+ template_order = channel_info["templateByIndex"]
260
  unassigned = [channel for channel in input_dict if input_dict[channel]["assigned"]==False]
261
  if unassigned == []:
262
+ app_state["runnigState"] = "finished"
263
+ return app_state, channel_info
264
 
265
  tpl_coords = np.array([template_dict[channel]["coord"] for channel in template_order])
266
  unassigned_coords = np.array([input_dict[channel]["coord"] for channel in unassigned])
 
293
  input_dict[in_channel]["assigned"] = True
294
  new_idx[i] = [input_dict[in_channel]["index"]]
295
 
296
+ channel_info.update({
 
297
  "templateByName" : template_dict,
298
  "inputByName" : input_dict
299
  })
300
+ app_state.update({
301
+ "newOrder" : new_idx
302
+ })
303
 
304
  # fill the unmatched channels
305
+ app_state = fill_channels(app_state, channel_info, fill_mode)
306
 
307
  second2 = time.time()
308
+ print(f'Mapping (stage2-{app_state["batchCount"]-1}) finished in {second2 - second1}s.')
309
+ return app_state, channel_info
310
 
utils.py CHANGED
@@ -223,8 +223,7 @@ def preprocessing(filepath, filename, samplerate):
223
  signal = read_train_data(filepath+filename)
224
  #print(signal.shape)
225
  # resample
226
- signal = resample(signal, samplerate)
227
- #signal = resample_(signal, samplerate, 256)
228
  #print(signal.shape)
229
  # FIR_filter
230
  signal = FIR_filter(signal, 1, 50)
 
223
  signal = read_train_data(filepath+filename)
224
  #print(signal.shape)
225
  # resample
226
+ signal = resample(signal, samplerate) #signal = resample_(signal, samplerate, 256)
 
227
  #print(signal.shape)
228
  # FIR_filter
229
  signal = FIR_filter(signal, 1, 50)