LDanielBlueway commited on
Commit
f021aff
·
verified ·
1 Parent(s): 54fbbbf

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -5
handler.py CHANGED
@@ -1,14 +1,15 @@
1
  from typing import Dict, List, Any
2
  from PIL import Image
3
  from io import BytesIO
4
- from transformers import pipeline
5
  import base64
6
 
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
- self.pipeline=pipeline("zero-shot-object-detection",model=path)
11
-
 
12
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
  """
14
  data args:
@@ -23,5 +24,16 @@ class EndpointHandler():
23
  image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
24
 
25
  # run prediction one image wit provided candiates
26
- prediction = self.pipeline(image=[image], candidate_labels=inputs["candidates"])
27
- return prediction[0]
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
  from PIL import Image
3
  from io import BytesIO
4
+ from transformers import AutoProcessor, OmDetTurboForObjectDetection
5
  import base64
6
 
7
 
8
  class EndpointHandler():
9
  def __init__(self, path=""):
10
+ self.processor = AutoProcessor.from_pretrained("Blueway/inference-endpoint-zero-shot-image-classification")
11
+ self.model = OmDetTurboForObjectDetection.from_pretrained("Blueway/inference-endpoint-zero-shot-image-classification")
12
+
13
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
14
  """
15
  data args:
 
24
  image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
25
 
26
  # run prediction one image wit provided candiates
27
+ inputs = processor(image, text=inputs["candidates"], return_tensors="pt")
28
+ outputs = model(**inputs)
29
+ results = processor.post_process_grounded_object_detection(
30
+ outputs,
31
+ classes=classes,
32
+ target_sizes=[image.size[::-1]],
33
+ score_threshold=0.3,
34
+ nms_threshold=0.3,
35
+ )[0]
36
+ return results
37
+ #prediction = self.pipeline(image=[image], candidate_labels=inputs["candidates"])
38
+ #return prediction[0]
39
+