Visual Document Retrieval
ColPali
Safetensors
English
vidore-experimental
vidore
manu commited on
Commit
4988933
·
verified ·
1 Parent(s): 7db3305

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +4 -3
handler.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Dict, List, Any
 
2
  from colpali_engine.models import ColQwen2, ColQwen2Processor
3
  import torch
4
 
@@ -8,7 +9,7 @@ class EndpointHandler():
8
  path,
9
  torch_dtype=torch.bfloat16,
10
  device_map="cuda:0", # or "mps" if on Apple Silicon
11
- # attn_implementation="flash_attention_2", # should work on A100
12
  ).eval()
13
  self.processor = ColQwen2Processor.from_pretrained(path)
14
 
@@ -24,7 +25,7 @@ class EndpointHandler():
24
  batch_images = self.processor.process_images([images]).to(self.model.device)
25
  # Forward pass
26
  with torch.no_grad():
27
- image_embeddings = self.model(**batch_images)
28
 
29
- return {"embeddings": image_embeddings.tolist()}
30
 
 
1
  from typing import Dict, List, Any
2
+ from transformers.utils.import_utils import is_flash_attn_2_available
3
  from colpali_engine.models import ColQwen2, ColQwen2Processor
4
  import torch
5
 
 
9
  path,
10
  torch_dtype=torch.bfloat16,
11
  device_map="cuda:0", # or "mps" if on Apple Silicon
12
+ attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, # should work on A100
13
  ).eval()
14
  self.processor = ColQwen2Processor.from_pretrained(path)
15
 
 
25
  batch_images = self.processor.process_images([images]).to(self.model.device)
26
  # Forward pass
27
  with torch.no_grad():
28
+ image_embeddings = self.model(**batch_images).tolist()
29
 
30
+ return {"embeddings": image_embeddings}
31