File size: 660 Bytes
64af3aa f60b4b0 64af3aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
import torch.nn as nn
from torch.nn import functional as F
class LinearImplicitBackward(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
class LinearBackward(nn.Module):
has_backward = True
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
class LinearNoBackward(nn.Module):
has_backward = False
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
__all__ = ["LinearImplicitBackward", "LinearBackward", "LinearNoBackward"]
|