Flexstorydiff / xformers /stubs /torch_stub_tests.py
FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Tuple, TypeVar
import torch
import torch.nn as nn
from pyre_extensions import TypeVarTuple, Unpack
from torch import Tensor
from typing_extensions import Literal as L
Ts = TypeVarTuple("Ts")
N = TypeVar("N", bound=int)
# flake8: noqa
"""
Tensor shape signatures can get complicated and hard to debug. We are basically
writing code at the level of types.
It's helpful to have type-level unit tests for the stubs.
Take care to add both a positive and a negative test for your stub. That way,
even if someone changes the stub to return a bad type like `Any`, we will still
be warned by an unused-ignore error. Otherwise, `y: Tensor[int, L[2], L[3]] =
foo(x)` would silently pass because `Any` is compatible with any type.
Use `pyre --output=json | pyre-upgrade` to add the `pyre-fixme` comment for you.
"""
def test_sin() -> None:
x: Tensor[int, L[2], L[3]]
same_shape_as_x: Tensor[int, L[2], L[3]]
not_same_shape_as_x: Tensor[int, L[2], L[99]]
y: Tensor[int, L[2], L[3]] = torch.sin(x)
# pyre-fixme[9]: y2 has type `Tensor[int, typing_extensions.Literal[2],
# typing_extensions.Literal[4]]`; used as `Tensor[int,
# typing_extensions.Literal[2], typing_extensions.Literal[3]]`.
y2: Tensor[int, L[2], L[4]] = torch.sin(x)
y3: Tensor[int, L[2], L[3]] = torch.sin(x, out=same_shape_as_x)
# pyre-fixme[6]: Expected `Tensor[Variable[torch.DType], *torch.Ts]` for 2nd
# param but got `Tensor[int, int, int]`.
# pyre-fixme[9]: y4 has type `Tensor[int, typing_extensions.Literal[2],
# typing_extensions.Literal[4]]`; used as `Tensor[int,
# typing_extensions.Literal[2], typing_extensions.Literal[3]]`.
y4: Tensor[int, L[2], L[4]] = torch.sin(x, out=not_same_shape_as_x)
y5: Tensor[int, L[2], L[3]] = torch.sin(x, out=None)
def test_unsqueeze() -> None:
x: Tensor[int, L[2], L[3]]
y: Tensor[int, L[1], L[2], L[3]] = x.unsqueeze(0)
y_torch_function: Tensor[int, L[1], L[2], L[3]] = torch.unsqueeze(x, 0)
y2: Tensor[int, L[2], L[1], L[3]] = x.unsqueeze(1)
y3: Tensor[int, L[2], L[3], L[1]] = x.unsqueeze(-1)
# pyre-fixme[9]: y4 has type `Tensor[int, typing_extensions.Literal[99]]`; used
# as `Tensor[int, typing_extensions.Literal[1], typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y4: Tensor[int, L[99]] = x.unsqueeze(0)
empty: Tensor[int]
y5: Tensor[int, L[1]] = empty.unsqueeze(0)
# pyre-fixme[6]: Expected `typing_extensions.Literal[0]` for 1st param but got
# `typing_extensions.Literal[1]`.
y6: Tensor[int, L[1]] = empty.unsqueeze(1)
y7: Tensor[int, L[2], L[3], L[1]] = x.unsqueeze(2)
def test_unsqueeze_() -> None:
x: Tensor[int, L[2], L[3]]
y: Tensor[int, L[1], L[2], L[3]] = x.unsqueeze_(0)
y_error: Tensor[int, L[1], L[2], L[3]] = x.unsqueeze_(0)
# pyre-ignore[9]: `unsqueeze_` is an in-place shape-transforming function. But Pyre cannot
# update a variable's shape type.
z: Tensor[int, L[1], L[2], L[3]] = x
def test_squeeze_() -> None:
x: Tensor[int, L[1], L[2], L[3]]
out: Tensor
y: Tensor[int, L[2], L[3]] = x.squeeze_(out=out)
# pyre-ignore[9]: Expected error.
y_error: Tensor[int, L[2], L[99]] = x.squeeze_()
y2: Tensor[int, L[2], L[3]] = x.squeeze_().squeeze_()
x2: Tensor[int, L[2], L[3], L[1], L[1]]
x3: Tensor[int, L[2], L[3], L[1]]
y3: Tensor[int, L[2], L[3]] = x2.squeeze_()
y4: Tensor[int, L[2], L[3]] = x3.squeeze_()
y5: Tensor[int, L[2], L[3]] = x.squeeze_(0)
y6: Tensor[int, L[2], L[3], L[1]] = x2.squeeze_(-1)
def test_squeeze() -> None:
x: Tensor[int, L[1], L[2], L[3]]
out: Tensor
y: Tensor[int, L[2], L[3]] = x.squeeze(out=out)
# pyre-ignore[9]: Expected error.
y_error: Tensor[int, L[2], L[99]] = x.squeeze()
y2: Tensor[int, L[2], L[3]] = x.squeeze().squeeze()
x2: Tensor[int, L[2], L[3], L[1], L[1]]
x3: Tensor[int, L[2], L[3], L[1]]
y3: Tensor[int, L[2], L[3]] = x2.squeeze()
y4: Tensor[int, L[2], L[3]] = x3.squeeze()
y5: Tensor[int, L[2], L[3]] = x.squeeze(0)
y6: Tensor[int, L[2], L[3], L[1]] = x2.squeeze(-1)
def test_repeat() -> None:
x: Tensor[int, L[2], L[3]]
y: Tensor[int, L[8], L[15]] = x.repeat(4, 5)
# pyre-fixme[9]
y2: Tensor[int, L[8], L[16]] = x.repeat(4, 5)
# TODO(T96315150): This is passing by coincidence right now.
y3: Tensor[int, L[4], L[10], L[18]] = x.repeat(4, 5, 6)
# pyre-ignore[9]: Doesn't error as expected because we have limited overloads.
y3_error: Tensor[int, L[4], L[10], L[99]] = x.repeat(4, 5, 6)
# pyre-ignore[9, 19]
not_yet_supported: Tensor[int, L[4], L[5], L[12], L[21]] = x.repeat(4, 5, 6, 7)
# Fewer dimensions than the Tensor. Should raise a different error.
x.repeat(2)
one_dimension: Tensor[int, L[2]]
y4: Tensor[int, L[8]] = x.repeat(4)
# pyre-ignore[9]
y4_error: Tensor[int, L[99]] = x.repeat(4)
def test_multiply() -> None:
x: Tensor[torch.int64, L[2], L[3]]
y: Tensor[torch.float32, L[2], L[3]] = x * 2
# pyre-fixme[9]: y_error has type `Tensor[torch.bool,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: Tensor[torch.bool, L[2], L[99]] = x * 2
y2: Tensor[torch.float32, L[2], L[3]] = 2 * x
# pyre-fixme[9]: y2_error has type `Tensor[torch.bool,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y2_error: Tensor[torch.bool, L[2], L[99]] = 2 * x
y3: Tensor[torch.float32, L[2], L[3]] = x * 2.0
# pyre-fixme[9]: y3_error has type `Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[4]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y3_error: Tensor[torch.float32, L[2], L[4]] = x * 2.0
z: Tensor[torch.int64, L[4], L[1], L[1]]
z_bad: Tensor[torch.int64, L[4], L[2], L[99]]
y4: Tensor[torch.int64, L[4], L[2], L[3]] = x * z
# pyre-fixme[2001]: Broadcast error at expression `x.__mul__(z_bad)`; types
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[3]]` and
# `Tuple[typing_extensions.Literal[4], typing_extensions.Literal[2],
# typing_extensions.Literal[99]]` cannot be broadcasted together.
x * z_bad
x4: Tensor[torch.float32, L[2], L[3]]
x5: Tensor[torch.float32, L[2], L[3]]
x5_bad: Tensor[torch.float32, L[2], L[99]]
x4 *= x5
x4 *= 4
y5: Tensor[torch.float32, L[2], L[3]] = x5
# pyre-fixme[2001]: Broadcast error at expression `x4.__imul__(x5_bad)`; types
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[3]]` and
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[99]]` cannot be
# broadcasted together.
x4 *= x5_bad
def test_floor_division() -> None:
x: Tensor[torch.int64, L[2], L[3]]
x2: Tensor[torch.int64, L[2], L[1]]
y: Tensor[torch.int64, L[2], L[3]] = x // 2
# pyre-fixme[9]: y_error has type `Tensor[torch.bool,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: Tensor[torch.bool, L[2], L[99]] = x // 2
y2: Tensor[torch.int64, L[2], L[3]] = 2 // x
# pyre-fixme[9]: y2_error has type `Tensor[torch.bool,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y2_error: Tensor[torch.bool, L[2], L[99]] = 2 // x
y3: Tensor[torch.int64, L[2], L[3]] = x // x2
x3: Tensor[torch.float32, L[2], L[3]]
x4: Tensor[torch.float32, L[2], L[3]]
x4_bad: Tensor[torch.float32, L[2], L[99]]
x3 //= x4
x3 //= 4
y5: Tensor[torch.float32, L[2], L[3]] = x3
# pyre-fixme[2001]: Broadcast error at expression `x3.__ifloordiv__(x4_bad)`;
# types `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[3]]` and
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[99]]` cannot be
# broadcasted together.
x3 //= x4_bad
def test_division() -> None:
x: Tensor[torch.int64, L[2], L[3]]
x2: Tensor[torch.int64, L[2], L[1]]
y: Tensor[torch.float32, L[2], L[3]] = x / 2
# pyre-fixme[9]: y_error has type `Tensor[torch.bool,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: Tensor[torch.bool, L[2], L[99]] = x / 2
y2: Tensor[torch.float32, L[2], L[3]] = 2 / x
# pyre-fixme[9]: y2_error has type `Tensor[torch.bool,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y2_error: Tensor[torch.bool, L[2], L[99]] = 2 / x
x3: Tensor[torch.float32, L[2], L[3]]
y3: Tensor[torch.float32, L[2], L[3]] = x3 / 2
y4: Tensor[torch.float32, L[2], L[3]] = 2 / x3
y5: Tensor[torch.float32, L[2], L[3]] = x / x2
x5: Tensor[torch.float32, L[2], L[3]]
x6: Tensor[torch.float32, L[2], L[3]]
x6_bad: Tensor[torch.float32, L[2], L[99]]
x5 /= x6
x5 /= 4
y6: Tensor[torch.float32, L[2], L[3]] = x5
# pyre-fixme[2001]: Broadcast error at expression `x5.__itruediv__(x6_bad)`;
# types `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[3]]` and
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[99]]` cannot be
# broadcasted together.
x5 /= x6_bad
def test_setitem() -> None:
x: Tensor[torch.int64, L[2], L[3]]
x[0, 0] = 1
def test_arange(n: N) -> None:
y: Tensor[torch.int64, L[5]] = torch.arange(5)
# pyre-fixme[9]: y_error has type `Tensor[torch.int64,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.int64,
# typing_extensions.Literal[5]]`.
y_error: Tensor[torch.int64, L[99]] = torch.arange(5)
y2: Tensor[torch.int64, L[4]] = torch.arange(1, 5)
y3: Tensor[torch.int64, L[2]] = torch.arange(1, 6, 2)
y_float: Tensor[torch.float32, L[5]] = torch.arange(5, dtype=torch.float32)
y_float2: Tensor[torch.float32, L[2]] = torch.arange(1, 6, 2, dtype=torch.float32)
device: torch.device
y_generic: Tensor[torch.float32, N] = torch.arange(
0, n, device=device, dtype=torch.float32
)
# pyre-fixme[9]: Expected error.
y_generic_error: Tensor[torch.float32, L[99]] = torch.arange(
0, n, device=device, dtype=torch.float32
)
def test_embedding() -> None:
embedding = nn.Embedding(10, 20)
y: Tensor[torch.float32, L[10], L[20]] = embedding.weight
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[10], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[10],
# typing_extensions.Literal[20]]`.
y_error: Tensor[torch.float32, L[10], L[99]] = embedding.weight
x: Tensor[torch.float32, L[2], L[3], L[4]]
y2: Tensor[torch.float32, L[2], L[3], L[4], L[20]] = embedding(x)
# pyre-fixme[9]: y2_error has type `Tensor[torch.float32, typing_extensions.Liter...
y2_error: Tensor[torch.float32, L[2], L[3], L[4], L[99]] = embedding(x)
weight: Tensor[torch.float32, L[3], L[4]]
embedding2: nn.Embedding[L[3], L[4]] = nn.Embedding.from_pretrained(weight)
# pyre-fixme[9]: embedding2_error has type
# `Embedding[typing_extensions.Literal[3], typing_extensions.Literal[99]]`; used
# as `Embedding[typing_extensions.Literal[3], typing_extensions.Literal[4]]`.
embedding2_error: nn.Embedding[L[3], L[99]] = nn.Embedding.from_pretrained(weight)
y3: Tensor[torch.float32, L[2], L[3], L[4], L[4]] = embedding2(x)
def test_init_normal() -> None:
x: Tensor[torch.float32, L[5], L[10]]
y: Tensor[torch.float32, L[5], L[10]] = nn.init.normal_(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[5], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[5],
# typing_extensions.Literal[10]]`.
y_error: Tensor[torch.float32, L[5], L[99]] = nn.init.normal_(x)
def test_view() -> None:
x: Tensor[torch.float32, L[4], L[4]]
y: Tensor[torch.float32, L[16]] = x.view(16)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.float32,
# typing_extensions.Literal[16]]`.
y_error: Tensor[torch.float32, L[99]] = x.view(16)
# Should be an error because 4 * 4 != 99. Don't think this is going to be
# feasible any time soon.
y_error2: Tensor[torch.float32, L[99]] = x.view(99)
y_error3: Tensor[torch.float32, L[2], L[3], L[4], L[5]] = x.view(2, 3, 4, 5)
y2: Tensor[torch.float32, L[2], L[8]] = x.view(-1, 8)
# pyre-fixme[9]: y2_error has type `Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[8]]`.
y2_error: Tensor[torch.float32, L[2], L[99]] = x.view(-1, 8)
x3: Tensor[torch.float32, L[2], L[3], L[4]]
y3: Tensor[torch.float32, L[24]] = x3.view(-1)
y4: Tensor[torch.float32, L[8], L[3]] = x3.view(-1, 3)
# pyre-fixme[9]: y4_error has type `Tensor[torch.float32,
# typing_extensions.Literal[99], typing_extensions.Literal[3]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[8],
# typing_extensions.Literal[3]]`.
y4_error: Tensor[torch.float32, L[99], L[3]] = x3.view(-1, 3)
y5: Tensor[torch.float32, L[2], L[6], L[2]] = x3.view(2, -1, 2)
x4: Tensor[torch.float32, L[2], L[3], L[4], L[5]]
y6: Tensor[torch.float32, L[3], L[5], L[8]] = x4.view(3, 5, -1)
def test_reshape() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[24]] = torch.reshape(x, (-1,))
y2: Tensor[torch.float32, L[8], L[3]] = torch.reshape(x, (-1, 3))
# pyre-fixme[9]: y2_error has type `Tensor[torch.float32,
# typing_extensions.Literal[99], typing_extensions.Literal[3]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[8],
# typing_extensions.Literal[3]]`.
y2_error: Tensor[torch.float32, L[99], L[3]] = torch.reshape(x, (-1, 3))
y3: Tensor[torch.float32, L[6], L[2], L[2]] = torch.reshape(x, (-1, 2, 2))
y4: Tensor[torch.float32, L[2], L[6], L[2]] = torch.reshape(x, (2, -1, 2))
y5: Tensor[torch.float32, L[4], L[3], L[2]] = torch.reshape(x, (4, 3, 2))
def test_transpose() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4], L[5], L[6]]
y: Tensor[torch.float32, L[2], L[3], L[4], L[6], L[5]] = x.transpose(-2, -1)
y_function: Tensor[torch.float32, L[2], L[3], L[4], L[6], L[5]] = torch.transpose(
x, -2, -1
)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: Tensor[torch.float32, L[2], L[4], L[99]] = x.transpose(-2, -1)
y2: Tensor[torch.float32, L[2], L[4], L[3], L[5], L[6]] = x.transpose(1, 2)
y3: Tensor[torch.float32, L[3], L[2], L[4], L[5], L[6]] = x.transpose(0, 1)
y4: Tensor[torch.float32, L[3], L[2], L[4], L[5], L[6]] = x.transpose(1, 0)
y5: Tensor[torch.float32, L[2], L[3], L[4], L[6], L[5]] = x.transpose(-1, -2)
not_yet_supported: Tensor[
torch.float32,
L[3],
L[2],
L[4],
L[5],
L[6]
# pyre-fixme[6]: Expected `typing_extensions.Literal[0]` for 2nd param but got
# `typing_extensions.Literal[4]`.
] = x.transpose(1, 4)
def test_flatten() -> None:
x: Tensor[torch.float32, L[2], L[3]]
x_large: Tensor[torch.float32, L[2], L[3], L[4], L[5]]
y: Tensor[torch.float32, L[6]] = x.flatten()
y_default: Tensor[torch.float32, L[6]] = torch.flatten(x)
y_large: Tensor[torch.float32, L[120]] = x_large.flatten()
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.float32,
# typing_extensions.Literal[6]]`.
y_error: Tensor[torch.float32, L[99]] = x.flatten()
z: Tensor[torch.float32, L[2], L[3], L[4]]
y2: Tensor[torch.float32, L[6], L[4]] = z.flatten(0, 1)
y2_keyword: Tensor[torch.float32, L[6], L[4]] = z.flatten(start_dim=0, end_dim=1)
y3: Tensor[torch.float32, L[2], L[12]] = z.flatten(1, 2)
y3_large: Tensor[torch.float32, L[2], L[12], L[5]] = x_large.flatten(1, 2)
y4: Tensor[torch.float32, L[2], L[3], L[20]] = x_large.flatten(2, 3)
x_6d: Tensor[torch.float32, L[2], L[3], L[4], L[5], L[6], L[7]]
y4_large: Tensor[torch.float32, L[2], L[3], L[20], L[6], L[7]] = x_6d.flatten(2, 3)
# Out of bounds.
# pyre-fixme[9]: y5_error has type `Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[12]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[6]]`.
# pyre-fixme[6]: Expected `typing_extensions.Literal[0]` for 1st param but got
# `typing_extensions.Literal[99]`.
y5_error: Tensor[torch.float32, L[2], L[12]] = x.flatten(99, 100)
x_0d: Tensor[torch.float32]
y_0d: Tensor[torch.float32, L[1]] = x_0d.flatten()
def test_empty() -> None:
x: Tuple[L[1], L[2], L[3]]
y: Tensor
device: torch.device
result1: torch.Tensor[torch.float32, L[1], L[2], L[3]] = torch.empty(
*x,
device=device,
layout=torch.strided,
requires_grad=True,
out=y,
pin_memory=False,
memory_format=torch.memory_format(),
)
# pyre-fixme[9]: bad1 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad1: torch.Tensor[torch.float32, L[99], L[2], L[3]] = torch.empty(*x)
result2: torch.Tensor[torch.float32, L[1], L[2], L[3]] = torch.empty(
*x, device=device, layout=torch.strided, requires_grad=True, out=y
)
# pyre-fixme[9]: bad2 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad2: torch.Tensor[torch.float32, L[99], L[2], L[3]] = torch.empty(
*x, device=device, layout=torch.strided, requires_grad=True, out=y
)
result4: torch.Tensor[torch.float32, L[1], L[2], L[3]] = torch.empty(x)
result5: torch.Tensor[torch.float32, L[4]] = torch.empty(4)
result6: torch.Tensor[torch.int64, L[1], L[2], L[3]] = torch.empty(
x, dtype=torch.int64
)
result7: torch.Tensor[torch.int64, L[1], L[2], L[3]] = torch.empty(
*x, dtype=torch.int64
)
def test_empty_like() -> None:
x: torch.Tensor[torch.float32, L[1], L[2], L[3]]
out: Tensor
device: torch.device
y1: torch.Tensor[torch.float32, L[1], L[2], L[3]] = torch.empty_like(
x, device=device, layout=torch.strided, requires_grad=True, out=out
)
# pyre-fixme[9]: Expected error.
y1_error: torch.Tensor[torch.float32, L[99], L[2], L[3]] = torch.empty_like(
x, device=device, layout=torch.strided, requires_grad=True, out=out
)
y2: torch.Tensor[torch.int64, L[1], L[2], L[3]] = torch.empty_like(
x,
dtype=torch.int64,
device=device,
layout=torch.strided,
requires_grad=True,
out=out,
)
def test_randn() -> None:
x: Tuple[L[1], L[2], L[3]]
y: Tensor
device: torch.device
result1: torch.Tensor[torch.float32, L[1], L[2], L[3]] = torch.randn(
*x, device=device, layout=torch.strided, requires_grad=True, out=y
)
# pyre-fixme[9]: bad1 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad1: torch.Tensor[torch.float32, L[99], L[2], L[3]] = torch.randn(*x)
result2: torch.Tensor[torch.float32, L[1], L[2], L[3]] = torch.randn(
*x, device=device, layout=torch.strided, requires_grad=True, out=y
)
# pyre-fixme[9]: bad2 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad2: torch.Tensor[torch.float32, L[99], L[2], L[3]] = torch.randn(
*x, device=device, layout=torch.strided, requires_grad=True, out=y
)
result4: torch.Tensor[torch.float32, L[1], L[2], L[3]] = torch.randn(x)
result5: torch.Tensor[torch.float32, L[4]] = torch.randn(4)
result6: torch.Tensor[torch.int64, L[1], L[2], L[3]] = torch.randn(
x, dtype=torch.int64
)
result7: torch.Tensor[torch.int64, L[1], L[2], L[3]] = torch.randn(
*x, dtype=torch.int64
)
def test_all() -> None:
x: torch.Tensor[torch.float32, L[1], L[2], L[3]]
device: torch.device
y: torch.Tensor[torch.bool, L[1]] = torch.all(x)
# pyre-fixme[9]: bad1 has type `Tensor[torch.bool,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.bool,
# typing_extensions.Literal[1]]`.
y_error: torch.Tensor[torch.bool, L[99]] = torch.all(x)
y2: torch.Tensor[torch.bool, L[2], L[3]] = torch.all(x, dim=0)
y3: torch.Tensor[torch.bool, L[1], L[3]] = torch.all(x, dim=1)
y4: torch.Tensor[torch.bool, L[1]] = x.all()
def test_where() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
good: Tuple[torch.LongTensor[int, int], torch.LongTensor[int, int]] = torch.where(x)
bad: Tuple[
torch.LongTensor[int, int], torch.LongTensor[int, int], L[99]
] = torch.where(x)
y: torch.Tensor[torch.float32, L[2], L[1]]
not_broadcastable: torch.Tensor[torch.float32, L[2], L[99]]
good: Tuple[torch.LongTensor[int, int], torch.LongTensor[int, int]] = torch.where(x)
good2: torch.Tensor[torch.float32, L[2], L[3]] = torch.where(x > 0, x, y)
# pyre-fixme[9]: bad2 has type `Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
bad2: torch.Tensor[torch.float32, L[2], L[99]] = torch.where(x > 0, x, y)
# pyre-fixme[2001]: Broadcast error at expression `torch.where(x > 0, x,
# not_broadcastable)`; types `Tuple[typing_extensions.Literal[2],
# typing_extensions.Literal[3]]` and `Tuple[typing_extensions.Literal[2],
# typing_extensions.Literal[99]]` cannot be broadcasted together.
z = torch.where(x > 0, x, not_broadcastable)
def test_getitem() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
good1: torch.Tensor[torch.float32, L[3], L[4]] = x[0]
# pyre-fixme[9]: bad1 has type `Tensor[torch.float32,
# typing_extensions.Literal[99], typing_extensions.Literal[4]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[3],
# typing_extensions.Literal[4]]`.
bad1: torch.Tensor[torch.float32, L[99], L[4]] = x[0]
good2: torch.Tensor[torch.float32, L[1], L[2], L[3], L[4]] = x[None]
# pyre-fixme[9]: bad2 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad2: torch.Tensor[torch.float32, L[99], L[2], L[3], L[4]] = x[None]
mask: torch.Tensor[torch.bool, L[2], L[3], L[4]]
good3: torch.Tensor[torch.float32, int] = x[mask]
# pyre-fixme[9]: bad3 has type `Tensor[torch.float32,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.float32, int]`.
bad3: torch.Tensor[torch.float32, L[99]] = x[mask]
any1: Tuple[int, str, float] = x[2]
any2: Tuple[float, str, int] = x[2]
def test_expand() -> None:
x: torch.Tensor[torch.float32, L[1], L[2], L[3]]
shape: Tuple[L[4], L[1], L[3]]
good1: torch.Tensor[torch.float32, L[4], L[2], L[3]] = x.expand(shape)
# pyre-fixme[9]: bad1 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad1: torch.Tensor[torch.float32, L[99], L[2], L[3]] = x.expand(shape)
# pyre-fixme[2001]: Broadcast error at expression `x.expand((4, 99, 3))`; types `...
x.expand((4, 99, 3))
good2: torch.Tensor[torch.float32, L[4], L[2], L[3]] = x.expand(4, 1, 3)
def test_to() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
good1: torch.Tensor[torch.int32, L[2], L[3], L[4]] = x.to(torch.int32)
# pyre-fixme[9]: bad1 has type `Tensor[torch.int32, typing_extensions.Literal[99]...
bad1: torch.Tensor[torch.int32, L[99], L[3], L[4]] = x.to(torch.int32)
device: torch.device
good2: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.to(device)
# pyre-fixme[9]: bad2 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad2: torch.Tensor[torch.float32, L[99], L[3], L[4]] = x.to(device)
y: torch.Tensor[torch.int32, L[2], L[3], L[4]]
good3: torch.Tensor[torch.float32, L[2], L[3], L[4]] = y.to(torch.float32, device)
# pyre-fixme[9]: bad3 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad3: torch.Tensor[torch.float32, L[99], L[3], L[4]] = y.to(torch.float32, device)
def test_Linear_to() -> None:
linear: nn.Linear[L[10], L[20]]
device: torch.device
linear.to(dtype=torch.int64, device=device)
def test_Module_eval() -> None:
module: nn.Module
module.eval()
def test_Module_train() -> None:
module: nn.Module
module.train(mode=True)
y: bool = module.training
def test_Linear_bias() -> None:
linear: nn.Linear[L[10], L[20]]
x: nn.Parameter = linear.bias
def test_sum() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y1: torch.Tensor[torch.float32, L[2], L[3]] = x.sum(-1, dtype=None)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[99], typing_extensions.Literal[3]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: torch.Tensor[torch.float32, L[99], L[3]] = x.sum(-1, dtype=None)
y2: torch.Tensor[torch.float32, L[2], L[4]] = x.sum(-2)
y3: torch.Tensor[torch.float32] = x.sum()
y4: torch.Tensor[torch.float32, L[3], L[4]] = x.sum(0)
y5: torch.Tensor[torch.float32, L[2], L[4]] = x.sum(1)
y6: torch.Tensor[torch.float32] = torch.sum(x)
def test_cumsum() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
good1: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.cumsum()
# pyre-fixme[9]: bad1 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad1: torch.Tensor[torch.float32, L[99], L[3], L[4]] = x.cumsum()
good2: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.cumsum(dim=0)
# pyre-fixme[9]: bad2 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad2: torch.Tensor[torch.float32, L[99], L[3], L[4]] = x.cumsum(dim=0)
good3: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.cumsum(dtype=None)
# pyre-fixme[9]: bad3 has type `Tensor[torch.float32, typing_extensions.Literal[9...
bad3: torch.Tensor[torch.float32, L[99], L[3], L[4]] = x.cumsum(dtype=None)
def test_contiguous() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
good: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.contiguous()
# pyre-fixme[9]: bad has type `Tensor[torch.float32, typing_extensions.Literal[99...
bad: torch.Tensor[torch.float32, L[99], L[3], L[4]] = x.contiguous()
def test_diff() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
good: torch.Tensor[torch.float32, L[2], L[3], L[3]] = torch.diff(x)
# pyre-fixme[9]: bad has type `Tensor[torch.float32, typing_extensions.Literal[99...
bad: torch.Tensor[torch.float32, L[99], L[3], L[3]] = torch.diff(x)
good2: torch.Tensor[torch.float32, L[1], L[3], L[4]] = torch.diff(x, dim=0)
good3: torch.Tensor[torch.float32, L[2], L[2], L[4]] = torch.diff(x, dim=1)
good4: torch.Tensor[torch.float32, L[2], L[3], L[3]] = torch.diff(x, dim=-1)
good5: torch.Tensor[torch.float32, L[2], L[2], L[4]] = torch.diff(x, dim=-2)
good6: torch.Tensor[torch.float32, L[2], L[2], L[4]] = x.diff(dim=-2)
def test_argsort() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
good1: torch.Tensor[torch.float32, L[2], L[3], L[4]] = torch.argsort(x)
# pyre-fixme[9]: bad1 has type `LongTensor[torch.float32, typing_extensions.Liter...
bad1: torch.Tensor[torch.float32, L[99], L[3], L[4]] = torch.argsort(x)
good2: torch.Tensor[torch.float32, L[2], L[3], L[4]] = torch.argsort(x, dim=0)
# pyre-fixme[9]: bad2 has type `LongTensor[torch.float32, typing_extensions.Liter...
bad2: torch.Tensor[torch.float32, L[99], L[3], L[4]] = torch.argsort(x, dim=0)
good3: torch.Tensor[torch.float32, L[2], L[3], L[4]] = torch.argsort(
x, descending=True
)
# pyre-fixme[9]: bad3 has type `LongTensor[torch.float32, typing_extensions.Liter...
bad3: torch.Tensor[torch.float32, L[99], L[3], L[4]] = torch.argsort(
x, descending=True
)
good4: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.argsort(dim=-1)
def test_functional_pad() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
good: torch.Tensor[torch.float32, L[2], L[3], L[5]] = nn.functional.pad(x, (1, 0))
bad: torch.Tensor[torch.float32, L[99], L[3], L[5]] = nn.functional.pad(x, (1, 0))
good2: torch.Tensor[torch.float32, L[2], L[10], L[7]] = nn.functional.pad(
x, (1, 2, 3, 4), "constant", value=0.0
)
def test_allclose() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor[torch.float32, L[2], L[1]]
not_broadcastable: torch.Tensor[torch.float32, L[3], L[4]]
good: bool = torch.allclose(x, y, atol=0.0, rtol=0.0, equal_nan=True)
# This should complain about non-broadcastable tensors but we don't have a
# way to constrain two parameter types to be broadcastable.
should_error: bool = torch.allclose(x, not_broadcastable)
def test_new_ones() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor[torch.float32, L[8], L[9]] = x.new_ones((8, 9))
# pyre-fixme[9]: Expected error.
y_error: torch.Tensor[torch.float32, L[8], L[99]] = x.new_ones((8, 9))
y2: torch.Tensor[torch.int64, L[8], L[9]] = x.new_ones(
(8, 9), dtype=torch.int64, device="cuda", requires_grad=True
)
def test_ones_like() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
device: torch.device
good: torch.Tensor[torch.int64, L[2], L[3]] = torch.ones_like(
x, dtype=torch.int64, device=device
)
# pyre-fixme[9]: bad has type `Tensor[torch.int64,
# typing_extensions.Literal[99], typing_extensions.Literal[3]]`; used as
# `Tensor[torch.int64, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
bad: torch.Tensor[torch.int64, L[99], L[3]] = torch.ones_like(
x, dtype=torch.int64, device=device
)
bad2: torch.Tensor[torch.float32, L[2], L[3]] = torch.ones_like(
x,
)
def test_sparse_softmax() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor[torch.float32, L[2], L[3]] = torch.sparse.softmax(x, dim=-1)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[99], typing_extensions.Literal[3]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: torch.Tensor[torch.float32, L[99], L[3]] = torch.sparse.softmax(x, dim=-1)
dtype: torch.int64
y2: torch.Tensor[torch.int64, L[2], L[3]] = torch.sparse.softmax(
x, dim=-1, dtype=dtype
)
def test_eye() -> None:
y: torch.Tensor[torch.int64, L[2], L[3]] = torch.eye(2, 3, dtype=torch.int64)
# pyre-fixme[9]: y_error has type `Tensor[torch.int64,
# typing_extensions.Literal[99], typing_extensions.Literal[3]]`; used as
# `Tensor[torch.int64, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: torch.Tensor[torch.int64, L[99], L[3]] = torch.eye(2, 3, dtype=torch.int64)
y2: torch.Tensor[torch.float32, L[3], L[3]] = torch.eye(3)
def test_adaptive_average_pool2d() -> None:
model: nn.AdaptiveAvgPool2d[L[5], L[7]] = nn.AdaptiveAvgPool2d((5, 7))
# pyre-fixme[9]: model_error has type
# `AdaptiveAvgPool2d[typing_extensions.Literal[5],
# typing_extensions.Literal[99]]`; used as
# `AdaptiveAvgPool2d[typing_extensions.Literal[5], typing_extensions.Literal[7]]`.
model_error: nn.AdaptiveAvgPool2d[L[5], L[99]] = nn.AdaptiveAvgPool2d((5, 7))
model2: nn.AdaptiveAvgPool2d[L[5], L[5]] = nn.AdaptiveAvgPool2d(5)
# TODO(T100083794): This should be an error.
model2_error: nn.AdaptiveAvgPool2d[L[5], L[99]] = nn.AdaptiveAvgPool2d(5)
model3: nn.AdaptiveAvgPool2d[L[5], L[-1]] = nn.AdaptiveAvgPool2d((5, None))
# TODO(T100083794): This should be an error.
model3_error: nn.AdaptiveAvgPool2d[L[5], L[99]] = nn.AdaptiveAvgPool2d((5, None))
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.Tensor[torch.float32, L[2], L[5], L[7]] = model(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[99], L[7]] = model(x)
y2: torch.Tensor[torch.float32, L[2], L[5], L[5]] = model2(x)
y3: torch.Tensor[torch.float32, L[2], L[5], L[4]] = model3(x)
def test_randperm() -> None:
y: torch.Tensor[torch.int64, L[10]] = torch.randperm(10, dtype=torch.int64)
# pyre-fixme[9]: y_error has type `Tensor[torch.int64,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.int64,
# typing_extensions.Literal[10]]`.
y_error: torch.Tensor[torch.int64, L[99]] = torch.randperm(10, dtype=torch.int64)
def test_sqrt() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor[torch.float32, L[2], L[3]] = torch.sqrt(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: torch.Tensor[torch.float32, L[2], L[99]] = torch.sqrt(x)
def test_multinomial() -> None:
x: torch.Tensor[torch.float32, L[2], L[4]]
y: torch.Tensor[torch.float32, L[2], L[3]] = torch.multinomial(x, 3)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: torch.Tensor[torch.float32, L[2], L[99]] = torch.multinomial(x, 3)
x2: torch.Tensor[torch.float32, L[4]]
y2: torch.Tensor[torch.float32, L[3]] = torch.multinomial(x2, 3)
y2: torch.Tensor[torch.float32, L[3]] = x2.multinomial(3)
def test_bmm() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
matrix: torch.Tensor[torch.float32, L[2], L[4], L[5]]
y: torch.Tensor[torch.float32, L[2], L[3], L[5]] = torch.bmm(x, matrix)
y2: torch.Tensor[torch.float32, L[2], L[3], L[5]] = x.bmm(matrix)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[3], L[99]] = torch.bmm(x, matrix)
bad_matrix: torch.Tensor[torch.float32, L[2], L[99], L[5]]
# Should raise an error but doesn't because we solve `L[99] <: M && L[4] <:
# M` to be M = int.
torch.bmm(x, bad_matrix)
def test_subtract() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[1]]
x2: torch.Tensor[torch.float32, L[2], L[1], L[4]]
y: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x - x2
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[3], L[99]] = x - x2
y2: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x2 - x
y3: torch.Tensor[torch.float32, L[2], L[3], L[1]] = x - 42.0
y4: torch.Tensor[torch.float32, L[2], L[3], L[1]] = 42.0 - x
z: Any
# Should not error.
x - z
x5: Tensor[torch.float32, L[2], L[3]]
x6: Tensor[torch.float32, L[2], L[3]]
x6_bad: Tensor[torch.float32, L[2], L[99]]
x5 -= x6
x5 -= 4
y5: Tensor[torch.float32, L[2], L[3]] = x5
# pyre-fixme[2001]: Broadcast error at expression `x5.__isub__(x6_bad)`; types
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[3]]` and
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[99]]` cannot be
# broadcasted together.
x5 -= x6_bad
def test_add() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[1]]
x2: torch.Tensor[torch.float32, L[2], L[1], L[4]]
y: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x + x2
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[3], L[99]] = x + x2
y2: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x2 + x
y3: torch.Tensor[torch.float32, L[2], L[3], L[1]] = x + 42.0
y4: torch.Tensor[torch.float32, L[2], L[3], L[1]] = 42.0 + x
x5: Tensor[torch.float32, L[2], L[3]]
x6: Tensor[torch.float32, L[2], L[3]]
x6_bad: Tensor[torch.float32, L[2], L[99]]
x5 += x6
x5 += 4
y5: Tensor[torch.float32, L[2], L[3]] = x5
# pyre-fixme[2001]: Broadcast error at expression `x5.__iadd__(x6_bad)`; types
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[3]]` and
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[99]]` cannot be
# broadcasted together.
x5 += x6_bad
def test_torch_fft() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.Tensor[torch.complex64, L[2], L[3], L[4]] = torch.fft.fft(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.complex64, typing_extensions.Lite...
y_error: torch.Tensor[torch.complex64, L[2], L[3], L[99]] = torch.fft.fft(x)
y2: torch.Tensor[torch.complex64, L[2], L[3], L[4]] = torch.fft.fft(x, dim=-2)
def test_torch_real() -> None:
x: torch.Tensor[torch.complex64, L[2], L[3], L[4]]
y: torch.Tensor[torch.float32, L[2], L[3], L[4]] = torch.real(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[3], L[99]] = torch.real(x)
x2: torch.Tensor[torch.complex128, L[2], L[3], L[4]]
y2: torch.Tensor[torch.float64, L[2], L[3], L[4]] = torch.real(x2)
bad: torch.Tensor[torch.float32, L[2], L[3], L[4]]
# pyre-fixme[6]: Expected `Tensor[torch.complex64, *torch.Ts]` for 1st param but
# got `Tensor[torch.float32, int, int, int]`.
torch.real(bad)
def test_logical_and() -> None:
x: torch.Tensor[torch.complex64, L[2], L[1], L[4]]
x2: torch.Tensor[torch.float32, L[2], L[3], L[1]]
y: torch.Tensor[torch.bool, L[2], L[3], L[4]] = torch.logical_and(x, x2)
# pyre-fixme[9]: y_error has type `Tensor[torch.bool, typing_extensions.Literal[2...
y_error: torch.Tensor[torch.bool, L[2], L[3], L[99]] = torch.logical_and(x, x2)
y2: torch.Tensor[torch.bool, L[2], L[3], L[4]] = x.logical_and(x2)
not_broadcastable: torch.Tensor[torch.float32, L[2], L[3], L[99]]
# pyre-fixme[2001]: Broadcast error at expression `torch.logical_and(x, not_broad...
torch.logical_and(x, not_broadcastable)
x3: torch.Tensor[torch.complex64, L[2], L[1], L[1]]
# In-place version.
x.logical_and_(x3)
# This is actually an error because the output type (2, 3, 4) is not
# assignable to x. But we can't catch that because the typechecker doesn't
# know this is an in-place operator. Leaving this as is for now.
x.logical_and_(x2)
def test_and() -> None:
x_bool: torch.Tensor[torch.bool, L[2], L[1], L[4]]
x_bool2: torch.Tensor[torch.bool, L[2], L[3], L[1]]
y3: torch.Tensor[torch.bool, L[2], L[3], L[4]] = x_bool & x_bool2
# This broadcasts to (2, 1, 4), which is assignable to x_bool.
x_bool3: torch.Tensor[torch.bool, L[2], L[1], L[1]]
x_bool &= x_bool3
# This broadcasts to (2, 3, 4), which is not assignable to x_bool.
# pyre-fixme[9]: x_bool has type `Tensor[torch.bool, typing_extensions.Literal[2]...
x_bool &= x_bool2
x: torch.Tensor[torch.complex64, L[2], L[1], L[4]]
x2: torch.Tensor[torch.float32, L[2], L[3], L[1]]
# pyre-fixme[58]: `&` is not supported for operand types
# `Tensor[torch.complex64, int, int, int]` and `Tensor[torch.float32, int, int,
# int]`.
x & x2
def test_linalg_pinv() -> None:
x: torch.Tensor[torch.float32, L[2], L[2], L[3], L[4]]
y: torch.Tensor[torch.float32, L[2], L[2], L[4], L[3]] = torch.linalg.pinv(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[4], L[99]] = torch.linalg.pinv(x)
wrong_datatype: torch.Tensor[torch.bool, L[2], L[3], L[4]]
# pyre-fixme[6]: Expected `Tensor[Variable[torch.linalg.FloatOrDouble <:
# [torch.float32, torch.float64, torch.complex64, torch.complex128]],
# *torch.linalg.Ts, Variable[N1 (bound to int)], Variable[N2 (bound to int)]]` for
# 1st param but got `Tensor[torch.bool, int, int, int]`.
torch.linalg.pinv(wrong_datatype)
torch.linalg.pinv(x, hermitian=True)
# Last two dimensions have to be equal.
x_square: torch.Tensor[torch.float32, L[2], L[3], L[4], L[4]]
y2: torch.Tensor[torch.float32, L[2], L[3], L[4], L[4]] = torch.linalg.pinv(
x_square, hermitian=True
)
def test_linalg_qr() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: Tuple[
torch.Tensor[torch.float32, L[2], L[3], L[3]],
torch.Tensor[torch.float32, L[2], L[3], L[4]],
] = torch.linalg.qr(x)
# pyre-fixme[9]: y_error has type `Tuple[Tensor[torch.float32, typing_extensions....
y_error: Tuple[
torch.Tensor[torch.float32, L[2], L[3], L[99]],
torch.Tensor[torch.float32, L[2], L[3], L[4]],
] = torch.linalg.qr(x)
y2: Tuple[
torch.Tensor[torch.float32, L[2], L[3], L[3]],
torch.Tensor[torch.float32, L[2], L[3], L[4]],
] = torch.linalg.qr(x, mode="complete")
def test_torch_matmul() -> None:
x: torch.Tensor[torch.float32, L[2], L[1], L[3], L[4]]
x2: torch.Tensor[torch.float32, L[1], L[5], L[4], L[3]]
y: torch.Tensor[torch.float32, L[2], L[5], L[3], L[3]] = torch.matmul(x, x2)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[5], L[3], L[99]] = torch.matmul(x, x2)
y2: torch.Tensor[torch.float32, L[2], L[5], L[3], L[3]] = x.matmul(x2)
y3: torch.Tensor[torch.float32, L[2], L[5], L[3], L[3]] = x.__matmul__(x2)
bad_x: torch.Tensor[torch.float32, L[1], L[5], L[99], L[3]]
torch.matmul(x, bad_x)
x_1d: torch.Tensor[torch.float32, L[3]]
x2_1d: torch.Tensor[torch.float32, L[3]]
y4: torch.Tensor[torch.float32] = torch.matmul(x_1d, x2_1d)
x3_1d_different: torch.Tensor[torch.float32, L[1]]
torch.matmul(x_1d, x3_1d_different)
def test_torch_optim() -> None:
block_parameters: Any
torch.optim.SGD(block_parameters, lr=1.0)
def test_torch_cuda() -> None:
torch.cuda.reset_peak_memory_stats()
def test_torch_profiler() -> None:
torch.profiler.profile()
def test_mse_loss() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
x2: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor[torch.float32] = nn.MSELoss(
size_average=True, reduce=True, reduction="mean"
)(x, x2)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.float32]`.
y_error: torch.Tensor[torch.float32, L[99]] = nn.MSELoss()(x, x2)
def test_clip_grad_norm() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor = nn.utils.clip_grad_norm_(
x, max_norm=0.0, norm_type=0.0, error_if_nonfinite=True
)
# pyre-fixme[9]: y_error has type `int`; used as `Tensor[typing.Any,
# *Tuple[typing.Any, ...]]`.
y_error: int = nn.utils.clip_grad_norm_(
x, max_norm=0.0, norm_type=0.0, error_if_nonfinite=True
)
def test_clip_grad_value() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
nn.utils.clip_grad_value_([x], clip_value=0.0)
def test_bitwise_not() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor[torch.float32, L[2], L[3]] = torch.bitwise_not(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: torch.Tensor[torch.float32, L[2], L[99]] = torch.bitwise_not(x)
y2: torch.Tensor[torch.float32, L[2], L[3]] = x.bitwise_not()
# In-place.
y3: torch.Tensor[torch.float32, L[2], L[3]] = x.bitwise_not_()
y4: torch.Tensor[torch.float32, L[2], L[3]] = ~x
def test_cdist() -> None:
x: torch.Tensor[torch.float32, L[5], L[1], L[2], L[3]]
x2: torch.Tensor[torch.float32, L[1], L[7], L[4], L[3]]
y: torch.Tensor[torch.float32, L[5], L[7], L[2], L[4]] = torch.cdist(x, x2)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[5], L[7], L[2], L[99]] = torch.cdist(x, x2)
not_broadcastable: torch.Tensor[torch.float32, L[99], L[1], L[2], L[3]]
# pyre-fixme[2001]: Broadcast error at expression `torch.cdist(x,
# not_broadcastable)`; types `Tuple[typing_extensions.Literal[5],
# typing_extensions.Literal[1]]` and `Tuple[typing_extensions.Literal[99],
# typing_extensions.Literal[1]]` cannot be broadcasted together.
torch.cdist(x, not_broadcastable)
def test_random_manual_seed() -> None:
torch.random.manual_seed(42)
def test_clone() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor[torch.float32, L[2], L[3]] = torch.clone(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: torch.Tensor[torch.float32, L[2], L[99]] = torch.clone(x)
y2: torch.Tensor[torch.float32, L[2], L[3]] = x.clone()
def test_equal() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor[torch.bool, L[2], L[3]] = x == 42
# pyre-fixme[9]: y_error has type `Tensor[torch.bool,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.bool, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: torch.Tensor[torch.bool, L[2], L[99]] = x == 42
# This doesn't return a Tensor as expected because `int.__eq__` accepts `object`.
y2: int = 42 == x
x2: torch.Tensor[torch.float32, L[2], L[1]]
x3: torch.Tensor[torch.float32, L[1], L[3]]
y3: torch.Tensor[torch.bool, L[2], L[3]] = x2 == x3
def test_diag_embed() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.Tensor = torch.diag_embed(x)
def test_unbind() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: Tuple[torch.Tensor[torch.float32, L[2], L[4]], ...] = torch.unbind(x, dim=1)
# pyre-fixme[9]: y_error has type `Tuple[Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[99]], ...]`; used as
# `Tuple[Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[4]], ...]`.
y_error: Tuple[torch.Tensor[torch.float32, L[2], L[99]], ...] = torch.unbind(
x, dim=1
)
y2: Tuple[torch.Tensor[torch.float32, L[2], L[3]], ...] = torch.unbind(x, dim=-1)
y3: Tuple[torch.Tensor[torch.float32, L[3], L[4]], ...] = torch.unbind(x)
y4: Tuple[torch.Tensor[torch.float32, L[3], L[4]], ...] = x.unbind()
def test_size() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: Tuple[L[2], L[3], L[4]] = x.size()
# pyre-fixme[9]: y_error has type `Tuple[typing_extensions.Literal[2],
# typing_extensions.Literal[3], typing_extensions.Literal[99]]`; used as
# `Tuple[typing_extensions.Literal[2], typing_extensions.Literal[3],
# typing_extensions.Literal[4]]`.
y_error: Tuple[L[2], L[3], L[99]] = x.size()
y2: L[2] = x.size(0)
y3: L[3] = x.size(1)
y4: L[4] = x.size(-1)
y5: L[3] = x.size(-2)
def test_stack(
arbitary_length_tuple: Tuple[torch.Tensor[torch.float32, L[3], L[4], L[5]], ...],
variadic_tuple: Tuple[Unpack[Ts]],
) -> None:
x: torch.Tensor[torch.float32, L[3], L[4], L[5]]
x_incompatible: torch.Tensor[torch.float32, L[3], L[4], L[99]]
y: torch.Tensor[torch.float32, L[2], L[3], L[4], L[5]] = torch.stack((x, x))
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[3], L[4], L[99]] = torch.stack((x, x))
y_incompatible_tensors: torch.Tensor = torch.stack((x, x_incompatible))
y2: torch.Tensor[torch.float32, L[3], L[2], L[4], L[5]] = torch.stack((x, x), dim=1)
y3: torch.Tensor[torch.float32, L[3], L[3], L[4], L[5]] = torch.stack(
(x, x, x), dim=1
)
y4: torch.Tensor[torch.float32, L[3], L[3], L[4], L[5]] = torch.stack((x, x, x))
# Arbitrary-length tuples make it return an arbitrary Tensor.
y5: torch.Tensor = torch.stack(arbitary_length_tuple)
y6: torch.Tensor = torch.stack(variadic_tuple)
def test_repeat_interleave() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
repeats: torch.Tensor[torch.float32, L[2]]
y: torch.Tensor[torch.float32, L[72]] = torch.repeat_interleave(x, 3)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.float32,
# typing_extensions.Literal[72]]`.
y_error: torch.Tensor[torch.float32, L[99]] = torch.repeat_interleave(x, 3)
y2: torch.Tensor[torch.float32, L[4], L[3], L[4]] = torch.repeat_interleave(
x, 2, dim=0
)
y3: torch.Tensor[torch.float32, L[2], L[6], L[4]] = torch.repeat_interleave(
x, 2, dim=1
)
y4: torch.Tensor[torch.float32, L[2], L[3], L[8]] = torch.repeat_interleave(
x, 2, dim=-1
)
# Too dynamic because the output shape depends on the contents of repeats.
y5: torch.Tensor[torch.float32, L[0], L[3], L[4]] = torch.repeat_interleave(
x, repeats, dim=0
)
y6: torch.Tensor[torch.float32, L[2], L[3], L[8]] = x.repeat_interleave(2, dim=-1)
def test_meshgrid() -> None:
x1: torch.Tensor[torch.float32, L[2]]
x2: torch.Tensor[torch.float32, L[3]]
x3: torch.Tensor[torch.float32, L[4]]
y: Tuple[
torch.Tensor[torch.float32, L[2], L[3], L[4]],
torch.Tensor[torch.float32, L[2], L[3], L[4]],
torch.Tensor[torch.float32, L[2], L[3], L[4]],
] = torch.meshgrid(x1, x2, x3)
# pyre-fixme[9]: y_error has type `Tuple[Tensor[torch.float32, typing_extensions....
y_error: Tuple[
torch.Tensor[torch.float32, L[2], L[3], L[4]],
torch.Tensor[torch.float32, L[2], L[3], L[4]],
torch.Tensor[torch.float32, L[2], L[3], L[99]],
] = torch.meshgrid(x1, x2, x3)
y2: Tuple[
torch.Tensor[torch.float32, L[2], L[3]],
torch.Tensor[torch.float32, L[2], L[3]],
] = torch.meshgrid(x1, x2)
y3: Tuple[
torch.Tensor[torch.float32, L[2]],
] = torch.meshgrid(x1)
x4: Tensor
xs = tuple(x4 for _ in range(5))
y4: Tuple[torch.Tensor, ...] = torch.meshgrid(*xs)
xs2 = [x4 for _ in range(5)]
y5: Tuple[torch.Tensor, ...] = torch.meshgrid(*xs2)
def test_argmax() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.LongTensor[torch.int64] = torch.argmax(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.int64,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.int64]`.
y_error: torch.LongTensor[torch.int64, L[99]] = torch.argmax(x)
y2: torch.LongTensor[torch.int64, L[3], L[4]] = torch.argmax(x, dim=0)
y3: torch.LongTensor[torch.int64, L[1], L[3], L[4]] = torch.argmax(
x, dim=0, keepdim=True
)
y4: torch.LongTensor[torch.int64, L[2], L[4]] = torch.argmax(x, dim=1)
y5: torch.LongTensor[torch.int64, L[2], L[1], L[4]] = torch.argmax(
x, dim=1, keepdim=True
)
y6: torch.LongTensor[torch.int64, L[2], L[3]] = torch.argmax(x, dim=2)
y7: torch.LongTensor[torch.int64, L[2], L[3], L[1]] = torch.argmax(
x, dim=2, keepdim=True
)
y8: torch.LongTensor[torch.int64, L[2], L[3]] = torch.argmax(x, dim=-1)
y9: torch.LongTensor[torch.int64, L[2], L[3], L[1]] = torch.argmax(
x, dim=-1, keepdim=True
)
y10: torch.LongTensor[torch.int64, L[2], L[3], L[1]] = x.argmax(
dim=-1, keepdim=True
)
# pyre-fixme[6]: Expected `typing_extensions.Literal[0]` for 2nd param but got
# `typing_extensions.Literal[3]`.
torch.argmax(x, dim=3)
def test_argmin() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.LongTensor[torch.int64] = torch.argmin(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.int64,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.int64]`.
y_error: torch.LongTensor[torch.int64, L[99]] = torch.argmin(x)
y2: torch.LongTensor[torch.int64, L[3], L[4]] = torch.argmin(x, dim=0)
y3: torch.LongTensor[torch.int64, L[1], L[3], L[4]] = torch.argmin(
x, dim=0, keepdim=True
)
y4: torch.LongTensor[torch.int64, L[2], L[4]] = torch.argmin(x, dim=1)
y5: torch.LongTensor[torch.int64, L[2], L[1], L[4]] = torch.argmin(
x, dim=1, keepdim=True
)
y6: torch.LongTensor[torch.int64, L[2], L[3]] = torch.argmin(x, dim=2)
y7: torch.LongTensor[torch.int64, L[2], L[3], L[1]] = torch.argmin(
x, dim=2, keepdim=True
)
y8: torch.LongTensor[torch.int64, L[2], L[3]] = torch.argmin(x, dim=-1)
y9: torch.LongTensor[torch.int64, L[2], L[3], L[1]] = torch.argmin(
x, dim=-1, keepdim=True
)
y10: torch.LongTensor[torch.int64, L[2], L[3], L[1]] = x.argmin(
dim=-1, keepdim=True
)
# pyre-fixme[6]: Expected `typing_extensions.Literal[0]` for 2nd param but got
# `typing_extensions.Literal[3]`.
torch.argmin(x, dim=3)
def test_mean() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.Tensor[torch.float32] = torch.mean(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.float32]`.
y_error: torch.Tensor[torch.float32, L[99]] = torch.mean(x)
y2: torch.Tensor[torch.float32, L[3], L[4]] = torch.mean(x, dim=0)
y3: torch.Tensor[torch.float32, L[1], L[3], L[4]] = torch.mean(
x, dim=0, keepdim=True
)
y4: torch.Tensor[torch.float32, L[2], L[4]] = torch.mean(x, dim=1)
y5: torch.Tensor[torch.float32, L[2], L[1], L[4]] = torch.mean(
x, dim=1, keepdim=True
)
y6: torch.Tensor[torch.float32, L[2], L[3]] = torch.mean(x, dim=2)
y7: torch.Tensor[torch.float32, L[2], L[3], L[1]] = torch.mean(
x, dim=2, keepdim=True
)
y8: torch.Tensor[torch.float32, L[2], L[3]] = torch.mean(x, dim=-1)
y9: torch.Tensor[torch.float32, L[2], L[3], L[1]] = torch.mean(
x, dim=-1, keepdim=True
)
y10: torch.Tensor[torch.float32, L[2], L[3], L[1]] = x.mean(dim=-1, keepdim=True)
# pyre-fixme[6]: Expected `typing_extensions.Literal[0]` for 2nd param but got
# `typing_extensions.Literal[3]`.
torch.mean(x, dim=3)
def test_count_nonzero() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.Tensor[torch.int64] = torch.count_nonzero(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.int64,
# typing_extensions.Literal[99]]`; used as `Tensor[torch.int64]`.
y_error: torch.Tensor[torch.int64, L[99]] = torch.count_nonzero(x)
y2: torch.Tensor[torch.int64, L[3], L[4]] = torch.count_nonzero(x, dim=0)
y3: torch.Tensor[torch.int64, L[2], L[4]] = torch.count_nonzero(x, dim=1)
y4: torch.Tensor[torch.int64, L[2], L[3]] = torch.count_nonzero(x, dim=2)
y5: torch.Tensor[torch.int64, L[2], L[3]] = x.count_nonzero(dim=-1)
# pyre-fixme[6]: Expected `typing_extensions.Literal[0]` for 2nd param but got
# `typing_extensions.Literal[3]`.
torch.count_nonzero(x, dim=3)
def test_cat() -> None:
x1: torch.Tensor[torch.float32, L[2], L[3], L[4]]
x1_first_is_3: torch.Tensor[torch.float32, L[3], L[3], L[4]]
x1_first_is_4: torch.Tensor[torch.float32, L[4], L[3], L[4]]
x1_second_is_4: torch.Tensor[torch.float32, L[2], L[4], L[4]]
x1_second_is_5: torch.Tensor[torch.float32, L[2], L[5], L[4]]
x1_last_is_5: torch.Tensor[torch.float32, L[2], L[3], L[5]]
x1_last_is_6: torch.Tensor[torch.float32, L[2], L[3], L[6]]
# 2-element tuple.
y: torch.Tensor[torch.float32, L[5], L[3], L[4]] = torch.cat((x1, x1_first_is_3))
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[99], L[3], L[4]] = torch.cat(
(x1, x1_first_is_3)
)
y2: torch.Tensor[torch.float32, L[2], L[7], L[4]] = torch.cat(
(x1, x1_second_is_4), dim=1
)
y3: torch.Tensor[torch.float32, L[2], L[3], L[9]] = torch.cat(
(x1, x1_last_is_5), dim=-1
)
y3_shape_mismatch: torch.Tensor[torch.float32, Unpack[Tuple[Any, ...]]] = torch.cat(
(x1, x1_second_is_4), dim=-1
)
# 3-element tuple.
y4: torch.Tensor[torch.float32, L[9], L[3], L[4]] = torch.cat(
(x1, x1_first_is_3, x1_first_is_4)
)
y5: torch.Tensor[torch.float32, L[2], L[12], L[4]] = torch.cat(
(x1, x1_second_is_4, x1_second_is_5), dim=1
)
y6: torch.Tensor[torch.float32, L[2], L[3], L[15]] = torch.cat(
(x1, x1_last_is_5, x1_last_is_6), dim=-1
)
y_many_element_tuple: torch.Tensor[
torch.float32, Unpack[Tuple[Any, ...]]
] = torch.cat((x1, x1, x1, x1))
y_list: torch.Tensor[torch.float32, Unpack[Tuple[Any, ...]]] = torch.cat([x1, x1])
def test_sign() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.Tensor[torch.float32, L[2], L[3], L[4]] = torch.sign(x)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[3], L[99]] = torch.sign(x)
y2: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.sign()
def test_diagonal() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4], L[5]]
y: torch.Tensor = torch.diagonal(x)
def test_diag() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
y: torch.Tensor = torch.diag(x)
def test_module_list() -> None:
x: torch.Tensor[torch.float32, L[2], L[3]]
modules = nn.ModuleList([nn.AdaptiveAvgPool2d(0), nn.AdaptiveAvgPool2d(1)])
for module in modules:
y: Tensor = module(x)
z: int = len(modules)
def test_sparse_coo_tensor() -> None:
y: torch.Tensor[torch.float32, L[2], L[3]] = torch.sparse_coo_tensor(
torch.randn(5), [6, 7, 8], size=(2, 3)
)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32,
# typing_extensions.Literal[2], typing_extensions.Literal[99]]`; used as
# `Tensor[torch.float32, typing_extensions.Literal[2],
# typing_extensions.Literal[3]]`.
y_error: torch.Tensor[torch.float32, L[2], L[99]] = torch.sparse_coo_tensor(
torch.randn(5), [6, 7, 8], size=(2, 3)
)
y2: torch.Tensor = torch.sparse_coo_tensor(torch.randn(5), [6, 7, 8])
def test_max() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.Tensor[torch.float32] = torch.max(x)
y2: torch.Tensor[torch.float32, L[3], L[4]] = torch.max(x, dim=0).values
y2_indices: torch.Tensor[torch.int64, L[3], L[4]] = torch.max(x, dim=0).indices
y2_getitem: torch.Tensor[torch.int64, L[3], L[4]] = torch.max(x, dim=0)[1]
y3: torch.Tensor[torch.float32, L[1], L[3], L[4]] = torch.max(
x, dim=0, keepdim=True
).values
y4: torch.Tensor[torch.float32, L[2], L[4]] = torch.max(x, dim=1).values
y5: torch.Tensor[torch.float32, L[2], L[1], L[4]] = torch.max(
x, dim=1, keepdim=True
).values
y6: torch.Tensor[torch.float32, L[2], L[3]] = torch.max(x, dim=2).values
y7: torch.Tensor[torch.float32, L[2], L[3], L[1]] = torch.max(
x, dim=2, keepdim=True
).values
y8: torch.Tensor[torch.float32, L[2], L[3]] = torch.max(x, dim=-1).values
y9: torch.Tensor[torch.float32, L[2], L[3], L[1]] = torch.max(
x, dim=-1, keepdim=True
).values
y10: torch.Tensor[torch.float32, L[2], L[4]] = torch.max(x, dim=-2).values
y11: torch.Tensor[torch.float32, L[2], L[1], L[4]] = torch.max(
x, dim=-2, keepdim=True
).values
y12: torch.Tensor[torch.float32, L[2], L[3], L[1]] = x.max(
dim=-1, keepdim=True
).values
# pyre-fixme[6]: Expected `typing_extensions.Literal[0]` for 2nd param but got
# `typing_extensions.Literal[3]`.
torch.max(x, dim=3).values
def test_einsum() -> None:
x: Tensor = torch.einsum("ii", torch.randn(4, 4))
def test_type_as() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
x2: torch.Tensor[torch.int64, L[2], L[3], L[4]]
y: torch.Tensor[torch.int64, L[2], L[3], L[4]] = x.type_as(x2)
def test_softmax() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: torch.Tensor[torch.float32, L[2], L[3], L[4]] = torch.softmax(x, dim=1)
# pyre-fixme[9]: y_error has type `Tensor[torch.float32, typing_extensions.Litera...
y_error: torch.Tensor[torch.float32, L[2], L[3], L[99]] = torch.softmax(x, dim=1)
y2: torch.Tensor[torch.int64, L[2], L[3], L[4]] = torch.softmax(
x, dim=1, dtype=torch.int64
)
y3: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.softmax(dim=1)
def test_conv2d() -> None:
x: Tensor[torch.float32, L[20], L[16], L[50], L[100]]
y7: Tensor[torch.float32, L[20], L[33], L[56], L[100]] = nn.Conv2d(
16, 33, (3, 5), padding=(4, 2), bias=False
)(x)
# pyre-fixme[9]: y7_error has type `Tensor[torch.float32, typing_extensions.Liter...
y7_error: Tensor[torch.float32, L[20], L[33], L[56], L[99]] = nn.Conv2d(
16, 33, (3, 5), padding=(4, 2)
)(x)
module: nn.Module = nn.Conv2d(16, 33, (3, 5), padding=(4, 2))
def test_nn_Parameter() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2], L[3], L[4]] = nn.Parameter(x)
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = nn.Parameter(x)
def test_torch_datatypes() -> None:
x: torch.float16
x2: torch.int
def test_norm() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
x_out: Tensor[torch.float32, L[2], L[3], L[4]]
y1: Tensor[torch.float32] = torch.norm(x)
y2: Tensor[torch.float32, L[3], L[4]] = torch.norm(x, dim=0, out=x_out, p=1)
# pyre-fixme[9]: Expected error.
y2_error: Tensor[torch.float32, L[3], L[99]] = torch.norm(x, dim=0)
y3: Tensor[torch.float32, L[1], L[3], L[4]] = torch.norm(x, dim=0, keepdim=True)
def test_rand() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
x_out: Tensor[torch.float32, L[2], L[3], L[4]]
device: torch.device
y1: Tensor[torch.float32, L[2], L[3], L[4]] = torch.rand(2, 3, 4)
# pyre-fixme[9]: Expected Error.
y1_error: Tensor[torch.float32, L[2], L[3], L[99]] = torch.rand(2, 3, 4)
y2: Tensor[torch.int64, L[2], L[3], L[4]] = torch.rand(
2,
3,
4,
dtype=torch.int64,
device=device,
layout=torch.strided,
out=x_out,
requires_grad=True,
generator=torch.default_generator,
)
y3: Tensor[torch.float32, L[2], L[3], L[4]] = torch.rand((2, 3, 4))
def test_randint() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
x_out: Tensor[torch.float32, L[2], L[3], L[4]]
device: torch.device
y1: Tensor[torch.int64, L[2], L[3], L[4]] = torch.randint(0, 3, (2, 3, 4))
# pyre-fixme[9]: Expected error.
y1_error: Tensor[torch.int64, L[2], L[3], L[99]] = torch.randint(0, 3, (2, 3, 4))
y2: Tensor[torch.int64, L[2], L[3], L[4]] = torch.randint(
3,
(2, 3, 4),
dtype=torch.int64,
device=device,
layout=torch.strided,
out=x_out,
requires_grad=True,
generator=torch.default_generator,
)
def test_zeros() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
x_out: Tensor[torch.float32, L[2], L[3], L[4]]
device: torch.device
y1: Tensor[torch.float32, L[2], L[3], L[4]] = torch.zeros(2, 3, 4)
# pyre-fixme[9]: Expected Error.
y1_error: Tensor[torch.float32, L[2], L[3], L[99]] = torch.zeros(2, 3, 4)
y2: Tensor[torch.int64, L[2], L[3], L[4]] = torch.zeros(
2,
3,
4,
dtype=torch.int64,
device=device,
layout=torch.strided,
out=x_out,
requires_grad=True,
)
y3: Tensor[torch.float32, L[2], L[3], L[4]] = torch.zeros((2, 3, 4))
def test_stride() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tuple[L[2], L[3], L[4]] = x.stride()
# pyre-fixme[9]: Expected error.
y_error: Tuple[L[2], L[3], L[99]] = x.stride()
y2: L[12] = x.stride(0)
y3: L[4] = x.stride(1)
y4: L[1] = x.stride(2)
def test_chunk() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tuple[
Tensor[torch.float32, L[2], L[3], L[2]], Tensor[torch.float32, L[2], L[3], L[2]]
] = torch.chunk(x, 2, dim=-1)
# pyre-fixme[9]: Expected error.
y_error: Tuple[
Tensor[torch.float32, L[2], L[3], L[99]],
Tensor[torch.float32, L[2], L[3], L[2]],
] = torch.chunk(x, 2, dim=-1)
y2: Tuple[
Tensor[torch.float32, L[1], L[3], L[4]], Tensor[torch.float32, L[1], L[3], L[4]]
] = torch.chunk(x, 2, dim=0)
y3: Tuple[
Tensor[torch.float32, L[1], L[3], L[4]], Tensor[torch.float32, L[1], L[3], L[4]]
] = x.chunk(2, dim=0)
def test_abs() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2], L[3], L[4]] = x.abs()
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = x.abs()
def test_enable_grad() -> None:
with torch.enable_grad():
pass
def test_normal() -> None:
y: Tensor[torch.float32, L[2], L[3], L[4]] = torch.normal(
0, 1, size=(2, 3, 4), device="cuda", requires_grad=True
)
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = torch.normal(
0, 1, size=(2, 3, 4), device="cuda", requires_grad=True
)
def test_dim() -> None:
x0: Tensor[torch.float32]
x1: Tensor[torch.float32, L[2]]
x2: Tensor[torch.float32, L[2], L[3]]
x3: Tensor[torch.float32, L[2], L[3], L[4]]
y: L[3] = x3.dim()
# pyre-fixme[9]: Expected error.
y_error: L[5] = x3.dim()
y2: L[0] = x0.dim()
y3: L[1] = x1.dim()
y4: L[2] = x2.dim()
def test_is_cuda() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: bool = x.is_cuda
def test_autograd_backward() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
torch.autograd.backward(x, x)
def test_linalg_norm() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2]] = torch.linalg.norm(x, dim=(-2, -1))
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[99]] = torch.linalg.norm(x, dim=(-2, -1))
def test_Sized() -> None:
x: torch.Size = torch.Size((2, 3, 4))
def test_initial_seed() -> None:
x: int = torch.initial_seed()
def test_log_softmax() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2], L[3], L[4]] = torch.log_softmax(x, dim=1)
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = torch.log_softmax(x, dim=1)
y2: Tensor[torch.int64, L[2], L[3], L[4]] = torch.log_softmax(
x, dtype=torch.int64, dim=1
)
def test_masked_select() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
mask: Tensor[torch.bool, L[2], L[3], L[4]]
out: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor = x.masked_select(mask, out=out)
y2: Tensor = torch.masked_select(x, mask, out=out)
def test__lt__() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.bool, L[2], L[3], L[4]] = x < 3.0
def test_pow() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2], L[3], L[4]] = x**4
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = x**4
def test_item() -> None:
x: Tensor[torch.float32]
x2: Tensor[torch.float32, L[1]]
y: torch.float32 = x.item()
# pyre-fixme[9]: Expected error.
y_error: torch.int64 = x.item()
y2: torch.float32 = x.item()
def test_uniform_() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2], L[3], L[4]] = nn.init.uniform_(x, a=1.0, b=2.0)
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = nn.init.uniform_(
x, a=1.0, b=2.0
)
def test_kaiming_uniform_() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2], L[3], L[4]] = nn.init.kaiming_uniform_(
x, a=1.0, mode="fan_in", nonlinearity="leaky_relu"
)
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = nn.init.kaiming_uniform_(x)
def test_constant_() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2], L[3], L[4]] = nn.init.constant_(x, val=1.0)
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = nn.init.constant_(x, val=1.0)
def test_leaky_relu() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2], L[3], L[4]] = nn.LeakyReLU(
negative_slope=1.0, inplace=True
)(x)
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = nn.LeakyReLU(
negative_slope=1.0, inplace=True
)(x)
def test_fft_fft2() -> None:
x: Tensor[torch.complex64, L[2], L[3], L[4]]
y: Tensor[torch.complex64, L[2], L[3], L[4]] = torch.fft.fft2(x)
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.complex64, L[2], L[3], L[99]] = torch.fft.fft2(x)
def test_real() -> None:
x: Tensor[torch.complex64, L[2], L[3], L[4]]
y: Tensor[torch.float32, L[2], L[3], L[4]] = x.real
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = x.real
x2: Tensor[torch.complex128, L[2], L[3], L[4]]
y2: Tensor[torch.float64, L[2], L[3], L[4]] = x2.real
not_complex: Tensor[torch.float64, L[2], L[3], L[4]]
# Should error but we don't have overloads for @property.
not_complex.real
def test_Tensor_init() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
# pyre-fixme[9]: Unexpected error because the constructor doesn't bind DType.
y: Tensor[torch.float32, L[2], L[3], L[4]] = Tensor((2, 3, 4), device="cuda")
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[2], L[3], L[99]] = Tensor((2, 3, 4), device="cuda")
y2: Tensor[torch.float32, L[2], L[3], L[4]] = Tensor(2, 3, 4, device="cuda")
y3: Tensor[torch.float32, L[2], L[3], L[4]] = Tensor(x)
def test_reflection_pad2d() -> None:
module: nn.Module = nn.ReflectionPad2d(4)
x: Tensor[torch.float32, L[20], L[16], L[50], L[100]]
y: Tensor[torch.float32, L[20], L[16], L[58], L[108]] = nn.ReflectionPad2d(4)(x)
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.float32, L[20], L[16], L[58], L[99]] = nn.ReflectionPad2d(4)(
x
)
def test_half() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
good1: torch.Tensor[torch.float16, L[2], L[3], L[4]] = x.half(torch.memory_format())
# pyre-fixme[9]: Expected error.
bad1: torch.Tensor[torch.float16, L[99], L[3], L[4]] = x.half()
def test_is_contiguous() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
y: bool = x.is_contiguous(torch.memory_format())
def test_scatter() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
# We don't really check for the shape of index or src.
index: torch.LongTensor[torch.float32, L[99]]
src: torch.Tensor[torch.float32, L[99], L[99]]
y: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.scatter(0, index, src)
# pyre-fixme[9]: Expected error.
y_error: torch.Tensor[torch.float32, L[2], L[3], L[99]] = x.scatter(0, index, src)
y2: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.scatter(2, index, src)
def test_scatter_() -> None:
x: torch.Tensor[torch.float32, L[2], L[3], L[4]]
# We don't really check for the shape of index or src.
index: torch.LongTensor[torch.float32, L[99]]
src: torch.Tensor[torch.float32, L[99], L[99]]
y: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.scatter_(0, index, src)
# pyre-fixme[9]: Expected error.
y_error: torch.Tensor[torch.float32, L[2], L[3], L[99]] = x.scatter_(0, index, src)
y2: torch.Tensor[torch.float32, L[2], L[3], L[4]] = x.scatter_(2, index, src)
def test_bool() -> None:
x: Tensor[torch.float32, L[2], L[3], L[4]]
y: Tensor[torch.bool, L[2], L[3], L[4]] = x.bool()
# pyre-fixme[9]: Expected error.
y_error: Tensor[torch.bool, L[2], L[3], L[99]] = x.bool()