File size: 578 Bytes
d30db4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from transformers import PretrainedConfig
from typing import List
class ViTNQSConfig(PretrainedConfig):
model_type = "vit_nqs"
def __init__(
self,
L_eff=25,
num_layers = 8,
d_model = 72,
heads = 12,
b = 2,
tras_inv = True,
two_dim = True,
**kwargs,
):
self.L_eff = L_eff
self.num_layers = num_layers
self.d_model = d_model
self.heads = heads
self.b = b
self.tras_inv = tras_inv
self.two_dim = two_dim
super().__init__(**kwargs) |