File size: 1,473 Bytes
b9eaea7
 
4728dd2
 
94ae045
b9eaea7
441b340
94ae045
 
441b340
 
94ae045
 
441b340
 
4728dd2
3ab4789
441b340
b9eaea7
3ab4789
441b340
3ab4789
94ae045
 
 
 
 
 
441b340
 
 
 
94ae045
441b340
b9eaea7
 
94ae045
4728dd2
 
b9eaea7
 
77546d5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gradio as gr
import numpy as np
import os
import utils
import channel_mapping

def test(model_name, fill_mode, csv_file, chanloc_file):
    tmp_path = os.path.dirname(str(csv_file))
    
    input_name = os.path.basename(str(csv_file))
    output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
    
    # Channel mapping
    #input_loc_name = os.path.basename(str(chanloc_file))
    channel_mapping.mapping(csv_file, chanloc_file, fill_mode)
    
    # step1: Data preprocessing
    total_file_num = utils.preprocessing(tmp_path, 'mapped.csv', 256)
    
    # step2: Signal reconstruction
    utils.reconstruct(model_name, total_file_num, tmp_path, output_name)
    
    return tmp_path+'/'+output_data_name

input_data = gr.File(label="CSV file", file_types=[".csv"])
input_loc = gr.File(label="Channel location", file_types=[".loc", "locs"])
input_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART"],
                                value="ICUNet",
                                label="Model")
input_fill_mode = gr.Dropdown(choices=["zero", "adjacent"],
                                value="zero",
                                label="Fill empty channels with:")

inputs = [input_model_name, input_fill_mode, input_data, input_loc]

app = gr.Interface(fn=test,
                   inputs=inputs,
                   outputs=gr.File(),
                   allow_flagging="never")

if __name__ == "__main__":
    app.launch()