Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, Union | |
import mp_utils | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from xformers.ops import RMSNorm, fmha, rope_padded | |
from xformers.ops.fmha.attn_bias import ( | |
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, | |
) | |
class ModelArgs: | |
dim: int = 512 | |
n_layers: int = 8 | |
n_heads: int = 8 | |
n_kv_heads: Optional[int] = None | |
vocab_size: int = -1 | |
ffn_dim_multiplier: Optional[float] = None | |
multiple_of: int = 256 | |
""" | |
Enforces that the SwiGLU hidden layer size is a multiple | |
of large power of 2. | |
""" | |
norm_eps: float = 1e-5 | |
rope_theta: float = 10000.0 | |
""" | |
Positional encoding parameter; increase to 1e6 to run | |
Code Llama models with long contexts. | |
""" | |
LayerCache = Tuple[torch.Tensor, torch.Tensor] | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
head_dim: int, | |
n_heads: int, | |
n_kv_heads: int, | |
rope_theta: float, | |
): | |
super().__init__() | |
mp_size = mp_utils.get_world_size() | |
self.head_dim = head_dim | |
self.rope_theta = rope_theta | |
self.n_local_heads = n_heads // mp_size | |
self.n_local_kv_heads = n_kv_heads // mp_size | |
self.wqkv = nn.Linear( | |
dim, | |
(self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim, | |
bias=False, | |
) | |
self.wo = nn.Linear( | |
self.n_local_heads * head_dim, | |
dim, | |
bias=False, | |
) | |
self._register_load_state_dict_pre_hook(self.load_hook) | |
# This adapter makes sure we can load vanilla | |
# Llama checkpoints where wq, wk, and wv are | |
# not fused in a single parameter | |
def load_hook( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
if prefix + "wq.weight" in state_dict: | |
wq = state_dict.pop(prefix + "wq.weight") | |
wk = state_dict.pop(prefix + "wk.weight") | |
wv = state_dict.pop(prefix + "wv.weight") | |
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) | |
def forward( | |
self, | |
x: torch.Tensor, | |
cache: LayerCache, | |
attn_bias: AttnBias, | |
position_index: Optional[torch.Tensor], | |
) -> torch.Tensor: | |
# x.shape is (sum(seq_lens), dim) | |
# | |
# Since we support heterogenous sequence | |
# lengths, the hidden states are all | |
# concatenated together along the usual | |
# sequence dimension. The attention below | |
# finds out where sequences start & end | |
# using the provided attention bias. | |
xqkv = self.wqkv(x) | |
xq = xqkv[:, : (self.n_local_heads * self.head_dim)] | |
xkv = xqkv[:, (self.n_local_heads * self.head_dim) :] | |
xk, xv = xkv.chunk(2, 1) | |
output_shape = xq.shape | |
heads_per_group = self.n_local_heads // self.n_local_kv_heads | |
xq = xq.view( | |
1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim | |
) | |
xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim) | |
xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim) | |
cache_k, cache_v = cache | |
xq = rope_padded( | |
xq=xq, | |
xk=xk, | |
xv=xv, | |
cache_k=cache_k, | |
cache_v=cache_v, | |
attn_bias=attn_bias, | |
theta=self.rope_theta, | |
) | |
# rope_padded() updated the caches, so we | |
# call attention directly | |
output = fmha.memory_efficient_attention_forward( | |
xq, cache_k, cache_v, attn_bias | |
) | |
output = output.reshape(output_shape) | |
if position_index is not None: | |
output = output[position_index] | |
output = self.wo(output) | |
mp_utils.all_reduce(output) | |
return output | |
class FeedForward(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
hidden_dim: int, | |
multiple_of: int, | |
ffn_dim_multiplier: Optional[float], | |
): | |
super().__init__() | |
mp_size = mp_utils.get_world_size() | |
hidden_dim = int(2 * hidden_dim / 3) | |
if ffn_dim_multiplier is not None: | |
hidden_dim = int(ffn_dim_multiplier * hidden_dim) | |
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
assert hidden_dim % mp_size == 0 | |
self.w13 = nn.Linear( | |
dim, | |
2 * hidden_dim // mp_size, | |
bias=False, | |
) | |
self.w2 = nn.Linear( | |
hidden_dim // mp_size, | |
dim, | |
bias=False, | |
) | |
self._register_load_state_dict_pre_hook(self.load_hook) | |
# This adapter makes sure we can load vanilla | |
# Llama checkpoints where w1 and w3 are not | |
# fused in a single parameter | |
def load_hook( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
if prefix + "w1.weight" in state_dict: | |
w1 = state_dict.pop(prefix + "w1.weight") | |
w3 = state_dict.pop(prefix + "w3.weight") | |
state_dict[prefix + "w13.weight"] = torch.cat([w1, w3]) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x13 = self.w13(x) | |
x1, x3 = x13.chunk(2, -1) | |
output = self.w2(F.silu(x1) * x3) | |
mp_utils.all_reduce(output) | |
return output | |
class TransformerBlock(nn.Module): | |
def __init__(self, args: ModelArgs, layer_index: int): | |
super().__init__() | |
assert args.dim % args.n_heads == 0 | |
head_dim = args.dim // args.n_heads | |
if args.n_kv_heads is not None: | |
n_kv_heads = args.n_kv_heads | |
else: | |
n_kv_heads = args.n_heads | |
mp_size = mp_utils.get_world_size() | |
assert args.n_heads % n_kv_heads == 0 | |
assert args.n_heads % mp_size == 0 | |
assert n_kv_heads % mp_size == 0 | |
self.is_last_layer = layer_index + 1 == args.n_layers | |
self.attention = Attention( | |
dim=args.dim, | |
head_dim=head_dim, | |
n_heads=args.n_heads, | |
n_kv_heads=n_kv_heads, | |
rope_theta=args.rope_theta, | |
) | |
self.feed_forward = FeedForward( | |
dim=args.dim, | |
hidden_dim=4 * args.dim, | |
multiple_of=args.multiple_of, | |
ffn_dim_multiplier=args.ffn_dim_multiplier, | |
) | |
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
def forward( | |
self, | |
x: torch.Tensor, | |
cache: LayerCache, | |
attn_bias: AttnBias, | |
) -> torch.Tensor: | |
position_index = None | |
if self.is_last_layer and attn_bias.q_seqinfo.max_seqlen > 1: | |
position_index = attn_bias.q_seqinfo.seqstart[1:] - 1 | |
h = self.attention.forward( | |
self.attention_norm(x), | |
cache, | |
attn_bias, | |
position_index=position_index, | |
) | |
if position_index is not None: | |
x = x[position_index] | |
h = h + x | |
out = h + self.feed_forward(self.ffn_norm(h)) | |
return out | |
class Transformer(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
mp_size = mp_utils.get_world_size() | |
assert args.dim % mp_size == 0 | |
assert args.vocab_size > 0 | |
assert args.vocab_size % mp_size == 0 | |
self.tok_embeddings = nn.Embedding( | |
num_embeddings=args.vocab_size, | |
embedding_dim=args.dim // mp_size, | |
) | |
self.layers = nn.ModuleList() | |
for layer_index in range(args.n_layers): | |
self.layers.append(TransformerBlock(args, layer_index)) | |
self.norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.output = nn.Linear( | |
args.dim, | |
args.vocab_size // mp_size, | |
bias=False, | |
) | |
def forward_with_attn_bias( | |
self, | |
token_values: torch.Tensor, | |
attn_bias: AttnBias, | |
cache: list[LayerCache], | |
) -> torch.Tensor: | |
h_parallel = self.tok_embeddings(token_values) | |
h = mp_utils.all_gather(h_parallel) | |
for i, layer in enumerate(self.layers): | |
h = layer(h, cache[i], attn_bias) | |
logits_parallel = self.output(self.norm(h)) | |
logits = mp_utils.all_gather(logits_parallel) | |
return logits.float() | |
def forward( | |
self, | |
token_values: torch.Tensor, | |
token_lengths: torch.Tensor, | |
start_pos: torch.Tensor, | |
cache: list[LayerCache], | |
kv_padding: int, | |
) -> torch.Tensor: | |
attn_bias = AttnBias.from_seqlens( | |
q_seqlen=token_lengths.tolist(), | |
kv_seqlen=(start_pos + token_lengths).tolist(), | |
kv_padding=kv_padding, | |
) | |
return self.forward_with_attn_bias(token_values, attn_bias, cache) | |
def make_cache( | |
args: ModelArgs, | |
length: int, | |
device: Optional[Union[str, torch.device]] = None, | |
n_layers: Optional[int] = None, | |
dtype: Optional[torch.dtype] = None, | |
) -> list[LayerCache]: | |
""" | |
Allocate a cache to be used with the Transformer module. | |
Args: | |
args (ModelArgs): the model configuration. | |
length (int): per layer cache size. | |
It is usually budgeted as ``max_batch * max_seq`` | |
device (torch.device, optional): the device on which | |
the cache should be allocated. | |
n_layers (int, optional): the number of layers to | |
allocate a cache for (defaults to the model | |
settings). | |
dtype (torch.dtype, optional): the dtype to use for | |
cache entries (defaults to the default dtype). | |
Returns: | |
The cache object to pass to ``Tranformer.forward``. | |
""" | |
head_dim = args.dim // args.n_heads | |
n_kv_heads = args.n_kv_heads | |
if n_kv_heads is None: | |
n_kv_heads = args.n_heads | |
n_local_kv_heads = n_kv_heads // mp_utils.get_world_size() | |
if n_layers is None: | |
n_layers = args.n_layers | |
shape = (1, length, n_local_kv_heads, 1, head_dim) | |
heads_per_group = args.n_heads // n_kv_heads | |
expansion = (-1, -1, -1, heads_per_group, -1) | |
return [ | |
( | |
torch.zeros(shape, device=device, dtype=dtype).expand(expansion), | |
torch.zeros(shape, device=device, dtype=dtype).expand(expansion), | |
) | |
for _ in range(n_layers) | |
] | |
def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]: | |
""" | |
Take a prefix view of a larger cache. | |
The original cache object remains of identical size and valid | |
after the shrinked alias has been used. This function is useful | |
when a cache was allocated for a larger batch size than what is | |
necessary. | |
Args: | |
cache: the cache to take a view in. | |
length (int): the desired length | |
Returns: | |
A view in the input cache object. | |
""" | |
if len(cache) > 0: | |
assert cache[0][0].shape[1] >= length | |
return [(ck[:, :length], cv[:, :length]) for ck, cv in cache] | |