AIEEG / app.py
audrey06100's picture
update
995c1d0
raw
history blame
18.5 kB
import gradio as gr
import numpy as np
import os
import random
import math
import utils
from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin
import mne
from mne.channels import read_custom_montage
quickstart = """
# Quickstart
### Raw data
1. The data need to be a two-dimensional array (channel, timepoint).
2. Upload your EEG data in `.csv` format.
### Channel locations
Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**.
>If you cannot obtain it, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage.
### Imputation
The models was trained using the EEG signals of 30 channels, including: `Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2`.
We expect your input data to include these channels as well.
If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from:
- **zero**: fill the missing channels with zeros.
- **mean(auto)**: select 4 neareat channels for each missing channels, and we will average their values.
- **mean(manual)**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**.
### Mapping result
Once the mapping process is finished, the **template montage** and the **input montage**(with the matched channels displaying their names) will be shown.
### Model
Select the model you want to use.
The detailed description of the models can be found in other pages.
"""
icunet = """
# IC-U-Net
### Abstract
Electroencephalography (EEG) signals are often contaminated with artifacts. It is imperative to develop a practical and reliable artifact removal method to prevent the misinterpretation of neural signals and the underperformance of brain–computer interfaces. Based on the U-Net architecture, we developed a new artifact removal model, IC-U-Net, for removing pervasive EEG artifacts and reconstructing brain signals. IC-U-Net was trained using mixtures of brain and non-brain components decomposed by independent component analysis. It uses an ensemble of loss functions to model complex signal fluctuations in EEG recordings. The effectiveness of the proposed method in recovering brain activities and removing various artifacts (e.g., eye blinks/movements, muscle activities, and line/channel noise) was demonstrated in a simulation study and four real-world EEG experiments. IC-U-Net can reconstruct a multi-channel EEG signal and is applicable to most artifact types, offering a promising end-to-end solution for automatically removing artifacts from EEG recordings. It also meets the increasing need to image natural brain dynamics in a mobile setting.
"""
chkbox_js = """
(app_state, channel_info) => {
app_state = JSON.parse(JSON.stringify(app_state));
channel_info = JSON.parse(JSON.stringify(channel_info));
if(app_state.state == "finished") return;
// add figure of in_montage
document.querySelector("#chkbox-group> div:nth-of-type(2)").style.cssText = `
position: relative;
width: 560px;
height: 560px;
background: url("file=${app_state.filenames.raw_montage}");
`;
// add indication for the missing channels
let channel = channel_info.missingChannelsIndex[0]
channel = channel_info.templateByIndex[channel]
let left = channel_info.templateByName[channel].css_position[0];
let bottom = channel_info.templateByName[channel].css_position[1];
let rule = `
#chkbox-group> div:nth-of-type(2)::after{
content: '';
position: absolute;
background-color: red;
width: 10px;
height: 10px;
border-radius: 50%;
left: ${left};
bottom: ${bottom};
}
`;
// check if indicator already exist
let exist = 0;
const styleSheet = document.styleSheets[0];
for(let i=0; i<styleSheet.cssRules.length; i++){
if(styleSheet.cssRules[i].selectorText == "#chkbox-group> div:nth-of-type(2)::after"){
exist = 1;
console.log('exist!');
styleSheet.deleteRule(i);
styleSheet.insertRule(rule, styleSheet.cssRules.length);
break;
}
}
if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length);
// move the checkboxes
let all_chkbox = document.querySelectorAll("#chkbox-group> div:nth-of-type(2)> label");
//all_chkbox = Array.apply(null, all_chkbox);
Array.from(all_chkbox).forEach((item, index) => {
channel = channel_info.inputByIndex[index];
left = channel_info.inputByName[channel].css_position[0];
bottom = channel_info.inputByName[channel].css_position[1];
console.log(`left: ${left}, bottom: ${bottom}`);
item.style.cssText = `
position: absolute;
left: ${left};
bottom: ${bottom};
`;
item.className = "";
item.querySelector(":scope> span").innerText = "";
});
}
"""
indication_js = """
(app_state, channel_info) => {
app_state = JSON.parse(JSON.stringify(app_state));
channel_info = JSON.parse(JSON.stringify(channel_info));
if(app_state.state == "finished") return;
let channel = channel_info.missingChannelsIndex[app_state["fillingCount"]-1]
channel = channel_info.templateByIndex[channel]
let left = channel_info.templateByName[channel].css_position[0];
let bottom = channel_info.templateByName[channel].css_position[1];
let rule = `
#chkbox-group> div:nth-of-type(2)::after{
content: '';
position: absolute;
background-color: red;
width: 10px;
height: 10px;
border-radius: 50%;
left: ${left};
bottom: ${bottom};
}
`;
// check if indicator already exist
let exist = 0;
const styleSheet = document.styleSheets[0];
for(let i=0; i<styleSheet.cssRules.length; i++){
if(styleSheet.cssRules[i].selectorText == "#chkbox-group> div:nth-of-type(2)::after"){
exist = 1;
console.log('exist!');
styleSheet.deleteRule(i);
styleSheet.insertRule(rule, styleSheet.cssRules.length);
break;
}
}
if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length);
}
"""
with gr.Blocks() as demo:
app_state_json = gr.JSON(visible=False)
channel_info_json = gr.JSON(visible=False)
with gr.Row():
gr.Markdown(
"""
<p style="text-align: center;">(...)</p>
"""
)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
# 1.Channel Mapping
"""
)
# upload files, chose imputation way (???
with gr.Row():
in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"])
with gr.Column(min_width=100):
in_sample_rate = gr.Textbox(label="Sampling rate (Hz)")
in_fill_mode = gr.Dropdown(choices=[
#("adjacent channel", "adjacent"),
("mean (auto)", "mean_auto"),
("mean (manual)", "mean_manual"),
("",""),
"zero"],
value="mean_auto",
label="Imputation")
map_btn = gr.Button("Mapping")
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False)
next_btn = gr.Button("Next", interactive=False, visible=False)
# mapping result
res_md = gr.Markdown(
"""
### Mapping result:
""",
visible=False
)
with gr.Row():
tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False)
map_montage = gr.Image(label="Matched channels", visible=False)
#miss_txtbox = gr.Textbox(label="Missing channels", visible=False)
#tpl_loc_file = gr.File("./template_chanlocs.loc", show_label=False, visible=False)
with gr.Column():
gr.Markdown(
"""
# 2.Decode Data
"""
)
with gr.Row():
in_model_name = gr.Dropdown(choices=[
("ART", "EEGART"),
("IC-U-Net", "ICUNet"),
("IC-U-Net++", "UNetpp"),
("IC-U-Net-Attn", "AttUnet"),
"(mapped data)",
"(denoised data)"],
value="EEGART",
label="Model",
scale=2)
run_btn = gr.Button(scale=1, interactive=False)
batch_md = gr.Markdown(visible=False)
out_denoised_data = gr.File(label="Denoised data", visible=False)
with gr.Row():
with gr.Tab("ART"):
gr.Markdown()
with gr.Tab("IC-U-Net"):
gr.Markdown(icunet)
with gr.Tab("IC-U-Net++"):
gr.Markdown()
with gr.Tab("IC-U-Net-Attn"):
gr.Markdown()
with gr.Tab("QuickStart"):
gr.Markdown(quickstart)
#demo.load(js=js)
def reset1(raw_data, samplerate):
# establish temp folder
filepath = os.path.dirname(str(raw_data))
try:
os.mkdir(filepath+"/temp_data/")
except OSError as e:
utils.dataDelete(filepath+"/temp_data/")
os.mkdir(filepath+"/temp_data/")
#print(e)
# initialize app_state, channel_info
data = utils.read_train_data(raw_data)
app_state = {
"filepath": filepath+"/temp_data/",
"filenames": {},
"sampleRate": int(samplerate),
}
channel_info = {
"dataShape" : data.shape
}
return {app_state_json : app_state,
channel_info_json : channel_info,
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False),
next_btn : gr.Button("Next", interactive=False, visible=False),
run_btn : gr.Button(interactive=False),
tpl_montage : gr.Image(visible=False),
map_montage : gr.Image(value=None, visible=False),
res_md : gr.Markdown(visible=False),
batch_md : gr.Markdown(visible=False),
out_denoised_data : gr.File(visible=False)}
def mapping_result(app_state, channel_info, fill_mode):
in_num = len(channel_info["inputByName"])
matched_num = 30 - len(channel_info["missingChannelsIndex"])
batch_num = math.ceil((in_num-matched_num)/30) + 1
app_state.update({
"batchCount" : 1,
"totalBatchNum" : batch_num
})
if fill_mode=="mean_manual" and channel_info["missingChannelsIndex"]!=[]:
app_state.update({
"state" : "initializing",
"totalFillingNum" : len(channel_info["missingChannelsIndex"])
})
#print("Missing channels:", channel_info["missingChannelsIndex"])
return {app_state_json : app_state,
next_btn : gr.Button(visible=True)}
else:
app_state.update({
"state" : "finished"
})
return {app_state_json : app_state,
res_md : gr.Markdown(visible=True),
run_btn : gr.Button(interactive=True)}
def show_montage(app_state, channel_info, raw_loc):
if app_state["state"] == "selecting":
return {app_state_json : app_state} # change nothing
filepath = app_state["filepath"]
raw_montage = read_custom_montage(raw_loc)
# convert all channel names to uppercase
for i in range(len(raw_montage.ch_names)):
channel = raw_montage.ch_names[i]
raw_montage.rename_channels({channel: str.upper(channel)})
if app_state["state"] == "initializing":
filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png"
app_state["filenames"]["raw_montage"] = filename
raw_fig = raw_montage.plot()
raw_fig.set_size_inches(5.6, 5.6)
raw_fig.savefig(filename, pad_inches=0)
return {app_state_json : app_state}
elif app_state["state"] == "finished":
filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png"
app_state["filenames"]["map_montage"] = filename
show_names= []
for channel in channel_info["inputByName"]:
if channel_info["inputByName"][channel]["matched"]:
show_names.append(channel)
mapped_fig = raw_montage.plot(show_names=show_names)
mapped_fig.set_size_inches(5.6, 5.6)
mapped_fig.savefig(filename, pad_inches=0)
return {app_state_json : app_state,
tpl_montage : gr.Image(visible=True),
map_montage : gr.Image(value=filename, visible=True)}
#else:
#return {app_state_json : app_state} # change nothing
def generate_chkbox(app_state, channel_info):
if app_state["state"] == "initializing":
in_channels = [channel for channel in channel_info["inputByName"]]
app_state["state"] = "selecting"
app_state["fillingCount"] = 1
idx = channel_info["missingChannelsIndex"][0]
name = channel_info["templateByIndex"][idx]
chkbox_label = name+' (1/'+str(app_state["totalFillingNum"])+')'
return {app_state_json : app_state,
chkbox_group : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True),
next_btn : gr.Button(interactive=True)}
else:
return {app_state_json : app_state} # change nothing
map_btn.click(
fn = reset1,
inputs = [in_raw_data, in_sample_rate],
outputs = [app_state_json, channel_info_json, chkbox_group, next_btn, run_btn,
tpl_montage, map_montage, res_md, batch_md, out_denoised_data]
).success(
fn = mapping_stage1,
inputs = [app_state_json, channel_info_json, in_raw_data, in_raw_loc, in_fill_mode],
outputs = [app_state_json, channel_info_json]
).success(
fn = mapping_result,
inputs = [app_state_json, channel_info_json, in_fill_mode],
outputs = [app_state_json, next_btn, res_md, run_btn]
).success(
fn = show_montage,
inputs = [app_state_json, channel_info_json, in_raw_loc],
outputs = [app_state_json, tpl_montage, map_montage]
).success(
fn = generate_chkbox,
inputs = [app_state_json, channel_info_json],
outputs = [app_state_json, chkbox_group, next_btn]
).success(
fn = None,
js = chkbox_js,
inputs = [app_state_json, channel_info_json],
outputs = []
)
def check_next(app_state, channel_info, selected, raw_data, fill_mode):
#if state["state"] == "selecting":
# save info before clicking on next_btn
prev_target_idx = channel_info["missingChannelsIndex"][app_state["fillingCount"]-1]
prev_target_name = channel_info["templateByIndex"][prev_target_idx]
selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected]
app_state["stage1NewOrder"][prev_target_idx] = selected_idx
#if len(selected)==1 and channel_info["inputByName"][selected[0]]["used"]==False:
#channel_info["inputByName"][selected[0]]["used"] = True
#channel_info["missingChannelsIndex"][state["fillingCount"]-1] = -1
print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected))
# update next round
app_state["fillingCount"] += 1
if app_state["fillingCount"] <= app_state["totalFillingNum"]:
target_idx = channel_info["missingChannelsIndex"][app_state["fillingCount"]-1]
target_name = channel_info["templateByIndex"][target_idx]
chkbox_label = target_name+' ('+str(app_state["fillingCount"])+'/'+str(app_state["totalFillingNum"])+')'
btn_label = "Submit" if app_state["fillingCount"]==app_state["totalFillingNum"] else "Next"
return {app_state_json : app_state,
#channel_info_json : channel_info,
chkbox_group : gr.CheckboxGroup(value=[], label=chkbox_label),
next_btn : gr.Button(btn_label)}
else:
app_state["state"] = "finished"
return {app_state_json : app_state,
#channel_info_json : channel_info,
chkbox_group : gr.CheckboxGroup(visible=False),
next_btn : gr.Button(visible=False),
res_md : gr.Markdown(visible=True),
run_btn : gr.Button(interactive=True)}
next_btn.click(
fn = check_next,
inputs = [app_state_json, channel_info_json, chkbox_group, in_raw_data, in_fill_mode],
outputs = [app_state_json, chkbox_group, next_btn, run_btn, res_md]
).success(
fn = show_montage,
inputs = [app_state_json, channel_info_json, in_raw_loc],
outputs = [app_state_json, tpl_montage, map_montage]
).success(
fn = None,
js = indication_js,
inputs = [app_state_json, channel_info_json],
outputs = []
)
def delete_file(filename):
try:
os.remove(filename)
except OSError as e:
print(e)
def reset2(app_state, raw_data, model_name):
filepath = app_state["filepath"]
input_name = os.path.basename(str(raw_data))
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
app_state["filenames"]["denoised"] = filepath + output_name
app_state.update({
"runnigState" : "stage1",
"batchCount" : 1,
"stage2NewOrder" : [[]]*30
})
delete_file(filepath+'mapped.csv')
delete_file(filepath+'denoised.csv')
return {app_state_json : app_state,
run_btn : gr.Button(interactive=False),
batch_md : gr.Markdown(visible=False),
out_denoised_data : gr.File(visible=False)}
def run_model(app_state, channel_info, raw_data, model_name, fill_mode):
filepath = app_state["filepath"]
samplerate = app_state["sampleRate"]
new_filename = app_state["filenames"]["denoised"]
while app_state["runnigState"] != "finished":
#if app_state["batchCount"] > app_state["totalBatchNum"]:
#app_state["runnigState"] = "finished"
#break
md = 'Running model('+str(app_state["batchCount"])+'/'+str(app_state["totalBatchNum"])+')...'
yield {batch_md : gr.Markdown(md, visible=True)}
if app_state["batchCount"] > 1:
app_state, channel_info = mapping_stage2(app_state, channel_info, fill_mode)
if app_state["runnigState"] == "finished":
break
app_state["batchCount"] += 1
reorder_to_template(app_state, raw_data)
# step1: Data preprocessing
total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate)
# step2: Signal reconstruction
utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate)
reorder_to_origin(app_state, channel_info, filepath+'denoised.csv', new_filename)
#if model_name == "(mapped data)":
#return {out_denoised_data : filepath + 'mapped.csv'}
#elif model_name == "(denoised data)":
#return {out_denoised_data : filepath + 'denoised.csv'}
delete_file(filepath+'mapped.csv')
delete_file(filepath+'denoised.csv')
yield {run_btn : gr.Button(interactive=True),
batch_md : gr.Markdown(visible=False),
out_denoised_data : gr.File(new_filename, visible=True)}
run_btn.click(
fn = reset2,
inputs = [app_state_json, in_raw_data, in_model_name],
outputs = [app_state_json, run_btn, batch_md, out_denoised_data]
).success(
fn = run_model,
inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name, in_fill_mode],
outputs = [run_btn, batch_md, out_denoised_data]
)
if __name__ == "__main__":
demo.launch()