File size: 3,168 Bytes
15ff5ca c845fe5 15ff5ca 4ede309 15ff5ca 7622c91 15ff5ca 4ede309 a79325e 4ede309 f69eebd 4ede309 7622c91 139d3a1 2e105fa 139d3a1 7622c91 a79325e 93415ef 15ff5ca f69eebd 34d655e 777b42e c845fe5 777b42e c845fe5 3658eeb 15ff5ca 7622c91 15ff5ca 4506fca 15ff5ca 4506fca f69eebd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import base64
from PIL import Image
from io import BytesIO
from typing import Dict, List, Any
import torch
import requests
from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig
dtype = torch.bfloat16
DEVICE = torch.device('cuda')
class EndpointHandler:
def __init__(self, path=""):
print(f'Start init, the path is {path}')
try:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True # Because the shard is too big, it will be OOM, so load in 4bit
)
self.model = AutoModelForCausalLM.from_pretrained("microsoft/Magma-8B", trust_remote_code=True, torch_dtype=dtype, quantization_config=quantization_config)
print("Model is loaded")
self.processor = AutoProcessor.from_pretrained("microsoft/Magma-8B", trust_remote_code=True)
print("Processor is loaded")
self.model.to(DEVICE)
except Exception as e:
print(f"An error occurred!!!!!: {e}")
def __call__(self, data: Dict[str, bytes]) -> Dict[str, Any]:
inputs = data.pop("inputs", None)
image_base64 = inputs.get("image", None)
convs = inputs.get("convs", [])
if not image_base64:
return "error: No base64 encoded image provided"
try:
image = Image.open(BytesIO(base64.b64decode(image_base64)))
image = image.convert("RGB")
except Exception as e:
return f"error: Invalid image data: {str(e)}"
if not isinstance(convs, list):
return "error: Invalid conversation format"
"""
Example:
url = "https://assets-c4akfrf5b4d3f4b7.z01.azurefd.net/assets/2024/04/BMDataViz_661fb89f3845e.png"
image = Image.open(BytesIO(requests.get(url, stream=True).content))
image = image.convert("RGB")
convs = [
{"role": "system", "content": "You are agent that can see, talk and act."},
{"role": "user", "content": "<image_start><image><image_end>\nWhat is in this image?"},
]
"""
prompt = self.processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
inputs = self.processor(images=[image], texts=prompt, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
inputs = inputs.to(dtype).to(DEVICE)
self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id
with torch.inference_mode():
output_ids = self.model.generate(
**inputs,
temperature=0.0,
do_sample=False,
num_beams=1,
max_new_tokens=128,
use_cache=True
)
prompt_decoded = self.processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0]
response = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0]
response = response.replace(prompt_decoded, '').strip()
return {"response": response}
|