zerchen commited on
Commit
2a7450b
·
1 Parent(s): e3f8288

handle multiple hands

Browse files
Files changed (1) hide show
  1. app.py +96 -71
app.py CHANGED
@@ -62,28 +62,35 @@ hand_detector = hand_detector.to(device)
62
  hort_model = hort_model.to(device)
63
  wilor_model.eval()
64
  hort_model.eval()
65
-
66
  image_transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  @spaces.GPU()
69
  def run_model(image, conf, IoU_threshold=0.5):
70
  img_cv2 = image[..., ::-1]
71
  img_pil = Image.fromarray(image)
72
 
73
  pred_obj = sam_model.predict([img_pil], ["manipulated object"])
74
- pred_hand = sam_model.predict([img_pil], ["hand"])
75
-
76
  bbox_obj = pred_obj[0]["boxes"][0].reshape((-1, 2))
77
- mask_obj = pred_obj[0]["masks"][0]
78
- bbox_hand = pred_hand[0]["boxes"][0].reshape((-1, 2))
79
- mask_hand = pred_hand[0]["masks"][0]
80
 
81
- tl = np.min(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
82
- br = np.max(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
83
- box_size = br - tl
84
- bbox = np.concatenate([tl - 10, box_size + 20], axis=0)
85
- ho_bbox = process_bbox(bbox)
86
-
87
  detections = hand_detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0]
88
 
89
  bboxes = []
@@ -92,60 +99,81 @@ def run_model(image, conf, IoU_threshold=0.5):
92
  Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
93
  is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
94
  bboxes.append(Bbox[:4].tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- if len(bboxes) == 1:
97
- boxes = np.stack(bboxes)
98
- right = np.stack(is_right)
99
- if not right:
100
- new_x1 = img_cv2.shape[1] - boxes[0][2]
101
- new_x2 = img_cv2.shape[1] - boxes[0][0]
102
- boxes[0][0] = new_x1
103
- boxes[0][2] = new_x2
104
- ho_bbox[0] = img_cv2.shape[1] - (ho_bbox[0] + ho_bbox[2])
105
- img_cv2 = cv2.flip(img_cv2, 1)
106
- right[0] = 1.
107
- crop_img_cv2, _ = generate_patch_image(img_cv2, ho_bbox, (224, 224), 0, 1.0, 0)
108
-
109
- dataset = ViTDetDataset(wilor_model_cfg, img_cv2, boxes, right, rescale_factor=2.0)
110
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)
111
 
112
- for batch in dataloader:
113
- batch = recursive_to(batch, device)
114
 
115
- with torch.no_grad():
116
- out = wilor_model(batch)
117
-
118
- pred_cam = out['pred_cam']
119
- box_center = batch["box_center"].float()
120
- box_size = batch["box_size"].float()
121
- img_size = batch["img_size"].float()
122
- scaled_focal_length = wilor_model_cfg.EXTRA.FOCAL_LENGTH / wilor_model_cfg.MODEL.IMAGE_SIZE * 224
123
- pred_cam_t_full = cam_crop_to_new(pred_cam, box_center, box_size, img_size, torch.from_numpy(np.array(ho_bbox, dtype=np.float32))[None, :].to(img_size.device), scaled_focal_length).detach().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- batch_size = batch['img'].shape[0]
126
- for n in range(batch_size):
127
- verts = out['pred_vertices'][n].detach().cpu().numpy()
128
- joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
129
-
130
- is_right = batch['right'][n].cpu().numpy()
131
- palm = (verts[95] + verts[22]) / 2
132
- cam_t = pred_cam_t_full[n]
133
-
134
- img_input = image_transform(crop_img_cv2[:, :, ::-1]).unsqueeze(0).cuda()
135
- camera = PerspectiveCamera(5000 / 256 * 224, 5000 / 256 * 224, 112, 112)
136
- cam_intr = camera.intrinsics
137
-
138
- metas = dict()
139
- metas["right_hand_verts_3d"] = torch.from_numpy((verts + cam_t)[None]).cuda()
140
- metas["right_hand_joints_3d"] = torch.from_numpy((joints + cam_t)[None]).cuda()
141
- metas["right_hand_palm"] = torch.from_numpy((palm + cam_t)[None]).cuda()
142
- metas["cam_intr"] = torch.from_numpy(cam_intr[None]).cuda()
143
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
144
- pc_results = hort_model(img_input, metas)
145
- objtrans = pc_results["objtrans"][0].detach().cpu().numpy()
146
- pointclouds_up = pc_results["pointclouds_up"][0].detach().cpu().numpy() * 0.3
147
-
148
- reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length}
149
 
150
  return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions
151
  else:
@@ -154,18 +182,15 @@ def run_model(image, conf, IoU_threshold=0.5):
154
 
155
  def render_reconstruction(image, conf, IoU_threshold=0.3):
156
  input_img, num_dets, reconstructions = run_model(image, conf, IoU_threshold=0.5)
157
- if num_dets == 1:
158
  # Render front view
159
- misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal'])
160
- cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args)
161
 
162
- # Overlay image
163
- input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
164
- input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
165
 
166
- return input_img_overlay, f'{num_dets} hands detected'
167
- else:
168
- return input_img, f'{num_dets} hands detected'
169
 
170
 
171
  header = ('''
 
62
  hort_model = hort_model.to(device)
63
  wilor_model.eval()
64
  hort_model.eval()
 
65
  image_transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
66
 
67
+
68
+ def calculate_iou(box1, box2):
69
+ x1_inter = max(box1[0], box2[0])
70
+ y1_inter = max(box1[1], box2[1])
71
+ x2_inter = min(box1[2], box2[2])
72
+ y2_inter = min(box1[3], box2[3])
73
+ # Compute intersection area
74
+ inter_width = max(0, x2_inter - x1_inter)
75
+ inter_height = max(0, y2_inter - y1_inter)
76
+ intersection = inter_width * inter_height
77
+ # Compute areas of each box
78
+ area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
79
+ area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
80
+ # Compute union
81
+ union = area_box1 + area_box2 - intersection
82
+ # Compute IoU
83
+ return intersection / union if union > 0 else 0.0
84
+
85
+
86
  @spaces.GPU()
87
  def run_model(image, conf, IoU_threshold=0.5):
88
  img_cv2 = image[..., ::-1]
89
  img_pil = Image.fromarray(image)
90
 
91
  pred_obj = sam_model.predict([img_pil], ["manipulated object"])
 
 
92
  bbox_obj = pred_obj[0]["boxes"][0].reshape((-1, 2))
 
 
 
93
 
 
 
 
 
 
 
94
  detections = hand_detector(img_cv2, conf=conf, verbose=False, iou=IoU_threshold)[0]
95
 
96
  bboxes = []
 
99
  Bbox = det.boxes.data.cpu().detach().squeeze().numpy()
100
  is_right.append(det.boxes.cls.cpu().detach().squeeze().item())
101
  bboxes.append(Bbox[:4].tolist())
102
+
103
+ if len(bboxes) == 0:
104
+ print("no hands in this image")
105
+ elif len(bboxes) == 1:
106
+ bbox_hand = np.array(bboxes[0]).reshape((-1, 2))
107
+ elif len(bboxes) > 1:
108
+ hand_idx = None
109
+ max_iou = -10.
110
+ for cur_idx, cur_bbox in enumerate(bboxes):
111
+ cur_iou = calculate_iou(cur_bbox, bbox_obj.reshape(-1).tolist())
112
+ if cur_iou >= max_iou:
113
+ hand_idx = cur_idx
114
+ max_iou = cur_iou
115
+ bbox_hand = np.array(bboxes[hand_idx]).reshape((-1, 2))
116
+ bboxes = [bboxes[hand_idx]]
117
+ is_right = [is_right[hand_idx]]
118
+
119
+ tl = np.min(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
120
+ br = np.max(np.concatenate([bbox_obj, bbox_hand], axis=0), axis=0)
121
+ box_size = br - tl
122
+ bbox = np.concatenate([tl - 10, box_size + 20], axis=0)
123
+ ho_bbox = process_bbox(bbox)
124
 
125
+ boxes = np.stack(bboxes)
126
+ right = np.stack(is_right)
127
+ if not right:
128
+ new_x1 = img_cv2.shape[1] - boxes[0][2]
129
+ new_x2 = img_cv2.shape[1] - boxes[0][0]
130
+ boxes[0][0] = new_x1
131
+ boxes[0][2] = new_x2
132
+ ho_bbox[0] = img_cv2.shape[1] - (ho_bbox[0] + ho_bbox[2])
133
+ img_cv2 = cv2.flip(img_cv2, 1)
134
+ right[0] = 1.
135
+ crop_img_cv2, _ = generate_patch_image(img_cv2, ho_bbox, (224, 224), 0, 1.0, 0)
136
+
137
+ dataset = ViTDetDataset(wilor_model_cfg, img_cv2, boxes, right, rescale_factor=2.0)
138
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0)
 
139
 
140
+ for batch in dataloader:
141
+ batch = recursive_to(batch, device)
142
 
143
+ with torch.no_grad():
144
+ out = wilor_model(batch)
145
+
146
+ pred_cam = out['pred_cam']
147
+ box_center = batch["box_center"].float()
148
+ box_size = batch["box_size"].float()
149
+ img_size = batch["img_size"].float()
150
+ scaled_focal_length = wilor_model_cfg.EXTRA.FOCAL_LENGTH / wilor_model_cfg.MODEL.IMAGE_SIZE * 224
151
+ pred_cam_t_full = cam_crop_to_new(pred_cam, box_center, box_size, img_size, torch.from_numpy(np.array(ho_bbox, dtype=np.float32))[None, :].to(img_size.device), scaled_focal_length).detach().cpu().numpy()
152
+
153
+ batch_size = batch['img'].shape[0]
154
+ for n in range(batch_size):
155
+ verts = out['pred_vertices'][n].detach().cpu().numpy()
156
+ joints = out['pred_keypoints_3d'][n].detach().cpu().numpy()
157
+
158
+ is_right = batch['right'][n].cpu().numpy()
159
+ palm = (verts[95] + verts[22]) / 2
160
+ cam_t = pred_cam_t_full[n]
161
+
162
+ img_input = image_transform(crop_img_cv2[:, :, ::-1]).unsqueeze(0).cuda()
163
+ camera = PerspectiveCamera(5000 / 256 * 224, 5000 / 256 * 224, 112, 112)
164
+ cam_intr = camera.intrinsics
165
+
166
+ metas = dict()
167
+ metas["right_hand_verts_3d"] = torch.from_numpy((verts + cam_t)[None]).cuda()
168
+ metas["right_hand_joints_3d"] = torch.from_numpy((joints + cam_t)[None]).cuda()
169
+ metas["right_hand_palm"] = torch.from_numpy((palm + cam_t)[None]).cuda()
170
+ metas["cam_intr"] = torch.from_numpy(cam_intr[None]).cuda()
171
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
172
+ pc_results = hort_model(img_input, metas)
173
+ objtrans = pc_results["objtrans"][0].detach().cpu().numpy()
174
+ pointclouds_up = pc_results["pointclouds_up"][0].detach().cpu().numpy() * 0.3
175
 
176
+ reconstructions = {'verts': verts, 'palm': palm, 'objtrans': objtrans, 'objpcs': pointclouds_up, 'cam_t': cam_t, 'right': is_right, 'img_size': 224, 'focal': scaled_focal_length}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  return crop_img_cv2[..., ::-1].astype(np.float32) / 255.0, len(detections), reconstructions
179
  else:
 
182
 
183
  def render_reconstruction(image, conf, IoU_threshold=0.3):
184
  input_img, num_dets, reconstructions = run_model(image, conf, IoU_threshold=0.5)
 
185
  # Render front view
186
+ misc_args = dict(mesh_base_color=LIGHT_PURPLE, point_base_color=STEEL_BLUE, scene_bg_color=(1, 1, 1), focal_length=reconstructions['focal'])
187
+ cam_view = renderer.render_rgba(reconstructions['verts'], reconstructions['objpcs'] + reconstructions['palm'] + reconstructions['objtrans'], cam_t=reconstructions['cam_t'], render_res=(224, 224), is_right=True, **misc_args)
188
 
189
+ # Overlay image
190
+ input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel
191
+ input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:]
192
 
193
+ return input_img_overlay, f'{num_dets} hands detected'
 
 
194
 
195
 
196
  header = ('''