audrey06100 commited on
Commit
a22369d
·
1 Parent(s): 7a54f74
Files changed (3) hide show
  1. app.py +220 -183
  2. channel_mapping.py +237 -241
  3. utils.py +40 -3
app.py CHANGED
@@ -2,8 +2,10 @@ import gradio as gr
2
  import numpy as np
3
  import os
4
  import random
 
 
5
  import utils
6
- from channel_mapping import mapping, reorder_data
7
 
8
  import mne
9
  from mne.channels import read_custom_montage
@@ -11,12 +13,9 @@ from mne.channels import read_custom_montage
11
  quickstart = """
12
  # Quickstart
13
 
14
- ## 1. Channel mapping
15
-
16
  ### Raw data
17
  1. The data need to be a two-dimensional array (channel, timepoint).
18
- 2. Make sure you have **resampled** your data to **256 Hz**.
19
- 3. Upload your EEG data in `.csv` format.
20
 
21
  ### Channel locations
22
  Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
@@ -27,29 +26,12 @@ The models was trained using the EEG signals of 30 channels, including: `Fp1, Fp
27
  We expect your input data to include these channels as well.
28
  If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from:
29
 
30
- <u>Manually</u>:
31
- - **mean**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**.
32
-
33
- <u>Automatically</u>:
34
- Firstly, we will attempt to find neighboring channel to use as alternative. For instance, if the required channel is **FC3** but you only have **FC1**, we will use it as a replacement for **FC3**.
35
- Then, depending on the **Imputation** way you chose, we will:
36
  - **zero**: fill the missing channels with zeros.
37
- - **adjacent**: fill the missing channels using neighboring channels which are located closer to the center. For example, if the required channel is **FC3** but you only have **F3, C3**, then we will choose **C3** as the imputing value for **FC3**.
38
- >Note: The imputed channels **need to be removed** after the data being reconstructed.
39
 
40
  ### Mapping result
41
- Once the mapping process is finished, the **template montage** and the **input montage**(with the channels choosen by the mapping function displaying their names) will be shown.
42
-
43
- ### Missing channels
44
- The channels displayed here are those for which the template didn't find suitable channels to use, and utilized **Imputation** to fill the missing values.
45
- Therefore, you need to
46
- <span style="color:red">**remove these channels**</span>
47
- after you download the denoised data.
48
-
49
- ### Template location file
50
- You need to use this as the **new location file** for the denoised data.
51
-
52
- ## 2. Decode data
53
 
54
  ### Model
55
  Select the model you want to use.
@@ -68,20 +50,43 @@ chkbox_js = """
68
  state_json = JSON.parse(JSON.stringify(state_json));
69
  if(state_json.state == "finished") return;
70
 
71
- document.querySelector("#chs-chkbox>div:nth-of-type(2)").style.cssText = `
 
72
  position: relative;
73
  width: 560px;
74
  height: 560px;
75
  background: url("file=${state_json.files.raw_montage}");
76
  `;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- let all_chkbox = document.querySelectorAll("#chs-chkbox> div:nth-of-type(2)> label");
 
79
  all_chkbox = Array.apply(null, all_chkbox);
80
 
81
  all_chkbox.forEach((item, index) => {
82
- let channel = state_json.inputByIndex[index];
83
- let left = state_json.inputByName[channel].css_position[0];
84
- let bottom = state_json.inputByName[channel].css_position[1];
85
  //console.log(`left: ${left}, bottom: ${bottom}`);
86
 
87
  item.style.cssText = `
@@ -90,42 +95,73 @@ chkbox_js = """
90
  bottom: ${bottom};
91
  `;
92
  item.className = "";
93
- item.querySelector("span").innerText = "";
94
- });
 
 
 
 
 
 
 
 
 
 
 
 
95
 
 
 
 
 
 
 
 
 
 
96
  }
97
  """
98
 
99
 
100
  with gr.Blocks() as demo:
101
 
102
- state_json = gr.JSON(elem_id="state", visible=False)
103
 
104
  with gr.Row():
105
  gr.Markdown(
106
  """
107
-
108
  """
109
  )
110
  with gr.Row():
 
111
  with gr.Column():
112
  gr.Markdown(
113
  """
114
  # 1.Channel Mapping
115
  """
116
  )
 
 
117
  with gr.Row():
118
  in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
119
  in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
120
- with gr.Row():
121
- in_fill_mode = gr.Dropdown(choices=["zero",
122
- ("adjacent channel", "adjacent"),
123
- ("mean (manually select channels)", "mean")],
124
- value="zero",
125
- label="Imputation",
126
- scale=2)
127
- map_btn = gr.Button("Mapping", scale=1)
128
- channels_json = gr.JSON(visible=False)
 
 
 
 
 
 
 
129
  res_md = gr.Markdown(
130
  """
131
  ### Mapping result:
@@ -134,11 +170,8 @@ with gr.Blocks() as demo:
134
  )
135
  with gr.Row():
136
  tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
137
- map_montage = gr.Image(label="Choosen channels", visible=False)
138
- chs_chkbox = gr.CheckboxGroup(elem_id="chs-chkbox", label="", visible=False)
139
- next_btn = gr.Button("Next", interactive=False, visible=False)
140
- miss_txtbox = gr.Textbox(label="Missing channels", visible=False)
141
- tpl_loc_file = gr.File("./template_chanlocs.loc", show_label=False, visible=False)
142
  with gr.Column():
143
  gr.Markdown(
144
  """
@@ -146,29 +179,36 @@ with gr.Blocks() as demo:
146
  """
147
  )
148
  with gr.Row():
149
- in_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART", "(mapped data)"],
150
- value="ICUNet",
 
 
 
 
 
 
151
  label="Model",
152
  scale=2)
153
  run_btn = gr.Button(scale=1, interactive=False)
 
154
  out_denoised_data = gr.File(label="Denoised data")
155
 
156
 
157
  with gr.Row():
158
- with gr.Tab("EEGART"):
159
  gr.Markdown()
160
  with gr.Tab("IC-U-Net"):
161
  gr.Markdown(icunet)
162
  with gr.Tab("IC-U-Net++"):
163
  gr.Markdown()
164
- with gr.Tab("IC-U-Net-Att"):
165
  gr.Markdown()
166
  with gr.Tab("QuickStart"):
167
  gr.Markdown(quickstart)
168
 
169
  #demo.load(js=js)
170
 
171
- def reset_layout(raw_data):
172
  # establish temp folder
173
  filepath = os.path.dirname(str(raw_data))
174
  try:
@@ -177,51 +217,53 @@ with gr.Blocks() as demo:
177
  utils.dataDelete(filepath+"/temp_data/")
178
  os.mkdir(filepath+"/temp_data/")
179
  #print(e)
180
- state_obj = {
 
 
181
  "filepath": filepath+"/temp_data/",
182
- "files": {}
 
 
183
  }
184
- return {state_json : state_obj,
185
- chs_chkbox : gr.CheckboxGroup(choices=[], value=[], label="", visible=False), # choices, value ???
186
  next_btn : gr.Button("Next", interactive=False, visible=False),
187
  run_btn : gr.Button(interactive=False),
188
  tpl_montage : gr.Image(visible=False),
189
  map_montage : gr.Image(value=None, visible=False),
190
- miss_txtbox : gr.Textbox(visible=False),
191
  res_md : gr.Markdown(visible=False),
192
- tpl_loc_file : gr.File(visible=False)}
193
 
194
- def mapping_result(state_obj, channels_obj, raw_data, fill_mode):
195
- state_obj.update(channels_obj)
 
 
 
 
 
 
 
 
196
 
197
- if fill_mode=="mean" and channels_obj["missingChannelsIndex"]!=[]:
198
- state_obj.update({
199
  "state" : "initializing",
200
  "fillingCount" : 0,
201
- "totalFillingNum" : len(channels_obj["missingChannelsIndex"])-1
202
  })
