Spaces:
Running
Running
import os | |
import cv2 | |
import torch | |
import gc | |
from transformers import AutoProcessor, AutoModelForImageTextToText | |
import logging | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
torch.cuda.empty_cache() | |
os.environ['PYTORCH_CUDA_ALLOC_CONF']= 'max_split_size_mb:1024' | |
gc.collect() | |
class SMOLVLM2: | |
def __init__(self,model_name = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" , memory_efficient=True): | |
self.half = True | |
self.processor = AutoProcessor.from_pretrained(model_name) | |
if self.support_flash_attension(device_id=0): | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
_attn_implementation="flash_attention_2" | |
).to(device) | |
else: | |
self.model = AutoModelForImageTextToText.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
).to(device) | |
logging.info("Model loaded") | |
self.print_gpu_memory() | |
def print_gpu_memory(): | |
logging.info(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
logging.info(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") | |
## check for flash attension | |
def support_flash_attension(device_id): | |
""" Check if GPU supports FalshAttension""" | |
support = False | |
major, minor = torch.cuda.get_device_capability(device_id) | |
if major<8: | |
print("GPU does not support Flash Attension") | |
else: | |
support = True | |
return support | |
def run_inference_on_image(self,image_path,query): | |
messages = [ | |
{ | |
"role":"user", | |
"content":[ | |
{"type":"image","path":image_path}, | |
{"type":"text","text":query} | |
] | |
} | |
] | |
inputs = self.processor.apply_chat_template( | |
messages, | |
add_generation_prompt = True, | |
tokenize = True, | |
return_dict = True, | |
return_tensors = 'pt' | |
) | |
if self.half: | |
inputs.to(torch.half).to(device) | |
else: | |
inputs.to(device) | |
generated_ids = self.model.generate(**inputs,do_sample = False , max_new_tokens = 1024) | |
generated_texts = self.processor.batch_decode(generated_ids,skip_special_tokens=True) | |
del inputs | |
torch.cuda.empty_cache() | |
return generated_texts[0].split('\n')[-1] | |