shimu0215 commited on
Commit
8d4f9a4
·
verified ·
1 Parent(s): eb6c373

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +13 -0
  2. app.py +326 -0
  3. requirements.txt +8 -0
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sapiens Depth
3
+ emoji: 🦀
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.42.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-4.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import gradio as gr
4
+ import numpy as np
5
+ import spaces
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torchvision import transforms
9
+ from PIL import Image
10
+ import matplotlib.pyplot as plt
11
+ import tempfile
12
+
13
+ class Config:
14
+ ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
15
+ CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
16
+ CHECKPOINTS = {
17
+ "0.3b": "sapiens_0.3b_render_people_epoch_100_torchscript.pt2",
18
+ "0.6b": "sapiens_0.6b_render_people_epoch_70_torchscript.pt2",
19
+ "1b": "sapiens_1b_render_people_epoch_88_torchscript.pt2",
20
+ "2b": "sapiens_2b_render_people_epoch_25_torchscript.pt2",
21
+ }
22
+ SEG_CHECKPOINTS = {
23
+ "fg-bg-1b (recommended)": "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2",
24
+ "no-bg-removal": None,
25
+ "part-seg-1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
26
+ }
27
+
28
+ class ModelManager:
29
+ @staticmethod
30
+ def load_model(checkpoint_name: str):
31
+ if checkpoint_name is None:
32
+ return None
33
+ checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, checkpoint_name)
34
+ model = torch.jit.load(checkpoint_path)
35
+ model.eval()
36
+ model.to("cuda")
37
+ return model
38
+
39
+ @staticmethod
40
+ @torch.inference_mode()
41
+ def run_model(model, input_tensor, height, width):
42
+ output = model(input_tensor)
43
+ return F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
44
+
45
+ class ImageProcessor:
46
+ def __init__(self):
47
+ self.transform_fn = transforms.Compose([
48
+ transforms.Resize((1024, 768)),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]),
51
+ ])
52
+
53
+ @spaces.GPU
54
+ def process_image(self, image: Image.Image, depth_model_name: str, seg_model_name: str):
55
+ depth_model = ModelManager.load_model(Config.CHECKPOINTS[depth_model_name])
56
+ input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda")
57
+ depth_output = ModelManager.run_model(depth_model, input_tensor, image.height, image.width)
58
+ depth_map = depth_output.squeeze().cpu().numpy()
59
+
60
+ if seg_model_name != "no-bg-removal":
61
+ seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name])
62
+ seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width)
63
+ seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0]
64
+ depth_map[seg_mask == 0] = np.nan
65
+
66
+ depth_colored = self.colorize_depth_map(depth_map)
67
+ npy_path = tempfile.mktemp(suffix='.npy')
68
+ np.save(npy_path, depth_map)
69
+
70
+ return Image.fromarray(depth_colored), npy_path
71
+
72
+ @staticmethod
73
+ def colorize_depth_map(depth_map):
74
+ depth_foreground = depth_map[~np.isnan(depth_map)]
75
+ if len(depth_foreground) > 0:
76
+ min_val, max_val = np.nanmin(depth_foreground), np.nanmax(depth_foreground)
77
+ depth_normalized = (depth_map - min_val) / (max_val - min_val)
78
+ depth_normalized = 1 - depth_normalized
79
+ depth_normalized = np.nan_to_num(depth_normalized, nan=0)
80
+ cmap = plt.get_cmap('inferno')
81
+ depth_colored = (cmap(depth_normalized) * 255).astype(np.uint8)[:, :, :3]
82
+ else:
83
+ depth_colored = np.zeros((depth_map.shape[0], depth_map.shape[1], 3), dtype=np.uint8)
84
+ return depth_colored
85
+
86
+ class GradioInterface:
87
+ def __init__(self):
88
+ self.image_processor = ImageProcessor()
89
+
90
+ def create_interface(self):
91
+ app_styles = """
92
+ <style>
93
+ /* Global Styles */
94
+ body, #root {
95
+ font-family: Helvetica, Arial, sans-serif;
96
+ background-color: #1a1a1a;
97
+ color: #fafafa;
98
+ }
99
+
100
+ /* Header Styles */
101
+ .app-header {
102
+ background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
103
+ padding: 24px;
104
+ border-radius: 8px;
105
+ margin-bottom: 24px;
106
+ text-align: center;
107
+ }
108
+
109
+ .app-title {
110
+ font-size: 48px;
111
+ margin: 0;
112
+ color: #fafafa;
113
+ }
114
+
115
+ .app-subtitle {
116
+ font-size: 24px;
117
+ margin: 8px 0 16px;
118
+ color: #fafafa;
119
+ }
120
+
121
+ .app-description {
122
+ font-size: 16px;
123
+ line-height: 1.6;
124
+ opacity: 0.8;
125
+ margin-bottom: 24px;
126
+ }
127
+
128
+ /* Button Styles */
129
+ .publication-links {
130
+ display: flex;
131
+ justify-content: center;
132
+ flex-wrap: wrap;
133
+ gap: 8px;
134
+ margin-bottom: 16px;
135
+ }
136
+
137
+ .publication-link {
138
+ display: inline-flex;
139
+ align-items: center;
140
+ padding: 8px 16px;
141
+ background-color: #333;
142
+ color: #fff !important;
143
+ text-decoration: none !important;
144
+ border-radius: 20px;
145
+ font-size: 14px;
146
+ transition: background-color 0.3s;
147
+ }
148
+
149
+ .publication-link:hover {
150
+ background-color: #555;
151
+ }
152
+
153
+ .publication-link i {
154
+ margin-right: 8px;
155
+ }
156
+
157
+ /* Content Styles */
158
+ .content-container {
159
+ background-color: #2a2a2a;
160
+ border-radius: 8px;
161
+ padding: 24px;
162
+ margin-bottom: 24px;
163
+ }
164
+
165
+ /* Image Styles */
166
+ .image-preview img {
167
+ max-width: 512px;
168
+ max-height: 512px;
169
+ margin: 0 auto;
170
+ border-radius: 4px;
171
+ display: block;
172
+ object-fit: contain;
173
+ }
174
+
175
+ /* Control Styles */
176
+ .control-panel {
177
+ background-color: #333;
178
+ padding: 16px;
179
+ border-radius: 8px;
180
+ margin-top: 16px;
181
+ }
182
+
183
+ /* Gradio Component Overrides */
184
+ .gr-button {
185
+ background-color: #4a4a4a;
186
+ color: #fff;
187
+ border: none;
188
+ border-radius: 4px;
189
+ padding: 8px 16px;
190
+ cursor: pointer;
191
+ transition: background-color 0.3s;
192
+ }
193
+
194
+ .gr-button:hover {
195
+ background-color: #5a5a5a;
196
+ }
197
+
198
+ .gr-input, .gr-dropdown {
199
+ background-color: #3a3a3a;
200
+ color: #fff;
201
+ border: 1px solid #4a4a4a;
202
+ border-radius: 4px;
203
+ padding: 8px;
204
+ }
205
+
206
+ .gr-form {
207
+ background-color: transparent;
208
+ }
209
+
210
+ .gr-panel {
211
+ border: none;
212
+ background-color: transparent;
213
+ }
214
+
215
+ /* Override any conflicting styles from Bulma */
216
+ .button.is-normal.is-rounded.is-dark {
217
+ color: #fff !important;
218
+ text-decoration: none !important;
219
+ }
220
+ </style>
221
+ """
222
+
223
+ header_html = f"""
224
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/css/bulma.min.css">
225
+ <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
226
+ {app_styles}
227
+ <div class="app-header">
228
+ <h1 class="app-title">Sapiens: Depth Estimation</h1>
229
+ <h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
230
+ <p class="app-description">
231
+ Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images.
232
+ This demo showcases the finetuned depth model.
233
+ </p>
234
+ <div class="publication-links">
235
+ <a href="https://arxiv.org/abs/2408.12569" class="publication-link">
236
+ <i class="fas fa-file-pdf"></i>arXiv
237
+ </a>
238
+ <a href="https://github.com/facebookresearch/sapiens" class="publication-link">
239
+ <i class="fab fa-github"></i>Code
240
+ </a>
241
+ <a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link">
242
+ <i class="fas fa-globe"></i>Meta
243
+ </a>
244
+ <a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link">
245
+ <i class="fas fa-chart-bar"></i>Results
246
+ </a>
247
+ </div>
248
+ <div class="publication-links">
249
+ <a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link">
250
+ <i class="fas fa-user"></i>Demo-Pose
251
+ </a>
252
+ <a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link">
253
+ <i class="fas fa-puzzle-piece"></i>Demo-Seg
254
+ </a>
255
+ <a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link">
256
+ <i class="fas fa-cube"></i>Demo-Depth
257
+ </a>
258
+ <a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link">
259
+ <i class="fas fa-vector-square"></i>Demo-Normal
260
+ </a>
261
+ </div>
262
+ </div>
263
+ """
264
+
265
+ js_func = """
266
+ function refresh() {
267
+ const url = new URL(window.location);
268
+ if (url.searchParams.get('__theme') !== 'dark') {
269
+ url.searchParams.set('__theme', 'dark');
270
+ window.location.href = url.href;
271
+ }
272
+ }
273
+ """
274
+
275
+ def process_image(image, depth_model_name, seg_model_name):
276
+ result, npy_path = self.image_processor.process_image(image, depth_model_name, seg_model_name)
277
+ return result, npy_path
278
+
279
+ with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
280
+ gr.HTML(header_html)
281
+ with gr.Row(elem_classes="content-container"):
282
+ with gr.Column():
283
+ input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
284
+ with gr.Row(elem_classes="control-panel"):
285
+ depth_model_name = gr.Dropdown(
286
+ label="Depth Model Size",
287
+ choices=list(Config.CHECKPOINTS.keys()),
288
+ value="1b",
289
+ )
290
+ seg_model_name = gr.Dropdown(
291
+ label="Background Removal Model",
292
+ choices=list(Config.SEG_CHECKPOINTS.keys()),
293
+ value="fg-bg-1b (recommended)",
294
+ )
295
+ example_model = gr.Examples(
296
+ inputs=input_image,
297
+ examples_per_page=14,
298
+ examples=[
299
+ os.path.join(Config.ASSETS_DIR, "images", img)
300
+ for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images"))
301
+ ],
302
+ )
303
+ with gr.Column():
304
+ result_image = gr.Image(label="Depth Estimation Result", type="pil", elem_classes="image-preview")
305
+ npy_output = gr.File(label="Output (.npy). Note: Background depth is NaN.")
306
+ run_button = gr.Button("Run", elem_classes="gr-button")
307
+
308
+ run_button.click(
309
+ fn=process_image,
310
+ inputs=[input_image, depth_model_name, seg_model_name],
311
+ outputs=[result_image, npy_output],
312
+ )
313
+
314
+ return demo
315
+
316
+ def main():
317
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
318
+ torch.backends.cuda.matmul.allow_tf32 = True
319
+ torch.backends.cudnn.allow_tf32 = True
320
+
321
+ interface = GradioInterface()
322
+ demo = interface.create_interface()
323
+ demo.launch(share=False)
324
+
325
+ if __name__ == "__main__":
326
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ numpy
3
+ torch
4
+ torchvision
5
+ matplotlib
6
+ pillow
7
+ spaces
8
+ opencv-python