203
- #print("Missing channels:", state_obj["missingChannelsIndex"])
204
- return {state_json : state_obj,
 
205
  next_btn : gr.Button(visible=True)}
206
  else:
207
- reorder_data(raw_data, channels_obj["newOrder"], fill_mode, state_obj)
208
-
209
- missing_channels = [state_obj["templateByIndex"][idx] for idx in state_obj["missingChannelsIndex"]]
210
- missing_channels = ', '.join(missing_channels)
211
-
212
- state_obj.update({
213
- "state" : "finished",
214
- #"fillingCount" : -1,
215
- #"totalFillingNum" : -1
216
- })
217
- return {state_json : state_obj,
218
  res_md : gr.Markdown(visible=True),
219
- miss_txtbox : gr.Textbox(value=missing_channels, visible=True),
220
- tpl_loc_file : gr.File(visible=True),
221
  run_btn : gr.Button(interactive=True)}
222
 
223
- def show_montage(state_obj, raw_loc):
224
- filepath = state_obj["filepath"]
225
  raw_montage = read_custom_montage(raw_loc)
226
 
227
  # convert all channel names to uppercase
@@ -229,73 +271,63 @@ with gr.Blocks() as demo:
229
  channel = raw_montage.ch_names[i]
230
  raw_montage.rename_channels({channel: str.upper(channel)})
231
 
232
- if state_obj["state"] == "initializing":
233
  filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
234
- state_obj["files"]["raw_montage"] = filename
235
  raw_fig = raw_montage.plot()
236
  raw_fig.set_size_inches(5.6, 5.6)
237
  raw_fig.savefig(filename, pad_inches=0)
238
 
239
- return {state_json : state_obj}#,
240
- #tpl_montage : gr.Image(visible=True),
241
- #in_montage : gr.Image(value=filename, visible=True),
242
- #map_montage : gr.Image(visible=False)}
243
 
244
- elif state_obj["state"] == "finished":
245
- # didn't find any way to hide the dark points...
246
- # tmp
247
  filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
248
- state_obj["files"]["map_montage"] = filename
249
 
250
  show_names= []
251
- for channel in state_obj["inputByName"]:
252
- if state_obj["inputByName"][channel]["used"]:
253
- if channel=='CZ' and state_obj["CZImputed"]:
254
- continue
255
  show_names.append(channel)
256
  mapped_fig = raw_montage.plot(show_names=show_names)
257
  mapped_fig.set_size_inches(5.6, 5.6)
258
  mapped_fig.savefig(filename, pad_inches=0)
259
 
260
- return {state_json : state_obj,
261
  tpl_montage : gr.Image(visible=True),
262
  map_montage : gr.Image(value=filename, visible=True)}
263
 
264
- elif state_obj["state"] == "selecting":
265
- # update in_montage here ?
266
- #return {in_montage : gr.Image()}
267
- return {state_json : state_obj}
268
 
269
- def generate_chkbox(state_obj):
270
- if state_obj["state"] == "initializing":
271
- in_channels = [channel for channel in state_obj["inputByName"]]
272
- state_obj["state"] = "selecting"
273
 
274
- first_idx = state_obj["missingChannelsIndex"][0]
275
- first_name = state_obj["templateByIndex"][first_idx]
276
- chkbox_label = first_name+' (1/'+str(state_obj["totalFillingNum"]+1)+')'
277
- return {state_json : state_obj,
278
- chs_chkbox : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True),
279
  next_btn : gr.Button(interactive=True)}
280
  else:
281
- return {state_json : state_obj}
282
 
283
 
284
  map_btn.click(
285
  fn = reset_layout,
286
- inputs = in_raw_data,
287
- outputs = [state_json, chs_chkbox, next_btn, run_btn, tpl_montage, map_montage, miss_txtbox,
288
- res_md, tpl_loc_file]
289
-
290
  ).success(
291
- fn = mapping,
292
- inputs = [in_raw_data, in_raw_loc, in_fill_mode],
293
- outputs = channels_json
294
 
295
  ).success(
296
  fn = mapping_result,
297
- inputs = [state_json, channels_json, in_raw_data, in_fill_mode],
298
- outputs = [state_json, chs_chkbox, next_btn, miss_txtbox, res_md, tpl_loc_file, run_btn]
299
 
300
  ).success(
301
  fn = show_montage,
@@ -305,7 +337,8 @@ with gr.Blocks() as demo:
305
  ).success(
306
  fn = generate_chkbox,
307
  inputs = state_json,
308
- outputs = [state_json, chs_chkbox, next_btn]
 
309
  ).success(
310
  fn = None,
311
  js = chkbox_js,
@@ -314,81 +347,85 @@ with gr.Blocks() as demo:
314
  )
315
 
316
 
317
- def check_next(state_obj, selected, raw_data, fill_mode):
318
- if state_obj["state"] == "selecting":
319
 
320
- # save info before clicking on next_btn
321
- prev_target_idx = state_obj["missingChannelsIndex"][state_obj["fillingCount"]]
322
- prev_target_name = state_obj["templateByIndex"][prev_target_idx]
323
-
324
- selected_idx = [state_obj["inputByName"][channel]["index"] for channel in selected]
325
- state_obj["newOrder"][prev_target_idx] = selected_idx
326
-
327
- if len(selected)==1 and state_obj["inputByName"][selected[0]]["used"]==False:
328
- state_obj["inputByName"][selected[0]]["used"] = True
329
- state_obj["missingChannelsIndex"][state_obj["fillingCount"]] = -1
330
-
331
- print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
 
 
 
 
 
 
 
 
332
 
333
- # update next round
334
- state_obj["fillingCount"] += 1
335
- if state_obj["fillingCount"] <= state_obj["totalFillingNum"]:
336
- target_idx = state_obj["missingChannelsIndex"][state_obj["fillingCount"]]
337
- target_name = state_obj["templateByIndex"][target_idx]
338
- chkbox_label = target_name+' ('+str(state_obj["fillingCount"]+1)+'/'+str(state_obj["totalFillingNum"]+1)+')'
339
- btn_label = "Submit" if state_obj["fillingCount"]==state_obj["totalFillingNum"] else "Next"
340
-
341
- return {state_json : state_obj,
342
- chs_chkbox : gr.CheckboxGroup(value=[], label=chkbox_label),
343
- next_btn : gr.Button(btn_label)}
344
- else:
345
- state_obj["state"] = "finished"
346
- reorder_data(raw_data, state_obj["newOrder"], fill_mode, state_obj)
347
-
348
- missing_channels = []
349
- for idx in state_obj["missingChannelsIndex"]:
350
- if idx != -1:
351
- missing_channels.append(state_obj["templateByIndex"][idx])
352
- missing_channels = ', '.join(missing_channels)
353
-
354
- return {state_json : state_obj,
355
- chs_chkbox : gr.CheckboxGroup(visible=False),
356
- next_btn : gr.Button(visible=False),
357
- res_md : gr.Markdown(visible=True),
358
- miss_txtbox : gr.Textbox(value=missing_channels, visible=True),
359
- tpl_loc_file : gr.File(visible=True),
360
- run_btn : gr.Button(interactive=True)}
361
 
362
  next_btn.click(
363
- fn = check_next,
364
- inputs = [state_json, chs_chkbox, in_raw_data, in_fill_mode],
365
- outputs = [state_json, chs_chkbox, next_btn, run_btn, res_md, miss_txtbox, tpl_loc_file]
366
-
367
  ).success(
368
  fn = show_montage,
369
  inputs = [state_json, in_raw_loc],
370
  outputs = [state_json, tpl_montage, map_montage]
371
  )
