import warnings from typing import Any, List, Optional, Tuple, Union import torch.utils.checkpoint import transformers from torch import nn from torch.nn import CrossEntropyLoss from transformers import GenerationConfig from transformers import Qwen2ForCausalLM from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput, logging import torch.nn.functional as F from .configuration_llavaqw import LlavaQwConfig from .conversation import get_conv_template from .modeling_intern_vit import InternVisionModel, has_flash_attn logger = logging.get_logger(__name__) def version_cmp(v1, v2, op='eq'): import operator from packaging import version op_func = getattr(operator, op) return op_func(version.parse(v1), version.parse(v2)) class LlavaQwModel(PreTrainedModel): config_class = LlavaQwConfig main_input_name = 'pixel_values' _supports_flash_attn_2 = True _no_split_modules = ['InternVisionModel', 'Qwen2DecoderLayer'] def __init__(self, config: LlavaQwConfig, vision_model=None, language_model=None, use_flash_attn=True): super().__init__(config) assert version_cmp(transformers.__version__, '4.44.2', 'ge') image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.patch_size = patch_size self.select_layer = config.select_layer self.llm_arch_name = config.llm_config.architectures[0] self.template = config.template self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version use_flash_attn = use_flash_attn if has_flash_attn else False config.vision_config.use_flash_attn = True if use_flash_attn else False config.llm_config._attn_implementation = 'flash_attention_2' if use_flash_attn else 'eager' logger.info(f'num_image_token: {self.num_image_token}') logger.info(f'ps_version: {self.ps_version}') if vision_model is not None: self.vision_model = vision_model else: self.vision_model = InternVisionModel(config.vision_config) if language_model is not None: self.language_model = language_model else: if config.llm_config.architectures[0] == 'Qwen2ForCausalLM': self.language_model = Qwen2ForCausalLM(config.llm_config) else: raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') vit_hidden_size = config.vision_config.hidden_size llm_intermediate_size = config.llm_config.intermediate_size llm_hidden_size = config.llm_config.hidden_size self.mlp1 = nn.Sequential( nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_intermediate_size, bias=False), nn.GELU(), nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False) ) self.img_context_token_id = 151654 self.conv_template = get_conv_template(self.template) self.system_message = self.conv_template.system_message def forward( self, pixel_values: torch.FloatTensor, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, image_flags: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_embeds = self.language_model.get_input_embeddings()(input_ids) vit_embeds = self.extract_feature(pixel_values) vit_batch_size = pixel_values.shape[0] B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) try: input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) except Exception as e: vit_embeds = vit_embeds.reshape(-1, C) print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' f'vit_embeds.shape={vit_embeds.shape}') n_token = selected.sum() input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] input_embeds = input_embeds.reshape(B, N, C) outputs = self.language_model( inputs_embeds=input_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) logits = outputs.logits loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) if self.ps_version == 'v1': warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " 'which results in a transposed image.') else: x = x.permute(0, 2, 1, 3).contiguous() return x def extract_feature(self, pixel_values): if self.select_layer == -1: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=False, return_dict=True).last_hidden_state else: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[self.select_layer] vit_embeds = vit_embeds[:, 1:, :] h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) return vit_embeds def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, num_patches_list=None, IMG_START_TOKEN='<|vision_start|>', IMG_END_TOKEN='<|vision_end|>', IMG_CONTEXT_TOKEN='<|vision_pad|>', verbose=False, visual_features=None): if history is None and pixel_values is not None and '' not in question: question = '\n' + question if num_patches_list is None: num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] assert pixel_values is None or len(pixel_values) == sum(num_patches_list) img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id template = get_conv_template(self.template) template.system_message = self.system_message eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) history = [] if history is None else history for (old_question, old_answer) in history: template.append_message(template.roles[0], old_question) template.append_message(template.roles[1], old_answer) template.append_message(template.roles[0], question) template.append_message(template.roles[1], None) query = template.get_prompt() # print('query:', query) if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') for num_patches in num_patches_list: tile_pos_identifiers = [f"" for i in range(1, num_patches)] + [""] image_tokens = '' for tile_pos_identifier in tile_pos_identifiers: image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN query = query.replace('', image_tokens, 1) model_inputs = tokenizer(query, return_tensors='pt') input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() generation_config['eos_token_id'] = eos_token_id generation_output = self.generate( pixel_values=pixel_values, visual_features=visual_features, input_ids=input_ids, attention_mask=attention_mask, **generation_config ) response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] response = response.split(template.sep)[0].strip() history.append((question, response)) if return_history: return response, history else: query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') if verbose: print(query_to_print, response) return response def chat_without_sys_prompt(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, num_patches_list=None, IMG_START_TOKEN='<|vision_start|>', IMG_END_TOKEN='<|vision_end|>', IMG_CONTEXT_TOKEN='<|vision_pad|>', verbose=False, visual_features=None): if history is None and pixel_values is not None and '' not in question: question = '\n' + question if num_patches_list is None: num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] assert pixel_values is None or len(pixel_values) == sum(num_patches_list) img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id template = get_conv_template(self.template) system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" # override dummy system prompt template.system_message = system_prompt eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) history = [] if history is None else history for (old_question, old_answer) in history: template.append_message(template.roles[0], old_question) template.append_message(template.roles[1], old_answer) template.append_message(template.roles[0], question) template.append_message(template.roles[1], None) query = template.get_prompt() if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') query = query[len(system_prompt):] for num_patches in num_patches_list: tile_pos_identifiers = [f"" for i in range(1, num_patches)] + [""] image_tokens = '' for tile_pos_identifier in tile_pos_identifiers: image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN query = query.replace('', image_tokens, 1) model_inputs = tokenizer(query, return_tensors='pt') input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() generation_config['eos_token_id'] = eos_token_id generation_output = self.generate( pixel_values=pixel_values, visual_features=visual_features, input_ids=input_ids, attention_mask=attention_mask, **generation_config ) response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] response = response.split(template.sep)[0].strip() history.append((question, response)) if return_history: return response, history else: query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') if verbose: print(query_to_print, response) return response def chat_without_chat_prompt(self, tokenizer, pixel_values, question, generation_config, num_patches_list=None, IMG_START_TOKEN='<|vision_start|>', IMG_END_TOKEN='<|vision_end|>', IMG_CONTEXT_TOKEN='<|vision_pad|>', verbose=False, visual_features=None): if pixel_values is not None and '' not in question: question = '\n' + question if num_patches_list is None: num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] assert pixel_values is None or len(pixel_values) == sum(num_patches_list) img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id template = get_conv_template(self.template) template.system_message = self.system_message eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) if verbose and pixel_values is not None: image_bs = pixel_values.shape[0] print(f'dynamic ViT batch size: {image_bs}') query = question for num_patches in num_patches_list: tile_pos_identifiers = [f"" for i in range(1, num_patches)] + [""] image_tokens = '' for tile_pos_identifier in tile_pos_identifiers: image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN query = query.replace('', image_tokens, 1) model_inputs = tokenizer(query, return_tensors='pt') input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() generation_config['eos_token_id'] = eos_token_id generation_output = self.generate( pixel_values=pixel_values, visual_features=visual_features, input_ids=input_ids, attention_mask=attention_mask, **generation_config ) response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] response = response.split(template.sep)[0].strip() query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') if verbose: print(query_to_print, response) return response @torch.no_grad() def generate( self, pixel_values: Optional[torch.FloatTensor] = None, input_ids: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, visual_features: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **generate_kwargs, ) -> torch.LongTensor: # assert self.img_context_token_id is not None if pixel_values is not None: if visual_features is not None: vit_embeds = visual_features.cuda() vit_embeds = self.mlp1(vit_embeds) else: vit_embeds = self.extract_feature(pixel_values) input_embeds = self.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) assert selected.sum() != 0 input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) input_embeds = input_embeds.reshape(B, N, C) else: input_embeds = self.language_model.get_input_embeddings()(input_ids) outputs = self.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, generation_config=generation_config, output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=True, **generate_kwargs, ) return outputs def chat_batch( self, tokenizer, pixel_values_list, questions, generation_config, histories=None, return_histories=False, num_patches_lists=None, IMG_START_TOKEN='<|vision_start|>', IMG_END_TOKEN='<|vision_end|>', IMG_CONTEXT_TOKEN='<|vision_pad|>', verbose=False, visual_features_list=None ): if histories is None: histories = [[] for _ in questions] img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) self.img_context_token_id = img_context_token_id # Get eos_token_id from the template template = get_conv_template(self.template) template.system_message = self.system_message eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) generation_config['eos_token_id'] = eos_token_id queries = [] input_ids_list = [] attention_mask_list = [] for idx in range(len(questions)): question = questions[idx] history = histories[idx] pixel_values = pixel_values_list[idx] if pixel_values_list[idx] is not None else None num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] if not history and pixel_values is not None and '' not in question: question = '\n' + question template_i = get_conv_template(self.template) template_i.system_message = self.system_message for (old_question, old_answer) in history: template_i.append_message(template_i.roles[0], old_question) template_i.append_message(template_i.roles[1], old_answer) template_i.append_message(template_i.roles[0], question) template_i.append_message(template_i.roles[1], None) query = template_i.get_prompt() # Handle image tokens if pixel_values is not None: for num_patches in num_patches_list: tile_pos_identifiers = [f"" for i in range(1, num_patches)] + [""] image_tokens = '' for tile_pos_identifier in tile_pos_identifiers: image_tokens += tile_pos_identifier + IMG_CONTEXT_TOKEN * self.num_image_token image_tokens = IMG_START_TOKEN + image_tokens + IMG_END_TOKEN query = query.replace('', image_tokens, 1) model_inputs = tokenizer( query, return_tensors='pt', padding=True, truncation=True ) input_ids = model_inputs['input_ids'].cuda() attention_mask = model_inputs['attention_mask'].cuda() input_ids_list.append(input_ids) attention_mask_list.append(attention_mask) # Call the generate function generation_output = self.generate_batch( pixel_values_list=pixel_values_list, input_ids_list=input_ids_list, attention_mask_list=attention_mask_list, **generation_config ) responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) outputs = [] for idx, response in enumerate(responses): response = response.split(template.sep)[0].strip() histories[idx].append((questions[idx], response)) outputs.append(response) if return_histories: return outputs, histories else: if verbose: for idx, query in enumerate(queries): query_to_print = query.replace(IMG_CONTEXT_TOKEN, '') query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '') print(query_to_print, outputs[idx]) return outputs @torch.no_grad() def generate_batch( self, pixel_values_list: Optional[List[torch.FloatTensor]] = None, input_ids_list: Optional[List[torch.FloatTensor]] = None, attention_mask_list: Optional[List[torch.LongTensor]] = None, visual_features: Optional[torch.FloatTensor] = None, generation_config: Optional[GenerationConfig] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **generate_kwargs, ) -> torch.LongTensor: input_embeds_list = [] attention_mask_padded_list = [] max_seq_length = max(input_ids.shape[1] for input_ids in input_ids_list) for pixel_values, input_ids, attention_mask in zip(pixel_values_list, input_ids_list, attention_mask_list): if pixel_values is not None: if visual_features is not None: vit_embeds = visual_features.cuda() vit_embeds = self.mlp1(vit_embeds) else: vit_embeds = self.extract_feature(pixel_values) input_embeds = self.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape input_embeds = input_embeds.reshape(B * N, C) input_ids = input_ids.reshape(B * N) selected = (input_ids == self.img_context_token_id) assert selected.sum() != 0, "No valid image context token IDs found." input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) input_embeds = input_embeds.reshape(B, N, C) else: input_embeds = self.language_model.get_input_embeddings()(input_ids) seq_length = input_embeds.shape[1] if seq_length < max_seq_length: pad_size = max_seq_length - seq_length input_embeds = F.pad(input_embeds, (0, 0, 0, pad_size)) attention_mask = F.pad(attention_mask, (0, pad_size)) input_embeds_list.append(input_embeds) attention_mask_padded_list.append(attention_mask) input_embeds = torch.cat(input_embeds_list, dim=0) attention_mask = torch.cat(attention_mask_padded_list, dim=0) outputs = self.language_model.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, generation_config=generation_config, output_hidden_states=output_hidden_states, return_dict=return_dict, use_cache=True, **generate_kwargs, ) return outputs