Spaces:
Sleeping
Sleeping
Dit-document-layout-analysis
/
unilm
/decoding
/GAD
/fairseq
/modules
/quantization
/scalar
/utils.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
from operator import attrgetter | |
import torch.distributed as dist | |
import torch.nn as nn | |
from ..pq.utils import attrsetter, get_layers | |
from .modules import ActivationQuantizer, IntConv2d, IntEmbedding, IntLinear | |
MAPPING = {nn.Linear: IntLinear, nn.Embedding: IntEmbedding, nn.Conv2d: IntConv2d} | |
def quantize_model_(model, p=0.2, bits=8, update_step=3000): | |
""" | |
Replaces all modules with their scalar quantized counterpart and | |
registers hooks to quantize the post-ativations of those modules. | |
Args: | |
- model: a nn.Module | |
- p: amount of noise (0 for no noise, 1 to quantize all the weights/activations) | |
- bits: number of bits | |
- update_step: update quantization parameters every update_step steps | |
""" | |
# quantize all layers | |
quantized_layers = get_layers(model, "(.*?)") | |
for layer in quantized_layers: | |
# book-keeping | |
is_master_process = (not dist.is_initialized()) or ( | |
dist.is_initialized() and dist.get_rank() == 0 | |
) | |
# recover module | |
module = attrgetter(layer)(model) | |
if is_master_process: | |
logging.info( | |
f"Quantizing layer {layer} with bits={bits} and QuantNoise={p}" | |
) | |
# quantization params | |
q_params = { | |
"p": p, | |
"update_step": update_step, | |
"bits": bits, | |
"method": "histogram", | |
"counter": 0, | |
} | |
# instantiate the quantized counterpart | |
if isinstance(module, tuple(MAPPING.keys())): | |
QuantizedModule = MAPPING[module.__class__] | |
quantized_module = QuantizedModule.__new__(QuantizedModule) | |
params = module.__dict__ | |
params.update(q_params) | |
quantized_module.__dict__.update(params) | |
else: | |
if is_master_process: | |
logging.info(f"Module {module} not yet supported for quantization") | |
continue | |
# activation quantization | |
a_q = ActivationQuantizer(quantized_module, p=0, bits=bits, method="histogram") | |
# replace layer by its quantized counterpart | |
attrsetter(layer)(model, quantized_module) | |
# return name of quantized layers | |
return quantized_layers | |