File size: 5,507 Bytes
6a12954 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import jax
import jax.numpy as jnp
from flax import linen as nn
import jax.numpy as jnp
from einops import rearrange
from .attentions import FMHA
def extract_patches1d(x, b):
return rearrange(x, 'batch (L_eff b) -> batch L_eff b', b=b)
def extract_patches2d(x, b):
batch = x.shape[0]
L_eff = int((x.shape[1] // b**2)**0.5)
x = x.reshape(batch, L_eff, b, L_eff, b) # [L_eff, b, L_eff, b]
x = x.transpose(0, 1, 3, 2, 4) # [L_eff, L_eff, b, b]
# flatten the patches
x = x.reshape(batch, L_eff, L_eff, -1) # [L_eff, L_eff, b*b]
x = x.reshape(batch, L_eff*L_eff, -1) # [L_eff*L_eff, b*b]
return x
def log_cosh(x):
sgn_x = -2 * jnp.signbit(x.real) + 1
x = x * sgn_x
return x + jnp.log1p(jnp.exp(-2.0 * x)) - jnp.log(2.0)
class Embed(nn.Module):
d_model : int
b: int
two_dimensional: bool = False
def setup(self):
if self.two_dimensional:
self.extract_patches = extract_patches2d
else:
self.extract_patches = extract_patches1d
self.embed = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
def __call__(self, x):
x = self.extract_patches(x, self.b)
x = self.embed(x)
return x
class EncoderBlock(nn.Module):
d_model : int
h: int
L_eff: int
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
self.attn = FMHA(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional)
self.layer_norm_1 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
self.ff = nn.Sequential([
nn.Dense(4*self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
nn.gelu,
nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
])
def __call__(self, x):
x = x + self.attn(self.layer_norm_1(x))
x = x + self.ff( self.layer_norm_2(x) )
return x
class Encoder(nn.Module):
num_layers: int
d_model : int
h: int
L_eff: int
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
self.layers = [EncoderBlock(d_model=self.d_model, h=self.h, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional) for _ in range(self.num_layers)]
def __call__(self, x):
for l in self.layers:
x = l(x)
return x
class OuputHead(nn.Module):
d_model : int
complex: bool = False
def setup(self):
self.out_layer_norm = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
self.norm0 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64)
self.norm1 = nn.LayerNorm(use_scale=True, use_bias=True, dtype=jnp.float64, param_dtype=jnp.float64)
self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
def __call__(self, x, return_z=False):
z = self.out_layer_norm(x.sum(axis=1))
if return_z:
return z
amp = self.norm0(self.output_layer0(z))
if self.complex:
sign = self.norm1(self.output_layer1(z))
out = amp + 1j*sign
else:
out = amp
return jnp.sum(log_cosh(out), axis=-1)
class ViTFNQS(nn.Module):
num_layers: int
d_model : int
heads: int
L_eff: int
b: int
complex: bool = False
disorder: bool = False
transl_invariant: bool = True
two_dimensional: bool = False
def setup(self):
if self.disorder:
self.patches_and_embed = Embed(self.d_model//2, self.b, two_dimensional=self.two_dimensional)
self.patches_and_embed_coup = Embed(self.d_model//2, self.b, two_dimensional=self.two_dimensional)
else:
self.embed = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
self.encoder = Encoder(num_layers=self.num_layers, d_model=self.d_model, h=self.heads, L_eff=self.L_eff, transl_invariant=self.transl_invariant, two_dimensional=self.two_dimensional)
self.output = OuputHead(self.d_model, complex=self.complex)
def __call__(self, spins, coups, return_z=False):
x = jnp.atleast_2d(spins)
if self.disorder:
x_spins = self.patches_and_embed(x)
x_coups = self.patches_and_embed(coups)
x = jnp.concatenate((x_spins, x_coups), axis=-1)
else:
if self.two_dimensional:
x = extract_patches2d(x, self.b)
else:
x = extract_patches1d(x, self.b)
coups = jnp.broadcast_to(coups, (x.shape[0], x.shape[1], 2))
x = jnp.concatenate((x, coups), axis=-1)
x = self.embed(x)
x = self.encoder(x)
out = self.output(x, return_z=return_z)
return out |