import torch import torch.nn as nn import torch.functional as F class batch_norm(nn.Module): def __init__(self, inp): super().__init__() self.batch = nn.BatchNorm2d(inp) self.relu = nn.ReLU() def forward(self, x): b = self.batch(x) op = self.relu(b) return op