Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,9 +8,9 @@ from transformers import DPTImageProcessor, DPTForDepthEstimation
|
|
8 |
import warnings
|
9 |
warnings.filterwarnings("ignore")
|
10 |
|
11 |
-
# Load segmentation model
|
12 |
-
seg_processor = AutoImageProcessor.from_pretrained("
|
13 |
-
seg_model = AutoModelForSemanticSegmentation.from_pretrained("
|
14 |
|
15 |
# Load depth estimation model
|
16 |
depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
@@ -99,22 +99,44 @@ def apply_depth_blur(image, depth_map, max_sigma=25):
|
|
99 |
return result
|
100 |
|
101 |
def get_segmentation_mask(image_pil):
|
102 |
-
"""Get segmentation mask for person
|
|
|
|
|
|
|
|
|
103 |
# Process the image with the segmentation model
|
104 |
-
inputs = seg_processor(images=
|
105 |
-
|
|
|
106 |
|
107 |
# Get the predicted segmentation mask
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
# Convert the mask to a numpy array
|
111 |
mask_np = predicted_mask.cpu().numpy()
|
112 |
|
113 |
-
#
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
return
|
118 |
|
119 |
def get_depth_map(image_pil):
|
120 |
"""Get depth map from an image."""
|
|
|
8 |
import warnings
|
9 |
warnings.filterwarnings("ignore")
|
10 |
|
11 |
+
# Load segmentation model - using SegFormer which is compatible with AutoModelForSemanticSegmentation
|
12 |
+
seg_processor = AutoImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
13 |
+
seg_model = AutoModelForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
14 |
|
15 |
# Load depth estimation model
|
16 |
depth_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
|
|
|
99 |
return result
|
100 |
|
101 |
def get_segmentation_mask(image_pil):
|
102 |
+
"""Get segmentation mask for person/foreground from an image."""
|
103 |
+
# Resize the image to the size expected by the segmentation model
|
104 |
+
width, height = image_pil.size
|
105 |
+
image_pil_resized = image_pil.resize((512, 512))
|
106 |
+
|
107 |
# Process the image with the segmentation model
|
108 |
+
inputs = seg_processor(images=image_pil_resized, return_tensors="pt")
|
109 |
+
with torch.no_grad():
|
110 |
+
outputs = seg_model(**inputs)
|
111 |
|
112 |
# Get the predicted segmentation mask
|
113 |
+
logits = outputs.logits
|
114 |
+
upsampled_logits = torch.nn.functional.interpolate(
|
115 |
+
logits,
|
116 |
+
size=(512, 512),
|
117 |
+
mode="bilinear",
|
118 |
+
align_corners=False,
|
119 |
+
)
|
120 |
+
|
121 |
+
# Get the predicted segmentation mask
|
122 |
+
predicted_mask = upsampled_logits.argmax(dim=1)[0]
|
123 |
|
124 |
# Convert the mask to a numpy array
|
125 |
mask_np = predicted_mask.cpu().numpy()
|
126 |
|
127 |
+
# Create a foreground mask - considering classes that are likely to be foreground
|
128 |
+
# The ADE20K dataset has 150 classes, so we need to choose which ones to consider as foreground
|
129 |
+
# Common foreground classes: person (12), animal classes, and objects like furniture
|
130 |
+
# This is a simplified approach - you may need to adjust based on your needs
|
131 |
+
foreground_classes = [12, 13, 14, 15, 16, 17, 18, 19, 20] # Person and some objects
|
132 |
+
foreground_mask = np.zeros_like(mask_np)
|
133 |
+
for cls in foreground_classes:
|
134 |
+
foreground_mask[mask_np == cls] = 1
|
135 |
+
|
136 |
+
# Resize back to original image size
|
137 |
+
foreground_mask = cv2.resize(foreground_mask, (width, height))
|
138 |
|
139 |
+
return foreground_mask
|
140 |
|
141 |
def get_depth_map(image_pil):
|
142 |
"""Get depth map from an image."""
|