Spaces:
Sleeping
Sleeping
Commit
·
995c1d0
1
Parent(s):
884c10b
update
Browse files- app.py +72 -32
- channel_mapping.py +21 -12
app.py
CHANGED
@@ -55,7 +55,7 @@ chkbox_js = """
|
|
55 |
position: relative;
|
56 |
width: 560px;
|
57 |
height: 560px;
|
58 |
-
background: url("file=${app_state.
|
59 |
`;
|
60 |
|
61 |
|
@@ -191,7 +191,6 @@ with gr.Blocks() as demo:
|
|
191 |
label="Imputation")
|
192 |
map_btn = gr.Button("Mapping")
|
193 |
|
194 |
-
#indic
|
195 |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
|
196 |
next_btn = gr.Button("Next", interactive=False, visible=False)
|
197 |
|
@@ -228,7 +227,7 @@ with gr.Blocks() as demo:
|
|
228 |
scale=2)
|
229 |
run_btn = gr.Button(scale=1, interactive=False)
|
230 |
batch_md = gr.Markdown(visible=False)
|
231 |
-
out_denoised_data = gr.File(label="Denoised data")
|
232 |
|
233 |
|
234 |
with gr.Row():
|
@@ -245,7 +244,7 @@ with gr.Blocks() as demo:
|
|
245 |
|
246 |
#demo.load(js=js)
|
247 |
|
248 |
-
def
|
249 |
# establish temp folder
|
250 |
filepath = os.path.dirname(str(raw_data))
|
251 |
try:
|
@@ -259,7 +258,7 @@ with gr.Blocks() as demo:
|
|
259 |
data = utils.read_train_data(raw_data)
|
260 |
app_state = {
|
261 |
"filepath": filepath+"/temp_data/",
|
262 |
-
"
|
263 |
"sampleRate": int(samplerate),
|
264 |
|
265 |
}
|
@@ -275,7 +274,8 @@ with gr.Blocks() as demo:
|
|
275 |
tpl_montage : gr.Image(visible=False),
|
276 |
map_montage : gr.Image(value=None, visible=False),
|
277 |
res_md : gr.Markdown(visible=False),
|
278 |
-
batch_md : gr.Markdown(visible=False)
|
|
|
279 |
|
280 |
def mapping_result(app_state, channel_info, fill_mode):
|
281 |
|
@@ -283,7 +283,6 @@ with gr.Blocks() as demo:
|
|
283 |
matched_num = 30 - len(channel_info["missingChannelsIndex"])
|
284 |
batch_num = math.ceil((in_num-matched_num)/30) + 1
|
285 |
app_state.update({
|
286 |
-
"runnigState" : "stage1",
|
287 |
"batchCount" : 1,
|
288 |
"totalBatchNum" : batch_num
|
289 |
})
|
@@ -295,11 +294,11 @@ with gr.Blocks() as demo:
|
|
295 |
})
|
296 |
#print("Missing channels:", channel_info["missingChannelsIndex"])
|
297 |
return {app_state_json : app_state,
|
298 |
-
#chkbox_group : gr.CheckboxGroup(visible=True),
|
299 |
next_btn : gr.Button(visible=True)}
|
300 |
else:
|
301 |
-
app_state
|
302 |
-
|
|
|
303 |
return {app_state_json : app_state,
|
304 |
res_md : gr.Markdown(visible=True),
|
305 |
run_btn : gr.Button(interactive=True)}
|
@@ -318,7 +317,7 @@ with gr.Blocks() as demo:
|
|
318 |
|
319 |
if app_state["state"] == "initializing":
|
320 |
filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
|
321 |
-
app_state["
|
322 |
raw_fig = raw_montage.plot()
|
323 |
raw_fig.set_size_inches(5.6, 5.6)
|
324 |
raw_fig.savefig(filename, pad_inches=0)
|
@@ -327,7 +326,7 @@ with gr.Blocks() as demo:
|
|
327 |
|
328 |
elif app_state["state"] == "finished":
|
329 |
filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
|
330 |
-
app_state["
|
331 |
|
332 |
show_names= []
|
333 |
for channel in channel_info["inputByName"]:
|
@@ -361,9 +360,10 @@ with gr.Blocks() as demo:
|
|
361 |
|
362 |
|
363 |
map_btn.click(
|
364 |
-
fn =
|
365 |
inputs = [in_raw_data, in_sample_rate],
|
366 |
-
outputs = [app_state_json, channel_info_json, chkbox_group, next_btn, run_btn,
|
|
|
367 |
|
368 |
).success(
|
369 |
fn = mapping_stage1,
|
@@ -401,7 +401,7 @@ with gr.Blocks() as demo:
|
|
401 |
prev_target_name = channel_info["templateByIndex"][prev_target_idx]
|
402 |
|
403 |
selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected]
|
404 |
-
app_state["
|
405 |
|
406 |
#if len(selected)==1 and channel_info["inputByName"][selected[0]]["used"]==False:
|
407 |
#channel_info["inputByName"][selected[0]]["used"] = True
|
@@ -450,22 +450,47 @@ with gr.Blocks() as demo:
|
|
450 |
outputs = []
|
451 |
)
|
452 |
|
453 |
-
|
454 |
-
|
455 |
-
|
|
|
|
|
|
|
|
|
456 |
filepath = app_state["filepath"]
|
457 |
-
samplerate = app_state["sampleRate"]
|
458 |
-
|
459 |
input_name = os.path.basename(str(raw_data))
|
460 |
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
|
461 |
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
466 |
if app_state["batchCount"] > 1:
|
467 |
-
app_state["runnigState"] = "stage2"
|
468 |
app_state, channel_info = mapping_stage2(app_state, channel_info, fill_mode)
|
|
|
|
|
469 |
app_state["batchCount"] += 1
|
470 |
|
471 |
reorder_to_template(app_state, raw_data)
|
@@ -473,15 +498,30 @@ with gr.Blocks() as demo:
|
|
473 |
total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
|
474 |
# step2: Signal reconstruction
|
475 |
utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
|
476 |
-
reorder_to_origin(app_state, channel_info, filepath+'denoised.csv',
|
|
|
|
|
|
|
|
|
|
|
477 |
|
478 |
-
|
479 |
-
|
480 |
-
elif model_name == "(denoised data)":
|
481 |
-
return {out_denoised_data : filepath + 'denoised.csv'}
|
482 |
|
483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
484 |
|
|
|
|
|
|
|
|
|
|
|
485 |
|
486 |
if __name__ == "__main__":
|
487 |
-
demo.launch(
|
|
|
55 |
position: relative;
|
56 |
width: 560px;
|
57 |
height: 560px;
|
58 |
+
background: url("file=${app_state.filenames.raw_montage}");
|
59 |
`;
|
60 |
|
61 |
|
|
|
191 |
label="Imputation")
|
192 |
map_btn = gr.Button("Mapping")
|
193 |
|
|
|
194 |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
|
195 |
next_btn = gr.Button("Next", interactive=False, visible=False)
|
196 |
|
|
|
227 |
scale=2)
|
228 |
run_btn = gr.Button(scale=1, interactive=False)
|
229 |
batch_md = gr.Markdown(visible=False)
|
230 |
+
out_denoised_data = gr.File(label="Denoised data", visible=False)
|
231 |
|
232 |
|
233 |
with gr.Row():
|
|
|
244 |
|
245 |
#demo.load(js=js)
|
246 |
|
247 |
+
def reset1(raw_data, samplerate):
|
248 |
# establish temp folder
|
249 |
filepath = os.path.dirname(str(raw_data))
|
250 |
try:
|
|
|
258 |
data = utils.read_train_data(raw_data)
|
259 |
app_state = {
|
260 |
"filepath": filepath+"/temp_data/",
|
261 |
+
"filenames": {},
|
262 |
"sampleRate": int(samplerate),
|
263 |
|
264 |
}
|
|
|
274 |
tpl_montage : gr.Image(visible=False),
|
275 |
map_montage : gr.Image(value=None, visible=False),
|
276 |
res_md : gr.Markdown(visible=False),
|
277 |
+
batch_md : gr.Markdown(visible=False),
|
278 |
+
out_denoised_data : gr.File(visible=False)}
|
279 |
|
280 |
def mapping_result(app_state, channel_info, fill_mode):
|
281 |
|
|
|
283 |
matched_num = 30 - len(channel_info["missingChannelsIndex"])
|
284 |
batch_num = math.ceil((in_num-matched_num)/30) + 1
|
285 |
app_state.update({
|
|
|
286 |
"batchCount" : 1,
|
287 |
"totalBatchNum" : batch_num
|
288 |
})
|
|
|
294 |
})
|
295 |
#print("Missing channels:", channel_info["missingChannelsIndex"])
|
296 |
return {app_state_json : app_state,
|
|
|
297 |
next_btn : gr.Button(visible=True)}
|
298 |
else:
|
299 |
+
app_state.update({
|
300 |
+
"state" : "finished"
|
301 |
+
})
|
302 |
return {app_state_json : app_state,
|
303 |
res_md : gr.Markdown(visible=True),
|
304 |
run_btn : gr.Button(interactive=True)}
|
|
|
317 |
|
318 |
if app_state["state"] == "initializing":
|
319 |
filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
|
320 |
+
app_state["filenames"]["raw_montage"] = filename
|
321 |
raw_fig = raw_montage.plot()
|
322 |
raw_fig.set_size_inches(5.6, 5.6)
|
323 |
raw_fig.savefig(filename, pad_inches=0)
|
|
|
326 |
|
327 |
elif app_state["state"] == "finished":
|
328 |
filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
|
329 |
+
app_state["filenames"]["map_montage"] = filename
|
330 |
|
331 |
show_names= []
|
332 |
for channel in channel_info["inputByName"]:
|
|
|
360 |
|
361 |
|
362 |
map_btn.click(
|
363 |
+
fn = reset1,
|
364 |
inputs = [in_raw_data, in_sample_rate],
|
365 |
+
outputs = [app_state_json, channel_info_json, chkbox_group, next_btn, run_btn,
|
366 |
+
tpl_montage, map_montage, res_md, batch_md, out_denoised_data]
|
367 |
|
368 |
).success(
|
369 |
fn = mapping_stage1,
|
|
|
401 |
prev_target_name = channel_info["templateByIndex"][prev_target_idx]
|
402 |
|
403 |
selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected]
|
404 |
+
app_state["stage1NewOrder"][prev_target_idx] = selected_idx
|
405 |
|
406 |
#if len(selected)==1 and channel_info["inputByName"][selected[0]]["used"]==False:
|
407 |
#channel_info["inputByName"][selected[0]]["used"] = True
|
|
|
450 |
outputs = []
|
451 |
)
|
452 |
|
453 |
+
def delete_file(filename):
|
454 |
+
try:
|
455 |
+
os.remove(filename)
|
456 |
+
except OSError as e:
|
457 |
+
print(e)
|
458 |
+
|
459 |
+
def reset2(app_state, raw_data, model_name):
|
460 |
filepath = app_state["filepath"]
|
|
|
|
|
461 |
input_name = os.path.basename(str(raw_data))
|
462 |
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
|
463 |
|
464 |
+
app_state["filenames"]["denoised"] = filepath + output_name
|
465 |
+
app_state.update({
|
466 |
+
"runnigState" : "stage1",
|
467 |
+
"batchCount" : 1,
|
468 |
+
"stage2NewOrder" : [[]]*30
|
469 |
+
})
|
470 |
+
|
471 |
+
delete_file(filepath+'mapped.csv')
|
472 |
+
delete_file(filepath+'denoised.csv')
|
473 |
+
return {app_state_json : app_state,
|
474 |
+
run_btn : gr.Button(interactive=False),
|
475 |
+
batch_md : gr.Markdown(visible=False),
|
476 |
+
out_denoised_data : gr.File(visible=False)}
|
477 |
+
|
478 |
+
def run_model(app_state, channel_info, raw_data, model_name, fill_mode):
|
479 |
+
filepath = app_state["filepath"]
|
480 |
+
samplerate = app_state["sampleRate"]
|
481 |
+
new_filename = app_state["filenames"]["denoised"]
|
482 |
+
|
483 |
+
while app_state["runnigState"] != "finished":
|
484 |
+
#if app_state["batchCount"] > app_state["totalBatchNum"]:
|
485 |
+
#app_state["runnigState"] = "finished"
|
486 |
+
#break
|
487 |
+
md = 'Running model('+str(app_state["batchCount"])+'/'+str(app_state["totalBatchNum"])+')...'
|
488 |
+
yield {batch_md : gr.Markdown(md, visible=True)}
|
489 |
+
|
490 |
if app_state["batchCount"] > 1:
|
|
|
491 |
app_state, channel_info = mapping_stage2(app_state, channel_info, fill_mode)
|
492 |
+
if app_state["runnigState"] == "finished":
|
493 |
+
break
|
494 |
app_state["batchCount"] += 1
|
495 |
|
496 |
reorder_to_template(app_state, raw_data)
|
|
|
498 |
total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
|
499 |
# step2: Signal reconstruction
|
500 |
utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
|
501 |
+
reorder_to_origin(app_state, channel_info, filepath+'denoised.csv', new_filename)
|
502 |
+
|
503 |
+
#if model_name == "(mapped data)":
|
504 |
+
#return {out_denoised_data : filepath + 'mapped.csv'}
|
505 |
+
#elif model_name == "(denoised data)":
|
506 |
+
#return {out_denoised_data : filepath + 'denoised.csv'}
|
507 |
|
508 |
+
delete_file(filepath+'mapped.csv')
|
509 |
+
delete_file(filepath+'denoised.csv')
|
|
|
|
|
510 |
|
511 |
+
yield {run_btn : gr.Button(interactive=True),
|
512 |
+
batch_md : gr.Markdown(visible=False),
|
513 |
+
out_denoised_data : gr.File(new_filename, visible=True)}
|
514 |
+
|
515 |
+
run_btn.click(
|
516 |
+
fn = reset2,
|
517 |
+
inputs = [app_state_json, in_raw_data, in_model_name],
|
518 |
+
outputs = [app_state_json, run_btn, batch_md, out_denoised_data]
|
519 |
|
520 |
+
).success(
|
521 |
+
fn = run_model,
|
522 |
+
inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name, in_fill_mode],
|
523 |
+
outputs = [run_btn, batch_md, out_denoised_data]
|
524 |
+
)
|
525 |
|
526 |
if __name__ == "__main__":
|
527 |
+
demo.launch()
|
channel_mapping.py
CHANGED
@@ -11,10 +11,12 @@ from scipy.optimize import linear_sum_assignment
|
|
11 |
from sklearn.neighbors import NearestNeighbors
|
12 |
|
13 |
def reorder_to_template(app_state, filename):
|
14 |
-
old_idx = app_state["
|
15 |
old_data = utils.read_train_data(filename) # original raw data
|
16 |
new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
|
17 |
new_filename = app_state["filepath"]+'mapped.csv'
|
|
|
|
|
18 |
|
19 |
zero_arr = np.zeros((1, old_data.shape[1]))
|
20 |
old_data = np.concatenate((old_data, zero_arr), axis=0)
|
@@ -34,7 +36,7 @@ def reorder_to_template(app_state, filename):
|
|
34 |
return
|
35 |
|
36 |
def reorder_to_origin(app_state, channel_info, filename, new_filename):
|
37 |
-
old_idx = app_state["
|
38 |
old_data = utils.read_train_data(filename) # denoised data
|
39 |
template_order = channel_info["templateByIndex"]
|
40 |
|
@@ -161,7 +163,7 @@ def align_coords(channel_info, template_montage, input_montage):
|
|
161 |
|
162 |
def fill_channels(app_state, channel_info, fill_mode):
|
163 |
|
164 |
-
new_idx = app_state["
|
165 |
template_dict = channel_info["templateByName"]
|
166 |
input_dict = channel_info["inputByName"]
|
167 |
template_order = channel_info["templateByIndex"]
|
@@ -186,16 +188,18 @@ def fill_channels(app_state, channel_info, fill_mode):
|
|
186 |
knn.fit(in_coords)
|
187 |
|
188 |
for channel in unmatched:
|
189 |
-
distances, indices = knn.kneighbors(template_dict[channel]["coord"].reshape(1,-1))
|
190 |
selected = [input_order[i] for i in indices[0]]
|
191 |
print(channel, ':', selected)
|
192 |
|
193 |
idx = template_dict[channel]["index"]
|
194 |
new_idx[idx] = indices[0].tolist()
|
195 |
-
|
196 |
-
app_state
|
197 |
-
|
198 |
-
|
|
|
|
|
199 |
return app_state
|
200 |
|
201 |
def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode):
|
@@ -239,7 +243,8 @@ def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode):
|
|
239 |
"inputByIndex" : input_montage.ch_names
|
240 |
})
|
241 |
app_state.update({
|
242 |
-
"
|
|
|
243 |
})
|
244 |
|
245 |
# align input, template's coordinates
|
@@ -271,12 +276,13 @@ def mapping_stage2(app_state, channel_info, fill_mode):
|
|
271 |
|
272 |
# initialize the cost matrix
|
273 |
if len(unassigned) < 30:
|
274 |
-
cost_matrix = np.full((30, 30),
|
275 |
else:
|
276 |
cost_matrix = np.zeros((30, len(unassigned)))
|
277 |
for i in range(30):
|
278 |
for j in range(len(unassigned)):
|
279 |
-
cost_matrix[i][j] = np.linalg.norm(tpl_coords[i]
|
|
|
280 |
|
281 |
# use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
|
282 |
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
@@ -292,13 +298,16 @@ def mapping_stage2(app_state, channel_info, fill_mode):
|
|
292 |
template_dict[tpl_channel]["matched"] = True
|
293 |
input_dict[in_channel]["assigned"] = True
|
294 |
new_idx[i] = [input_dict[in_channel]["index"]]
|
|
|
|
|
295 |
|
296 |
channel_info.update({
|
297 |
"templateByName" : template_dict,
|
298 |
"inputByName" : input_dict
|
299 |
})
|
300 |
app_state.update({
|
301 |
-
"
|
|
|
302 |
})
|
303 |
|
304 |
# fill the unmatched channels
|
|
|
11 |
from sklearn.neighbors import NearestNeighbors
|
12 |
|
13 |
def reorder_to_template(app_state, filename):
|
14 |
+
old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
|
15 |
old_data = utils.read_train_data(filename) # original raw data
|
16 |
new_data = np.zeros((30, old_data.shape[1])) # reordered raw data
|
17 |
new_filename = app_state["filepath"]+'mapped.csv'
|
18 |
+
#print('new order 1:', app_state["stage1NewOrder"])
|
19 |
+
#print('new order 2:', app_state["stage2NewOrder"])
|
20 |
|
21 |
zero_arr = np.zeros((1, old_data.shape[1]))
|
22 |
old_data = np.concatenate((old_data, zero_arr), axis=0)
|
|
|
36 |
return
|
37 |
|
38 |
def reorder_to_origin(app_state, channel_info, filename, new_filename):
|
39 |
+
old_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
|
40 |
old_data = utils.read_train_data(filename) # denoised data
|
41 |
template_order = channel_info["templateByIndex"]
|
42 |
|
|
|
163 |
|
164 |
def fill_channels(app_state, channel_info, fill_mode):
|
165 |
|
166 |
+
new_idx = app_state["stage1NewOrder"] if app_state["runnigState"]=="stage1" else app_state["stage2NewOrder"]
|
167 |
template_dict = channel_info["templateByName"]
|
168 |
input_dict = channel_info["inputByName"]
|
169 |
template_order = channel_info["templateByIndex"]
|
|
|
188 |
knn.fit(in_coords)
|
189 |
|
190 |
for channel in unmatched:
|
191 |
+
distances, indices = knn.kneighbors(np.array(template_dict[channel]["coord"]).reshape(1,-1))
|
192 |
selected = [input_order[i] for i in indices[0]]
|
193 |
print(channel, ':', selected)
|
194 |
|
195 |
idx = template_dict[channel]["index"]
|
196 |
new_idx[idx] = indices[0].tolist()
|
197 |
+
|
198 |
+
if app_state["runnigState"] == "stage1":
|
199 |
+
app_state["stage1NewOrder"] = new_idx
|
200 |
+
else:
|
201 |
+
app_state["stage2NewOrder"] = new_idx
|
202 |
+
|
203 |
return app_state
|
204 |
|
205 |
def mapping_stage1(app_state, channel_info, data_file, loc_file, fill_mode):
|
|
|
243 |
"inputByIndex" : input_montage.ch_names
|
244 |
})
|
245 |
app_state.update({
|
246 |
+
"stage1NewOrder" : new_idx,
|
247 |
+
"runnigState" : "stage1"
|
248 |
})
|
249 |
|
250 |
# align input, template's coordinates
|
|
|
276 |
|
277 |
# initialize the cost matrix
|
278 |
if len(unassigned) < 30:
|
279 |
+
cost_matrix = np.full((30, 30), 1e6) # add dummy channels to ensure num_col > num_row
|
280 |
else:
|
281 |
cost_matrix = np.zeros((30, len(unassigned)))
|
282 |
for i in range(30):
|
283 |
for j in range(len(unassigned)):
|
284 |
+
cost_matrix[i][j] = np.linalg.norm((tpl_coords[i]-unassigned_coords[j])*1000) # Euclidean distance
|
285 |
+
#print(cost_matrix[i][j], tpl_coords[i] - unassigned_coords[j])
|
286 |
|
287 |
# use Hungarian Algorithm to find the minimum sum of distance of (input's coord to template's coord)...?
|
288 |
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
|
|
298 |
template_dict[tpl_channel]["matched"] = True
|
299 |
input_dict[in_channel]["assigned"] = True
|
300 |
new_idx[i] = [input_dict[in_channel]["index"]]
|
301 |
+
|
302 |
+
print(template_order[row_idx[i]], '<-', unassigned[col_idx[i]])
|
303 |
|
304 |
channel_info.update({
|
305 |
"templateByName" : template_dict,
|
306 |
"inputByName" : input_dict
|
307 |
})
|
308 |
app_state.update({
|
309 |
+
"stage2NewOrder" : new_idx,
|
310 |
+
"runnigState" : "stage2"
|
311 |
})
|
312 |
|
313 |
# fill the unmatched channels
|