ZhiyuanthePony's picture
update
f876753
import os
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from jaxtyping import Float
from torch import Tensor
from typing import *
from ...utils.general_utils import contract_to_unisphere_custom, sample_from_planes
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel
from ..networks import get_mlp
from ...utils.general_utils import config_to_primitive
@dataclass
class StableDiffusionTriplaneDualAttentionConfig:
n_feature_dims: int = 3
space_generator_config: dict = field(
default_factory=lambda: {
"pretrained_model_name_or_path": "stable-diffusion-2-1-base",
"training_type": "self_lora_rank_16-cross_lora_rank_16-locon_rank_16",
"output_dim": 32,
"gradient_checkpoint": False,
"self_lora_type": "hexa_v1",
"cross_lora_type": "hexa_v1",
"locon_type": "vanilla_v1",
}
)
mlp_network_config: dict = field(
default_factory=lambda: {
"otype": "VanillaMLP",
"activation": "ReLU",
"output_activation": "none",
"n_neurons": 64,
"n_hidden_layers": 2,
}
)
backbone: str = "one_step_triplane_dual_stable_diffusion"
finite_difference_normal_eps: Union[
float, str
] = 0.01 # in [float, "progressive"] finite_difference_normal_eps: Union[float, str] = 0.01
sdf_bias: Union[float, str] = 0.0
sdf_bias_params: Optional[Any] = None
isosurface_remove_outliers: bool = False
# rotate planes to fit the conventional direction of image generated by SD
# in right-handed coordinate system
# xy plane should looks that a img from top-down / bottom-up view
# xz plane should looks that a img from right-left / left-right view
# yz plane should looks that a img from front-back / back-front view
rotate_planes: Optional[str] = None
split_channels: Optional[str] = None
geo_interpolate: str = "v1"
tex_interpolate: str = "v1"
isosurface_deformable_grid: bool = True
class StableDiffusionTriplaneDualAttention(nn.Module):
def __init__(
self,
config: StableDiffusionTriplaneDualAttentionConfig,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
):
super().__init__()
self.cfg = StableDiffusionTriplaneDualAttentionConfig(**config) if isinstance(config, dict) else config
# set up the space generator
from ...extern.sd_dual_triplane_modules import OneStepTriplaneDualStableDiffusion as Generator
self.space_generator = Generator(
self.cfg.space_generator_config,
vae=vae,
unet=unet,
)
input_dim = self.space_generator.output_dim # feat_xy + feat_xz + feat_yz
assert self.cfg.split_channels in [None, "v1"]
if self.cfg.split_channels in ["v1"]: # split geometry and texture
input_dim = input_dim // 2
assert self.cfg.geo_interpolate in ["v1", "v2"]
if self.cfg.geo_interpolate in ["v2"]:
geo_input_dim = input_dim * 3 # concat[feat_xy, feat_xz, feat_yz]
else:
geo_input_dim = input_dim # feat_xy + feat_xz + feat_yz
assert self.cfg.tex_interpolate in ["v1", "v2"]
if self.cfg.tex_interpolate in ["v2"]:
tex_input_dim = input_dim * 3 # concat[feat_xy, feat_xz, feat_yz]
else:
tex_input_dim = input_dim # feat_xy + feat_xz + feat_yz
self.sdf_network = get_mlp(
geo_input_dim,
1,
self.cfg.mlp_network_config,
)
if self.cfg.n_feature_dims > 0:
self.feature_network = get_mlp(
tex_input_dim,
self.cfg.n_feature_dims,
self.cfg.mlp_network_config,
)
if self.cfg.isosurface_deformable_grid:
self.deformation_network = get_mlp(
geo_input_dim,
3,
self.cfg.mlp_network_config,
)
# hard-coded for now
self.unbounded = False
radius = 1.0
self.register_buffer(
"bbox",
torch.as_tensor(
[
[-radius, -radius, -radius],
[radius, radius, radius],
],
dtype=torch.float32,
)
)
def initialize_shape(self) -> None:
# not used
pass
def get_shifted_sdf(
self,
points: Float[Tensor, "*N Di"],
sdf: Float[Tensor, "*N 1"]
) -> Float[Tensor, "*N 1"]:
sdf_bias: Union[float, Float[Tensor, "*N 1"]]
if self.cfg.sdf_bias == "ellipsoid":
assert (
isinstance(self.cfg.sdf_bias_params, Sized)
and len(self.cfg.sdf_bias_params) == 3
)
size = torch.as_tensor(self.cfg.sdf_bias_params).to(points)
sdf_bias = ((points / size) ** 2).sum(
dim=-1, keepdim=True
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
elif self.cfg.sdf_bias == "sphere":
assert isinstance(self.cfg.sdf_bias_params, float)
radius = self.cfg.sdf_bias_params
sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius
elif isinstance(self.cfg.sdf_bias, float):
sdf_bias = self.cfg.sdf_bias
else:
raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}")
return sdf + sdf_bias
def generate_space_cache(
self,
styles: Float[Tensor, "B Z"],
text_embed: Float[Tensor, "B C"],
) -> Any:
output = self.space_generator(
text_embed = text_embed,
styles = styles,
)
return output
def denoise(
self,
noisy_input: Any,
text_embed: Float[Tensor, "B C"],
timestep
) -> Any:
output = self.space_generator.forward_denoise(
text_embed = text_embed,
noisy_input = noisy_input,
t = timestep
)
return output
def decode(
self,
latents: Any,
) -> Any:
triplane = self.space_generator.forward_decode(
latents = latents
)
if self.cfg.split_channels == None:
return triplane
elif self.cfg.split_channels == "v1":
B, _, C, H, W = triplane.shape
# geometry triplane uses the first n_feature_dims // 2 channels
# texture triplane uses the last n_feature_dims // 2 channels
used_indices_geo = torch.tensor([True] * (self.space_generator.output_dim// 2) + [False] * (self.space_generator.output_dim // 2))
used_indices_tex = torch.tensor([False] * (self.space_generator.output_dim // 2) + [True] * (self.space_generator.output_dim // 2))
used_indices = torch.stack([used_indices_geo] * 3 + [used_indices_tex] * 3, dim=0).to(triplane.device)
return triplane[:, used_indices].view(B, 6, C//2, H, W)
def interpolate_encodings(
self,
points: Float[Tensor, "*N Di"],
space_cache: Float[Tensor, "B 3 C//3 H W"],
only_geo: bool = False,
):
batch_size, n_points, n_dims = points.shape
# the following code is similar to EG3D / OpenLRM
assert self.cfg.rotate_planes in [None, "v1", "v2"]
if self.cfg.rotate_planes == None:
raise NotImplementedError("rotate_planes == None is not implemented yet.")
space_cache_rotated = torch.zeros_like(space_cache)
if self.cfg.rotate_planes == "v1":
# xy plane, diagonal-wise
space_cache_rotated[:, 0::3] = torch.transpose(
space_cache[:, 0::3], 3, 4
)
# xz plane, rotate 180° counterclockwise
space_cache_rotated[:, 1::3] = torch.rot90(
space_cache[:, 1::3], k=2, dims=(3, 4)
)
# zy plane, rotate 90° clockwise
space_cache_rotated[:, 2::3] = torch.rot90(
space_cache[:, 2::3], k=-1, dims=(3, 4)
)
elif self.cfg.rotate_planes == "v2":
# all are the same as v1, except for the xy plane
# xy plane, row-wise flip
space_cache_rotated[:, 0::3] = torch.flip(
space_cache[:, 0::3], dims=(4,)
)
# xz plane, rotate 180° counterclockwise
space_cache_rotated[:, 1::3] = torch.rot90(
space_cache[:, 1::3], k=2, dims=(3, 4)
)
# zy plane, rotate 90° clockwise
space_cache_rotated[:, 2::3] = torch.rot90(
space_cache[:, 2::3], k=-1, dims=(3, 4)
)
# the 0, 1, 2 axis of the space_cache_rotated is for geometry
geo_feat = sample_from_planes(
plane_features = space_cache_rotated[:, 0:3].contiguous(),
coordinates = points,
interpolate_feat = self.cfg.geo_interpolate
).view(*points.shape[:-1],-1)
if only_geo:
return geo_feat
else:
# the 3, 4, 5 axis of the space_cache is for texture
tex_feat = sample_from_planes(
plane_features = space_cache_rotated[:, 3:6].contiguous(),
coordinates = points,
interpolate_feat = self.cfg.tex_interpolate
).view(*points.shape[:-1],-1)
return geo_feat, tex_feat
def rescale_points(
self,
points: Float[Tensor, "*N Di"],
):
# transform points from original space to [-1, 1]^3
points = contract_to_unisphere_custom(
points,
self.bbox,
self.unbounded
)
return points
def forward(
self,
points: Float[Tensor, "*N Di"],
space_cache: Any,
) -> Dict[str, Float[Tensor, "..."]]:
batch_size, n_points, n_dims = points.shape
points_unscaled = points
points = self.rescale_points(points)
enc_geo, enc_tex = self.interpolate_encodings(points, space_cache)
sdf_orig = self.sdf_network(enc_geo).view(*points.shape[:-1], 1)
sdf = self.get_shifted_sdf(points_unscaled, sdf_orig)
output = {
"sdf": sdf.view(batch_size * n_points, 1), # reshape to [B*N, 1]
}
if self.cfg.n_feature_dims > 0:
features = self.feature_network(enc_tex).view(
*points.shape[:-1], self.cfg.n_feature_dims)
output.update(
{
"features": features.view(batch_size * n_points, self.cfg.n_feature_dims)
}
)
return output
def forward_sdf(
self,
points: Float[Tensor, "*N Di"],
space_cache: Float[Tensor, "B 3 C//3 H W"],
) -> Float[Tensor, "*N 1"]:
batch_size = points.shape[0]
assert points.shape[0] == batch_size, "points and space_cache should have the same batch size in forward_sdf"
points_unscaled = points
points = self.rescale_points(points)
# sample from planes
enc_geo = self.interpolate_encodings(
points.reshape(batch_size, -1, 3),
space_cache,
only_geo = True
).reshape(*points.shape[:-1], -1)
sdf = self.sdf_network(enc_geo).reshape(*points.shape[:-1], 1)
sdf = self.get_shifted_sdf(points_unscaled, sdf)
return sdf
def forward_field(
self,
points: Float[Tensor, "*N Di"],
space_cache: Float[Tensor, "B 3 C//3 H W"],
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
batch_size = points.shape[0]
assert points.shape[0] == batch_size, "points and space_cache should have the same batch size in forward_sdf"
points_unscaled = points
points = self.rescale_points(points)
# sample from planes
enc_geo = self.interpolate_encodings(points, space_cache, only_geo = True)
sdf = self.sdf_network(enc_geo).reshape(*points.shape[:-1], 1)
sdf = self.get_shifted_sdf(points_unscaled, sdf)
deformation: Optional[Float[Tensor, "*N 3"]] = None
if self.cfg.isosurface_deformable_grid:
deformation = self.deformation_network(enc_geo).reshape(*points.shape[:-1], 3)
return sdf, deformation
def forward_level(
self, field: Float[Tensor, "*N 1"], threshold: float
) -> Float[Tensor, "*N 1"]:
# TODO: is this function correct?
return field - threshold
def export(
self,
points: Float[Tensor, "*N Di"],
space_cache: Float[Tensor, "B 3 C//3 H W"],
**kwargs) -> Dict[str, Any]:
# TODO: is this function correct?
out: Dict[str, Any] = {}
if self.cfg.n_feature_dims == 0:
return out
orig_shape = points.shape
points = points.view(1, -1, 3)
# assume the batch size is 1
points_unscaled = points
points = self.rescale_points(points)
# sample from planes
_, enc_tex = self.interpolate_encodings(points, space_cache)
features = self.feature_network(enc_tex).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
out.update(
{
"features": features.view(orig_shape[:-1] + (self.cfg.n_feature_dims,))
}
)
return out
def train(self, mode=True):
super().train(mode)
self.space_generator.train(mode)
def eval(self):
super().eval()
self.space_generator.eval()