j1j2_square_10x10_05 / transformer.py
rrende's picture
Upload model
12647a4 verified
import jax
import jax.numpy as jnp
from flax import linen as nn
import jax.numpy as jnp
from einops import rearrange
from .attentions import FMHA
def log_cosh(x):
sgn_x = -2 * jnp.signbit(x.real) + 1
x = x * sgn_x
return x + jnp.log1p(jnp.exp(-2.0 * x)) - jnp.log(2.0)
def extract_patches1d(x, b):
return rearrange(x, 'batch (L_eff b) -> batch L_eff b', b=b)
def extract_patches2d(x, b):
batch = x.shape[0]
L_eff = int((x.shape[1] // b**2)**0.5)
x = x.reshape(batch, L_eff, b, L_eff, b) # [L_eff, b, L_eff, b]
x = x.transpose(0, 1, 3, 2, 4) # [L_eff, L_eff, b, b]
# flatten the patches
x = x.reshape(batch, L_eff, L_eff, -1) # [L_eff, L_eff, b*b]
x = x.reshape(batch, L_eff*L_eff, -1) # [L_eff*L_eff, b*b]
return x
class Embed(nn.Module):
d_model : int
b: int
two_dimensional: bool = False
def setup(self):
if self.two_dimensional:
self.extract_patches = extract_patches2d
else:
self.extract_patches = extract_patches1d
self.embed = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
def __call__(self, x):
x = self.extract_patches(x, self.b)
x = self.embed(x)
return x
class EncoderBlock(nn.Module):
d_model : int
h: int
L_eff: int
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
self.attn = FMHA(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional)
self.layer_norm_1 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
self.ff = nn.Sequential([
nn.Dense(4*self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
nn.gelu,
nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
])
def __call__(self, x):
x = x + self.attn(self.layer_norm_1(x))
x = x + self.ff( self.layer_norm_2(x) )
return x
class Encoder(nn.Module):
num_layers: int
d_model : int
h: int
L_eff: int
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
self.layers = [EncoderBlock(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) for _ in range(self.num_layers)]
def __call__(self, x):
for l in self.layers:
x = l(x)
return x
class OuputHead(nn.Module):
d_model : int
def setup(self):
self.out_layer_norm = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
self.norm2 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64)
self.norm3 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64)
self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
def __call__(self, x, return_z=False):
z = self.out_layer_norm(x.sum(axis=1))
if return_z:
return z
amp = self.norm2(self.output_layer0(z))
sign = self.norm3(self.output_layer1(z))
out = amp + 1j*sign
return jnp.sum(log_cosh(out), axis=-1)
class ViT(nn.Module):
num_layers: int
d_model : int
heads: int
L_eff: int
b: int
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
self.patches_and_embed = Embed(self.d_model, self.b, two_dimensional=self.two_dimensional)
self.encoder = Encoder(num_layers=self.num_layers, d_model=self.d_model, h=self.heads, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional)
self.output = OuputHead(self.d_model)
def __call__(self, spins, return_z=False):
x = jnp.atleast_2d(spins)
x = self.patches_and_embed(x)
x = self.encoder(x)
output = self.output(x, return_z=return_z)
return output