audrey06100 commited on
Commit
2691303
·
1 Parent(s): 145f4ad
Files changed (3) hide show
  1. app.py +5 -4
  2. app_utils.py +1 -0
  3. utils.py +2 -2
app.py CHANGED
@@ -832,7 +832,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
832
  batch_md : gr.Markdown("", visible=True),
833
  out_data_file : gr.File(value=None, visible=False)}
834
 
835
- def run_model(stage1_info, stage2_info, samplerate, modelname):
836
  if stage2_info["errorFlag"] == True:
837
  stage2_info["errorFlag"] = False
838
  yield {stage2_json : stage2_info}
@@ -841,14 +841,15 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
841
  filepath = stage2_info["filePath"]
842
  inputname = stage2_info["fileNames"]["inputData"]
843
  outputname = stage2_info["fileNames"]["outputData"]
 
844
  mapping_result = stage1_info["mappingResult"]
845
- break_flag = False
846
 
 
847
  for i in range(stage1_info["batch"]):
848
  yield {batch_md : gr.Markdown('Running model({}/{})...'.format(i+1, stage1_info["batch"]))}
849
  try:
850
  # step1: Data preprocessing
851
- preprocess_data, channel_num = utils.preprocessing(filepath, inputname, int(samplerate), mapping_result[i])
852
  # step2: Signal reconstruction
853
  reconstructed_data = utils.reconstruct(modelname, preprocess_data, filepath, i)
854
  # step3: Data postprocessing
@@ -874,7 +875,7 @@ with gr.Blocks(js=js, delete_cache=(3600, 3600)) as demo:
874
  outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
875
  ).success(
876
  fn = run_model,
877
- inputs = [stage1_json, stage2_json, in_samplerate, in_modelname],
878
  outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
879
  )
880
 
 
832
  batch_md : gr.Markdown("", visible=True),
833
  out_data_file : gr.File(value=None, visible=False)}
834
 
835
+ def run_model(stage1_info, stage2_info, channel_info, samplerate, modelname):
836
  if stage2_info["errorFlag"] == True:
837
  stage2_info["errorFlag"] = False
838
  yield {stage2_json : stage2_info}
 
841
  filepath = stage2_info["filePath"]
842
  inputname = stage2_info["fileNames"]["inputData"]
843
  outputname = stage2_info["fileNames"]["outputData"]
844
+ channel_num = len(channel_info["inputNames"])
845
  mapping_result = stage1_info["mappingResult"]
 
846
 
847
+ break_flag = False
848
  for i in range(stage1_info["batch"]):
849
  yield {batch_md : gr.Markdown('Running model({}/{})...'.format(i+1, stage1_info["batch"]))}
850
  try:
851
  # step1: Data preprocessing
852
+ preprocess_data = utils.preprocessing(filepath, inputname, int(samplerate), mapping_result[i])
853
  # step2: Signal reconstruction
854
  reconstructed_data = utils.reconstruct(modelname, preprocess_data, filepath, i)
855
  # step3: Data postprocessing
 
875
  outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
876
  ).success(
877
  fn = run_model,
878
+ inputs = [stage1_json, stage2_json, channel_json, in_samplerate, in_modelname],
879
  outputs = [stage2_json, run_btn, cancel_btn, batch_md, out_data_file]
880
  )
881
 
app_utils.py CHANGED
@@ -313,6 +313,7 @@ def mapping_result(stage1_info, channel_info, filename):
313
  data = {
314
  #"templateNames" : channel_info["templateNames"],
315
  #"inputNames" : channel_info["inputNames"],
 
316
  "batch" : batch,
317
  "mappingResult" : results
318
  }
 
313
  data = {
314
  #"templateNames" : channel_info["templateNames"],
315
  #"inputNames" : channel_info["inputNames"],
316
+ "channelNum" : len(channel_info["inputNames"]),
317
  "batch" : batch,
318
  "mappingResult" : results
319
  }
utils.py CHANGED
@@ -202,7 +202,7 @@ def preprocessing(filepath, inputfile, samplerate, mapping_result):
202
 
203
  # read data
204
  signal = read_train_data(inputfile)
205
- channel_num = signal.shape[0]
206
  # channel mapping
207
  signal = reorder_data(signal, mapping_result)
208
  #print(signal.shape)
@@ -215,7 +215,7 @@ def preprocessing(filepath, inputfile, samplerate, mapping_result):
215
  # cutting data
216
  total_file_num = cut_data(filepath, signal)
217
 
218
- return total_file_num, channel_num
219
 
220
  def restore_order(data, all_data, mapping_result):
221
  for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):
 
202
 
203
  # read data
204
  signal = read_train_data(inputfile)
205
+ #print(signal.shape)
206
  # channel mapping
207
  signal = reorder_data(signal, mapping_result)
208
  #print(signal.shape)
 
215
  # cutting data
216
  total_file_num = cut_data(filepath, signal)
217
 
218
+ return total_file_num
219
 
220
  def restore_order(data, all_data, mapping_result):
221
  for i, (indices, flag) in enumerate(zip(mapping_result["index"], mapping_result["isOriginalData"])):