File size: 3,680 Bytes
690f890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.

import cv2
import torch
import numpy as np
import torchvision
from .utils import convert_to_numpy


class GDINOAnnotator:
    def __init__(self, cfg, device=None):
        try:
            from groundingdino.util.inference import Model, load_model, load_image, predict
        except:
            import warnings
            warnings.warn("please pip install groundingdino package, or you can refer to models/VACE-Annotators/gdino/groundingdino-0.1.0-cp310-cp310-linux_x86_64.whl")

        grounding_dino_config_path = cfg['CONFIG_PATH']
        grounding_dino_checkpoint_path = cfg['PRETRAINED_MODEL']
        grounding_dino_tokenizer_path = cfg['TOKENIZER_PATH']  # TODO
        self.box_threshold = cfg.get('BOX_THRESHOLD', 0.25)
        self.text_threshold = cfg.get('TEXT_THRESHOLD', 0.2)
        self.iou_threshold = cfg.get('IOU_THRESHOLD', 0.5)
        self.use_nms = cfg.get('USE_NMS', True)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        self.model = Model(model_config_path=grounding_dino_config_path,
                           model_checkpoint_path=grounding_dino_checkpoint_path,
                           device=self.device)

    def forward(self, image, classes=None, caption=None):
        image_bgr = convert_to_numpy(image)[..., ::-1]  # bgr

        if classes is not None:
            classes = [classes] if isinstance(classes, str) else classes
            detections = self.model.predict_with_classes(
                image=image_bgr,
                classes=classes,
                box_threshold=self.box_threshold,
                text_threshold=self.text_threshold
            )
        elif caption is not None:
            detections, phrases = self.model.predict_with_caption(
                image=image_bgr,
                caption=caption,
                box_threshold=self.box_threshold,
                text_threshold=self.text_threshold
            )
        else:
            raise NotImplementedError()

        if self.use_nms:
            nms_idx = torchvision.ops.nms(
                torch.from_numpy(detections.xyxy),
                torch.from_numpy(detections.confidence),
                self.iou_threshold
            ).numpy().tolist()
            detections.xyxy = detections.xyxy[nms_idx]
            detections.confidence = detections.confidence[nms_idx]
            detections.class_id = detections.class_id[nms_idx] if detections.class_id is not None else None

        boxes = detections.xyxy
        confidences = detections.confidence
        class_ids = detections.class_id
        class_names = [classes[_id] for _id in class_ids] if classes is not None else phrases

        ret_data = {
            "boxes": boxes.tolist() if boxes is not None else None,
            "confidences": confidences.tolist() if confidences is not None else None,
            "class_ids": class_ids.tolist() if class_ids is not None else None,
            "class_names": class_names if class_names is not None else None,
        }
        return ret_data


class GDINORAMAnnotator:
    def __init__(self, cfg, device=None):
        from .ram import RAMAnnotator
        from .gdino import GDINOAnnotator
        self.ram_model = RAMAnnotator(cfg['RAM'], device=device)
        self.gdino_model = GDINOAnnotator(cfg['GDINO'], device=device)

    def forward(self, image):
        ram_res = self.ram_model.forward(image)
        classes = ram_res['tag_e'] if isinstance(ram_res, dict) else ram_res
        gdino_res = self.gdino_model.forward(image, classes=classes)
        return gdino_res