M3Site / esm /layers /geom_attention.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
raw
history blame contribute delete
5.85 kB
from math import sqrt
import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
class GeometricReasoningOriginalImpl(nn.Module):
def __init__(
self,
c_s: int,
v_heads: int,
num_vector_messages: int = 1,
mask_and_zero_frameless: bool = True,
divide_residual_by_depth: bool = False,
bias: bool = False,
):
"""Approximate implementation:
ATTN(A, v) := (softmax_j A_ij) v_j
make_rot_vectors(x) := R(i->g) Linear(x).reshape(..., 3)
make_vectors(x) := T(i->g) Linear(x).reshape(..., 3)
v <- make_rot_vectors(x)
q_dir, k_dir <- make_rot_vectors(x)
q_dist, k_dist <- make_vectors(x)
A_ij <- dot(q_dir_i, k_dir_j) -||q_dist_i - k_dist_j||^2
x <- x + Linear(T(g->i) ATTN(A, v))
"""
super().__init__()
self.c_s = c_s
self.v_heads = v_heads
self.num_vector_messages = num_vector_messages
self.mask_and_zero_frameless = mask_and_zero_frameless
self.s_norm = nn.LayerNorm(c_s, bias=bias)
dim_proj = (
4 * self.v_heads * 3 + self.v_heads * 3 * self.num_vector_messages
) # 2 x (q, k) * number of heads * (x, y, z) + number of heads * number of vector messages * (x, y, z)
self.proj = nn.Linear(c_s, dim_proj, bias=bias)
channels_out = self.v_heads * 3 * self.num_vector_messages
self.out_proj = nn.Linear(channels_out, c_s, bias=bias)
# The basic idea is for some attention heads to pay more or less attention to rotation versus distance,
# as well as to control the sharpness of the softmax (i.e., should this head only attend to those residues
# very nearby or should there be shallower dropoff in attention weight?)
self.distance_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
self.rotation_scale_per_head = nn.Parameter(torch.zeros((self.v_heads)))
def forward(self, s, affine, affine_mask, sequence_id, chain_id):
attn_bias = sequence_id.unsqueeze(-1) == sequence_id.unsqueeze(-2)
attn_bias = attn_bias.unsqueeze(1).float()
attn_bias = attn_bias.masked_fill(
~affine_mask[:, None, None, :], torch.finfo(attn_bias.dtype).min
)
chain_id_mask = chain_id.unsqueeze(1) != chain_id.unsqueeze(2)
attn_bias = attn_bias.masked_fill(
chain_id_mask.unsqueeze(1), torch.finfo(s.dtype).min
)
ns = self.s_norm(s)
vec_rot, vec_dist = self.proj(ns).split(
[
self.v_heads * 2 * 3 + self.v_heads * 3 * self.num_vector_messages,
self.v_heads * 2 * 3,
],
dim=-1,
)
# Rotate the queries and keys for the rotation term. We also rotate the values.
# NOTE(zeming, thayes): Values are only rotated, not translated. We may wish to change
# this in the future.
query_rot, key_rot, value = (
affine.rot[..., None]
.apply(rearrange(vec_rot, "... (h c) -> ... h c", c=3))
.split(
[
self.v_heads,
self.v_heads,
self.v_heads * self.num_vector_messages,
],
dim=-2,
)
)
# Rotate and translate the queries and keys for the distance term
# NOTE(thayes): a simple speedup would be to apply all rotations together, then
# separately apply the translations.
query_dist, key_dist = (
affine[..., None]
.apply(rearrange(vec_dist, "... (h c) -> ... h c", c=3))
.chunk(2, dim=-2)
)
query_dist = rearrange(query_dist, "b s h d -> b h s 1 d")
key_dist = rearrange(key_dist, "b s h d -> b h 1 s d")
query_rot = rearrange(query_rot, "b s h d -> b h s d")
key_rot = rearrange(key_rot, "b s h d -> b h d s")
value = rearrange(
value, "b s (h m) d -> b h s (m d)", m=self.num_vector_messages
)
distance_term = (query_dist - key_dist).norm(dim=-1) / sqrt(3)
rotation_term = query_rot.matmul(key_rot) / sqrt(3)
distance_term_weight = rearrange(
F.softplus(self.distance_scale_per_head), "h -> h 1 1"
)
rotation_term_weight = rearrange(
F.softplus(self.rotation_scale_per_head), "h -> h 1 1"
)
attn_weight = (
rotation_term * rotation_term_weight - distance_term * distance_term_weight
)
if attn_bias is not None:
# we can re-use the attention bias from the transformer layers
# NOTE(thayes): This attention bias is expected to handle two things:
# 1. Masking attention on padding tokens
# 2. Masking cross sequence attention in the case of bin packing
s_q = attn_weight.size(2)
s_k = attn_weight.size(3)
_s_q = max(0, attn_bias.size(2) - s_q)
_s_k = max(0, attn_bias.size(3) - s_k)
attn_bias = attn_bias[:, :, _s_q:, _s_k:]
attn_weight = attn_weight + attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_out = attn_weight.matmul(value)
attn_out = (
affine.rot[..., None]
.invert()
.apply(
rearrange(
attn_out, "b h s (m d) -> b s (h m) d", m=self.num_vector_messages
)
)
)
attn_out = rearrange(
attn_out, "b s (h m) d -> b s (h m d)", m=self.num_vector_messages
)
if self.mask_and_zero_frameless:
attn_out = attn_out.masked_fill(~affine_mask[..., None], 0.0)
s = self.out_proj(attn_out)
return s