372
-
373
-
374
- @run_btn.click(inputs=[state_json, in_raw_data, in_model_name], outputs=out_denoised_data)
375
- def run_model(state_obj, raw_file, model_name):
376
- filepath = state_obj["filepath"]
377
-
378
- input_name = os.path.basename(str(raw_file))
 
 
 
379
  output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
380
 
381
- if model_name == "(mapped data)":
382
- return filepath + 'mapped.csv'
383
-
384
- # step1: Data preprocessing
385
- total_file_num = utils.preprocessing(filepath, 'mapped.csv', 256)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- # step2: Signal reconstruction
388
- utils.reconstruct(model_name, total_file_num, filepath, output_name)
389
 
390
- return filepath + output_name
391
-
392
-
393
  if __name__ == "__main__":
394
  demo.launch()
 
2
  import numpy as np
3
  import os
4
  import random
5
+ import math
6
+ import json
7
  import utils
8
+ from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin
9
 
10
  import mne
11
  from mne.channels import read_custom_montage
 
13
  quickstart = """
14
  # Quickstart
15
 
 
 
16
  ### Raw data
17
  1. The data need to be a two-dimensional array (channel, timepoint).
18
+ 2. Upload your EEG data in `.csv` format.
 
19
 
20
  ### Channel locations
21
  Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
 
26
  We expect your input data to include these channels as well.
27
  If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from:
28
 
 
 
 
 
 
 
29
  - **zero**: fill the missing channels with zeros.
30
+ - **mean(auto)**: select 4 neareat channels for each missing channels, and we will average their values.
31
+ - **mean(manual)**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**.
32
 
33
  ### Mapping result
34
+ Once the mapping process is finished, the **template montage** and the **input montage**(with the matched channels displaying their names) will be shown.
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  ### Model
37
  Select the model you want to use.
 
50
  state_json = JSON.parse(JSON.stringify(state_json));
51
  if(state_json.state == "finished") return;
52
 
53
+ // add figure of in_montage
54
+ document.querySelector("#chkbox-group> div:nth-of-type(2)").style.cssText = `
55
  position: relative;
56
  width: 560px;
57
  height: 560px;
58
  background: url("file=${state_json.files.raw_montage}");
59
  `;
60
+
61
+ // add indication for the missing channel
62
+ /*
63
+ let indicator = document.getElementById("indicator")
64
+ if(!indicator) document.querySelector("#chkbox-group> div:nth-of-type(2)").innerHTML += '<div id="indicator"></div>'
65
+
66
+ let channel = state_json.missingChannelsIndex[0]
67
+ channel = state_json.templateByIndex[channel]
68
+ let left = state_json.templateByName[channel].css_position[0];
69
+ let bottom = state_json.templateByName[channel].css_position[1];
70
+
71
+ document.getElementById("red-dot").style.cssText = `
72
+ position: absolute;
73
+ background-color: red;
74
+ width: 10px;
75
+ height: 10px;
76
+ border-radius: 50%;
77
+ left: ${left};
78
+ bottom: ${bottom};
79
+ `;
80
+ */
81
 
82
+ // move the checkboxes
83
+ let all_chkbox = document.querySelectorAll("#chkbox-group> div:nth-of-type(2)> label");
84
  all_chkbox = Array.apply(null, all_chkbox);
85
 
86
  all_chkbox.forEach((item, index) => {
87
+ channel = state_json.inputByIndex[index];
88
+ left = state_json.inputByName[channel].css_position[0];
89
+ bottom = state_json.inputByName[channel].css_position[1];
90
  //console.log(`left: ${left}, bottom: ${bottom}`);
91
 
92
  item.style.cssText = `
 
95
  bottom: ${bottom};
96
  `;
97
  item.className = "";
98
+ item.querySelector(":scope> span").innerText = "";
99
+ });
100
+ }
101
+ """
102
+
103
+ dot_js = """
104
+ (state_json) => {
105
+ state_json = JSON.parse(JSON.stringify(state_json));
106
+ if(state_json.state == "finished") return;
107
+
108
+ let channel = state_json.missingChannelsIndex[state_json["fillingCount"]]
109
+ channel = state_json.templateByIndex[channel]
110
+ let left = state_json.templateByName[channel].css_position[0];
111
+ let bottom = state_json.templateByName[channel].css_position[1];
112
 
113
+ document.getElementById("indicator").style.cssText = `
114
+ position: absolute;
115
+ background-color: red;
116
+ width: 10px;
117
+ height: 10px;
118
+ border-radius: 50%;
119
+ left: ${left};
120
+ bottom: ${bottom};
121
+ `;
122
  }
123
  """
124
 
125
 
126
  with gr.Blocks() as demo:
127
 
128
+ state_json = gr.JSON(visible=False)
129
 
130
  with gr.Row():
131
  gr.Markdown(
132
  """
133
+ <p style="text-align: center;">(...)</p>
134
  """
135
  )
136
  with gr.Row():
137
+
138
  with gr.Column():
139
  gr.Markdown(
140
  """
141
  # 1.Channel Mapping
142
  """
143
  )
144
+
145
+ # upload files, chose imputation way (???
146
  with gr.Row():
147
  in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
148
  in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
