Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
# | |
# This source code is licensed under the BSD license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
import pytest | |
import torch | |
from xformers.components.attention import FavorAttention, ScaledDotProduct | |
from xformers.components.attention.feature_maps import ( | |
FeatureMapType, | |
NormDistribution, | |
SMHyperbolic, | |
SMOrf, | |
SMReg, | |
) | |
_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
def test_random_matrix(features): | |
torch.random.manual_seed(0) | |
DRAWS = 100 | |
DIM = 10 | |
for _ in range(DRAWS): | |
q = features._get_random_ortho_matrix( | |
1, DIM, device=_device, norm_distribution=NormDistribution.Xi | |
).squeeze(0) | |
# Check that the matrix is indeed orthonormal | |
torch.allclose( | |
torch.diag(q @ q.transpose(0, 1)), | |
torch.diag(torch.ones(10, device=_device)), | |
) | |
# Check that the row norm is in the right ballpark (sqrt(dim)) | |
assert abs(torch.mean(torch.norm(q, dim=1)).item() - math.sqrt(DIM)) < 1.0 | |
def _plot_distribution(ortho_feature_map): | |
# Debug helper, check the uniformity of the random matrix draws | |
DRAWS = 1000 | |
DIM = 50 | |
q = ortho_feature_map._get_random_ortho_matrix(DRAWS, DIM, device=_device) | |
x, y = [], [] | |
for qq in q: | |
# For every matrix, look at the real and imaginary eigen value | |
e = torch.linalg.eigvals(qq) | |
x.append(e.real) | |
y.append(e.imag) | |
# Ideally the repartition of the real and imaginary eigenvalues | |
# should build a circle in the complex plane | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
sns.kdeplot(x=torch.cat(x).cpu().numpy(), y=torch.cat(y).cpu().numpy()) | |
plt.axis("equal") | |
plt.savefig("kde.png") | |
def _get_rng_data(device): | |
emb = 10 | |
batch_size = 2 | |
seq_len = 20 | |
num_heads = 1 | |
shape = (batch_size * num_heads, seq_len, emb) | |
return torch.randn(shape, device=device) | |
def test_feature_map_shape(): | |
# Check the delayed initialization of the feature map | |
nb_random_features = 1000 | |
batch = _get_rng_data(_device) | |
att = FavorAttention( | |
dropout=0.0, | |
dim_features=nb_random_features, | |
feature_map_type=FeatureMapType.SMOrf, | |
) | |
_ = att(batch, batch, batch) | |
assert att.feature_map.features.shape[0] == batch.shape[-1] | |
assert att.feature_map.features.shape[1] == nb_random_features | |
def test_feature_map_redraw(): | |
# Check the delayed initialization of the feature map | |
nb_random_features = 1000 | |
batch = _get_rng_data(_device) | |
def check(should_redraw: bool): | |
att = FavorAttention( | |
dropout=0.0, | |
dim_features=nb_random_features, | |
feature_map_type=FeatureMapType.SMOrf, | |
iter_before_redraw=1 if should_redraw else 100, | |
) | |
v0 = att(batch, batch, batch) | |
assert att.feature_map is not None | |
f0 = att.feature_map.features | |
v1 = att(batch, batch, batch) | |
f1 = att.feature_map.features | |
# There should not have been a redraw after v0 | |
assert should_redraw != torch.allclose(v0, v1) | |
assert should_redraw != torch.allclose(f0, f1) # type: ignore | |
check(should_redraw=True) | |
check(should_redraw=False) | |
def test_favor_approximation_accuracy(feature, causal, normalize_inputs, device): | |
# Run two attentions in parallel, the normal scaled dot product and the favor approximation | |
torch.random.manual_seed(0) | |
query, key, value = ( | |
_get_rng_data(device), | |
_get_rng_data(device), | |
_get_rng_data(device), | |
) | |
for x in (query, key, value): | |
x.requires_grad = True | |
# Build the two attention heads | |
sdp_attention = ScaledDotProduct(dropout=0.0, causal=causal).to(device) | |
approx_attention = FavorAttention( | |
dropout=0.0, | |
causal=causal, | |
dim_head=10, | |
feature_map_type=FeatureMapType(feature), | |
normalize_inputs=normalize_inputs, | |
).to(device) | |
with torch.cuda.amp.autocast(enabled=_device.type == "cuda"): | |
standard_attention_result = sdp_attention(query, key, value) | |
approx_attention_result = approx_attention(query, key, value) | |
mismatch = torch.mean( | |
(standard_attention_result - approx_attention_result) ** 2 | |
).item() | |
if causal: | |
# FIXME(@lefaudeux) the causal case seems significantly worse, not obvious why, | |
# could be worth investigating | |
assert mismatch < 0.6 | |
else: | |
assert mismatch < 0.23 | |
# Check trainability | |
torch.sum(approx_attention_result).backward() | |
if __name__ == "__main__": | |
_plot_distribution(SMOrf) | |