import torch from torch import nn from typing import Optional, Tuple, List from .multimodal_config import MultiModalConfig from .multimodal_components import CausalLM, MultiModalProjector from ..vision.siglip_model import SigLip from ..utils.kv_cache import KVCache class PaliGemmaForConditionalGeneration(nn.Module): def __init__(self, config: MultiModalConfig): super().__init__() self.config = config self.vision_tower = SigLip(config.vision_config) self.multi_modal_projector = MultiModalProjector(config) self.vocab_size = config.vocab_size language_model = CausalLM(config.text_config) self.language_model = language_model self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 def tie_weights(self): return self.language_model.tie_weights() def _merge_input_ids_with_image_features( self, image_features: torch.Tensor, inputs_embeds: torch.Tensor, input_ids: torch.Tensor, attention_mask: torch.Tensor, kv_cache: Optional[KVCache] = None ): _, _, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape dtype, device = inputs_embeds.dtype, inputs_embeds.device scaled_image_features = image_features / (self.config.hidden_size**0.5) final_embedding = torch.zeros(batch_size, sequence_length, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device) text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id) image_mask = input_ids == self.config.image_token_index pad_mask = input_ids == self.pad_token_id text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim) pad_mask_expanded = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim) image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim) final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding) final_embedding = final_embedding.masked_scatter(image_mask_expanded, scaled_image_features) final_embedding = torch.where(pad_mask_expanded, torch.zeros_like(final_embedding), final_embedding) dtype, device = inputs_embeds.dtype, inputs_embeds.device min_dtype = torch.finfo(dtype).min q_len = inputs_embeds.shape[1] if kv_cache is None or kv_cache.num_items() == 0: causal_mask = torch.full( (batch_size, q_len, q_len), fill_value=0, dtype=dtype, device=device ) else: assert q_len == 1 kv_len = kv_cache.num_items() + q_len causal_mask = torch.full((batch_size, q_len, kv_len), fill_value=0, dtype=dtype, device=device) causal_mask = causal_mask.unsqueeze(1) if kv_cache is not None and kv_cache.num_items() > 0: position_ids = attention_mask.cumsum(-1)[:, -1] if position_ids.dim() == 1: position_ids = position_ids.unsqueeze(0) else: position_ids = (attention_mask.cumsum(-1)).masked_fill_((attention_mask == 0), 1).to(device) return final_embedding, causal_mask, position_ids def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, kv_cache: Optional[KVCache] = None, ) -> Tuple: assert torch.all(attention_mask == 1), "The input cannot be padded" inputs_embeds = self.language_model.get_input_embeddings()(input_ids) selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype)) image_features = self.multi_modal_projector(selected_image_feature) inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, kv_cache) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, kv_cache=kv_cache, ) return outputs