import torch import torch.nn as nn from torch.nn import functional as F class LinearImplicitBackward(nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight, self.bias) class LinearBackward(nn.Module): has_backward = True def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight, self.bias) class LinearNoBackward(nn.Module): has_backward = False def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight, self.bias) __all__ = ["LinearImplicitBackward", "LinearBackward", "LinearNoBackward"]