Spaces:
Running
Running
File size: 2,580 Bytes
4708376 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
import cv2
import torch
import gc
from transformers import AutoProcessor, AutoModelForImageTextToText
import logging
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.empty_cache()
os.environ['PYTORCH_CUDA_ALLOC_CONF']= 'max_split_size_mb:1024'
gc.collect()
class SMOLVLM2:
def __init__(self,model_name = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" , memory_efficient=True):
self.half = True
self.processor = AutoProcessor.from_pretrained(model_name)
if self.support_flash_attension(device_id=0):
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype=torch.float16,
_attn_implementation="flash_attention_2"
).to(device)
else:
self.model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype=torch.float16,
).to(device)
logging.info("Model loaded")
self.print_gpu_memory()
@staticmethod
def print_gpu_memory():
logging.info(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
logging.info(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
## check for flash attension
@staticmethod
def support_flash_attension(device_id):
""" Check if GPU supports FalshAttension"""
support = False
major, minor = torch.cuda.get_device_capability(device_id)
if major<8:
print("GPU does not support Flash Attension")
else:
support = True
return support
def run_inference_on_image(self,image_path,query):
messages = [
{
"role":"user",
"content":[
{"type":"image","path":image_path},
{"type":"text","text":query}
]
}
]
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt = True,
tokenize = True,
return_dict = True,
return_tensors = 'pt'
)
if self.half:
inputs.to(torch.half).to(device)
else:
inputs.to(device)
generated_ids = self.model.generate(**inputs,do_sample = False , max_new_tokens = 1024)
generated_texts = self.processor.batch_decode(generated_ids,skip_special_tokens=True)
del inputs
torch.cuda.empty_cache()
return generated_texts[0].split('\n')[-1]
|