File size: 6,032 Bytes
99aee7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
# not tested yet
import math
from transformers import CLIPImageProcessor

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from .eva_clip import create_model_and_transforms, get_model_config
import torch
import torchvision
import time

# from llava.utils import print


class EvaViTWrapper(nn.Module):
    def __init__(self, vision_tower, args, delay_load=False):
        super().__init__()

        self.is_loaded = False
        self.vision_tower_name = vision_tower
        self.pretrained = args.vision_tower_pretrained
        self.args = args

        self.select_layer = args.mm_vision_select_layer
        if self.select_layer < -1:
            self.select_layer += 1
        self.select_feature = getattr(args, "mm_vision_select_feature", "patch")

        self.model_config = get_model_config(self.vision_tower_name)

        if not delay_load:
            print(f"Loading vision tower: {vision_tower}")
            self.load_model()
        elif getattr(args, "unfreeze_mm_vision_tower", False):
            # TODO: better detector is needed.
            print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
            self.load_model()
        elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
            print(f"The checkpoint seems to contain `vision_tower` weights: `mm_tunable_parts` contains `mm_vision_tower`.")
            self.load_model()

    def load_model(self):
        print(f"Loading: {self.vision_tower_name}")
        print(f"Pretrained: {self.pretrained}")
        time_start = time.time()
        model, _, image_processor = create_model_and_transforms(self.vision_tower_name, self.pretrained, force_custom_clip=True, precision="fp16")
        time_end = time.time()
        print(f"Loaded: {self.vision_tower_name} in {time_end - time_start:.2f}s")
        self.device = next(model.parameters()).device
        self.dtype = next(model.parameters()).dtype
        if self.device.type != "meta":
            model = model.to("cuda")
        self.vision_tower = model.visual
        resize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Resize)][0]
        normalize_transform = [t for t in image_processor.transforms if isinstance(t, torchvision.transforms.Normalize)][0]
        self.resize_transform_size = resize_transform.size
        self.image_processor = CLIPImageProcessor.from_pretrained(
            "openai/clip-vit-large-patch14",
            crop_size=resize_transform.size,
            size={"shortest_edge": resize_transform.size},
            image_mean=list(normalize_transform.mean),
            image_std=list(normalize_transform.std),
        )
        print(f"Loaded image processor: {self.image_processor}")
        self.vision_tower.requires_grad_(False)
        self.is_loaded = True

    def feature_select(self, image_features):
        select_feature_type = self.select_feature

        # if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
        #     select_every_k_layer = len(image_features) // 4
        #     image_features = torch.cat([image_features[i] for i in range(select_every_k_layer + self.select_layer, len(image_features), select_every_k_layer)], dim=-1)
        #     select_feature_type = select_feature_type.replace("slicefour_", "")
        # elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
        #     select_layers = [-1, -4, -7, -10, 6]
        #     image_features = torch.cat([image_features[i] for i in select_layers], dim=-1)
        #     select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
        # else:
        #     image_features = image_features[self.select_layer]

        if select_feature_type == "patch":
            image_features = image_features[:, 1:]
        elif select_feature_type == "cls_patch":
            image_features = image_features
        else:
            raise ValueError(f"Unexpected select feature: {select_feature_type}")
        return image_features

    def train(self, mode=True):
        self.training = mode

        if self.is_loaded:
            self.vision_tower.eval()

    def forward(self, images):
        if type(images) is list:
            image_features = []
            for image in images:
                image_features = self.vision_tower.forward_features(image.to(self.dtype), return_all_features=True)
                image_features = self.feature_select(image_features).to(self.dtype)
                image_features.append(image_features)
        else:
            image_features = self.vision_tower.forward_features(images.to(self.dtype), return_all_features=True)
            image_features = self.feature_select(image_features).to(self.dtype)

        return image_features

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def hidden_size(self):
        return self.model_config["vision_cfg"]["width"]

    @property
    def num_patches(self):
        return (self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]) ** 2

    @property
    def num_patches_per_side(self):
        return self.model_config["vision_cfg"]["image_size"] // self.model_config["vision_cfg"]["patch_size"]

    @property
    def config(self):
        return self.model_config

    @property
    def image_size(self):
        return self.model_config["vision_cfg"]["image_size"]