Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,654 Bytes
f876753 fc44d4b f876753 fc44d4b f876753 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils.general_utils import config_to_primitive
from dataclasses import dataclass
def get_activation(name):
if name is None:
return lambda x: x
name = name.lower()
if name == "none":
return lambda x: x
elif name == "sigmoid-mipnerf":
return lambda x: torch.sigmoid(x) * (1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
else:
try:
return getattr(F, name)
except AttributeError:
raise ValueError(f"Unknown activation function: {name}")
class VanillaMLP(nn.Module):
def __init__(self, dim_in, dim_out, config):
super().__init__()
# Convert dict to MLPConfig if needed
if isinstance(config, dict):
config = MLPConfig(**config)
self.n_neurons = config.n_neurons
self.n_hidden_layers = config.n_hidden_layers
layers = [
self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False),
self.make_activation(),
]
for i in range(self.n_hidden_layers - 1):
layers += [
self.make_linear(
self.n_neurons, self.n_neurons, is_first=False, is_last=False
),
self.make_activation(),
]
layers += [
self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)
]
self.layers = nn.Sequential(*layers)
self.output_activation = get_activation(config.output_activation)
def forward(self, x):
# disable autocast
# strange that the parameters will have empty gradients if autocast is enabled in AMP
with torch.cuda.amp.autocast(enabled=False):
x = self.layers(x)
x = self.output_activation(x)
return x
def make_linear(self, dim_in, dim_out, is_first, is_last):
layer = nn.Linear(dim_in, dim_out, bias=False)
return layer
def make_activation(self):
return nn.ReLU(inplace=True)
@dataclass
class MLPConfig:
otype: str = "VanillaMLP"
activation: str = "ReLU"
output_activation: str = "none"
n_neurons: int = 64
n_hidden_layers: int = 2
def get_mlp(input_dim, output_dim, config):
"""Create MLP network based on config"""
# Convert dict to MLPConfig
if isinstance(config, dict):
config = MLPConfig(**config)
if config.otype == "VanillaMLP":
network = VanillaMLP(input_dim, output_dim, config)
else:
raise ValueError(f"Unknown MLP type: {config.otype}")
return network |