File size: 2,079 Bytes
d30db4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import jax
import jax.numpy as jnp

from flax import linen as nn
import jax.numpy as jnp

from einops import rearrange

def roll(J, shift, axis=-1):
    return jnp.roll(J, shift, axis=axis)

from functools import partial
@partial(jax.vmap, in_axes=(None, 0, None), out_axes=1)
@partial(jax.vmap, in_axes=(None, None, 0), out_axes=1)
def roll2d(spins, i, j):
    side = int(spins.shape[-1]**0.5)
    spins = spins.reshape(spins.shape[0], side, side)
    spins = jnp.roll(jnp.roll(spins, i, axis=-2), j, axis=-1)
    return spins.reshape(spins.shape[0], -1)
    
class FMHA(nn.Module):
    d_model : int
    h: int
    L_eff: int
    transl_invariant: bool = True
    two_dimensional: bool = False

    def setup(self):
        self.v = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)
        if self.transl_invariant:
            self.J = self.param("J", nn.initializers.xavier_uniform(), (self.h, self.L_eff), jnp.float64)
            if self.two_dimensional:
                sq_L_eff = int(self.L_eff**0.5)
                assert sq_L_eff * sq_L_eff == self.L_eff
                self.J = roll2d(self.J, jnp.arange(sq_L_eff), jnp.arange(sq_L_eff))
                self.J = self.J.reshape(self.h, -1, self.L_eff)
            else:
                self.J = jax.vmap(roll, (None, 0), out_axes=1)(self.J, jnp.arange(self.L_eff))
        else:
            self.J = self.param("J", nn.initializers.xavier_uniform(), (self.h, self.L_eff, self.L_eff), jnp.float64)

        self.W = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64)

    def __call__(self, x):
        v = self.v(x)
        v = rearrange(v, 'batch L_eff (h d_eff) -> batch L_eff h d_eff', h=self.h)
        v = rearrange(v, 'batch L_eff h d_eff -> batch h L_eff d_eff')
        x = jnp.matmul(self.J, v)
        x = rearrange(x, 'batch h L_eff d_eff  -> batch L_eff h d_eff')
        x = rearrange(x, 'batch L_eff h d_eff ->  batch L_eff (h d_eff)')

        x = self.W(x)

        return x