Update handler.py
Browse files- handler.py +35 -12
handler.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
from typing import Dict, List, Any
|
2 |
from transformers.utils.import_utils import is_flash_attn_2_available
|
3 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
@@ -9,23 +12,43 @@ class EndpointHandler():
|
|
9 |
path,
|
10 |
torch_dtype=torch.bfloat16,
|
11 |
device_map="cuda:0", # or "mps" if on Apple Silicon
|
12 |
-
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
|
13 |
-
|
14 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
15 |
|
16 |
-
def __call__(self, data: Dict[str, Any]) ->
|
17 |
"""
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
"""
|
23 |
-
#
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
with torch.no_grad():
|
28 |
image_embeddings = self.model(**batch_images).tolist()
|
29 |
|
30 |
return {"embeddings": image_embeddings}
|
31 |
-
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
from PIL import Image
|
4 |
from typing import Dict, List, Any
|
5 |
from transformers.utils.import_utils import is_flash_attn_2_available
|
6 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
|
|
12 |
path,
|
13 |
torch_dtype=torch.bfloat16,
|
14 |
device_map="cuda:0", # or "mps" if on Apple Silicon
|
15 |
+
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
|
16 |
+
).eval()
|
17 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
18 |
|
19 |
+
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
20 |
"""
|
21 |
+
Expects data in the following format:
|
22 |
+
{
|
23 |
+
"images": [
|
24 |
+
"base64_encoded_image1",
|
25 |
+
"base64_encoded_image2",
|
26 |
+
...
|
27 |
+
]
|
28 |
+
}
|
29 |
+
|
30 |
+
Decodes each Base64 image into a PIL Image, processes them, and returns the embeddings.
|
31 |
"""
|
32 |
+
# Retrieve the list of base64 encoded images
|
33 |
+
base64_images = data.get("images", [])
|
34 |
+
if not isinstance(base64_images, list):
|
35 |
+
base64_images = [base64_images]
|
36 |
+
|
37 |
+
# Decode each image from base64 and convert to a PIL Image
|
38 |
+
decoded_images = []
|
39 |
+
for img_str in base64_images:
|
40 |
+
try:
|
41 |
+
img_data = base64.b64decode(img_str)
|
42 |
+
image = Image.open(io.BytesIO(img_data)).convert("RGB")
|
43 |
+
decoded_images.append(image)
|
44 |
+
except Exception as e:
|
45 |
+
print(f"Error decoding an image: {e}")
|
46 |
+
|
47 |
+
# Process the images using the processor
|
48 |
+
batch_images = self.processor.process_images(decoded_images).to(self.model.device)
|
49 |
+
|
50 |
+
# Forward pass through the model
|
51 |
with torch.no_grad():
|
52 |
image_embeddings = self.model(**batch_images).tolist()
|
53 |
|
54 |
return {"embeddings": image_embeddings}
|
|