# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import math import pytest import torch from xformers.components.attention import FavorAttention, ScaledDotProduct from xformers.components.attention.feature_maps import ( FeatureMapType, NormDistribution, SMHyperbolic, SMOrf, SMReg, ) _device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") @pytest.mark.parametrize("features", [SMOrf, SMHyperbolic, SMReg]) def test_random_matrix(features): torch.random.manual_seed(0) DRAWS = 100 DIM = 10 for _ in range(DRAWS): q = features._get_random_ortho_matrix( 1, DIM, device=_device, norm_distribution=NormDistribution.Xi ).squeeze(0) # Check that the matrix is indeed orthonormal torch.allclose( torch.diag(q @ q.transpose(0, 1)), torch.diag(torch.ones(10, device=_device)), ) # Check that the row norm is in the right ballpark (sqrt(dim)) assert abs(torch.mean(torch.norm(q, dim=1)).item() - math.sqrt(DIM)) < 1.0 def _plot_distribution(ortho_feature_map): # Debug helper, check the uniformity of the random matrix draws DRAWS = 1000 DIM = 50 q = ortho_feature_map._get_random_ortho_matrix(DRAWS, DIM, device=_device) x, y = [], [] for qq in q: # For every matrix, look at the real and imaginary eigen value e = torch.linalg.eigvals(qq) x.append(e.real) y.append(e.imag) # Ideally the repartition of the real and imaginary eigenvalues # should build a circle in the complex plane import matplotlib.pyplot as plt import seaborn as sns sns.kdeplot(x=torch.cat(x).cpu().numpy(), y=torch.cat(y).cpu().numpy()) plt.axis("equal") plt.savefig("kde.png") def _get_rng_data(device): emb = 10 batch_size = 2 seq_len = 20 num_heads = 1 shape = (batch_size * num_heads, seq_len, emb) return torch.randn(shape, device=device) def test_feature_map_shape(): # Check the delayed initialization of the feature map nb_random_features = 1000 batch = _get_rng_data(_device) att = FavorAttention( dropout=0.0, dim_features=nb_random_features, feature_map_type=FeatureMapType.SMOrf, ) _ = att(batch, batch, batch) assert att.feature_map.features.shape[0] == batch.shape[-1] assert att.feature_map.features.shape[1] == nb_random_features def test_feature_map_redraw(): # Check the delayed initialization of the feature map nb_random_features = 1000 batch = _get_rng_data(_device) def check(should_redraw: bool): att = FavorAttention( dropout=0.0, dim_features=nb_random_features, feature_map_type=FeatureMapType.SMOrf, iter_before_redraw=1 if should_redraw else 100, ) v0 = att(batch, batch, batch) assert att.feature_map is not None f0 = att.feature_map.features v1 = att(batch, batch, batch) f1 = att.feature_map.features # There should not have been a redraw after v0 assert should_redraw != torch.allclose(v0, v1) assert should_redraw != torch.allclose(f0, f1) # type: ignore check(should_redraw=True) check(should_redraw=False) @pytest.mark.parametrize("feature", ["sm_orf", "sm_hyp", "sm_reg"]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("normalize_inputs", [True, False]) @pytest.mark.parametrize("device", [_device]) def test_favor_approximation_accuracy(feature, causal, normalize_inputs, device): # Run two attentions in parallel, the normal scaled dot product and the favor approximation torch.random.manual_seed(0) query, key, value = ( _get_rng_data(device), _get_rng_data(device), _get_rng_data(device), ) for x in (query, key, value): x.requires_grad = True # Build the two attention heads sdp_attention = ScaledDotProduct(dropout=0.0, causal=causal).to(device) approx_attention = FavorAttention( dropout=0.0, causal=causal, dim_head=10, feature_map_type=FeatureMapType(feature), normalize_inputs=normalize_inputs, ).to(device) with torch.cuda.amp.autocast(enabled=_device.type == "cuda"): standard_attention_result = sdp_attention(query, key, value) approx_attention_result = approx_attention(query, key, value) mismatch = torch.mean( (standard_attention_result - approx_attention_result) ** 2 ).item() if causal: # FIXME(@lefaudeux) the causal case seems significantly worse, not obvious why, # could be worth investigating assert mismatch < 0.6 else: assert mismatch < 0.23 # Check trainability torch.sum(approx_attention_result).backward() if __name__ == "__main__": _plot_distribution(SMOrf)