mac9087 commited on
Commit
e51a639
·
verified ·
1 Parent(s): 05ec4e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -52
app.py CHANGED
@@ -110,20 +110,20 @@ def get_adaptive_parameters():
110
  """Adjust parameters based on current system resources"""
111
  mem = psutil.virtual_memory()
112
 
113
- # Base parameters
114
  params = {
115
- 'karras_steps': 8,
116
  'batch_size': 1,
117
  'guidance_scale': 15.0
118
  }
119
 
120
  # If memory is tight, reduce steps further
121
  if mem.percent > 70:
122
- params['karras_steps'] = 6
123
 
124
  # If we have more memory to spare, can be slightly more generous
125
  if mem.percent < 50:
126
- params['karras_steps'] = 10
127
 
128
  print(f"Adaptive parameters chosen: karras_steps={params['karras_steps']}, mem={mem.percent}%")
129
  return params
@@ -131,13 +131,13 @@ def get_adaptive_parameters():
131
  def check_memory_pressure():
132
  """Check if memory is getting too high and take action if needed"""
133
  mem = psutil.virtual_memory()
134
- if mem.percent > 85: # Critical threshold
135
  print("WARNING: Memory pressure critical. Forcing garbage collection.")
136
  gc.collect()
137
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
138
 
139
  # If still critical, try more aggressive measures
140
- if psutil.virtual_memory().percent > 80:
141
  print("EMERGENCY: Memory still critical. Clearing model cache.")
142
  # Reset global models to force reload when memory is better
143
  global xm, model, diffusion
@@ -176,21 +176,87 @@ def model_unloader_thread():
176
  global xm, model, diffusion, last_usage_time
177
 
178
  while True:
179
- time.sleep(300) # Check every 5 minutes
180
 
181
  if last_usage_time is not None:
182
  idle_time = time.time() - last_usage_time
183
 
184
- # If models have been idle for more than 10 minutes and no active jobs
185
- if idle_time > 600 and active_jobs == 0:
186
- # Check memory usage
187
  mem = psutil.virtual_memory()
188
- if mem.percent > 50: # Only unload if memory usage is significant
189
  print(f"Models idle for {idle_time:.1f} seconds and memory at {mem.percent}%. Unloading...")
190
  xm, model, diffusion = None, None, None
191
  gc.collect()
192
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  def process_job(job_id, prompt):
195
  try:
196
  # Get adaptive parameters
@@ -209,6 +275,7 @@ def process_job(job_id, prompt):
209
  print(f"Starting latent generation for job {job_id} with {karras_steps} steps...")
210
 
211
  # Generate latents
 
212
  with torch.inference_mode():
213
  latents = sample_latents(
214
  batch_size=batch_size,
@@ -227,7 +294,7 @@ def process_job(job_id, prompt):
227
  )
228
  print(f"Latent generation complete for job {job_id}!")
229
 
230
- # Optimization: Clear unnecessary memory and load next model
231
  check_memory_pressure()
232
 
233
  # Generate a unique filename
@@ -245,7 +312,7 @@ def process_job(job_id, prompt):
245
  mem_before = psutil.Process().memory_info().rss / (1024 * 1024)
246
  print(f"Memory before mesh decoding: {mem_before:.2f} MB")
247
 
248
- # Decode the mesh (fixed: removed 'max_points' parameter from original code)
249
  mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
250
 
251
  print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
@@ -254,50 +321,58 @@ def process_job(job_id, prompt):
254
 
255
  # Report mesh complexity if possible
256
  try:
257
- print(f"Mesh complexity: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces")
258
- except:
259
- print("Could not determine mesh complexity")
260
-
261
- # Simplify mesh if it's too complex (if supported)
262
- try:
263
- if hasattr(mesh, 'simplify') and hasattr(mesh, 'faces') and len(mesh.faces) > 5000:
264
- target_faces = min(5000, int(len(mesh.faces) * 0.6))
265
- print(f"Simplifying mesh to target {target_faces} faces...")
266
- t0 = time.time()
267
- simplified = mesh.simplify_quadratic_decimation(target_faces)
268
- mesh = simplified
269
- print(f"Mesh simplified in {time.time() - t0:.2f} seconds")
270
- print(f"New complexity: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces")
271
  except Exception as e:
272
- print(f"Mesh simplification not available or failed: {e}")
 
 
273
 
274
  # Clear latents from memory
275
  del latents
276
  gc.collect()
277
 
