|
import math |
|
|
|
import torch.nn as nn |
|
|
|
|
|
class Upsample(nn.Module): |
|
"""Upsample module. |
|
Args: |
|
scale (int): Scale factor. Supported scales: 2^n and 3. |
|
num_feat (int): Channel number of intermediate features. |
|
""" |
|
|
|
def __init__(self, scale, num_feat): |
|
super(Upsample, self).__init__() |
|
m = [] |
|
if (scale & (scale - 1)) == 0: |
|
for _ in range(int(math.log(scale, 2))): |
|
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) |
|
m.append(nn.PixelShuffle(2)) |
|
elif scale == 3: |
|
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) |
|
m.append(nn.PixelShuffle(3)) |
|
else: |
|
raise ValueError( |
|
f"scale {scale} is not supported. " "Supported scales: 2^n and 3." |
|
) |
|
self.up = nn.Sequential(*m) |
|
|
|
def forward(self, x): |
|
return self.up(x) |
|
|
|
|
|
class UpsampleOneStep(nn.Module): |
|
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) |
|
Used in lightweight SR to save parameters. |
|
Args: |
|
scale (int): Scale factor. Supported scales: 2^n and 3. |
|
num_feat (int): Channel number of intermediate features. |
|
""" |
|
|
|
def __init__(self, scale, num_feat, num_out_ch): |
|
super(UpsampleOneStep, self).__init__() |
|
self.num_feat = num_feat |
|
m = [] |
|
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) |
|
m.append(nn.PixelShuffle(scale)) |
|
self.up = nn.Sequential(*m) |
|
|
|
def forward(self, x): |
|
return self.up(x) |
|
|