Visual Document Retrieval
ColPali
Safetensors
English
vidore-experimental
vidore
manuel commited on
Commit
c458479
·
1 Parent(s): f1b2913

multiroute

Browse files
Files changed (1) hide show
  1. handler.py +61 -21
handler.py CHANGED
@@ -20,40 +20,80 @@ class EndpointHandler():
20
 
21
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
22
  """
23
- Expects data in the following format:
24
  {
25
- "inputs": [
26
  "base64_encoded_image1",
27
  "base64_encoded_image2",
28
  ...
29
  ]
30
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- Decodes each Base64 image into a PIL Image, processes them, and returns the embeddings.
33
  """
34
- # Retrieve the list of base64 encoded images
35
- base64_images = data.get("inputs", [])
36
- if not isinstance(base64_images, list):
37
- base64_images = [base64_images]
38
- else:
39
- if len(base64_images) > 8:
 
 
 
 
 
 
 
 
 
40
  return {"message": "Send a maximum of 8 images at once. We recommend sending one by one to improve load balancing."}
41
 
42
- # Decode each image from base64 and convert to a PIL Image
43
- decoded_images = []
44
- for img_str in base64_images:
 
 
 
 
 
 
 
 
 
 
 
45
  try:
46
- img_data = base64.b64decode(img_str)
47
- image = Image.open(io.BytesIO(img_data)).convert("RGB")
48
- decoded_images.append(image)
49
  except Exception as e:
50
- print(f"Error decoding an image: {e}")
51
 
52
- # Process the images using the processor
53
- batch_images = self.processor.process_images(decoded_images).to(self.model.device)
 
 
 
 
 
54
 
55
  # Forward pass through the model
56
- with torch.no_grad():
57
- image_embeddings = self.model(**batch_images).tolist()
58
 
59
- return {"embeddings": image_embeddings}
 
20
 
21
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
22
  """
23
+ Expects data in one of the following formats:
24
  {
25
+ "images": [
26
  "base64_encoded_image1",
27
  "base64_encoded_image2",
28
  ...
29
  ]
30
  }
31
+ or
32
+ {
33
+ "processed_images": [
34
+ [...], # preprocessed image tensors
35
+ [...]
36
+ ]
37
+ }
38
+ or
39
+ {
40
+ "text": [
41
+ "text1",
42
+ "text2",
43
+ ...
44
+ ]
45
+ }
46
 
47
+ Returns embeddings for the provided input type.
48
  """
49
+ # Input validation
50
+ data = data.get("inputs", [])
51
+ input_keys = [key for key in ["images", "processed_images", "text"] if key in data]
52
+ if len(input_keys) != 1:
53
+ return {"error": "Exactly one of 'images', 'processed_images', or 'text' must be provided"}
54
+
55
+ input_type = input_keys[0]
56
+ inputs = data[input_type]
57
+
58
+
59
+ if input_type == "images":
60
+ if not isinstance(inputs, list):
61
+ inputs = [inputs]
62
+
63
+ if len(inputs) > 8:
64
  return {"message": "Send a maximum of 8 images at once. We recommend sending one by one to improve load balancing."}
65
 
66
+ # Decode each image from base64 and convert to a PIL Image
67
+ decoded_images = []
68
+ for img_str in inputs:
69
+ try:
70
+ img_data = base64.b64decode(img_str)
71
+ image = Image.open(io.BytesIO(img_data)).convert("RGB")
72
+ decoded_images.append(image)
73
+ except Exception as e:
74
+ return {"error": f"Error decoding image: {str(e)}"}
75
+
76
+ # Process the images using the processor
77
+ batch = self.processor.process_images(decoded_images).to(self.model.device)
78
+
79
+ elif input_type == "processed_images":
80
  try:
81
+ print(inputs)
82
+ batch = torch.load(io.BytesIO(inputs), map_location=self.model.device)
83
+ print(batch)
84
  except Exception as e:
85
+ return {"error": f"Error processing preprocessed images: {str(e)}"}
86
 
87
+ else: # text
88
+ if not isinstance(inputs, list):
89
+ inputs = [inputs]
90
+ try:
91
+ batch = self.processor.process_text(inputs).to(self.model.device)
92
+ except Exception as e:
93
+ return {"error": f"Error processing text: {str(e)}"}
94
 
95
  # Forward pass through the model
96
+ with torch.inference_mode():
97
+ embeddings = self.model(**batch).tolist()
98
 
99
+ return {"embeddings": embeddings}