amaye15 commited on
Commit
2c8e3a0
·
1 Parent(s): e165930

Optimised Handler

Browse files
Files changed (1) hide show
  1. handler.py +306 -66
handler.py CHANGED
@@ -1,9 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from typing import Dict, Any, List
3
  from PIL import Image
4
  import base64
5
  from io import BytesIO
6
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  class EndpointHandler:
@@ -20,10 +218,6 @@ class EndpointHandler:
20
  def __init__(self, path: str = "", default_batch_size: int = 4):
21
  """
22
  Initializes the EndpointHandler with a specified model path and default batch size.
23
-
24
- Args:
25
- path (str): Path to the pre-trained model and processor.
26
- default_batch_size (int): Default batch size for processing images and text data.
27
  """
28
  # Initialize logging
29
  logging.basicConfig(level=logging.INFO)
@@ -33,60 +227,91 @@ class EndpointHandler:
33
 
34
  self.logger.info("Initializing model and processor.")
35
  try:
36
- self.model = ColQwen2.from_pretrained(
37
- path,
38
- torch_dtype=torch.bfloat16,
39
- device_map=("cuda:0" if torch.cuda.is_available() else "cpu"),
40
- ).eval()
41
- self.processor = ColQwen2Processor.from_pretrained(path)
42
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
- self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
44
  self.default_batch_size = default_batch_size
45
  self.logger.info("Initialization complete.")
46
  except Exception as e:
47
  self.logger.error(f"Failed to initialize model or processor: {e}")
48
  raise
49
 
50
- def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
 
 
51
  """
52
- Processes a batch of images and generates embeddings.
53
 
54
  Args:
55
  images (List[Image.Image]): List of images to process.
 
56
 
57
  Returns:
58
- List[List[float]]: List of embeddings for each image.
59
  """
60
- self.logger.debug(f"Processing batch of {len(images)} images.")
61
  try:
62
- batch_images = self.processor.process_images(images).to(self.device)
 
 
 
 
 
63
  with torch.no_grad():
64
- image_embeddings = self.model(**batch_images)
65
- self.logger.debug("Image batch processing complete.")
66
- return image_embeddings.cpu().tolist()
 
 
 
 
 
67
  except Exception as e:
68
- self.logger.error(f"Error processing image batch: {e}")
69
  raise
70
 
71
- def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
 
 
72
  """
73
- Processes a batch of text queries and generates embeddings.
74
 
75
  Args:
76
  texts (List[str]): List of text queries to process.
 
77
 
78
  Returns:
79
- List[List[float]]: List of embeddings for each text query.
80
  """
81
- self.logger.debug(f"Processing batch of {len(texts)} text queries.")
82
  try:
83
- batch_queries = self.processor.process_queries(texts).to(self.device)
 
 
 
 
 
84
  with torch.no_grad():
85
- query_embeddings = self.model(**batch_queries)
86
- self.logger.debug("Text batch processing complete.")
87
- return query_embeddings.cpu().tolist()
 
 
 
 
 
88
  except Exception as e:
89
- self.logger.error(f"Error processing text batch: {e}")
90
  raise
91
 
92
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
@@ -103,7 +328,6 @@ class EndpointHandler:
103
  text_data = data.get("text", [])
104
  batch_size = data.get("batch_size", self.default_batch_size)
105
 
106
- # Decode and process images
107
  images = []
108
  if images_data:
109
  self.logger.info("Decoding images from base64.")
@@ -120,49 +344,65 @@ class EndpointHandler:
120
  self.logger.error("Images should be base64-encoded strings.")
121
  return {"error": "Images should be base64-encoded strings."}
122
 
123
- image_embeddings = []
124
- if images:
125
- self.logger.info("Processing image embeddings.")
126
- try:
127
- for i in range(0, len(images), batch_size):
128
- batch_images = images[i : i + batch_size]
129
- batch_embeddings = self._process_image_batch(batch_images)
130
- image_embeddings.extend(batch_embeddings)
131
- except Exception as e:
132
- self.logger.error(f"Error generating image embeddings: {e}")
133
- return {"error": f"Error generating image embeddings: {e}"}
134
 
