M3Site / esm /layers /transformer_stack.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
raw
history blame contribute delete
3.77 kB
import math
import torch
import torch.nn as nn
from esm.layers.blocks import UnifiedTransformerBlock
from esm.utils.structure.affine3d import Affine3D
class TransformerStack(nn.Module):
"""
A stack of transformer blocks used in the ESM-3 model. Each block is a UnifiedTransformerBlock,
which can either be geometric attention or standard multi-head attention.
Args:
d_model (int): The dimensionality of the input and output feature vectors.
n_heads (int): The number of attention heads.
v_heads (int): The number of voting heads.
n_layers (int): The number of transformer blocks in the stack.
n_layers_geom (int, optional): The number of transformer blocks that use geometric attention.
scale_residue (bool, optional): Whether to scale the residue connections in each transformer block.
mask_and_zero_frameless (bool, optional): Whether to mask and zero frameless positions in the input.
Only applies in the geometric attention blocks, which is conditioned on the structure
"""
def __init__(
self,
d_model: int,
n_heads: int,
v_heads: int | None,
n_layers: int,
n_layers_geom: int = 1,
scale_residue: bool = True,
mask_and_zero_frameless: bool = False,
bias: bool = False,
qk_layernorm: bool = True,
ffn_type: str = "swiglu", # swiglu | gelu
expansion_ratio: float = 8 / 3,
):
super().__init__()
self.blocks = nn.ModuleList(
[
UnifiedTransformerBlock(
d_model,
n_heads,
v_heads=v_heads,
use_geom_attn=i < n_layers_geom,
residue_scaling_factor=(
math.sqrt(n_layers / 36) if scale_residue else 1.0
),
expansion_ratio=expansion_ratio,
mask_and_zero_frameless=mask_and_zero_frameless,
bias=bias,
qk_layernorm=qk_layernorm,
ffn_type=ffn_type,
)
for i in range(n_layers)
]
)
self.norm = nn.LayerNorm(d_model, bias=False)
def forward(
self,
x: torch.Tensor,
sequence_id: torch.Tensor | None = None,
affine: Affine3D | None = None,
affine_mask: torch.Tensor | None = None,
chain_id: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass of the TransformerStack.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model).
sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length).
affine (Affine3D | None): The affine transformation tensor or None.
affine_mask (torch.Tensor | None): The affine mask tensor or None.
chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length).
Only used in geometric attention.
Returns:
post_norm: The output tensor of shape (batch_size, sequence_length, d_model).
pre_norm: The embedding of shape (batch_size, sequence_length, d_model).
"""
*batch_dims, _ = x.shape
if sequence_id is None:
sequence_id = torch.ones(
size=batch_dims, dtype=torch.int64, device=x.device
)
if chain_id is None:
chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device)
for block in self.blocks:
x = block(x, sequence_id, affine, affine_mask, chain_id)
return self.norm(x), x