Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import os | |
import utils | |
import channel_mapping | |
def test(model_name, fill_mode, csv_file, chanloc_file): | |
tmp_path = os.path.dirname(str(csv_file)) | |
input_name = os.path.basename(str(csv_file)) | |
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv' | |
# Channel mapping | |
#input_loc_name = os.path.basename(str(chanloc_file)) | |
channel_mapping.mapping(csv_file, chanloc_file, fill_mode) | |
# step1: Data preprocessing | |
total_file_num = utils.preprocessing(tmp_path, 'mapped.csv', 256) | |
# step2: Signal reconstruction | |
utils.reconstruct(model_name, total_file_num, tmp_path, output_name) | |
return tmp_path+'/'+output_data_name | |
input_data = gr.File(label="CSV file", file_types=[".csv"]) | |
input_loc = gr.File(label="Channel location", file_types=[".loc", "locs"]) | |
input_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART"], | |
value="ICUNet", | |
label="Model") | |
input_fill_mode = gr.Dropdown(choices=["zero", "adjacent"], | |
value="zero", | |
label="Fill empty channels with:") | |
inputs = [input_model_name, input_fill_mode, input_data, input_loc] | |
app = gr.Interface(fn=test, | |
inputs=inputs, | |
outputs=gr.File(), | |
allow_flagging="never") | |
if __name__ == "__main__": | |
app.launch() | |