FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import mp_utils
import torch
from torch import nn
from torch.nn import functional as F
from xformers.ops import RMSNorm, fmha, rope_padded
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
)
@dataclass
class ModelArgs:
dim: int = 512
n_layers: int = 8
n_heads: int = 8
n_kv_heads: Optional[int] = None
vocab_size: int = -1
ffn_dim_multiplier: Optional[float] = None
multiple_of: int = 256
"""
Enforces that the SwiGLU hidden layer size is a multiple
of large power of 2.
"""
norm_eps: float = 1e-5
rope_theta: float = 10000.0
"""
Positional encoding parameter; increase to 1e6 to run
Code Llama models with long contexts.
"""
LayerCache = Tuple[torch.Tensor, torch.Tensor]
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int,
n_heads: int,
n_kv_heads: int,
rope_theta: float,
):
super().__init__()
mp_size = mp_utils.get_world_size()
self.head_dim = head_dim
self.rope_theta = rope_theta
self.n_local_heads = n_heads // mp_size
self.n_local_kv_heads = n_kv_heads // mp_size
self.wqkv = nn.Linear(
dim,
(self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
bias=False,
)
self.wo = nn.Linear(
self.n_local_heads * head_dim,
dim,
bias=False,
)
self._register_load_state_dict_pre_hook(self.load_hook)
# This adapter makes sure we can load vanilla
# Llama checkpoints where wq, wk, and wv are
# not fused in a single parameter
def load_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def forward(
self,
x: torch.Tensor,
cache: LayerCache,
attn_bias: AttnBias,
position_index: Optional[torch.Tensor],
) -> torch.Tensor:
# x.shape is (sum(seq_lens), dim)
#
# Since we support heterogenous sequence
# lengths, the hidden states are all
# concatenated together along the usual
# sequence dimension. The attention below
# finds out where sequences start & end
# using the provided attention bias.
xqkv = self.wqkv(x)
xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
xk, xv = xkv.chunk(2, 1)
output_shape = xq.shape
heads_per_group = self.n_local_heads // self.n_local_kv_heads
xq = xq.view(
1, xq.shape[0], self.n_local_kv_heads, heads_per_group, self.head_dim
)
xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, 1, self.head_dim)
xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, 1, self.head_dim)
cache_k, cache_v = cache
xq = rope_padded(
xq=xq,
xk=xk,
xv=xv,
cache_k=cache_k,
cache_v=cache_v,
attn_bias=attn_bias,
theta=self.rope_theta,
)
# rope_padded() updated the caches, so we
# call attention directly
output = fmha.memory_efficient_attention_forward(
xq, cache_k, cache_v, attn_bias
)
output = output.reshape(output_shape)
if position_index is not None:
output = output[position_index]
output = self.wo(output)
mp_utils.all_reduce(output)
return output
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
mp_size = mp_utils.get_world_size()
hidden_dim = int(2 * hidden_dim / 3)
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
assert hidden_dim % mp_size == 0
self.w13 = nn.Linear(
dim,
2 * hidden_dim // mp_size,
bias=False,
)
self.w2 = nn.Linear(
hidden_dim // mp_size,
dim,
bias=False,
)
self._register_load_state_dict_pre_hook(self.load_hook)
# This adapter makes sure we can load vanilla
# Llama checkpoints where w1 and w3 are not
# fused in a single parameter
def load_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
if prefix + "w1.weight" in state_dict:
w1 = state_dict.pop(prefix + "w1.weight")
w3 = state_dict.pop(prefix + "w3.weight")
state_dict[prefix + "w13.weight"] = torch.cat([w1, w3])
def forward(self, x: torch.Tensor) -> torch.Tensor:
x13 = self.w13(x)
x1, x3 = x13.chunk(2, -1)
output = self.w2(F.silu(x1) * x3)
mp_utils.all_reduce(output)
return output
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_index: int):
super().__init__()
assert args.dim % args.n_heads == 0
head_dim = args.dim // args.n_heads
if args.n_kv_heads is not None:
n_kv_heads = args.n_kv_heads
else:
n_kv_heads = args.n_heads
mp_size = mp_utils.get_world_size()
assert args.n_heads % n_kv_heads == 0
assert args.n_heads % mp_size == 0
assert n_kv_heads % mp_size == 0
self.is_last_layer = layer_index + 1 == args.n_layers
self.attention = Attention(
dim=args.dim,
head_dim=head_dim,
n_heads=args.n_heads,
n_kv_heads=n_kv_heads,
rope_theta=args.rope_theta,
)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
cache: LayerCache,
attn_bias: AttnBias,
) -> torch.Tensor:
position_index = None
if self.is_last_layer and attn_bias.q_seqinfo.max_seqlen > 1:
position_index = attn_bias.q_seqinfo.seqstart[1:] - 1
h = self.attention.forward(
self.attention_norm(x),
cache,
attn_bias,
position_index=position_index,
)
if position_index is not None:
x = x[position_index]
h = h + x
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
mp_size = mp_utils.get_world_size()
assert args.dim % mp_size == 0
assert args.vocab_size > 0
assert args.vocab_size % mp_size == 0
self.tok_embeddings = nn.Embedding(
num_embeddings=args.vocab_size,
embedding_dim=args.dim // mp_size,
)
self.layers = nn.ModuleList()
for layer_index in range(args.n_layers):
self.layers.append(TransformerBlock(args, layer_index))
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(
args.dim,
args.vocab_size // mp_size,
bias=False,
)
@torch.no_grad()
def forward_with_attn_bias(
self,
token_values: torch.Tensor,
attn_bias: AttnBias,
cache: list[LayerCache],
) -> torch.Tensor:
h_parallel = self.tok_embeddings(token_values)
h = mp_utils.all_gather(h_parallel)
for i, layer in enumerate(self.layers):
h = layer(h, cache[i], attn_bias)
logits_parallel = self.output(self.norm(h))
logits = mp_utils.all_gather(logits_parallel)
return logits.float()
def forward(
self,
token_values: torch.Tensor,
token_lengths: torch.Tensor,
start_pos: torch.Tensor,
cache: list[LayerCache],
kv_padding: int,
) -> torch.Tensor:
attn_bias = AttnBias.from_seqlens(
q_seqlen=token_lengths.tolist(),
kv_seqlen=(start_pos + token_lengths).tolist(),
kv_padding=kv_padding,
)
return self.forward_with_attn_bias(token_values, attn_bias, cache)
def make_cache(
args: ModelArgs,
length: int,
device: Optional[Union[str, torch.device]] = None,
n_layers: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
) -> list[LayerCache]:
"""
Allocate a cache to be used with the Transformer module.
Args:
args (ModelArgs): the model configuration.
length (int): per layer cache size.
It is usually budgeted as ``max_batch * max_seq``
device (torch.device, optional): the device on which
the cache should be allocated.
n_layers (int, optional): the number of layers to
allocate a cache for (defaults to the model
settings).
dtype (torch.dtype, optional): the dtype to use for
cache entries (defaults to the default dtype).
Returns:
The cache object to pass to ``Tranformer.forward``.
"""
head_dim = args.dim // args.n_heads
n_kv_heads = args.n_kv_heads
if n_kv_heads is None:
n_kv_heads = args.n_heads
n_local_kv_heads = n_kv_heads // mp_utils.get_world_size()
if n_layers is None:
n_layers = args.n_layers
shape = (1, length, n_local_kv_heads, 1, head_dim)
heads_per_group = args.n_heads // n_kv_heads
expansion = (-1, -1, -1, heads_per_group, -1)
return [
(
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
torch.zeros(shape, device=device, dtype=dtype).expand(expansion),
)
for _ in range(n_layers)
]
def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
"""
Take a prefix view of a larger cache.
The original cache object remains of identical size and valid
after the shrinked alias has been used. This function is useful
when a cache was allocated for a larger batch size than what is
necessary.
Args:
cache: the cache to take a view in.
length (int): the desired length
Returns:
A view in the input cache object.
"""
if len(cache) > 0:
assert cache[0][0].shape[1] >= length
return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]