Spaces:
Saad0KH
/
Running on Zero

Saad0KH commited on
Commit
bb5afb0
ยท
verified ยท
1 Parent(s): a24960d

Update utils_mask.py

Browse files
Files changed (1) hide show
  1. utils_mask.py +23 -3
utils_mask.py CHANGED
@@ -51,7 +51,7 @@ def refine_mask(mask):
51
 
52
  return refine_mask
53
 
54
- def get_mask_location(model_type, category, model_parse: Image.Image, keypoint: dict, width=384,height=512):
55
  im_parse = model_parse.resize((width, height), Image.NEAREST)
56
  parse_array = np.array(im_parse)
57
 
@@ -60,7 +60,7 @@ def get_mask_location(model_type, category, model_parse: Image.Image, keypoint:
60
  elif model_type == 'dc':
61
  arm_width = 45
62
  else:
63
- raise ValueError("model_type must be \'hd\' or \'dc\'!")
64
 
65
  parse_head = (parse_array == 1).astype(np.float32) + \
66
  (parse_array == 3).astype(np.float32) + \
@@ -82,7 +82,6 @@ def get_mask_location(model_type, category, model_parse: Image.Image, keypoint:
82
  (parse_array == 4).astype(np.float32) + \
83
  (parse_array == 5).astype(np.float32) + \
84
  (parse_array == 6).astype(np.float32)
85
-
86
  parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
87
 
88
  elif category == 'upper_body':
@@ -91,6 +90,7 @@ def get_mask_location(model_type, category, model_parse: Image.Image, keypoint:
91
  (parse_array == label_map["pants"]).astype(np.float32)
92
  parser_mask_fixed += parser_mask_fixed_lower_cloth
93
  parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
 
94
  elif category == 'lower_body':
95
  parse_mask = (parse_array == 6).astype(np.float32) + \
96
  (parse_array == 12).astype(np.float32) + \
@@ -100,9 +100,29 @@ def get_mask_location(model_type, category, model_parse: Image.Image, keypoint:
100
  (parse_array == 14).astype(np.float32) + \
101
  (parse_array == 15).astype(np.float32)
102
  parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  else:
104
  raise NotImplementedError
105
 
 
 
 
106
  # Load pose points
107
  pose_data = keypoint["pose_keypoints_2d"]
108
  pose_data = np.array(pose_data)
 
51
 
52
  return refine_mask
53
 
54
+ def get_mask_location(model_type, category, model_parse: Image.Image, keypoint: dict, width=384, height=512):
55
  im_parse = model_parse.resize((width, height), Image.NEAREST)
56
  parse_array = np.array(im_parse)
57
 
 
60
  elif model_type == 'dc':
61
  arm_width = 45
62
  else:
63
+ raise ValueError("model_type must be 'hd' or 'dc'!")
64
 
65
  parse_head = (parse_array == 1).astype(np.float32) + \
66
  (parse_array == 3).astype(np.float32) + \
 
82
  (parse_array == 4).astype(np.float32) + \
83
  (parse_array == 5).astype(np.float32) + \
84
  (parse_array == 6).astype(np.float32)
 
85
  parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
86
 
87
  elif category == 'upper_body':
 
90
  (parse_array == label_map["pants"]).astype(np.float32)
91
  parser_mask_fixed += parser_mask_fixed_lower_cloth
92
  parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
93
+
94
  elif category == 'lower_body':
95
  parse_mask = (parse_array == 6).astype(np.float32) + \
96
  (parse_array == 12).astype(np.float32) + \
 
100
  (parse_array == 14).astype(np.float32) + \
101
  (parse_array == 15).astype(np.float32)
102
  parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
103
+
104
+ elif category == 'full_body':
105
+ # Combine lower_body and upper_body
106
+ parse_mask_upper = (parse_array == 4).astype(np.float32) + (parse_array == 7).astype(np.float32)
107
+ parse_mask_lower = (parse_array == 6).astype(np.float32) + \
108
+ (parse_array == 12).astype(np.float32) + \
109
+ (parse_array == 13).astype(np.float32) + \
110
+ (parse_array == 5).astype(np.float32)
111
+ parse_mask = parse_mask_upper + parse_mask_lower
112
+
113
+ parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
114
+ (parse_array == label_map["skirt"]).astype(np.float32) + \
115
+ (parse_array == label_map["pants"]).astype(np.float32) + \
116
+ (parse_array == 14).astype(np.float32) + \
117
+ (parse_array == 15).astype(np.float32)
118
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
119
+
120
  else:
121
  raise NotImplementedError
122
 
123
+ # Rest of the function logic remains the same...
124
+
125
+
126
  # Load pose points
127
  pose_data = keypoint["pose_keypoints_2d"]
128
  pose_data = np.array(pose_data)