278
- # Save as GLB
279
- print(f"Saving job {job_id} as GLB...")
280
- glb_path = f"{filename}.glb"
281
- mesh.write_glb(glb_path)
282
 
283
- # Save as OBJ
284
- print(f"Saving job {job_id} as OBJ...")
285
- obj_path = f"{filename}.obj"
286
- with open(obj_path, 'w') as f:
287
- mesh.write_obj(f)
288
 
289
  # Clear mesh from memory
290
  del mesh
291
  gc.collect()
292
 
293
- print(f"Files saved successfully for job {job_id}!")
294
-
295
- return {
296
  "success": True,
297
  "message": "3D model generated successfully",
298
- "glb_url": f"/download/{os.path.basename(glb_path)}",
299
- "obj_url": f"/download/{os.path.basename(obj_path)}"
 
 
 
300
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
  except Exception as e:
303
  print(f"Error during generation for job {job_id}: {str(e)}")
@@ -305,7 +380,8 @@ def process_job(job_id, prompt):
305
  traceback.print_exc()
306
  return {
307
  "success": False,
308
- "error": str(e)
 
309
  }
310
 
311
  def worker_thread():
@@ -326,6 +402,9 @@ def worker_thread():
326
  job_results[job_id] = result
327
  active_jobs -= 1
328
 
 
 
 
329
  except queue.Empty:
330
  # No jobs in queue, continue waiting
331
  pass
@@ -337,12 +416,43 @@ def worker_thread():
337
  if 'job_id' in locals():
338
  job_results[job_id] = {
339
  "success": False,
340
- "error": str(e)
 
341
  }
342
  active_jobs -= 1
 
 
 
343
  finally:
344
  is_thread_running = False
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  def ensure_worker_thread_running():
347
  global generation_thread, is_thread_running
348
 
@@ -351,8 +461,13 @@ def ensure_worker_thread_running():
351
  generation_thread = threading.Thread(target=worker_thread, daemon=True)
352
  generation_thread.start()
353
 
354
- def start_model_unloader():
 
 
355
  threading.Thread(target=model_unloader_thread, daemon=True).start()
 
 
 
356
 
357
  @app.route('/generate', methods=['POST'])
358
  def generate_3d():
@@ -394,7 +509,7 @@ def generate_3d():
394
  def job_status(job_id):
395
  if job_id in job_results:
396
  result = job_results[job_id]
397
- # Clean up memory if the job is complete and successful
398
  return jsonify(result)
399
  else:
400
  # Job is still in progress
@@ -407,9 +522,13 @@ def job_status(job_id):
407
  @app.route('/download/<filename>', methods=['GET'])
408
  def download_file(filename):
409
  try:
410
- return send_file(f"{output_dir}/{filename}", as_attachment=True)
 
 
 
 
411
  except Exception as e:
412
- return jsonify({"error": str(e)}), 404
413
 
414
  @app.route('/health', methods=['GET'])
415
  def health_check():
@@ -443,6 +562,9 @@ def health_check():
443
  if last_usage_time is not None:
444
  model_inactive = f"{(time.time() - last_usage_time) / 60:.1f} minutes"
445
 
 
 
 
446
  return jsonify({
447
  "status": "ok",
448
  "message": "Service is running",
@@ -451,6 +573,7 @@ def health_check():
451
  "cpu_usage": cpu_usage,
452
  "queue_size": queue_size,
453
  "active_jobs": active_jobs,
 
454
  "worker_running": is_thread_running,
455
  "models_loaded": models_loaded,
456
  "model_inactive_time": model_inactive
@@ -516,12 +639,13 @@ GET /status/123e4567-e89b-12d3-a456-426614174000
516
  "success": true,
517
  "message": "3D model generated successfully",
518
  "glb_url": "/download/abc123.glb",
519
- "obj_url": "/download/abc123.obj"
 
520
  }
521
  </pre>
522
 
523
  <h3>Step 3: Download the files</h3>
524
- <p>Use the provided URLs to download the GLB and OBJ files.</p>
525
 
526
  <h2>Health Check:</h2>
527
  <pre>GET /health</pre>
@@ -563,10 +687,40 @@ def purge_old_results():
563
  "error": str(e)
564
  }), 500
565
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  if __name__ == '__main__':
567
- # Start the worker thread and model unloader
 
 
 
568
  ensure_worker_thread_running()
569
- start_model_unloader()
570
 
571
  # Recommended to run with gunicorn for production with increased timeout:
572
  # $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1
 
110
  """Adjust parameters based on current system resources"""
111
  mem = psutil.virtual_memory()
112
 
113
+ # Base parameters - more conservative to prevent memory issues
114
  params = {
115
+ 'karras_steps': 6, # Reduced from 8 to 6 as default
116
  'batch_size': 1,
117
  'guidance_scale': 15.0
118
  }
119
 
120
  # If memory is tight, reduce steps further
121
  if mem.percent > 70:
122
+ params['karras_steps'] = 4 # Even more conservative
123
 
124
  # If we have more memory to spare, can be slightly more generous
125
  if mem.percent < 50:
126
+ params['karras_steps'] = 8
127
 
128
  print(f"Adaptive parameters chosen: karras_steps={params['karras_steps']}, mem={mem.percent}%")
129
  return params
 
131
  def check_memory_pressure():
132
  """Check if memory is getting too high and take action if needed"""
133
  mem = psutil.virtual_memory()
134
+ if mem.percent > 80: # Reduced threshold from 85 to 80
135
  print("WARNING: Memory pressure critical. Forcing garbage collection.")
136
  gc.collect()
137
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
138
 
139
  # If still critical, try more aggressive measures
140
+ if psutil.virtual_memory().percent > 75:
141
  print("EMERGENCY: Memory still critical. Clearing model cache.")
142
  # Reset global models to force reload when memory is better
143
  global xm, model, diffusion
 
176
  global xm, model, diffusion, last_usage_time
177
 
178
  while True:
179
+ time.sleep(180) # Check more frequently: every 3 minutes instead of 5
180
 
181
  if last_usage_time is not None:
182
  idle_time = time.time() - last_usage_time
183
 
184
+ # If models have been idle for more than 5 minutes (reduced from 10) and no active jobs
185
+ if idle_time > 300 and active_jobs == 0:
186
+ # Check memory usage - more aggressive unloading
187
  mem = psutil.virtual_memory()
188
+ if mem.percent > 40: # Lowered threshold from 50 to 40
189
  print(f"Models idle for {idle_time:.1f} seconds and memory at {mem.percent}%. Unloading...")
190
  xm, model, diffusion = None, None, None
191
  gc.collect()
192
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
193
 
194
+ def save_trimesh(mesh, filename_base):
195
+ """Save mesh in multiple formats using trimesh"""
196
+ # Convert to trimesh format if needed
197
+ if not isinstance(mesh, trimesh.Trimesh):
198
+ try:
199
+ # Try to convert to trimesh
200
+ vertices = np.array(mesh.vertices)
201
+ faces = np.array(mesh.faces)
202
+ trimesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces)
203
+ except Exception as e:
204
+ print(f"Error converting to trimesh: {e}")
205
+ raise
206
+ else:
207
+ trimesh_obj = mesh
208
+
209
+ # Save as GLB
210
+ glb_path = f"{filename_base}.glb"
211
+ try:
212
+ trimesh_obj.export(glb_path, file_type='glb')
213
+ print(f"Saved GLB file: {glb_path}")
214
+ except Exception as e:
215
+ print(f"Error saving GLB: {e}")
216
+ # Try alternative approach
217
+ try:
218
+ scene = trimesh.Scene()
219
+ scene.add_geometry(trimesh_obj)
220
+ scene.export(glb_path)
221
+ print(f"Saved GLB using scene approach: {glb_path}")
222
+ except Exception as e2:
223
+ print(f"Alternative GLB export also failed: {e2}")
224
+ glb_path = None
225
+
226
+ # Save as OBJ - always works more reliably
227
+ obj_path = f"{filename_base}.obj"
228
+ try:
229
+ trimesh_obj.export(obj_path, file_type='obj')
230
+ print(f"Saved OBJ file: {obj_path}")
231
+ except Exception as e:
232
+ print(f"Error saving OBJ: {e}")
233
+ # Try to write directly
234
+ try:
235
+ with open(obj_path, 'w') as f:
236
+ for v in trimesh_obj.vertices:
237
+ f.write(f"v {v[0]} {v[1]} {v[2]}\n")
238
+ for face in trimesh_obj.faces:
239
+ f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
240
+ print(f"Saved OBJ using direct write: {obj_path}")
241
+ except Exception as e2:
242
+ print(f"Alternative OBJ export also failed: {e2}")
243
+ obj_path = None
244
+
245
+ # Also save as PLY as a fallback
246
+ ply_path = f"{filename_base}.ply"
247
+ try:
248
+ trimesh_obj.export(ply_path, file_type='ply')
249
+ print(f"Saved PLY file: {ply_path}")
250
+ except Exception as e:
251
+ print(f"Error saving PLY: {e}")
252
+ ply_path = None
253
+
254
+ return {
255
+ "glb": os.path.basename(glb_path) if glb_path else None,
256
+ "obj": os.path.basename(obj_path) if obj_path else None,
257
+ "ply": os.path.basename(ply_path) if ply_path else None
258
+ }
259
+
260
  def process_job(job_id, prompt):
261
  try:
262
  # Get adaptive parameters
 
275
  print(f"Starting latent generation for job {job_id} with {karras_steps} steps...")
276
 
277
  # Generate latents
278
+ latents = None
279
  with torch.inference_mode():
280
  latents = sample_latents(
281
  batch_size=batch_size,
 
294
  )
295
  print(f"Latent generation complete for job {job_id}!")
296
 
297
+ # Optimization: Clear unnecessary memory and check pressure
298
  check_memory_pressure()
299
 
300
  # Generate a unique filename
 
312
  mem_before = psutil.Process().memory_info().rss / (1024 * 1024)
313
  print(f"Memory before mesh decoding: {mem_before:.2f} MB")
314
 
315
+ # Decode the mesh
316
  mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
317
 
318
  print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
 
321
 
322
  # Report mesh complexity if possible
323
  try:
324
+ vertices_count = len(mesh.vertices)
325
+ faces_count = len(mesh.faces)
326
+ print(f"Mesh complexity: {vertices_count} vertices, {faces_count} faces")
 
 
 
 
 
 
 
 
 
 
 
327
  except Exception as e:
328
+ print(f"Could not determine mesh complexity: {e}")
329
+ vertices_count = 0
330
+ faces_count = 0
331
 
332
  # Clear latents from memory
333
  del latents
334
  gc.collect()
335
 
336
+ # Convert to trimesh format and save files
337
+ print(f"Converting and saving mesh for job {job_id}...")
 
 
338
 
339
+ # Save mesh using the helper function
340
+ saved_files = save_trimesh(mesh, filename)
 
 
 
341
 
342
  # Clear mesh from memory
343
  del mesh
344
  gc.collect()
345
 
346
+ # Check which files were successfully saved
347
+ result = {
 
348
  "success": True,
349
  "message": "3D model generated successfully",
350
+ "timestamp": time.time(),
351
+ "stats": {
352
+ "vertices": vertices_count,
353
+ "faces": faces_count
354
+ }
355
  }
356
+
357
+ # Add URLs for the files that were saved
358
+ if saved_files["glb"]:
359
+ result["glb_url"] = f"/download/{saved_files['glb']}"
360
+ if saved_files["obj"]:
361
+ result["obj_url"] = f"/download/{saved_files['obj']}"
362
+ if saved_files["ply"]:
363
+ result["ply_url"] = f"/download/{saved_files['ply']}"
364
+
365
+ # If no files were saved, mark as failure
366
+ if not (saved_files["glb"] or saved_files["obj"] or saved_files["ply"]):
367
+ result["success"] = False
368
+ result["message"] = "Failed to save mesh in any format"
369
+
370
+ print(f"Files saved successfully for job {job_id}!")
371
+
372
+ # Force garbage collection again
373
+ gc.collect()
374
+
375
+ return result
376
 
377
  except Exception as e:
378
  print(f"Error during generation for job {job_id}: {str(e)}")
 
380
  traceback.print_exc()
381
  return {
382
  "success": False,
383
+ "error": str(e),
384
+ "timestamp": time.time()
385
  }
386
 
387
  def worker_thread():
 
402
  job_results[job_id] = result
403
  active_jobs -= 1
404
 
405
+ # Explicit cleanup after job
406
+ gc.collect()
407
+
408
  except queue.Empty:
409
  # No jobs in queue, continue waiting
410
  pass
 
416
  if 'job_id' in locals():
417
  job_results[job_id] = {
418
  "success": False,
419
+ "error": str(e),
420
+ "timestamp": time.time()
421
  }
422
  active_jobs -= 1
423
+
424
+ # Force garbage collection to clean up
425
+ gc.collect()
426
  finally:
427
  is_thread_running = False
428
 
