alessandro trinca tornidor commited on
Commit
910b241
·
1 Parent(s): 8c11d44

[bug] inference(): return 422 jsonresponse in case of invalid input

Browse files
Files changed (1) hide show
  1. lisa_on_cuda/utils/app_helpers.py +11 -3
lisa_on_cuda/utils/app_helpers.py CHANGED
@@ -203,7 +203,7 @@ def get_inference_model_by_args(args_to_parse):
203
  logging.info(f"args_to_parse:{args_to_parse}, creating model...")
204
  model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
205
  logging.info("created model, preparing inference function")
206
- no_seg_out, error_happened = placeholders["no_seg_out"], placeholders["error_happened"]
207
 
208
  @session_logger.set_uuid_logging
209
  def inference(input_str: str, input_image: str | np.ndarray):
@@ -214,8 +214,16 @@ def get_inference_model_by_args(args_to_parse):
214
 
215
  ## input valid check
216
  if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
217
- output_str = "[Error] Invalid input: ", input_str
218
- return error_happened, output_str
 
 
 
 
 
 
 
 
219
 
220
  # Model Inference
221
  conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()
 
203
  logging.info(f"args_to_parse:{args_to_parse}, creating model...")
204
  model, clip_image_processor, tokenizer, transform = get_model(args_to_parse)
205
  logging.info("created model, preparing inference function")
206
+ no_seg_out = placeholders["no_seg_out"]
207
 
208
  @session_logger.set_uuid_logging
209
  def inference(input_str: str, input_image: str | np.ndarray):
 
214
 
215
  ## input valid check
216
  if not re.match(r"^[A-Za-z ,.!?\'\"]+$", input_str) or len(input_str) < 1:
217
+ output_str = f"[Error] Unprocessable Entity input: {input_str}."
218
+ logging.error(output_str)
219
+
220
+ from fastapi import status
221
+ from fastapi.responses import JSONResponse
222
+
223
+ return JSONResponse(
224
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
225
+ content={"msg": "Error - Unprocessable Entity"}
226
+ )
227
 
228
  # Model Inference
229
  conv = conversation_lib.conv_templates[args_to_parse.conv_type].copy()