Visual Document Retrieval
ColPali
Safetensors
English
vidore-experimental
vidore
manu commited on
Commit
72c1e75
·
verified ·
1 Parent(s): b0c2940

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +30 -0
handler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from colpali_engine.models import ColQwen2, ColQwen2Processor
3
+ import torch
4
+
5
+ class EndpointHandler():
6
+ def __init__(self, path=""):
7
+ self.model = ColQwen2.from_pretrained(
8
+ path,
9
+ torch_dtype=torch.bfloat16,
10
+ device_map="cuda:0", # or "mps" if on Apple Silicon
11
+ # attn_implementation="flash_attention_2", # should work on A100
12
+ ).eval()
13
+ self.processor = ColQwen2Processor.from_pretrained(path)
14
+
15
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ """
17
+ data args:
18
+ inputs (:obj: `str`)
19
+ Return:
20
+ A :obj:`list` | `dict`: will be serialized and returned
21
+ """
22
+ # process input
23
+ images = data.pop("inputs", data)
24
+ batch_images = self.processor.process_images(images).to(model.device)
25
+ # Forward pass
26
+ with torch.no_grad():
27
+ image_embeddings = model(**batch_images)
28
+
29
+ return {"embeddings": list(torch.unbind(image_embeddings_doc.to("cpu")))}
30
+