import os from dataclasses import dataclass, field import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from jaxtyping import Float from torch import Tensor from typing import * from ...utils.general_utils import contract_to_unisphere_custom, sample_from_planes from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel from ..networks import get_mlp from ...utils.general_utils import config_to_primitive @dataclass class StableDiffusionTriplaneDualAttentionConfig: n_feature_dims: int = 3 space_generator_config: dict = field( default_factory=lambda: { "pretrained_model_name_or_path": "stable-diffusion-2-1-base", "training_type": "self_lora_rank_16-cross_lora_rank_16-locon_rank_16", "output_dim": 32, "gradient_checkpoint": False, "self_lora_type": "hexa_v1", "cross_lora_type": "hexa_v1", "locon_type": "vanilla_v1", } ) mlp_network_config: dict = field( default_factory=lambda: { "otype": "VanillaMLP", "activation": "ReLU", "output_activation": "none", "n_neurons": 64, "n_hidden_layers": 2, } ) backbone: str = "one_step_triplane_dual_stable_diffusion" finite_difference_normal_eps: Union[ float, str ] = 0.01 # in [float, "progressive"] finite_difference_normal_eps: Union[float, str] = 0.01 sdf_bias: Union[float, str] = 0.0 sdf_bias_params: Optional[Any] = None isosurface_remove_outliers: bool = False # rotate planes to fit the conventional direction of image generated by SD # in right-handed coordinate system # xy plane should looks that a img from top-down / bottom-up view # xz plane should looks that a img from right-left / left-right view # yz plane should looks that a img from front-back / back-front view rotate_planes: Optional[str] = None split_channels: Optional[str] = None geo_interpolate: str = "v1" tex_interpolate: str = "v1" isosurface_deformable_grid: bool = True class StableDiffusionTriplaneDualAttention(nn.Module): def __init__( self, config: StableDiffusionTriplaneDualAttentionConfig, vae: AutoencoderKL, unet: UNet2DConditionModel, ): super().__init__() self.cfg = StableDiffusionTriplaneDualAttentionConfig(**config) if isinstance(config, dict) else config # set up the space generator from ...extern.sd_dual_triplane_modules import OneStepTriplaneDualStableDiffusion as Generator self.space_generator = Generator( self.cfg.space_generator_config, vae=vae, unet=unet, ) input_dim = self.space_generator.output_dim # feat_xy + feat_xz + feat_yz assert self.cfg.split_channels in [None, "v1"] if self.cfg.split_channels in ["v1"]: # split geometry and texture input_dim = input_dim // 2 assert self.cfg.geo_interpolate in ["v1", "v2"] if self.cfg.geo_interpolate in ["v2"]: geo_input_dim = input_dim * 3 # concat[feat_xy, feat_xz, feat_yz] else: geo_input_dim = input_dim # feat_xy + feat_xz + feat_yz assert self.cfg.tex_interpolate in ["v1", "v2"] if self.cfg.tex_interpolate in ["v2"]: tex_input_dim = input_dim * 3 # concat[feat_xy, feat_xz, feat_yz] else: tex_input_dim = input_dim # feat_xy + feat_xz + feat_yz self.sdf_network = get_mlp( geo_input_dim, 1, self.cfg.mlp_network_config, ) if self.cfg.n_feature_dims > 0: self.feature_network = get_mlp( tex_input_dim, self.cfg.n_feature_dims, self.cfg.mlp_network_config, ) if self.cfg.isosurface_deformable_grid: self.deformation_network = get_mlp( geo_input_dim, 3, self.cfg.mlp_network_config, ) # hard-coded for now self.unbounded = False radius = 1.0 self.register_buffer( "bbox", torch.as_tensor( [ [-radius, -radius, -radius], [radius, radius, radius], ], dtype=torch.float32, ) ) def initialize_shape(self) -> None: # not used pass def get_shifted_sdf( self, points: Float[Tensor, "*N Di"], sdf: Float[Tensor, "*N 1"] ) -> Float[Tensor, "*N 1"]: sdf_bias: Union[float, Float[Tensor, "*N 1"]] if self.cfg.sdf_bias == "ellipsoid": assert ( isinstance(self.cfg.sdf_bias_params, Sized) and len(self.cfg.sdf_bias_params) == 3 ) size = torch.as_tensor(self.cfg.sdf_bias_params).to(points) sdf_bias = ((points / size) ** 2).sum( dim=-1, keepdim=True ).sqrt() - 1.0 # pseudo signed distance of an ellipsoid elif self.cfg.sdf_bias == "sphere": assert isinstance(self.cfg.sdf_bias_params, float) radius = self.cfg.sdf_bias_params sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius elif isinstance(self.cfg.sdf_bias, float): sdf_bias = self.cfg.sdf_bias else: raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}") return sdf + sdf_bias def generate_space_cache( self, styles: Float[Tensor, "B Z"], text_embed: Float[Tensor, "B C"], ) -> Any: output = self.space_generator( text_embed = text_embed, styles = styles, ) return output def denoise( self, noisy_input: Any, text_embed: Float[Tensor, "B C"], timestep ) -> Any: output = self.space_generator.forward_denoise( text_embed = text_embed, noisy_input = noisy_input, t = timestep ) return output def decode( self, latents: Any, ) -> Any: triplane = self.space_generator.forward_decode( latents = latents ) if self.cfg.split_channels == None: return triplane elif self.cfg.split_channels == "v1": B, _, C, H, W = triplane.shape # geometry triplane uses the first n_feature_dims // 2 channels # texture triplane uses the last n_feature_dims // 2 channels used_indices_geo = torch.tensor([True] * (self.space_generator.output_dim// 2) + [False] * (self.space_generator.output_dim // 2)) used_indices_tex = torch.tensor([False] * (self.space_generator.output_dim // 2) + [True] * (self.space_generator.output_dim // 2)) used_indices = torch.stack([used_indices_geo] * 3 + [used_indices_tex] * 3, dim=0).to(triplane.device) return triplane[:, used_indices].view(B, 6, C//2, H, W) def interpolate_encodings( self, points: Float[Tensor, "*N Di"], space_cache: Float[Tensor, "B 3 C//3 H W"], only_geo: bool = False, ): batch_size, n_points, n_dims = points.shape # the following code is similar to EG3D / OpenLRM assert self.cfg.rotate_planes in [None, "v1", "v2"] if self.cfg.rotate_planes == None: raise NotImplementedError("rotate_planes == None is not implemented yet.") space_cache_rotated = torch.zeros_like(space_cache) if self.cfg.rotate_planes == "v1": # xy plane, diagonal-wise space_cache_rotated[:, 0::3] = torch.transpose( space_cache[:, 0::3], 3, 4 ) # xz plane, rotate 180° counterclockwise space_cache_rotated[:, 1::3] = torch.rot90( space_cache[:, 1::3], k=2, dims=(3, 4) ) # zy plane, rotate 90° clockwise space_cache_rotated[:, 2::3] = torch.rot90( space_cache[:, 2::3], k=-1, dims=(3, 4) ) elif self.cfg.rotate_planes == "v2": # all are the same as v1, except for the xy plane # xy plane, row-wise flip space_cache_rotated[:, 0::3] = torch.flip( space_cache[:, 0::3], dims=(4,) ) # xz plane, rotate 180° counterclockwise space_cache_rotated[:, 1::3] = torch.rot90( space_cache[:, 1::3], k=2, dims=(3, 4) ) # zy plane, rotate 90° clockwise space_cache_rotated[:, 2::3] = torch.rot90( space_cache[:, 2::3], k=-1, dims=(3, 4) ) # the 0, 1, 2 axis of the space_cache_rotated is for geometry geo_feat = sample_from_planes( plane_features = space_cache_rotated[:, 0:3].contiguous(), coordinates = points, interpolate_feat = self.cfg.geo_interpolate ).view(*points.shape[:-1],-1) if only_geo: return geo_feat else: # the 3, 4, 5 axis of the space_cache is for texture tex_feat = sample_from_planes( plane_features = space_cache_rotated[:, 3:6].contiguous(), coordinates = points, interpolate_feat = self.cfg.tex_interpolate ).view(*points.shape[:-1],-1) return geo_feat, tex_feat def rescale_points( self, points: Float[Tensor, "*N Di"], ): # transform points from original space to [-1, 1]^3 points = contract_to_unisphere_custom( points, self.bbox, self.unbounded ) return points def forward( self, points: Float[Tensor, "*N Di"], space_cache: Any, ) -> Dict[str, Float[Tensor, "..."]]: batch_size, n_points, n_dims = points.shape points_unscaled = points points = self.rescale_points(points) enc_geo, enc_tex = self.interpolate_encodings(points, space_cache) sdf_orig = self.sdf_network(enc_geo).view(*points.shape[:-1], 1) sdf = self.get_shifted_sdf(points_unscaled, sdf_orig) output = { "sdf": sdf.view(batch_size * n_points, 1), # reshape to [B*N, 1] } if self.cfg.n_feature_dims > 0: features = self.feature_network(enc_tex).view( *points.shape[:-1], self.cfg.n_feature_dims) output.update( { "features": features.view(batch_size * n_points, self.cfg.n_feature_dims) } ) return output def forward_sdf( self, points: Float[Tensor, "*N Di"], space_cache: Float[Tensor, "B 3 C//3 H W"], ) -> Float[Tensor, "*N 1"]: batch_size = points.shape[0] assert points.shape[0] == batch_size, "points and space_cache should have the same batch size in forward_sdf" points_unscaled = points points = self.rescale_points(points) # sample from planes enc_geo = self.interpolate_encodings( points.reshape(batch_size, -1, 3), space_cache, only_geo = True ).reshape(*points.shape[:-1], -1) sdf = self.sdf_network(enc_geo).reshape(*points.shape[:-1], 1) sdf = self.get_shifted_sdf(points_unscaled, sdf) return sdf def forward_field( self, points: Float[Tensor, "*N Di"], space_cache: Float[Tensor, "B 3 C//3 H W"], ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: batch_size = points.shape[0] assert points.shape[0] == batch_size, "points and space_cache should have the same batch size in forward_sdf" points_unscaled = points points = self.rescale_points(points) # sample from planes enc_geo = self.interpolate_encodings(points, space_cache, only_geo = True) sdf = self.sdf_network(enc_geo).reshape(*points.shape[:-1], 1) sdf = self.get_shifted_sdf(points_unscaled, sdf) deformation: Optional[Float[Tensor, "*N 3"]] = None if self.cfg.isosurface_deformable_grid: deformation = self.deformation_network(enc_geo).reshape(*points.shape[:-1], 3) return sdf, deformation def forward_level( self, field: Float[Tensor, "*N 1"], threshold: float ) -> Float[Tensor, "*N 1"]: # TODO: is this function correct? return field - threshold def export( self, points: Float[Tensor, "*N Di"], space_cache: Float[Tensor, "B 3 C//3 H W"], **kwargs) -> Dict[str, Any]: # TODO: is this function correct? out: Dict[str, Any] = {} if self.cfg.n_feature_dims == 0: return out orig_shape = points.shape points = points.view(1, -1, 3) # assume the batch size is 1 points_unscaled = points points = self.rescale_points(points) # sample from planes _, enc_tex = self.interpolate_encodings(points, space_cache) features = self.feature_network(enc_tex).view( *points.shape[:-1], self.cfg.n_feature_dims ) out.update( { "features": features.view(orig_shape[:-1] + (self.cfg.n_feature_dims,)) } ) return out def train(self, mode=True): super().train(mode) self.space_generator.train(mode) def eval(self): super().eval() self.space_generator.eval()