Update handler.py
Browse files- handler.py +4 -3
handler.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from typing import Dict, List, Any
|
|
|
2 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
3 |
import torch
|
4 |
|
@@ -8,7 +9,7 @@ class EndpointHandler():
|
|
8 |
path,
|
9 |
torch_dtype=torch.bfloat16,
|
10 |
device_map="cuda:0", # or "mps" if on Apple Silicon
|
11 |
-
|
12 |
).eval()
|
13 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
14 |
|
@@ -24,7 +25,7 @@ class EndpointHandler():
|
|
24 |
batch_images = self.processor.process_images([images]).to(self.model.device)
|
25 |
# Forward pass
|
26 |
with torch.no_grad():
|
27 |
-
image_embeddings = self.model(**batch_images)
|
28 |
|
29 |
-
return {"embeddings": image_embeddings
|
30 |
|
|
|
1 |
from typing import Dict, List, Any
|
2 |
+
from transformers.utils.import_utils import is_flash_attn_2_available
|
3 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
4 |
import torch
|
5 |
|
|
|
9 |
path,
|
10 |
torch_dtype=torch.bfloat16,
|
11 |
device_map="cuda:0", # or "mps" if on Apple Silicon
|
12 |
+
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, # should work on A100
|
13 |
).eval()
|
14 |
self.processor = ColQwen2Processor.from_pretrained(path)
|
15 |
|
|
|
25 |
batch_images = self.processor.process_images([images]).to(self.model.device)
|
26 |
# Forward pass
|
27 |
with torch.no_grad():
|
28 |
+
image_embeddings = self.model(**batch_images).tolist()
|
29 |
|
30 |
+
return {"embeddings": image_embeddings}
|
31 |
|