from transformers import PretrainedConfig from typing import List class ViTNQSConfig(PretrainedConfig): model_type = "vit_nqs" def __init__( self, L_eff=25, num_layers = 8, d_model = 72, heads = 12, b = 2, tras_inv = True, two_dim = True, **kwargs, ): self.L_eff = L_eff self.num_layers = num_layers self.d_model = d_model self.heads = heads self.b = b self.tras_inv = tras_inv self.two_dim = two_dim super().__init__(**kwargs)