Chaerin5 commited on
Commit
fd8c448
·
1 Parent(s): e86bac5

instruction renovation; allow manual keypoints at edit hands

Browse files
Files changed (1) hide show
  1. app_regular_gpu.py +0 -2003
app_regular_gpu.py DELETED
@@ -1,2003 +0,0 @@
1
- import os
2
- import torch
3
- from dataclasses import dataclass
4
- import gradio as gr
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- import cv2
8
- import mediapipe as mp
9
- from torchvision.transforms import Compose, Resize, ToTensor, Normalize
10
- import vqvae
11
- import vit
12
- from typing import Literal
13
- from diffusion import create_diffusion
14
- from utils import scale_keypoint, keypoint_heatmap, check_keypoints_validity
15
- from segment_hoi import init_sam
16
- from io import BytesIO
17
- from PIL import Image
18
- import random
19
- from copy import deepcopy
20
- from typing import Optional
21
- import requests
22
- from huggingface_hub import hf_hub_download
23
- # import spaces
24
-
25
- MAX_N = 6
26
- FIX_MAX_N = 6
27
-
28
- placeholder = cv2.cvtColor(cv2.imread("placeholder.png"), cv2.COLOR_BGR2RGB)
29
- NEW_MODEL = True
30
- MODEL_EPOCH = 6
31
- REF_POSE_MASK = True
32
-
33
- def set_seed(seed):
34
- seed = int(seed)
35
- torch.manual_seed(seed)
36
- np.random.seed(seed)
37
- torch.cuda.manual_seed_all(seed)
38
- random.seed(seed)
39
-
40
- # if torch.cuda.is_available():
41
- device = "cuda"
42
- # else:
43
- # device = "cpu"
44
-
45
- def remove_prefix(text, prefix):
46
- if text.startswith(prefix):
47
- return text[len(prefix) :]
48
- return text
49
-
50
-
51
- def unnormalize(x):
52
- return (((x + 1) / 2) * 255).astype(np.uint8)
53
-
54
-
55
- def visualize_hand(all_joints, img, side=["right", "left"], n_avail_joints=21):
56
- # Define the connections between joints for drawing lines and their corresponding colors
57
- connections = [
58
- ((0, 1), "red"),
59
- ((1, 2), "green"),
60
- ((2, 3), "blue"),
61
- ((3, 4), "purple"),
62
- ((0, 5), "orange"),
63
- ((5, 6), "pink"),
64
- ((6, 7), "brown"),
65
- ((7, 8), "cyan"),
66
- ((0, 9), "yellow"),
67
- ((9, 10), "magenta"),
68
- ((10, 11), "lime"),
69
- ((11, 12), "indigo"),
70
- ((0, 13), "olive"),
71
- ((13, 14), "teal"),
72
- ((14, 15), "navy"),
73
- ((15, 16), "gray"),
74
- ((0, 17), "lavender"),
75
- ((17, 18), "silver"),
76
- ((18, 19), "maroon"),
77
- ((19, 20), "fuchsia"),
78
- ]
79
- H, W, C = img.shape
80
-
81
- # Create a figure and axis
82
- plt.figure()
83
- ax = plt.gca()
84
- # Plot joints as points
85
- ax.imshow(img)
86
- start_is = []
87
- if "right" in side:
88
- start_is.append(0)
89
- if "left" in side:
90
- start_is.append(21)
91
- for start_i in start_is:
92
- joints = all_joints[start_i : start_i + n_avail_joints]
93
- if len(joints) == 1:
94
- ax.scatter(joints[0][0], joints[0][1], color="red", s=10)
95
- else:
96
- for connection, color in connections[: len(joints) - 1]:
97
- joint1 = joints[connection[0]]
98
- joint2 = joints[connection[1]]
99
- ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], color=color)
100
-
101
- ax.set_xlim([0, W])
102
- ax.set_ylim([0, H])
103
- ax.grid(False)
104
- ax.set_axis_off()
105
- ax.invert_yaxis()
106
- # plt.subplots_adjust(wspace=0.01)
107
- # plt.show()
108
- buf = BytesIO()
109
- plt.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
110
- plt.close()
111
-
112
- # Convert BytesIO object to numpy array
113
- buf.seek(0)
114
- img_pil = Image.open(buf)
115
- img_pil = img_pil.resize((H, W))
116
- numpy_img = np.array(img_pil)
117
-
118
- return numpy_img
119
-
120
-
121
- def mask_image(image, mask, color=[0, 0, 0], alpha=0.6, transparent=True):
122
- """Overlay mask on image for visualization purpose.
123
- Args:
124
- image (H, W, 3) or (H, W): input image
125
- mask (H, W): mask to be overlaid
126
- color: the color of overlaid mask
127
- alpha: the transparency of the mask
128
- """
129
- out = deepcopy(image)
130
- img = deepcopy(image)
131
- img[mask == 1] = color
132
- if transparent:
133
- out = cv2.addWeighted(img, alpha, out, 1 - alpha, 0, out)
134
- else:
135
- out = img
136
- return out
137
-
138
-
139
- def scale_keypoint(keypoint, original_size, target_size):
140
- """Scale a keypoint based on the resizing of the image."""
141
- keypoint_copy = keypoint.copy()
142
- keypoint_copy[:, 0] *= target_size[0] / original_size[0]
143
- keypoint_copy[:, 1] *= target_size[1] / original_size[1]
144
- return keypoint_copy
145
-
146
-
147
- print("Configure...")
148
-
149
-
150
- @dataclass
151
- class HandDiffOpts:
152
- run_name: str = "ViT_256_handmask_heatmap_nvs_b25_lr1e-5"
153
- sd_path: str = "/users/kchen157/scratch/weights/SD/sd-v1-4.ckpt"
154
- log_dir: str = "/users/kchen157/scratch/log"
155
- data_root: str = "/users/kchen157/data/users/kchen157/dataset/handdiff"
156
- image_size: tuple = (256, 256)
157
- latent_size: tuple = (32, 32)
158
- latent_dim: int = 4
159
- mask_bg: bool = False
160
- kpts_form: str = "heatmap"
161
- n_keypoints: int = 42
162
- n_mask: int = 1
163
- noise_steps: int = 1000
164
- test_sampling_steps: int = 250
165
- ddim_steps: int = 100
166
- ddim_discretize: str = "uniform"
167
- ddim_eta: float = 0.0
168
- beta_start: float = 8.5e-4
169
- beta_end: float = 0.012
170
- latent_scaling_factor: float = 0.18215
171
- cfg_pose: float = 5.0
172
- cfg_appearance: float = 3.5
173
- batch_size: int = 25
174
- lr: float = 1e-5
175
- max_epochs: int = 500
176
- log_every_n_steps: int = 100
177
- limit_val_batches: int = 1
178
- n_gpu: int = 8
179
- num_nodes: int = 1
180
- precision: str = "16-mixed"
181
- profiler: str = "simple"
182
- swa_epoch_start: int = 10
183
- swa_lrs: float = 1e-3
184
- num_workers: int = 10
185
- n_val_samples: int = 4
186
-
187
- # load models
188
- token = os.getenv("HF_TOKEN")
189
- if NEW_MODEL:
190
- opts = HandDiffOpts()
191
- if MODEL_EPOCH == 7:
192
- model_path = './DINO_EMA_11M_b50_lr1e-5_epoch7_step380k.ckpt'
193
- elif MODEL_EPOCH == 6:
194
- # model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt"
195
- model_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="DINO_EMA_11M_b50_lr1e-5_epoch6_step320k.ckpt", token=token)
196
- elif MODEL_EPOCH == 4:
197
- model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch4_step210k.ckpt"
198
- elif MODEL_EPOCH == 10:
199
- model_path = "./DINO_EMA_11M_b50_lr1e-5_epoch10_step550k.ckpt"
200
- else:
201
- raise ValueError(f"new model epoch should be either 6 or 7, got {MODEL_EPOCH}")
202
- # vae_path = './vae-ft-mse-840000-ema-pruned.ckpt'
203
- vae_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="vae-ft-mse-840000-ema-pruned.ckpt", token=token)
204
- # sd_path = './sd-v1-4.ckpt'
205
- print('Load diffusion model...')
206
- diffusion = create_diffusion(str(opts.test_sampling_steps))
207
- model = vit.DiT_XL_2(
208
- input_size=opts.latent_size[0],
209
- latent_dim=opts.latent_dim,
210
- in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
211
- learn_sigma=True,
212
- ).to(device)
213
- # ckpt_state_dict = torch.load(model_path)['model_state_dict']
214
- ckpt_state_dict = torch.load(model_path, map_location='cpu')['ema_state_dict']
215
- missing_keys, extra_keys = model.load_state_dict(ckpt_state_dict, strict=False)
216
- model = model.to(device)
217
- model.eval()
218
- print(missing_keys, extra_keys)
219
- assert len(missing_keys) == 0
220
- vae_state_dict = torch.load(vae_path, map_location='cpu')['state_dict']
221
- print(f"vae_state_dict encoder dtype: {vae_state_dict['encoder.conv_in.weight'].dtype}")
222
- autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False)
223
- print(f"autoencoder encoder dtype: {next(autoencoder.encoder.parameters()).dtype}")
224
- print(f"encoder before load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
225
- print(f"encoder before load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
226
- missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
227
- print(f"encoder after load_state_dict parameters min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
228
- print(f"encoder after load_state_dict parameters max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
229
- autoencoder = autoencoder.to(device)
230
- autoencoder.eval()
231
- print(f"encoder after eval() min: {min([p.min() for p in autoencoder.encoder.parameters()])}")
232
- print(f"encoder after eval() max: {max([p.max() for p in autoencoder.encoder.parameters()])}")
233
- print(f"autoencoder encoder after eval() dtype: {next(autoencoder.encoder.parameters()).dtype}")
234
- assert len(missing_keys) == 0
235
- # else:
236
- # opts = HandDiffOpts()
237
- # model_path = './finetune_epoch=5-step=130000.ckpt'
238
- # sd_path = './sd-v1-4.ckpt'
239
- # print('Load diffusion model...')
240
- # diffusion = create_diffusion(str(opts.test_sampling_steps))
241
- # model = vit.DiT_XL_2(
242
- # input_size=opts.latent_size[0],
243
- # latent_dim=opts.latent_dim,
244
- # in_channels=opts.latent_dim+opts.n_keypoints+opts.n_mask,
245
- # learn_sigma=True,
246
- # ).to(device)
247
- # ckpt_state_dict = torch.load(model_path)['state_dict']
248
- # dit_state_dict = {remove_prefix(k, 'diffusion_backbone.'): v for k, v in ckpt_state_dict.items() if k.startswith('diffusion_backbone')}
249
- # vae_state_dict = {remove_prefix(k, 'autoencoder.'): v for k, v in ckpt_state_dict.items() if k.startswith('autoencoder')}
250
- # missing_keys, extra_keys = model.load_state_dict(dit_state_dict, strict=False)
251
- # model.eval()
252
- # assert len(missing_keys) == 0 and len(extra_keys) == 0
253
- # autoencoder = vqvae.create_model(3, 3, opts.latent_dim).eval().requires_grad_(False).to(device)
254
- # missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
255
- # autoencoder.eval()
256
- # assert len(missing_keys) == 0 and len(extra_keys) == 0
257
- sam_path = hf_hub_download(repo_id="Chaerin5/FoundHand-weights", filename="sam_vit_h_4b8939.pth", token=token)
258
- sam_predictor = init_sam(ckpt_path=sam_path, device='cuda')
259
-
260
-
261
- print("Mediapipe hand detector and SAM ready...")
262
- mp_hands = mp.solutions.hands
263
- hands = mp_hands.Hands(
264
- static_image_mode=True, # Use False if image is part of a video stream
265
- max_num_hands=2, # Maximum number of hands to detect
266
- min_detection_confidence=0.1,
267
- )
268
-
269
- def prepare_ref_anno(ref):
270
- if ref is None:
271
- return (
272
- None,
273
- None,
274
- None,
275
- None,
276
- None,
277
- )
278
- missing_keys, extra_keys = autoencoder.load_state_dict(vae_state_dict, strict=False)
279
-
280
- img = ref["composite"][..., :3]
281
- img = cv2.resize(img, opts.image_size, interpolation=cv2.INTER_AREA)
282
- keypts = np.zeros((42, 2))
283
- # if REF_POSE_MASK:
284
- mp_pose = hands.process(img)
285
- # detected = np.array([0, 0])
286
- # start_idx = 0
287
- if mp_pose.multi_hand_landmarks:
288
- # handedness is flipped assuming the input image is mirrored in MediaPipe
289
- for hand_landmarks, handedness in zip(
290
- mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
291
- ):
292
- # actually right hand
293
- if handedness.classification[0].label == "Left":
294
- start_idx = 0
295
- # detected[0] = 1
296
- # actually left hand
297
- elif handedness.classification[0].label == "Right":
298
- start_idx = 21
299
- # detected[1] = 1
300
- for i, landmark in enumerate(hand_landmarks.landmark):
301
- keypts[start_idx + i] = [
302
- landmark.x * opts.image_size[1],
303
- landmark.y * opts.image_size[0],
304
- ]
305
-
306
- # sam_predictor.set_image(img)
307
- # l = keypts[:21].shape[0]
308
- # if keypts[0].sum() != 0 and keypts[21].sum() != 0:
309
- # input_point = np.array([keypts[0], keypts[21]])
310
- # input_label = np.array([1, 1])
311
- # elif keypts[0].sum() != 0:
312
- # input_point = np.array(keypts[:1])
313
- # input_label = np.array([1])
314
- # elif keypts[21].sum() != 0:
315
- # input_point = np.array(keypts[21:22])
316
- # input_label = np.array([1])
317
- # masks, _, _ = sam_predictor.predict(
318
- # point_coords=input_point,
319
- # point_labels=input_label,
320
- # multimask_output=False,
321
- # )
322
- # hand_mask = masks[0]
323
- # masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
324
- # ref_pose = visualize_hand(keypts, masked_img)
325
- print(f"keypts.max(): {keypts.max()}, keypts.min(): {keypts.min()}")
326
- return img, keypts
327
- else:
328
- return img, None
329
- # raise gr.Error("No hands detected in the reference image.")
330
- # else:
331
- # hand_mask = np.zeros_like(img[:,:, 0])
332
- # ref_pose = np.zeros_like(img)
333
-
334
- def get_ref_anno(img, keypts):
335
- if keypts is None:
336
- no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
337
- return None, no_hands, None
338
- if isinstance(keypts, list):
339
- if len(keypts[0]) == 0:
340
- keypts[0] = np.zeros((21, 2))
341
- elif len(keypts[0]) == 21:
342
- keypts[0] = np.array(keypts[0], dtype=np.float32)
343
- else:
344
- gr.Info("Number of right hand keypoints should be either 0 or 21.")
345
- return None, None
346
-
347
- if len(keypts[1]) == 0:
348
- keypts[1] = np.zeros((21, 2))
349
- elif len(keypts[1]) == 21:
350
- keypts[1] = np.array(keypts[1], dtype=np.float32)
351
- else:
352
- gr.Info("Number of left hand keypoints should be either 0 or 21.")
353
- return None, None
354
-
355
- keypts = np.concatenate(keypts, axis=0)
356
- # keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size)
357
- if REF_POSE_MASK:
358
- sam_predictor.set_image(img)
359
- # l = keypts[:21].shape[0]
360
- if keypts[0].sum() != 0 and keypts[21].sum() != 0:
361
- input_point = np.array([keypts[0], keypts[21]])
362
- input_label = np.array([1, 1])
363
- elif keypts[0].sum() != 0:
364
- input_point = np.array(keypts[:1])
365
- input_label = np.array([1])
366
- elif keypts[21].sum() != 0:
367
- input_point = np.array(keypts[21:22])
368
- input_label = np.array([1])
369
- masks, _, _ = sam_predictor.predict(
370
- point_coords=input_point,
371
- point_labels=input_label,
372
- multimask_output=False,
373
- )
374
- hand_mask = masks[0]
375
- masked_img = img * hand_mask[..., None] + 255 * (1 - hand_mask[..., None])
376
- ref_pose = visualize_hand(keypts, masked_img)
377
- else:
378
- hand_mask = np.zeros_like(img[:,:, 0])
379
- ref_pose = np.zeros_like(img)
380
- def make_ref_cond(
381
- img,
382
- keypts,
383
- hand_mask,
384
- device="cuda",
385
- target_size=(256, 256),
386
- latent_size=(32, 32),
387
- ):
388
- image_transform = Compose(
389
- [
390
- ToTensor(),
391
- Resize(target_size),
392
- Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
393
- ]
394
- )
395
- image = image_transform(img).to(device)
396
- kpts_valid = check_keypoints_validity(keypts, target_size)
397
- heatmaps = torch.tensor(
398
- keypoint_heatmap(
399
- scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
400
- )
401
- * kpts_valid[:, None, None],
402
- dtype=torch.float,
403
- device=device
404
- )[None, ...]
405
- mask = torch.tensor(
406
- cv2.resize(
407
- hand_mask.astype(int),
408
- dsize=latent_size,
409
- interpolation=cv2.INTER_NEAREST,
410
- ),
411
- dtype=torch.float,
412
- device=device,
413
- ).unsqueeze(0)[None, ...]
414
- return image[None, ...], heatmaps, mask
415
-
416
- print(f"img.max(): {img.max()}, img.min(): {img.min()}")
417
- image, heatmaps, mask = make_ref_cond(
418
- img,
419
- keypts,
420
- hand_mask,
421
- device="cuda",
422
- target_size=opts.image_size,
423
- latent_size=opts.latent_size,
424
- )
425
- print(f"image.max(): {image.max()}, image.min(): {image.min()}")
426
- print(f"opts.latent_scaling_factor: {opts.latent_scaling_factor}")
427
- print(f"autoencoder encoder before operating max: {min([p.min() for p in autoencoder.encoder.parameters()])}")
428
- print(f"autoencoder encoder before operating min: {max([p.max() for p in autoencoder.encoder.parameters()])}")
429
- print(f"autoencoder encoder before operating dtype: {next(autoencoder.encoder.parameters()).dtype}")
430
- latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
431
- print(f"latent.max(): {latent.max()}, latent.min(): {latent.min()}")
432
- if not REF_POSE_MASK:
433
- heatmaps = torch.zeros_like(heatmaps)
434
- mask = torch.zeros_like(mask)
435
- print(f"heatmaps.max(): {heatmaps.max()}, heatmaps.min(): {heatmaps.min()}")
436
- print(f"mask.max(): {mask.max()}, mask.min(): {mask.min()}")
437
- ref_cond = torch.cat([latent, heatmaps, mask], 1)
438
- print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}")
439
-
440
- return img, ref_pose, ref_cond
441
-
442
- def get_target_anno(target):
443
- if target is None:
444
- return (
445
- gr.State.update(value=None),
446
- gr.Image.update(value=None),
447
- gr.State.update(value=None),
448
- gr.State.update(value=None),
449
- )
450
- pose_img = target["composite"][..., :3]
451
- pose_img = cv2.resize(pose_img, opts.image_size, interpolation=cv2.INTER_AREA)
452
- # detect keypoints
453
- mp_pose = hands.process(pose_img)
454
- target_keypts = np.zeros((42, 2))
455
- detected = np.array([0, 0])
456
- start_idx = 0
457
- if mp_pose.multi_hand_landmarks:
458
- # handedness is flipped assuming the input image is mirrored in MediaPipe
459
- for hand_landmarks, handedness in zip(
460
- mp_pose.multi_hand_landmarks, mp_pose.multi_handedness
461
- ):
462
- # actually right hand
463
- if handedness.classification[0].label == "Left":
464
- start_idx = 0
465
- detected[0] = 1
466
- # actually left hand
467
- elif handedness.classification[0].label == "Right":
468
- start_idx = 21
469
- detected[1] = 1
470
- for i, landmark in enumerate(hand_landmarks.landmark):
471
- target_keypts[start_idx + i] = [
472
- landmark.x * opts.image_size[1],
473
- landmark.y * opts.image_size[0],
474
- ]
475
-
476
- target_pose = visualize_hand(target_keypts, pose_img)
477
- kpts_valid = check_keypoints_validity(target_keypts, opts.image_size)
478
- target_heatmaps = torch.tensor(
479
- keypoint_heatmap(
480
- scale_keypoint(target_keypts, opts.image_size, opts.latent_size),
481
- opts.latent_size,
482
- var=1.0,
483
- )
484
- * kpts_valid[:, None, None],
485
- dtype=torch.float,
486
- # device=device,
487
- )[None, ...]
488
- target_cond = torch.cat(
489
- [target_heatmaps, torch.zeros_like(target_heatmaps)[:, :1]], 1
490
- )
491
- else:
492
- raise gr.Error("No hands detected in the target image.")
493
-
494
- return pose_img, target_pose, target_cond, target_keypts
495
-
496
-
497
- def get_mask_inpaint(ref):
498
- inpaint_mask = np.array(ref["layers"][0])[..., -1]
499
- inpaint_mask = cv2.resize(
500
- inpaint_mask, opts.image_size, interpolation=cv2.INTER_AREA
501
- )
502
- inpaint_mask = (inpaint_mask >= 128).astype(np.uint8)
503
- return inpaint_mask
504
-
505
-
506
- def visualize_ref(crop, brush):
507
- if crop is None or brush is None:
508
- return None
509
- inpainted = brush["layers"][0][..., -1]
510
- img = crop["background"][..., :3]
511
- img = cv2.resize(img, inpainted.shape[::-1], interpolation=cv2.INTER_AREA)
512
- mask = inpainted < 128
513
- # img = img.astype(np.int32)
514
- # img[mask, :] = img[mask, :] - 50
515
- # img[np.any(img<0, axis=-1)]=0
516
- # img = img.astype(np.uint8)
517
- img = mask_image(img, mask)
518
- return img
519
-
520
-
521
- def get_kps(img, keypoints, side: Literal["right", "left"], evt: gr.SelectData):
522
- if keypoints is None:
523
- keypoints = [[], []]
524
- kps = np.zeros((42, 2))
525
- if side == "right":
526
- if len(keypoints[0]) == 21:
527
- gr.Info("21 keypoints for right hand already selected. Try reset if something looks wrong.")
528
- else:
529
- keypoints[0].append(list(evt.index))
530
- len_kps = len(keypoints[0])
531
- kps[:len_kps] = np.array(keypoints[0])
532
- elif side == "left":
533
- if len(keypoints[1]) == 21:
534
- gr.Info("21 keypoints for left hand already selected. Try reset if something looks wrong.")
535
- else:
536
- keypoints[1].append(list(evt.index))
537
- len_kps = len(keypoints[1])
538
- kps[21 : 21 + len_kps] = np.array(keypoints[1])
539
- vis_hand = visualize_hand(kps, img, side, len_kps)
540
- return vis_hand, keypoints
541
-
542
-
543
- def undo_kps(img, keypoints, side: Literal["right", "left"]):
544
- if keypoints is None:
545
- return img, None
546
- kps = np.zeros((42, 2))
547
- if side == "right":
548
- if len(keypoints[0]) == 0:
549
- return img, keypoints
550
- keypoints[0].pop()
551
- len_kps = len(keypoints[0])
552
- kps[:len_kps] = np.array(keypoints[0])
553
- elif side == "left":
554
- if len(keypoints[1]) == 0:
555
- return img, keypoints
556
- keypoints[1].pop()
557
- len_kps = len(keypoints[1])
558
- kps[21 : 21 + len_kps] = np.array(keypoints[1])
559
- vis_hand = visualize_hand(kps, img, side, len_kps)
560
- return vis_hand, keypoints
561
-
562
-
563
- def reset_kps(img, keypoints, side: Literal["right", "left"]):
564
- if keypoints is None:
565
- return img, None
566
- if side == "right":
567
- keypoints[0] = []
568
- elif side == "left":
569
- keypoints[1] = []
570
- return img, keypoints
571
-
572
- # @spaces.GPU(duration=60)
573
- def sample_diff(ref_cond, target_cond, target_keypts, num_gen, seed, cfg):
574
- set_seed(seed)
575
- z = torch.randn(
576
- (num_gen, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]),
577
- device=device,
578
- )
579
- print(f"z.device: {z.device}")
580
- target_cond = target_cond.repeat(num_gen, 1, 1, 1).to(z.device)
581
- ref_cond = ref_cond.repeat(num_gen, 1, 1, 1).to(z.device)
582
- print(f"target_cond.max(): {target_cond.max()}, target_cond.min(): {target_cond.min()}")
583
- print(f"ref_cond.max(): {ref_cond.max()}, ref_cond.min(): {ref_cond.min()}")
584
- # novel view synthesis mode = off
585
- nvs = torch.zeros(num_gen, dtype=torch.int, device=device)
586
- z = torch.cat([z, z], 0)
587
- model_kwargs = dict(
588
- target_cond=torch.cat([target_cond, torch.zeros_like(target_cond)]),
589
- ref_cond=torch.cat([ref_cond, torch.zeros_like(ref_cond)]),
590
- nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
591
- cfg_scale=cfg,
592
- )
593
-
594
- samples, _ = diffusion.p_sample_loop(
595
- model.forward_with_cfg,
596
- z.shape,
597
- z,
598
- clip_denoised=False,
599
- model_kwargs=model_kwargs,
600
- progress=True,
601
- device=device,
602
- ).chunk(2)
603
- sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
604
- sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
605
- sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
606
-
607
- results = []
608
- results_pose = []
609
- for i in range(MAX_N):
610
- if i < num_gen:
611
- results.append(sampled_images[i])
612
- results_pose.append(visualize_hand(target_keypts, sampled_images[i]))
613
- else:
614
- results.append(placeholder)
615
- results_pose.append(placeholder)
616
- print(f"results[0].max(): {results[0].max()}")
617
- return results, results_pose
618
-
619
- # @spaces.GPU(duration=120)
620
- def ready_sample(img_ori, inpaint_mask, keypts):
621
- img = cv2.resize(img_ori[..., :3], opts.image_size, interpolation=cv2.INTER_AREA)
622
- sam_predictor.set_image(img)
623
- if len(keypts[0]) == 0:
624
- keypts[0] = np.zeros((21, 2))
625
- elif len(keypts[0]) == 21:
626
- keypts[0] = np.array(keypts[0], dtype=np.float32)
627
- else:
628
- gr.Info("Number of right hand keypoints should be either 0 or 21.")
629
- return None, None
630
-
631
- if len(keypts[1]) == 0:
632
- keypts[1] = np.zeros((21, 2))
633
- elif len(keypts[1]) == 21:
634
- keypts[1] = np.array(keypts[1], dtype=np.float32)
635
- else:
636
- gr.Info("Number of left hand keypoints should be either 0 or 21.")
637
- return None, None
638
-
639
- keypts = np.concatenate(keypts, axis=0)
640
- keypts = scale_keypoint(keypts, (LENGTH, LENGTH), opts.image_size)
641
- # if keypts[0].sum() != 0 and keypts[21].sum() != 0:
642
- # input_point = np.array([keypts[0], keypts[21]])
643
- # # input_point = keypts
644
- # input_label = np.array([1, 1])
645
- # # input_label = np.ones_like(input_point[:, 0])
646
- # elif keypts[0].sum() != 0:
647
- # input_point = np.array(keypts[:1])
648
- # # input_point = keypts[:21]
649
- # input_label = np.array([1])
650
- # # input_label = np.ones_like(input_point[:21, 0])
651
- # elif keypts[21].sum() != 0:
652
- # input_point = np.array(keypts[21:22])
653
- # # input_point = keypts[21:]
654
- # input_label = np.array([1])
655
- # # input_label = np.ones_like(input_point[21:, 0])
656
-
657
- box_shift_ratio = 0.5
658
- box_size_factor = 1.2
659
-
660
- if keypts[0].sum() != 0 and keypts[21].sum() != 0:
661
- input_point = np.array(keypts)
662
- input_box = np.stack([keypts.min(axis=0), keypts.max(axis=0)])
663
- elif keypts[0].sum() != 0:
664
- input_point = np.array(keypts[:21])
665
- input_box = np.stack([keypts[:21].min(axis=0), keypts[:21].max(axis=0)])
666
- elif keypts[21].sum() != 0:
667
- input_point = np.array(keypts[21:])
668
- input_box = np.stack([keypts[21:].min(axis=0), keypts[21:].max(axis=0)])
669
- else:
670
- raise ValueError(
671
- "Something wrong. If no hand detected, it should not reach here."
672
- )
673
-
674
- input_label = np.ones_like(input_point[:, 0]).astype(np.int32)
675
- box_trans = input_box[0] * box_shift_ratio + input_box[1] * (1 - box_shift_ratio)
676
- input_box = ((input_box - box_trans) * box_size_factor + box_trans).reshape(-1)
677
-
678
- masks, _, _ = sam_predictor.predict(
679
- point_coords=input_point,
680
- point_labels=input_label,
681
- box=input_box[None, :],
682
- multimask_output=False,
683
- )
684
- hand_mask = masks[0]
685
-
686
- inpaint_latent_mask = torch.tensor(
687
- cv2.resize(
688
- inpaint_mask, dsize=opts.latent_size, interpolation=cv2.INTER_NEAREST
689
- ),
690
- dtype=torch.float,
691
- # device=device,
692
- ).unsqueeze(0)[None, ...]
693
-
694
- def make_ref_cond(
695
- img,
696
- keypts,
697
- hand_mask,
698
- device=device,
699
- target_size=(256, 256),
700
- latent_size=(32, 32),
701
- ):
702
- image_transform = Compose(
703
- [
704
- ToTensor(),
705
- Resize(target_size),
706
- Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
707
- ]
708
- )
709
- image = image_transform(img)
710
- kpts_valid = check_keypoints_validity(keypts, target_size)
711
- heatmaps = torch.tensor(
712
- keypoint_heatmap(
713
- scale_keypoint(keypts, target_size, latent_size), latent_size, var=1.0
714
- )
715
- * kpts_valid[:, None, None],
716
- dtype=torch.float,
717
- # device=device,
718
- )[None, ...]
719
- mask = torch.tensor(
720
- cv2.resize(
721
- hand_mask.astype(int),
722
- dsize=latent_size,
723
- interpolation=cv2.INTER_NEAREST,
724
- ),
725
- dtype=torch.float,
726
- # device=device,
727
- ).unsqueeze(0)[None, ...]
728
- return image[None, ...], heatmaps, mask
729
-
730
- image, heatmaps, mask = make_ref_cond(
731
- img,
732
- keypts,
733
- hand_mask * (1 - inpaint_mask),
734
- device=device,
735
- target_size=opts.image_size,
736
- latent_size=opts.latent_size,
737
- )
738
- latent = opts.latent_scaling_factor * autoencoder.encode(image).sample()
739
- target_cond = torch.cat([heatmaps, torch.zeros_like(mask)], 1)
740
- ref_cond = torch.cat([latent, heatmaps, mask], 1)
741
- ref_cond = torch.zeros_like(ref_cond)
742
-
743
- img32 = cv2.resize(img, opts.latent_size, interpolation=cv2.INTER_NEAREST)
744
- assert mask.max() == 1
745
- vis_mask32 = mask_image(
746
- img32, inpaint_latent_mask[0,0].cpu().numpy(), (255,255,255), transparent=False
747
- ).astype(np.uint8) # 1.0 - mask[0, 0].cpu().numpy()
748
-
749
- assert np.unique(inpaint_mask).shape[0] <= 2
750
- assert hand_mask.dtype == bool
751
- mask256 = inpaint_mask # hand_mask * (1 - inpaint_mask)
752
- vis_mask256 = mask_image(img, mask256, (255,255,255), transparent=False).astype(
753
- np.uint8
754
- ) # 1 - mask256
755
-
756
- return (
757
- ref_cond,
758
- target_cond,
759
- latent,
760
- inpaint_latent_mask,
761
- keypts,
762
- vis_mask32,
763
- vis_mask256,
764
- )
765
-
766
-
767
- def switch_mask_size(radio):
768
- if radio == "256x256":
769
- out = (gr.update(visible=False), gr.update(visible=True))
770
- elif radio == "latent size (32x32)":
771
- out = (gr.update(visible=True), gr.update(visible=False))
772
- return out
773
-
774
- # @spaces.GPU(duration=300)
775
- def sample_inpaint(
776
- ref_cond,
777
- target_cond,
778
- latent,
779
- inpaint_latent_mask,
780
- keypts,
781
- num_gen,
782
- seed,
783
- cfg,
784
- quality,
785
- ):
786
- set_seed(seed)
787
- N = num_gen
788
- jump_length = 10
789
- jump_n_sample = quality
790
- cfg_scale = cfg
791
- z = torch.randn(
792
- (N, opts.latent_dim, opts.latent_size[0], opts.latent_size[1]), device=device
793
- )
794
- target_cond_N = target_cond.repeat(N, 1, 1, 1).to(z.device)
795
- ref_cond_N = ref_cond.repeat(N, 1, 1, 1).to(z.device)
796
- # novel view synthesis mode = off
797
- nvs = torch.zeros(N, dtype=torch.int, device=device)
798
- z = torch.cat([z, z], 0)
799
- model_kwargs = dict(
800
- target_cond=torch.cat([target_cond_N, torch.zeros_like(target_cond_N)]),
801
- ref_cond=torch.cat([ref_cond_N, torch.zeros_like(ref_cond_N)]),
802
- nvs=torch.cat([nvs, 2 * torch.ones_like(nvs)]),
803
- cfg_scale=cfg_scale,
804
- )
805
-
806
- samples, _ = diffusion.inpaint_p_sample_loop(
807
- model.forward_with_cfg,
808
- z.shape,
809
- latent.to(z.device),
810
- inpaint_latent_mask.to(z.device),
811
- z,
812
- clip_denoised=False,
813
- model_kwargs=model_kwargs,
814
- progress=True,
815
- device=z.device,
816
- jump_length=jump_length,
817
- jump_n_sample=jump_n_sample,
818
- ).chunk(2)
819
- sampled_images = autoencoder.decode(samples / opts.latent_scaling_factor)
820
- sampled_images = torch.clamp(sampled_images, min=-1.0, max=1.0)
821
- sampled_images = unnormalize(sampled_images.permute(0, 2, 3, 1).cpu().numpy())
822
-
823
- # visualize
824
- results = []
825
- results_pose = []
826
- for i in range(FIX_MAX_N):
827
- if i < num_gen:
828
- results.append(sampled_images[i])
829
- results_pose.append(visualize_hand(keypts, sampled_images[i]))
830
- else:
831
- results.append(placeholder)
832
- results_pose.append(placeholder)
833
- return results, results_pose
834
-
835
-
836
- def flip_hand(
837
- img, pose_img, cond: Optional[torch.Tensor], keypts: Optional[torch.Tensor] = None, pose_manual_img = None,
838
- manual_kp_right=None, manual_kp_left=None
839
- ):
840
- if cond is None: # clear clicked
841
- return None, None, None, None
842
- img["composite"] = img["composite"][:, ::-1, :]
843
- img["background"] = img["background"][:, ::-1, :]
844
- img["layers"] = [layer[:, ::-1, :] for layer in img["layers"]]
845
- pose_img = pose_img[:, ::-1, :]
846
- cond = cond.flip(-1)
847
- if keypts is not None: # cond is target_cond
848
- if keypts[:21, :].sum() != 0:
849
- keypts[:21, 0] = opts.image_size[1] - keypts[:21, 0]
850
- # keypts[:21, 1] = opts.image_size[0] - keypts[:21, 1]
851
- if keypts[21:, :].sum() != 0:
852
- keypts[21:, 0] = opts.image_size[1] - keypts[21:, 0]
853
- # keypts[21:, 1] = opts.image_size[0] - keypts[21:, 1]
854
- if pose_manual_img is not None:
855
- pose_manual_img = pose_manual_img[:, ::-1, :]
856
- manual_kp_right = manual_kp_right[:, ::-1, :]
857
- manual_kp_left = manual_kp_left[:, ::-1, :]
858
- return img, pose_img, cond, keypts, pose_manual_img, manual_kp_right, manual_kp_left
859
-
860
-
861
- def resize_to_full(img):
862
- img["background"] = cv2.resize(img["background"], (LENGTH, LENGTH))
863
- img["composite"] = cv2.resize(img["composite"], (LENGTH, LENGTH))
864
- img["layers"] = [cv2.resize(layer, (LENGTH, LENGTH)) for layer in img["layers"]]
865
- return img
866
-
867
-
868
- def clear_all():
869
- return (
870
- None,
871
- None,
872
- None,
873
- None,
874
- None,
875
- False,
876
- None,
877
- None,
878
- False,
879
- None,
880
- None,
881
- None,
882
- None,
883
- None,
884
- None,
885
- None,
886
- 1,
887
- 42,
888
- 3.0,
889
- gr.update(interactive=False),
890
- []
891
- )
892
-
893
-
894
- def fix_clear_all():
895
- return (
896
- None,
897
- None,
898
- None,
899
- None,
900
- None,
901
- None,
902
- None,
903
- None,
904
- None,
905
- None,
906
- None,
907
- None,
908
- None,
909
- None,
910
- None,
911
- None,
912
- None,
913
- 1,
914
- # (0,0),
915
- 42,
916
- 3.0,
917
- 10,
918
- )
919
-
920
-
921
- def enable_component(image1, image2):
922
- if image1 is None or image2 is None:
923
- return gr.update(interactive=False)
924
- if "background" in image1 and "layers" in image1 and "composite" in image1:
925
- if (
926
- image1["background"].sum() == 0
927
- and (sum([im.sum() for im in image1["layers"]]) == 0)
928
- and image1["composite"].sum() == 0
929
- ):
930
- return gr.update(interactive=False)
931
- if "background" in image2 and "layers" in image2 and "composite" in image2:
932
- if (
933
- image2["background"].sum() == 0
934
- and (sum([im.sum() for im in image2["layers"]]) == 0)
935
- and image2["composite"].sum() == 0
936
- ):
937
- return gr.update(interactive=False)
938
- return gr.update(interactive=True)
939
-
940
-
941
- def set_visible(checkbox, kpts, img_clean, img_pose_right, img_pose_left, done=None, done_info=None):
942
- if kpts is None:
943
- kpts = [[], []]
944
- if "Right hand" not in checkbox:
945
- kpts[0] = []
946
- vis_right = img_clean
947
- update_right = gr.update(visible=False)
948
- update_r_info = gr.update(visible=False)
949
- else:
950
- vis_right = img_pose_right
951
- update_right = gr.update(visible=True)
952
- update_r_info = gr.update(visible=True)
953
-
954
- if "Left hand" not in checkbox:
955
- kpts[1] = []
956
- vis_left = img_clean
957
- update_left = gr.update(visible=False)
958
- update_l_info = gr.update(visible=False)
959
- else:
960
- vis_left = img_pose_left
961
- update_left = gr.update(visible=True)
962
- update_l_info = gr.update(visible=True)
963
-
964
- ret = [
965
- kpts,
966
- vis_right,
967
- vis_left,
968
- update_right,
969
- update_right,
970
- update_right,
971
- update_left,
972
- update_left,
973
- update_left,
974
- update_r_info,
975
- update_l_info,
976
- ]
977
- if done is not None:
978
- if not checkbox:
979
- ret.append(gr.update(visible=False))
980
- ret.append(gr.update(visible=False))
981
- else:
982
- ret.append(gr.update(visible=True))
983
- ret.append(gr.update(visible=True))
984
- return tuple(ret)
985
-
986
- def set_unvisible():
987
- return (
988
- gr.update(visible=False),
989
- gr.update(visible=False),
990
- gr.update(visible=False),
991
- gr.update(visible=False),
992
- gr.update(visible=False),
993
- gr.update(visible=False),
994
- gr.update(visible=False),
995
- gr.update(visible=False),
996
- gr.update(visible=False),
997
- gr.update(visible=False),
998
- gr.update(visible=False),
999
- gr.update(visible=False)
1000
- )
1001
-
1002
- def set_no_hands(decider, component):
1003
- if decider is None:
1004
- no_hands = cv2.resize(np.array(Image.open("no_hands.png"))[..., :3], (LENGTH, LENGTH))
1005
- return no_hands
1006
- else:
1007
- return component
1008
-
1009
- # def visible_component(decider, component):
1010
- # if decider is not None:
1011
- # update_component = gr.update(visible=True)
1012
- # else:
1013
- # update_component = gr.update(visible=False)
1014
- # return update_component
1015
-
1016
- def unvisible_component(decider, component):
1017
- if decider is not None:
1018
- update_component = gr.update(visible=False)
1019
- else:
1020
- update_component = gr.update(visible=True)
1021
- return update_component
1022
-
1023
- def make_change(decider, state):
1024
- '''
1025
- if decider is not None, change the state's value. True/False does not matter.
1026
- '''
1027
- if decider is not None:
1028
- if state:
1029
- state = False
1030
- else:
1031
- state = True
1032
- return state
1033
- else:
1034
- return state
1035
-
1036
- LENGTH = 480
1037
-
1038
- example_ref_imgs = [
1039
- [
1040
- "sample_images/sample1.jpg",
1041
- ],
1042
- [
1043
- "sample_images/sample2.jpg",
1044
- ],
1045
- [
1046
- "sample_images/sample3.jpg",
1047
- ],
1048
- [
1049
- "sample_images/sample4.jpg",
1050
- ],
1051
- # [
1052
- # "sample_images/sample5.jpg",
1053
- # ],
1054
- [
1055
- "sample_images/sample6.jpg",
1056
- ],
1057
- # [
1058
- # "sample_images/sample7.jpg",
1059
- # ],
1060
- # [
1061
- # "sample_images/sample8.jpg",
1062
- # ],
1063
- # [
1064
- # "sample_images/sample9.jpg",
1065
- # ],
1066
- # [
1067
- # "sample_images/sample10.jpg",
1068
- # ],
1069
- # [
1070
- # "sample_images/sample11.jpg",
1071
- # ],
1072
- # ["pose_images/pose1.jpg"],
1073
- # ["pose_images/pose2.jpg"],
1074
- # ["pose_images/pose3.jpg"],
1075
- # ["pose_images/pose4.jpg"],
1076
- # ["pose_images/pose5.jpg"],
1077
- # ["pose_images/pose6.jpg"],
1078
- # ["pose_images/pose7.jpg"],
1079
- # ["pose_images/pose8.jpg"],
1080
- ]
1081
- example_target_imgs = [
1082
- # [
1083
- # "sample_images/sample1.jpg",
1084
- # ],
1085
- # [
1086
- # "sample_images/sample2.jpg",
1087
- # ],
1088
- # [
1089
- # "sample_images/sample3.jpg",
1090
- # ],
1091
- # [
1092
- # "sample_images/sample4.jpg",
1093
- # ],
1094
- [
1095
- "sample_images/sample5.jpg",
1096
- ],
1097
- # [
1098
- # "sample_images/sample6.jpg",
1099
- # ],
1100
- # [
1101
- # "sample_images/sample7.jpg",
1102
- # ],
1103
- # [
1104
- # "sample_images/sample8.jpg",
1105
- # ],
1106
- [
1107
- "sample_images/sample9.jpg",
1108
- ],
1109
- [
1110
- "sample_images/sample10.jpg",
1111
- ],
1112
- [
1113
- "sample_images/sample11.jpg",
1114
- ],
1115
- ["pose_images/pose1.jpg"],
1116
- # ["pose_images/pose2.jpg"],
1117
- # ["pose_images/pose3.jpg"],
1118
- # ["pose_images/pose4.jpg"],
1119
- # ["pose_images/pose5.jpg"],
1120
- # ["pose_images/pose6.jpg"],
1121
- # ["pose_images/pose7.jpg"],
1122
- # ["pose_images/pose8.jpg"],
1123
- ]
1124
- fix_example_imgs = [
1125
- ["bad_hands/1.jpg"], # "bad_hands/1_mask.jpg"],
1126
- # ["bad_hands/2.jpg"], # "bad_hands/2_mask.jpg"],
1127
- ["bad_hands/3.jpg"], # "bad_hands/3_mask.jpg"],
1128
- # ["bad_hands/4.jpg"], # "bad_hands/4_mask.jpg"],
1129
- ["bad_hands/5.jpg"], # "bad_hands/5_mask.jpg"],
1130
- ["bad_hands/6.jpg"], # "bad_hands/6_mask.jpg"],
1131
- ["bad_hands/7.jpg"], # "bad_hands/7_mask.jpg"],
1132
- # ["bad_hands/8.jpg"], # "bad_hands/8_mask.jpg"],
1133
- # ["bad_hands/9.jpg"], # "bad_hands/9_mask.jpg"],
1134
- # ["bad_hands/10.jpg"], # "bad_hands/10_mask.jpg"],
1135
- # ["bad_hands/11.jpg"], # "bad_hands/11_mask.jpg"],
1136
- # ["bad_hands/12.jpg"], # "bad_hands/12_mask.jpg"],
1137
- # ["bad_hands/13.jpg"], # "bad_hands/13_mask.jpg"],
1138
- ["bad_hands/14.jpg"],
1139
- ["bad_hands/15.jpg"],
1140
- ]
1141
- custom_css = """
1142
- .gradio-container .examples img {
1143
- width: 240px !important;
1144
- height: 240px !important;
1145
- }
1146
- """
1147
-
1148
- _HEADER_ = '''
1149
- <div style="text-align: center;">
1150
- <h1><b>FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation</b></h1>
1151
- <h2 style="color: #777777;">CVPR 2025</h2>
1152
- <style>
1153
- .link-spacing {
1154
- margin-right: 20px;
1155
- }
1156
- </style>
1157
- <p style="font-size: 15px;">
1158
- <span style="display: inline-block; margin-right: 30px;">Brown University</span>
1159
- <span style="display: inline-block;">Meta Reality Labs</span>
1160
- </p>
1161
- <h3>
1162
- <a href='https://arxiv.org/abs/2412.02690' target='_blank' class="link-spacing">Paper</a>
1163
- <a href='https://ivl.cs.brown.edu/research/foundhand.html' target='_blank' class="link-spacing">Project Page</a>
1164
- <a href='' target='_blank' class="link-spacing">Code</a>
1165
- <a href='' target='_blank'>Model Weights</a>
1166
- </h3>
1167
- <p>Below are two important abilities of our model. First, we can <b>edit hand poses</b> given two hand images - one is the image to edit, and the other one provides target hand pose. Second, we can automatically <b>fix malformed hand images</b>, following the user-provided target hand pose and area to fix.</p>
1168
- </div>
1169
- '''
1170
-
1171
- _CITE_ = r"""
1172
- ```
1173
- @article{chen2024foundhand,
1174
- title={FoundHand: Large-Scale Domain-Specific Learning for Controllable Hand Image Generation},
1175
- author={Chen, Kefan and Min, Chaerin and Zhang, Linguang and Hampali, Shreyas and Keskin, Cem and Sridhar, Srinath},
1176
- journal={arXiv preprint arXiv:2412.02690},
1177
- year={2024}
1178
- }
1179
- ```
1180
- """
1181
-
1182
- with gr.Blocks(css=custom_css, theme="soft") as demo:
1183
- gr.Markdown(_HEADER_)
1184
- with gr.Tab("Edit Hand Poses"):
1185
- ref_img = gr.State(value=None)
1186
- ref_im_raw = gr.State(value=None)
1187
- ref_kp_raw = gr.State(value=0)
1188
- ref_kp_got = gr.State(value=None)
1189
- dump = gr.State(value=None)
1190
- ref_cond = gr.State(value=None)
1191
- ref_manual_cond = gr.State(value=None)
1192
- ref_auto_cond = gr.State(value=None)
1193
- keypts = gr.State(value=None)
1194
- target_img = gr.State(value=None)
1195
- target_cond = gr.State(value=None)
1196
- target_keypts = gr.State(value=None)
1197
- dump = gr.State(value=None)
1198
- with gr.Row():
1199
- with gr.Column():
1200
- gr.Markdown(
1201
- """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a hand image to edit 📥</p>"""
1202
- )
1203
- gr.Markdown(
1204
- """<p style="text-align: center;">&#9312; Optionally crop the image</p>"""
1205
- )
1206
- # gr.Markdown("""<p style="text-align: center;"><br></p>""")
1207
- ref = gr.ImageEditor(
1208
- type="numpy",
1209
- label="Reference",
1210
- show_label=True,
1211
- height=LENGTH,
1212
- width=LENGTH,
1213
- brush=False,
1214
- layers=False,
1215
- crop_size="1:1",
1216
- )
1217
- gr.Examples(example_ref_imgs, [ref], examples_per_page=20)
1218
- gr.Markdown(
1219
- """<p style="text-align: center;">&#9313; Hit the &quot;Finish Cropping&quot; button to get hand pose</p>"""
1220
- )
1221
- ref_finish_crop = gr.Button(value="Finish Cropping", interactive=False)
1222
- with gr.Tab("Automatic hand keypoints"):
1223
- ref_pose = gr.Image(
1224
- type="numpy",
1225
- label="Reference Pose",
1226
- show_label=True,
1227
- height=LENGTH,
1228
- width=LENGTH,
1229
- interactive=False,
1230
- )
1231
- ref_use_auto = gr.Button(value="Click here to use automatic, not manual", interactive=False, visible=True)
1232
- with gr.Tab("Manual hand keypoints"):
1233
- ref_manual_checkbox_info = gr.Markdown(
1234
- """<p style="text-align: center;"><b>Step 1.</b> Tell us if this is right, left, or both hands.</p>""",
1235
- visible=True,
1236
- )
1237
- ref_manual_checkbox = gr.CheckboxGroup(
1238
- ["Right hand", "Left hand"],
1239
- # label="Hand side",
1240
- # info="Hand pose failed to automatically detected. Now let's enable user-provided hand pose. First of all, please tell us if this is right, left, or both hands",
1241
- show_label=False,
1242
- visible=True,
1243
- interactive=True,
1244
- )
1245
- ref_manual_kp_r_info = gr.Markdown(
1246
- """<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>right</b> hand. See \"OpenPose Keypoint Convention\" for guidance.</p>""",
1247
- visible=False,
1248
- )
1249
- ref_manual_kp_right = gr.Image(
1250
- type="numpy",
1251
- label="Keypoint Selection (right hand)",
1252
- show_label=True,
1253
- height=LENGTH,
1254
- width=LENGTH,
1255
- interactive=False,
1256
- visible=False,
1257
- sources=[],
1258
- )
1259
- with gr.Row():
1260
- ref_manual_undo_right = gr.Button(
1261
- value="Undo", interactive=True, visible=False
1262
- )
1263
- ref_manual_reset_right = gr.Button(
1264
- value="Reset", interactive=True, visible=False
1265
- )
1266
- ref_manual_kp_l_info = gr.Markdown(
1267
- """<p style="text-align: center;"><b>Step 2.</b> Click on image to provide hand keypoints for <b>left</b> hand. See \"OpenPose keypoint convention\" for guidance.</p>""",
1268
- visible=False
1269
- )
1270
- ref_manual_kp_left = gr.Image(
1271
- type="numpy",
1272
- label="Keypoint Selection (left hand)",
1273
- show_label=True,
1274
- height=LENGTH,
1275
- width=LENGTH,
1276
- interactive=False,
1277
- visible=False,
1278
- sources=[],
1279
- )
1280
- with gr.Row():
1281
- ref_manual_undo_left = gr.Button(
1282
- value="Undo", interactive=True, visible=False
1283
- )
1284
- ref_manual_reset_left = gr.Button(
1285
- value="Reset", interactive=True, visible=False
1286
- )
1287
- ref_manual_done_info = gr.Markdown(
1288
- """<p style="text-align: center;"><b>Step 3.</b> Hit \"Done\" button to confirm.</p>""",
1289
- visible=False,
1290
- )
1291
- ref_manual_done = gr.Button(value="Done", interactive=True, visible=False)
1292
- ref_manual_pose = gr.Image(
1293
- type="numpy",
1294
- label="Reference Pose",
1295
- show_label=True,
1296
- height=LENGTH,
1297
- width=LENGTH,
1298
- interactive=False,
1299
- visible=False
1300
- )
1301
- ref_use_manual = gr.Button(value="Click here to use manual, not automatic", interactive=True, visible=False)
1302
- ref_manual_instruct = gr.Markdown(
1303
- value="""<p style="text-align: left; font-weight: bold; ">OpenPose Keypoints Convention</p>""",
1304
- visible=True
1305
- )
1306
- ref_manual_openpose = gr.Image(
1307
- value="openpose.png",
1308
- type="numpy",
1309
- # label="OpenPose keypoints convention",
1310
- show_label=False,
1311
- height=LENGTH // 2,
1312
- width=LENGTH // 2,
1313
- interactive=False,
1314
- visible=True
1315
- )
1316
- gr.Markdown(
1317
- """<p style="text-align: center;">&#9314; Optionally flip the hand</p>"""
1318
- )
1319
- ref_flip = gr.Checkbox(
1320
- value=False, label="Flip Handedness (Reference)", interactive=False
1321
- )
1322
- with gr.Column():
1323
- gr.Markdown(
1324
- """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Upload a hand image for target hand pose 📥</p>"""
1325
- )
1326
- gr.Markdown(
1327
- """<p style="text-align: center;">&#9312; Optionally crop the image</p>"""
1328
- )
1329
- target = gr.ImageEditor(
1330
- type="numpy",
1331
- label="Target",
1332
- show_label=True,
1333
- height=LENGTH,
1334
- width=LENGTH,
1335
- brush=False,
1336
- layers=False,
1337
- crop_size="1:1",
1338
- )
1339
- gr.Examples(example_target_imgs, [target], examples_per_page=20)
1340
- gr.Markdown(
1341
- """<p style="text-align: center;">&#9313; Hit the &quot;Finish Cropping&quot; button to get hand pose</p>"""
1342
- )
1343
- target_finish_crop = gr.Button(
1344
- value="Finish Cropping", interactive=False
1345
- )
1346
- target_pose = gr.Image(
1347
- type="numpy",
1348
- label="Target Pose",
1349
- show_label=True,
1350
- height=LENGTH,
1351
- width=LENGTH,
1352
- interactive=False,
1353
- )
1354
- gr.Markdown(
1355
- """<p style="text-align: center;">&#9314; Optionally flip the hand</p>"""
1356
- )
1357
- target_flip = gr.Checkbox(
1358
- value=False, label="Flip Handedness (Target)", interactive=False
1359
- )
1360
- with gr.Column():
1361
- gr.Markdown(
1362
- """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Run&quot; to get the edited results 🎯</p>"""
1363
- )
1364
- # gr.Markdown(
1365
- # """<p style="text-align: center;">[NOTE] Run will be enabled after the previous steps have been completed</p>"""
1366
- # )
1367
- run = gr.Button(value="Run", interactive=False)
1368
- gr.Markdown(
1369
- """<p style="text-align: center;">⚠️ ~20s per generation with RTX3090. ~50s with A100. <br>(For example, if you set Number of generations as 2, it would take around 40s)</p>"""
1370
- )
1371
- results = gr.Gallery(
1372
- type="numpy",
1373
- label="Results",
1374
- show_label=True,
1375
- height=LENGTH,
1376
- min_width=LENGTH,
1377
- columns=MAX_N,
1378
- interactive=False,
1379
- preview=True,
1380
- )
1381
- results_pose = gr.Gallery(
1382
- type="numpy",
1383
- label="Results Pose",
1384
- show_label=True,
1385
- height=LENGTH,
1386
- min_width=LENGTH,
1387
- columns=MAX_N,
1388
- interactive=False,
1389
- preview=True,
1390
- )
1391
- gr.Markdown(
1392
- """<p style="text-align: center;">✨ Hit &quot;Clear&quot; to restart from the beginning</p>"""
1393
- )
1394
- clear = gr.ClearButton()
1395
-
1396
- # gr.Markdown(
1397
- # """<p style="text-align: left; font-size: 25px;"><b>More options</b></p>"""
1398
- # )
1399
- with gr.Tab("More options"):
1400
- with gr.Row():
1401
- n_generation = gr.Slider(
1402
- label="Number of generations",
1403
- value=1,
1404
- minimum=1,
1405
- maximum=MAX_N,
1406
- step=1,
1407
- randomize=False,
1408
- interactive=True,
1409
- )
1410
- seed = gr.Slider(
1411
- label="Seed",
1412
- value=42,
1413
- minimum=0,
1414
- maximum=10000,
1415
- step=1,
1416
- randomize=False,
1417
- interactive=True,
1418
- )
1419
- cfg = gr.Slider(
1420
- label="Classifier free guidance scale",
1421
- value=2.5,
1422
- minimum=0.0,
1423
- maximum=10.0,
1424
- step=0.1,
1425
- randomize=False,
1426
- interactive=True,
1427
- )
1428
-
1429
- ref.change(enable_component, [ref, ref], ref_finish_crop)
1430
- # ref_finish_crop.click(get_ref_anno, [ref], [ref_img, ref_pose, ref_cond])
1431
- ref_finish_crop.click(prepare_ref_anno, [ref], [ref_im_raw, ref_kp_raw])
1432
- # ref_kp_raw.change(make_change, [ref_kp_raw, ref_kp_watcher], ref_kp_watcher)
1433
- # ref_kp_raw.change(set_no_hands, [ref_kp_raw, ref_pose], ref_pose)
1434
- ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_right)
1435
- ref_kp_raw.change(lambda x: x, ref_im_raw, ref_manual_kp_left)
1436
- # ref_kp_raw.change(unvisible_component, [ref_kp_raw, ref_manual_checkbox], ref_manual_checkbox)
1437
- # ref_kp_raw.change(unvisible_component, [ref_kp_raw, ref_manual_checkbox_info], ref_manual_checkbox_info)
1438
- # ref_kp_raw.change(unvisible_component, [ref_kp_raw, ref_manual_openpose], ref_manual_openpose)
1439
- # ref_kp_raw.change(unvisible_component, [ref_kp_raw, ref_manual_instruct], ref_manual_instruct)
1440
- # ref_kp_raw.change(lambda x: x, ref_kp_raw, ref_kp_got)
1441
- ref_manual_checkbox.select(
1442
- set_visible,
1443
- [ref_manual_checkbox, ref_kp_got, ref_im_raw, ref_manual_kp_right, ref_manual_kp_left, ref_manual_done],
1444
- [
1445
- ref_kp_got,
1446
- ref_manual_kp_right,
1447
- ref_manual_kp_left,
1448
- ref_manual_kp_right,
1449
- ref_manual_undo_right,
1450
- ref_manual_reset_right,
1451
- ref_manual_kp_left,
1452
- ref_manual_undo_left,
1453
- ref_manual_reset_left,
1454
- ref_manual_kp_r_info,
1455
- ref_manual_kp_l_info,
1456
- ref_manual_done,
1457
- ref_manual_done_info
1458
- ]
1459
- )
1460
- ref_manual_kp_right.select(
1461
- get_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
1462
- )
1463
- ref_manual_undo_right.click(
1464
- undo_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
1465
- )
1466
- ref_manual_reset_right.click(
1467
- reset_kps, [ref_im_raw, ref_kp_got, gr.State("right")], [ref_manual_kp_right, ref_kp_got]
1468
- )
1469
- ref_manual_kp_left.select(
1470
- get_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
1471
- )
1472
- ref_manual_undo_left.click(
1473
- undo_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
1474
- )
1475
- ref_manual_reset_left.click(
1476
- reset_kps, [ref_im_raw, ref_kp_got, gr.State("left")], [ref_manual_kp_left, ref_kp_got]
1477
- )
1478
- # ref_manual_done.click(lambda x: ~x, ref_kp_watcher, ref_kp_watcher)
1479
- ref_manual_done.click(get_ref_anno, [ref_im_raw, ref_kp_got], [ref_img, ref_manual_pose, ref_manual_cond])
1480
- ref_manual_cond.change(lambda x: x, ref_manual_cond, ref_cond)
1481
- ref_use_manual.click(lambda x: x, ref_manual_cond, ref_cond)
1482
- ref_use_manual.click(lambda x: gr.Info("Manual hand keypoints will be used for 'Reference'", duration=3))
1483
- ref_manual_done.click(lambda x: gr.update(visible=True), ref_manual_pose, ref_manual_pose)
1484
- ref_manual_done.click(lambda x: gr.update(visible=True), ref_use_manual, ref_use_manual)
1485
- ref_manual_pose.change(enable_component, [ref_manual_pose, ref_manual_pose], ref_manual_done)
1486
- # ref_pose.change(enable_component, [ref_pose, gr.State(value=True)], ref_ok)
1487
- ref_kp_raw.change(get_ref_anno, [ref_im_raw, ref_kp_raw], [ref_img, ref_pose, ref_auto_cond])
1488
- ref_auto_cond.change(lambda x: x, ref_auto_cond, ref_cond)
1489
- ref_use_auto.click(lambda x: x, ref_auto_cond, ref_cond)
1490
- ref_use_auto.click(lambda x: gr.Info("Automatic hand keypoints will be used for 'Reference'", duration=3))
1491
- ref_pose.change(enable_component, [ref_kp_raw, ref_pose], ref_use_auto)
1492
- ref_pose.change(enable_component, [ref_img, ref_pose], ref_flip)
1493
- ref_manual_pose.change(enable_component, [ref_img, ref_manual_pose], ref_flip)
1494
- ref_flip.select(
1495
- flip_hand, [ref, ref_pose, ref_cond, gr.State(value=None), ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left], [ref, ref_pose, ref_cond, dump, ref_manual_pose, ref_manual_kp_right, ref_manual_kp_left]
1496
- )
1497
- target.change(enable_component, [target, target], target_finish_crop)
1498
- target_finish_crop.click(
1499
- get_target_anno,
1500
- [target],
1501
- [target_img, target_pose, target_cond, target_keypts],
1502
- )
1503
- target_pose.change(enable_component, [target_img, target_pose], target_flip)
1504
- target_flip.select(
1505
- flip_hand,
1506
- [target, target_pose, target_cond, target_keypts],
1507
- [target, target_pose, target_cond, target_keypts],
1508
- )
1509
- ref_pose.change(enable_component, [ref_pose, target_pose], run)
1510
- ref_manual_pose.change(enable_component, [ref_manual_pose, target_pose], run)
1511
- target_pose.change(enable_component, [ref_pose, target_pose], run)
1512
- run.click(
1513
- sample_diff,
1514
- [ref_cond, target_cond, target_keypts, n_generation, seed, cfg],
1515
- [results, results_pose],
1516
- )
1517
- clear.click(
1518
- clear_all,
1519
- [],
1520
- [
1521
- ref,
1522
- ref_manual_kp_right,
1523
- ref_manual_kp_left,
1524
- ref_pose,
1525
- ref_manual_pose,
1526
- ref_flip,
1527
- target,
1528
- target_pose,
1529
- target_flip,
1530
- results,
1531
- results_pose,
1532
- ref_img,
1533
- ref_cond,
1534
- # mask,
1535
- target_img,
1536
- target_cond,
1537
- target_keypts,
1538
- n_generation,
1539
- seed,
1540
- cfg,
1541
- ref_kp_raw,
1542
- ref_manual_checkbox
1543
- ],
1544
- )
1545
- clear.click(
1546
- set_unvisible,
1547
- [],
1548
- [
1549
- # ref_manual_checkbox,
1550
- # ref_manual_instruct,
1551
- # ref_manual_openpose,
1552
- ref_manual_kp_r_info,
1553
- ref_manual_kp_l_info,
1554
- ref_manual_undo_left,
1555
- ref_manual_undo_right,
1556
- ref_manual_reset_left,
1557
- ref_manual_reset_right,
1558
- ref_manual_done,
1559
- ref_manual_done_info,
1560
- ref_manual_pose,
1561
- ref_use_manual,
1562
- ref_manual_kp_right,
1563
- ref_manual_kp_left
1564
- ]
1565
- )
1566
-
1567
- # gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
1568
- # with gr.Tab("Reference"):
1569
- # with gr.Row():
1570
- # gr.Examples(example_imgs, [ref], examples_per_page=20)
1571
- # with gr.Tab("Target"):
1572
- # with gr.Row():
1573
- # gr.Examples(example_imgs, [target], examples_per_page=20)
1574
- with gr.Tab("Fix Hands"):
1575
- fix_inpaint_mask = gr.State(value=None)
1576
- fix_original = gr.State(value=None)
1577
- fix_img = gr.State(value=None)
1578
- fix_kpts = gr.State(value=None)
1579
- fix_kpts_np = gr.State(value=None)
1580
- fix_ref_cond = gr.State(value=None)
1581
- fix_target_cond = gr.State(value=None)
1582
- fix_latent = gr.State(value=None)
1583
- fix_inpaint_latent = gr.State(value=None)
1584
- # fix_size_memory = gr.State(value=(0, 0))
1585
- # gr.Markdown("""<p style="text-align: center; font-size: 25px; font-weight: bold; ">⚠️ Note</p>""")
1586
- # gr.Markdown("""<p>"Fix Hands" with A100 needs around 6 mins, which is beyond the ZeroGPU quota (5 mins). Please either purchase additional gpus from Hugging Face or wait for us to open-source our code soon so that you can use your own gpus🙏 </p>""")
1587
- with gr.Row():
1588
- with gr.Column():
1589
- # gr.Markdown(
1590
- # """<p style="text-align: center; font-size: 25px; font-weight: bold; ">1. Image Cropping & Brushing</p>"""
1591
- # )
1592
- gr.Markdown(
1593
- """<p style="text-align: center; font-size: 20px; font-weight: bold;">1. Upload a malformed hand image to fix 📥</p>"""
1594
- )
1595
- gr.Markdown(
1596
- """<p style="text-align: center;">&#9312; Optionally crop the image around the hand</p>"""
1597
- )
1598
- # gr.Markdown(
1599
- # """<p style="text-align: center; font-size: 20px; font-weight: bold; ">A. Crop</p>"""
1600
- # )
1601
- fix_crop = gr.ImageEditor(
1602
- type="numpy",
1603
- sources=["upload", "webcam", "clipboard"],
1604
- label="Image crop",
1605
- show_label=True,
1606
- height=LENGTH,
1607
- width=LENGTH,
1608
- layers=False,
1609
- crop_size="1:1",
1610
- brush=False,
1611
- image_mode="RGBA",
1612
- container=False,
1613
- )
1614
- fix_example = gr.Examples(
1615
- fix_example_imgs,
1616
- inputs=[fix_crop],
1617
- examples_per_page=20,
1618
- )
1619
- gr.Markdown(
1620
- """<p style="text-align: center;">&#9313; Brush area (e.g., wrong finger) that needs to be fixed. This will serve as an inpaint mask</p>"""
1621
- )
1622
- # gr.Markdown(
1623
- # """<p style="text-align: center; font-size: 20px; font-weight: bold; ">B. Brush</p>"""
1624
- # )
1625
- fix_ref = gr.ImageEditor(
1626
- type="numpy",
1627
- label="Image brush",
1628
- sources=(),
1629
- show_label=True,
1630
- height=LENGTH,
1631
- width=LENGTH,
1632
- layers=False,
1633
- transforms=("brush"),
1634
- brush=gr.Brush(
1635
- colors=["rgb(255, 255, 255)"], default_size=20
1636
- ), # 204, 50, 50
1637
- image_mode="RGBA",
1638
- container=False,
1639
- interactive=False,
1640
- )
1641
- fix_finish_crop = gr.Button(
1642
- value="Finish Croping & Brushing", interactive=False
1643
- )
1644
- with gr.Column():
1645
- # gr.Markdown(
1646
- # """<p style="text-align: center; font-size: 25px; font-weight: bold; ">2. Keypoint Selection</p>"""
1647
- # )
1648
- gr.Markdown(
1649
- """<p style="text-align: center; font-size: 20px; font-weight: bold;">2. Click on hand to get target hand pose</p>"""
1650
- )
1651
- # gr.Markdown(
1652
- # """<p style="text-align: center;">On the hand, select 21 keypoints that you hope the output to be. <br>Please see the \"OpenPose keypoints convention\"</p>"""
1653
- # )
1654
- gr.Markdown(
1655
- """<p style="text-align: center;">&#9312; Tell us if this is right, left, or both hands</p>"""
1656
- )
1657
- fix_checkbox = gr.CheckboxGroup(
1658
- ["Right hand", "Left hand"],
1659
- # value=["Right hand", "Left hand"],
1660
- # label="Hand side",
1661
- # info="Which side this hand is? Could be both.",
1662
- show_label=False,
1663
- interactive=False,
1664
- )
1665
- gr.Markdown(
1666
- """<p style="text-align: center;">&#9313; On the image, click 21 hand keypoints. This will serve as target hand poses. See the \"OpenPose keypoints convention\" for guidance.</p>"""
1667
- )
1668
- fix_kp_r_info = gr.Markdown(
1669
- """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select right only</p>""",
1670
- visible=False,
1671
- )
1672
- fix_kp_right = gr.Image(
1673
- type="numpy",
1674
- label="Keypoint Selection (right hand)",
1675
- show_label=True,
1676
- height=LENGTH,
1677
- width=LENGTH,
1678
- interactive=False,
1679
- visible=False,
1680
- sources=[],
1681
- )
1682
- with gr.Row():
1683
- fix_undo_right = gr.Button(
1684
- value="Undo", interactive=False, visible=False
1685
- )
1686
- fix_reset_right = gr.Button(
1687
- value="Reset", interactive=False, visible=False
1688
- )
1689
- fix_kp_l_info = gr.Markdown(
1690
- """<p style="text-align: center; font-size: 20px; font-weight: bold; ">Select left only</p>""",
1691
- visible=False
1692
- )
1693
- fix_kp_left = gr.Image(
1694
- type="numpy",
1695
- label="Keypoint Selection (left hand)",
1696
- show_label=True,
1697
- height=LENGTH,
1698
- width=LENGTH,
1699
- interactive=False,
1700
- visible=False,
1701
- sources=[],
1702
- )
1703
- with gr.Row():
1704
- fix_undo_left = gr.Button(
1705
- value="Undo", interactive=False, visible=False
1706
- )
1707
- fix_reset_left = gr.Button(
1708
- value="Reset", interactive=False, visible=False
1709
- )
1710
- gr.Markdown(
1711
- """<p style="text-align: left; font-weight: bold; ">OpenPose keypoints convention</p>"""
1712
- )
1713
- fix_openpose = gr.Image(
1714
- value="openpose.png",
1715
- type="numpy",
1716
- # label="OpenPose keypoints convention",
1717
- show_label=False,
1718
- height=LENGTH // 2,
1719
- width=LENGTH // 2,
1720
- interactive=False,
1721
- )
1722
- with gr.Column():
1723
- # gr.Markdown(
1724
- # """<p style="text-align: center; font-size: 25px; font-weight: bold; ">3. Prepare Mask</p>"""
1725
- # )
1726
- gr.Markdown(
1727
- """<p style="text-align: center; font-size: 20px; font-weight: bold;">3. Press &quot;Ready&quot; to start pre-processing</p>"""
1728
- )
1729
- fix_ready = gr.Button(value="Ready", interactive=False)
1730
- # fix_mask_size = gr.Radio(
1731
- # ["256x256", "latent size (32x32)"],
1732
- # label="Visualized inpaint mask size",
1733
- # interactive=False,
1734
- # value="256x256",
1735
- # )
1736
- gr.Markdown(
1737
- """<p style="text-align: center; font-weight: bold; ">Visualized (256, 256) Inpaint Mask</p>"""
1738
- )
1739
- fix_vis_mask32 = gr.Image(
1740
- type="numpy",
1741
- label=f"Visualized {opts.latent_size} Inpaint Mask",
1742
- show_label=True,
1743
- height=opts.latent_size,
1744
- width=opts.latent_size,
1745
- interactive=False,
1746
- visible=False,
1747
- )
1748
- fix_vis_mask256 = gr.Image(
1749
- type="numpy",
1750
- # label=f"Visualized {opts.image_size} Inpaint Mask",
1751
- visible=True,
1752
- show_label=False,
1753
- height=opts.image_size,
1754
- width=opts.image_size,
1755
- interactive=False,
1756
- )
1757
- gr.Markdown(
1758
- """<p style="text-align: center;">[NOTE] Above should be inpaint mask that you brushed, NOT the segmentation mask of the entire hand. </p>"""
1759
- )
1760
- with gr.Column():
1761
- # gr.Markdown(
1762
- # """<p style="text-align: center; font-size: 25px; font-weight: bold; ">4. Results</p>"""
1763
- # )
1764
- gr.Markdown(
1765
- """<p style="text-align: center; font-size: 20px; font-weight: bold;">4. Press &quot;Run&quot; to get the fixed hand image 🎯</p>"""
1766
- )
1767
- fix_run = gr.Button(value="Run", interactive=False)
1768
- gr.Markdown(
1769
- """<p style="text-align: center;">⚠️ >3min and ~24GB per generation</p>"""
1770
- )
1771
- fix_result = gr.Gallery(
1772
- type="numpy",
1773
- label="Results",
1774
- show_label=True,
1775
- height=LENGTH,
1776
- min_width=LENGTH,
1777
- columns=FIX_MAX_N,
1778
- interactive=False,
1779
- preview=True,
1780
- )
1781
- fix_result_pose = gr.Gallery(
1782
- type="numpy",
1783
- label="Results Pose",
1784
- show_label=True,
1785
- height=LENGTH,
1786
- min_width=LENGTH,
1787
- columns=FIX_MAX_N,
1788
- interactive=False,
1789
- preview=True,
1790
- )
1791
- gr.Markdown(
1792
- """<p style="text-align: center;">✨ Hit &quot;Clear&quot; to restart from the beginning</p>"""
1793
- )
1794
- fix_clear = gr.ClearButton()
1795
-
1796
- gr.Markdown(
1797
- """<p style="text-align: left; font-size: 25px;"><b>More options</b></p>"""
1798
- )
1799
- gr.Markdown(
1800
- "⚠️ Currently, Number of generation > 1 could lead to out-of-memory"
1801
- )
1802
- with gr.Row():
1803
- fix_n_generation = gr.Slider(
1804
- label="Number of generations",
1805
- value=1,
1806
- minimum=1,
1807
- maximum=FIX_MAX_N,
1808
- step=1,
1809
- randomize=False,
1810
- interactive=True,
1811
- )
1812
- fix_seed = gr.Slider(
1813
- label="Seed",
1814
- value=42,
1815
- minimum=0,
1816
- maximum=10000,
1817
- step=1,
1818
- randomize=False,
1819
- interactive=True,
1820
- )
1821
- fix_cfg = gr.Slider(
1822
- label="Classifier free guidance scale",
1823
- value=3.0,
1824
- minimum=0.0,
1825
- maximum=10.0,
1826
- step=0.1,
1827
- randomize=False,
1828
- interactive=True,
1829
- )
1830
- fix_quality = gr.Slider(
1831
- label="Quality",
1832
- value=10,
1833
- minimum=1,
1834
- maximum=10,
1835
- step=1,
1836
- randomize=False,
1837
- interactive=True,
1838
- )
1839
- fix_crop.change(enable_component, [fix_crop, fix_crop], fix_ref)
1840
- fix_crop.change(resize_to_full, fix_crop, fix_ref)
1841
- fix_ref.change(enable_component, [fix_ref, fix_ref], fix_finish_crop)
1842
- fix_finish_crop.click(get_mask_inpaint, [fix_ref], [fix_inpaint_mask])
1843
- # fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_right])
1844
- # fix_finish_crop.click(lambda x: x["background"], [fix_ref], [fix_kp_left])
1845
- fix_finish_crop.click(lambda x: x["background"], [fix_crop], [fix_original])
1846
- fix_finish_crop.click(visualize_ref, [fix_crop, fix_ref], [fix_img])
1847
- fix_img.change(lambda x: x, [fix_img], [fix_kp_right])
1848
- fix_img.change(lambda x: x, [fix_img], [fix_kp_left])
1849
- fix_inpaint_mask.change(
1850
- enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_checkbox
1851
- )
1852
- fix_inpaint_mask.change(
1853
- enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_right
1854
- )
1855
- fix_inpaint_mask.change(
1856
- enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_right
1857
- )
1858
- fix_inpaint_mask.change(
1859
- enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_right
1860
- )
1861
- fix_inpaint_mask.change(
1862
- enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_kp_left
1863
- )
1864
- fix_inpaint_mask.change(
1865
- enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_undo_left
1866
- )
1867
- fix_inpaint_mask.change(
1868
- enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_reset_left
1869
- )
1870
- fix_inpaint_mask.change(
1871
- enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_ready
1872
- )
1873
- # fix_inpaint_mask.change(
1874
- # enable_component, [fix_inpaint_mask, fix_inpaint_mask], fix_run
1875
- # )
1876
- fix_checkbox.select(
1877
- set_visible,
1878
- [fix_checkbox, fix_kpts, fix_img, fix_kp_right, fix_kp_left],
1879
- [
1880
- fix_kpts,
1881
- fix_kp_right,
1882
- fix_kp_left,
1883
- fix_kp_right,
1884
- fix_undo_right,
1885
- fix_reset_right,
1886
- fix_kp_left,
1887
- fix_undo_left,
1888
- fix_reset_left,
1889
- fix_kp_r_info,
1890
- fix_kp_l_info,
1891
- ],
1892
- )
1893
- fix_kp_right.select(
1894
- get_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1895
- )
1896
- fix_undo_right.click(
1897
- undo_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1898
- )
1899
- fix_reset_right.click(
1900
- reset_kps, [fix_img, fix_kpts, gr.State("right")], [fix_kp_right, fix_kpts]
1901
- )
1902
- fix_kp_left.select(
1903
- get_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
1904
- )
1905
- fix_undo_left.click(
1906
- undo_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
1907
- )
1908
- fix_reset_left.click(
1909
- reset_kps, [fix_img, fix_kpts, gr.State("left")], [fix_kp_left, fix_kpts]
1910
- )
1911
- # fix_kpts.change(check_keypoints, [fix_kpts], [fix_kp_right, fix_kp_left, fix_run])
1912
- # fix_run.click(lambda x:gr.update(value=None), [], [fix_result, fix_result_pose])
1913
- fix_vis_mask32.change(
1914
- enable_component, [fix_vis_mask32, fix_vis_mask256], fix_run
1915
- )
1916
- # fix_vis_mask32.change(
1917
- # enable_component, [fix_vis_mask32, fix_vis_mask256], fix_mask_size
1918
- # )
1919
- fix_ready.click(
1920
- ready_sample,
1921
- [fix_original, fix_inpaint_mask, fix_kpts],
1922
- [
1923
- fix_ref_cond,
1924
- fix_target_cond,
1925
- fix_latent,
1926
- fix_inpaint_latent,
1927
- fix_kpts_np,
1928
- fix_vis_mask32,
1929
- fix_vis_mask256,
1930
- ],
1931
- )
1932
- # fix_mask_size.select(
1933
- # switch_mask_size, [fix_mask_size], [fix_vis_mask32, fix_vis_mask256]
1934
- # )
1935
- fix_run.click(
1936
- sample_inpaint,
1937
- [
1938
- fix_ref_cond,
1939
- fix_target_cond,
1940
- fix_latent,
1941
- fix_inpaint_latent,
1942
- fix_kpts_np,
1943
- fix_n_generation,
1944
- fix_seed,
1945
- fix_cfg,
1946
- fix_quality,
1947
- ],
1948
- [fix_result, fix_result_pose],
1949
- )
1950
- fix_clear.click(
1951
- fix_clear_all,
1952
- [],
1953
- [
1954
- fix_crop,
1955
- fix_ref,
1956
- fix_kp_right,
1957
- fix_kp_left,
1958
- fix_result,
1959
- fix_result_pose,
1960
- fix_inpaint_mask,
1961
- fix_original,
1962
- fix_img,
1963
- fix_vis_mask32,
1964
- fix_vis_mask256,
1965
- fix_kpts,
1966
- fix_kpts_np,
1967
- fix_ref_cond,
1968
- fix_target_cond,
1969
- fix_latent,
1970
- fix_inpaint_latent,
1971
- fix_n_generation,
1972
- # fix_size_memory,
1973
- fix_seed,
1974
- fix_cfg,
1975
- fix_quality,
1976
- ],
1977
- )
1978
-
1979
- # gr.Markdown("""<p style="font-size: 25px; font-weight: bold;">Examples</p>""")
1980
- # fix_dump_ex = gr.Image(value=None, label="Original Image", visible=False)
1981
- # fix_dump_ex_masked = gr.Image(value=None, label="After Brushing", visible=False)
1982
- # with gr.Column():
1983
- # fix_example = gr.Examples(
1984
- # fix_example_imgs,
1985
- # # run_on_click=True,
1986
- # # fn=parse_fix_example,
1987
- # # inputs=[fix_dump_ex, fix_dump_ex_masked],
1988
- # # outputs=[fix_original, fix_ref, fix_img, fix_inpaint_mask],
1989
- # inputs=[fix_crop],
1990
- # examples_per_page=20,
1991
- # )
1992
-
1993
- gr.Markdown("<h1>Citation</h1>")
1994
- gr.Markdown(
1995
- """<p style="text-align: left;">If this was useful, please cite us! ❤️</p>"""
1996
- )
1997
- gr.Markdown(_CITE_)
1998
-
1999
- print("Ready to launch..")
2000
- _, _, shared_url = demo.queue().launch(
2001
- share=True, server_name="0.0.0.0", server_port=7739
2002
- )
2003
- # demo.launch(share=True)