Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import re | |
import warnings | |
from typing import Callable | |
import torch | |
# avoid division by zero when calculating scale | |
EPS = 1e-12 | |
def scale(t, amax_t, dtype_t): | |
min_v, max_v = torch.finfo(dtype_t).min, torch.finfo(dtype_t).max | |
scale_t = torch.clamp(amax_t.float(), min=EPS) / max_v | |
t_fp8 = (t / scale_t).clamp(min=min_v, max=max_v).to(dtype_t) | |
return t_fp8, scale_t | |
def matmul( | |
first, amax_first, dtype_first, second_t, amax_second_t, dtype_second_t, bias | |
): | |
first_fp8, scale_first = scale(first, amax_first, dtype_first) | |
second_t_fp8, scale_second_t = scale(second_t, amax_second_t, dtype_second_t) | |
output = torch._scaled_mm( | |
first_fp8, | |
second_t_fp8.t(), | |
scale_a=scale_first, | |
scale_b=scale_second_t.t(), | |
bias=bias, | |
out_dtype=torch.bfloat16, | |
use_fast_accum=True, | |
) | |
return output | |
class Fp8LinearFn(torch.autograd.Function): | |
def forward(ctx, a, b_t, bias): | |
amax_a = a.abs().amax(dim=-1, keepdim=True) | |
amax_b_t = b_t.abs().amax(dim=-1, keepdim=True) | |
out = matmul( | |
a, amax_a, torch.float8_e4m3fn, b_t, amax_b_t, torch.float8_e4m3fn, bias | |
) | |
ctx.a_requires_grad = a.requires_grad | |
ctx.b_requires_grad = b_t.requires_grad | |
ctx.bias_requires_grad = bias.requires_grad if bias is not None else False | |
ctx.save_for_backward(a, b_t, amax_b_t.max()) | |
return out | |
def backward(ctx, grad_out): | |
a, b_t, amax_b = ctx.saved_tensors | |
if ctx.a_requires_grad: | |
b = b_t.t().contiguous() | |
amax_grad_out = grad_out.abs().amax(dim=-1, keepdim=True) | |
amax_b = amax_b.repeat(b.shape[0], 1) | |
grad_a = matmul( | |
grad_out, | |
amax_grad_out, | |
torch.float8_e4m3fn, | |
b, | |
amax_b, | |
torch.float8_e4m3fn, | |
None, | |
) | |
else: | |
grad_a = None | |
if ctx.b_requires_grad: | |
grad_b = grad_out.t() @ a | |
else: | |
grad_b = None | |
if ctx.bias_requires_grad: | |
grad_bias = grad_out.sum(dim=0) | |
else: | |
grad_bias = None | |
return grad_a, grad_b, grad_bias | |
class Fp8Linear(torch.nn.Linear): | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
out = Fp8LinearFn.apply(input.flatten(end_dim=-2), self.weight, self.bias) | |
out = out.unflatten(0, input.shape[:-1]) | |
return out | |
def named_replace( | |
fn: Callable[[torch.nn.Module, str], torch.nn.Module], | |
module: torch.nn.Module, | |
name="", | |
) -> torch.nn.Module: | |
for child_name, child_module in list(module.named_children()): | |
full_name = f"{name}.{child_name}" if name else child_name | |
new_child_module = named_replace(fn, child_module, full_name) | |
setattr(module, child_name, new_child_module) | |
module = fn(module, name) | |
return module | |
def convert_linears_to_fp8( | |
root_module: torch.nn.Module, recipe: str, filter: str | |
) -> torch.nn.Module: | |
if recipe not in ["rowwise"]: | |
raise RuntimeError(f"Unknown float8 recipe {recipe!r}") | |
if recipe == "rowwise" and torch.__version__ < "2.5": | |
# We need https://github.com/pytorch/pytorch/pull/134781. | |
warnings.warn("Float8 row-wise scaling is slow in PyTorch prior to v2.5.0") | |
# Multi-kernel makes Inductor auto-tune between a regular "streaming"-based | |
# reduction kernel and a "persistent" reduction kernel. Since fp8 has some | |
# multi-pass steps (e.g., first get amax, then scale), persistent kernels | |
# should perform better. | |
torch._inductor.config.triton.multi_kernel = 1 | |
filter_re = re.compile(filter) | |
def replace(module: torch.nn.Module, name: str) -> torch.nn.Module: | |
if not isinstance(module, torch.nn.Linear) or not filter_re.search(name): | |
return module | |
if type(module) == torch.nn.Linear: | |
if recipe == "rowwise": | |
new_module = Fp8Linear( | |
in_features=module.in_features, | |
out_features=module.out_features, | |
bias=module.bias is not None, | |
dtype=module.weight.dtype, | |
device=module.weight.device, | |
) | |
new_module.weight = module.weight | |
new_module.bias = module.bias | |
else: | |
assert False, recipe | |
else: | |
assert False, str(type(module)) | |
return new_module | |
out = named_replace(replace, root_module) | |
# Force re-compile everything | |
torch._dynamo.reset_code_caches() | |
from torch._inductor.cudagraph_trees import reset_cudagraph_trees | |
reset_cudagraph_trees() | |
return out | |