AIEEG / app.py
audrey06100's picture
update
b62670c
raw
history blame
28.9 kB
import gradio as gr
import numpy as np
import os
import random
import math
import utils
from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin, find_neighbors
import mne
from mne.channels import read_custom_montage
quickstart = """
# Quickstart
### Raw data
1. The data need to be a two-dimensional array (channel, timepoint).
2. Upload your EEG data in `.csv` format.
### Channel locations
Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
>If you cannot obtain it, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage.
### Imputation
The models was trained using the EEG signals of 30 channels, including: `Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2`.
We expect your input data to include these channels as well.
If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from:
- **zero**: fill the missing channels with zeros.
- **mean(auto)**: select 4 neareat channels for each missing channels, and we will average their values.
- **mean(manual)**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**.
### Mapping result
Once the mapping process is finished, the **template montage** and the **input montage**(with the matched channels displaying their names) will be shown.
### Model
Select the model you want to use.
The detailed description of the models can be found in other pages.
"""
icunet = """
# IC-U-Net
### Abstract
Electroencephalography (EEG) signals are often contaminated with artifacts. It is imperative to develop a practical and reliable artifact removal method to prevent the misinterpretation of neural signals and the underperformance of brain–computer interfaces. Based on the U-Net architecture, we developed a new artifact removal model, IC-U-Net, for removing pervasive EEG artifacts and reconstructing brain signals. IC-U-Net was trained using mixtures of brain and non-brain components decomposed by independent component analysis. It uses an ensemble of loss functions to model complex signal fluctuations in EEG recordings. The effectiveness of the proposed method in recovering brain activities and removing various artifacts (e.g., eye blinks/movements, muscle activities, and line/channel noise) was demonstrated in a simulation study and four real-world EEG experiments. IC-U-Net can reconstruct a multi-channel EEG signal and is applicable to most artifact types, offering a promising end-to-end solution for automatically removing artifacts from EEG recordings. It also meets the increasing need to image natural brain dynamics in a mobile setting.
"""
init_js = """
(app_state, channel_info) => {
app_state = JSON.parse(JSON.stringify(app_state));
channel_info = JSON.parse(JSON.stringify(channel_info));
let selector, attribute;
let channel, left, bottom;
if(app_state.state == "step2-selecting"){
selector = "#radio> div:nth-of-type(2)";
attribute = "value";
}else if(app_state.state == "step3-selecting"){
selector = "#chkbox-group> div:nth-of-type(2)";
attribute = "name";
}else return;
// add figure of in_montage
document.querySelector(selector).style.cssText = `
position: relative;
width: 560px;
height: 560px;
background: url("file=${app_state.filenames.raw_montage}");
`;
// move the radios/checkboxes
let all_elem = document.querySelectorAll(selector+"> label");
Array.from(all_elem).forEach((item) => {
channel = item.querySelector("input").getAttribute(attribute);
left = channel_info.inputByName[channel].css_position[0];
bottom = channel_info.inputByName[channel].css_position[1];
//console.log(`channel: ${channel}, left: ${left}, bottom: ${bottom}`);
item.style.cssText = `
position: absolute;
left: ${left};
bottom: ${bottom};
`;
item.className = "";
item.querySelector(":scope> span").innerText = "";
});
// add indication for the missing channels
channel = app_state.missingTemplates[0]
left = channel_info.templateByName[channel].css_position[0];
bottom = channel_info.templateByName[channel].css_position[1];
let rule = `
${selector}::after{
content: '';
position: absolute;
background-color: red;
width: 10px;
height: 10px;
border-radius: 50%;
left: ${left};
bottom: ${bottom};
}
`;
// check if indicator already exist
let exist = 0;
const styleSheet = document.styleSheets[0];
for(let i=0; i<styleSheet.cssRules.length; i++){
if(styleSheet.cssRules[i].selectorText == selector+"::after"){
exist = 1;
//console.log('exist!');
styleSheet.deleteRule(i);
styleSheet.insertRule(rule, styleSheet.cssRules.length);
break;
}
}
if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length);
}
"""
update_js = """
(app_state, channel_info) => {
app_state = JSON.parse(JSON.stringify(app_state));
channel_info = JSON.parse(JSON.stringify(channel_info));
let selector;
let channel, left, bottom;
if(app_state.state == "step2-selecting"){
selector = "#radio> div:nth-of-type(2)";
// update the radios
let all_elem = document.querySelectorAll(selector+"> label");
Array.from(all_elem).forEach((item) => {
channel = item.querySelector("input").value;
left = channel_info.inputByName[channel].css_position[0];
bottom = channel_info.inputByName[channel].css_position[1];
//console.log(`channel: ${channel}, left: ${left}, bottom: ${bottom}`);
item.style.cssText = `
position: absolute;
left: ${left};
bottom: ${bottom};
`;
item.className = "";
item.querySelector(":scope> span").innerText = "";
});
}else if(app_state.state == "step3-selecting"){
selector = "#chkbox-group> div:nth-of-type(2)";
}else return;
// update indication
channel = app_state.missingTemplates[app_state["fillingCount"]-1]
left = channel_info.templateByName[channel].css_position[0];
bottom = channel_info.templateByName[channel].css_position[1];
let rule = `
${selector}::after{
content: '';
position: absolute;
background-color: red;
width: 10px;
height: 10px;
border-radius: 50%;
left: ${left};
bottom: ${bottom};
}
`;
// check if indicator already exist
let exist = 0;
const styleSheet = document.styleSheets[0];
for(let i=0; i<styleSheet.cssRules.length; i++){
if(styleSheet.cssRules[i].selectorText == selector+"::after"){
exist = 1;
//console.log('exist!');
styleSheet.deleteRule(i);
styleSheet.insertRule(rule, styleSheet.cssRules.length);
break;
}
}
if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length);
}
"""
with gr.Blocks() as demo:
app_state_json = gr.JSON(visible=False)
channel_info_json = gr.JSON(visible=False)
with gr.Row():
gr.Markdown(
"""
<p style="text-align: center;">(...)</p>
"""
)
with gr.Row():
with gr.Column():
gr.Markdown("# 1.Channel Mapping")
# ------------------------input--------------------------
with gr.Row():
in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
with gr.Column(): #min_width=100
in_samplerate = gr.Textbox(label="Sampling rate (Hz)")
map_btn = gr.Button("Mapping")
# ------------------------mapping------------------------
# description for step123
desc_md = gr.Markdown("### Mapping result:", visible=False) # """??? # test
# step1 : mapping result
with gr.Row():
tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
map_montage = gr.Image(label="Input channels", visible=False)
# step2 : assign unmatched input channels to empty template channels
radio = gr.Radio(elem_id="radio", visible=False) #, label=""
step2_btn = gr.Button("Next", visible=False) #, interactive=False
# step3 : select a way to fill the empty template channels
with gr.Row():
in_fill_mode = gr.Dropdown(choices=["mean", "zero"],
value="mean",
label="Imputation", #......
visible=False,
scale=2)
fillmode_btn = gr.Button("OK", visible=False, scale=1)
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
step3_btn = gr.Button("Next", visible=False)
next_btn = gr.Button("Next step", visible=False)
# -------------------------------------------------------
with gr.Column():
gr.Markdown("# 2.Decode Data")
# ------------------------input--------------------------
with gr.Row():
in_model_name = gr.Dropdown(choices=[
("ART", "EEGART"),
("IC-U-Net", "ICUNet"),
("IC-U-Net++", "UNetpp"),
("IC-U-Net-Attn", "AttUnet"),
"(mapped data)",
"(denoised data)"],
value="EEGART",
label="Model",
scale=2)
run_btn = gr.Button(scale=1, interactive=False)
# ------------------------output-------------------------
batch_md = gr.Markdown(visible=False)
out_denoised_data = gr.File(label="Denoised data", visible=False)
#files = []
#for i in range():
#f = gr.File()
#files.append(f)
# -------------------------------------------------------
with gr.Row():
with gr.Tab("ART"):
gr.Markdown()
with gr.Tab("IC-U-Net"):
gr.Markdown(icunet)
with gr.Tab("IC-U-Net++"):
gr.Markdown()
with gr.Tab("IC-U-Net-Attn"):
gr.Markdown()
with gr.Tab("QuickStart"):
gr.Markdown(quickstart)
#demo.load(js=js)
# click on mapping button
def reset_all(raw_data, raw_loc, samplerate):
# verify that all required inputs have been provided
if raw_data == None or raw_loc == None:
gr.Warning('Please upload both the raw data and the channel location files.')
return
if samplerate == "":
gr.Warning('Please enter the sampling rate.')
return
# establish temp folder
filepath = os.path.dirname(str(raw_data))
try:
os.mkdir(filepath+"/temp_data/")
except OSError as e:
utils.dataDelete(filepath+"/temp_data/")
os.mkdir(filepath+"/temp_data/")
#print(e)
# initialize app_state, channel_info
#data = utils.read_train_data(raw_data)
app_state = {
"filepath": filepath+"/temp_data/",
"filenames": {},
"sampleRate": int(samplerate),
"state" : "step1"
}
channel_info = {
#"dataShape" : data.shape
}
# reset layout
return {app_state_json : app_state,
channel_info_json : channel_info,
# ------------------Stage1-----------------------
desc_md : gr.Markdown(visible=False),# res_md
tpl_montage : gr.Image(visible=False),
map_montage : gr.Image(value=None, visible=False),
radio : gr.Radio(choices=[], value=[], label="", visible=False),
in_fill_mode : gr.Dropdown(visible=False),
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
fillmode_btn : gr.Button("OK", visible=False),
step2_btn : gr.Button("Next", visible=False),
step3_btn : gr.Button("Next", visible=False),
next_btn : gr.Button("Next step", visible=False),
# ------------------Stage2-----------------------
run_btn : gr.Button(interactive=False),
batch_md : gr.Markdown(visible=False),
out_denoised_data : gr.File(visible=False)}
# step1
def mapping_result(app_state, channel_info, raw_loc):
filepath = app_state["filepath"]
filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
app_state["filenames"]["raw_montage"] = filename
raw_montage = read_custom_montage(raw_loc)
raw_fig = raw_montage.plot()
raw_fig.set_size_inches(5.6, 5.6)
raw_fig.savefig(filename, pad_inches=0)
# ------------------determine the next step-----------------------
in_num = len(channel_info["inputByIndex"])
matched_num = 30 - len(app_state["missingTemplates"])
# if the input channels(>=30) has all the 30 template channels
# -> Stage2.decode data
if matched_num == 30:
app_state["state"] = "finished"
gr.Info('The mapping process is finished!')
return {app_state_json : app_state,
desc_md : gr.Markdown("### Mapping result", visible=True),
tpl_montage : gr.Image(visible=True),
map_montage : gr.Image(value=filename, visible=True),
run_btn : gr.Button(interactive=True)}
# if matched channels < 30, and there're still some unmatched input channels
# -> assign these input channels to nearby unmatched/empty template channels
if in_num > matched_num:
app_state["state"] = "step2-initializing"
# if input channels < 30, but all of them can match to some template channels
# -> directly use fill_mode to fill the remaining channels
if in_num == matched_num:
app_state["state"] = "step3-initializing"
return {app_state_json : app_state,
desc_md : gr.Markdown("### Mapping result", visible=True),
tpl_montage : gr.Image(visible=True),
map_montage : gr.Image(value=filename, visible=True),
next_btn : gr.Button("Next step", visible=True)}
map_btn.click(
fn = reset_all,
inputs = [in_raw_data, in_raw_loc, in_samplerate],
outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, map_montage, radio, in_fill_mode,
chkbox_group, fillmode_btn, step2_btn, step3_btn, next_btn, run_btn, batch_md, out_denoised_data]
).success(
fn = mapping_stage1,
inputs = [app_state_json, channel_info_json, in_raw_loc],
outputs = [app_state_json, channel_info_json, desc_md]
).success(
fn = mapping_result,
inputs = [app_state_json, channel_info_json, in_raw_loc],
outputs = [app_state_json, desc_md, tpl_montage, map_montage, next_btn, run_btn]
)
def init_next_step(app_state, channel_info, selected_radio, selected_chkbox):
# step1 -> step2
if app_state["state"] == "step2-initializing":
print('step1 -> step2')
app_state["missingTemplates"] = [channel for channel in channel_info["templateByIndex"]
if channel_info["templateByName"][channel]["matched"]==False]
app_state.update({
"state" : "step2-selecting",
"fillingCount" : 1,
"totalFillingNum" : len(app_state["missingTemplates"])
})
name = app_state["missingTemplates"][0]
label = name+' (1/'+str(app_state["totalFillingNum"])+')'
if len(app_state["stage1UnassignedInputs"])==1 or app_state["totalFillingNum"]==1:
return {app_state_json : app_state,
channel_info_json : channel_info,
desc_md : gr.Markdown("### step2"),
tpl_montage : gr.Image(visible=False),
map_montage : gr.Image(visible=False),
radio : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=label, visible=True),
next_btn : gr.Button("Next step")}
else:
return {app_state_json : app_state,
channel_info_json : channel_info,
desc_md : gr.Markdown("### step2"),
tpl_montage : gr.Image(visible=False),
map_montage : gr.Image(visible=False),
radio : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=label, visible=True),
step2_btn : gr.Button(visible=True),
next_btn : gr.Button(visible=False)}
# step1 -> step3
elif app_state["state"] == "step3-initializing":
print('step1 -> step3')
app_state["missingTemplates"] = [channel for channel in channel_info["templateByIndex"]
if channel_info["templateByName"][channel]["matched"]==False]
app_state.update({
"state" : "step3-initializing",
"fillingCount" : 1,
"totalFillingNum" : len(app_state["missingTemplates"])
})
return {app_state_json : app_state,
channel_info_json : channel_info,
desc_md : gr.Markdown("### step3"),
tpl_montage : gr.Image(visible=False),
map_montage : gr.Image(visible=False),
in_fill_mode : gr.Dropdown(visible=True),
fillmode_btn : gr.Button(visible=True),
next_btn : gr.Button(visible=False)}
# step2 -> step3/Stage2.decode data
elif app_state["state"] == "step2-selecting":
# save info before clicking on next_btn
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
prev_target_idx = channel_info["templateByName"][prev_target_name]["index"]
if selected_radio == []:
app_state["stage1NewOrder"][prev_target_idx] = []
else:
selected_idx = channel_info["inputByName"][selected_radio]["index"]
app_state["stage1NewOrder"][prev_target_idx] = [selected_idx]
channel_info["templateByName"][prev_target_name]["matched"] = True
channel_info["inputByName"][selected_radio]["assigned"] = True
print(prev_target_name, '<-', selected_radio)
app_state.update({
"stage1UnassignedInputs" : [channel for channel in channel_info["inputByIndex"]
if channel_info["inputByName"][channel]["assigned"]==False],
"missingTemplates" : [channel for channel in channel_info["templateByIndex"]
if channel_info["templateByName"][channel]["matched"]==False]
})
# if all the unmatched template channels were filled by input channels
# -> Stage2
if len(app_state["missingTemplates"]) == 0:
print('step2 -> Stage2')
gr.Info('The mapping process is finished!')
app_state["state"] = "finished"
return {app_state_json : app_state,
channel_info_json : channel_info,
desc_md : gr.Markdown(visible=False),
radio : gr.Radio(visible=False),
next_btn : gr.Button(visible=False),
run_btn : gr.Button(interactive=True)}
# -> step3
else:
print('step2 -> step3')
app_state.update({
"state" : "step3-initializing",
"fillingCount" : 1,
"totalFillingNum" : len(app_state["missingTemplates"])
})
return {app_state_json : app_state,
channel_info_json : channel_info,
desc_md : gr.Markdown("### step3"),
radio : gr.Radio(visible=False),
in_fill_mode : gr.Dropdown(visible=True),
fillmode_btn : gr.Button(visible=True),
next_btn : gr.Button(visible=False)}
# step3 -> Stage2.decode data
elif app_state["state"] == "step3-selecting":
# save info before clicking on next_btn
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
prev_target_idx = channel_info["templateByName"][prev_target_name]["index"]
if selected_chkbox == []:
app_state["stage1NewOrder"][prev_target_idx] = []
else:
selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected_chkbox]
app_state["stage1NewOrder"][prev_target_idx] = selected_idx
#print(f'{prev_target_name}({prev_target_idx}): {selected_chkbox}')
gr.Info('The mapping process is finished!')
app_state["state"] = "finished"
print('step3 -> Stage2')
app_state["missingTemplates"] = [channel for channel in channel_info["templateByIndex"]
if channel_info["templateByName"][channel]["matched"]==False]
return {app_state_json : app_state,
desc_md : gr.Markdown(visible=False),
chkbox_group : gr.CheckboxGroup(visible=False),
next_btn : gr.Button(visible=False),
run_btn : gr.Button(interactive=True)}
next_btn.click(
fn = init_next_step,
inputs = [app_state_json, channel_info_json, radio, chkbox_group],
outputs = [app_state_json, channel_info_json, desc_md, tpl_montage, map_montage, radio, in_fill_mode,
chkbox_group, fillmode_btn, step2_btn, next_btn, run_btn]
).success(
fn = None,
js = init_js,
inputs = [app_state_json, channel_info_json],
outputs = []
)
# step2
# def update_selection()
def update_radio(app_state, channel_info, selected):
# save info before clicking on next_btn
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
prev_target_idx = channel_info["templateByName"][prev_target_name]["index"]
if selected == []:
app_state["stage1NewOrder"][prev_target_idx] = []
else:
selected_idx = channel_info["inputByName"][selected]["index"]
app_state["stage1NewOrder"][prev_target_idx] = [selected_idx]
channel_info["templateByName"][prev_target_name]["matched"] = True
channel_info["inputByName"][selected]["assigned"] = True
print(prev_target_name, '<-', selected)
# update the current round
app_state["fillingCount"] += 1
app_state["stage1UnassignedInputs"] = [channel for channel in channel_info["inputByIndex"]
if channel_info["inputByName"][channel]["assigned"]==False]
target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
radio_label = target_name+' ('+str(app_state["fillingCount"])+'/'+str(app_state["totalFillingNum"])+')'
if len(app_state["stage1UnassignedInputs"])==1 or app_state["fillingCount"]==app_state["totalFillingNum"]:
return {app_state_json : app_state,
channel_info_json : channel_info,
radio : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=radio_label),
step2_btn : gr.Button(visible=False),
next_btn : gr.Button("Next step", visible=True)}
else:
return {app_state_json : app_state,
channel_info_json : channel_info,
radio : gr.Radio(choices=app_state["stage1UnassignedInputs"], value=[], label=radio_label)}
step2_btn.click(
fn = update_radio,
inputs = [app_state_json, channel_info_json, radio],
outputs = [app_state_json, channel_info_json, radio, step2_btn, next_btn]
).success(
fn = None,
js = update_js,
inputs = [app_state_json, channel_info_json],
outputs = []
)
# step3
def fill_value(app_state, channel_info, fill_mode):
if fill_mode == 'zero':
app_state["state"] = "finished"
gr.Info('The mapping process is finished!')
return {app_state_json : app_state,
desc_md : gr.Markdown(visible=False),
in_fill_mode : gr.Dropdown(visible=False),
fillmode_btn : gr.Button(visible=False),
run_btn : gr.Button(interactive=True)}
elif fill_mode == 'mean':
app_state["state"] = "step3-selecting"
app_state = find_neighbors(app_state, channel_info, fill_mode)
name = app_state["missingTemplates"][0]
idx = channel_info["templateByName"][name]["index"]
value = app_state["stage1NewOrder"][idx]
value = [channel_info["inputByIndex"][i] for i in value]
label = name+' (1/'+str(app_state["totalFillingNum"])+')'
if app_state["totalFillingNum"] == 1:
return {app_state_json : app_state,
in_fill_mode : gr.Dropdown(visible=False),
fillmode_btn : gr.Button(visible=False),
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputByIndex"],
value=value, label=label, visible=True),
next_btn : gr.Button(visible=True)}
else:
return {app_state_json : app_state,
in_fill_mode : gr.Dropdown(visible=False),
fillmode_btn : gr.Button(visible=False),
chkbox_group : gr.CheckboxGroup(choices=channel_info["inputByIndex"],
value=value, label=label, visible=True),
step3_btn : gr.Button(visible=True)}
fillmode_btn.click(
fn = fill_value,
inputs = [app_state_json, channel_info_json, in_fill_mode],
outputs = [app_state_json, in_fill_mode, fillmode_btn, chkbox_group, step3_btn, next_btn, run_btn]
).success(
fn = None,
js = init_js,
inputs = [app_state_json, channel_info_json],
outputs = []
)
def update_chkbox(app_state, channel_info, selected):
# save info before clicking on next_btn
prev_target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
prev_target_idx = channel_info["templateByName"][prev_target_name]["index"]
if selected == []:
app_state["stage1NewOrder"][prev_target_idx] = []
else:
selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected]
app_state["stage1NewOrder"][prev_target_idx] = selected_idx
#print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
# update the current round
app_state["fillingCount"] += 1
target_name = app_state["missingTemplates"][app_state["fillingCount"]-1]
target_idx = channel_info["templateByName"][target_name]["index"]
chkbox_value = app_state["stage1NewOrder"][target_idx]
chkbox_value = [channel_info["inputByIndex"][i] for i in chkbox_value]
chkbox_label = target_name+' ('+str(app_state["fillingCount"])+'/'+str(app_state["totalFillingNum"])+')'
if app_state["fillingCount"] == app_state["totalFillingNum"]:
return {app_state_json : app_state,
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label),
step3_btn : gr.Button(visible=False),
next_btn : gr.Button("Submit", visible=True)}
else:
return {app_state_json : app_state,
chkbox_group : gr.CheckboxGroup(value=chkbox_value, label=chkbox_label)}
step3_btn.click(
fn = update_chkbox,
inputs = [app_state_json, channel_info_json, chkbox_group],
outputs = [app_state_json, chkbox_group, step3_btn, next_btn]
).success(
fn = None,
js = update_js,
inputs = [app_state_json, channel_info_json],
outputs = []
)
def delete_file(filename):
try:
os.remove(filename)
except OSError as e:
print(e)
def reset_run(app_state, channel_info, raw_data, model_name):
filepath = app_state["filepath"]
delete_file(filepath+'mapped.csv')
delete_file(filepath+'denoised.csv')
input_name = os.path.basename(str(raw_data))
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
in_num = len(channel_info["inputByIndex"])
matched_num = len([channel for channel in channel_info["inputByIndex"]
if channel_info["inputByName"][channel]["matched"]==True])
batch_num = math.ceil((in_num-matched_num)/30) + 1
app_state["filenames"]["denoised"] = filepath + output_name
app_state.update({
"runnigState" : "stage1",
"batchCount" : 1,
"totalBatchNum" : batch_num,
"stage2UnassignedInputs" : app_state["stage1UnassignedInputs"],
"stage2NewOrder" : [[]]*30,
})
# reset in.assigned back to the state after Stage1
for channel in app_state["stage1UnassignedInputs"]:
channel_info["inputByName"][channel]["assigned"] = False
return {app_state_json : app_state,
channel_info_json : channel_info,
run_btn : gr.Button(interactive=False),
batch_md : gr.Markdown(visible=False),
out_denoised_data : gr.File(visible=False)}
def run_model(app_state, channel_info, raw_data, model_name, fill_mode):
filepath = app_state["filepath"]
samplerate = app_state["sampleRate"]
new_filename = app_state["filenames"]["denoised"]
while app_state["runnigState"] != "finished":
#if app_state["batchCount"] > app_state["totalBatchNum"]:
#app_state["runnigState"] = "finished"
#break
md = 'Running model('+str(app_state["batchCount"])+'/'+str(app_state["totalBatchNum"])+')...'
yield {batch_md : gr.Markdown(md, visible=True)}
if app_state["batchCount"] > 1:
app_state, channel_info = mapping_stage2(app_state, channel_info, fill_mode)
if app_state["runnigState"] == "finished":
#yield {batch_md : gr.Markdown("error", visible=True)}
break
reorder_to_template(app_state, raw_data)
# step1: Data preprocessing
total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
# step2: Signal reconstruction
utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
reorder_to_origin(app_state, channel_info, filepath+'denoised.csv', new_filename)
#if model_name == "(mapped data)":
#return {out_denoised_data : filepath + 'mapped.csv'}
#elif model_name == "(denoised data)":
#return {out_denoised_data : filepath + 'denoised.csv'}
delete_file(filepath+'mapped.csv')
delete_file(filepath+'denoised.csv')
app_state["batchCount"] += 1
yield {run_btn : gr.Button(interactive=True),
batch_md : gr.Markdown(visible=False),
out_denoised_data : gr.File(new_filename, visible=True)}
run_btn.click(
fn = reset_run,
inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name],
outputs = [app_state_json, channel_info_json, run_btn, batch_md, out_denoised_data]
).success(
fn = run_model,
inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name, in_fill_mode],
outputs = [run_btn, batch_md, out_denoised_data]
)
if __name__ == "__main__":
demo.launch()