Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import os | |
import random | |
import math | |
import utils | |
from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin, find_neighbors | |
import mne | |
from mne.channels import read_custom_montage | |
quickstart = """ | |
# Quickstart | |
### Raw data | |
1. The data need to be a two-dimensional array (channel, timepoint). | |
2. Upload your EEG data in `.csv` format. | |
### Channel locations | |
Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**. | |
>If you cannot obtain it, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage. | |
### Imputation | |
The models was trained using the EEG signals of 30 channels, including: `Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2`. | |
We expect your input data to include these channels as well. | |
If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from: | |
- **zero**: fill the missing channels with zeros. | |
- **mean(auto)**: select 4 neareat channels for each missing channels, and we will average their values. | |
- **mean(manual)**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**. | |
### Mapping result | |
Once the mapping process is finished, the **template montage** and the **input montage**(with the matched channels displaying their names) will be shown. | |
### Model | |
Select the model you want to use. | |
The detailed description of the models can be found in other pages. | |
""" | |
icunet = """ | |
# IC-U-Net | |
### Abstract | |
Electroencephalography (EEG) signals are often contaminated with artifacts. It is imperative to develop a practical and reliable artifact removal method to prevent the misinterpretation of neural signals and the underperformance of brain–computer interfaces. Based on the U-Net architecture, we developed a new artifact removal model, IC-U-Net, for removing pervasive EEG artifacts and reconstructing brain signals. IC-U-Net was trained using mixtures of brain and non-brain components decomposed by independent component analysis. It uses an ensemble of loss functions to model complex signal fluctuations in EEG recordings. The effectiveness of the proposed method in recovering brain activities and removing various artifacts (e.g., eye blinks/movements, muscle activities, and line/channel noise) was demonstrated in a simulation study and four real-world EEG experiments. IC-U-Net can reconstruct a multi-channel EEG signal and is applicable to most artifact types, offering a promising end-to-end solution for automatically removing artifacts from EEG recordings. It also meets the increasing need to image natural brain dynamics in a mobile setting. | |
""" | |
init_js = """ | |
(app_state, channel_info) => { | |
app_state = JSON.parse(JSON.stringify(app_state)); | |
channel_info = JSON.parse(JSON.stringify(channel_info)); | |
let selector, attribute; | |
let channel, left, bottom; | |
if(app_state.state == "step2-selecting"){ | |
selector = "#radio> div:nth-of-type(2)"; | |
attribute = "value"; | |
}else if(app_state.state == "step3-selecting"){ | |
selector = "#chkbox-group> div:nth-of-type(2)"; | |
attribute = "name"; | |
}else return; | |
// add figure of in_montage | |
document.querySelector(selector).style.cssText = ` | |
position: relative; | |
width: 560px; | |
height: 560px; | |
background: url("file=${app_state.filenames.raw_montage}"); | |
`; | |
// move the radios/checkboxes | |
let all_elem = document.querySelectorAll(selector+"> label"); | |
Array.from(all_elem).forEach((item) => { | |
channel = item.querySelector("input").getAttribute(attribute); | |
left = channel_info.inputByName[channel].css_position[0]; | |
bottom = channel_info.inputByName[channel].css_position[1]; | |
//console.log(`channel: ${channel}, left: ${left}, bottom: ${bottom}`); | |
item.style.cssText = ` | |
position: absolute; | |
left: ${left}; | |
bottom: ${bottom}; | |
`; | |
item.className = ""; | |
item.querySelector(":scope> span").innerText = ""; | |
}); | |
// add indication for the missing channels | |
channel = app_state.missingTemplates[0] | |
left = channel_info.templateByName[channel].css_position[0]; | |
bottom = channel_info.templateByName[channel].css_position[1]; | |
let rule = ` | |
${selector}::after{ | |
content: ''; | |
position: absolute; | |
background-color: red; | |
width: 10px; | |
height: 10px; | |
border-radius: 50%; | |
left: ${left}; | |
bottom: ${bottom}; | |
} | |
`; | |
// check if indicator already exist | |
let exist = 0; | |
const styleSheet = document.styleSheets[0]; | |
for(let i=0; i<styleSheet.cssRules.length; i++){ | |
if(styleSheet.cssRules[i].selectorText == selector+"::after"){ | |
exist = 1; | |
//console.log('exist!'); | |
styleSheet.deleteRule(i); | |
styleSheet.insertRule(rule, styleSheet.cssRules.length); | |
break; | |
} | |
} | |
if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length); | |
} | |
""" | |
update_js = """ | |
(app_state, channel_info) => { | |
app_state = JSON.parse(JSON.stringify(app_state)); | |
channel_info = JSON.parse(JSON.stringify(channel_info)); | |
let selector; | |
let channel, left, bottom; | |
if(app_state.state == "step2-selecting"){ | |
selector = "#radio> div:nth-of-type(2)"; | |
// update the radios | |
let all_elem = document.querySelectorAll(selector+"> label"); | |
Array.from(all_elem).forEach((item) => { | |
channel = item.querySelector("input").value; | |
left = channel_info.inputByName[channel].css_position[0]; | |
bottom = channel_info.inputByName[channel].css_position[1]; | |
//console.log(`channel: ${channel}, left: ${left}, bottom: ${bottom}`); | |
item.style.cssText = ` | |
position: absolute; | |
left: ${left}; | |
bottom: ${bottom}; | |
`; | |
item.className = ""; | |
item.querySelector(":scope> span").innerText = ""; | |
}); | |
}else if(app_state.state == "step3-selecting"){ | |
selector = "#chkbox-group> div:nth-of-type(2)"; | |
}else return; | |
// update indication | |
channel = app_state.missingTemplates[app_state["fillingCount"]-1] | |
left = channel_info.templateByName[channel].css_position[0]; | |
bottom = channel_info.templateByName[channel].css_position[1]; | |
let rule = ` | |
${selector}::after{ | |
content: ''; | |
position: absolute; | |
background-color: red; | |
width: 10px; | |
height: 10px; | |
border-radius: 50%; | |
left: ${left}; | |
bottom: ${bottom}; | |
} | |
`; | |
// check if indicator already exist | |
let exist = 0; | |
const styleSheet = document.styleSheets[0]; | |
for(let i=0; i<styleSheet.cssRules.length; i++){ | |
if(styleSheet.cssRules[i].selectorText == selector+"::after"){ | |
exist = 1; | |
//console.log('exist!'); | |
styleSheet.deleteRule(i); | |
styleSheet.insertRule(rule, styleSheet.cssRules.length); | |
break; | |
} | |
} | |
if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length); | |
} | |
""" | |
with gr.Blocks() as demo: | |
app_state_json = gr.JSON(visible=False) | |
channel_info_json = gr.JSON(visible=False) | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
<p style="text-align: center;">(...)</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("# 1.Channel Mapping") | |
# ------------------------input-------------------------- | |
with gr.Row(): | |
in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"]) | |
in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"]) | |
with gr.Column(): #min_width=100 | |
in_samplerate = gr.Textbox(label="Sampling rate (Hz)") | |
map_btn = gr.Button("Mapping") | |
# ------------------------mapping------------------------ | |
# description for step123 | |
desc_md = gr.Markdown("### Mapping result:", visible=False) # """??? # test | |
# step1 : mapping result | |
with gr.Row(): | |
tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False) | |
map_montage = gr.Image(label="Input channels", visible=False) | |
# step2 : assign unmatched input channels to empty template channels | |
radio = gr.Radio(elem_id="radio", visible=False) #, label="" | |
step2_btn = gr.Button("Next", visible=False) #, interactive=False | |
# step3 : select a way to fill the empty template channels | |
with gr.Row(): | |
in_fill_mode = gr.Dropdown(choices=["mean", "zero"], | |
value="mean", | |
label="Imputation", #...... | |
visible=False, | |
scale=2) | |
fillmode_btn = gr.Button("OK", visible=False, scale=1) | |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False) | |
step3_btn = gr.Button("Next", visible=False) | |
next_btn = gr.Button("Next step", visible=False) | |
# ------------------------------------------------------- | |
with gr.Column(): | |
gr.Markdown("# 2.Decode Data") | |
# ------------------------input-------------------------- | |
with gr.Row(): | |
in_model_name = gr.Dropdown(choices=[ | |
("ART", "EEGART"), | |
("IC-U-Net", "ICUNet"), | |
("IC-U-Net++", "UNetpp"), | |
("IC-U-Net-Attn", "AttUnet"), | |
"(mapped data)", | |
"(denoised data)"], | |
value="EEGART", | |
label="Model", | |
scale=2) | |
run_btn = gr.Button(scale=1, interactive=False) | |
# ------------------------output------------------------- | |
batch_md = gr.Markdown(visible=False) | |
out_denoised_data = gr.File(label="Denoised data", visible=False) | |
#files = [] | |
#for i in range(): | |
#f = gr.File() | |
#files.append(f) | |
# ------------------------------------------------------- | |
with gr.Row(): | |
with gr.Tab("ART"): | |
gr.Markdown() | |
with gr.Tab("IC-U-Net"): | |
gr.Markdown(icunet) | |
with gr.Tab("IC-U-Net++"): | |
gr.Markdown() | |
with gr.Tab("IC-U-Net-Attn"): | |
gr.Markdown() | |
with gr.Tab("QuickStart"): | |
gr.Markdown(quickstart) | |
#demo.load(js=js) | |
# click on mapping button | |
def reset_all(raw_data, raw_loc, samplerate): | |
# verify that all required inputs have been provided | |
if raw_data == None or raw_loc == None: | |
gr.Warning('Please upload both the raw data and the channel location files.') | |
return | |
if samplerate == "": | |
gr.Warning('Please enter the sampling rate.') | |
return | |
# establish temp folder | |
filepath = os.path.dirname(str(raw_data)) | |
try: | |
os.mkdir(filepath+"/temp_data/") | |
except OSError as e: | |
utils.dataDelete(filepath+"/temp_data/") | |
os.mkdir(filepath+"/temp_data/") | |
#print(e) | |
# initialize app_state, channel_info | |
#data = utils.read_train_data(raw_data) | |
app_state = { | |
"filepath": filepath+"/temp_data/", | |
"filenames": {}, | |
"sampleRate": int(samplerate), | |
"state" : "step1" | |
} | |
channel_info = { | |
#"dataShape" : data.shape | |
} | |
# reset layout | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
# ------------------Stage1----------------------- | |
desc_md : gr.Markdown(visible=False),# res_md | |
tpl_montage : gr.Image(visible=False), | |
map_montage : gr.Image(value=None, visible=False), | |
radio : gr.Radio(choices=[], value=[], label="", visible=False), | |
in_fill_mode : gr.Dropdown(visible=False), | |
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False), | |
fillmode_btn : gr.Button("OK", visible=False), | |
step2_btn : gr.Button("Next", visible=False), | |
step3_btn : gr.Button("Next", visible=False), | |
next_btn : gr.Button("Next step", visible=False), | |
# ------------------Stage2----------------------- | |
run_btn : gr.Button(interactive=False), | |
batch_md : gr.Markdown(visible=False), | |
out_denoised_data : gr.File(visible=False)} | |
# step1 | |
def mapping_result(app_state, channel_info, raw_loc): | |
filepath = app_state["filepath"] | |
filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png" | |
app_state["filenames"]["raw_montage"] = filename | |
raw_montage = read_custom_montage(raw_loc) | |
raw_fig = raw_montage.plot() | |
raw_fig.set_size_inches(5.6, 5.6) | |
raw_fig.savefig(filename, pad_inches=0) | |
# ------------------determine the next step----------------------- | |
in_num = len(channel_info["inputByIndex"]) | |
matched_num = 30 - len(app_state["missingTemplates"]) | |
# if the input channels(>=30) has all the 30 template channels | |
# -> Stage2.decode data | |
if matched_num == 30: | |
app_state["state"] = "finished" | |
gr.Info('The mapping process is finished!') | |
return {app_state_json : app_state, | |
desc_md : gr.Markdown("### Mapping result", visible=True), | |
tpl_montage : gr.Image(visible=True), | |
map_montage : gr.Image(value=filename, visible=True), | |
run_btn : gr.Button(interactive=True)} | |
# if matched channels < 30, and there're still some unmatched input channels | |
# -> assign these input channels to nearby unmatched/empty template channels | |
if in_num > matched_num: | |
app_state["state"] = "step2-initializing" | |
# if input channels < 30, but all of them can match to some template channels | |
# -> directly use fill_mode to fill the remaining channels | |
if in_num == matched_num: | |
app_state["state"] = "step3-initializing" | |
return {app_state_json : app_state, | |
desc_md : gr.Markdown("### Mapping result", visible=True), | |
tpl_montage : gr.Image(visible=True), | |
map_montage : gr.Image(value=filename, visible=True), | |
next_btn : gr.Button("Next step", visible=True)} | |
map_btn.click( | |
fn = reset_all, | |
inputs = [in_raw_data, in_raw_loc, in_samplerate], | |
outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, map_montage, radio, in_fill_mode, | |
chkbox_group, fillmode_btn, step2_btn, step3_btn, next_btn, run_btn, batch_md, out_denoised_data] | |
).success( | |
fn = mapping_stage1, | |
inputs = [app_state_json, channel_info_json, in_raw_loc], | |
outputs = [app_state_json, channel_info_json, desc_md] | |
).success( | |
fn = mapping_result, | |
inputs = [app_state_json, channel_info_json, in_raw_loc], | |
outputs = [app_state_json, desc_md, tpl_montage, map_montage, next_btn, run_btn] | |
) | |
def init_next_step(app_state, channel_info, selected_radio, selected_chkbox): | |
# step1 -> step2 | |
if app_state["state"] == "step2-initializing": | |
print('step1 -> step2') | |
app_state["missingTemplates"] = [channel for channel in channel_info["templateByIndex"] | |
if channel_info["templateByName"][channel]["matched"]==False] | |
app_state.update({ | |
"state" : "step2-selecting", | |
"fillingCount" : 1, | |
"totalFillingNum" : len(app_state["missingTemplates"]) | |
}) | |
name = app_state["missingTemplates"][0] | |
label = name+' (1/'+str(app_state["totalFillingNum"])+')' | |
if len(app_state["stage1UnassignedInputs"])==1 or app_state["totalFillingNum"]==1: | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
desc_md : gr.Markdown("### step2"), | |
tpl_montage : gr.Image(visible=False), | |
map_montage : gr.Image(visible=False), | |
radio : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=label, visible=True), | |
next_btn : gr.Button("Next step")} | |
else: | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
desc_md : gr.Markdown("### step2"), | |
tpl_montage : gr.Image(visible=False), | |
map_montage : gr.Image(visible=False), | |
radio : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=label, visible=True), | |
step2_btn : gr.Button(visible=True), | |
next_btn : gr.Button(visible=False)} | |
# step1 -> step3 | |
elif app_state["state"] == "step3-initializing": | |
print('step1 -> step3') | |
app_state["missingTemplates"] = [channel for channel in channel_info["templateByIndex"] | |
if channel_info["templateByName"][channel]["matched"]==False] | |
app_state.update({ | |
"state" : "step3-initializing", | |
"fillingCount" : 1, | |
"totalFillingNum" : len(app_state["missingTemplates"]) | |
}) | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
desc_md : gr.Markdown("### step3"), | |
tpl_montage : gr.Image(visible=False), | |
map_montage : gr.Image(visible=False), | |
in_fill_mode : gr.Dropdown(visible=True), | |
fillmode_btn : gr.Button(visible=True), | |
next_btn : gr.Button(visible=False)} | |
# step2 -> step3/Stage2.decode data | |
elif app_state["state"] == "step2-selecting": | |
# save info before clicking on next_btn | |
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1] | |
prev_target_idx = channel_info["templateByName"][prev_target_name]["index"] | |
if selected_radio == []: | |
app_state["stage1NewOrder"][prev_target_idx] = [] | |
else: | |
selected_idx = channel_info["inputByName"][selected_radio]["index"] | |
app_state["stage1NewOrder"][prev_target_idx] = [selected_idx] | |
channel_info["templateByName"][prev_target_name]["matched"] = True | |
channel_info["inputByName"][selected_radio]["assigned"] = True | |
print(prev_target_name, '<-', selected_radio) | |
app_state.update({ | |
"stage1UnassignedInputs" : [channel for channel in channel_info["inputByIndex"] | |
if channel_info["inputByName"][channel]["assigned"]==False], | |
"missingTemplates" : [channel for channel in channel_info["templateByIndex"] | |
if channel_info["templateByName"][channel]["matched"]==False] | |
}) | |
# if all the unmatched template channels were filled by input channels | |
# -> Stage2 | |
if len(app_state["missingTemplates"]) == 0: | |
print('step2 -> Stage2') | |
gr.Info('The mapping process is finished!') | |
app_state["state"] = "finished" | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
desc_md : gr.Markdown(visible=False), | |
radio : gr.Radio(visible=False), | |
next_btn : gr.Button(visible=False), | |
run_btn : gr.Button(interactive=True)} | |
# -> step3 | |
else: | |
print('step2 -> step3') | |
app_state.update({ | |
"state" : "step3-initializing", | |
"fillingCount" : 1, | |
"totalFillingNum" : len(app_state["missingTemplates"]) | |
}) | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
desc_md : gr.Markdown("### step3"), | |
radio : gr.Radio(visible=False), | |
in_fill_mode : gr.Dropdown(visible=True), | |
fillmode_btn : gr.Button(visible=True), | |
next_btn : gr.Button(visible=False)} | |
# step3 -> Stage2.decode data | |
elif app_state["state"] == "step3-selecting": | |
# save info before clicking on next_btn | |
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1] | |
prev_target_idx = channel_info["templateByName"][prev_target_name]["index"] | |
if selected_chkbox == []: | |
app_state["stage1NewOrder"][prev_target_idx] = [] | |
else: | |
selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected_chkbox] | |
app_state["stage1NewOrder"][prev_target_idx] = selected_idx | |
#print(f'{prev_target_name}({prev_target_idx}): {selected_chkbox}') | |
gr.Info('The mapping process is finished!') | |
app_state["state"] = "finished" | |
print('step3 -> Stage2') | |
app_state["missingTemplates"] = [channel for channel in channel_info["templateByIndex"] | |
if channel_info["templateByName"][channel]["matched"]==False] | |
return {app_state_json : app_state, | |
desc_md : gr.Markdown(visible=False), | |
chkbox_group : gr.CheckboxGroup(visible=False), | |
next_btn : gr.Button(visible=False), | |
run_btn : gr.Button(interactive=True)} | |
next_btn.click( | |
fn = init_next_step, | |
inputs = [app_state_json, channel_info_json, radio, chkbox_group], | |
outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, map_montage, radio, in_fill_mode, | |
chkbox_group, fillmode_btn, step2_btn, next_btn, run_btn] | |
).success( | |
fn = None, | |
js = init_js, | |
inputs = [app_state_json, channel_info_json], | |
outputs = [] | |
) | |
# step2 | |
# def update_selection() | |
def update_radio(app_state, channel_info, selected): | |
# save info before clicking on next_btn | |
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1] | |
prev_target_idx = channel_info["templateByName"][prev_target_name]["index"] | |
if selected == []: | |
app_state["stage1NewOrder"][prev_target_idx] = [] | |
else: | |
selected_idx = channel_info["inputByName"][selected]["index"] | |
app_state["stage1NewOrder"][prev_target_idx] = [selected_idx] | |
channel_info["templateByName"][prev_target_name]["matched"] = True | |
channel_info["inputByName"][selected]["assigned"] = True | |
print(prev_target_name, '<-', selected) | |
# update the current round | |
app_state["fillingCount"] += 1 | |
app_state["stage1UnassignedInputs"] = [channel for channel in channel_info["inputByIndex"] | |
if channel_info["inputByName"][channel]["assigned"]==False] | |
target_name = app_state["missingTemplates"][app_state["fillingCount"]-1] | |
radio_label = target_name+' ('+str(app_state["fillingCount"])+'/'+str(app_state["totalFillingNum"])+')' | |
if len(app_state["stage1UnassignedInputs"])==1 or app_state["fillingCount"]==app_state["totalFillingNum"]: | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
radio : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=radio_label), | |
step2_btn : gr.Button(visible=False), | |
next_btn : gr.Button("Next step", visible=True)} | |
else: | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
radio : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=radio_label)} | |
step2_btn.click( | |
fn = update_radio, | |
inputs = [app_state_json, channel_info_json, radio], | |
outputs = [app_state_json, channel_info_json, radio, step2_btn, next_btn] | |
).success( | |
fn = None, | |
js = update_js, | |
inputs = [app_state_json, channel_info_json], | |
outputs = [] | |
) | |
# step3 | |
def fill_value(app_state, channel_info, fill_mode): | |
if fill_mode == 'zero': | |
app_state["state"] = "finished" | |
gr.Info('The mapping process is finished!') | |
return {app_state_json : app_state, | |
desc_md : gr.Markdown(visible=False), | |
in_fill_mode : gr.Dropdown(visible=False), | |
fillmode_btn : gr.Button(visible=False), | |
run_btn : gr.Button(interactive=True)} | |
elif fill_mode == 'mean': | |
app_state["state"] = "step3-selecting" | |
app_state = find_neighbors(app_state, channel_info, fill_mode) | |
name = app_state["missingTemplates"][0] | |
idx = channel_info["templateByName"][name]["index"] | |
value = app_state["stage1NewOrder"][idx] | |
value = [channel_info["inputByIndex"][i] for i in value] | |
label = name+' (1/'+str(app_state["totalFillingNum"])+')' | |
if app_state["totalFillingNum"] == 1: | |
return {app_state_json : app_state, | |
in_fill_mode : gr.Dropdown(visible=False), | |
fillmode_btn : gr.Button(visible=False), | |
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputByIndex"], | |
value=value, label=label, visible=True), | |
next_btn : gr.Button(visible=True)} | |
else: | |
return {app_state_json : app_state, | |
in_fill_mode : gr.Dropdown(visible=False), | |
fillmode_btn : gr.Button(visible=False), | |
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputByIndex"], | |
value=value, label=label, visible=True), | |
step3_btn : gr.Button(visible=True)} | |
fillmode_btn.click( | |
fn = fill_value, | |
inputs = [app_state_json, channel_info_json, in_fill_mode], | |
outputs = [app_state_json, in_fill_mode, fillmode_btn, chkbox_group, step3_btn, next_btn, run_btn] | |
).success( | |
fn = None, | |
js = init_js, | |
inputs = [app_state_json, channel_info_json], | |
outputs = [] | |
) | |
def update_chkbox(app_state, channel_info, selected): | |
# save info before clicking on next_btn | |
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1] | |
prev_target_idx = channel_info["templateByName"][prev_target_name]["index"] | |
if selected == []: | |
app_state["stage1NewOrder"][prev_target_idx] = [] | |
else: | |
selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected] | |
app_state["stage1NewOrder"][prev_target_idx] = selected_idx | |
#print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected)) | |
# update the current round | |
app_state["fillingCount"] += 1 | |
target_name = app_state["missingTemplates"][app_state["fillingCount"]-1] | |
target_idx = channel_info["templateByName"][target_name]["index"] | |
chkbox_value = app_state["stage1NewOrder"][target_idx] | |
chkbox_value = [channel_info["inputByIndex"][i] for i in chkbox_value] | |
chkbox_label = target_name+' ('+str(app_state["fillingCount"])+'/'+str(app_state["totalFillingNum"])+')' | |
if app_state["fillingCount"] == app_state["totalFillingNum"]: | |
return {app_state_json : app_state, | |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label), | |
step3_btn : gr.Button(visible=False), | |
next_btn : gr.Button("Submit", visible=True)} | |
else: | |
return {app_state_json : app_state, | |
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)} | |
step3_btn.click( | |
fn = update_chkbox, | |
inputs = [app_state_json, channel_info_json, chkbox_group], | |
outputs = [app_state_json, chkbox_group, step3_btn, next_btn] | |
).success( | |
fn = None, | |
js = update_js, | |
inputs = [app_state_json, channel_info_json], | |
outputs = [] | |
) | |
def delete_file(filename): | |
try: | |
os.remove(filename) | |
except OSError as e: | |
print(e) | |
def reset_run(app_state, channel_info, raw_data, model_name): | |
filepath = app_state["filepath"] | |
delete_file(filepath+'mapped.csv') | |
delete_file(filepath+'denoised.csv') | |
input_name = os.path.basename(str(raw_data)) | |
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv' | |
in_num = len(channel_info["inputByIndex"]) | |
matched_num = len([channel for channel in channel_info["inputByIndex"] | |
if channel_info["inputByName"][channel]["matched"]==True]) | |
batch_num = math.ceil((in_num-matched_num)/30) + 1 | |
app_state["filenames"]["denoised"] = filepath + output_name | |
app_state.update({ | |
"runnigState" : "stage1", | |
"batchCount" : 1, | |
"totalBatchNum" : batch_num, | |
"stage2UnassignedInputs" : app_state["stage1UnassignedInputs"], | |
"stage2NewOrder" : [[]]*30, | |
}) | |
# reset in.assigned back to the state after Stage1 | |
for channel in app_state["stage1UnassignedInputs"]: | |
channel_info["inputByName"][channel]["assigned"] = False | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
run_btn : gr.Button(interactive=False), | |
batch_md : gr.Markdown(visible=False), | |
out_denoised_data : gr.File(visible=False)} | |
def run_model(app_state, channel_info, raw_data, model_name, fill_mode): | |
filepath = app_state["filepath"] | |
samplerate = app_state["sampleRate"] | |
new_filename = app_state["filenames"]["denoised"] | |
while app_state["runnigState"] != "finished": | |
#if app_state["batchCount"] > app_state["totalBatchNum"]: | |
#app_state["runnigState"] = "finished" | |
#break | |
md = 'Running model('+str(app_state["batchCount"])+'/'+str(app_state["totalBatchNum"])+')...' | |
yield {batch_md : gr.Markdown(md, visible=True)} | |
if app_state["batchCount"] > 1: | |
app_state, channel_info = mapping_stage2(app_state, channel_info, fill_mode) | |
if app_state["runnigState"] == "finished": | |
#yield {batch_md : gr.Markdown("error", visible=True)} | |
break | |
reorder_to_template(app_state, raw_data) | |
# step1: Data preprocessing | |
total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate) | |
# step2: Signal reconstruction | |
utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate) | |
reorder_to_origin(app_state, channel_info, filepath+'denoised.csv', new_filename) | |
#if model_name == "(mapped data)": | |
#return {out_denoised_data : filepath + 'mapped.csv'} | |
#elif model_name == "(denoised data)": | |
#return {out_denoised_data : filepath + 'denoised.csv'} | |
delete_file(filepath+'mapped.csv') | |
delete_file(filepath+'denoised.csv') | |
app_state["batchCount"] += 1 | |
yield {run_btn : gr.Button(interactive=True), | |
batch_md : gr.Markdown(visible=False), | |
out_denoised_data : gr.File(new_filename, visible=True)} | |
run_btn.click( | |
fn = reset_run, | |
inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name], | |
outputs = [app_state_json, channel_info_json, run_btn, batch_md, out_denoised_data] | |
).success( | |
fn = run_model, | |
inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name, in_fill_mode], | |
outputs = [run_btn, batch_md, out_denoised_data] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |