# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import cv2 import torch import numpy as np import torchvision from .utils import convert_to_numpy class GDINOAnnotator: def __init__(self, cfg, device=None): try: from groundingdino.util.inference import Model, load_model, load_image, predict except: import warnings warnings.warn("please pip install groundingdino package, or you can refer to models/VACE-Annotators/gdino/groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl") grounding_dino_config_path = cfg['CONFIG_PATH'] grounding_dino_checkpoint_path = cfg['PRETRAINED_MODEL'] grounding_dino_tokenizer_path = cfg['TOKENIZER_PATH'] # TODO self.box_threshold = cfg.get('BOX_THRESHOLD', 0.25) self.text_threshold = cfg.get('TEXT_THRESHOLD', 0.2) self.iou_threshold = cfg.get('IOU_THRESHOLD', 0.5) self.use_nms = cfg.get('USE_NMS', True) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.model = Model(model_config_path=grounding_dino_config_path, model_checkpoint_path=grounding_dino_checkpoint_path, device=self.device) def forward(self, image, classes=None, caption=None): image_bgr = convert_to_numpy(image)[..., ::-1] # bgr if classes is not None: classes = [classes] if isinstance(classes, str) else classes detections = self.model.predict_with_classes( image=image_bgr, classes=classes, box_threshold=self.box_threshold, text_threshold=self.text_threshold ) elif caption is not None: detections, phrases = self.model.predict_with_caption( image=image_bgr, caption=caption, box_threshold=self.box_threshold, text_threshold=self.text_threshold ) else: raise NotImplementedError() if self.use_nms: nms_idx = torchvision.ops.nms( torch.from_numpy(detections.xyxy), torch.from_numpy(detections.confidence), self.iou_threshold ).numpy().tolist() detections.xyxy = detections.xyxy[nms_idx] detections.confidence = detections.confidence[nms_idx] detections.class_id = detections.class_id[nms_idx] if detections.class_id is not None else None boxes = detections.xyxy confidences = detections.confidence class_ids = detections.class_id class_names = [classes[_id] for _id in class_ids] if classes is not None else phrases ret_data = { "boxes": boxes.tolist() if boxes is not None else None, "confidences": confidences.tolist() if confidences is not None else None, "class_ids": class_ids.tolist() if class_ids is not None else None, "class_names": class_names if class_names is not None else None, } return ret_data class GDINORAMAnnotator: def __init__(self, cfg, device=None): from .ram import RAMAnnotator from .gdino import GDINOAnnotator self.ram_model = RAMAnnotator(cfg['RAM'], device=device) self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device) def forward(self, image): ram_res = self.ram_model.forward(image) classes = ram_res['tag_e'] if isinstance(ram_res, dict) else ram_res gdino_res = self.gdino_model.forward(image, classes=classes) return gdino_res