|
import math |
|
from argparse import Namespace |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import models |
|
from models import register |
|
import numpy as np |
|
|
|
class ExpansionNet(nn.Module): |
|
def __init__(self, args): |
|
super(ExpansionNet, self).__init__() |
|
self.args = args |
|
self.in_dim = args.in_dim |
|
self.out_dim = args.out_dim |
|
self.hidden_list = args.hidden_list |
|
layers = [] |
|
lastv = self.in_dim |
|
hidden_list = self.hidden_list |
|
out_dim = self.out_dim |
|
for hidden in hidden_list: |
|
layers.append(nn.Linear(lastv, hidden)) |
|
layers.append(nn.ReLU()) |
|
lastv = hidden |
|
layers.append(nn.Linear(lastv, out_dim)) |
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
b, _, c = x.shape |
|
x = x.view(-1, c) |
|
logits = self.layers(x) |
|
out = nn.functional.normalize(logits, dim=1) |
|
return out.view(b,_,self.out_dim) |
|
|
|
|
|
@register('ExpansionNet') |
|
def make_ExpansionNet(in_dim=580,out_dim=10,hidden_list=None): |
|
args = Namespace() |
|
args.in_dim = in_dim |
|
args.out_dim = out_dim |
|
args.hidden_list = hidden_list |
|
return ExpansionNet(args) |
|
|