Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import os | |
import random | |
import math | |
import utils | |
from channel_mapping import mapping_stage1, mapping_stage2, reorder_to_template, reorder_to_origin | |
import mne | |
from mne.channels import read_custom_montage | |
quickstart = """ | |
# Quickstart | |
### Raw data | |
1. The data need to be a two-dimensional array (channel, timepoint). | |
2. Upload your EEG data in `.csv` format. | |
### Channel locations | |
Upload your data's channel locations in `.loc` format, which can be obtained using **EEGLAB**. | |
>If you cannot obtain it, we recommend you to download the standard montage <a href="">here</a>. If the channels in those files doesn't match yours, you can use **EEGLAB** to modify them to your needed montage. | |
### Imputation | |
The models was trained using the EEG signals of 30 channels, including: `Fp1, Fp2, F7, F3, Fz, F4, F8, FT7, FC3, FCz, FC4, FT8, T7, C3, Cz, C4, T8, TP7, CP3, CPz, CP4, TP8, P7, P3, Pz, P4, P8, O1, Oz, O2`. | |
We expect your input data to include these channels as well. | |
If your data doesn't contain all of the mentioned channels, there are 3 imputation ways you can choose from: | |
- **zero**: fill the missing channels with zeros. | |
- **mean(auto)**: select 4 neareat channels for each missing channels, and we will average their values. | |
- **mean(manual)**: select the channels you wish to use for imputing the required one, and we will average their values. If you select nothing, zeros will be imputed. For example, you didn't have **FCZ** and you choose **FC1, FC2, FZ, CZ** to impute it(depending on the channels you have), we will compute the mean of these 4 channels and assign this new value to **FCZ**. | |
### Mapping result | |
Once the mapping process is finished, the **template montage** and the **input montage**(with the matched channels displaying their names) will be shown. | |
### Model | |
Select the model you want to use. | |
The detailed description of the models can be found in other pages. | |
""" | |
icunet = """ | |
# IC-U-Net | |
### Abstract | |
Electroencephalography (EEG) signals are often contaminated with artifacts. It is imperative to develop a practical and reliable artifact removal method to prevent the misinterpretation of neural signals and the underperformance of brain–computer interfaces. Based on the U-Net architecture, we developed a new artifact removal model, IC-U-Net, for removing pervasive EEG artifacts and reconstructing brain signals. IC-U-Net was trained using mixtures of brain and non-brain components decomposed by independent component analysis. It uses an ensemble of loss functions to model complex signal fluctuations in EEG recordings. The effectiveness of the proposed method in recovering brain activities and removing various artifacts (e.g., eye blinks/movements, muscle activities, and line/channel noise) was demonstrated in a simulation study and four real-world EEG experiments. IC-U-Net can reconstruct a multi-channel EEG signal and is applicable to most artifact types, offering a promising end-to-end solution for automatically removing artifacts from EEG recordings. It also meets the increasing need to image natural brain dynamics in a mobile setting. | |
""" | |
chkbox_js = """ | |
(app_state, channel_info) => { | |
app_state = JSON.parse(JSON.stringify(app_state)); | |
channel_info = JSON.parse(JSON.stringify(channel_info)); | |
if(app_state.state == "finished") return; | |
// add figure of in_montage | |
document.querySelector("#chkbox-group> div:nth-of-type(2)").style.cssText = ` | |
position: relative; | |
width: 560px; | |
height: 560px; | |
background: url("file=${app_state.filenames.raw_montage}"); | |
`; | |
// add indication for the missing channels | |
let channel = channel_info.missingChannelsIndex[0] | |
channel = channel_info.templateByIndex[channel] | |
let left = channel_info.templateByName[channel].css_position[0]; | |
let bottom = channel_info.templateByName[channel].css_position[1]; | |
let rule = ` | |
#chkbox-group> div:nth-of-type(2)::after{ | |
content: ''; | |
position: absolute; | |
background-color: red; | |
width: 10px; | |
height: 10px; | |
border-radius: 50%; | |
left: ${left}; | |
bottom: ${bottom}; | |
} | |
`; | |
// check if indicator already exist | |
let exist = 0; | |
const styleSheet = document.styleSheets[0]; | |
for(let i=0; i<styleSheet.cssRules.length; i++){ | |
if(styleSheet.cssRules[i].selectorText == "#chkbox-group> div:nth-of-type(2)::after"){ | |
exist = 1; | |
console.log('exist!'); | |
styleSheet.deleteRule(i); | |
styleSheet.insertRule(rule, styleSheet.cssRules.length); | |
break; | |
} | |
} | |
if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length); | |
// move the checkboxes | |
let all_chkbox = document.querySelectorAll("#chkbox-group> div:nth-of-type(2)> label"); | |
//all_chkbox = Array.apply(null, all_chkbox); | |
Array.from(all_chkbox).forEach((item, index) => { | |
channel = channel_info.inputByIndex[index]; | |
left = channel_info.inputByName[channel].css_position[0]; | |
bottom = channel_info.inputByName[channel].css_position[1]; | |
console.log(`left: ${left}, bottom: ${bottom}`); | |
item.style.cssText = ` | |
position: absolute; | |
left: ${left}; | |
bottom: ${bottom}; | |
`; | |
item.className = ""; | |
item.querySelector(":scope> span").innerText = ""; | |
}); | |
} | |
""" | |
indication_js = """ | |
(app_state, channel_info) => { | |
app_state = JSON.parse(JSON.stringify(app_state)); | |
channel_info = JSON.parse(JSON.stringify(channel_info)); | |
if(app_state.state == "finished") return; | |
let channel = channel_info.missingChannelsIndex[app_state["fillingCount"]-1] | |
channel = channel_info.templateByIndex[channel] | |
let left = channel_info.templateByName[channel].css_position[0]; | |
let bottom = channel_info.templateByName[channel].css_position[1]; | |
let rule = ` | |
#chkbox-group> div:nth-of-type(2)::after{ | |
content: ''; | |
position: absolute; | |
background-color: red; | |
width: 10px; | |
height: 10px; | |
border-radius: 50%; | |
left: ${left}; | |
bottom: ${bottom}; | |
} | |
`; | |
// check if indicator already exist | |
let exist = 0; | |
const styleSheet = document.styleSheets[0]; | |
for(let i=0; i<styleSheet.cssRules.length; i++){ | |
if(styleSheet.cssRules[i].selectorText == "#chkbox-group> div:nth-of-type(2)::after"){ | |
exist = 1; | |
console.log('exist!'); | |
styleSheet.deleteRule(i); | |
styleSheet.insertRule(rule, styleSheet.cssRules.length); | |
break; | |
} | |
} | |
if(exist == 0) styleSheet.insertRule(rule, styleSheet.cssRules.length); | |
} | |
""" | |
with gr.Blocks() as demo: | |
app_state_json = gr.JSON(visible=False) | |
channel_info_json = gr.JSON(visible=False) | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
<p style="text-align: center;">(...)</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
# 1.Channel Mapping | |
""" | |
) | |
# upload files, chose imputation way (??? | |
with gr.Row(): | |
in_raw_data = gr.File(label="Raw data (.csv)", file_types=[".csv"]) | |
in_raw_loc = gr.File(label="Channel locations (.loc, .locs)", file_types=[".loc", "locs"]) | |
with gr.Column(min_width=100): | |
in_sample_rate = gr.Textbox(label="Sampling rate (Hz)") | |
in_fill_mode = gr.Dropdown(choices=[ | |
#("adjacent channel", "adjacent"), | |
("mean (auto)", "mean_auto"), | |
("mean (manual)", "mean_manual"), | |
("",""), | |
"zero"], | |
value="mean_auto", | |
label="Imputation") | |
map_btn = gr.Button("Mapping") | |
chkbox_group = gr.CheckboxGroup(elem_id="chkbox-group", label="", visible=False) | |
next_btn = gr.Button("Next", interactive=False, visible=False) | |
# mapping result | |
res_md = gr.Markdown( | |
""" | |
### Mapping result: | |
""", | |
visible=False | |
) | |
with gr.Row(): | |
tpl_montage = gr.Image("./template_montage.png", label="Template montage", visible=False) | |
map_montage = gr.Image(label="Matched channels", visible=False) | |
#miss_txtbox = gr.Textbox(label="Missing channels", visible=False) | |
#tpl_loc_file = gr.File("./template_chanlocs.loc", show_label=False, visible=False) | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
# 2.Decode Data | |
""" | |
) | |
with gr.Row(): | |
in_model_name = gr.Dropdown(choices=[ | |
("ART", "EEGART"), | |
("IC-U-Net", "ICUNet"), | |
("IC-U-Net++", "UNetpp"), | |
("IC-U-Net-Attn", "AttUnet"), | |
"(mapped data)", | |
"(denoised data)"], | |
value="EEGART", | |
label="Model", | |
scale=2) | |
run_btn = gr.Button(scale=1, interactive=False) | |
batch_md = gr.Markdown(visible=False) | |
out_denoised_data = gr.File(label="Denoised data", visible=False) | |
with gr.Row(): | |
with gr.Tab("ART"): | |
gr.Markdown() | |
with gr.Tab("IC-U-Net"): | |
gr.Markdown(icunet) | |
with gr.Tab("IC-U-Net++"): | |
gr.Markdown() | |
with gr.Tab("IC-U-Net-Attn"): | |
gr.Markdown() | |
with gr.Tab("QuickStart"): | |
gr.Markdown(quickstart) | |
#demo.load(js=js) | |
def reset1(raw_data, samplerate): | |
# establish temp folder | |
filepath = os.path.dirname(str(raw_data)) | |
try: | |
os.mkdir(filepath+"/temp_data/") | |
except OSError as e: | |
utils.dataDelete(filepath+"/temp_data/") | |
os.mkdir(filepath+"/temp_data/") | |
#print(e) | |
# initialize app_state, channel_info | |
data = utils.read_train_data(raw_data) | |
app_state = { | |
"filepath": filepath+"/temp_data/", | |
"filenames": {}, | |
"sampleRate": int(samplerate), | |
} | |
channel_info = { | |
"dataShape" : data.shape | |
} | |
return {app_state_json : app_state, | |
channel_info_json : channel_info, | |
chkbox_group : gr.CheckboxGroup(choices=[], value=[], label="", visible=False), | |
next_btn : gr.Button("Next", interactive=False, visible=False), | |
run_btn : gr.Button(interactive=False), | |
tpl_montage : gr.Image(visible=False), | |
map_montage : gr.Image(value=None, visible=False), | |
res_md : gr.Markdown(visible=False), | |
batch_md : gr.Markdown(visible=False), | |
out_denoised_data : gr.File(visible=False)} | |
def mapping_result(app_state, channel_info, fill_mode): | |
in_num = len(channel_info["inputByName"]) | |
matched_num = 30 - len(channel_info["missingChannelsIndex"]) | |
batch_num = math.ceil((in_num-matched_num)/30) + 1 | |
app_state.update({ | |
"batchCount" : 1, | |
"totalBatchNum" : batch_num | |
}) | |
if fill_mode=="mean_manual" and channel_info["missingChannelsIndex"]!=[]: | |
app_state.update({ | |
"state" : "initializing", | |
"totalFillingNum" : len(channel_info["missingChannelsIndex"]) | |
}) | |
#print("Missing channels:", channel_info["missingChannelsIndex"]) | |
return {app_state_json : app_state, | |
next_btn : gr.Button(visible=True)} | |
else: | |
app_state.update({ | |
"state" : "finished" | |
}) | |
return {app_state_json : app_state, | |
res_md : gr.Markdown(visible=True), | |
run_btn : gr.Button(interactive=True)} | |
def show_montage(app_state, channel_info, raw_loc): | |
if app_state["state"] == "selecting": | |
return {app_state_json : app_state} # change nothing | |
filepath = app_state["filepath"] | |
raw_montage = read_custom_montage(raw_loc) | |
# convert all channel names to uppercase | |
for i in range(len(raw_montage.ch_names)): | |
channel = raw_montage.ch_names[i] | |
raw_montage.rename_channels({channel: str.upper(channel)}) | |
if app_state["state"] == "initializing": | |
filename = filepath+"raw_montage_"+str(random.randint(1,10000))+".png" | |
app_state["filenames"]["raw_montage"] = filename | |
raw_fig = raw_montage.plot() | |
raw_fig.set_size_inches(5.6, 5.6) | |
raw_fig.savefig(filename, pad_inches=0) | |
return {app_state_json : app_state} | |
elif app_state["state"] == "finished": | |
filename = filepath+"mapped_montage_"+str(random.randint(1,10000))+".png" | |
app_state["filenames"]["map_montage"] = filename | |
show_names= [] | |
for channel in channel_info["inputByName"]: | |
if channel_info["inputByName"][channel]["matched"]: | |
show_names.append(channel) | |
mapped_fig = raw_montage.plot(show_names=show_names) | |
mapped_fig.set_size_inches(5.6, 5.6) | |
mapped_fig.savefig(filename, pad_inches=0) | |
return {app_state_json : app_state, | |
tpl_montage : gr.Image(visible=True), | |
map_montage : gr.Image(value=filename, visible=True)} | |
#else: | |
#return {app_state_json : app_state} # change nothing | |
def generate_chkbox(app_state, channel_info): | |
if app_state["state"] == "initializing": | |
in_channels = [channel for channel in channel_info["inputByName"]] | |
app_state["state"] = "selecting" | |
app_state["fillingCount"] = 1 | |
idx = channel_info["missingChannelsIndex"][0] | |
name = channel_info["templateByIndex"][idx] | |
chkbox_label = name+' (1/'+str(app_state["totalFillingNum"])+')' | |
return {app_state_json : app_state, | |
chkbox_group : gr.CheckboxGroup(choices=in_channels, label=chkbox_label, visible=True), | |
next_btn : gr.Button(interactive=True)} | |
else: | |
return {app_state_json : app_state} # change nothing | |
map_btn.click( | |
fn = reset1, | |
inputs = [in_raw_data, in_sample_rate], | |
outputs = [app_state_json, channel_info_json, chkbox_group, next_btn, run_btn, | |
tpl_montage, map_montage, res_md, batch_md, out_denoised_data] | |
).success( | |
fn = mapping_stage1, | |
inputs = [app_state_json, channel_info_json, in_raw_data, in_raw_loc, in_fill_mode], | |
outputs = [app_state_json, channel_info_json] | |
).success( | |
fn = mapping_result, | |
inputs = [app_state_json, channel_info_json, in_fill_mode], | |
outputs = [app_state_json, next_btn, res_md, run_btn] | |
).success( | |
fn = show_montage, | |
inputs = [app_state_json, channel_info_json, in_raw_loc], | |
outputs = [app_state_json, tpl_montage, map_montage] | |
).success( | |
fn = generate_chkbox, | |
inputs = [app_state_json, channel_info_json], | |
outputs = [app_state_json, chkbox_group, next_btn] | |
).success( | |
fn = None, | |
js = chkbox_js, | |
inputs = [app_state_json, channel_info_json], | |
outputs = [] | |
) | |
def check_next(app_state, channel_info, selected, raw_data, fill_mode): | |
#if state["state"] == "selecting": | |
# save info before clicking on next_btn | |
prev_target_idx = channel_info["missingChannelsIndex"][app_state["fillingCount"]-1] | |
prev_target_name = channel_info["templateByIndex"][prev_target_idx] | |
selected_idx = [channel_info["inputByName"][channel]["index"] for channel in selected] | |
app_state["stage1NewOrder"][prev_target_idx] = selected_idx | |
#if len(selected)==1 and channel_info["inputByName"][selected[0]]["used"]==False: | |
#channel_info["inputByName"][selected[0]]["used"] = True | |
#channel_info["missingChannelsIndex"][state["fillingCount"]-1] = -1 | |
print('Selection for missing channel "{}"({}): {}'.format(prev_target_name, prev_target_idx, selected)) | |
# update next round | |
app_state["fillingCount"] += 1 | |
if app_state["fillingCount"] <= app_state["totalFillingNum"]: | |
target_idx = channel_info["missingChannelsIndex"][app_state["fillingCount"]-1] | |
target_name = channel_info["templateByIndex"][target_idx] | |
chkbox_label = target_name+' ('+str(app_state["fillingCount"])+'/'+str(app_state["totalFillingNum"])+')' | |
btn_label = "Submit" if app_state["fillingCount"]==app_state["totalFillingNum"] else "Next" | |
return {app_state_json : app_state, | |
#channel_info_json : channel_info, | |
chkbox_group : gr.CheckboxGroup(value=[], label=chkbox_label), | |
next_btn : gr.Button(btn_label)} | |
else: | |
app_state["state"] = "finished" | |
return {app_state_json : app_state, | |
#channel_info_json : channel_info, | |
chkbox_group : gr.CheckboxGroup(visible=False), | |
next_btn : gr.Button(visible=False), | |
res_md : gr.Markdown(visible=True), | |
run_btn : gr.Button(interactive=True)} | |
next_btn.click( | |
fn = check_next, | |
inputs = [app_state_json, channel_info_json, chkbox_group, in_raw_data, in_fill_mode], | |
outputs = [app_state_json, chkbox_group, next_btn, run_btn, res_md] | |
).success( | |
fn = show_montage, | |
inputs = [app_state_json, channel_info_json, in_raw_loc], | |
outputs = [app_state_json, tpl_montage, map_montage] | |
).success( | |
fn = None, | |
js = indication_js, | |
inputs = [app_state_json, channel_info_json], | |
outputs = [] | |
) | |
def delete_file(filename): | |
try: | |
os.remove(filename) | |
except OSError as e: | |
print(e) | |
def reset2(app_state, raw_data, model_name): | |
filepath = app_state["filepath"] | |
input_name = os.path.basename(str(raw_data)) | |
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv' | |
app_state["filenames"]["denoised"] = filepath + output_name | |
app_state.update({ | |
"runnigState" : "stage1", | |
"batchCount" : 1, | |
"stage2NewOrder" : [[]]*30 | |
}) | |
delete_file(filepath+'mapped.csv') | |
delete_file(filepath+'denoised.csv') | |
return {app_state_json : app_state, | |
run_btn : gr.Button(interactive=False), | |
batch_md : gr.Markdown(visible=False), | |
out_denoised_data : gr.File(visible=False)} | |
def run_model(app_state, channel_info, raw_data, model_name, fill_mode): | |
filepath = app_state["filepath"] | |
samplerate = app_state["sampleRate"] | |
new_filename = app_state["filenames"]["denoised"] | |
while app_state["runnigState"] != "finished": | |
#if app_state["batchCount"] > app_state["totalBatchNum"]: | |
#app_state["runnigState"] = "finished" | |
#break | |
md = 'Running model('+str(app_state["batchCount"])+'/'+str(app_state["totalBatchNum"])+')...' | |
yield {batch_md : gr.Markdown(md, visible=True)} | |
if app_state["batchCount"] > 1: | |
app_state, channel_info = mapping_stage2(app_state, channel_info, fill_mode) | |
if app_state["runnigState"] == "finished": | |
break | |
app_state["batchCount"] += 1 | |
reorder_to_template(app_state, raw_data) | |
# step1: Data preprocessing | |
total_file_num = utils.preprocessing(filepath, 'mapped.csv', samplerate) | |
# step2: Signal reconstruction | |
utils.reconstruct(model_name, total_file_num, filepath, 'denoised.csv', samplerate) | |
reorder_to_origin(app_state, channel_info, filepath+'denoised.csv', new_filename) | |
#if model_name == "(mapped data)": | |
#return {out_denoised_data : filepath + 'mapped.csv'} | |
#elif model_name == "(denoised data)": | |
#return {out_denoised_data : filepath + 'denoised.csv'} | |
delete_file(filepath+'mapped.csv') | |
delete_file(filepath+'denoised.csv') | |
yield {run_btn : gr.Button(interactive=True), | |
batch_md : gr.Markdown(visible=False), | |
out_denoised_data : gr.File(new_filename, visible=True)} | |
run_btn.click( | |
fn = reset2, | |
inputs = [app_state_json, in_raw_data, in_model_name], | |
outputs = [app_state_json, run_btn, batch_md, out_denoised_data] | |
).success( | |
fn = run_model, | |
inputs = [app_state_json, channel_info_json, in_raw_data, in_model_name, in_fill_mode], | |
outputs = [run_btn, batch_md, out_denoised_data] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |