mac9087 commited on
Commit
27f2066
·
verified ·
1 Parent(s): d9a9271

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -118
app.py CHANGED
@@ -15,11 +15,10 @@ from huggingface_hub import snapshot_download, login, HfFileSystem
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
- from transformers import pipeline, AutoImageProcessor, AutoModelForDepthEstimation
19
  from scipy.ndimage import gaussian_filter
20
  from scipy import interpolate
21
  import cv2
22
- from rembg import remove
23
 
24
  app = Flask(__name__)
25
  CORS(app)
@@ -45,8 +44,6 @@ processing_jobs = {}
45
 
46
  # Model variables
47
  dpt_estimator = None
48
- depth_anything_model = None
49
- depth_anything_processor = None
50
  model_loaded = False
51
  model_loading = False
52
 
@@ -89,21 +86,36 @@ def allowed_file(filename):
89
 
90
  def remove_background(image_path):
91
  try:
92
- with open(image_path, "rb") as img_file:
93
- img_data = img_file.read()
94
- result = remove(img_data)
95
- img = Image.open(io.BytesIO(result)).convert("RGBA")
96
 
97
- # Check if image is fully transparent
98
- img_array = np.array(img)
99
- if np.all(img_array[:, :, 3] == 0):
100
- print(f"Warning: Image {image_path} is fully transparent or no object detected")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  return None
102
 
103
- # Create black background
104
- black_bg = Image.new("RGB", img.size, (0, 0, 0))
105
- black_bg.paste(img, (0, 0), img)
106
- return black_bg
 
107
  except Exception as e:
108
  print(f"Error in remove_background for {image_path}: {str(e)}")
109
  raise
@@ -111,7 +123,7 @@ def remove_background(image_path):
111
  def preprocess_image(image_path):
112
  img = remove_background(image_path)
113
  if img is None:
114
- raise ValueError("Image is fully transparent or no object detected")
115
 
116
  if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
117
  if img.width > img.height:
@@ -135,15 +147,15 @@ def preprocess_image(image_path):
135
  return img
136
 
137
  def load_models():
138
- global dpt_estimator, depth_anything_model, depth_anything_processor, model_loaded, model_loading
139
 
140
  if model_loaded:
141
- return dpt_estimator, depth_anything_model, depth_anything_processor
142
 
143
  if model_loading:
144
  while model_loading and not model_loaded:
145
  time.sleep(0.5)
146
- return dpt_estimator, depth_anything_model, depth_anything_processor
147
 
148
  try:
149
  model_loading = True
@@ -155,7 +167,8 @@ def load_models():
155
  login(token=hf_token)
156
  print("Authenticated with Hugging Face token")
157
  else:
158
- print("Warning: HF_TOKEN not found in environment")
 
159
 
160
  dpt_model_name = "Intel/dpt-large"
161
  fs = HfFileSystem(token=hf_token)
@@ -195,49 +208,8 @@ def load_models():
195
  print("DPT-Large loaded")
196
  gc.collect()
197
 
198
- da_model_name = "LiheYoung/depth-anything-v2-small"
199
- da_model_cached = os.path.exists(os.path.join(CACHE_DIR, "hub", "models--LiheYoung--depth-anything-v2-small"))
200
-
201
- if not da_model_cached:
202
- for attempt in range(max_retries):
203
- try:
204
- print(f"Attempting to download {da_model_name}, attempt {attempt+1}")
205
- snapshot_download(
206
- repo_id=da_model_name,
207
- cache_dir=CACHE_DIR,
208
- resume_download=True,
209
- token=hf_token
210
- )
211
- print(f"Successfully downloaded {da_model_name}")
212
- break
213
- except Exception as e:
214
- if attempt < max_retries - 1:
215
- print(f"Depth Anything download attempt {attempt+1} failed: {str(e)}. Retrying after {retry_delay}s...")
216
- time.sleep(retry_delay)
217
- retry_delay *= 2
218
- else:
219
- print(f"Failed to load Depth Anything: {str(e)}. Falling back to DPT-Large only.")
220
- depth_anything_model = None
221
- depth_anything_processor = None
222
- model_loaded = True
223
- return dpt_estimator, None, None
224
- else:
225
- print(f"{da_model_name} already cached in {CACHE_DIR}")
226
-
227
- depth_anything_processor = AutoImageProcessor.from_pretrained(
228
- da_model_name,
229
- cache_dir=CACHE_DIR,
230
- token=hf_token
231
- )
232
- depth_anything_model = AutoModelForDepthEstimation.from_pretrained(
233
- da_model_name,
234
- cache_dir=CACHE_DIR,
235
- token=hf_token
236
- ).to("cpu")
237
-
238
  model_loaded = True
239
- print("Depth Anything loaded")
240
- return dpt_estimator, depth_anything_model, depth_anything_processor
241
 
242
  except Exception as e:
243
  print(f"Error loading models: {str(e)}")
@@ -246,38 +218,6 @@ def load_models():
246
  finally:
247
  model_loading = False
248
 
249
- def fuse_depth_maps(dpt_depth, da_depth, detail_level='medium'):
250
- if isinstance(dpt_depth, Image.Image):
251
- dpt_depth = np.array(dpt_depth)
252
- if isinstance(da_depth, torch.Tensor):
253
- da_depth = da_depth.cpu().numpy()
254
- if len(dpt_depth.shape) > 2:
255
- dpt_depth = np.mean(dpt_depth, axis=2)
256
- if len(da_depth.shape) > 2:
257
- da_depth = np.mean(da_depth, axis=2)
258
-
259
- if dpt_depth.shape != da_depth.shape:
260
- da_depth = cv2.resize(da_depth, (dpt_depth.shape[1], dpt_depth.shape[0]), interpolation=cv2.INTER_CUBIC)
261
-
262
- p_low_dpt, p_high_dpt = np.percentile(dpt_depth, [1, 99])
263
- p_low_da, p_high_da = np.percentile(da_depth, [1, 99])
264
- dpt_depth = np.clip((dpt_depth - p_low_dpt) / (p_high_dpt - p_low_dpt), 0, 1) if p_high_dpt > p_low_dpt else dpt_depth
265
- da_depth = np.clip((da_depth - p_low_da) / (p_high_da - p_low_da), 0, 1) if p_high_da > p_low_da else da_depth
266
-
267
- if detail_level == 'high':
268
- weight_da = 0.7
269
- edges = cv2.Canny((da_depth * 255).astype(np.uint8), 50, 150)
270
- edge_mask = (edges > 0).astype(np.float32)
271
- dpt_weight = gaussian_filter(1 - edge_mask, sigma=1.0)
272
- da_weight = gaussian_filter(edge_mask, sigma=1.0)
273
- fused_depth = dpt_weight * dpt_depth + da_weight * da_depth * weight_da + (1 - weight_da) * dpt_depth
274
- else:
275
- weight_da = 0.5 if detail_level == 'medium' else 0.3
276
- fused_depth = (1 - weight_da) * dpt_depth + weight_da * da_depth
277
-
278
- fused_depth = np.clip(fused_depth, 0, 1)
279
- return fused_depth
280
-
281
  def enhance_depth_map(depth_map, detail_level='medium'):
282
  enhanced_depth = depth_map.copy().astype(np.float32)
283
  p_low, p_high = np.percentile(enhanced_depth, [1, 99])
@@ -420,7 +360,7 @@ def combine_meshes(meshes):
420
  def health_check():
421
  return jsonify({
422
  "status": "healthy",
423
- "model": "DPT-Large + Depth Anything (Multi-View)",
424
  "device": "cpu"
425
  }), 200
426
 
@@ -527,7 +467,7 @@ def convert_image_to_3d():
527
  processing_jobs[job_id]['progress'] = 10
528
 
529
  try:
530
- dpt_model, da_model, da_processor = load_models()
531
  processing_jobs[job_id]['progress'] = 20
532
  except Exception as e:
533
  processing_jobs[job_id]['status'] = 'error'
@@ -543,26 +483,13 @@ def convert_image_to_3d():
543
  dpt_result = dpt_model(image)
544
  dpt_depth = dpt_result["depth"]
545
 
546
- if da_model and da_processor:
547
- inputs = da_processor(images=image, return_tensors="pt")
548
- inputs = {k: v.to("cpu") for k, v in inputs.items()}
549
- outputs = da_model(**inputs)
550
- da_depth = outputs.predicted_depth.squeeze()
551
- da_depth = torch.nn.functional.interpolate(
552
- da_depth.unsqueeze(0).unsqueeze(0),
553
- size=(image.height, image.width),
554
- mode='bicubic',
555
- align_corners=False
556
- ).squeeze()
557
- fused_depth = fuse_depth_maps(dpt_depth, da_depth, detail_level)
558
- else:
559
- fused_depth = np.array(dpt_depth) if isinstance(dpt_depth, Image.Image) else dpt_depth
560
- if len(fused_depth.shape) > 2:
561
- fused_depth = np.mean(fused_depth, axis=2)
562
- p_low, p_high = np.percentile(fused_depth, [1, 99])
563
- fused_depth = np.clip((fused_depth - p_low) / (p_high - p_low), 0, 1) if p_high > p_low else fused_depth
564
 
565
- mesh = depth_to_mesh(fused_depth, image, resolution=mesh_resolution, detail_level=detail_level, view_angle=view_angles[view])
566
  meshes.append(mesh)
567
  gc.collect()
568
 
@@ -748,7 +675,7 @@ def model_info(job_id):
748
  @app.route('/', methods=['GET'])
749
  def index():
750
  return jsonify({
751
- "message": "Multi-View Image to 3D API (DPT-Large + Depth Anything)",
752
  "endpoints": [
753
  "/convert",
754
  "/progress/<job_id>",
@@ -766,7 +693,7 @@ def index():
766
  "detail_level": "low, medium, or high",
767
  "texture_quality": "low, medium, or high"
768
  },
769
- "description": "Creates high-quality 3D models from multiple 2D images using DPT-Large and Depth Anything."
770
  }), 200
771
 
772
  if __name__ == '__main__':
 
15
  from flask_cors import CORS
