File size: 3,183 Bytes
02c5426 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import math
from argparse import Namespace
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import register
class gen_basis(nn.Module):
def __init__(self, args):
super(gen_basis, self).__init__()
self.basis_num = args.basis_num
self.hidden = args.hidden
self.state = args.state
self.path=args.path
def init_basis_bias(self):
self.w0 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*580), requires_grad=True)
nn.init.kaiming_uniform_(self.w0, a=math.sqrt(5))
self.w1 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True)
nn.init.kaiming_uniform_(self.w1, a=math.sqrt(5))
self.w2 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True)
nn.init.kaiming_uniform_(self.w2, a=math.sqrt(5))
self.w3 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden*self.hidden), requires_grad=True)
nn.init.kaiming_uniform_(self.w3, a=math.sqrt(5))
self.w4 = nn.Parameter(torch.Tensor(self.basis_num,3*self.hidden), requires_grad=True)
nn.init.kaiming_uniform_(self.w4, a=math.sqrt(5))
basis = [self.w0, self.w1, self.w2, self.w3, self.w4]
self.bias1 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True)
self.bias2 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True)
self.bias3 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True)
self.bias4 = nn.Parameter(torch.Tensor(self.basis_num,self.hidden), requires_grad=True)
self.bias5 = nn.Parameter(torch.Tensor(self.basis_num,3), requires_grad=True)
bias = [self.bias1,self.bias2,self.bias3,self.bias4,self.bias5]
for i in range(len(bias)):
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(basis[i])
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(bias[i], -bound, bound)
return basis,bias
def load_basis_for_test_kaiming(self,path):
model_spec = torch.load(path)['model']
w0 = model_spec['sd']['basis.w0']
w1 = model_spec['sd']['basis.w1']
w2 = model_spec['sd']['basis.w2']
w3 = model_spec['sd']['basis.w3']
w4 = model_spec['sd']['basis.w4']
b0 = model_spec['sd']['basis.bias1']
b1 = model_spec['sd']['basis.bias2']
b2 = model_spec['sd']['basis.bias3']
b3 = model_spec['sd']['basis.bias4']
b4 = model_spec['sd']['basis.bias5']
torch.cuda.empty_cache()
return [w0,w1,w2,w3,w4],[b0,b1,b2,b3,b4]
def forward(self):
if self.state=='train':
print('init_basis_use_kaiming')
res=self.init_basis_bias()
else:
print('load_basis_from_model')
res=self.load_basis_for_test_kaiming(self.path)
return res
@register('basis')
def make_basis(basis_num=10,hidden=16,state=None,path=None):
args = Namespace()
args.basis_num = basis_num
args.hidden = hidden
args.state = state
args.path = path
return gen_basis(args)
|