triposr-s3 / tsr /bake_texture.py
ashh757's picture
Upload 18 files
8c93108 verified
raw
history blame contribute delete
5.37 kB
import numpy as np
import torch
import xatlas
import trimesh
import moderngl
from PIL import Image
def make_atlas(mesh, texture_resolution, texture_padding):
atlas = xatlas.Atlas()
atlas.add_mesh(mesh.vertices, mesh.faces)
options = xatlas.PackOptions()
options.resolution = texture_resolution
options.padding = texture_padding
options.bilinear = True
atlas.generate(pack_options=options)
vmapping, indices, uvs = atlas[0]
return {
"vmapping": vmapping,
"indices": indices,
"uvs": uvs,
}
def rasterize_position_atlas(
mesh, atlas_vmapping, atlas_indices, atlas_uvs, texture_resolution, texture_padding
):
ctx = moderngl.create_context(standalone=True)
basic_prog = ctx.program(
vertex_shader="""
#version 330
in vec2 in_uv;
in vec3 in_pos;
out vec3 v_pos;
void main() {
v_pos = in_pos;
gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
}
""",
fragment_shader="""
#version 330
in vec3 v_pos;
out vec4 o_col;
void main() {
o_col = vec4(v_pos, 1.0);
}
""",
)
gs_prog = ctx.program(
vertex_shader="""
#version 330
in vec2 in_uv;
in vec3 in_pos;
out vec3 vg_pos;
void main() {
vg_pos = in_pos;
gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0);
}
""",
geometry_shader="""
#version 330
uniform float u_resolution;
uniform float u_dilation;
layout (triangles) in;
layout (triangle_strip, max_vertices = 12) out;
in vec3 vg_pos[];
out vec3 vf_pos;
void lineSegment(int aidx, int bidx) {
vec2 a = gl_in[aidx].gl_Position.xy;
vec2 b = gl_in[bidx].gl_Position.xy;
vec3 aCol = vg_pos[aidx];
vec3 bCol = vg_pos[bidx];
vec2 dir = normalize((b - a) * u_resolution);
vec2 offset = vec2(-dir.y, dir.x) * u_dilation / u_resolution;
gl_Position = vec4(a + offset, 0.0, 1.0);
vf_pos = aCol;
EmitVertex();
gl_Position = vec4(a - offset, 0.0, 1.0);
vf_pos = aCol;
EmitVertex();
gl_Position = vec4(b + offset, 0.0, 1.0);
vf_pos = bCol;
EmitVertex();
gl_Position = vec4(b - offset, 0.0, 1.0);
vf_pos = bCol;
EmitVertex();
}
void main() {
lineSegment(0, 1);
lineSegment(1, 2);
lineSegment(2, 0);
EndPrimitive();
}
""",
fragment_shader="""
#version 330
in vec3 vf_pos;
out vec4 o_col;
void main() {
o_col = vec4(vf_pos, 1.0);
}
""",
)
uvs = atlas_uvs.flatten().astype("f4")
pos = mesh.vertices[atlas_vmapping].flatten().astype("f4")
indices = atlas_indices.flatten().astype("i4")
vbo_uvs = ctx.buffer(uvs)
vbo_pos = ctx.buffer(pos)
ibo = ctx.buffer(indices)
vao_content = [
vbo_uvs.bind("in_uv", layout="2f"),
vbo_pos.bind("in_pos", layout="3f"),
]
basic_vao = ctx.vertex_array(basic_prog, vao_content, ibo)
gs_vao = ctx.vertex_array(gs_prog, vao_content, ibo)
fbo = ctx.framebuffer(
color_attachments=[
ctx.texture((texture_resolution, texture_resolution), 4, dtype="f4")
]
)
fbo.use()
fbo.clear(0.0, 0.0, 0.0, 0.0)
gs_prog["u_resolution"].value = texture_resolution
gs_prog["u_dilation"].value = texture_padding
gs_vao.render()
basic_vao.render()
fbo_bytes = fbo.color_attachments[0].read()
fbo_np = np.frombuffer(fbo_bytes, dtype="f4").reshape(
texture_resolution, texture_resolution, 4
)
return fbo_np
def positions_to_colors(model, scene_code, positions_texture, texture_resolution):
positions = torch.tensor(positions_texture.reshape(-1, 4)[:, :-1])
with torch.no_grad():
queried_grid = model.renderer.query_triplane(
model.decoder,
positions,
scene_code,
)
rgb_f = queried_grid["color"].numpy().reshape(-1, 3)
rgba_f = np.insert(rgb_f, 3, positions_texture.reshape(-1, 4)[:, -1], axis=1)
rgba_f[rgba_f[:, -1] == 0.0] = [0, 0, 0, 0]
return rgba_f.reshape(texture_resolution, texture_resolution, 4)
def bake_texture(mesh, model, scene_code, texture_resolution):
texture_padding = round(max(2, texture_resolution / 256))
atlas = make_atlas(mesh, texture_resolution, texture_padding)
positions_texture = rasterize_position_atlas(
mesh,
atlas["vmapping"],
atlas["indices"],
atlas["uvs"],
texture_resolution,
texture_padding,
)
colors_texture = positions_to_colors(
model, scene_code, positions_texture, texture_resolution
)
return {
"vmapping": atlas["vmapping"],
"indices": atlas["indices"],
"uvs": atlas["uvs"],
"colors": colors_texture,
}