Spaces:
Dorjzodovsuren
/
Running on Zero

QZFantasies's picture
add wheels
c614b0f
raw
history blame contribute delete
12.9 kB
import os
import imageio
import numpy as np
import torch
from tqdm import tqdm
from pytorch3d.renderer import (
PerspectiveCameras,
TexturesVertex,
PointLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
)
from pytorch3d.renderer.mesh.shader import ShaderBase
from pytorch3d.structures import Meshes
class NormalShader(ShaderBase):
def __init__(self, device = "cpu", **kwargs):
super().__init__(device=device, **kwargs)
def forward(self, fragments, meshes, **kwargs):
blend_params = kwargs.get("blend_params", self.blend_params)
texels = fragments.bary_coords.clone()
texels = texels.permute(0, 3, 1, 2, 4)
texels = texels * 2 - 1 # 将 bary_coords 映射到 [-1, 1]
# 获取法线
verts_normals = meshes.verts_normals_packed()
faces_normals = verts_normals[meshes.faces_packed()]
bary_coords = fragments.bary_coords
pixel_normals = (bary_coords[..., None] * faces_normals[fragments.pix_to_face]).sum(dim=-2)
pixel_normals = pixel_normals / pixel_normals.norm(dim=-1, keepdim=True)
# 将法线映射到颜色空间
# colors = (pixel_normals + 1) / 2 # 将法线映射到 [0, 1]
colors = torch.clamp(pixel_normals, -1, 1)
print(colors.shape)
mask = (fragments.pix_to_face > 0).float()
colors = torch.cat([colors, mask.unsqueeze(-1)], dim=-1)
# colors[fragments.pix_to_face < 0] = 0
# 混合颜色
# images = self.blend(texels, colors, fragments, blend_params)
return colors
def overlay_image_onto_background(image, mask, bbox, background):
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
if isinstance(mask, torch.Tensor):
mask = mask.detach().cpu().numpy()
out_image = background.copy()
bbox = bbox[0].int().cpu().numpy().copy()
roi_image = out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]]
if len(roi_image) < 1 or len(roi_image[1]) < 1:
return out_image
try:
roi_image[mask] = image[mask]
except Exception as e:
raise e
out_image[bbox[1]:bbox[3], bbox[0]:bbox[2]] = roi_image
return out_image
def update_intrinsics_from_bbox(K_org, bbox):
'''
update intrinsics for cropped images
'''
device, dtype = K_org.device, K_org.dtype
K = torch.zeros((K_org.shape[0], 4, 4)
).to(device=device, dtype=dtype)
K[:, :3, :3] = K_org.clone()
K[:, 2, 2] = 0
K[:, 2, -1] = 1
K[:, -1, 2] = 1
image_sizes = []
for idx, bbox in enumerate(bbox):
left, upper, right, lower = bbox
cx, cy = K[idx, 0, 2], K[idx, 1, 2]
new_cx = cx - left
new_cy = cy - upper
new_height = max(lower - upper, 1)
new_width = max(right - left, 1)
new_cx = new_width - new_cx
new_cy = new_height - new_cy
K[idx, 0, 2] = new_cx
K[idx, 1, 2] = new_cy
image_sizes.append((int(new_height), int(new_width)))
return K, image_sizes
def perspective_projection(x3d, K, R=None, T=None):
if R != None:
x3d = torch.matmul(R, x3d.transpose(1, 2)).transpose(1, 2)
if T != None:
x3d = x3d + T.transpose(1, 2)
x2d = torch.div(x3d, x3d[..., 2:])
x2d = torch.matmul(K, x2d.transpose(-1, -2)).transpose(-1, -2)[..., :2]
return x2d
def compute_bbox_from_points(X, img_w, img_h, scaleFactor=1.2):
left = torch.clamp(X.min(1)[0][:, 0], min=0, max=img_w)
right = torch.clamp(X.max(1)[0][:, 0], min=0, max=img_w)
top = torch.clamp(X.min(1)[0][:, 1], min=0, max=img_h)
bottom = torch.clamp(X.max(1)[0][:, 1], min=0, max=img_h)
cx = (left + right) / 2
cy = (top + bottom) / 2
width = (right - left)
height = (bottom - top)
new_left = torch.clamp(cx - width/2 * scaleFactor, min=0, max=img_w-1)
new_right = torch.clamp(cx + width/2 * scaleFactor, min=1, max=img_w)
new_top = torch.clamp(cy - height / 2 * scaleFactor, min=0, max=img_h-1)
new_bottom = torch.clamp(cy + height / 2 * scaleFactor, min=1, max=img_h)
bbox = torch.stack((new_left.detach(), new_top.detach(),
new_right.detach(), new_bottom.detach())).int().float().T
return bbox
class Renderer():
def __init__(self, width, height, K, device, faces=None):
self.width = width
self.height = height
self.K = K
self.device = device
if faces is not None:
self.faces = torch.from_numpy(
(faces).astype('int')
).unsqueeze(0).to(self.device)
self.initialize_camera_params()
self.lights = PointLights(device=device, location=[[0.0, 0.0, -10.0]])
self.create_renderer()
def create_camera(self, R=None, T=None):
if R is not None:
self.R = R.clone().view(1, 3, 3).to(self.device)
if T is not None:
self.T = T.clone().view(1, 3).to(self.device)
return PerspectiveCameras(
device=self.device,
R=self.R.mT,
T=self.T,
K=self.K_full,
image_size=self.image_sizes,
in_ndc=False)
def create_renderer(self):
self.renderer = MeshRenderer(
rasterizer=MeshRasterizer(
raster_settings=RasterizationSettings(
image_size=self.image_sizes[0],
blur_radius=1e-5,),
),
shader=SoftPhongShader(
device=self.device,
lights=self.lights,
)
)
def create_normal_renderer(self):
normal_renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=self.cameras,
raster_settings=RasterizationSettings(
image_size=self.image_sizes[0],
),
),
shader=NormalShader(device=self.device),
)
return normal_renderer
def initialize_camera_params(self):
"""Hard coding for camera parameters
TODO: Do some soft coding"""
# Extrinsics
self.R = torch.diag(
torch.tensor([1, 1, 1])
).float().to(self.device).unsqueeze(0)
self.T = torch.tensor(
[0, 0, 0]
).unsqueeze(0).float().to(self.device)
# Intrinsics
self.K = self.K.unsqueeze(0).float().to(self.device)
self.bboxes = torch.tensor([[0, 0, self.width, self.height]]).float()
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, self.bboxes)
self.cameras = self.create_camera()
def render_normal(self, vertices):
vertices = vertices.unsqueeze(0)
mesh = Meshes(verts=vertices, faces=self.faces)
normal_renderer = self.create_normal_renderer()
results = normal_renderer(mesh)
results = torch.flip(results, [1, 2])
return results
def render_mesh(self, vertices, background, colors=[0.8, 0.8, 0.8]):
self.update_bbox(vertices[::50], scale=1.2)
vertices = vertices.unsqueeze(0)
if colors[0] > 1: colors = [c / 255. for c in colors]
verts_features = torch.tensor(colors).reshape(1, 1, 3).to(device=vertices.device, dtype=vertices.dtype)
verts_features = verts_features.repeat(1, vertices.shape[1], 1)
textures = TexturesVertex(verts_features=verts_features)
mesh = Meshes(verts=vertices,
faces=self.faces,
textures=textures,)
materials = Materials(
device=self.device,
specular_color=(colors, ),
shininess=0
)
results = torch.flip(
self.renderer(mesh, materials=materials, cameras=self.cameras, lights=self.lights),
[1, 2]
)
image = results[0, ..., :3] * 255
mask = results[0, ..., -1] > 1e-3
image = overlay_image_onto_background(image, mask, self.bboxes, background.copy())
self.reset_bbox()
return image
def update_bbox(self, x3d, scale=2.0, mask=None):
""" Update bbox of cameras from the given 3d points
x3d: input 3D keypoints (or vertices), (num_frames, num_points, 3)
"""
if x3d.size(-1) != 3:
x2d = x3d.unsqueeze(0)
else:
x2d = perspective_projection(x3d.unsqueeze(0), self.K, self.R, self.T.reshape(1, 3, 1))
if mask is not None:
x2d = x2d[:, ~mask]
bbox = compute_bbox_from_points(x2d, self.width, self.height, scale)
self.bboxes = bbox
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
self.cameras = self.create_camera()
self.create_renderer()
def reset_bbox(self,):
bbox = torch.zeros((1, 4)).float().to(self.device)
bbox[0, 2] = self.width
bbox[0, 3] = self.height
self.bboxes = bbox
self.K_full, self.image_sizes = update_intrinsics_from_bbox(self.K, bbox)
self.cameras = self.create_camera()
self.create_renderer()
class RendererUtil():
def __init__(self, K, w, h, device, faces, keep_origin=True):
self.keep_origin = keep_origin
self.default_R = torch.eye(3)
self.default_T = torch.zeros(3)
self.device = device
self.renderer = Renderer(w, h, K, device, faces)
def set_extrinsic(self, R, T):
self.default_R = R
self.default_T = T
def render_normal(self, verts_list):
if not len(verts_list) == 1:
return None
self.renderer.create_camera(self.default_R, self.default_T)
normal_map = self.renderer.render_normal(verts_list[0])
return normal_map[0, :, :, 0]
def render_frame(self, humans, pred_rend_array, verts_list=None, color_list=None):
if not isinstance(pred_rend_array, np.ndarray):
pred_rend_array = np.asarray(pred_rend_array)
self.renderer.create_camera(self.default_R, self.default_T)
_img = pred_rend_array
if humans is not None:
for human in humans:
_img = self.renderer.render_mesh(human['v3d'].to(self.device), _img)
else:
for i, verts in enumerate(verts_list):
if color_list is None:
_img = self.renderer.render_mesh(verts.to(self.device), _img)
else:
_img = self.renderer.render_mesh(verts.to(self.device), _img, color_list[i])
if self.keep_origin:
_img = np.concatenate([np.asarray(pred_rend_array), _img],1).astype(np.uint8)
return _img
def render_video(self, results, pil_bis_frames, fps, out_path):
writer = imageio.get_writer(
out_path,
fps=fps, mode='I', format='FFMPEG', macro_block_size=1
)
for i, humans in enumerate(tqdm(results)):
pred_rend_array = pil_bis_frames[i]
_img = self.render_frame( humans, pred_rend_array)
try:
writer.append_data(_img)
except:
print('Error in writing video')
print(type(_img))
writer.close()
def render_frame(renderer, humans, pred_rend_array, default_R, default_T, device, keep_origin=True):
if not isinstance(pred_rend_array, np.ndarray):
pred_rend_array = np.asarray(pred_rend_array)
renderer.create_camera(default_R, default_T)
_img = pred_rend_array
if humans is None:
humans = []
if isinstance(humans, dict):
humans = [humans]
for human in humans:
if isinstance(human, dict):
v3d = human['v3d'].to(device)
else:
v3d = human
_img = renderer.render_mesh(v3d, _img)
if keep_origin:
_img = np.concatenate([np.asarray(pred_rend_array), _img],1).astype(np.uint8)
return _img
def render_video(results, faces, K, pil_bis_frames, fps, out_path, device, keep_origin=True):
# results [F, N, ...]
if isinstance(pil_bis_frames[0], np.ndarray):
height, width, _ = pil_bis_frames[0].shape
else:
shape = pil_bis_frames[0].size
width, height = shape[1], shape[0]
renderer = Renderer(width, height, K[0], device, faces)
# build default camera
default_R, default_T = torch.eye(3), torch.zeros(3)
writer = imageio.get_writer(
out_path,
fps=fps, mode='I', format='FFMPEG', macro_block_size=1
)
for i, humans in enumerate(tqdm(results)):
pred_rend_array = pil_bis_frames[i]
_img = render_frame(renderer, humans, pred_rend_array, default_R, default_T, device, keep_origin)
try:
writer.append_data(_img)
except:
print('Error in writing video')
print(type(_img))
writer.close()