Spaces:
Sleeping
Sleeping
Commit
·
94ae045
1
Parent(s):
6c43c31
update app.py
Browse files
app.py
CHANGED
@@ -2,23 +2,42 @@ import gradio as gr
|
|
2 |
import numpy as np
|
3 |
import os
|
4 |
import utils
|
5 |
-
|
6 |
|
7 |
-
def test(csv_file):
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# step1: Data preprocessing
|
11 |
-
total_file_num = utils.preprocessing(
|
12 |
|
13 |
# step2: Signal reconstruction
|
14 |
-
utils.reconstruct(
|
15 |
|
16 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
app = gr.Interface(fn=test,
|
19 |
-
inputs=
|
20 |
outputs=gr.File(),
|
21 |
allow_flagging="never")
|
22 |
|
23 |
if __name__ == "__main__":
|
24 |
-
app.launch()
|
|
|
2 |
import numpy as np
|
3 |
import os
|
4 |
import utils
|
5 |
+
import channel_mapping
|
6 |
|
7 |
+
def test(model_name, csv_file, chanloc_file):
|
8 |
+
#print('model:[', model_name, ']')
|
9 |
+
tmp_path = os.path.dirname(str(csv_file))
|
10 |
+
|
11 |
+
input_data_name = os.path.basename(str(csv_file))
|
12 |
+
output_data_name = os.path.splitext(input_data_name)[0]+'_'+model_name+'.csv'
|
13 |
+
#print('in_data_name:[', input_data_name, ']')
|
14 |
+
#print('out_data_name:[', output_data_name, ']')
|
15 |
+
|
16 |
+
# Channel mapping
|
17 |
+
input_loc_name = os.path.basename(str(chanloc_file))
|
18 |
+
#print('in_loc_name:[', input_loc_name, ']')
|
19 |
+
|
20 |
|
21 |
# step1: Data preprocessing
|
22 |
+
total_file_num = utils.preprocessing(tmp_path, input_data_name, 256)
|
23 |
|
24 |
# step2: Signal reconstruction
|
25 |
+
utils.reconstruct(model_name, total_file_num, tmp_path, output_data_name)
|
26 |
|
27 |
+
return tmp_path+'/'+output_data_name
|
28 |
+
|
29 |
+
input_data = gr.File(label="CSV file", file_types=[".csv"])
|
30 |
+
input_loc = gr.File(label="Channel location", file_types=[".loc", "locs"])
|
31 |
+
input_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART"],
|
32 |
+
value="ICUNet",
|
33 |
+
label="model")
|
34 |
+
|
35 |
+
inputs = [input_model_name, input_data, input_loc]
|
36 |
|
37 |
app = gr.Interface(fn=test,
|
38 |
+
inputs=inputs,
|
39 |
outputs=gr.File(),
|
40 |
allow_flagging="never")
|
41 |
|
42 |
if __name__ == "__main__":
|
43 |
+
app.launch(server_name="0.0.0.0", server_port=7859, share=False)
|