|
import base64 |
|
from PIL import Image |
|
from io import BytesIO |
|
from typing import Dict, List, Any |
|
import torch |
|
import requests |
|
|
|
from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig |
|
|
|
dtype = torch.bfloat16 |
|
DEVICE = torch.device('cuda') |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
print(f'Start init, the path is {path}') |
|
try: |
|
quantization_config = BitsAndBytesConfig( |
|
load_in_4bit=True |
|
) |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained("microsoft/Magma-8B", trust_remote_code=True, torch_dtype=dtype, quantization_config=quantization_config) |
|
print("Model is loaded") |
|
self.processor = AutoProcessor.from_pretrained("microsoft/Magma-8B", trust_remote_code=True) |
|
print("Processor is loaded") |
|
self.model.to(DEVICE) |
|
except Exception as e: |
|
print(f"An error occurred!!!!!: {e}") |
|
|
|
def __call__(self, data: Dict[str, bytes]) -> Dict[str, Any]: |
|
|
|
inputs = data.pop("inputs", None) |
|
image_base64 = inputs.get("image", None) |
|
convs = inputs.get("convs", []) |
|
|
|
if not image_base64: |
|
return "error: No base64 encoded image provided" |
|
try: |
|
image = Image.open(BytesIO(base64.b64decode(image_base64))) |
|
image = image.convert("RGB") |
|
except Exception as e: |
|
return f"error: Invalid image data: {str(e)}" |
|
if not isinstance(convs, list): |
|
return "error: Invalid conversation format" |
|
|
|
|
|
""" |
|
Example: |
|
|
|
url = "https://assets-c4akfrf5b4d3f4b7.z01.azurefd.net/assets/2024/04/BMDataViz_661fb89f3845e.png" |
|
image = Image.open(BytesIO(requests.get(url, stream=True).content)) |
|
image = image.convert("RGB") |
|
convs = [ |
|
{"role": "system", "content": "You are agent that can see, talk and act."}, |
|
{"role": "user", "content": "<image_start><image><image_end>\nWhat is in this image?"}, |
|
] |
|
""" |
|
prompt = self.processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True) |
|
inputs = self.processor(images=[image], texts=prompt, return_tensors="pt") |
|
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0) |
|
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0) |
|
inputs = inputs.to(dtype).to(DEVICE) |
|
|
|
self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id |
|
|
|
with torch.inference_mode(): |
|
output_ids = self.model.generate( |
|
**inputs, |
|
temperature=0.0, |
|
do_sample=False, |
|
num_beams=1, |
|
max_new_tokens=128, |
|
use_cache=True |
|
) |
|
|
|
prompt_decoded = self.processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0] |
|
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0] |
|
response = response.replace(prompt_decoded, '').strip() |
|
return {"response": response} |
|
|