Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -15,11 +15,10 @@ from huggingface_hub import snapshot_download, login, HfFileSystem
|
|
15 |
from flask_cors import CORS
|
16 |
import numpy as np
|
17 |
import trimesh
|
18 |
-
from transformers import pipeline
|
19 |
from scipy.ndimage import gaussian_filter
|
20 |
from scipy import interpolate
|
21 |
import cv2
|
22 |
-
from rembg import remove
|
23 |
|
24 |
app = Flask(__name__)
|
25 |
CORS(app)
|
@@ -45,8 +44,6 @@ processing_jobs = {}
|
|
45 |
|
46 |
# Model variables
|
47 |
dpt_estimator = None
|
48 |
-
depth_anything_model = None
|
49 |
-
depth_anything_processor = None
|
50 |
model_loaded = False
|
51 |
model_loading = False
|
52 |
|
@@ -89,21 +86,36 @@ def allowed_file(filename):
|
|
89 |
|
90 |
def remove_background(image_path):
|
91 |
try:
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
img = Image.open(io.BytesIO(result)).convert("RGBA")
|
96 |
|
97 |
-
#
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
return None
|
102 |
|
103 |
-
#
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
107 |
except Exception as e:
|
108 |
print(f"Error in remove_background for {image_path}: {str(e)}")
|
109 |
raise
|
@@ -111,7 +123,7 @@ def remove_background(image_path):
|
|
111 |
def preprocess_image(image_path):
|
112 |
img = remove_background(image_path)
|
113 |
if img is None:
|
114 |
-
raise ValueError("
|
115 |
|
116 |
if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
|
117 |
if img.width > img.height:
|
@@ -135,15 +147,15 @@ def preprocess_image(image_path):
|
|
135 |
return img
|
136 |
|
137 |
def load_models():
|
138 |
-
global dpt_estimator,
|
139 |
|
140 |
if model_loaded:
|
141 |
-
return dpt_estimator
|
142 |
|
143 |
if model_loading:
|
144 |
while model_loading and not model_loaded:
|
145 |
time.sleep(0.5)
|
146 |
-
return dpt_estimator
|
147 |
|
148 |
try:
|
149 |
model_loading = True
|
@@ -155,7 +167,8 @@ def load_models():
|
|
155 |
login(token=hf_token)
|
156 |
print("Authenticated with Hugging Face token")
|
157 |
else:
|
158 |
-
print("
|
|
|
159 |
|
160 |
dpt_model_name = "Intel/dpt-large"
|
161 |
fs = HfFileSystem(token=hf_token)
|
@@ -195,49 +208,8 @@ def load_models():
|
|
195 |
print("DPT-Large loaded")
|
196 |
gc.collect()
|
197 |
|
198 |
-
da_model_name = "LiheYoung/depth-anything-v2-small"
|
199 |
-
da_model_cached = os.path.exists(os.path.join(CACHE_DIR, "hub", "models--LiheYoung--depth-anything-v2-small"))
|
200 |
-
|
201 |
-
if not da_model_cached:
|
202 |
-
for attempt in range(max_retries):
|
203 |
-
try:
|
204 |
-
print(f"Attempting to download {da_model_name}, attempt {attempt+1}")
|
205 |
-
snapshot_download(
|
206 |
-
repo_id=da_model_name,
|
207 |
-
cache_dir=CACHE_DIR,
|
208 |
-
resume_download=True,
|
209 |
-
token=hf_token
|
210 |
-
)
|
211 |
-
print(f"Successfully downloaded {da_model_name}")
|
212 |
-
break
|
213 |
-
except Exception as e:
|
214 |
-
if attempt < max_retries - 1:
|
215 |
-
print(f"Depth Anything download attempt {attempt+1} failed: {str(e)}. Retrying after {retry_delay}s...")
|
216 |
-
time.sleep(retry_delay)
|
217 |
-
retry_delay *= 2
|
218 |
-
else:
|
219 |
-
print(f"Failed to load Depth Anything: {str(e)}. Falling back to DPT-Large only.")
|
220 |
-
depth_anything_model = None
|
221 |
-
depth_anything_processor = None
|
222 |
-
model_loaded = True
|
223 |
-
return dpt_estimator, None, None
|
224 |
-
else:
|
225 |
-
print(f"{da_model_name} already cached in {CACHE_DIR}")
|
226 |
-
|
227 |
-
depth_anything_processor = AutoImageProcessor.from_pretrained(
|
228 |
-
da_model_name,
|
229 |
-
cache_dir=CACHE_DIR,
|
230 |
-
token=hf_token
|
231 |
-
)
|
232 |
-
depth_anything_model = AutoModelForDepthEstimation.from_pretrained(
|
233 |
-
da_model_name,
|
234 |
-
cache_dir=CACHE_DIR,
|
235 |
-
token=hf_token
|
236 |
-
).to("cpu")
|
237 |
-
|
238 |
model_loaded = True
|
239 |
-
|
240 |
-
return dpt_estimator, depth_anything_model, depth_anything_processor
|
241 |
|
242 |
except Exception as e:
|
243 |
print(f"Error loading models: {str(e)}")
|
@@ -246,38 +218,6 @@ def load_models():
|
|
246 |
finally:
|
247 |
model_loading = False
|
248 |
|
249 |
-
def fuse_depth_maps(dpt_depth, da_depth, detail_level='medium'):
|
250 |
-
if isinstance(dpt_depth, Image.Image):
|
251 |
-
dpt_depth = np.array(dpt_depth)
|
252 |
-
if isinstance(da_depth, torch.Tensor):
|
253 |
-
da_depth = da_depth.cpu().numpy()
|
254 |
-
if len(dpt_depth.shape) > 2:
|
255 |
-
dpt_depth = np.mean(dpt_depth, axis=2)
|
256 |
-
if len(da_depth.shape) > 2:
|
257 |
-
da_depth = np.mean(da_depth, axis=2)
|
258 |
-
|
259 |
-
if dpt_depth.shape != da_depth.shape:
|
260 |
-
da_depth = cv2.resize(da_depth, (dpt_depth.shape[1], dpt_depth.shape[0]), interpolation=cv2.INTER_CUBIC)
|
261 |
-
|
262 |
-
p_low_dpt, p_high_dpt = np.percentile(dpt_depth, [1, 99])
|
263 |
-
p_low_da, p_high_da = np.percentile(da_depth, [1, 99])
|
264 |
-
dpt_depth = np.clip((dpt_depth - p_low_dpt) / (p_high_dpt - p_low_dpt), 0, 1) if p_high_dpt > p_low_dpt else dpt_depth
|
265 |
-
da_depth = np.clip((da_depth - p_low_da) / (p_high_da - p_low_da), 0, 1) if p_high_da > p_low_da else da_depth
|
266 |
-
|
267 |
-
if detail_level == 'high':
|
268 |
-
weight_da = 0.7
|
269 |
-
edges = cv2.Canny((da_depth * 255).astype(np.uint8), 50, 150)
|
270 |
-
edge_mask = (edges > 0).astype(np.float32)
|
271 |
-
dpt_weight = gaussian_filter(1 - edge_mask, sigma=1.0)
|
272 |
-
da_weight = gaussian_filter(edge_mask, sigma=1.0)
|
273 |
-
fused_depth = dpt_weight * dpt_depth + da_weight * da_depth * weight_da + (1 - weight_da) * dpt_depth
|
274 |
-
else:
|
275 |
-
weight_da = 0.5 if detail_level == 'medium' else 0.3
|
276 |
-
fused_depth = (1 - weight_da) * dpt_depth + weight_da * da_depth
|
277 |
-
|
278 |
-
fused_depth = np.clip(fused_depth, 0, 1)
|
279 |
-
return fused_depth
|
280 |
-
|
281 |
def enhance_depth_map(depth_map, detail_level='medium'):
|
282 |
enhanced_depth = depth_map.copy().astype(np.float32)
|
283 |
p_low, p_high = np.percentile(enhanced_depth, [1, 99])
|
@@ -420,7 +360,7 @@ def combine_meshes(meshes):
|
|
420 |
def health_check():
|
421 |
return jsonify({
|
422 |
"status": "healthy",
|
423 |
-
"model": "DPT-Large
|
424 |
"device": "cpu"
|
425 |
}), 200
|
426 |
|
@@ -527,7 +467,7 @@ def convert_image_to_3d():
|
|
527 |
processing_jobs[job_id]['progress'] = 10
|
528 |
|
529 |
try:
|
530 |
-
dpt_model
|
531 |
processing_jobs[job_id]['progress'] = 20
|
532 |
except Exception as e:
|
533 |
processing_jobs[job_id]['status'] = 'error'
|
@@ -543,26 +483,13 @@ def convert_image_to_3d():
|
|
543 |
dpt_result = dpt_model(image)
|
544 |
dpt_depth = dpt_result["depth"]
|
545 |
|
546 |
-
if
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
da_depth = torch.nn.functional.interpolate(
|
552 |
-
da_depth.unsqueeze(0).unsqueeze(0),
|
553 |
-
size=(image.height, image.width),
|
554 |
-
mode='bicubic',
|
555 |
-
align_corners=False
|
556 |
-
).squeeze()
|
557 |
-
fused_depth = fuse_depth_maps(dpt_depth, da_depth, detail_level)
|
558 |
-
else:
|
559 |
-
fused_depth = np.array(dpt_depth) if isinstance(dpt_depth, Image.Image) else dpt_depth
|
560 |
-
if len(fused_depth.shape) > 2:
|
561 |
-
fused_depth = np.mean(fused_depth, axis=2)
|
562 |
-
p_low, p_high = np.percentile(fused_depth, [1, 99])
|
563 |
-
fused_depth = np.clip((fused_depth - p_low) / (p_high - p_low), 0, 1) if p_high > p_low else fused_depth
|
564 |
|
565 |
-
mesh = depth_to_mesh(
|
566 |
meshes.append(mesh)
|
567 |
gc.collect()
|
568 |
|
@@ -748,7 +675,7 @@ def model_info(job_id):
|
|
748 |
@app.route('/', methods=['GET'])
|
749 |
def index():
|
750 |
return jsonify({
|
751 |
-
"message": "Multi-View Image to 3D API (DPT-Large
|
752 |
"endpoints": [
|
753 |
"/convert",
|
754 |
"/progress/<job_id>",
|
@@ -766,7 +693,7 @@ def index():
|
|
766 |
"detail_level": "low, medium, or high",
|
767 |
"texture_quality": "low, medium, or high"
|
768 |
},
|
769 |
-
"description": "Creates
|
770 |
}), 200
|
771 |
|
772 |
if __name__ == '__main__':
|
|
|
15 |
from flask_cors import CORS
|
16 |
import numpy as np
|
17 |
import trimesh
|
18 |
+
from transformers import pipeline
|
19 |
from scipy.ndimage import gaussian_filter
|
20 |
from scipy import interpolate
|
21 |
import cv2
|
|
|
22 |
|
23 |
app = Flask(__name__)
|
24 |
CORS(app)
|
|
|
44 |
|
45 |
# Model variables
|
46 |
dpt_estimator = None
|
|
|
|
|
47 |
model_loaded = False
|
48 |
model_loading = False
|
49 |
|
|
|
86 |
|
87 |
def remove_background(image_path):
|
88 |
try:
|
89 |
+
# Load image
|
90 |
+
img = cv2.imread(image_path)
|
91 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
92 |
|
93 |
+
# Initialize mask and models for GrabCut
|
94 |
+
mask = np.zeros(img.shape[:2], np.uint8)
|
95 |
+
bgd_model = np.zeros((1, 65), np.float64)
|
96 |
+
fgd_model = np.zeros((1, 65), np.float64)
|
97 |
+
|
98 |
+
# Define initial rectangle (10% border margin)
|
99 |
+
h, w = img.shape[:2]
|
100 |
+
margin = int(min(w, h) * 0.1)
|
101 |
+
rect = (margin, margin, w - 2 * margin, h - 2 * margin)
|
102 |
+
|
103 |
+
# Run GrabCut
|
104 |
+
cv2.grabCut(img, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
|
105 |
+
|
106 |
+
# Create final mask (0 for background, 1 for foreground)
|
107 |
+
mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
|
108 |
+
|
109 |
+
# Check if foreground exists
|
110 |
+
if np.sum(mask2) == 0:
|
111 |
+
print(f"Warning: No foreground detected in {image_path}")
|
112 |
return None
|
113 |
|
114 |
+
# Apply mask and set background to black
|
115 |
+
img = img * mask2[:, :, np.newaxis]
|
116 |
+
img_pil = Image.fromarray(img).convert("RGB")
|
117 |
+
|
118 |
+
return img_pil
|
119 |
except Exception as e:
|
120 |
print(f"Error in remove_background for {image_path}: {str(e)}")
|
121 |
raise
|
|
|
123 |
def preprocess_image(image_path):
|
124 |
img = remove_background(image_path)
|
125 |
if img is None:
|
126 |
+
raise ValueError("No foreground detected in image")
|
127 |
|
128 |
if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
|
129 |
if img.width > img.height:
|
|
|
147 |
return img
|
148 |
|
149 |
def load_models():
|
150 |
+
global dpt_estimator, model_loaded, model_loading
|
151 |
|
152 |
if model_loaded:
|
153 |
+
return dpt_estimator
|
154 |
|
155 |
if model_loading:
|
156 |
while model_loading and not model_loaded:
|
157 |
time.sleep(0.5)
|
158 |
+
return dpt_estimator
|
159 |
|
160 |
try:
|
161 |
model_loading = True
|
|
|
167 |
login(token=hf_token)
|
168 |
print("Authenticated with Hugging Face token")
|
169 |
else:
|
170 |
+
print("Error: HF_TOKEN not found in environment. Intel/dpt-large requires authentication.")
|
171 |
+
raise ValueError("HF_TOKEN is required for Intel/dpt-large")
|
172 |
|
173 |
dpt_model_name = "Intel/dpt-large"
|
174 |
fs = HfFileSystem(token=hf_token)
|
|
|
208 |
print("DPT-Large loaded")
|
209 |
gc.collect()
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
model_loaded = True
|
212 |
+
return dpt_estimator
|
|
|
213 |
|
214 |
except Exception as e:
|
215 |
print(f"Error loading models: {str(e)}")
|
|
|
218 |
finally:
|
219 |
model_loading = False
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
def enhance_depth_map(depth_map, detail_level='medium'):
|
222 |
enhanced_depth = depth_map.copy().astype(np.float32)
|
223 |
p_low, p_high = np.percentile(enhanced_depth, [1, 99])
|
|
|
360 |
def health_check():
|
361 |
return jsonify({
|
362 |
"status": "healthy",
|
363 |
+
"model": "DPT-Large (Multi-View)",
|
364 |
"device": "cpu"
|
365 |
}), 200
|
366 |
|
|
|
467 |
processing_jobs[job_id]['progress'] = 10
|
468 |
|
469 |
try:
|
470 |
+
dpt_model = load_models()
|
471 |
processing_jobs[job_id]['progress'] = 20
|
472 |
except Exception as e:
|
473 |
processing_jobs[job_id]['status'] = 'error'
|
|
|
483 |
dpt_result = dpt_model(image)
|
484 |
dpt_depth = dpt_result["depth"]
|
485 |
|
486 |
+
depth_map = np.array(dpt_depth) if isinstance(dpt_depth, Image.Image) else dpt_depth
|
487 |
+
if len(depth_map.shape) > 2:
|
488 |
+
depth_map = np.mean(depth_map, axis=2)
|
489 |
+
p_low, p_high = np.percentile(depth_map, [1, 99])
|
490 |
+
depth_map = np.clip((depth_map - p_low) / (p_high - p_low), 0, 1) if p_high > p_low else depth_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
|
492 |
+
mesh = depth_to_mesh(depth_map, image, resolution=mesh_resolution, detail_level=detail_level, view_angle=view_angles[view])
|
493 |
meshes.append(mesh)
|
494 |
gc.collect()
|
495 |
|
|
|
675 |
@app.route('/', methods=['GET'])
|
676 |
def index():
|
677 |
return jsonify({
|
678 |
+
"message": "Multi-View Image to 3D API (DPT-Large)",
|
679 |
"endpoints": [
|
680 |
"/convert",
|
681 |
"/progress/<job_id>",
|
|
|
693 |
"detail_level": "low, medium, or high",
|
694 |
"texture_quality": "low, medium, or high"
|
695 |
},
|
696 |
+
"description": "Creates 3D models from multiple 2D images using Intel DPT-Large with custom background removal."
|
697 |
}), 200
|
698 |
|
699 |
if __name__ == '__main__':
|