File size: 3,453 Bytes
f0a3223 72c1e75 4988933 72c1e75 f3b052e f0a3223 530094e 03e731a 5b26f75 72c1e75 f0a3223 72c1e75 f47f1d6 f0a3223 c458479 f0a3223 f47f1d6 c458479 f47f1d6 c458479 f0a3223 c458479 72c1e75 c458479 f47f1d6 c458479 f47f1d6 c458479 1ae20a5 f0a3223 c458479 f47f1d6 f0a3223 c458479 f47f1d6 c458479 f0a3223 c458479 72c1e75 c458479 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import base64
import io
from PIL import Image
from typing import Dict, List, Any
from transformers.utils.import_utils import is_flash_attn_2_available
from colpali_engine.models import ColQwen2, ColQwen2Processor
import torch
class EndpointHandler():
def __init__(self, path=""):
self.model = ColQwen2.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="cuda:0", # or "mps" if on Apple Silicon
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
).eval()
self.processor = ColQwen2Processor.from_pretrained(path) #, max_num_visual_tokens=8192) # temporary
# self.model = torch.compile(self.model)
print(f"Model and processor loaded {'with' if is_flash_attn_2_available() else 'without'} FA2")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Expects data in one of the following formats in the "inputs" key:
{
"images": [
"base64_encoded_image1",
"base64_encoded_image2",
...
]
}
xor
{
"queries": [
"text1",
"text2",
...
]
}
Returns embeddings for the provided input type.
"""
# Input validation
data = data.get("inputs", [])
input_keys = [key for key in ["images", "queries"] if key in data]
if len(input_keys) != 1:
return {"error": "Exactly one of 'images', 'queries' must be provided"}
input_type = input_keys[0]
inputs = data[input_type]
if input_type == "images":
if not isinstance(inputs, list):
inputs = [inputs]
if len(inputs) > 8:
return {"message": "Send a maximum of 8 images at once. We recommend sending one by one to improve load balancing."}
# Decode each image from base64 and convert to a PIL Image
decoded_images = []
for img_str in inputs:
try:
img_data = base64.b64decode(img_str)
image = Image.open(io.BytesIO(img_data)).convert("RGB")
decoded_images.append(image)
except Exception as e:
return {"error": f"Error decoding image: {str(e)}"}
# Process the images using the processor
batch = self.processor.process_images(decoded_images).to(self.model.device)
# elif input_type == "processed_images":
# try:
# buffer = io.BytesIO(base64.b64decode(inputs))
# batch = torch.load(buffer, map_location=self.model.device)
# except Exception as e:
# return {"error": f"Error processing preprocessed images: {str(e)}"}
else: # text
if not isinstance(inputs, list):
inputs = [inputs]
try:
batch = self.processor.process_queries(inputs).to(self.model.device)
except Exception as e:
return {"error": f"Error processing text: {str(e)}"}
# Forward pass through the model
with torch.inference_mode():
embeddings = self.model(**batch).tolist()
return {"embeddings": embeddings}
|