|
import torch |
|
from torch import nn |
|
import numpy as np |
|
from typing import Optional, Tuple, List, Union |
|
from transformers import Qwen2VLForConditionalGeneration |
|
import logging |
|
import warnings |
|
from PIL import Image |
|
from transformers.image_utils import load_image |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
LOGIT_BIAS = 2.65 |
|
|
|
def load_images(images, lazy_load: bool = True): |
|
|
|
pil_max_px = Image.MAX_IMAGE_PIXELS |
|
Image.MAX_IMAGE_PIXELS = None |
|
|
|
images_batch = [] |
|
for image in images: |
|
if isinstance(image, Image.Image): |
|
images_batch.append(image) |
|
else: |
|
pil_image = load_image(image) |
|
if lazy_load: |
|
images_batch.append(pil_image) |
|
else: |
|
|
|
images_batch.append(pil_image.copy()) |
|
pil_image.close() |
|
Image.MAX_IMAGE_PIXELS = pil_max_px |
|
|
|
return images_batch |
|
|
|
|
|
def formatting_prompts_func( |
|
query: str, |
|
doc: str, |
|
query_type: str = 'text', |
|
doc_type: str = 'text', |
|
prefix_str: str = '', |
|
) -> str: |
|
""" |
|
Format prompts for different combinations of query and content types. |
|
|
|
Args: |
|
query: Query text or image path |
|
doc: Content text or image path |
|
query_type: Whether query is an image |
|
doc_type: Whether content is an image |
|
prefix_str: Optional prefix string to add |
|
""" |
|
|
|
if query_type == 'image': |
|
query_part = "**Query**:\n<|vision_start|><|image_pad|><|vision_end|>" |
|
else: |
|
query_part = f"**Query**:\n{query}" |
|
|
|
|
|
if doc_type == 'image': |
|
doc_part = "**Document**:\n<|vision_start|><|image_pad|><|vision_end|>" |
|
else: |
|
doc_part = f"**Document**:\n{doc}" |
|
|
|
|
|
prompt = doc_part + '\n' + query_part |
|
|
|
|
|
if prefix_str: |
|
prompt = prefix_str + '\n' + prompt |
|
|
|
return prompt |
|
|
|
|
|
class JinaVLForRanking(Qwen2VLForConditionalGeneration): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.padding_side = "left" |
|
self.num_labels = 1 |
|
|
|
|
|
self.lm_head = nn.Identity() |
|
|
|
|
|
self.score = nn.Sequential( |
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
nn.ReLU(), |
|
nn.Linear(config.hidden_size, self.num_labels), |
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
self.score_token_id = 100 |
|
|
|
def forward(self, *args, **kwargs) -> torch.Tensor: |
|
|
|
kwargs.pop("output_hidden_states", None) |
|
kwargs.pop("use_cache", None) |
|
assert kwargs.pop("labels", None) is None, "labels should not be passed to forward()" |
|
|
|
outputs = super().forward( |
|
*args, |
|
use_cache=False, |
|
output_hidden_states=True, |
|
**kwargs, |
|
) |
|
|
|
|
|
hidden_states = outputs.hidden_states[-1] |
|
|
|
|
|
|
|
pooled_logits = self.score(hidden_states[:, -1]) |
|
|
|
return pooled_logits.squeeze(-1) |
|
|
|
@torch.no_grad() |
|
def compute_score( |
|
self, |
|
pairs: Union[List[Tuple[str, str]], Tuple[str, str]], |
|
batch_size: int = 8, |
|
max_length: int = 10240, |
|
max_query_length: int = 512, |
|
max_doc_length: Optional[int] = None, |
|
query_type: str = 'text', |
|
doc_type: str = 'text', |
|
normalize_scores: bool = True, |
|
show_progress: bool = False, |
|
) -> List[float]: |
|
|
|
if not hasattr(self, "_processor"): |
|
from transformers import AutoProcessor |
|
|
|
self._processor = AutoProcessor.from_pretrained( |
|
self.name_or_path, max_pixels=602112, min_pixels=3136, trust_remote_code=True |
|
) |
|
|
|
assert isinstance(pairs, list) |
|
|
|
if isinstance(pairs[0], str): |
|
pairs = [pairs] |
|
|
|
max_length = max_length or self.config.max_length |
|
|
|
if max_doc_length is None: |
|
max_doc_length = max(max_length - max_query_length, max_query_length) |
|
|
|
if max_doc_length < max_query_length: |
|
warnings.warn( |
|
f"max_doc_length={max_doc_length} should be greater than max_query_length={max_query_length}" |
|
) |
|
|
|
assert ( |
|
max_doc_length + max_query_length <= max_length |
|
), f"max_doc_length ({max_doc_length}) + max_query_length ({max_query_length}) should be less than max_length ({max_length})" |
|
|
|
max_length = max_length - 1 |
|
|
|
all_scores = [] |
|
|
|
device = next(self.parameters()).device |
|
|
|
batch_iter = range(0, len(pairs), batch_size) |
|
if show_progress: |
|
from tqdm import trange |
|
|
|
batch_iter = trange(0, len(pairs), batch_size, desc="Computing scores") |
|
|
|
for start_index in batch_iter: |
|
mini_batch = pairs[start_index : start_index + batch_size] |
|
|
|
batch_inputs = [] |
|
for q, d in mini_batch: |
|
|
|
if doc_type == 'text': |
|
tokens = self._processor.tokenizer(d, truncation=True, max_length=max_doc_length) |
|
if len(tokens['input_ids']) >= max_doc_length: |
|
d = self._processor.tokenizer.decode(tokens['input_ids']) |
|
|
|
batch_inputs.append(formatting_prompts_func(q, d, query_type=query_type, doc_type=doc_type)) |
|
|
|
batch_images = None |
|
|
|
|
|
|
|
|
|
|
|
doc_images = [] |
|
query_images = [] |
|
if doc_type == 'image': |
|
doc_images = load_images([d for (q, d) in mini_batch]) |
|
if query_type == 'image': |
|
query_images = load_images([q for (q, d) in mini_batch]) |
|
|
|
if len(doc_images) == len(query_images) and len(doc_images) > 0: |
|
batch_images = [[d, q] for q, d in zip(query_images, doc_images)] |
|
elif len(doc_images) > 0: |
|
batch_images = doc_images |
|
elif len(query_images) > 0: |
|
batch_images = query_images |
|
|
|
batch = self._processor( |
|
text=batch_inputs, |
|
images=batch_images, |
|
return_tensors="pt", |
|
padding=True, |
|
truncation=True, |
|
max_length=max_length, |
|
) |
|
|
|
|
|
batch_size = batch["input_ids"].size(0) |
|
batch["input_ids"] = torch.cat( |
|
[ |
|
batch["input_ids"], |
|
torch.full((batch_size, 1), self.score_token_id, device=batch["input_ids"].device), |
|
], |
|
dim=1, |
|
) |
|
batch["attention_mask"] = torch.cat( |
|
[ |
|
batch["attention_mask"], |
|
torch.ones((batch_size, 1), device=batch["attention_mask"].device), |
|
], |
|
dim=1, |
|
) |
|
|
|
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} |
|
|
|
scores = self.forward(**batch).view(-1).cpu().float().numpy() |
|
|
|
|
|
scores = 1.0 / (1.0 + np.exp(-(scores - LOGIT_BIAS))) |
|
|
|
all_scores.extend(scores.tolist()) |
|
|
|
if len(all_scores) == 1: |
|
return all_scores[0] |
|
return all_scores |
|
|