audrey06100 commited on
Commit
94ae045
·
1 Parent(s): 6c43c31

update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -8
app.py CHANGED
@@ -2,23 +2,42 @@ import gradio as gr
2
  import numpy as np
3
  import os
4
  import utils
5
- from pathlib import Path
6
 
7
- def test(csv_file):
8
- tmp_filepath = str(Path(str(csv_file)).parent)
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # step1: Data preprocessing
11
- total_file_num = utils.preprocessing(csv_file, 256)
12
 
13
  # step2: Signal reconstruction
14
- utils.reconstruct("EEGART", total_file_num, tmp_filepath, "EEGART_test.csv")
15
 
16
- return tmp_filepath+"/EEGART_test.csv"
 
 
 
 
 
 
 
 
17
 
18
  app = gr.Interface(fn=test,
19
- inputs=gr.File(label="CSV file", file_types=[".csv"]),
20
  outputs=gr.File(),
21
  allow_flagging="never")
22
 
23
  if __name__ == "__main__":
24
- app.launch()
 
2
  import numpy as np
3
  import os
4
  import utils
5
+ import channel_mapping
6
 
7
+ def test(model_name, csv_file, chanloc_file):
8
+ #print('model:[', model_name, ']')
9
+ tmp_path = os.path.dirname(str(csv_file))
10
+
11
+ input_data_name = os.path.basename(str(csv_file))
12
+ output_data_name = os.path.splitext(input_data_name)[0]+'_'+model_name+'.csv'
13
+ #print('in_data_name:[', input_data_name, ']')
14
+ #print('out_data_name:[', output_data_name, ']')
15
+
16
+ # Channel mapping
17
+ input_loc_name = os.path.basename(str(chanloc_file))
18
+ #print('in_loc_name:[', input_loc_name, ']')
19
+
20
 
21
  # step1: Data preprocessing
22
+ total_file_num = utils.preprocessing(tmp_path, input_data_name, 256)
23
 
24
  # step2: Signal reconstruction
25
+ utils.reconstruct(model_name, total_file_num, tmp_path, output_data_name)
26
 
27
+ return tmp_path+'/'+output_data_name
28
+
29
+ input_data = gr.File(label="CSV file", file_types=[".csv"])
30
+ input_loc = gr.File(label="Channel location", file_types=[".loc", "locs"])
31
+ input_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART"],
32
+ value="ICUNet",
33
+ label="model")
34
+
35
+ inputs = [input_model_name, input_data, input_loc]
36
 
37
  app = gr.Interface(fn=test,
38
+ inputs=inputs,
39
  outputs=gr.File(),
40
  allow_flagging="never")
41
 
42
  if __name__ == "__main__":
43
+ app.launch(server_name="0.0.0.0", server_port=7859, share=False)