File size: 1,459 Bytes
6a12954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers import FlaxPreTrainedModel
import jax.numpy as jnp
from .transformer_fnqs import ViTFNQS
from .vit_fnqs_config import ViTFNQSConfig

class ViTFNQSModel(FlaxPreTrainedModel):
    config_class = ViTFNQSConfig

    def __init__(
            self, 
            config: ViTFNQSConfig, 
            input_shape = (jnp.zeros((1, 100)), jnp.zeros((1, 1))),
            seed: int = 0,
            dtype: jnp.dtype = jnp.float64,
            _do_init: bool = True,
            **kwargs,
    ):
        self.model = ViTFNQS(L_eff=config.L_eff,
                             num_layers=config.num_layers,
                             d_model=config.d_model,
                             heads=config.heads,
                             b=config.b,
                             complex=config.complex,
                             disorder=config.disorder,
                             transl_invariant=config.tras_inv,
                             two_dimensional=config.two_dim,
        )
        if not "return_z" in kwargs:
            self.return_z = False
        else:
            self.return_z = kwargs["return_z"]

        super().__init__(config, ViTFNQS, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def __call__(self, params, spins, coups):
        return self.model.apply(params, spins, coups, return_z=self.return_z)

    def init_weights(self, rng, input_shape):
        return self.model.init(rng, *input_shape)