File size: 1,261 Bytes
d30db4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5f03c4
 
 
 
d30db4c
 
 
 
f5f03c4
d30db4c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from transformers import FlaxPreTrainedModel
import jax.numpy as jnp
from .transformer import ViT
from .vitnqs_config import ViTNQSConfig


class ViTNQSModel(FlaxPreTrainedModel):
    config_class = ViTNQSConfig

    def __init__(
            self, 
            config: ViTNQSConfig, 
            input_shape = jnp.zeros((1, 100)), 
            seed: int = 0,
            dtype: jnp.dtype = jnp.float64,
            _do_init: bool = True,
            **kwargs,
    ):
        self.model = ViT(L_eff=config.L_eff,
                         num_layers=config.num_layers,
                         d_model=config.d_model,
                         heads=config.heads,
                         b=config.b,
                         transl_invariant=config.tras_inv,
                         two_dimensional=config.two_dim,
        )
        if not "return_z" in kwargs:
            self.return_z = False
        else:
            self.return_z = kwargs["return_z"]

        super().__init__(config, ViT, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

    def __call__(self, params, spins):
        return self.model.apply(params, spins, self.return_z)

    def init_weights(self, rng, input_shape):
        return self.model.init(rng, input_shape)