Visual Document Retrieval
ColPali
Safetensors
English
vidore-experimental
vidore
manuel commited on
Commit
f47f1d6
·
1 Parent(s): 98543a7
Files changed (1) hide show
  1. handler.py +12 -19
handler.py CHANGED
@@ -20,7 +20,7 @@ class EndpointHandler():
20
 
21
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
22
  """
23
- Expects data in one of the following formats:
24
  {
25
  "images": [
26
  "base64_encoded_image1",
@@ -28,16 +28,9 @@ class EndpointHandler():
28
  ...
29
  ]
30
  }
31
- or
32
  {
33
- "processed_images": [
34
- [...], # preprocessed image tensors
35
- [...]
36
- ]
37
- }
38
- or
39
- {
40
- "text": [
41
  "text1",
42
  "text2",
43
  ...
@@ -48,9 +41,9 @@ class EndpointHandler():
48
  """
49
  # Input validation
50
  data = data.get("inputs", [])
51
- input_keys = [key for key in ["images", "processed_images", "text"] if key in data]
52
  if len(input_keys) != 1:
53
- return {"error": "Exactly one of 'images', 'processed_images', or 'text' must be provided"}
54
 
55
  input_type = input_keys[0]
56
  inputs = data[input_type]
@@ -76,18 +69,18 @@ class EndpointHandler():
76
  # Process the images using the processor
77
  batch = self.processor.process_images(decoded_images).to(self.model.device)
78
 
79
- elif input_type == "processed_images":
80
- try:
81
- buffer = io.BytesIO(base64.b64decode(inputs))
82
- batch = torch.load(buffer, map_location=self.model.device)
83
- except Exception as e:
84
- return {"error": f"Error processing preprocessed images: {str(e)}"}
85
 
86
  else: # text
87
  if not isinstance(inputs, list):
88
  inputs = [inputs]
89
  try:
90
- batch = self.processor.process_text(inputs).to(self.model.device)
91
  except Exception as e:
92
  return {"error": f"Error processing text: {str(e)}"}
93
 
 
20
 
21
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
22
  """
23
+ Expects data in one of the following formats in the "inputs" key:
24
  {
25
  "images": [
26
  "base64_encoded_image1",
 
28
  ...
29
  ]
30
  }
31
+ xor
32
  {
33
+ "queries": [
 
 
 
 
 
 
 
34
  "text1",
35
  "text2",
36
  ...
 
41
  """
42
  # Input validation
43
  data = data.get("inputs", [])
44
+ input_keys = [key for key in ["images", "queries"] if key in data]
45
  if len(input_keys) != 1:
46
+ return {"error": "Exactly one of 'images', 'queries' must be provided"}
47
 
48
  input_type = input_keys[0]
49
  inputs = data[input_type]
 
69
  # Process the images using the processor
70
  batch = self.processor.process_images(decoded_images).to(self.model.device)
71
 
72
+ # elif input_type == "processed_images":
73
+ # try:
74
+ # buffer = io.BytesIO(base64.b64decode(inputs))
75
+ # batch = torch.load(buffer, map_location=self.model.device)
76
+ # except Exception as e:
77
+ # return {"error": f"Error processing preprocessed images: {str(e)}"}
78
 
79
  else: # text
80
  if not isinstance(inputs, list):
81
  inputs = [inputs]
82
  try:
83
+ batch = self.processor.process_queries(inputs).to(self.model.device)
84
  except Exception as e:
85
  return {"error": f"Error processing text: {str(e)}"}
86