audrey06100 commited on
Commit
f3fbfd6
·
1 Parent(s): 7129427

update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -4,15 +4,16 @@ import os
4
  import utils
5
  import channel_mapping
6
 
7
- def test(model_name, fill_mode, csv_file, chanloc_file):
8
  tmp_path = os.path.dirname(str(csv_file))
9
 
10
  input_name = os.path.basename(str(csv_file))
11
  output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
12
 
13
  # Channel mapping
14
- #input_loc_name = os.path.basename(str(chanloc_file))
15
- channel_mapping.mapping(csv_file, chanloc_file, fill_mode)
 
16
 
17
  # step1: Data preprocessing
18
  total_file_num = utils.preprocessing(tmp_path, 'mapped.csv', 256)
@@ -20,18 +21,18 @@ def test(model_name, fill_mode, csv_file, chanloc_file):
20
  # step2: Signal reconstruction
21
  utils.reconstruct(model_name, total_file_num, tmp_path, output_name)
22
 
23
- return tmp_path+'/'+output_data_name
24
 
25
- input_data = gr.File(label="CSV file", file_types=[".csv"])
26
- input_loc = gr.File(label="Channel location", file_types=[".loc", "locs"])
27
  input_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART"],
28
  value="ICUNet",
29
  label="Model")
30
- input_fill_mode = gr.Dropdown(choices=["zero", "adjacent"],
31
  value="zero",
32
- label="Fill empty channels with:")
33
 
34
- inputs = [input_model_name, input_fill_mode, input_data, input_loc]
35
 
36
  app = gr.Interface(fn=test,
37
  inputs=inputs,
@@ -39,4 +40,4 @@ app = gr.Interface(fn=test,
39
  allow_flagging="never")
40
 
41
  if __name__ == "__main__":
42
- app.launch()
 
4
  import utils
5
  import channel_mapping
6
 
7
+ def test(model_name, fill_mode, csv_file, chanlocs_file):
8
  tmp_path = os.path.dirname(str(csv_file))
9
 
10
  input_name = os.path.basename(str(csv_file))
11
  output_name = os.path.splitext(input_name)[0]+'_'+model_name+'.csv'
12
 
13
  # Channel mapping
14
+ #input_loc_name = os.path.basename(str(chanlocs_file))
15
+ fill_mode = "adjacent" if fill_mode == "adjacent channel's value" else fill_mode
16
+ channel_mapping.mapping(csv_file, chanlocs_file, fill_mode)
17
 
18
  # step1: Data preprocessing
19
  total_file_num = utils.preprocessing(tmp_path, 'mapped.csv', 256)
 
21
  # step2: Signal reconstruction
22
  utils.reconstruct(model_name, total_file_num, tmp_path, output_name)
23
 
24
+ return tmp_path+'/'+output_name
25
 
26
+ input_data = gr.File(label="Raw data (.csv)", file_types=[".csv"])
27
+ input_loc = gr.File(label="Channel location (.loc, .locs)", file_types=[".loc", "locs"])
28
  input_model_name = gr.Dropdown(choices=["ICUNet", "UNetpp", "AttUnet", "EEGART"],
29
  value="ICUNet",
30
  label="Model")
31
+ input_fill_value = gr.Dropdown(choices=["zero", "adjacent channel's value"],
32
  value="zero",
33
+ label="Imputation")
34
 
35
+ inputs = [input_model_name, input_fill_value, input_data, input_loc]
36
 
37
  app = gr.Interface(fn=test,
38
  inputs=inputs,
 
40
  allow_flagging="never")
41
 
42
  if __name__ == "__main__":
43
+ app.launch(server_name="0.0.0.0", server_port=7860, share=False)