Visual Document Retrieval
ColPali
Safetensors
English
vidore-experimental
vidore
manu commited on
Commit
f0a3223
·
verified ·
1 Parent(s): 4988933

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +35 -12
handler.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  from typing import Dict, List, Any
2
  from transformers.utils.import_utils import is_flash_attn_2_available
3
  from colpali_engine.models import ColQwen2, ColQwen2Processor
@@ -9,23 +12,43 @@ class EndpointHandler():
9
  path,
10
  torch_dtype=torch.bfloat16,
11
  device_map="cuda:0", # or "mps" if on Apple Silicon
12
- attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, # should work on A100
13
- ).eval()
14
  self.processor = ColQwen2Processor.from_pretrained(path)
15
 
16
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
17
  """
18
- data args:
19
- inputs (:obj: `str`)
20
- Return:
21
- A :obj:`list` | `dict`: will be serialized and returned
 
 
 
 
 
 
22
  """
23
- # process input
24
- images = data.pop("inputs", data)
25
- batch_images = self.processor.process_images([images]).to(self.model.device)
26
- # Forward pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  with torch.no_grad():
28
  image_embeddings = self.model(**batch_images).tolist()
29
 
30
  return {"embeddings": image_embeddings}
31
-
 
1
+ import base64
2
+ import io
3
+ from PIL import Image
4
  from typing import Dict, List, Any
5
  from transformers.utils.import_utils import is_flash_attn_2_available
6
  from colpali_engine.models import ColQwen2, ColQwen2Processor
 
12
  path,
13
  torch_dtype=torch.bfloat16,
14
  device_map="cuda:0", # or "mps" if on Apple Silicon
15
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
16
+ ).eval()
17
  self.processor = ColQwen2Processor.from_pretrained(path)
18
 
19
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
  """
21
+ Expects data in the following format:
22
+ {
23
+ "images": [
24
+ "base64_encoded_image1",
25
+ "base64_encoded_image2",
26
+ ...
27
+ ]
28
+ }
29
+
30
+ Decodes each Base64 image into a PIL Image, processes them, and returns the embeddings.
31
  """
32
+ # Retrieve the list of base64 encoded images
33
+ base64_images = data.get("images", [])
34
+ if not isinstance(base64_images, list):
35
+ base64_images = [base64_images]
36
+
37
+ # Decode each image from base64 and convert to a PIL Image
38
+ decoded_images = []
39
+ for img_str in base64_images:
40
+ try:
41
+ img_data = base64.b64decode(img_str)
42
+ image = Image.open(io.BytesIO(img_data)).convert("RGB")
43
+ decoded_images.append(image)
44
+ except Exception as e:
45
+ print(f"Error decoding an image: {e}")
46
+
47
+ # Process the images using the processor
48
+ batch_images = self.processor.process_images(decoded_images).to(self.model.device)
49
+
50
+ # Forward pass through the model
51
  with torch.no_grad():
52
  image_embeddings = self.model(**batch_images).tolist()
53
 
54
  return {"embeddings": image_embeddings}