File size: 3,168 Bytes
15ff5ca
 
 
 
 
c845fe5
15ff5ca
4ede309
15ff5ca
7622c91
 
 
15ff5ca
 
4ede309
a79325e
4ede309
f69eebd
4ede309
 
7622c91
139d3a1
2e105fa
139d3a1
7622c91
a79325e
93415ef
15ff5ca
f69eebd
 
 
 
 
34d655e
777b42e
 
 
 
 
 
 
 
 
c845fe5
777b42e
 
 
c845fe5
 
 
3658eeb
15ff5ca
 
 
 
 
 
 
 
 
7622c91
15ff5ca
4506fca
15ff5ca
4506fca
 
 
 
 
 
 
 
 
 
 
 
 
f69eebd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import base64
from PIL import Image
from io import BytesIO
from typing import Dict, List, Any
import torch
import requests

from transformers import AutoModelForCausalLM, AutoProcessor, BitsAndBytesConfig

dtype = torch.bfloat16
DEVICE = torch.device('cuda')  

class EndpointHandler:
    def __init__(self, path=""):
        print(f'Start init, the path is {path}')
        try:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True # Because the shard is too big, it will be OOM, so load in 4bit
            )

            self.model = AutoModelForCausalLM.from_pretrained("microsoft/Magma-8B", trust_remote_code=True, torch_dtype=dtype, quantization_config=quantization_config)
            print("Model is loaded")
            self.processor = AutoProcessor.from_pretrained("microsoft/Magma-8B", trust_remote_code=True)
            print("Processor is loaded")
            self.model.to(DEVICE)
        except Exception as e:
            print(f"An error occurred!!!!!: {e}")

    def __call__(self, data: Dict[str, bytes]) -> Dict[str, Any]:

        inputs = data.pop("inputs", None)
        image_base64 = inputs.get("image", None)
        convs = inputs.get("convs", [])
        
        if not image_base64:
            return "error: No base64 encoded image provided"
        try:
            image = Image.open(BytesIO(base64.b64decode(image_base64)))
            image = image.convert("RGB")
        except Exception as e:
            return f"error: Invalid image data: {str(e)}"
        if not isinstance(convs, list):
            return "error: Invalid conversation format"
        

        """
        Example:

        url = "https://assets-c4akfrf5b4d3f4b7.z01.azurefd.net/assets/2024/04/BMDataViz_661fb89f3845e.png"
        image = Image.open(BytesIO(requests.get(url, stream=True).content))
        image = image.convert("RGB")
        convs = [
            {"role": "system", "content": "You are agent that can see, talk and act."},
            {"role": "user", "content": "<image_start><image><image_end>\nWhat is in this image?"},
        ]
        """
        prompt = self.processor.tokenizer.apply_chat_template(convs, tokenize=False, add_generation_prompt=True)
        inputs = self.processor(images=[image], texts=prompt, return_tensors="pt")
        inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
        inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
        inputs = inputs.to(dtype).to(DEVICE)

        self.model.generation_config.pad_token_id = self.processor.tokenizer.pad_token_id

        with torch.inference_mode():
            output_ids = self.model.generate(
                **inputs, 
                temperature=0.0, 
                do_sample=False, 
                num_beams=1, 
                max_new_tokens=128, 
                use_cache=True
            )
    
        prompt_decoded = self.processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0]
        response = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0]
        response = response.replace(prompt_decoded, '').strip()
        return {"response":  response}