Magma-8B / handler.py
xuanzhaopeng's picture
Read image from api
777b42e
import base64
from PIL import Image
from io import BytesIO
from typing import Dict, List, Any
import torch
import requests
from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig
dtype = torch.bfloat16
DEVICE = torch.device('cuda')
class EndpointHandler:
def __init__(self, path=""):
print(f'Start init, the path is {path}')
try:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True # Because the shard is too big, it will be OOM, so load in 4bit
)
self.model = AutoModelForCausalLM.from_pretrained("microsoft/Magma-8B", trust_remote_code=True, torch_dtype=dtype, quantization_config=quantization_config)
print("Model is loaded")
self.processor = AutoProcessor.from_pretrained("microsoft/Magma-8B", trust_remote_code=True)
print("Processor is loaded")
self.model.to(DEVICE)
except Exception as e:
print(f"An error occurred!!!!!: {e}")
def __call__(self, data: Dict[str, bytes]) -> Dict[str, Any]:
inputs = data.pop("inputs", None)
image_base64 = inputs.get("image", None)
convs = inputs.get("convs", [])
if not image_base64:
return "error: No base64 encoded image provided"
try:
image = Image.open(BytesIO(base64.b64decode(image_base64)))
image = image.convert("RGB")
except Exception as e:
return f"error: Invalid image data: {str(e)}"
if not isinstance(convs, list):
return "error: Invalid conversation format"
"""
Example:
url = "https://assets-c4akfrf5b4d3f4b7.z01.azurefd.net/assets/2024/04/BMDataViz_661fb89f3845e.png"
image = Image.open(BytesIO(requests.get(url, stream=True).content))
image = image.convert("RGB")
convs = [
{"role": "system", "content": "You are agent that can see, talk and act."},
{"role": "user", "content": "<image_start><image><image_end>\nWhat is in this image?"},
]
"""
prompt = self.processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
inputs = self.processor(images=[image], texts=prompt, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
inputs = inputs.to(dtype).to(DEVICE)
self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id
with torch.inference_mode():
output_ids = self.model.generate(
**inputs,
temperature=0.0,
do_sample=False,
num_beams=1,
max_new_tokens=128,
use_cache=True
)
prompt_decoded = self.processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0]
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0]
response = response.replace(prompt_decoded, '').strip()
return {"response": response}