LiDAR-Diffusion / app.py
Hancy's picture
modify on ZeroGPU
1615664
raw
history blame contribute delete
2.84 kB
import gradio as gr
import spaces
import os
import torch
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from app_config import CSS, HEADER, FOOTER
from sample_cond import CKPT_PATH, MODEL_CFG, load_model_from_config, sample
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model():
pl_sd = torch.load(CKPT_PATH, map_location="cpu")
model = load_model_from_config(MODEL_CFG.model, pl_sd["state_dict"])
return model
def create_custom_colormap():
colors = [(0, 1, 0), (0, 1, 1), (0, 0, 1), (1, 0, 1), (1, 1, 0)]
positions = [0, 0.38, 0.6, 0.7, 1]
custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', list(zip(positions, colors)), N=256)
return custom_cmap
def colorize_depth(depth, log_scale):
if log_scale:
depth = ((np.log2((depth / 255.) * 56. + 1) / 5.84) * 255.).astype(np.uint8)
mask = depth == 0
colormap = create_custom_colormap()
rgb = colormap(depth)[:, :, :3]
rgb[mask] = 0.
return rgb
@spaces.GPU
@torch.no_grad()
def generate_lidar(model, cond):
img, pcd = sample(model, cond)
return img, pcd
def load_camera(image):
split_per_view = 4
camera = np.array(image).astype(np.float32) / 255.
camera = camera.transpose(2, 0, 1)
camera_list = np.split(camera, split_per_view, axis=2) # split into n chunks as different views
camera_cond = torch.from_numpy(np.stack(camera_list, axis=0)).unsqueeze(0).to(DEVICE)
return camera_cond
model = load_model().to(DEVICE)
with gr.Blocks(css=CSS) as demo:
gr.Markdown(HEADER)
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
with gr.Column():
output_image = gr.Image(label="Output Range Map", elem_id='img-display-output')
output_pcd = gr.Model3D(label="Output Point Cloud", elem_id='pcd-display-output', interactive=False)
# raw_file = gr.File(label="Point Cloud (.txt file). Can be viewed through Meshlab")
submit = gr.Button("Generate")
def on_submit(image):
cond = load_camera(image)
img, pcd = generate_lidar(model, cond)
# tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
# pcd.save(tmp.name)
rgb_img = colorize_depth(img, log_scale=True)
return [rgb_img, pcd]
submit.click(on_submit, inputs=[input_image], outputs=[output_image, output_pcd])
example_files = sorted(os.listdir('cam_examples'))
example_files = [os.path.join('cam_examples', filename) for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[output_image, output_pcd],
fn=on_submit, cache_examples=False)
gr.Markdown(FOOTER)
if __name__ == '__main__':
demo.queue().launch()