Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from .em import EM, EmptyClusterResolveError | |
class PQ(EM): | |
""" | |
Quantizes the layer weights W with the standard Product Quantization | |
technique. This learns a codebook of codewords or centroids of size | |
block_size from W. For further reference on using PQ to quantize | |
neural networks, see "And the Bit Goes Down: Revisiting the Quantization | |
of Neural Networks", Stock et al., ICLR 2020. | |
PQ is performed in two steps: | |
(1) The matrix W (weights or fully-connected or convolutional layer) | |
is reshaped to (block_size, -1). | |
- If W is fully-connected (2D), its columns are split into | |
blocks of size block_size. | |
- If W is convolutional (4D), its filters are split along the | |
spatial dimension. | |
(2) We apply the standard EM/k-means algorithm to the resulting reshaped matrix. | |
Args: | |
- W: weight matrix to quantize of size (in_features x out_features) | |
- block_size: size of the blocks (subvectors) | |
- n_centroids: number of centroids | |
- n_iter: number of k-means iterations | |
- eps: for cluster reassignment when an empty cluster is found | |
- max_tentatives for cluster reassignment when an empty cluster is found | |
- verbose: print information after each iteration | |
Remarks: | |
- block_size be compatible with the shape of W | |
""" | |
def __init__( | |
self, | |
W, | |
block_size, | |
n_centroids=256, | |
n_iter=20, | |
eps=1e-6, | |
max_tentatives=30, | |
verbose=True, | |
): | |
self.block_size = block_size | |
W_reshaped = self._reshape(W) | |
super(PQ, self).__init__( | |
W_reshaped, | |
n_centroids=n_centroids, | |
n_iter=n_iter, | |
eps=eps, | |
max_tentatives=max_tentatives, | |
verbose=verbose, | |
) | |
def _reshape(self, W): | |
""" | |
Reshapes the matrix W as expained in step (1). | |
""" | |
# fully connected: by convention the weight has size out_features x in_features | |
if len(W.size()) == 2: | |
self.out_features, self.in_features = W.size() | |
assert ( | |
self.in_features % self.block_size == 0 | |
), "Linear: n_blocks must be a multiple of in_features" | |
return ( | |
W.reshape(self.out_features, -1, self.block_size) | |
.permute(2, 1, 0) | |
.flatten(1, 2) | |
) | |
# convolutional: we reshape along the spatial dimension | |
elif len(W.size()) == 4: | |
self.out_channels, self.in_channels, self.k_h, self.k_w = W.size() | |
assert ( | |
self.in_channels * self.k_h * self.k_w | |
) % self.block_size == 0, ( | |
"Conv2d: n_blocks must be a multiple of in_channels * k_h * k_w" | |
) | |
return ( | |
W.reshape(self.out_channels, -1, self.block_size) | |
.permute(2, 1, 0) | |
.flatten(1, 2) | |
) | |
# not implemented | |
else: | |
raise NotImplementedError(W.size()) | |
def encode(self): | |
""" | |
Performs self.n_iter EM steps. | |
""" | |
self.initialize_centroids() | |
for i in range(self.n_iter): | |
try: | |
self.step(i) | |
except EmptyClusterResolveError: | |
break | |
def decode(self): | |
""" | |
Returns the encoded full weight matrix. Must be called after | |
the encode function. | |
""" | |
# fully connected case | |
if "k_h" not in self.__dict__: | |
return ( | |
self.centroids[self.assignments] | |
.reshape(-1, self.out_features, self.block_size) | |
.permute(1, 0, 2) | |
.flatten(1, 2) | |
) | |
# convolutional case | |
else: | |
return ( | |
self.centroids[self.assignments] | |
.reshape(-1, self.out_channels, self.block_size) | |
.permute(1, 0, 2) | |
.reshape(self.out_channels, self.in_channels, self.k_h, self.k_w) | |
) | |