Spaces:
Paused
Paused
from fnmatch import fnmatch | |
from typing import Any, Dict, List, Optional, Union | |
import torch | |
from dataclasses import dataclass | |
from optimum.quanto.quantize import _quantize_submodule | |
from optimum.quanto.tensor import Optimizer, qtype, qtypes | |
from torchao.quantization.quant_api import ( | |
quantize_ as torchao_quantize_, | |
Float8WeightOnlyConfig, | |
UIntXWeightOnlyConfig | |
) | |
# the quantize function in quanto had a bug where it was using exclude instead of include | |
Q_MODULES = ['QLinear', 'QConv2d', 'QEmbedding', 'QBatchNorm2d', 'QLayerNorm', 'QConvTranspose2d', 'QEmbeddingBag'] | |
torchao_qtypes = { | |
# "int4": Int4WeightOnlyConfig(), | |
"uint2": UIntXWeightOnlyConfig(torch.uint2), | |
"uint3": UIntXWeightOnlyConfig(torch.uint3), | |
"uint4": UIntXWeightOnlyConfig(torch.uint4), | |
"uint5": UIntXWeightOnlyConfig(torch.uint5), | |
"uint6": UIntXWeightOnlyConfig(torch.uint6), | |
"uint7": UIntXWeightOnlyConfig(torch.uint7), | |
"uint8": UIntXWeightOnlyConfig(torch.uint8), | |
"float8": Float8WeightOnlyConfig(), | |
} | |
class aotype: | |
def __init__(self, name: str): | |
self.name = name | |
self.config = torchao_qtypes[name] | |
def get_qtype(qtype: Union[str, qtype]) -> qtype: | |
if qtype in torchao_qtypes: | |
return aotype(qtype) | |
if isinstance(qtype, str): | |
return qtypes[qtype] | |
else: | |
return qtype | |
def quantize( | |
model: torch.nn.Module, | |
weights: Optional[Union[str, qtype, aotype]] = None, | |
activations: Optional[Union[str, qtype]] = None, | |
optimizer: Optional[Optimizer] = None, | |
include: Optional[Union[str, List[str]]] = None, | |
exclude: Optional[Union[str, List[str]]] = None, | |
): | |
"""Quantize the specified model submodules | |
Recursively quantize the submodules of the specified parent model. | |
Only modules that have quantized counterparts will be quantized. | |
If include patterns are specified, the submodule name must match one of them. | |
If exclude patterns are specified, the submodule must not match one of them. | |
Include or exclude patterns are Unix shell-style wildcards which are NOT regular expressions. See | |
https://docs.python.org/3/library/fnmatch.html for more details. | |
Note: quantization happens in-place and modifies the original model and its descendants. | |
Args: | |
model (`torch.nn.Module`): the model whose submodules will be quantized. | |
weights (`Optional[Union[str, qtype]]`): the qtype for weights quantization. | |
activations (`Optional[Union[str, qtype]]`): the qtype for activations quantization. | |
include (`Optional[Union[str, List[str]]]`): | |
Patterns constituting the allowlist. If provided, module names must match at | |
least one pattern from the allowlist. | |
exclude (`Optional[Union[str, List[str]]]`): | |
Patterns constituting the denylist. If provided, module names must not match | |
any patterns from the denylist. | |
""" | |
if include is not None: | |
include = [include] if isinstance(include, str) else include | |
if exclude is not None: | |
exclude = [exclude] if isinstance(exclude, str) else exclude | |
for name, m in model.named_modules(): | |
if include is not None and not any(fnmatch(name, pattern) for pattern in include): | |
continue | |
if exclude is not None and any(fnmatch(name, pattern) for pattern in exclude): | |
continue | |
try: | |
# check if m is QLinear or QConv2d | |
if m.__class__.__name__ in Q_MODULES: | |
continue | |
else: | |
if isinstance(weights, aotype): | |
torchao_quantize_(m, weights.config) | |
else: | |
_quantize_submodule(model, name, m, weights=weights, | |
activations=activations, optimizer=optimizer) | |
except Exception as e: | |
print(f"Failed to quantize {name}: {e}") | |
raise e |