429
+ def purge_old_results_thread():
430
+ """Thread that periodically cleans up old job results to manage memory"""
431
+ while True:
432
+ try:
433
+ time.sleep(1800) # Run every 30 minutes
434
+
435
+ # Default threshold: 2 hours
436
+ threshold_time = time.time() - (2 * 3600)
437
+
438
+ # Track jobs to be removed
439
+ jobs_to_remove = []
440
+ for job_id, result in job_results.items():
441
+ # If the job has a timestamp and it's older than threshold
442
+ if result.get('timestamp', time.time()) < threshold_time:
443
+ jobs_to_remove.append(job_id)
444
+
445
+ # Remove the old jobs
446
+ for job_id in jobs_to_remove:
447
+ job_results.pop(job_id, None)
448
+
449
+ if jobs_to_remove:
450
+ print(f"Auto-purged {len(jobs_to_remove)} old job results")
451
+ # Force garbage collection
452
+ gc.collect()
453
+ except Exception as e:
454
+ print(f"Error in purge thread: {e}")
455
+
456
  def ensure_worker_thread_running():
457
  global generation_thread, is_thread_running
458
 
 
461
  generation_thread = threading.Thread(target=worker_thread, daemon=True)
462
  generation_thread.start()
463
 
464
+ def start_monitoring_threads():
465
+ """Start all monitoring and maintenance threads"""
466
+ # Start model unloader thread
467
  threading.Thread(target=model_unloader_thread, daemon=True).start()
468
+
469
+ # Start results purge thread
470
+ threading.Thread(target=purge_old_results_thread, daemon=True).start()
471
 
472
  @app.route('/generate', methods=['POST'])
473
  def generate_3d():
 
509
  def job_status(job_id):
510
  if job_id in job_results:
511
  result = job_results[job_id]
512
+ # Return the result
513
  return jsonify(result)
514
  else:
515
  # Job is still in progress
 
522
  @app.route('/download/<filename>', methods=['GET'])
523
  def download_file(filename):
524
  try:
525
+ file_path = os.path.join(output_dir, filename)
526
+ if not os.path.exists(file_path):
527
+ return jsonify({"error": "File not found"}), 404
528
+
529
+ return send_file(file_path, as_attachment=True)
530
  except Exception as e:
531
+ return jsonify({"error": str(e)}), 500
532
 
533
  @app.route('/health', methods=['GET'])
534
  def health_check():
 
562
  if last_usage_time is not None:
563
  model_inactive = f"{(time.time() - last_usage_time) / 60:.1f} minutes"
564
 
565
+ # Number of saved jobs
566
+ saved_jobs = len(job_results)
567
+
568
  return jsonify({
569
  "status": "ok",
570
  "message": "Service is running",
 
573
  "cpu_usage": cpu_usage,
574
  "queue_size": queue_size,
575
  "active_jobs": active_jobs,
576
+ "saved_jobs": saved_jobs,
577
  "worker_running": is_thread_running,
578
  "models_loaded": models_loaded,
579
  "model_inactive_time": model_inactive
 
639
  "success": true,
640
  "message": "3D model generated successfully",
641
  "glb_url": "/download/abc123.glb",
642
+ "obj_url": "/download/abc123.obj",
643
+ "ply_url": "/download/abc123.ply"
644
  }
645
  </pre>
646
 
647
  <h3>Step 3: Download the files</h3>
648
+ <p>Use the provided URLs to download the GLB, OBJ, and PLY files.</p>
649
 
650
  <h2>Health Check:</h2>
651
  <pre>GET /health</pre>
 
687
  "error": str(e)
688
  }), 500
689
 
690
+ @app.route('/force-gc', methods=['POST'])
691
+ def force_garbage_collection():
692
+ """Endpoint to manually trigger garbage collection"""
693
+ try:
694
+ # Get current memory usage
695
+ before_mem = psutil.Process().memory_info().rss / (1024**3)
696
+
697
+ # Force garbage collection
698
+ gc.collect()
699
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
700
+
701
+ # Get memory usage after GC
702
+ after_mem = psutil.Process().memory_info().rss / (1024**3)
703
+ freed = before_mem - after_mem
704
+
705
+ return jsonify({
706
+ "success": True,
707
+ "message": f"Garbage collection completed",
708
+ "before_memory_gb": round(before_mem, 2),
709
+ "after_memory_gb": round(after_mem, 2),
710
+ "freed_memory_gb": round(freed, 2) if freed > 0 else 0
711
+ })
712
+ except Exception as e:
713
+ return jsonify({
714
+ "success": False,
715
+ "error": str(e)
716
+ }), 500
717
+
718
  if __name__ == '__main__':
719
+ # Start all monitoring threads
720
+ start_monitoring_threads()
721
+
722
+ # Start the worker thread
723
  ensure_worker_thread_running()
 
724
 
725
  # Recommended to run with gunicorn for production with increased timeout:
726
  # $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1