import torch from torch import nn from typing import Optional, Tuple from .multimodal_config import MultiModalConfig from ..utils.kv_cache import KVCache from ..language.language_model import LanguageModel class CausalLM(nn.Module): def __init__(self, config): super().__init__() self.config = config self.model = LanguageModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def get_input_embeddings(self): return self.model.embed_tokens def tie_weights(self): self.lm_head.weight = self.model.embed_tokens.weight def forward( self, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, kv_cache: Optional[KVCache] = None, ) -> Tuple: outputs = self.model( attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, kv_cache=kv_cache, ) hidden_states = outputs logits = self.lm_head(hidden_states) logits = logits.float() return_data = { "logits": logits, } if kv_cache is not None: return_data["kv_cache"] = kv_cache return return_data class MultiModalProjector(nn.Module): def __init__(self, config: MultiModalConfig): super().__init__() self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True) def forward(self, image_features): hidden_states = self.linear(image_features) return hidden_states