File size: 4,558 Bytes
d30db4c f5f03c4 d30db4c f5f03c4 d30db4c f5f03c4 d30db4c 12647a4 d30db4c 12647a4 d30db4c f5f03c4 d30db4c 673dcdd d30db4c 673dcdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import jax
import jax.numpy as jnp
from flax import linen as nn
import jax.numpy as jnp
from einops import rearrange
from .attentions import FMHA
def log_cosh(x):
sgn_x = -2 * jnp.signbit(x.real) + 1
x = x * sgn_x
return x + jnp.log1p(jnp.exp(-2.0 * x)) - jnp.log(2.0)
def extract_patches1d(x, b):
return rearrange(x, 'batch (L_eff b) -> batch L_eff b', b=b)
def extract_patches2d(x, b):
batch = x.shape[0]
L_eff = int((x.shape[1] // b**2)**0.5)
x = x.reshape(batch, L_eff, b, L_eff, b) # [L_eff, b, L_eff, b]
x = x.transpose(0, 1, 3, 2, 4) # [L_eff, L_eff, b, b]
# flatten the patches
x = x.reshape(batch, L_eff, L_eff, -1) # [L_eff, L_eff, b*b]
x = x.reshape(batch, L_eff*L_eff, -1) # [L_eff*L_eff, b*b]
return x
class Embed(nn.Module):
d_model : int
b: int
two_dimensional: bool = False
def setup(self):
if self.two_dimensional:
self.extract_patches = extract_patches2d
else:
self.extract_patches = extract_patches1d
self.embed = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
def __call__(self, x):
x = self.extract_patches(x, self.b)
x = self.embed(x)
return x
class EncoderBlock(nn.Module):
d_model : int
h: int
L_eff: int
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
self.attn = FMHA(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional)
self.layer_norm_1 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
self.ff = nn.Sequential([
nn.Dense(4*self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
nn.gelu,
nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
])
def __call__(self, x):
x = x + self.attn(self.layer_norm_1(x))
x = x + self.ff( self.layer_norm_2(x) )
return x
class Encoder(nn.Module):
num_layers: int
d_model : int
h: int
L_eff: int
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
self.layers = [EncoderBlock(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) for _ in range(self.num_layers)]
def __call__(self, x):
for l in self.layers:
x = l(x)
return x
class OuputHead(nn.Module):
d_model : int
def setup(self):
self.out_layer_norm = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
self.norm2 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64)
self.norm3 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64)
self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
def __call__(self, x, return_z=False):
z = self.out_layer_norm(x.sum(axis=1))
if return_z:
return z
amp = self.norm2(self.output_layer0(z))
sign = self.norm3(self.output_layer1(z))
out = amp + 1j*sign
return jnp.sum(log_cosh(out), axis=-1)
class ViT(nn.Module):
num_layers: int
d_model : int
heads: int
L_eff: int
b: int
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
self.patches_and_embed = Embed(self.d_model, self.b, two_dimensional=self.two_dimensional)
self.encoder = Encoder(num_layers=self.num_layers, d_model=self.d_model, h=self.heads, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional)
self.output = OuputHead(self.d_model)
def __call__(self, spins, return_z=False):
x = jnp.atleast_2d(spins)
x = self.patches_and_embed(x)
x = self.encoder(x)
output = self.output(x, return_z=return_z)
return output |