Apollo-LMMs-Apollo-7B-t32 / mm_connector /configuration_connector.py
Sri-Vigneshwar-DJ's picture
Upload folder using huggingface_hub
864bc3e verified
raw
history blame contribute delete
1.31 kB
import torch
import torch.nn as nn
from typing import Dict, List, Union
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel
import torch.nn.functional as F
import json, os
class ConnectorConfig(PretrainedConfig):
model_type = "mm_connector"
def __init__(
self,
vision_hidden_size: List[int] = [],
text_hidden_size: int = 0,
num_patches: int = 24,
rms_norm_eps: float = 1e-4,
token_input_shape: List[int] = [],
**kwargs,
):
super().__init__(**kwargs)
self.vision_hidden_size = vision_hidden_size
self.text_hidden_size = text_hidden_size
self.num_patches = num_patches
self.rms_norm_eps=rms_norm_eps
self.token_input_shape = token_input_shape
@classmethod
def load_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "ConnectorConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_from_json(pretrained_model_name_or_path, **kwargs)
return cls.from_dict(config_dict, **kwargs)
@classmethod
def get_config_from_json(cls, config_file, **kwargs):
with open(config_file, 'r') as file:
config_data = json.load(file)
return config_data, kwargs