135
- # Process text data
136
- text_embeddings = []
137
- if text_data:
138
- self.logger.info("Processing text embeddings.")
139
- try:
140
- for i in range(0, len(text_data), batch_size):
141
- batch_texts = text_data[i : i + batch_size]
142
- batch_text_embeddings = self._process_text_batch(batch_texts)
143
- text_embeddings.extend(batch_text_embeddings)
144
- except Exception as e:
145
- self.logger.error(f"Error generating text embeddings: {e}")
146
- return {"error": f"Error generating text embeddings: {e}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # Compute similarity scores if both image and text embeddings are available
149
- scores = []
150
- if image_embeddings and text_embeddings:
151
  self.logger.info("Computing similarity scores.")
152
  try:
153
- image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
154
- text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
155
- with torch.no_grad():
156
- scores = (
157
- self.processor.score_multi_vector(
158
- text_embeddings_tensor, image_embeddings_tensor
159
- )
160
- .cpu()
161
- .tolist()
162
  )
163
  self.logger.info("Similarity scoring complete.")
164
  except Exception as e:
165
  self.logger.error(f"Error computing similarity scores: {e}")
166
  return {"error": f"Error computing similarity scores: {e}"}
167
 
168
- return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # from typing import Dict, Any, List
3
+ # from PIL import Image
4
+ # import base64
5
+ # from io import BytesIO
6
+ # import logging
7
+
8
+
9
+ # class EndpointHandler:
10
+ # """
11
+ # A handler class for processing image and text data, generating embeddings using a specified model and processor.
12
+
13
+ # Attributes:
14
+ # model: The pre-trained model used for generating embeddings.
15
+ # processor: The pre-trained processor used to process images and text before model inference.
16
+ # device: The device (CPU or CUDA) used to run model inference.
17
+ # default_batch_size: The default batch size for processing images and text in batches.
18
+ # """
19
+
20
+ # def __init__(self, path: str = "", default_batch_size: int = 4):
21
+ # """
22
+ # Initializes the EndpointHandler with a specified model path and default batch size.
23
+
24
+ # Args:
25
+ # path (str): Path to the pre-trained model and processor.
26
+ # default_batch_size (int): Default batch size for processing images and text data.
27
+ # """
28
+ # # Initialize logging
29
+ # logging.basicConfig(level=logging.INFO)
30
+ # self.logger = logging.getLogger(__name__)
31
+
32
+ # from colpali_engine.models import ColQwen2, ColQwen2Processor
33
+
34
+ # self.logger.info("Initializing model and processor.")
35
+ # try:
36
+ # self.model = ColQwen2.from_pretrained(
37
+ # path,
38
+ # torch_dtype=torch.bfloat16,
39
+ # device_map="auto",
40
+ # ).eval()
41
+ # self.processor = ColQwen2Processor.from_pretrained(path)
42
+ # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+ # self.model.to(self.device)
44
+ # self.default_batch_size = default_batch_size
45
+ # self.logger.info("Initialization complete.")
46
+ # except Exception as e:
47
+ # self.logger.error(f"Failed to initialize model or processor: {e}")
48
+ # raise
49
+
50
+ # def _process_image_batch(self, images: List[Image.Image]) -> List[List[float]]:
51
+ # """
52
+ # Processes a batch of images and generates embeddings.
53
+
54
+ # Args:
55
+ # images (List[Image.Image]): List of images to process.
56
+
57
+ # Returns:
58
+ # List[List[float]]: List of embeddings for each image.
59
+ # """
60
+ # self.logger.debug(f"Processing batch of {len(images)} images.")
61
+ # try:
62
+ # batch_images = self.processor.process_images(images).to(self.device)
63
+ # with torch.no_grad(), torch.amp.autocast():
64
+ # image_embeddings = self.model(**batch_images)
65
+ # self.logger.debug("Image batch processing complete.")
66
+ # return image_embeddings.cpu().tolist()
67
+ # except Exception as e:
68
+ # self.logger.error(f"Error processing image batch: {e}")
69
+ # raise
70
+
71
+ # def _process_text_batch(self, texts: List[str]) -> List[List[float]]:
72
+ # """
73
+ # Processes a batch of text queries and generates embeddings.
74
+
75
+ # Args:
76
+ # texts (List[str]): List of text queries to process.
77
+
78
+ # Returns:
79
+ # List[List[float]]: List of embeddings for each text query.
80
+ # """
81
+ # self.logger.debug(f"Processing batch of {len(texts)} text queries.")
82
+ # try:
83
+ # batch_queries = self.processor.process_queries(texts).to(self.device)
84
+ # with torch.no_grad(), torch.amp.autocast():
85
+ # query_embeddings = self.model(**batch_queries)
86
+ # self.logger.debug("Text batch processing complete.")
87
+ # return query_embeddings.cpu().tolist()
88
+ # except Exception as e:
89
+ # self.logger.error(f"Error processing text batch: {e}")
90
+ # raise
91
+
92
+ # def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
93
+ # """
94
+ # Processes input data containing base64-encoded images and text queries, decodes them, and generates embeddings.
95
+
96
+ # Args:
97
+ # data (Dict[str, Any]): Dictionary containing input images, text queries, and optional batch size.
98
+
99
+ # Returns:
100
+ # Dict[str, Any]: Dictionary containing generated embeddings for images and text or error messages.
101
+ # """
102
+ # images_data = data.get("image", [])
103
+ # text_data = data.get("text", [])
104
+ # batch_size = data.get("batch_size", self.default_batch_size)
105
+
106
+ # # Decode and process images
107
+ # images = []
108
+ # if images_data:
109
+ # self.logger.info("Decoding images from base64.")
110
+ # for img_data in images_data:
111
+ # if isinstance(img_data, str):
112
+ # try:
113
+ # image_bytes = base64.b64decode(img_data)
114
+ # image = Image.open(BytesIO(image_bytes)).convert("RGB")
115
+ # images.append(image)
116
+ # except Exception as e:
117
+ # self.logger.error(f"Invalid image data: {e}")
118
+ # return {"error": f"Invalid image data: {e}"}
119
+ # else:
120
+ # self.logger.error("Images should be base64-encoded strings.")
121
+ # return {"error": "Images should be base64-encoded strings."}
122
+
123
+ # image_embeddings = []
124
+ # if images:
125
+ # self.logger.info("Processing image embeddings.")
126
+ # try:
127
+ # for i in range(0, len(images), batch_size):
128
+ # batch_images = images[i : i + batch_size]
129
+ # batch_embeddings = self._process_image_batch(batch_images)
130
+ # image_embeddings.extend(batch_embeddings)
131
+ # except Exception as e:
132
+ # self.logger.error(f"Error generating image embeddings: {e}")
133
+ # return {"error": f"Error generating image embeddings: {e}"}
134
+
135
+ # # Process text data
136
+ # text_embeddings = []
137
+ # if text_data:
138
+ # self.logger.info("Processing text embeddings.")
139
+ # try:
140
+ # for i in range(0, len(text_data), batch_size):
141
+ # batch_texts = text_data[i : i + batch_size]
142
+ # batch_text_embeddings = self._process_text_batch(batch_texts)
143
+ # text_embeddings.extend(batch_text_embeddings)
144
+ # except Exception as e:
145
+ # self.logger.error(f"Error generating text embeddings: {e}")
146
+ # return {"error": f"Error generating text embeddings: {e}"}
147
+
148
+ # # Compute similarity scores if both image and text embeddings are available
149
+ # scores = []
150
+ # if image_embeddings and text_embeddings:
151
+ # self.logger.info("Computing similarity scores.")
152
+ # try:
153
+ # image_embeddings_tensor = torch.tensor(image_embeddings).to(self.device)
154
+ # text_embeddings_tensor = torch.tensor(text_embeddings).to(self.device)
155
+ # with torch.no_grad(), torch.amp.autocast():
156
+ # scores = (
157
+ # self.processor.score_multi_vector(
158
+ # text_embeddings_tensor, image_embeddings_tensor
159
+ # )
160
+ # .cpu()
161
+ # .tolist()
162
+ # )
163
+ # self.logger.info("Similarity scoring complete.")
164
+ # except Exception as e:
165
+ # self.logger.error(f"Error computing similarity scores: {e}")
166
+ # return {"error": f"Error computing similarity scores: {e}"}
167
+
168
+ # return {"image": image_embeddings, "text": text_embeddings, "scores": scores}
169
+
170
+
171
  import torch
172
  from typing import Dict, Any, List
173
  from PIL import Image
174
  import base64
175
  from io import BytesIO
176
  import logging
177
+ from torch.utils.data import DataLoader, Dataset
178
+ import threading
179
+
180
+
181
+ class ImageDataset(Dataset):
182
+ def __init__(self, images: List[Image.Image], processor):
183
+ self.images = images
184
+ self.processor = processor
185
+
186
+ def __len__(self):
187
+ return len(self.images)
188
+
189
+ def __getitem__(self, idx):
190
+ image = self.processor.process_images([self.images[idx]])
191
+ return image
192
+
193
+
194
+ class TextDataset(Dataset):
195
+ def __init__(self, texts: List[str], processor):
196
+ self.texts = texts
197
+ self.processor = processor
198
+
199
+ def __len__(self):
200
+ return len(self.texts)
201
+
202
+ def __getitem__(self, idx):
203
+ text = self.processor.process_queries([self.texts[idx]])
204
+ return text
205
 
206
 
207
  class EndpointHandler:
 
218
  def __init__(self, path: str = "", default_batch_size: int = 4):
219
  """
220
  Initializes the EndpointHandler with a specified model path and default batch size.
 
 
 
 
221
  """
222
  # Initialize logging
223
  logging.basicConfig(level=logging.INFO)
 
227
 
228
  self.logger.info("Initializing model and processor.")
229
  try:
 
 
 
 
 
 
230
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
231
+
232
+ self.model = (
233
+ ColQwen2.from_pretrained(
234
+ path,
235
+ torch_dtype=torch.bfloat16,
236
+ device_map="auto",
237
+ )
238
+ .to(self.device)
239
+ .eval()
240
+ )
241
+
242
+ self.processor = ColQwen2Processor.from_pretrained(path)
243
  self.default_batch_size = default_batch_size
244
  self.logger.info("Initialization complete.")
245
  except Exception as e:
246
  self.logger.error(f"Failed to initialize model or processor: {e}")
247
  raise
248
 
249
+ def _process_image_embeddings(
250
+ self, images: List[Image.Image], batch_size: int
251
+ ) -> torch.Tensor:
252
  """
253
+ Processes images and generates embeddings.
254
 
255
  Args:
256
  images (List[Image.Image]): List of images to process.
257
+ batch_size (int): Batch size for processing images.
258
 
259
  Returns:
260
+ torch.Tensor: Tensor containing embeddings for each image.
261
  """
262
+ self.logger.debug(f"Processing {len(images)} images.")
263
  try:
264
+ image_dataset = ImageDataset(images, self.processor)
265
+ image_loader = DataLoader(
266
+ image_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
267
+ )
268
+
269
+ all_embeddings = []
270
  with torch.no_grad():
271
+ for batch in image_loader:
272
+ batch_images = batch[0].to(self.device, non_blocking=True)
273
+ with torch.cuda.amp.autocast():
274
+ embeddings = self.model(**batch_images)
275
+ all_embeddings.append(embeddings)
276
+ image_embeddings = torch.cat(all_embeddings, dim=0)
277
+ self.logger.debug("Image processing complete.")
278
+ return image_embeddings
279
  except Exception as e:
280
+ self.logger.error(f"Error processing images: {e}")
281
  raise
282
 
283
+ def _process_text_embeddings(
284
+ self, texts: List[str], batch_size: int
285
+ ) -> torch.Tensor:
286
  """
287
+ Processes text queries and generates embeddings.
288
 
289
  Args:
290
  texts (List[str]): List of text queries to process.
291
+ batch_size (int): Batch size for processing texts.
292
 
293
  Returns:
294
+ torch.Tensor: Tensor containing embeddings for each text query.
295
  """
296
+ self.logger.debug(f"Processing {len(texts)} text queries.")
297
  try:
298
+ text_dataset = TextDataset(texts, self.processor)
299
+ text_loader = DataLoader(
300
+ text_dataset, batch_size=batch_size, num_workers=4, pin_memory=True
301
+ )
302
+
303
+ all_embeddings = []
304
  with torch.no_grad():
305
+ for batch in text_loader:
306
+ batch_texts = batch[0].to(self.device, non_blocking=True)
307
+ with torch.cuda.amp.autocast():
308
+ embeddings = self.model(**batch_texts)
309
+ all_embeddings.append(embeddings)
310
+ text_embeddings = torch.cat(all_embeddings, dim=0)
311
+ self.logger.debug("Text processing complete.")
312
+ return text_embeddings
313
  except Exception as e:
314
+ self.logger.error(f"Error processing texts: {e}")
315
  raise
316
 
317
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
 
328
  text_data = data.get("text", [])
329
  batch_size = data.get("batch_size", self.default_batch_size)
330
 
 
331
  images = []
332
  if images_data:
333
  self.logger.info("Decoding images from base64.")
 
344
  self.logger.error("Images should be base64-encoded strings.")
345
  return {"error": "Images should be base64-encoded strings."}
346
 
347
+ image_embeddings = None
348
+ text_embeddings = None
349
+ scores = None
 
 
 
 
 
 
 
 
350
 
351
+ def process_images():
352
+ nonlocal image_embeddings
353
+ if images:
354
+ self.logger.info("Processing image embeddings.")
355
+ try:
356
+ image_embeddings = self._process_image_embeddings(
357
+ images, batch_size
358
+ )
359
+ except Exception as e:
360
+ self.logger.error(f"Error generating image embeddings: {e}")
361
+
362
+ def process_texts():
363
+ nonlocal text_embeddings
364
+ if text_data:
365
+ self.logger.info("Processing text embeddings.")
366
+ try:
367
+ text_embeddings = self._process_text_embeddings(
368
+ text_data, batch_size
369
+ )
370
+ except Exception as e:
371
+ self.logger.error(f"Error generating text embeddings: {e}")
372
+
373
+ # Process images and texts in parallel if both are present
374
+ threads = []
375
+ if images_data and text_data:
376
+ image_thread = threading.Thread(target=process_images)
377
+ text_thread = threading.Thread(target=process_texts)
378
+ threads.extend([image_thread, text_thread])
379
+ image_thread.start()
380
+ text_thread.start()
381
+ for thread in threads:
382
+ thread.join()
383
+ else:
384
+ process_images()
385
+ process_texts()
386
 
387
+ # Compute similarity scores if both embeddings are available
388
+ if image_embeddings is not None and text_embeddings is not None:
 
389
  self.logger.info("Computing similarity scores.")
390
  try:
391
+ with torch.no_grad(), torch.cuda.amp.autocast():
392
+ scores = self.processor.score_multi_vector(
393
+ text_embeddings, image_embeddings
 
 
 
 
 
 
394
  )
395
  self.logger.info("Similarity scoring complete.")
396
  except Exception as e:
397
  self.logger.error(f"Error computing similarity scores: {e}")
398
  return {"error": f"Error computing similarity scores: {e}"}
399
 
400
+ result = {}
401
+ if image_embeddings is not None:
402
+ result["image"] = image_embeddings.cpu().tolist()
403
+ if text_embeddings is not None:
404
+ result["text"] = text_embeddings.cpu().tolist()
405
+ if scores is not None:
406
+ result["scores"] = scores.cpu().tolist()
407
+
408
+ return result