File size: 8,183 Bytes
a93850a
 
b9e3a5c
030f959
a93850a
 
 
 
 
 
 
 
f2aa0b2
a93850a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
030f959
a93850a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3b9a22
7fc0965
a93850a
 
 
b9e3a5c
a93850a
 
 
 
 
030f959
c7891eb
 
 
a93850a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
030f959
a93850a
 
e14ddde
 
 
 
 
 
 
a93850a
e14ddde
 
 
 
 
 
 
 
 
 
a93850a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b9e3a5c
 
2f1cfb0
7fc0965
b9e3a5c
 
a93850a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import torch
from torch import nn
import numpy as np
from typing import Optional, Tuple, List, Union
from transformers import Qwen2VLForConditionalGeneration
import logging
import warnings
from PIL import Image
from transformers.image_utils import load_image

logger = logging.getLogger(__name__)

LOGIT_BIAS = 2.65  # logit bias for sigmoid normalization

def load_images(images, lazy_load: bool = True):
    # Disable PIL DecompositionBomb threshold for reading large images.
    pil_max_px = Image.MAX_IMAGE_PIXELS
    Image.MAX_IMAGE_PIXELS = None

    images_batch = []
    for image in images:
        if isinstance(image, Image.Image):
            images_batch.append(image)
        else:
            pil_image = load_image(image)
            if lazy_load:
                images_batch.append(pil_image)
            else:
                # avoid Too many open files error
                images_batch.append(pil_image.copy())
                pil_image.close()
    Image.MAX_IMAGE_PIXELS = pil_max_px

    return images_batch


def formatting_prompts_func(
    query: str,
    doc: str,
    query_type: str = 'text',
    doc_type: str = 'text',
    prefix_str: str = '',
) -> str:
    """
    Format prompts for different combinations of query and content types.

    Args:
        query: Query text or image path
        doc: Content text or image path
        query_type: Whether query is an image
        doc_type: Whether content is an image
        prefix_str: Optional prefix string to add
    """
    # Format query part
    if query_type == 'image':
        query_part = "**Query**:\n<|vision_start|><|image_pad|><|vision_end|>"
    else:
        query_part = f"**Query**:\n{query}"

    # Format content part
    if doc_type == 'image':
        doc_part = "**Document**:\n<|vision_start|><|image_pad|><|vision_end|>"
    else:
        doc_part = f"**Document**:\n{doc}"

    # Combine parts
    prompt = doc_part + '\n' + query_part

    # Add prefix if provided
    if prefix_str:
        prompt = prefix_str + '\n' + prompt

    return prompt


class JinaVLForRanking(Qwen2VLForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)

        self.padding_side = "left"
        self.num_labels = 1  # config.num_labels

        # hack the lm_head to do nothing, since we only want the hidden states
        self.lm_head = nn.Identity()

        # copy the idea from `Qwen2ForRewardModel` to have a MLP layer to get the final score
        self.score = nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            nn.ReLU(),
            nn.Linear(config.hidden_size, self.num_labels),
        )

        # Initialize weights and apply final processing
        self.post_init()

        self.score_token_id = 100

    def forward(self, *args, **kwargs) -> torch.Tensor:
        # Delete output_hidden_states from kwargs
        kwargs.pop("output_hidden_states", None)
        kwargs.pop("use_cache", None)
        assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()"

        outputs = super().forward(
            *args,
            use_cache=False,
            output_hidden_states=True,
            **kwargs,
        )

        # get the hidden states of the last layer
        hidden_states = outputs.hidden_states[-1]

        # IMPORTANT: the padding token must be on the left side
        # get the hidden states of the last token and apply the linear layer
        pooled_logits = self.score(hidden_states[:, -1])

        return pooled_logits.squeeze(-1)

    @torch.no_grad()
    def compute_score(
        self,
        pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
        batch_size: int = 8,
        max_length: int = 10240,
        max_query_length: int = 512,
        max_doc_length: Optional[int] = None,
        query_type: str = 'text',
        doc_type: str = 'text',
        normalize_scores: bool = True,
        show_progress: bool = False,
    ) -> List[float]:

        if not hasattr(self, "_processor"):
            from transformers import AutoProcessor

            self._processor = AutoProcessor.from_pretrained(
                self.name_or_path, max_pixels=602112, min_pixels=3136, trust_remote_code=True
            )

        assert isinstance(pairs, list)

        if isinstance(pairs[0], str):
            pairs = [pairs]

        max_length = max_length or self.config.max_length

        if max_doc_length is None:
            max_doc_length = max(max_length - max_query_length, max_query_length)

        if max_doc_length < max_query_length:
            warnings.warn(
                f"max_doc_length={max_doc_length} should be greater than max_query_length={max_query_length}"
            )

        assert (
            max_doc_length + max_query_length <= max_length
        ), f"max_doc_length ({max_doc_length}) + max_query_length ({max_query_length}) should be less than max_length ({max_length})"

        max_length = max_length - 1

        all_scores = []

        device = next(self.parameters()).device

        batch_iter = range(0, len(pairs), batch_size)
        if show_progress:
            from tqdm import trange

            batch_iter = trange(0, len(pairs), batch_size, desc="Computing scores")

        for start_index in batch_iter:
            mini_batch = pairs[start_index : start_index + batch_size]

            batch_inputs = []
            for q, d in mini_batch:
                # TEMP FIX: Truncate long documents
                if doc_type == 'text':
                    tokens = self._processor.tokenizer(d, truncation=True, max_length=max_doc_length)
                    if len(tokens['input_ids']) >= max_doc_length:
                        d = self._processor.tokenizer.decode(tokens['input_ids'])

                batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type))

            batch_images = None
            # if doc_type == 'image':
            #     batch_images = load_images([d for (q, d) in mini_batch])
            # elif query_type == 'image':
            #     batch_images = load_images([q for (q, d) in mini_batch])

            doc_images = []
            query_images = []
            if doc_type == 'image':
                doc_images = load_images([d for (q, d) in mini_batch])
            if query_type == 'image':
                query_images = load_images([q for (q, d) in mini_batch])

            if len(doc_images) == len(query_images) and len(doc_images) > 0:
                batch_images = [[d, q] for q, d in zip(query_images, doc_images)]
            elif len(doc_images) > 0:
                batch_images = doc_images
            elif len(query_images) > 0:
                batch_images = query_images

            batch = self._processor(
                text=batch_inputs,
                images=batch_images,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length,
            )

            # append the reward token to the input_ids and attention_mask
            batch_size = batch["input_ids"].size(0)
            batch["input_ids"] = torch.cat(
                [
                    batch["input_ids"],
                    torch.full((batch_size, 1), self.score_token_id, device=batch["input_ids"].device),
                ],
                dim=1,
            )
            batch["attention_mask"] = torch.cat(
                [
                    batch["attention_mask"],
                    torch.ones((batch_size, 1), device=batch["attention_mask"].device),
                ],
                dim=1,
            )
            # move the batch to the correct device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}

            scores = self.forward(**batch).view(-1).cpu().float().numpy()

            # normalize scores to [0, 1] with sigmoid with a scale
            scores = 1.0 / (1.0 + np.exp(-(scores - LOGIT_BIAS)))

            all_scores.extend(scores.tolist())

        if len(all_scores) == 1:
            return all_scores[0]
        return all_scores