Uddipan Basu Bir commited on
Commit
a63ba49
·
1 Parent(s): 4617aca

Download checkpoint from HF hub in OcrReorderPipeline

Browse files
Files changed (1) hide show
  1. inference.py +6 -2
inference.py CHANGED
@@ -29,7 +29,11 @@ class OcrReorderPipeline(Pipeline):
29
  self.projection.to(self.device)
30
 
31
  def _sanitize_parameters(self, **kwargs):
32
- return {}, {}, {}
 
 
 
 
33
 
34
  def preprocess(self, image, words, boxes):
35
  data = base64.b64decode(image)
@@ -42,7 +46,7 @@ class OcrReorderPipeline(Pipeline):
42
  def _forward(self, model_inputs):
43
  pv, ids, mask, bbox = (
44
  model_inputs[k].to(self.device)
45
- for k in ("pixel_values","input_ids","attention_mask","bbox")
46
  )
47
  vision_out = self.model.vision_model(
48
  pixel_values=pv,
 
29
  self.projection.to(self.device)
30
 
31
  def _sanitize_parameters(self, **kwargs):
32
+ # extract the pipeline 'inputs' (base64 image) and custom args
33
+ image = kwargs.pop("inputs", None)
34
+ words = kwargs.pop("words", None)
35
+ boxes = kwargs.pop("boxes", None)
36
+ return {"image": image, "words": words, "boxes": boxes}, {}, {}
37
 
38
  def preprocess(self, image, words, boxes):
39
  data = base64.b64decode(image)
 
46
  def _forward(self, model_inputs):
47
  pv, ids, mask, bbox = (
48
  model_inputs[k].to(self.device)
49
+ for k in ("pixel_values", "input_ids", "attention_mask", "bbox")
50
  )
51
  vision_out = self.model.vision_model(
52
  pixel_values=pv,