|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import contextmanager |
|
|
|
import packaging.version |
|
import torch |
|
import transformers |
|
|
|
|
|
@contextmanager |
|
def gather_params_ctx(param, modifier_rank: int = 0, fwd_module: torch.nn.Module = None): |
|
"""Call DeepSpeed GatheredParameters context manager if DeepSpeed is enabled, otherwise do nothing.""" |
|
if packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.33.0"): |
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
else: |
|
from transformers.deepspeed import is_deepspeed_zero3_enabled |
|
|
|
if not is_deepspeed_zero3_enabled(): |
|
yield |
|
return |
|
|
|
import deepspeed |
|
|
|
with deepspeed.zero.GatheredParameters(param, modifier_rank=modifier_rank, fwd_module=fwd_module): |
|
yield |
|
return |
|
|
|
|
|
def dequantize_module_weight(module: torch.nn.Module) -> torch.nn.Parameter: |
|
""" |
|
Helper function to dequantize a quantized weight. |
|
|
|
This function should be extended if more quantization schemes are added to the library. |
|
|
|
If the weight is not quantized, it will be returned as is. |
|
""" |
|
if hasattr(module, "W_q"): |
|
weight = module.dequantize() |
|
return weight |
|
|
|
weight = module.weight |
|
if not isinstance(weight, torch.nn.Parameter): |
|
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") |
|
|
|
cls_name = weight.__class__.__name__ |
|
if cls_name not in ("Params4bit", "Int8Params"): |
|
return weight |
|
|
|
quant_state = getattr(module, "state", None) |
|
device = weight.device |
|
is_cpu = device.type == torch.device("cpu").type |
|
weight = dequantize_bnb_weight(weight, state=quant_state) |
|
if is_cpu: |
|
|
|
module.weight = module.weight.to(device) |
|
return weight |
|
|
|
|
|
def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): |
|
"""Helper function to dequantize 4bit or 8bit bnb weights. |
|
|
|
Since dequantization is not supported on CPU, the weight will be temporarily moved to CUDA if necessary. |
|
""" |
|
import bitsandbytes as bnb |
|
|
|
|
|
device = weight.device |
|
is_cpu = device.type == torch.device("cpu").type |
|
if is_cpu: |
|
weight = weight.to(torch.device("cuda")) |
|
|
|
cls_name = weight.__class__.__name__ |
|
if cls_name == "Params4bit": |
|
dequantized = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) |
|
if is_cpu: |
|
dequantized = dequantized.to(device) |
|
return dequantized |
|
|
|
if state.SCB is None: |
|
state.SCB = weight.SCB |
|
|
|
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) |
|
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) |
|
im, Sim = bnb.functional.transform(im, "col32") |
|
if state.CxB is None: |
|
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) |
|
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) |
|
dequantized = bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() |
|
if is_cpu: |
|
dequantized = dequantized.to(device) |
|
return dequantized |
|
|