File size: 2,580 Bytes
4708376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os 
import cv2 
import torch
import gc 
from transformers import AutoProcessor, AutoModelForImageTextToText
import logging

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF']= 'max_split_size_mb:1024'
gc.collect()

class SMOLVLM2:
    def __init__(self,model_name = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" , memory_efficient=True):
        self.half = True
        self.processor = AutoProcessor.from_pretrained(model_name)        
        if self.support_flash_attension(device_id=0):
            self.model = AutoModelForImageTextToText.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
                _attn_implementation="flash_attention_2" 
            ).to(device)
        else:
            self.model = AutoModelForImageTextToText.from_pretrained(
                model_name,
                torch_dtype=torch.float16,
            ).to(device)
        logging.info("Model loaded")
        self.print_gpu_memory()

    @staticmethod
    def print_gpu_memory():
        logging.info(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
        logging.info(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

    ## check for flash attension
    @staticmethod
    def support_flash_attension(device_id):
        """ Check if GPU supports FalshAttension"""
        support = False
        major, minor = torch.cuda.get_device_capability(device_id)
        if major<8:
            print("GPU does not support Flash Attension")
        else: 
            support = True 
        return support
    
    def run_inference_on_image(self,image_path,query):
        messages = [
            {
                "role":"user",
                "content":[
                    {"type":"image","path":image_path},
                    {"type":"text","text":query}
                ]
            }
        ]
        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt = True,
            tokenize = True,
            return_dict = True,
            return_tensors = 'pt'
        )
        if self.half:
            inputs.to(torch.half).to(device)
        else:
            inputs.to(device)
        generated_ids = self.model.generate(**inputs,do_sample = False , max_new_tokens = 1024)
        generated_texts = self.processor.batch_decode(generated_ids,skip_special_tokens=True)
        del inputs
        torch.cuda.empty_cache()
        return generated_texts[0].split('\n')[-1]