16
  import numpy as np
17
  import trimesh
18
+ from transformers import pipeline
19
  from scipy.ndimage import gaussian_filter
20
  from scipy import interpolate
21
  import cv2
 
22
 
23
  app = Flask(__name__)
24
  CORS(app)
 
44
 
45
  # Model variables
46
  dpt_estimator = None
 
 
47
  model_loaded = False
48
  model_loading = False
49
 
 
86
 
87
  def remove_background(image_path):
88
  try:
89
+ # Load image
90
+ img = cv2.imread(image_path)
91
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 
92
 
93
+ # Initialize mask and models for GrabCut
94
+ mask = np.zeros(img.shape[:2], np.uint8)
95
+ bgd_model = np.zeros((1, 65), np.float64)
96
+ fgd_model = np.zeros((1, 65), np.float64)
97
+
98
+ # Define initial rectangle (10% border margin)
99
+ h, w = img.shape[:2]
100
+ margin = int(min(w, h) * 0.1)
101
+ rect = (margin, margin, w - 2 * margin, h - 2 * margin)
102
+
103
+ # Run GrabCut
104
+ cv2.grabCut(img, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT)
105
+
106
+ # Create final mask (0 for background, 1 for foreground)
107
+ mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
108
+
109
+ # Check if foreground exists
110
+ if np.sum(mask2) == 0:
111
+ print(f"Warning: No foreground detected in {image_path}")
112
  return None
113
 
114
+ # Apply mask and set background to black
115
+ img = img * mask2[:, :, np.newaxis]
116
+ img_pil = Image.fromarray(img).convert("RGB")
117
+
118
+ return img_pil
119
  except Exception as e:
120
  print(f"Error in remove_background for {image_path}: {str(e)}")
121
  raise
 
123
  def preprocess_image(image_path):
124
  img = remove_background(image_path)
125
  if img is None:
126
+ raise ValueError("No foreground detected in image")
127
 
128
  if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
129
  if img.width > img.height:
 
147
  return img
148
 
149
  def load_models():
150
+ global dpt_estimator, model_loaded, model_loading
151
 
152
  if model_loaded:
153
+ return dpt_estimator
154
 
155
  if model_loading:
156
  while model_loading and not model_loaded:
157
  time.sleep(0.5)
158
+ return dpt_estimator
159
 
160
  try:
161
  model_loading = True
 
167
  login(token=hf_token)
168
  print("Authenticated with Hugging Face token")
169
  else:
170
+ print("Error: HF_TOKEN not found in environment. Intel/dpt-large requires authentication.")
171
+ raise ValueError("HF_TOKEN is required for Intel/dpt-large")
172
 
173
  dpt_model_name = "Intel/dpt-large"
174
  fs = HfFileSystem(token=hf_token)
 
208
  print("DPT-Large loaded")
209
  gc.collect()
210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  model_loaded = True
212
+ return dpt_estimator
 
213
 
214
  except Exception as e:
215
  print(f"Error loading models: {str(e)}")
 
218
  finally:
219
  model_loading = False
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def enhance_depth_map(depth_map, detail_level='medium'):
222
  enhanced_depth = depth_map.copy().astype(np.float32)
223
  p_low, p_high = np.percentile(enhanced_depth, [1, 99])
 
360
  def health_check():
361
  return jsonify({
362
  "status": "healthy",
363
+ "model": "DPT-Large (Multi-View)",
364
  "device": "cpu"
365
  }), 200
366
 
 
467
  processing_jobs[job_id]['progress'] = 10
468
 
469
  try:
470
+ dpt_model = load_models()
471
  processing_jobs[job_id]['progress'] = 20
472
  except Exception as e:
473
  processing_jobs[job_id]['status'] = 'error'
 
483
  dpt_result = dpt_model(image)
484
  dpt_depth = dpt_result["depth"]
485
 
486
+ depth_map = np.array(dpt_depth) if isinstance(dpt_depth, Image.Image) else dpt_depth
487
+ if len(depth_map.shape) > 2:
488
+ depth_map = np.mean(depth_map, axis=2)
489
+ p_low, p_high = np.percentile(depth_map, [1, 99])
490
+ depth_map = np.clip((depth_map - p_low) / (p_high - p_low), 0, 1) if p_high > p_low else depth_map
 
 
 
 
 
 
 
 
 
 
 
 
 
491
 
492
+ mesh = depth_to_mesh(depth_map, image, resolution=mesh_resolution, detail_level=detail_level, view_angle=view_angles[view])
493
  meshes.append(mesh)
494
  gc.collect()
495
 
 
675
  @app.route('/', methods=['GET'])
676
  def index():
677
  return jsonify({
678
+ "message": "Multi-View Image to 3D API (DPT-Large)",
679
  "endpoints": [
680
  "/convert",
681
  "/progress/<job_id>",
 
693
  "detail_level": "low, medium, or high",
694
  "texture_quality": "low, medium, or high"
695
  },
696
+ "description": "Creates 3D models from multiple 2D images using Intel DPT-Large with custom background removal."
697
  }), 200
698
 
699
  if __name__ == '__main__':