# Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/encoder.py import types import deepspeed import torch import torch.nn as nn import torch.nn.functional as F from egogpt.utils import rank0_print from .model import ModelDimensions, Whisper def load_zero_partitions( model, state_dict, is_deepspeed_zero3_enabled, pretrained_model_path, ignore_mismatched_sizes=False, ): """ adept from pytorch lightning and transformers with deepspeed.zero.Init(): model = MyModel() state_dict = torch.load(model_path, map_location="cpu") load_zero_partitions(model, prefix="") """ # because zero3 puts placeholders in model params, this context # manager gathers (unpartitions) the params of the current layer, then loads from # the state dict and then re-partitions them again model_state_dict = model.state_dict() expected_keys = list(model_state_dict.keys()) loaded_keys = list(state_dict.keys()) missing_keys = list(set(expected_keys) - set(loaded_keys)) unexpected_keys = list(set(loaded_keys) - set(expected_keys)) # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not # matching the weights in the model. mismatched_keys = [] if ignore_mismatched_sizes: for checkpoint_key in loaded_keys: model_key = checkpoint_key if ( model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape ): mismatched_keys.append( ( checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape, ) ) del state_dict[checkpoint_key] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata error_msgs = [] # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants # so we need to apply the function recursively. def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) if is_deepspeed_zero3_enabled: # because zero3 puts placeholders in model params, this context # manager gathers (unpartitions) the params of the current layer, then loads from # the state dict and then re-partitions them again with deepspeed.zero.GatheredParameters( list(module.parameters(recurse=False)), modifier_rank=0 ): if torch.distributed.get_rank() == 0: module._load_from_state_dict(*args) else: module._load_from_state_dict(*args) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") # Make sure we are able to load base models as well as derived models (with heads) start_prefix = "" model_to_load = model load(model_to_load, prefix=start_prefix) del state_dict if len(error_msgs) > 0: error_msg = "\n\t".join(error_msgs) if "size mismatch" in error_msg: error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." raise RuntimeError( f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}" ) if len(unexpected_keys) > 0: rank0_print( f"Some weights of the model checkpoint at {pretrained_model_path} were not used when" f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" " with another architecture (e.g. initializing a BertForSequenceClassification model from a" " BertForPreTraining model).\n- This IS NOT expected if you are initializing" f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: rank0_print( f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" ) if len(missing_keys) > 0: rank0_print( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably" " TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) elif len(mismatched_keys) == 0: rank0_print( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint" f" was trained on, you can already use {model.__class__.__name__} for predictions without further" " training." ) if len(mismatched_keys) > 0: mismatched_warning = "\n".join( [ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" for key, shape1, shape2 in mismatched_keys ] ) rank0_print( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" f" {pretrained_model_path} and are newly initialized because the shapes did not" f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" " to use it for predictions and inference." ) class WhisperWrappedEncoder(nn.Module): def __init__(self, config, delay_load=False): super().__init__() self.is_loaded = False self.speech_encoder_name = config.speech_encoder if not delay_load: rank0_print(f"Loading speech encoder: {self.speech_encoder_name}") self.load_model(config) def load_model(self, model_config): if self.is_loaded: print( "{} is already loaded, `load_model` called again, skipping.".format( self.speech_encoder_name ) ) return def replace_layer_norm(module): from whisper.model import LayerNorm for name, child in module.named_children(): if isinstance(child, LayerNorm): old_params = child.state_dict() new_layer_norm = nn.LayerNorm( child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine, ) new_layer_norm.load_state_dict(old_params) setattr(module, name, new_layer_norm) else: replace_layer_norm(child) # import whisper # self.encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder checkpoint = torch.load(self.speech_encoder_name, map_location="cpu") dims = ModelDimensions(**checkpoint["dims"]) model = Whisper(dims) deepspeed3_enabled = True # print(deepspeed3_enabled) load_zero_partitions( model, checkpoint["model_state_dict"], deepspeed3_enabled, self.speech_encoder_name, ) self.encoder = model.encoder replace_layer_norm(self.encoder) self.encoder.requires_grad_(False) self.is_loaded = True def forward(self, audio): return self.encoder(audio)