Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
import random | |
from collections import Counter | |
import torch | |
class EM: | |
""" | |
EM algorithm used to quantize the columns of W to minimize | |
||W - W_hat||^2 | |
Args: | |
- W: weight matrix of size (in_features x out_features) | |
- n_iter: number of k-means iterations | |
- n_centroids: number of centroids (size of codebook) | |
- eps: for cluster reassignment when an empty cluster is found | |
- max_tentatives for cluster reassignment when an empty cluster is found | |
- verbose: print error after each iteration | |
Remarks: | |
- If one cluster is empty, the most populated cluster is split into | |
two clusters | |
- All the relevant dimensions are specified in the code | |
""" | |
def __init__( | |
self, W, n_centroids=256, n_iter=20, eps=1e-6, max_tentatives=30, verbose=True | |
): | |
self.W = W | |
self.n_centroids = n_centroids | |
self.n_iter = n_iter | |
self.eps = eps | |
self.max_tentatives = max_tentatives | |
self.verbose = verbose | |
self.centroids = torch.Tensor() | |
self.assignments = torch.Tensor() | |
self.objective = [] | |
def initialize_centroids(self): | |
""" | |
Initializes the centroids by sampling random columns from W. | |
""" | |
in_features, out_features = self.W.size() | |
indices = torch.randint( | |
low=0, high=out_features, size=(self.n_centroids,) | |
).long() | |
self.centroids = self.W[:, indices].t() # (n_centroids x in_features) | |
def step(self, i): | |
""" | |
There are two standard steps for each iteration: expectation (E) and | |
minimization (M). The E-step (assignment) is performed with an exhaustive | |
search and the M-step (centroid computation) is performed with | |
the exact solution. | |
Args: | |
- i: step number | |
Remarks: | |
- The E-step heavily uses PyTorch broadcasting to speed up computations | |
and reduce the memory overhead | |
""" | |
# assignments (E-step) | |
distances = self.compute_distances() # (n_centroids x out_features) | |
self.assignments = torch.argmin(distances, dim=0) # (out_features) | |
n_empty_clusters = self.resolve_empty_clusters() | |
# centroids (M-step) | |
for k in range(self.n_centroids): | |
W_k = self.W[:, self.assignments == k] # (in_features x size_of_cluster_k) | |
self.centroids[k] = W_k.mean(dim=1) # (in_features) | |
# book-keeping | |
obj = (self.centroids[self.assignments].t() - self.W).norm(p=2).item() | |
self.objective.append(obj) | |
if self.verbose: | |
logging.info( | |
f"Iteration: {i},\t" | |
f"objective: {obj:.6f},\t" | |
f"resolved empty clusters: {n_empty_clusters}" | |
) | |
def resolve_empty_clusters(self): | |
""" | |
If one cluster is empty, the most populated cluster is split into | |
two clusters by shifting the respective centroids. This is done | |
iteratively for a fixed number of tentatives. | |
""" | |
# empty clusters | |
counts = Counter(map(lambda x: x.item(), self.assignments)) | |
empty_clusters = set(range(self.n_centroids)) - set(counts.keys()) | |
n_empty_clusters = len(empty_clusters) | |
tentatives = 0 | |
while len(empty_clusters) > 0: | |
# given an empty cluster, find most populated cluster and split it into two | |
k = random.choice(list(empty_clusters)) | |
m = counts.most_common(1)[0][0] | |
e = torch.randn_like(self.centroids[m]) * self.eps | |
self.centroids[k] = self.centroids[m].clone() | |
self.centroids[k] += e | |
self.centroids[m] -= e | |
# recompute assignments | |
distances = self.compute_distances() # (n_centroids x out_features) | |
self.assignments = torch.argmin(distances, dim=0) # (out_features) | |
# check for empty clusters | |
counts = Counter(map(lambda x: x.item(), self.assignments)) | |
empty_clusters = set(range(self.n_centroids)) - set(counts.keys()) | |
# increment tentatives | |
if tentatives == self.max_tentatives: | |
logging.info( | |
f"Could not resolve all empty clusters, {len(empty_clusters)} remaining" | |
) | |
raise EmptyClusterResolveError | |
tentatives += 1 | |
return n_empty_clusters | |
def compute_distances(self): | |
""" | |
For every centroid m, computes | |
||M - m[None, :]||_2 | |
Remarks: | |
- We rely on PyTorch's broadcasting to speed up computations | |
and reduce the memory overhead | |
- Without chunking, the sizes in the broadcasting are modified as: | |
(n_centroids x n_samples x out_features) -> (n_centroids x out_features) | |
- The broadcasting computation is automatically chunked so that | |
the tensors fit into the memory of the GPU | |
""" | |
nb_centroids_chunks = 1 | |
while True: | |
try: | |
return torch.cat( | |
[ | |
(self.W[None, :, :] - centroids_c[:, :, None]).norm(p=2, dim=1) | |
for centroids_c in self.centroids.chunk( | |
nb_centroids_chunks, dim=0 | |
) | |
], | |
dim=0, | |
) | |
except RuntimeError: | |
nb_centroids_chunks *= 2 | |
def assign(self): | |
""" | |
Assigns each column of W to its closest centroid, thus essentially | |
performing the E-step in train(). | |
Remarks: | |
- The function must be called after train() or after loading | |
centroids using self.load(), otherwise it will return empty tensors | |
""" | |
distances = self.compute_distances() # (n_centroids x out_features) | |
self.assignments = torch.argmin(distances, dim=0) # (out_features) | |
def save(self, path, layer): | |
""" | |
Saves centroids and assignments. | |
Args: | |
- path: folder used to save centroids and assignments | |
""" | |
torch.save(self.centroids, os.path.join(path, "{}_centroids.pth".format(layer))) | |
torch.save( | |
self.assignments, os.path.join(path, "{}_assignments.pth".format(layer)) | |
) | |
torch.save(self.objective, os.path.join(path, "{}_objective.pth".format(layer))) | |
def load(self, path, layer): | |
""" | |
Loads centroids and assignments from a given path | |
Args: | |
- path: folder use to load centroids and assignments | |
""" | |
self.centroids = torch.load( | |
os.path.join(path, "{}_centroids.pth".format(layer)) | |
) | |
self.assignments = torch.load( | |
os.path.join(path, "{}_assignments.pth".format(layer)) | |
) | |
self.objective = torch.load( | |
os.path.join(path, "{}_objective.pth".format(layer)) | |
) | |
class EmptyClusterResolveError(Exception): | |
pass | |