Spaces:
Sleeping
Sleeping
Commit
·
c1cd6f3
1
Parent(s):
d92d76f
update
Browse files- app.py +10 -10
- app_utils.py +17 -17
app.py
CHANGED
@@ -331,7 +331,7 @@ with gr.Blocks() as demo:
|
|
331 |
"batchNum" : None,
|
332 |
"mappingResult" : [
|
333 |
{
|
334 |
-
"
|
335 |
"isOriginalData" : None
|
336 |
#"channelUsageNum" : None
|
337 |
}
|
@@ -488,7 +488,7 @@ with gr.Blocks() as demo:
|
|
488 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
489 |
sel_idx = channel_info["inputDict"][sel_radio]["index"]
|
490 |
|
491 |
-
stage1_info["mappingResult"][0]["
|
492 |
stage1_info["mappingResult"][0]["isOriginalData"][prev_tpl_idx] = True
|
493 |
channel_info["templateDict"][prev_tpl_name]["matched"] = True
|
494 |
channel_info["inputDict"][sel_radio]["assigned"] = True
|
@@ -565,10 +565,10 @@ with gr.Blocks() as demo:
|
|
565 |
value of the data from the selected channels. (By default, the 4 nearest channels are pre-selected.)
|
566 |
"""
|
567 |
# find the 4 nearest in_channels for each unmatched tpl_channels
|
568 |
-
stage1_info["mappingResult"][0]["
|
569 |
channel_info,
|
570 |
stage1_info["emptyTemplates"],
|
571 |
-
stage1_info["mappingResult"][0]["
|
572 |
# initialize the progress indication label
|
573 |
stage1_info["step3"] = {
|
574 |
"count" : 1,
|
@@ -578,7 +578,7 @@ with gr.Blocks() as demo:
|
|
578 |
label = '{} (1/{})'.format(tpl_name, stage1_info["step3"]["totalNum"])
|
579 |
|
580 |
tpl_idx = channel_info["templateDict"][tpl_name]["index"]
|
581 |
-
value = stage1_info["mappingResult"][0]["
|
582 |
value = [channel_info["inputNames"][i] for i in value]
|
583 |
|
584 |
stage1_info["state"] = "step3-2-selecting"
|
@@ -607,7 +607,7 @@ with gr.Blocks() as demo:
|
|
607 |
prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step3"]["count"]-1]
|
608 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
609 |
sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_chkbox]
|
610 |
-
stage1_info["mappingResult"][0]["
|
611 |
#print(prev_tpl_name, '<-', sel_chkbox)
|
612 |
# ----------------------------------------------------------------------------------
|
613 |
md = """
|
@@ -686,7 +686,7 @@ with gr.Blocks() as demo:
|
|
686 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
687 |
sel_idx = channel_info["inputDict"][sel_name]["index"]
|
688 |
|
689 |
-
stage1_info["mappingResult"][0]["
|
690 |
stage1_info["mappingResult"][0]["isOriginalData"][prev_tpl_idx] = True
|
691 |
channel_info["templateDict"][prev_tpl_name]["matched"] = True
|
692 |
channel_info["inputDict"][sel_name]["assigned"] = True
|
@@ -736,7 +736,7 @@ with gr.Blocks() as demo:
|
|
736 |
prev_tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
|
737 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
738 |
sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_name]
|
739 |
-
stage1_info["mappingResult"][0]["
|
740 |
#print(prev_tpl_name, '<-', sel_name)
|
741 |
|
742 |
# ---------------------------------update the new round---------------------------------
|
@@ -746,7 +746,7 @@ with gr.Blocks() as demo:
|
|
746 |
label = '{} ({}/{})'.format(tpl_name, step3["count"], step3["totalNum"])
|
747 |
|
748 |
tpl_idx = channel_info["templateDict"][tpl_name]["index"]
|
749 |
-
value = stage1_info["mappingResult"][0]["
|
750 |
value = [channel_info["inputNames"][i] for i in value]
|
751 |
|
752 |
stage1_info["step3"] = step3
|
@@ -849,7 +849,7 @@ with gr.Blocks() as demo:
|
|
849 |
outputname = outputname,
|
850 |
samplerate = int(samplerate),
|
851 |
batch_cnt = i,
|
852 |
-
|
853 |
orig_flags = stage1_info["mappingResult"][i]["isOriginalData"])
|
854 |
except FileNotFoundError:
|
855 |
print('stop!!')
|
|
|
331 |
"batchNum" : None,
|
332 |
"mappingResult" : [
|
333 |
{
|
334 |
+
"indices" : None,
|
335 |
"isOriginalData" : None
|
336 |
#"channelUsageNum" : None
|
337 |
}
|
|
|
488 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
489 |
sel_idx = channel_info["inputDict"][sel_radio]["index"]
|
490 |
|
491 |
+
stage1_info["mappingResult"][0]["indices"][prev_tpl_idx] = [sel_idx]
|
492 |
stage1_info["mappingResult"][0]["isOriginalData"][prev_tpl_idx] = True
|
493 |
channel_info["templateDict"][prev_tpl_name]["matched"] = True
|
494 |
channel_info["inputDict"][sel_radio]["assigned"] = True
|
|
|
565 |
value of the data from the selected channels. (By default, the 4 nearest channels are pre-selected.)
|
566 |
"""
|
567 |
# find the 4 nearest in_channels for each unmatched tpl_channels
|
568 |
+
stage1_info["mappingResult"][0]["indices"] = app_utils.find_neighbors(
|
569 |
channel_info,
|
570 |
stage1_info["emptyTemplates"],
|
571 |
+
stage1_info["mappingResult"][0]["indices"])
|
572 |
# initialize the progress indication label
|
573 |
stage1_info["step3"] = {
|
574 |
"count" : 1,
|
|
|
578 |
label = '{} (1/{})'.format(tpl_name, stage1_info["step3"]["totalNum"])
|
579 |
|
580 |
tpl_idx = channel_info["templateDict"][tpl_name]["index"]
|
581 |
+
value = stage1_info["mappingResult"][0]["indices"][tpl_idx]
|
582 |
value = [channel_info["inputNames"][i] for i in value]
|
583 |
|
584 |
stage1_info["state"] = "step3-2-selecting"
|
|
|
607 |
prev_tpl_name = stage1_info["emptyTemplates"][stage1_info["step3"]["count"]-1]
|
608 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
609 |
sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_chkbox]
|
610 |
+
stage1_info["mappingResult"][0]["indices"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
|
611 |
#print(prev_tpl_name, '<-', sel_chkbox)
|
612 |
# ----------------------------------------------------------------------------------
|
613 |
md = """
|
|
|
686 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
687 |
sel_idx = channel_info["inputDict"][sel_name]["index"]
|
688 |
|
689 |
+
stage1_info["mappingResult"][0]["indices"][prev_tpl_idx] = [sel_idx]
|
690 |
stage1_info["mappingResult"][0]["isOriginalData"][prev_tpl_idx] = True
|
691 |
channel_info["templateDict"][prev_tpl_name]["matched"] = True
|
692 |
channel_info["inputDict"][sel_name]["assigned"] = True
|
|
|
736 |
prev_tpl_name = stage1_info["emptyTemplates"][step3["count"]-1]
|
737 |
prev_tpl_idx = channel_info["templateDict"][prev_tpl_name]["index"]
|
738 |
sel_idx = [channel_info["inputDict"][name]["index"] for name in sel_name]
|
739 |
+
stage1_info["mappingResult"][0]["indices"][prev_tpl_idx] = sel_idx if sel_idx!=[] else [None]
|
740 |
#print(prev_tpl_name, '<-', sel_name)
|
741 |
|
742 |
# ---------------------------------update the new round---------------------------------
|
|
|
746 |
label = '{} ({}/{})'.format(tpl_name, step3["count"], step3["totalNum"])
|
747 |
|
748 |
tpl_idx = channel_info["templateDict"][tpl_name]["index"]
|
749 |
+
value = stage1_info["mappingResult"][0]["indices"][tpl_idx]
|
750 |
value = [channel_info["inputNames"][i] for i in value]
|
751 |
|
752 |
stage1_info["step3"] = step3
|
|
|
849 |
outputname = outputname,
|
850 |
samplerate = int(samplerate),
|
851 |
batch_cnt = i,
|
852 |
+
old_idx = stage1_info["mappingResult"][i]["indices"],
|
853 |
orig_flags = stage1_info["mappingResult"][i]["isOriginalData"])
|
854 |
except FileNotFoundError:
|
855 |
print('stop!!')
|
app_utils.py
CHANGED
@@ -174,7 +174,7 @@ def align_coords(channel_info, tpl_montage, in_montage):
|
|
174 |
})
|
175 |
return channel_info
|
176 |
|
177 |
-
def find_neighbors(channel_info, empty_tpl_names,
|
178 |
in_names = channel_info["inputNames"]
|
179 |
tpl_dict = channel_info["templateDict"]
|
180 |
in_dict = channel_info["inputDict"]
|
@@ -189,9 +189,9 @@ def find_neighbors(channel_info, empty_tpl_names, new_idx):
|
|
189 |
for i, name in enumerate(empty_tpl_names):
|
190 |
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
|
191 |
idx = tpl_dict[name]["index"]
|
192 |
-
|
193 |
|
194 |
-
return
|
195 |
|
196 |
def match_names(stage1_info):
|
197 |
# read the location file
|
@@ -199,7 +199,7 @@ def match_names(stage1_info):
|
|
199 |
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
|
200 |
tpl_names = tpl_montage.ch_names
|
201 |
in_names = in_montage.ch_names
|
202 |
-
|
203 |
orig_flags = [False]*30
|
204 |
|
205 |
alias_dict = {
|
@@ -215,7 +215,7 @@ def match_names(stage1_info):
|
|
215 |
name = alias_dict[name]
|
216 |
|
217 |
if name in in_dict:
|
218 |
-
|
219 |
orig_flags[i] = True
|
220 |
tpl_dict[name]["matched"] = True
|
221 |
in_dict[name]["assigned"] = True
|
@@ -228,7 +228,7 @@ def match_names(stage1_info):
|
|
228 |
"emptyTemplates" : get_empty_templates(tpl_names, tpl_dict),
|
229 |
"mappingResult" : [
|
230 |
{
|
231 |
-
"
|
232 |
"isOriginalData" : orig_flags
|
233 |
}
|
234 |
]
|
@@ -269,14 +269,14 @@ def optimal_mapping(channel_info):
|
|
269 |
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
270 |
|
271 |
# store the mapping result
|
272 |
-
|
273 |
orig_flags = [False]*30
|
274 |
for i, j in zip(row_idx, col_idx):
|
275 |
if j < len(unass_in_names): # filter out dummy channels
|
276 |
tpl_name = tpl_names[i]
|
277 |
in_name = unass_in_names[j]
|
278 |
|
279 |
-
|
280 |
orig_flags[i] = True
|
281 |
tpl_dict[tpl_name]["matched"] = True
|
282 |
in_dict[in_name]["assigned"] = True
|
@@ -285,10 +285,10 @@ def optimal_mapping(channel_info):
|
|
285 |
# fill the remaining empty tpl_channels
|
286 |
empty_tpl_names = get_empty_templates(tpl_names, tpl_dict)
|
287 |
if empty_tpl_names != []:
|
288 |
-
|
289 |
|
290 |
result = {
|
291 |
-
"
|
292 |
"isOriginalData" : orig_flags
|
293 |
}
|
294 |
channel_info.update({
|
@@ -327,14 +327,14 @@ def mapping_result(stage1_info, channel_info, filename):
|
|
327 |
return stage1_info, channel_info
|
328 |
|
329 |
|
330 |
-
def reorder_data(
|
331 |
# read the input data
|
332 |
raw_data = utils.read_train_data(inputname)
|
333 |
#print(raw_data.shape)
|
334 |
new_data = np.zeros((30, raw_data.shape[1]))
|
335 |
|
336 |
zero_arr = np.zeros((1, raw_data.shape[1]))
|
337 |
-
for i, (indices, flag) in enumerate(zip(
|
338 |
if flag == True:
|
339 |
new_data[i, :] = raw_data[indices[0], :]
|
340 |
elif indices == [None]:
|
@@ -346,7 +346,7 @@ def reorder_data(idx_order, orig_flags, inputname, filename):
|
|
346 |
utils.save_data(new_data, filename)
|
347 |
return raw_data.shape
|
348 |
|
349 |
-
def restore_order(batch_cnt, raw_data_shape,
|
350 |
# read the denoised data
|
351 |
d_data = utils.read_train_data(filename)
|
352 |
if batch_cnt == 0:
|
@@ -355,25 +355,25 @@ def restore_order(batch_cnt, raw_data_shape, idx_order, orig_flags, filename, ou
|
|
355 |
else:
|
356 |
new_data = utils.read_train_data(outputname)
|
357 |
|
358 |
-
for i, (indices, flag) in enumerate(zip(
|
359 |
if flag == True:
|
360 |
new_data[indices[0], :] = d_data[i, :]
|
361 |
|
362 |
utils.save_data(new_data, outputname)
|
363 |
return
|
364 |
|
365 |
-
def run_model(modelname, filepath, inputname, m_filename, d_filename, outputname, samplerate, batch_cnt,
|
366 |
# establish temp folder
|
367 |
os.mkdir(filepath+'temp_data/')
|
368 |
|
369 |
# step1: Reorder input data
|
370 |
-
data_shape = reorder_data(
|
371 |
# step2: Data preprocessing
|
372 |
total_file_num = utils.preprocessing(filepath+'temp_data/', m_filename, samplerate)
|
373 |
# step3: Signal reconstruction
|
374 |
utils.reconstruct(modelname, total_file_num, filepath+'temp_data/', d_filename, samplerate)
|
375 |
# step4: Restore original order
|
376 |
-
restore_order(batch_cnt, data_shape,
|
377 |
|
378 |
utils.dataDelete(filepath+'temp_data/')
|
379 |
|
|
|
174 |
})
|
175 |
return channel_info
|
176 |
|
177 |
+
def find_neighbors(channel_info, empty_tpl_names, old_idx):
|
178 |
in_names = channel_info["inputNames"]
|
179 |
tpl_dict = channel_info["templateDict"]
|
180 |
in_dict = channel_info["inputDict"]
|
|
|
189 |
for i, name in enumerate(empty_tpl_names):
|
190 |
distances, indices = knn.kneighbors(empty_tpl[i].reshape(1,-1))
|
191 |
idx = tpl_dict[name]["index"]
|
192 |
+
old_idx[idx] = indices[0].tolist()
|
193 |
|
194 |
+
return old_idx
|
195 |
|
196 |
def match_names(stage1_info):
|
197 |
# read the location file
|
|
|
199 |
tpl_montage, in_montage, tpl_dict, in_dict = read_montage_data(loc_file)
|
200 |
tpl_names = tpl_montage.ch_names
|
201 |
in_names = in_montage.ch_names
|
202 |
+
old_idx = [[None]]*30 # store the indices of the in_channels in the order of tpl_channels
|
203 |
orig_flags = [False]*30
|
204 |
|
205 |
alias_dict = {
|
|
|
215 |
name = alias_dict[name]
|
216 |
|
217 |
if name in in_dict:
|
218 |
+
old_idx[i] = [in_dict[name]["index"]]
|
219 |
orig_flags[i] = True
|
220 |
tpl_dict[name]["matched"] = True
|
221 |
in_dict[name]["assigned"] = True
|
|
|
228 |
"emptyTemplates" : get_empty_templates(tpl_names, tpl_dict),
|
229 |
"mappingResult" : [
|
230 |
{
|
231 |
+
"indices" : old_idx,
|
232 |
"isOriginalData" : orig_flags
|
233 |
}
|
234 |
]
|
|
|
269 |
row_idx, col_idx = linear_sum_assignment(cost_matrix)
|
270 |
|
271 |
# store the mapping result
|
272 |
+
old_idx = [[None]]*30
|
273 |
orig_flags = [False]*30
|
274 |
for i, j in zip(row_idx, col_idx):
|
275 |
if j < len(unass_in_names): # filter out dummy channels
|
276 |
tpl_name = tpl_names[i]
|
277 |
in_name = unass_in_names[j]
|
278 |
|
279 |
+
old_idx[i] = [in_dict[in_name]["index"]]
|
280 |
orig_flags[i] = True
|
281 |
tpl_dict[tpl_name]["matched"] = True
|
282 |
in_dict[in_name]["assigned"] = True
|
|
|
285 |
# fill the remaining empty tpl_channels
|
286 |
empty_tpl_names = get_empty_templates(tpl_names, tpl_dict)
|
287 |
if empty_tpl_names != []:
|
288 |
+
old_idx = find_neighbors(channel_info, empty_tpl_names, old_idx)
|
289 |
|
290 |
result = {
|
291 |
+
"indices" : old_idx,
|
292 |
"isOriginalData" : orig_flags
|
293 |
}
|
294 |
channel_info.update({
|
|
|
327 |
return stage1_info, channel_info
|
328 |
|
329 |
|
330 |
+
def reorder_data(old_idx, orig_flags, inputname, filename):
|
331 |
# read the input data
|
332 |
raw_data = utils.read_train_data(inputname)
|
333 |
#print(raw_data.shape)
|
334 |
new_data = np.zeros((30, raw_data.shape[1]))
|
335 |
|
336 |
zero_arr = np.zeros((1, raw_data.shape[1]))
|
337 |
+
for i, (indices, flag) in enumerate(zip(old_idx, orig_flags)):
|
338 |
if flag == True:
|
339 |
new_data[i, :] = raw_data[indices[0], :]
|
340 |
elif indices == [None]:
|
|
|
346 |
utils.save_data(new_data, filename)
|
347 |
return raw_data.shape
|
348 |
|
349 |
+
def restore_order(batch_cnt, raw_data_shape, old_idx, orig_flags, filename, outputname):
|
350 |
# read the denoised data
|
351 |
d_data = utils.read_train_data(filename)
|
352 |
if batch_cnt == 0:
|
|
|
355 |
else:
|
356 |
new_data = utils.read_train_data(outputname)
|
357 |
|
358 |
+
for i, (indices, flag) in enumerate(zip(old_idx, orig_flags)):
|
359 |
if flag == True:
|
360 |
new_data[indices[0], :] = d_data[i, :]
|
361 |
|
362 |
utils.save_data(new_data, outputname)
|
363 |
return
|
364 |
|
365 |
+
def run_model(modelname, filepath, inputname, m_filename, d_filename, outputname, samplerate, batch_cnt, old_idx, orig_flags):
|
366 |
# establish temp folder
|
367 |
os.mkdir(filepath+'temp_data/')
|
368 |
|
369 |
# step1: Reorder input data
|
370 |
+
data_shape = reorder_data(old_idx, orig_flags, inputname, filepath+'temp_data/'+m_filename)
|
371 |
# step2: Data preprocessing
|
372 |
total_file_num = utils.preprocessing(filepath+'temp_data/', m_filename, samplerate)
|
373 |
# step3: Signal reconstruction
|
374 |
utils.reconstruct(modelname, total_file_num, filepath+'temp_data/', d_filename, samplerate)
|
375 |
# step4: Restore original order
|
376 |
+
restore_order(batch_cnt, data_shape, old_idx, orig_flags, filepath+'temp_data/'+d_filename, outputname)
|
377 |
|
378 |
utils.dataDelete(filepath+'temp_data/')
|
379 |
|