|
from transformers import FlaxPreTrainedModel |
|
import jax.numpy as jnp |
|
from .transformer import ViT |
|
from .vitnqs_config import ViTNQSConfig |
|
|
|
|
|
class ViTNQSModel(FlaxPreTrainedModel): |
|
config_class = ViTNQSConfig |
|
|
|
def __init__( |
|
self, |
|
config: ViTNQSConfig, |
|
input_shape = jnp.zeros((1, 100)), |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float64, |
|
_do_init: bool = True, |
|
**kwargs, |
|
): |
|
self.model = ViT(L_eff=config.L_eff, |
|
num_layers=config.num_layers, |
|
d_model=config.d_model, |
|
heads=config.heads, |
|
b=config.b, |
|
transl_invariant=config.tras_inv, |
|
two_dimensional=config.two_dim, |
|
) |
|
if not "return_z" in kwargs: |
|
self.return_z = False |
|
else: |
|
self.return_z = kwargs["return_z"] |
|
|
|
super().__init__(config, ViT, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
|
def __call__(self, params, spins): |
|
return self.model.apply(params, spins, self.return_z) |
|
|
|
def init_weights(self, rng, input_shape): |
|
return self.model.init(rng, input_shape) |