Pavan2k4's picture
app
35d85a5
import torch
import torch.nn as nn
import torch.functional as F
class batch_norm(nn.Module):
def __init__(self, inp):
super().__init__()
self.batch = nn.BatchNorm2d(inp)
self.relu = nn.ReLU()
def forward(self, x):
b = self.batch(x)
op = self.relu(b)
return op