audrey06100 commited on
Commit
fe4e8b5
·
1 Parent(s): 0ab020b

update app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -71
app.py CHANGED
@@ -2,60 +2,12 @@ import gradio as gr
2
  import numpy as np
3
  import os
4
  import utils
5
- import channel_mapping
6
 
7
- def test(model_name, fill_mode, csv_file, chanlocs_file):
8
- tmp_path = os.path.dirname(str(csv_file))
9
-
10
- input_name = os.path.basename(str(csv_file))
11
- output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
12
-
13
- # Channel mapping
14
- #input_loc_name = os.path.basename(str(chanlocs_file))
15
- fill_mode = "adjacent" if fill_mode == "adjacent channel's value" else fill_mode
16
- channel_mapping.mapping(csv_file, chanlocs_file, fill_mode)
17
-
18
- # step1: Data preprocessing
19
- total_file_num = utils.preprocessing(tmp_path, 'mapped.csv', 256)
20
-
21
- # step2: Signal reconstruction
22
- utils.reconstruct(model_name, total_file_num, tmp_path, output_name)
23
-
24
- return tmp_path+'/'+output_name
25
 
26
-
27
- with gr.Blocks() as app:
28
- with gr.Row():
29
- gr.Markdown(
30
- """
31
- # Introduction
32
- (...)
33
- """
34
- )
35
- with gr.Row():
36
- with gr.Column():
37
- input_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART"],
38
- value="ICUNet",
39
- label="Model")
40
- input_fill_mode = gr.Dropdown(choices=["zero", "adjacent channel's value"],
41
- value="zero",
42
- label="Imputation")
43
- input_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
44
- input_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
45
- btn = gr.Button()
46
- with gr.Column():
47
- output_data = gr.File(label="Denoised data")
48
- # put template_loc file
49
- gr.Markdown(
50
- """
51
- (Missing channels)
52
- (New channel locations)
53
- """
54
- )
55
- with gr.Row():
56
- with gr.Tab("README"):
57
- gr.Markdown(
58
- """
59
  # Quickstart
60
 
61
  ### Raw data
@@ -86,28 +38,292 @@ Therefore, you need to
86
  <b><font color=#FF0000>remove these channels</font></b>
87
  after you download the denoised data.
88
 
89
- """
90
- # ### Denoised data: Once the reconstructing process finished, the denoised data will be downloadable here.
91
- # ### New channel locations: The template channel locations is downloadable here.
92
- )
93
- with gr.Tab("IC-U-Net"):
94
- gr.Markdown(
95
- """
96
  # IC-U-Net
97
  ### Abstract
98
- Electroencephalography (EEG) signals are often contaminated with artifacts. It is imperative to develop a practical and reliable artifact removal method to prevent the misinterpretation of neural signals and the underperformance of brain–computer interfaces. Based on the U-Net architecture, we developed a new artifact removal model, IC-U-Net, for removing pervasive EEG artifacts and reconstructing brain signals. IC-U-Net was trained using mixtures of brain and non-brain components decomposed by independent component analysis. It uses an ensemble of loss functions to model complex signal fluctuations in EEG recordings. The effectiveness of the proposed method in recovering brain activities and removing various artifacts (e.g., eye blinks/movements, muscle activities, and line/channel noise) was demonstrated in a simulation study and four real-world EEG experiments. IC-U-Net can reconstruct a multi-channel EEG signal and is applicable to most artifact types, offering a promising end-to-end solution for automatically removing artifacts from EEG recordings. It also meets the increasing need to image natural brain dynamics in a mobile setting.
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  """
101
- )
102
- with gr.Tab("IC-U-Net++"):
103
- gr.Markdown(
104
- """
105
- # Test
106
- """
107
- )
108
-
109
- inputs = [input_model_name, input_fill_mode, input_data, input_loc]
110
- btn.click(fn=test, inputs=inputs, outputs=output_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  if __name__ == "__main__":
113
- app.launch()
 
2
  import numpy as np
3
  import os
4
  import utils
5
+ from channel_mapping import mapping, reorder_data
6
 
7
+ import mne
8
+ from mne.channels import read_custom_montage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ readme = """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Quickstart
12
 
13
  ### Raw data
 
38
  <b><font color=#FF0000>remove these channels</font></b>
39
  after you download the denoised data.
40
 
41
+ """
42
+ # ### Denoised data: Once the reconstructing process finished, the denoised data will be downloadable here.
43
+ # ### New channel locations: The template channel locations is downloadable here.
44
+
45
+ icunet = """
 
 
46
  # IC-U-Net
47
  ### Abstract
48
+ Electroencephalography (EEG) signals are often contaminated with artifacts. It is imperative to develop a practical and reliable artifact removal method to prevent the misinterpretation of neural signals and the underperformance of brain–computer interfaces. Based on the U-Net architecture, we developed a new artifact removal model, IC-U-Net, for removing pervasive EEG artifacts and reconstructing brain signals. IC-U-Net was trained using mixtures of brain and non-brain components decomposed by independent component analysis. It uses an ensemble of loss functions to model complex signal fluctuations in EEG recordings. The effectiveness of the proposed method in recovering brain activities and removing various artifacts (e.g., eye blinks/movements, muscle activities, and line/channel noise) was demonstrated in a simulation study and four real-world EEG experiments. IC-U-Net can reconstruct a multi-channel EEG signal and is applicable to most artifact types, offering a promising end-to-end solution for automatically removing artifacts from EEG recordings. It also meets the increasing need to image natural brain dynamics in a mobile setting.
49
+ """
50
 
51
+ chk_html = """
52
+ <form name="test" id="chs-form">
53
+ <input type="checkbox" />
54
+ </form>
55
+ """
56
+
57
+ chk_script = """
58
+
59
+ let channels = document.getElementById("chs-checkbox");
60
+
61
+ // init generate checkboxgroup
62
+ let obj = document.getElementById("map-result").value; // emmm......
63
+ let channels = obj.channels;
64
+ let num = channels.length;
65
+
66
+ for(i=0; i<num; i++){
67
+ document.getElementById("gen-checkbox").
68
+ innerHTML += '<input type="checkbox" class="channels" name="channel" value=channels[i].name />'
69
+ }
70
+
71
+ // check if mapping just finished
72
+ const result = document.getElementById("map-result")
73
+ result.addEventListener("change", function() {
74
+ const res_obj = this.value;
75
+
76
+ if(res_obj.fill_mode=="mean" && res_obj.missing_channels.length!=0 && !res_obj.start){
77
+ gen_chkbox(res_obj.missing_channels);
78
+ res_obj.start = True;
79
+ }
80
+ })
81
+
82
+ function gen_chkbox(channels){
83
+ let num = channels.length;
84
+ let chs_form = document.getElementById("chs-form");
85
+
86
+ chs_form.innerHTML = "";
87
+ for(i=0; i<num; i++){
88
+ chs_form.innerHTML += '<input type="checkbox" class="channels" name="channel" value=channels[i].name />'
89
+ }
90
+ }
91
+
92
+
93
+ """
94
+
95
+ with gr.Blocks() as demo:
96
+
97
+ state_json = gr.JSON(elem_id="state", visible=False)
98
+
99
+ with gr.Row():
100
+ gr.Markdown(
101
+ """
102
+ # Introduction
103
+ (...)
104
+ """
105
+ )
106
+ with gr.Row():
107
+ with gr.Column():
108
+ gr.Markdown(
109
+ """
110
+ # 1.
111
  """
112
+ )
113
+ with gr.Row():
114
+ input_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
115
+ input_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
116
+ with gr.Row():
117
+ input_fill_mode = gr.Dropdown(choices=["zero",
118
+ ("adjacent channel's value", "adjacent"),
119
+ ("mean(manual)", "mean")],
120
+ value="zero",
121
+ label="Imputation",
122
+ scale=2)
123
+ map_btn = gr.Button("Mapping", scale=1)
124
+ channels_json = gr.JSON(visible=False)
125
+ with gr.Row():
126
+ tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
127
+ map_montage = gr.Image(label="Mapping result", visible=False)
128
+ input_montage = gr.Image(label="Input montage", visible=False)
129
+ with gr.Accordion(visible=False) as accordion: #???
130
+ with gr.Column():
131
+ #chk_block = gr.HTML(chk_html)
132
+ #input_montage = gr.Image(label="Input montage")
133
+ chs_chkbox = gr.CheckboxGroup(elem_classes="chs-chkbox", label="")
134
+ next_btn = gr.Button("Submmit/Next", interactive=False)
135
+ tpl_loc_file = gr.File("./template_chanlocs.loc", label="Template channel locations file")
136
+ with gr.Column():
137
+ gr.Markdown(
138
+ """
139
+ # 2.
140
+ """
141
+ )
142
+ with gr.Row():
143
+ input_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART"],
144
+ value="ICUNet",
145
+ label="Model",
146
+ scale=2)
147
+ run_btn = gr.Button(scale=1, interactive=False)
148
+ ouput_denoised_data = gr.File(label="Denoised data")
149
+
150
+ #with gr.Row():
151
+ # mapResult_montage_img = gr.HTML()
152
+ # tpl_montage_img = gr.HTML()
153
+ gr.Markdown(
154
+ """
155
+ (Missing channels)
156
+ """
157
+ )
158
+
159
+ with gr.Row():
160
+ with gr.Tab("EEGART"):
161
+ gr.Markdown()
162
+ with gr.Tab("IC-U-Net"):
163
+ gr.Markdown(icunet)
164
+ with gr.Tab("IC-U-Net++"):
165
+ gr.Markdown()
166
+ with gr.Tab("IC-U-Net-Att"):
167
+ gr.Markdown()
168
+ with gr.Tab("README"):
169
+ gr.Markdown(readme)
170
+
171
+ #demo.load(js=chk_script)
172
+
173
+ def reset_layout(raw_data):
174
+ # establish temp folders
175
+ filepath = os.path.dirname(str(raw_data))
176
+
177
+ #filename = filepath+"/temp_data/img/input_montage.png"
178
+ #if os.path.exists(filename):
179
+ # os.remove(filename)
180
+ # print("Previous input image has been deleted!")
181
+
182
+ try:
183
+ os.mkdir(filepath+"/temp_data/")
184
+ except OSError as e:
185
+ utils.dataDelete(filepath+"/temp_data/")
186
+ os.mkdir(filepath+"/temp_data/")
187
+ print(e)
188
+ try:
189
+ os.mkdir(filepath+"/temp_data/img/")
190
+ except OSError as e:
191
+ utils.dataDelete(filepath+"/temp_data/img/")
192
+ os.mkdir(filepath+"/temp_data/img/")
193
+ print(e)
194
+
195
+ return {state_json : {},
196
+ accordion : gr.Accordion(visible=False),
197
+ chs_chkbox : gr.CheckboxGroup(choices=[], value=[], label=""), # choices, value ???
198
+ next_btn : gr.Button(interactive=False),
199
+ run_btn : gr.Button(interactive=False),
200
+ tpl_montage : gr.Image(visible=False),
201
+ map_montage : gr.Image(value=None, visible=False),
202
+ input_montage : gr.Image(value=None, visible=False),}
203
+
204
+ def mapping_result(channels_obj, raw_data, fill_mode):
205
+ state_obj = channels_obj.copy()
206
+ filepath = os.path.dirname(str(raw_data))
207
+ if fill_mode=="mean" and channels_obj["missingIndex"]!=[]:
208
+ state_obj.update({
209
+ "state" : "initializing",
210
+ "currentIndex" : 0,
211
+ "maxIndex" : len(channels_obj["missingIndex"])-1
212
+ })
213
+ #print("Missing channels:", state_obj["missingIndex"])
214
+ return {state_json : state_obj,
215
+ accordion : gr.Accordion(visible=True)}
216
+ else:
217
+ state_obj.update({
218
+ "state" : "finished",
219
+ "currentIndex" : -1,
220
+ "maxIndex" : -1
221
+ })
222
+ reorder_data(raw_data, channels_obj["newOrder"], fill_mode)
223
+ return {state_json : state_obj,
224
+ run_btn : gr.Button(interactive=True)}
225
+
226
+
227
+ def generate_chkbox(state_obj):
228
+ if state_obj["state"] == "finished":
229
+ return {state_json : {}}
230
+ if state_obj["state"] == "initializing":
231
+ in_channels = [channel for channel in state_obj["inputByName"]]
232
+ state_obj["state"] = "selecting"
233
+ # and img....
234
+
235
+ first_idx = state_obj["missingIndex"][0]
236
+ first_name = state_obj["templateByIndex"][first_idx]
237
+ chkbox_label = first_name+' (1/'+str(state_obj["maxIndex"]+1)+')'
238
+ return {state_json : state_obj,
239
+ chs_chkbox : gr.CheckboxGroup(choices=in_channels, label=chkbox_label),
240
+ next_btn : gr.Button(interactive=True)}
241
+
242
+ def show_montage(state_obj, raw_data, raw_loc):
243
+ filepath = os.path.dirname(str(raw_data))
244
+
245
+ raw_montage = read_custom_montage(raw_loc)
246
+ if state_obj["state"] == "initializing":
247
+ filename = filepath+"/temp_data/img/input_montage.png"
248
+ input_fig = raw_montage.plot(show=False)
249
+ input_fig.savefig(filename)
250
+ return {tpl_montage : gr.Image(visible=True),
251
+ input_montage : gr.Image(value=filename, visible=True)}
252
+
253
+ elif state_obj["state"] == "finished":
254
+ # didn't find any way to remove channels from DigMontage...
255
+ # tmp
256
+ filename = filepath+"/temp_data/img/input_montage.png"
257
+ input_fig = raw_montage.plot(show=False)
258
+ input_fig.savefig(filename)
259
+ return {input_montage : gr.Image(visible=False),
260
+ tpl_montage : gr.Image(visible=True),
261
+ map_montage : gr.Image(value=filename, visible=True)} # value=,
262
+
263
+ elif state_obj["state"] == "selecting":
264
+ # need to update input_montage here
265
+ return {input_montage : gr.Image()}
266
+
267
+
268
+ map_btn.click(reset_layout, input_raw_data,
269
+ [state_json, accordion, chs_chkbox, next_btn, run_btn, tpl_montage, map_montage, input_montage]).success(
270
+ mapping, [input_raw_data, input_raw_loc, input_fill_mode], channels_json).success(
271
+ mapping_result, [channels_json, input_raw_data, input_fill_mode], [state_json, accordion, run_btn]).success(
272
+ show_montage, [state_json, input_raw_data, input_raw_loc], [input_montage, tpl_montage, map_montage]).success(
273
+ generate_chkbox, state_json, [state_json, chs_chkbox, next_btn])
274
+
275
+
276
+ def check_next(state_obj, selected, raw_data, fill_mode):
277
+ if state_obj["state"] == "selecting":
278
+ if state_obj["currentIndex"] <= state_obj["maxIndex"]:
279
+
280
+ # info before clicking on next_btn
281
+ prev_target_idx = state_obj["missingIndex"][state_obj["currentIndex"]]
282
+ prev_target_name = state_obj["templateByIndex"][prev_target_idx]
283
+
284
+ selected_idx = [state_obj["inputByName"][channel]["index"] for channel in selected]
285
+ state_obj["newOrder"][prev_target_idx] = selected_idx
286
+
287
+ print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
288
+
289
+ # current info
290
+ state_obj["currentIndex"] += 1
291
+ if state_obj["currentIndex"] <= state_obj["maxIndex"]:
292
+ target_idx = state_obj["missingIndex"][state_obj["currentIndex"]]
293
+ target_name = state_obj["templateByIndex"][target_idx]
294
+ chkbox_label = target_name+' ('+str(state_obj["currentIndex"]+1)+'/'+str(state_obj["maxIndex"]+1)+')'
295
+
296
+ #return {chk_block : chk_html, state_json : state_obj}
297
+ return {state_json : state_obj,
298
+ chs_chkbox : gr.CheckboxGroup(value=[], label=chkbox_label)}
299
+ else:
300
+ state_obj["state"] = "finished"
301
+ reorder_data(raw_data, state_obj["newOrder"], fill_mode)
302
+
303
+ return {state_json : state_obj,
304
+ accordion : gr.Accordion(visible=False),
305
+ run_btn : gr.Button(interactive=True)}
306
+
307
+ next_btn.click(check_next, [state_json, chs_chkbox, input_raw_data, input_fill_mode],
308
+ [state_json, accordion, chs_chkbox, run_btn]).success(
309
+ show_montage, [state_json, input_raw_data, input_raw_loc], [input_montage, tpl_montage, map_montage])
310
+
311
+
312
+ @run_btn.click(inputs=[input_model_name, input_raw_data], outputs=ouput_denoised_data)
313
+ def run_model(model_name, raw_file):
314
+ file_path = os.path.dirname(str(raw_file))
315
+
316
+ input_name = os.path.basename(str(raw_file))
317
+ output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
318
+
319
+ # step1: Data preprocessing
320
+ total_file_num = utils.preprocessing(file_path+'/temp_data', 'mapped.csv', 256)
321
+
322
+ # step2: Signal reconstruction
323
+ utils.reconstruct(model_name, total_file_num, file_path+'/temp_data/', output_name)
324
+
325
+ return file_path+'/'+output_name
326
+
327
 
328
  if __name__ == "__main__":
329
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)