|
import jax |
|
import jax.numpy as jnp |
|
|
|
from flax import linen as nn |
|
import jax.numpy as jnp |
|
|
|
from einops import rearrange |
|
|
|
from .attentions import FMHA |
|
|
|
def log_cosh(x): |
|
sgn_x = -2 * jnp.signbit(x.real) + 1 |
|
x = x * sgn_x |
|
return x + jnp.log1p(jnp.exp(-2.0 * x)) - jnp.log(2.0) |
|
|
|
def extract_patches1d(x, b): |
|
return rearrange(x, 'batch (L_eff b) -> batch L_eff b', b=b) |
|
|
|
def extract_patches2d(x, b): |
|
batch = x.shape[0] |
|
L_eff = int((x.shape[1] // b**2)**0.5) |
|
x = x.reshape(batch, L_eff, b, L_eff, b) |
|
x = x.transpose(0, 1, 3, 2, 4) |
|
|
|
x = x.reshape(batch, L_eff, L_eff, -1) |
|
x = x.reshape(batch, L_eff*L_eff, -1) |
|
return x |
|
|
|
class Embed(nn.Module): |
|
d_model : int |
|
b: int |
|
two_dimensional: bool = False |
|
|
|
def setup(self): |
|
if self.two_dimensional: |
|
self.extract_patches = extract_patches2d |
|
else: |
|
self.extract_patches = extract_patches1d |
|
|
|
self.embed = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64) |
|
|
|
def __call__(self, x): |
|
x = self.extract_patches(x, self.b) |
|
x = self.embed(x) |
|
|
|
return x |
|
|
|
class EncoderBlock(nn.Module): |
|
d_model : int |
|
h: int |
|
L_eff: int |
|
transl_invariant: bool = True |
|
two_dimensional: bool = False |
|
|
|
def setup(self): |
|
self.attn = FMHA(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) |
|
|
|
self.layer_norm_1 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64) |
|
self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64) |
|
|
|
self.ff = nn.Sequential([ |
|
nn.Dense(4*self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64), |
|
nn.gelu, |
|
nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64), |
|
]) |
|
|
|
|
|
def __call__(self, x): |
|
x = x + self.attn(self.layer_norm_1(x)) |
|
|
|
x = x + self.ff( self.layer_norm_2(x) ) |
|
return x |
|
|
|
class Encoder(nn.Module): |
|
num_layers: int |
|
d_model : int |
|
h: int |
|
L_eff: int |
|
transl_invariant: bool = True |
|
two_dimensional: bool = False |
|
|
|
def setup(self): |
|
self.layers = [EncoderBlock(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) for _ in range(self.num_layers)] |
|
|
|
def __call__(self, x): |
|
|
|
for l in self.layers: |
|
x = l(x) |
|
|
|
return x |
|
|
|
class OuputHead(nn.Module): |
|
d_model : int |
|
|
|
def setup(self): |
|
self.out_layer_norm = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64) |
|
|
|
self.norm2 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64) |
|
self.norm3 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64) |
|
|
|
self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros) |
|
self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros) |
|
|
|
def __call__(self, x, return_z=False): |
|
|
|
z = self.out_layer_norm(x.sum(axis=1)) |
|
if return_z: |
|
return z |
|
|
|
amp = self.norm2(self.output_layer0(z)) |
|
sign = self.norm3(self.output_layer1(z)) |
|
|
|
out = amp + 1j*sign |
|
|
|
return jnp.sum(log_cosh(out), axis=-1) |
|
|
|
class ViT(nn.Module): |
|
num_layers: int |
|
d_model : int |
|
heads: int |
|
L_eff: int |
|
b: int |
|
transl_invariant: bool = True |
|
two_dimensional: bool = False |
|
|
|
def setup(self): |
|
self.patches_and_embed = Embed(self.d_model, self.b, two_dimensional=self.two_dimensional) |
|
|
|
self.encoder = Encoder(num_layers=self.num_layers, d_model=self.d_model, h=self.heads, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) |
|
|
|
self.output = OuputHead(self.d_model) |
|
|
|
|
|
def __call__(self, spins, return_z=False): |
|
x = jnp.atleast_2d(spins) |
|
|
|
x = self.patches_and_embed(x) |
|
|
|
x = self.encoder(x) |
|
|
|
output = self.output(x, return_z=return_z) |
|
|
|
return output |