AIEEG / app.py
audrey06100's picture
update app.py
441b340
raw
history blame
1.47 kB
import gradio as gr
import numpy as np
import os
import utils
import channel_mapping
def test(model_name, fill_mode, csv_file, chanloc_file):
tmp_path = os.path.dirname(str(csv_file))
input_name = os.path.basename(str(csv_file))
output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
# Channel mapping
#input_loc_name = os.path.basename(str(chanloc_file))
channel_mapping.mapping(csv_file, chanloc_file, fill_mode)
# step1: Data preprocessing
total_file_num = utils.preprocessing(tmp_path, 'mapped.csv', 256)
# step2: Signal reconstruction
utils.reconstruct(model_name, total_file_num, tmp_path, output_name)
return tmp_path+'/'+output_data_name
input_data = gr.File(label="CSV file", file_types=[".csv"])
input_loc = gr.File(label="Channel location", file_types=[".loc", "locs"])
input_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART"],
value="ICUNet",
label="Model")
input_fill_mode = gr.Dropdown(choices=["zero", "adjacent"],
value="zero",
label="Fill empty channels with:")
inputs = [input_model_name, input_fill_mode, input_data, input_loc]
app = gr.Interface(fn=test,
inputs=inputs,
outputs=gr.File(),
allow_flagging="never")
if __name__ == "__main__":
app.launch()