File size: 5,539 Bytes
33a65b5
 
 
 
89a1e10
72cd992
89a1e10
33a65b5
72cd992
 
ecb2ac2
29aeef1
ecb2ac2
642c115
33a65b5
72cd992
33a65b5
72cd992
33a65b5
d79ba44
 
72cd992
d79ba44
72cd992
d79ba44
72cd992
d79ba44
33a65b5
 
 
 
72cd992
89a1e10
72cd992
642c115
72cd992
33a65b5
 
72cd992
d79ba44
 
 
 
 
 
33a65b5
d79ba44
72cd992
33a65b5
 
 
 
 
 
642c115
33a65b5
642c115
33a65b5
642c115
 
72cd992
d79ba44
642c115
d79ba44
72cd992
642c115
72cd992
642c115
33a65b5
72cd992
 
642c115
 
 
 
 
 
d79ba44
642c115
 
 
d79ba44
 
72cd992
642c115
 
 
ecb2ac2
 
89a1e10
72cd992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33a65b5
d79ba44
 
72cd992
d79ba44
72cd992
 
642c115
72cd992
33a65b5
642c115
 
33a65b5
 
 
 
89a1e10
642c115
 
 
 
33a65b5
29aeef1
d71ed99
29aeef1
33a65b5
 
89a1e10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import torch
from PIL import Image
import numpy as np
from distillanydepth.modeling.archs.dam.dam import DepthAnything
from distillanydepth.utils.image_util import chw2hwc, colorize_depth_maps
from distillanydepth.midas.transforms import Resize, NormalizeImage, PrepareForNet
from torchvision.transforms import Compose
import cv2
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from gradio_imageslider import ImageSlider
import spaces
import tempfile

# Helper function to load model from Hugging Face
def load_model_by_name(arch_name, checkpoint_path, device):
    model = None
    if arch_name == 'depthanything':
        # Use safetensors to load model weights
        model_weights = load_file(checkpoint_path)  # Load using safetensors
        
        # Initialize model
        model = DepthAnything(checkpoint_path=None).to(device)
        model.load_state_dict(model_weights)  # Apply loaded weights to the model

        model = model.to(device)  # Ensure the model is on the correct device
    else:
        raise NotImplementedError(f"Unknown architecture: {arch_name}")
    return model

# Image processing function
def process_image(image, model, device):
    if model is None:
        return None, None, None, None
    
    # Preprocess the image
    image_np = np.array(image)[..., ::-1] / 255
    
    # Resize input image to 1920p while maintaining aspect ratio
    h, w = image_np.shape[:2]
    scale = 1920 / max(h, w)
    new_h, new_w = int(h * scale), int(w * scale)
    image_np = cv2.resize(image_np, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
    
    transform = Compose([
        Resize(new_h, new_w, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC),
        NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        PrepareForNet()
    ])
    
    image_tensor = transform({'image': image_np})['image']
    image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(device)
    
    with torch.no_grad():
        pred_disp, _ = model(image_tensor)
    torch.cuda.empty_cache()

    # Convert depth map to numpy
    pred_disp_np = pred_disp.cpu().detach().numpy()[0, 0, :, :]
    
    # Normalize depth map to 16-bit range [0, 65535]
    pred_disp_normalized = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min())
    pred_disp_16bit = (pred_disp_normalized * 65535).astype(np.uint16)
    
    # Colorized depth map
    cmap = "Spectral_r"
    depth_colored = colorize_depth_maps(pred_disp_normalized[None, ..., None], 0, 1, cmap=cmap).squeeze()
    depth_colored = (depth_colored * 255).astype(np.uint8)
    depth_colored_hwc = chw2hwc(depth_colored)
    
    # Gray depth map
    depth_gray = (pred_disp_normalized * 255).astype(np.uint8)
    depth_gray_hwc = np.stack([depth_gray] * 3, axis=-1)  # Convert to 3-channel grayscale
    
    # Save raw depth map as a temporary npy file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".npy") as temp_file:
        np.save(temp_file.name, pred_disp_16bit)
        depth_raw_path = temp_file.name
    
    # Resize outputs to match original image size
    depth_colored_hwc = cv2.resize(depth_colored_hwc, (new_w, new_h), cv2.INTER_LINEAR)
    depth_gray_hwc = cv2.resize(depth_gray_hwc, (new_w, new_h), cv2.INTER_LINEAR)
    
    # Convert to PIL images
    return image, Image.fromarray(depth_colored_hwc), Image.fromarray(depth_gray_hwc), depth_raw_path

# Gradio interface function with GPU support
@spaces.GPU
def gradio_interface(image):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_kwargs = dict(
        vitb=dict(
            encoder='vitb',
            features=128,
            out_channels=[96, 192, 384, 768],
        ),
        vitl=dict(
            encoder="vitl", 
            features=256, 
            out_channels=[256, 512, 1024, 1024], 
            use_bn=False, 
            use_clstoken=False, 
            max_depth=150.0, 
            mode='disparity',
            pretrain_type='dinov2',
            del_mask_token=False
        )
    )
    # Load model
    model = DepthAnything(**model_kwargs['vitl']).to(device)
    checkpoint_path = hf_hub_download(repo_id=f"xingyang1/Distill-Any-Depth", filename=f"large/model.safetensors", repo_type="model")

    # Use safetensors to load model weights
    model_weights = load_file(checkpoint_path)  # Load using safetensors
    model.load_state_dict(model_weights)
    model = model.to(device)  # Ensure the model is on the correct device
    
    if model is None:
        return None, None, None, None
    
    # Process image and return output
    image, depth_image, depth_gray, depth_raw = process_image(image, model, device)
    return (image, depth_image), depth_gray, depth_raw

# Create Gradio interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="pil"),  # Only image input, no mode selection
    outputs = [ImageSlider(label="Depth slider", type="pil", slider_color="pink"), # Depth image out with a slider
        gr.Image(type="pil", label="Gray Depth"),
        gr.File(label="Raw Depth (NumPy File)")
    ],
    title="Depth Estimation Demo",
    description="Upload an image to see the depth estimation results. Our model is running on GPU for faster processing.",
    examples=["1.jpg", "2.jpg", "4.png", "5.jpg", "6.jpg"],
    cache_examples=True,)

# Launch the Gradio interface
iface.launch()