149
+ with gr.Column(min_width=100):
150
+ in_sample_rate = gr.Textbox(label="Sampling rate (Hz)")
151
+ in_fill_mode = gr.Dropdown(choices=[
152
+ #("adjacent channel", "adjacent"),
153
+ ("mean (auto)", "mean_auto"),
154
+ ("mean (manual)", "mean_manual"),
155
+ ("",""),
156
+ "zero"],
157
+ value="mean_auto",
158
+ label="Imputation")
159
+ map_btn = gr.Button("Mapping")
160
+
161
+ chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
162
+ next_btn = gr.Button("Next", interactive=False, visible=False)
163
+
164
+ # mapping result
165
  res_md = gr.Markdown(
166
  """
167
  ### Mapping result:
 
170
  )
171
  with gr.Row():
172
  tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
173
+ map_montage = gr.Image(label="Matched channels", visible=False)
174
+
 
 
 
175
  with gr.Column():
176
  gr.Markdown(
177
  """
 
179
  """
180
  )
181
  with gr.Row():
182
+ in_model_name = gr.Dropdown(choices=[
183
+ ("ART", "EEGART"),
184
+ ("IC-U-Net", "ICUNet"),
185
+ ("IC-U-Net++", "UNetpp"),
186
+ ("IC-U-Net-Attn", "AttUnet"),
187
+ "(mapped data)",
188
+ "(denoised data)"],
189
+ value="EEGART",
190
  label="Model",
191
  scale=2)
192
  run_btn = gr.Button(scale=1, interactive=False)
193
+ batch_md = gr.Markdown(visible=False)
194
  out_denoised_data = gr.File(label="Denoised data")
195
 
196
 
197
  with gr.Row():
198
+ with gr.Tab("ART"):
199
  gr.Markdown()
200
  with gr.Tab("IC-U-Net"):
201
  gr.Markdown(icunet)
202
  with gr.Tab("IC-U-Net++"):
203
  gr.Markdown()
204
+ with gr.Tab("IC-U-Net-Attn"):
205
  gr.Markdown()
206
  with gr.Tab("QuickStart"):
207
  gr.Markdown(quickstart)
208
 
209
  #demo.load(js=js)
210
 
211
+ def reset_layout(raw_data, samplerate):
212
  # establish temp folder
213
  filepath = os.path.dirname(str(raw_data))
214
  try:
 
217
  utils.dataDelete(filepath+"/temp_data/")
218
  os.mkdir(filepath+"/temp_data/")
219
  #print(e)
220
+
221
+ data = utils.read_train_data(raw_data)
222
+ state = {
223
  "filepath": filepath+"/temp_data/",
224
+ "files": {},
225
+ "sampleRate": int(samplerate),
226
+ "dataShape" : data.shape
227
  }
228
+ return {state_json : state,
229
+ chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
230
  next_btn : gr.Button("Next", interactive=False, visible=False),
231
  run_btn : gr.Button(interactive=False),
232
  tpl_montage : gr.Image(visible=False),
233
  map_montage : gr.Image(value=None, visible=False),
 
234
  res_md : gr.Markdown(visible=False),
235
+ batch_md : gr.Markdown(visible=False)}
236
 
237
+ def mapping_result(state, fill_mode):
238
+
239
+ in_num = len(state["inputByName"])
240
+ matched_num = 30 - len(state["missingChannelsIndex"])
241
+ batch_num = math.ceil((in_num-matched_num)/30) + 1
242
+ state.update({
243
+ "runnigState" : "stage1",
244
+ "batchCount" : 1,
245
+ "totalBatchNum" : batch_num
246
+ })
247
 
248
+ if fill_mode=="mean_manual" and state["missingChannelsIndex"]!=[]:
249
+ state.update({
250
  "state" : "initializing",
251
  "fillingCount" : 0,
252
+ "totalFillingNum" : len(state["missingChannelsIndex"])-1
253
  })
254
+ #print("Missing channels:", state["missingChannelsIndex"])
255
+ return {state_json : state,
256
+ #chkbox_group : gr.CheckboxGroup(visible=True),
257
  next_btn : gr.Button(visible=True)}
258
  else:
259
+ state["state"] = "finished"
260
+
261
+ return {state_json : state,
 
 
 
 
 
 
 
 
262
  res_md : gr.Markdown(visible=True),
 
 
263
  run_btn : gr.Button(interactive=True)}
264
 
265
+ def show_montage(state, raw_loc):
266
+ filepath = state["filepath"]
267
  raw_montage = read_custom_montage(raw_loc)
268
 
269
  # convert all channel names to uppercase
 
271
  channel = raw_montage.ch_names[i]
272
  raw_montage.rename_channels({channel: str.upper(channel)})
273
 
274
+ if state["state"] == "initializing":
275
  filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
276
+ state["files"]["raw_montage"] = filename
277
  raw_fig = raw_montage.plot()
278
  raw_fig.set_size_inches(5.6, 5.6)
279
  raw_fig.savefig(filename, pad_inches=0)
280
 
281
+ return {state_json : state}
 
 
 
282
 
283
+ elif state["state"] == "finished":
 
 
284
  filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
285
+ state["files"]["map_montage"] = filename
286
 
287
  show_names= []
288
+ for channel in state["inputByName"]:
289
+ if state["inputByName"][channel]["matched"]:
 
 
290
  show_names.append(channel)
291
  mapped_fig = raw_montage.plot(show_names=show_names)
292
  mapped_fig.set_size_inches(5.6, 5.6)
293
  mapped_fig.savefig(filename, pad_inches=0)
294
 
295
+ return {state_json : state,
296
  tpl_montage : gr.Image(visible=True),
297
  map_montage : gr.Image(value=filename, visible=True)}
298
 
299
+ else:
300
+ return {state_json : state} # change nothing
 
 
301
 
302
+ def generate_chkbox(state):
303
+ if state["state"] == "initializing":
304
+ in_channels = [channel for channel in state["inputByName"]]
305
+ state["state"] = "selecting"
306
 
307
+ first_idx = state["missingChannelsIndex"][0]
308
+ first_name = state["templateByIndex"][first_idx]
309
+ chkbox_label = first_name+' (1/'+str(state["totalFillingNum"]+1)+')'
310
+ return {state_json : state,
311
+ chkbox_group : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True),
312
  next_btn : gr.Button(interactive=True)}
313
  else:
314
+ return {state_json : state} # change nothing
315
 
316
 
317
  map_btn.click(
318
  fn = reset_layout,
319
+ inputs = [in_raw_data, in_sample_rate],
320
+ outputs = [state_json, chkbox_group, next_btn, run_btn, tpl_montage, map_montage, res_md, batch_md]
321
+
 
322
  ).success(
323
+ fn = mapping_stage1,
324
+ inputs = [state_json, in_raw_data, in_raw_loc, in_fill_mode],
325
+ outputs = state_json
326
 
327
  ).success(
328
  fn = mapping_result,
329
+ inputs = [state_json, in_fill_mode],
330
+ outputs = [state_json, next_btn, res_md, run_btn]
331
 
332
  ).success(
333
  fn = show_montage,
 
337
  ).success(
338
  fn = generate_chkbox,
339
  inputs = state_json,
340
+ outputs = [state_json, chkbox_group, next_btn]
341
+
342
  ).success(
343
  fn = None,
344
  js = chkbox_js,
 
347
  )
348
 
349
 
350
+ def check_next(state, selected, raw_data, fill_mode):
351
+ #if state["state"] == "selecting":
352
 
353
+ # save info before clicking on next_btn
354
+ prev_target_idx = state["missingChannelsIndex"][state["fillingCount"]]
355
+ prev_target_name = state["templateByIndex"][prev_target_idx]
356
+
357
+ selected_idx = [state["inputByName"][channel]["index"] for channel in selected]
358
+ state["newOrder"][prev_target_idx] = selected_idx
359
+
360
+ #if len(selected)==1 and state["inputByName"][selected[0]]["used"]==False:
361
+ #state["inputByName"][selected[0]]["used"] = True
362
+ #state["missingChannelsIndex"][state["fillingCount"]] = -1
363
+
364
+ print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
365
+
366
+ # update next round
367
+ state["fillingCount"] += 1
368
+ if state["fillingCount"] <= state["totalFillingNum"]:
369
+ target_idx = state["missingChannelsIndex"][state["fillingCount"]]
370
+ target_name = state["templateByIndex"][target_idx]
371
+ chkbox_label = target_name+' ('+str(state["fillingCount"]+1)+'/'+str(state["totalFillingNum"]+1)+')'
372
+ btn_label = "Submit" if state["fillingCount"]==state["totalFillingNum"] else "Next"
373
 
374
+ return {state_json : state,
375
+ chkbox_group : gr.CheckboxGroup(value=[], label=chkbox_label),
376
+ next_btn : gr.Button(btn_label)}
377
+ else:
378
+ state["state"] = "finished"
379
+ return {state_json : state,
380
+ chkbox_group : gr.CheckboxGroup(visible=False),
381
+ next_btn : gr.Button(visible=False),
382
+ res_md : gr.Markdown(visible=True),
383
+ run_btn : gr.Button(interactive=True)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
385
  next_btn.click(
386
+ fn = check_next,
387
+ inputs = [state_json, chkbox_group, in_raw_data, in_fill_mode],
388
+ outputs = [state_json, chkbox_group, next_btn, run_btn, res_md]
389
+
390
  ).success(
391
  fn = show_montage,
392
  inputs = [state_json, in_raw_loc],
393
  outputs = [state_json, tpl_montage, map_montage]
394
  )
395
+
396
+ @run_btn.click(inputs = [state_json, in_raw_data, in_model_name, in_fill_mode], outputs = out_denoised_data)
397
+ def run_model(state, raw_data, model_name, fill_mode):
398
+ filepath = state["filepath"]
399
+ samplerate = state["sampleRate"]
400
+
401
+ #if batch > total_batch:
402
+ #return {batch_md : gr.Markdown("error", visible=True)}
403
+
404
+ input_name = os.path.basename(str(raw_data))
405
  output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
406
 
407
+ while(state["runnigState"] != "finished"):
408
+ if state["batchCount"] > state["totalBatchNum"]:
409
+ break
410
+ if state["batchCount"] > 1:
411
+ state["runnigState"] = "stage2"
412
+ state = mapping_stage2(state, fill_mode)
413
+ state["batchCount"] += 1
414
+
415
+ reorder_to_template(state, raw_data)
416
+ # step1: Data preprocessing
417
+ total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
418
+ # step2: Signal reconstruction
419
+ utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
420
+ reorder_to_origin(state, filepath+'denoised.csv', filepath+output_name)
421
+
422
+ if model_name == "(mapped data)":
423
+ return {out_denoised_data : filepath + 'mapped.csv'}
424
+ elif model_name == "(denoised data)":
425
+ return {out_denoised_data : filepath + 'denoised.csv'}
426
 
427
+ return {out_denoised_data : filepath + output_name}
 
428
 
429
+
 
 
430
  if __name__ == "__main__":
431
  demo.launch()
channel_mapping.py CHANGED
@@ -2,15 +2,19 @@ import utils
2
  import time
3
  import os
4
  import numpy as np
 
5
 
6
  import mne
7
  from mne.channels import read_custom_montage
8
-
9
- def reorder_data(filename, old_idx, fill_mode, state_obj):
10
- old_data = utils.read_train_data(filename)
11
- new_data = np.zeros((30, old_data.shape[1]))
12
- new_filename = state_obj["filepath"]+'mapped.csv'
13
- #print('old data shape: ', old_data.shape)
 
 
 
14
 
15
  zero_arr = np.zeros((1, old_data.shape[1]))
16
  old_data = np.concatenate((old_data, zero_arr), axis=0)
@@ -24,286 +28,278 @@ def reorder_data(filename, old_idx, fill_mode, state_obj):
24
  else:
25
  tmp_data = [old_data[j, :] for j in curr_idx_set]
26
  new_data[i, :] = np.mean(tmp_data, axis=0)
27
-
28
- #print('new data shape: ', new_data.shape)
29
  utils.save_data(new_data, new_filename)
30
  return
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  class Channel:
34
 
35
- def __init__(self, index, name=None, used=False, coord=None, css_position=None, topo_index=None, topo_position=None):
36
-
37
  self.name = name
38
  self.index = index
39
- self.used = used
 
40
  self.coord = coord
41
  self.css_position = css_position
42
- self.topo_index = topo_index
43
- self.topo_position = topo_position
44
-
45
- def prefix(self):
46
- ret = ''.join(filter(str.isalpha, self.name))
47
- return ret[:len(ret) - 1] if ret[-1] == 'Z' else ret
48
 
49
- def suffix(self):
50
- return -1 if self.name[-1] == 'Z' else int(''.join(filter(str.isdigit, self.name)))
51
 
52
-
53
- def pack_data(new_idx, missing_channels, tpl_dict, in_dict, tpl_ordered_name, in_ordered_name):
54
-
55
- return {
56
- "newOrder" : [([i] if i!=-1 else []) for i in new_idx],
57
- "missingChannelsIndex" : missing_channels,
58
- "templateByName" : {k : v.__dict__ for k,v in tpl_dict.items()}, # dict, {name:object}
59
- "templateByIndex" : tpl_ordered_name, # list
60
- "inputByName" : {k : v.__dict__ for k,v in in_dict.items()},
61
- "inputByIndex" : in_ordered_name
62
- }
63
-
64
- def mapping(data_file, loc_file, fill_mode):
65
- second1 = time.time()
66
-
67
- data = utils.read_train_data(data_file)
68
-
69
- template_dict = {}
70
- input_dict = {}
71
  template_montage = read_custom_montage("./template_chanlocs.loc")
72
  input_montage = read_custom_montage(loc_file)
 
 
73
 
74
  montages = [template_montage, input_montage]
75
  dicts = [template_dict, input_dict]
76
  num = [30, len(input_montage.ch_names)]
77
 
78
  for i in range(2):
79
- fig = montages[i].plot()
80
- fig.set_size_inches(5.6, 5.6)
81
- ax = fig.axes[0]
82
- ax.set_aspect('equal')
83
- ax.figure.canvas.draw() #update the figure
84
- coords = ax.collections[0].get_offsets().data
85
- abs_coords = ax.transData.transform(coords)
86
- #print("abs_coords)
87
  for j in range(num[i]):
88
  channel = montages[i].ch_names[j]
 
89
 
90
- # convert all channel names to uppercase
91
- montages[i].rename_channels({channel: str.upper(channel)})
92
-
93
- css_left = (abs_coords[j][0]-11)/560
94
- css_bottom = (abs_coords[j][1]-7)/560
95
  channel = str.upper(channel)
96
- dicts[i][channel] = Channel(index=j,
97
- name=channel,
98
- coord=montages[i].get_positions()['ch_pos'][channel],
99
- css_position=[str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
100
- )
 
 
 
 
 
 
101
 
 
 
 
 
102
 
103
- new_idx = [-1]*30
104
- missing_channels = []
105
- exact_missing_channels = []
106
- z_row_idx = data.shape[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
 
 
 
 
108
 
109
- # STAGE_1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- # match the template's channel names with the input ones
112
- finish_flag = 1
 
 
113
  alias = {
114
  'T3': 'T7',
115
  'T4': 'T8',
116
  'T5': 'P7',
117
  'T6': 'P8',
118
- 'TP7': 'T5\'',
119
- 'TP8': 'T6\'',
120
  }
121
 
122
- for i in range(30):
123
- channel = template_montage.ch_names[i]
124
- if channel not in input_dict.keys() | alias.keys():
125
- exact_missing_channels.append(i)
126
- finish_flag = 0
127
- continue
128
-
129
- if channel not in input_dict and channel in alias:
130
- if alias[channel] in input_dict:
131
- template_montage.rename_channels({channel: alias[channel]})
132
- template_dict[alias[channel]] = template_dict.pop(channel)
133
- template_dict[alias[channel]].name = alias[channel]
134
- channel = alias[channel]
135
- else:
136
- exact_missing_channels.append(i)
137
- finish_flag = 0
138
- continue
139
-
140
- new_idx[i] = input_dict[channel].index
141
- input_dict[channel].used = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
- if finish_flag == 1:
144
- second2 = time.time()
145
- print('Finish at stage 1 ! (',second2 - second1,'s)')
146
- #print('new idx order:', new_idx)
147
 
148
- channels_obj = pack_data(new_idx, [],
149
- template_dict, input_dict,
150
- template_montage.ch_names, input_montage.ch_names)
151
- channels_obj.update({"CZImputed" : False})
152
- return channels_obj
153
-
154
- elif fill_mode == "mean":
155
- channels_obj = pack_data(new_idx, exact_missing_channels,
156
- template_dict, input_dict,
157
- template_montage.ch_names, input_montage.ch_names)
158
- channels_obj.update({"CZImputed" : False})
159
- return channels_obj
160
 
 
 
161
 
 
 
 
162
 
 
 
 
 
 
 
 
 
163
 
164
- # STAGE_2
 
165
 
166
- # store channel positions in a 2-d array
167
- template_topo_pos = []
168
- temporal_channels = []
169
- temporal_row_prefix = ['FC', 'C', 'CP', 'P']
170
-
171
- cnt = 0
172
- for i in range(7):
173
- tmp = []
174
- for j in range(5):
175
- if [i,j] in [[0,0],[0,2],[0,4],[6,0],[6,4]]:
176
- tmp.append('')
177
- else:
178
- channel = template_montage.ch_names[cnt]
179
- tmp.append(channel)
180
-
181
- ver = 'front' if i<3 else 'center' if i==3 else 'back'
182
- hor = 'left' if j<2 else 'center' if j==2 else 'right'
183
- template_dict[channel].topo_index = [i, j]
184
- template_dict[channel].topo_position = [ver, hor]
185
-
186
- if i > 1 and j in [0, 4]:
187
- temporal_channels.append(channel)
188
- cnt += 1
189
- template_topo_pos.append(tmp)
190
-
191
-
192
- # ensure that CZ is found or imputed by another channel
193
- CZ_impute_flag = False
194
- if 'CZ' not in input_dict and fill_mode=='adjacent':
195
- CZ_impute_flag = True
196
- min_dist = 1e5
197
- for channel in input_montage.ch_names:
198
- curr_x, curr_y, curr_z = input_dict[channel].coord.round(6)
199
- if curr_x**2 + curr_y**2 < min_dist:
200
- nearest_channel = channel
201
- min_dist = curr_x**2 + curr_y**2
202
-
203
- if input_dict[nearest_channel].used == True:
204
- missing_channels.append(template_dict['CZ'].index)
205
- input_dict[nearest_channel].used = True
206
- input_dict['CZ'] = input_dict[nearest_channel]
207
- print("CZ's nearest neighbor:", nearest_channel)
208
-
209
-
210
  for i in range(30):
211
- if new_idx[i] != -1:
212
- continue
213
-
214
- channel = template_montage.ch_names[i]
215
-
216
- curr_prefix = template_dict[channel].prefix()
217
- curr_suffix = template_dict[channel].suffix()
218
-
219
- curr_row = template_dict[channel].topo_index[0]
220
- curr_col = template_dict[channel].topo_index[1]
221
- curr_ver = template_dict[channel].topo_position[0]
222
- curr_hor = template_dict[channel].topo_position[1]
223
-
224
- impute_channel = ''
225
-
226
- # if the current channel is a temporal channel
227
- if channel in temporal_channels:
228
- curr_prefix = temporal_row_prefix[temporal_channels.index(channel)//2]
229
- curr_suffix = 7 if curr_hor=='left' else 8
230
-
231
- if fill_mode == 'zero':
232
-
233
- impute_channel = curr_prefix+str(1) if curr_hor=='center' else curr_prefix+str(curr_suffix-2)
234
- if impute_channel not in input_dict or input_dict[impute_channel].used==True:
235
- impute_channel = ''
236
- new_idx[i] = z_row_idx
237
- missing_channels.append(i)
238
- continue
239
-
240
- elif fill_mode == 'adjacent':
241
 
242
- ver_dir = 1 if curr_ver == 'front' else -1
243
-
244
- if curr_hor == 'center': # FZ, FPZ...
245
-
246
- if curr_prefix+str(1) in input_dict: # ex: FZ<-F1
247
- impute_channel = curr_prefix + str(1)
248
-
249
- elif template_topo_pos[curr_row+ver_dir][curr_col] in input_dict: # ex: front:FZ<-FCZ,
250
- impute_channel = template_topo_pos[curr_row+ver_dir][curr_col]
251
-
252
- elif curr_prefix+str(3) in input_dict: # ex: FZ<-F3
253
- impute_channel = curr_prefix + str(3)
254
-
255
- else:
256
- impute_channel = 'CZ'
257
-
258
- elif curr_hor == 'left' or curr_hor == 'right':
259
-
260
- ver_ctrl = 1 if curr_ver=='front' else 2 if curr_ver=='back' else 3 # bit0: row+1, bit1: row-1
261
-
262
- # search horizontally
263
- cnt = 0
264
- tmp_suffix = curr_suffix
265
- while tmp_suffix > 0: # ex: F7<-F5/F3/F1
266
- tmp_suffix = curr_suffix - 2*cnt
267
- if curr_prefix+str(tmp_suffix) in input_dict:
268
- impute_channel = curr_prefix + str(tmp_suffix)
269
- break
270
-
271
- if cnt == 2:
272
- # check row+1/row-1
273
- if ver_ctrl&1 and template_topo_pos[curr_row+1][curr_col] in input_dict:
274
- impute_channel = template_topo_pos[curr_row+1][curr_col]
275
- break
276
- if ver_ctrl&2 and template_topo_pos[curr_row-1][curr_col] in input_dict:
277
- impute_channel = template_topo_pos[curr_row-1][curr_col]
278
- break
279
- cnt += 1
280
-
281
- # search vertically
282
- if impute_channel == '':
283
- cnt = 0
284
- tmp_row = curr_row + ver_dir
285
- while tmp_row-ver_dir != 3: # terminate if the last channel is a middle one
286
- if template_topo_pos[tmp_row][curr_col] in input_dict:
287
- impute_channel = template_topo_pos[tmp_row][curr_col]
288
- break
289
- tmp_row += ver_dir
290
-
291
- # if still cannot find available channel...
292
- if impute_channel == '':
293
- impute_channel = 'CZ'
294
-
295
- new_idx[i] = input_dict[impute_channel].index
296
- if input_dict[impute_channel].used == True: # this channel is shared with others
297
- missing_channels.append(i)
298
- input_dict[impute_channel].used = True
299
 
300
- second2 = time.time()
301
- print('Finish at stage 2 ! (',second2 - second1,'s)')
302
- #print('new_idx:', new_idx)
 
 
303
 
304
- channels_obj = pack_data(new_idx, missing_channels,
305
- template_dict, input_dict,
306
- template_montage.ch_names, input_montage.ch_names)
307
- channels_obj.update({"CZImputed" : CZ_impute_flag})
308
- return channels_obj
309
 
 
 
 
 
 
2
  import time
3
  import os
4
  import numpy as np
5
+ import gradio as gr
6
 
7
  import mne
8
  from mne.channels import read_custom_montage
9
+ from scipy.interpolate import Rbf
10
+ from scipy.optimize import linear_sum_assignment
11
+ from sklearn.neighbors import NearestNeighbors
12
+
13
+ def reorder_to_template(state, filename):
14
+ old_idx = state["newOrder"]
15
+ old_data = utils.read_train_data(filename) # original raw data
16
+ new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
17
+ new_filename = state["filepath"]+'mapped.csv'
18
 
19
  zero_arr = np.zeros((1, old_data.shape[1]))
20
  old_data = np.concatenate((old_data, zero_arr), axis=0)
 
28
  else:
29
  tmp_data = [old_data[j, :] for j in curr_idx_set]
30
  new_data[i, :] = np.mean(tmp_data, axis=0)
31
+
32
+ print('old.shape, new.shape: ', old_data.shape, new_data.shape)
33
  utils.save_data(new_data, new_filename)
34
  return
35
 
36
+ def reorder_to_origin(state, filename, new_filename):
37
+ old_idx = state["newOrder"]
38
+ old_data = utils.read_train_data(filename) # denoised data
39
+ template_order = state["templateByIndex"]
40
+
41
+ if state["runnigState"] == "stage1":
42
+ new_data = np.zeros((len(state["inputByName"]), old_data.shape[1]))
43
+ else:
44
+ new_data = utils.read_train_data(new_filename)
45
+
46
+ for i, channel in enumerate(template_order):
47
+ idx_set = old_idx[i]
48
+
49
+ # ignore if this channel doesn't exist
50
+ if len(idx_set)==1 and state["templateByName"][channel]["matched"]==True:
51
+ new_data[idx_set[0], :] = old_data[i, :]
52
+
53
+ print('old.shape, new.shape: ', old_data.shape, new_data.shape)
54
+ utils.save_data(new_data, new_filename)
55
+ return
56
 
57
  class Channel:
58
 
59
+ def __init__(self, index, name=None, matched=False, assigned=False, coord=None, css_position=None):
 
60
  self.name = name
61
  self.index = index
62
+ self.matched = matched
63
+ self.assigned = assigned # for input channels
64
  self.coord = coord
65
  self.css_position = css_position
 
 
 
 
 
 
66
 
 
 
67
 
68
+ def read_montage_data(loc_file):
69
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  template_montage = read_custom_montage("./template_chanlocs.loc")
71
  input_montage = read_custom_montage(loc_file)
72
+ template_dict = {}
73
+ input_dict = {}
74
 
75
  montages = [template_montage, input_montage]
76
  dicts = [template_dict, input_dict]
77
  num = [30, len(input_montage.ch_names)]
78
 
79
  for i in range(2):
 
 
 
 
 
 
 
 
80
  for j in range(num[i]):
81
  channel = montages[i].ch_names[j]
82
+ montages[i].rename_channels({channel: str.upper(channel)}) # convert all channel names to uppercase
83
 
 
 
 
 
 
84
  channel = str.upper(channel)
85
+ dicts[i][channel] = Channel(index=j, name=channel, coord=montages[i].get_positions()['ch_pos'][channel])
86
+
87
+ return template_montage, input_montage, template_dict, input_dict
88
+
89
+ def align_coords(state, template_montage, input_montage):
90
+
91
+ template_dict = state["templateByName"]
92
+ input_dict = state["inputByName"]
93
+ template_order = state["templateByIndex"]
94
+ input_order = state["inputByIndex"]
95
+ matched = [channel for channel in input_dict if input_dict[channel]["matched"]==True]
96
 
97
+ # 2-d (fot the indication of missing template channel's position when fill_mode:'mean_manual')
98
+ fig = [template_montage.plot(), input_montage.plot()]
99
+ fig[0].set_size_inches(5.6, 5.6)
100
+ fig[1].set_size_inches(5.6, 5.6)
101
 
102
+ ax = [fig[0].axes[0], fig[1].axes[0]]
103
+ ax[0].set_aspect('equal')
104
+ ax[1].set_aspect('equal')
105
+ ax[0].figure.canvas.draw() #update the figure
106
+ ax[1].figure.canvas.draw()
107
+
108
+ # get the original coords
109
+ all_tpl = ax[0].transData.transform(ax[0].collections[0].get_offsets().data) # display coords (px)
110
+ all_in= ax[1].transData.transform(ax[1].collections[0].get_offsets().data)
111
+ matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
112
+ matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
113
+
114
+ # transform the xy axis (template's -> input's)
115
+ rbf_x = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,0], function='thin_plate')
116
+ rbf_y = Rbf(matched_tpl[:,0], matched_tpl[:,1], matched_in[:,1], function='thin_plate')
117
+
118
+ # apply to all template channels
119
+ transformed_tpl_x = rbf_x(all_tpl[:,0], all_tpl[:,1])
120
+ transformed_tpl_y = rbf_y(all_tpl[:,0], all_tpl[:,1])
121
+ #transformed_tpl = np.vstack((transformed_tpl_x, transformed_tpl_y)).T
122
+
123
+ # update input, template's position
124
+ for i, channel in enumerate(template_order):
125
+ css_left = (transformed_tpl_x[i]-11)/560
126
+ css_bottom = (transformed_tpl_y[i]-7)/560
127
+ template_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
128
+ for i, channel in enumerate(input_order):
129
+ css_left = (all_in[i][0]-11)/560
130
+ css_bottom = (all_in[i][1]-7)/560
131
+ input_dict[channel]["css_position"] = [str(round(css_left*100, 2))+"%", str(round(css_bottom*100, 2))+"%"]
132
+
133
+
134
+ # 3-d (to use KNN)
135
+ # get the original coords
136
+ all_tpl = np.array([template_dict[channel]["coord"].tolist() for channel in template_order])
137
+ all_in = np.array([input_dict[channel]["coord"].tolist() for channel in input_order])
138
+ matched_tpl = np.array([all_tpl[template_dict[channel]["index"]] for channel in matched])
139
+ matched_in = np.array([all_in[input_dict[channel]["index"]] for channel in matched])
140
 
141
+ # transform the xyz axis (input's -> template's)
142
+ rbf_x = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,0], function='thin_plate')
143
+ rbf_y = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,1], function='thin_plate')
144
+ rbf_z = Rbf(matched_in[:,0], matched_in[:,1], matched_in[:,2], matched_tpl[:,2], function='thin_plate')
145
 
146
+ # apply to all input channels
147
+ transformed_in_x = rbf_x(all_in[:,0], all_in[:,1], all_in[:,2])
148
+ transformed_in_y = rbf_y(all_in[:,0], all_in[:,1], all_in[:,2])
149
+ transformed_in_z = rbf_z(all_in[:,0], all_in[:,1], all_in[:,2])
150
+ transformed_in = np.vstack((transformed_in_x, transformed_in_y, transformed_in_z)).T
151
+
152
+ # update input's position
153
+ for i, channel in enumerate(input_order):
154
+ input_dict[channel]["coord"] = transformed_in[i].tolist()
155
+
156
+ state.update({
157
+ "templateByName" : template_dict,
158
+ "inputByName" : input_dict,
159
+ })
160
+
161
+ return state
162
+
163
+ def fill_channels(state, fill_mode):
164
+
165
+ new_idx = state["newOrder"]
166
+ template_dict = state["templateByName"]
167
+ input_dict = state["inputByName"]
168
+ template_order = state["templateByIndex"]
169
+ input_order = state["inputByIndex"]
170
+ z_row_idx = state["dataShape"][0]
171
+ unmatched = [channel for channel in template_dict if template_dict[channel]["matched"]==False]
172
+ if unmatched == []:
173
+ return state
174
+
175
+ if fill_mode == 'zero':
176
+ for channel in unmatched:
177
+ idx = template_dict[channel]["index"]
178
+ new_idx[idx] = [z_row_idx]
179
+
180
+ elif fill_mode == 'mean_auto':
181
+ # use KNN to choose k nearest channels
182
+ in_coords = [input_dict[channel]["coord"] for channel in input_order]
183
+ in_coords = np.array([in_coords[i] for i in range(len(in_coords))])
184
+
185
+ k = 4 if len(input_dict)>4 else len(input_dict)
186
+ knn = NearestNeighbors(n_neighbors=k, metric='euclidean')
187
+ knn.fit(in_coords)
188
+
189
+ for channel in unmatched:
190
+ distances, indices = knn.kneighbors(template_dict[channel]["coord"].reshape(1,-1))
191
+ selected = [input_order[i] for i in indices[0]]
192
+ print(channel, ':', selected)
193
+
194
+ idx = template_dict[channel]["index"]
195
+ new_idx[idx] = indices[0].tolist()
196
+
197
+ state["newOrder"] = new_idx
198
+ return state
199
+
200
+ def mapping_stage1(state, data_file, loc_file, fill_mode):
201
+ second1 = time.time()
202
 
203
+ template_montage, input_montage, template_dict, input_dict = read_montage_data(loc_file)
204
+ template_order = template_montage.ch_names
205
+ new_idx = [[]]*30
206
+ missing_channels = []
207
  alias = {
208
  'T3': 'T7',
209
  'T4': 'T8',
210
  'T5': 'P7',
211
  'T6': 'P8',
212
+ #'TP7': 'T5\'',
213
+ #'TP8': 'T6\'',
214
  }
215
 
216
+ # match the names of input channels -> template channels
217
+ for i, channel in enumerate(template_order):
218
+ if channel in alias and alias[channel] in input_dict:
219
+ template_montage.rename_channels({channel: alias[channel]})
220
+ template_dict[alias[channel]] = template_dict.pop(channel)
221
+ template_dict[alias[channel]].name = alias[channel]
222
+ channel = alias[channel]
223
+
224
+ if channel in input_dict:
225
+ new_idx[i] = [input_dict[channel].index]
226
+
227
+ template_dict[channel].matched = True
228
+ input_dict[channel].matched = True
229
+ input_dict[channel].assigned = True
230
+ else:
231
+ missing_channels.append(i)
232
+
233
+ state.update({
234
+ "newOrder" : new_idx,
235
+ "missingChannelsIndex" : missing_channels,
236
+ "templateByName" : {k : v.__dict__ for k,v in template_dict.items()},
237
+ "inputByName" : {k : v.__dict__ for k,v in input_dict.items()},
238
+ "templateByIndex" : template_montage.ch_names,
239
+ "inputByIndex" : input_montage.ch_names
240
+ })
241
+
242
+ # align input, template's coordinates
243
+ state = align_coords(state, template_montage, input_montage)
244
+ # fill the unmatched channels
245
+ state = fill_channels(state, fill_mode)
246
+
247
+ second2 = time.time()
248
+ print('Mapping (stage1) finished in',second2 - second1,'s.')
249
+ return state
250
 
251
+ def mapping_stage2(state, fill_mode):
252
+ second1 = time.time()
 
 
253
 
254
+ template_dict = state["templateByName"]
255
+ input_dict = state["inputByName"]
256
+ template_order = state["templateByIndex"]
257
+ unassigned = [channel for channel in input_dict if input_dict[channel]["assigned"]==False]
258
+ if unassigned == []:
259
+ state["runnigState"] = "finished"
260
+ return state
 
 
 
 
 
261
 
262
+ tpl_coords = np.array([template_dict[channel]["coord"] for channel in template_order])
263
+ unassigned_coords = np.array([input_dict[channel]["coord"] for channel in unassigned])
264
 
265
+ # set all tpl.matched to False
266
+ for channel in template_dict:
267
+ template_dict[channel]["matched"] = False
268
 
269
+ # initialize the cost matrix
270
+ if len(unassigned) < 30:
271
+ cost_matrix = np.full((30, 30), 10000) # add dummy channels to ensure num_col > num_row
272
+ else:
273
+ cost_matrix = np.zeros((30, len(unassigned)))
274
+ for i in range(30):
275
+ for j in range(len(unassigned)):
276
+ cost_matrix[i][j] = np.linalg.norm(tpl_coords[i] - unassigned_coords[j]) # Euclidean distance
277
 
278
+ # use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
279
+ row_idx, col_idx = linear_sum_assignment(cost_matrix)
280
 
281
+ matches = []
282
+ new_idx = [[]]*30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  for i in range(30):
284
+ if col_idx[i] < len(unassigned): # filter out dummy channels
285
+ matches.append([row_idx[i], col_idx[i]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ tpl_channel = template_order[row_idx[i]]
288
+ in_channel = unassigned[col_idx[i]]
289
+ template_dict[tpl_channel]["matched"] = True
290
+ input_dict[in_channel]["assigned"] = True
291
+ new_idx[i] = [input_dict[in_channel]["index"]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ state.update({
294
+ "newOrder" : new_idx,
295
+ "templateByName" : template_dict,
296
+ "inputByName" : input_dict
297
+ })
298
 
299
+ # fill the unmatched channels
300
+ state = fill_channels(state, fill_mode)
 
 
 
301
 
302
+ second2 = time.time()
303
+ print(f'Mapping (stage2-{state["batchCount"]-1}) finished in {second2 - second1}s.')
304
+ return state
305
+
utils.py CHANGED
@@ -42,6 +42,34 @@ def resample(signal, fs):
42
 
43
  return signal_new
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def FIR_filter(signal, lowcut, highcut):
46
  fs = 256.0
47
  # Number of FIR filter taps
@@ -96,11 +124,14 @@ def glue_data(file_name, total, output):
96
  raw_data[:, 1] = smooth
97
  gluedata = np.append(gluedata, raw_data, axis=1)
98
  #print(gluedata.shape)
 
99
  filename2 = output
100
  with open(filename2, 'w', newline='') as csvfile:
101
  writer = csv.writer(csvfile)
102
  writer.writerows(gluedata)
103
  #print("GLUE DONE!" + filename2)
 
 
104
 
105
 
106
  def save_data(data, filename):
@@ -189,10 +220,11 @@ def preprocessing(filepath, filename, samplerate):
189
  print(e)
190
 
191
  # read data
192
- signal = read_train_data(filepath+'/'+filename)
193
  #print(signal.shape)
194
  # resample
195
  signal = resample(signal, samplerate)
 
196
  #print(signal.shape)
197
  # FIR_filter
198
  signal = FIR_filter(signal, 1, 50)
@@ -204,7 +236,7 @@ def preprocessing(filepath, filename, samplerate):
204
 
205
 
206
  # model = tf.keras.models.load_model('./denoise_model/')
207
- def reconstruct(model_name, total, filepath, outputfile):
208
  # -------------------decode_data---------------------------
209
  second1 = time.time()
210
  for i in range(total):
@@ -224,9 +256,14 @@ def reconstruct(model_name, total, filepath, outputfile):
224
  save_data(d_data, outputname)
225
 
226
  # --------------------glue_data----------------------------
227
- glue_data(filepath+"/temp2/", total, filepath+'/'+outputfile)
 
228
  # -------------------delete_data---------------------------
229
  dataDelete(filepath+"/temp2/")
 
 
 
 
230
  second2 = time.time()
231
 
232
  print("Using", model_name,"model to reconstruct", outputfile, " has been success in", second2 - second1, "sec(s)")
 
42
 
43
  return signal_new
44
 
45
+ # original -> 256Hz or 256Hz -> original
46
+ def resample_(signal, current_fs, target_fs):
47
+ fs = current_fs
48
+ # downsample the signal to the target sample rate
49
+ if fs>target_fs:
50
+ fs_down = target_fs # Desired sample rate
51
+ q = int(fs / fs_down) # Downsampling factor
52
+ signal_new = []
53
+ for ch in signal:
54
+ x_down = decimate(ch, q)
55
+ signal_new.append(x_down)
56
+
57
+ # upsample the signal to the target sample rate
58
+ elif fs<target_fs:
59
+ fs_up = target_fs # Desired sample rate
60
+ p = int(fs_up / fs) # Upsampling factor
61
+ signal_new = []
62
+ for ch in signal:
63
+ x_up = resample_poly(ch, p, 1)
64
+ signal_new.append(x_up)
65
+
66
+ else:
67
+ signal_new = signal
68
+
69
+ signal_new = np.array(signal_new).astype(np.float64)
70
+
71
+ return signal_new
72
+
73
  def FIR_filter(signal, lowcut, highcut):
74
  fs = 256.0
75
  # Number of FIR filter taps
 
124
  raw_data[:, 1] = smooth
125
  gluedata = np.append(gluedata, raw_data, axis=1)
126
  #print(gluedata.shape)
127
+ '''
128
  filename2 = output
129
  with open(filename2, 'w', newline='') as csvfile:
130
  writer = csv.writer(csvfile)
131
  writer.writerows(gluedata)
132
  #print("GLUE DONE!" + filename2)
133
+ '''
134
+ return gluedata
135
 
136
 
137
  def save_data(data, filename):
 
220
  print(e)
221
 
222
  # read data
223
+ signal = read_train_data(filepath+filename)
224
  #print(signal.shape)
225
  # resample
226
  signal = resample(signal, samplerate)
227
+ #signal = resample_(signal, samplerate, 256)
228
  #print(signal.shape)
229
  # FIR_filter
230
  signal = FIR_filter(signal, 1, 50)
 
236
 
237
 
238
  # model = tf.keras.models.load_model('./denoise_model/')
239
+ def reconstruct(model_name, total, filepath, outputfile, samplerate):
240
  # -------------------decode_data---------------------------
241
  second1 = time.time()
242
  for i in range(total):
 
256
  save_data(d_data, outputname)
257
 
258
  # --------------------glue_data----------------------------
259
+ signal = glue_data(filepath+"/temp2/", total, filepath+outputfile)
260
+ #print(signal.shape)
261
  # -------------------delete_data---------------------------
262
  dataDelete(filepath+"/temp2/")
263
+ # --------------------resample-----------------------------
264
+ signal = resample_(signal, 256, samplerate) # 256Hz -> original sampling rate
265
+ #print(signal.shape)
266
+ save_data(signal, filepath+outputfile)
267
  second2 = time.time()
268
 
269
  print("Using", model_name,"model to reconstruct", outputfile, " has been success in", second2 - second1, "sec(s)")