mac9087 commited on
Commit
542f872
·
verified ·
1 Parent(s): d5ed7cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -89
app.py CHANGED
@@ -101,74 +101,114 @@ job_results = {}
101
  generation_thread = None
102
  is_thread_running = False
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def load_models_if_needed():
105
- global xm, model, diffusion
106
- if xm is None or model is None or diffusion is None:
107
- print("Loading models for the first time...")
108
- try:
109
- # Set lower precision for memory optimization
110
- torch.set_default_dtype(torch.float32) # Use float32 instead of float64
111
- xm = load_model('transmitter', device=device)
112
- model = load_model('text300M', device=device)
113
- diffusion = diffusion_from_config(load_config('diffusion'))
114
- print("Models loaded successfully!")
115
- except Exception as e:
116
- print(f"Error loading models: {e}")
117
- raise
118
 
119
- def worker_thread():
120
- global is_thread_running
121
- is_thread_running = True
122
 
123
- try:
124
- while True:
125
- try:
126
- # Get job from queue with a timeout
127
- job_id, prompt = job_queue.get(timeout=1)
128
- print(f"Processing job {job_id} with prompt: {prompt}")
129
-
130
- # Process the job
131
- result = process_job(job_id, prompt)
132
-
133
- # Store the result
134
- job_results[job_id] = result
135
-
136
- except queue.Empty:
137
- # No jobs in queue, continue waiting
138
- pass
139
- except Exception as e:
140
- print(f"Error in worker thread: {e}")
141
- import traceback
142
- traceback.print_exc()
143
- # If there was a job being processed, mark it as failed
144
- if 'job_id' in locals():
145
- job_results[job_id] = {
146
- "success": False,
147
- "error": str(e)
148
- }
149
- finally:
150
- is_thread_running = False
151
 
152
  def process_job(job_id, prompt):
153
  try:
154
- # Load models if not already loaded
155
- load_models_if_needed()
 
 
 
156
 
157
- # Set parameters for CPU performance (reduced steps and other optimizations)
158
- batch_size = 1
159
- guidance_scale = 15.0
160
 
161
- # *** EXTREME OPTIMIZATION: Significantly reduce steps for low-memory environments ***
162
- karras_steps = 8 # Reduced from 16 to 8 for even better performance
163
-
164
- # *** OPTIMIZATION: Run garbage collection before starting intensive task ***
165
  gc.collect()
166
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
167
 
168
- # Generate latents with the text-to-3D model
169
  print(f"Starting latent generation for job {job_id} with {karras_steps} steps...")
170
 
171
- # Force lower precision
172
  with torch.inference_mode():
173
  latents = sample_latents(
174
  batch_size=batch_size,
@@ -187,22 +227,51 @@ def process_job(job_id, prompt):
187
  )
188
  print(f"Latent generation complete for job {job_id}!")
189
 
190
- # *** OPTIMIZATION: Run garbage collection after intensive step ***
191
- gc.collect()
192
 
193
  # Generate a unique filename
194
  unique_id = str(uuid.uuid4())
195
  filename = f"{output_dir}/{unique_id}"
196
 
197
- # Convert latent to mesh with optimization settings
 
 
 
198
  print(f"Decoding mesh for job {job_id}...")
199
  t0 = time.time()
200
 
201
- # *** OPTIMIZATION: Use simplified decoding with lower resolution ***
202
- mesh = decode_latent_mesh(xm, latents[0], max_points=4000).tri_mesh() # Reduced point count
 
 
 
 
 
203
  print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
 
 
204
 
205
- # *** OPTIMIZATION: Clear latents from memory as they're no longer needed ***
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  del latents
207
  gc.collect()
208
 
@@ -217,13 +286,12 @@ def process_job(job_id, prompt):
217
  with open(obj_path, 'w') as f:
218
  mesh.write_obj(f)
219
 
220
- # *** OPTIMIZATION: Clear mesh from memory ***
221
  del mesh
222
  gc.collect()
223
 
224
  print(f"Files saved successfully for job {job_id}!")
225
 
226
- # Return paths to the generated files
227
  return {
228
  "success": True,
229
  "message": "3D model generated successfully",
@@ -240,6 +308,41 @@ def process_job(job_id, prompt):
240
  "error": str(e)
241
  }
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  def ensure_worker_thread_running():
244
  global generation_thread, is_thread_running
245
 
@@ -248,8 +351,21 @@ def ensure_worker_thread_running():
248
  generation_thread = threading.Thread(target=worker_thread, daemon=True)
249
  generation_thread.start()
250
 
 
 
 
251
  @app.route('/generate', methods=['POST'])
252
  def generate_3d():
 
 
 
 
 
 
 
 
 
 
253
  # Get the prompt from the request
254
  data = request.json
255
  if not data or 'prompt' not in data:
@@ -264,6 +380,7 @@ def generate_3d():
264
  # Add job to queue
265
  ensure_worker_thread_running()
266
  job_queue.put((job_id, prompt))
 
267
 
268
  # Return job ID immediately
269
  return jsonify({
@@ -278,10 +395,7 @@ def job_status(job_id):
278
  if job_id in job_results:
279
  result = job_results[job_id]
280
  # Clean up memory if the job is complete and successful
281
- if result.get("success", False):
282
- return jsonify(result)
283
- else:
284
- return jsonify({"error": result.get("error", "Unknown error")}), 500
285
  else:
286
  # Job is still in progress
287
  return jsonify({
@@ -299,35 +413,53 @@ def download_file(filename):
299
 
300
  @app.route('/health', methods=['GET'])
301
  def health_check():
302
- """Simple health check endpoint to verify the app is running"""
303
- # Check available memory
304
  try:
 
305
  memory_info = psutil.virtual_memory()
306
  memory_usage = f"{memory_info.percent}% (Available: {memory_info.available / (1024**3):.2f} GB)"
307
 
308
- # Check CPU usage
309
  cpu_usage = f"{psutil.cpu_percent(interval=0.1)}%"
310
 
311
- # Get queue status
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  queue_size = job_queue.qsize()
313
 
314
- # Get active jobs
315
- active_jobs = len(job_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  except Exception as e:
317
- memory_usage = "Error getting system info"
318
- cpu_usage = "Error getting CPU info"
319
- queue_size = "Unknown"
320
- active_jobs = "Unknown"
321
-
322
- return jsonify({
323
- "status": "ok",
324
- "message": "Service is running",
325
- "memory_usage": memory_usage,
326
- "cpu_usage": cpu_usage,
327
- "queue_size": queue_size,
328
- "active_jobs": active_jobs,
329
- "worker_running": is_thread_running
330
- })
331
 
332
  @app.route('/', methods=['GET'])
333
  def home():
@@ -399,10 +531,44 @@ GET /status/123e4567-e89b-12d3-a456-426614174000
399
  </html>
400
  """
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  if __name__ == '__main__':
403
- # Start the worker thread
404
  ensure_worker_thread_running()
 
405
 
406
  # Recommended to run with gunicorn for production with increased timeout:
407
  # $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1
408
- app.run(host='0.0.0.0', port=7860, debug=False) # Set debug=False in production
 
101
  generation_thread = None
102
  is_thread_running = False
103
 
104
+ # New global variables for optimizations
105
+ last_usage_time = None
106
+ active_jobs = 0
107
+ max_concurrent_jobs = 1 # Limit concurrent jobs for 2vCPU
108
+
109
+ def get_adaptive_parameters():
110
+ """Adjust parameters based on current system resources"""
111
+ mem = psutil.virtual_memory()
112
+
113
+ # Base parameters
114
+ params = {
115
+ 'karras_steps': 8,
116
+ 'batch_size': 1,
117
+ 'guidance_scale': 15.0
118
+ }
119
+
120
+ # If memory is tight, reduce steps further
121
+ if mem.percent > 70:
122
+ params['karras_steps'] = 6
123
+
124
+ # If we have more memory to spare, can be slightly more generous
125
+ if mem.percent < 50:
126
+ params['karras_steps'] = 10
127
+
128
+ print(f"Adaptive parameters chosen: karras_steps={params['karras_steps']}, mem={mem.percent}%")
129
+ return params
130
+
131
+ def check_memory_pressure():
132
+ """Check if memory is getting too high and take action if needed"""
133
+ mem = psutil.virtual_memory()
134
+ if mem.percent > 85: # Critical threshold
135
+ print("WARNING: Memory pressure critical. Forcing garbage collection.")
136
+ gc.collect()
137
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
138
+
139
+ # If still critical, try more aggressive measures
140
+ if psutil.virtual_memory().percent > 80:
141
+ print("EMERGENCY: Memory still critical. Clearing model cache.")
142
+ # Reset global models to force reload when memory is better
143
+ global xm, model, diffusion
144
+ xm, model, diffusion = None, None, None
145
+ gc.collect()
146
+ return True
147
+ return False
148
+
149
+ def load_transmitter_model():
150
+ global xm, last_usage_time
151
+ last_usage_time = time.time()
152
+
153
+ if xm is None:
154
+ print("Loading transmitter model...")
155
+ xm = load_model('transmitter', device=device)
156
+ print("Transmitter model loaded!")
157
+
158
+ def load_primary_model():
159
+ global model, diffusion, last_usage_time
160
+ last_usage_time = time.time()
161
+
162
+ if model is None or diffusion is None:
163
+ print("Loading primary models...")
164
+ torch.set_default_dtype(torch.float32) # Use float32 instead of float64
165
+ model = load_model('text300M', device=device)
166
+ diffusion = diffusion_from_config(load_config('diffusion'))
167
+ print("Primary models loaded!")
168
+
169
  def load_models_if_needed():
170
+ """Legacy function for compatibility"""
171
+ load_primary_model()
172
+ load_transmitter_model()
 
 
 
 
 
 
 
 
 
 
173
 
174
+ def model_unloader_thread():
175
+ """Thread that periodically unloads models if they haven't been used"""
176
+ global xm, model, diffusion, last_usage_time
177
 
178
+ while True:
179
+ time.sleep(300) # Check every 5 minutes
180
+
181
+ if last_usage_time is not None:
182
+ idle_time = time.time() - last_usage_time
183
+
184
+ # If models have been idle for more than 10 minutes and no active jobs
185
+ if idle_time > 600 and active_jobs == 0:
186
+ # Check memory usage
187
+ mem = psutil.virtual_memory()
188
+ if mem.percent > 50: # Only unload if memory usage is significant
189
+ print(f"Models idle for {idle_time:.1f} seconds and memory at {mem.percent}%. Unloading...")
190
+ xm, model, diffusion = None, None, None
191
+ gc.collect()
192
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  def process_job(job_id, prompt):
195
  try:
196
+ # Get adaptive parameters
197
+ adaptive_params = get_adaptive_parameters()
198
+ karras_steps = adaptive_params['karras_steps']
199
+ batch_size = adaptive_params['batch_size']
200
+ guidance_scale = adaptive_params['guidance_scale']
201
 
202
+ # Load primary models for generation
203
+ load_primary_model()
 
204
 
205
+ # Optimization: Run garbage collection before starting intensive task
 
 
 
206
  gc.collect()
207
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
208
 
 
209
  print(f"Starting latent generation for job {job_id} with {karras_steps} steps...")
210
 
211
+ # Generate latents
212
  with torch.inference_mode():
213
  latents = sample_latents(
214
  batch_size=batch_size,
 
227
  )
228
  print(f"Latent generation complete for job {job_id}!")
229
 
230
+ # Optimization: Clear unnecessary memory and load next model
231
+ check_memory_pressure()
232
 
233
  # Generate a unique filename
234
  unique_id = str(uuid.uuid4())
235
  filename = f"{output_dir}/{unique_id}"
236
 
237
+ # Load transmitter model for decoding
238
+ load_transmitter_model()
239
+
240
+ # Convert latent to mesh
241
  print(f"Decoding mesh for job {job_id}...")
242
  t0 = time.time()
243
 
244
+ # Monitor memory
245
+ mem_before = psutil.Process().memory_info().rss / (1024 * 1024)
246
+ print(f"Memory before mesh decoding: {mem_before:.2f} MB")
247
+
248
+ # Decode the mesh (fixed: removed 'max_points' parameter from original code)
249
+ mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
250
+
251
  print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
252
+ mem_after = psutil.Process().memory_info().rss / (1024 * 1024)
253
+ print(f"Memory after decoding: {mem_after:.2f} MB (delta: {mem_after - mem_before:.2f} MB)")
254
 
255
+ # Report mesh complexity if possible
256
+ try:
257
+ print(f"Mesh complexity: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces")
258
+ except:
259
+ print("Could not determine mesh complexity")
260
+
261
+ # Simplify mesh if it's too complex (if supported)
262
+ try:
263
+ if hasattr(mesh, 'simplify') and hasattr(mesh, 'faces') and len(mesh.faces) > 5000:
264
+ target_faces = min(5000, int(len(mesh.faces) * 0.6))
265
+ print(f"Simplifying mesh to target {target_faces} faces...")
266
+ t0 = time.time()
267
+ simplified = mesh.simplify_quadratic_decimation(target_faces)
268
+ mesh = simplified
269
+ print(f"Mesh simplified in {time.time() - t0:.2f} seconds")
270
+ print(f"New complexity: {len(mesh.vertices)} vertices, {len(mesh.faces)} faces")
271
+ except Exception as e:
272
+ print(f"Mesh simplification not available or failed: {e}")
273
+
274
+ # Clear latents from memory
275
  del latents
276
  gc.collect()
277
 
 
286
  with open(obj_path, 'w') as f:
287
  mesh.write_obj(f)
288
 
289
+ # Clear mesh from memory
290
  del mesh
291
  gc.collect()
292
 
293
  print(f"Files saved successfully for job {job_id}!")
294
 
 
295
  return {
296
  "success": True,
297
  "message": "3D model generated successfully",
 
308
  "error": str(e)
309
  }
310
 
311
+ def worker_thread():
312
+ global is_thread_running, active_jobs
313
+ is_thread_running = True
314
+
315
+ try:
316
+ while True:
317
+ try:
318
+ # Get job from queue with a timeout
319
+ job_id, prompt = job_queue.get(timeout=1)
320
+ print(f"Processing job {job_id} with prompt: {prompt}")
321
+
322
+ # Process the job
323
+ result = process_job(job_id, prompt)
324
+
325
+ # Store the result and update counter
326
+ job_results[job_id] = result
327
+ active_jobs -= 1
328
+
329
+ except queue.Empty:
330
+ # No jobs in queue, continue waiting
331
+ pass
332
+ except Exception as e:
333
+ print(f"Error in worker thread: {e}")
334
+ import traceback
335
+ traceback.print_exc()
336
+ # If there was a job being processed, mark it as failed
337
+ if 'job_id' in locals():
338
+ job_results[job_id] = {
339
+ "success": False,
340
+ "error": str(e)
341
+ }
342
+ active_jobs -= 1
343
+ finally:
344
+ is_thread_running = False
345
+
346
  def ensure_worker_thread_running():
347
  global generation_thread, is_thread_running
348
 
 
351
  generation_thread = threading.Thread(target=worker_thread, daemon=True)
352
  generation_thread.start()
353
 
354
+ def start_model_unloader():
355
+ threading.Thread(target=model_unloader_thread, daemon=True).start()
356
+
357
  @app.route('/generate', methods=['POST'])
358
  def generate_3d():
359
+ global active_jobs
360
+
361
+ # Check if we're already at max capacity
362
+ if active_jobs >= max_concurrent_jobs:
363
+ return jsonify({
364
+ "success": False,
365
+ "error": "Server is at maximum capacity. Please try again later.",
366
+ "retry_after": 300
367
+ }), 503
368
+
369
  # Get the prompt from the request
370
  data = request.json
371
  if not data or 'prompt' not in data:
 
380
  # Add job to queue
381
  ensure_worker_thread_running()
382
  job_queue.put((job_id, prompt))
383
+ active_jobs += 1
384
 
385
  # Return job ID immediately
386
  return jsonify({
 
395
  if job_id in job_results:
396
  result = job_results[job_id]
397
  # Clean up memory if the job is complete and successful
398
+ return jsonify(result)
 
 
 
399
  else:
400
  # Job is still in progress
401
  return jsonify({
 
413
 
414
  @app.route('/health', methods=['GET'])
415
  def health_check():
416
+ """Enhanced health check endpoint to monitor resource usage"""
 
417
  try:
418
+ # Memory info
419
  memory_info = psutil.virtual_memory()
420
  memory_usage = f"{memory_info.percent}% (Available: {memory_info.available / (1024**3):.2f} GB)"
421
 
422
+ # CPU info
423
  cpu_usage = f"{psutil.cpu_percent(interval=0.1)}%"
424
 
425
+ # Process specific info
426
+ process = psutil.Process()
427
+ process_memory = f"{process.memory_info().rss / (1024**3):.2f} GB"
428
+
429
+ # Models status
430
+ models_loaded = []
431
+ if model is not None:
432
+ models_loaded.append("text300M")
433
+ if diffusion is not None:
434
+ models_loaded.append("diffusion")
435
+ if xm is not None:
436
+ models_loaded.append("transmitter")
437
+
438
+ # Queue status
439
  queue_size = job_queue.qsize()
440
 
441
+ # Check for model inactivity
442
+ model_inactive = "N/A"
443
+ if last_usage_time is not None:
444
+ model_inactive = f"{(time.time() - last_usage_time) / 60:.1f} minutes"
445
+
446
+ return jsonify({
447
+ "status": "ok",
448
+ "message": "Service is running",
449
+ "memory_usage": memory_usage,
450
+ "process_memory": process_memory,
451
+ "cpu_usage": cpu_usage,
452
+ "queue_size": queue_size,
453
+ "active_jobs": active_jobs,
454
+ "worker_running": is_thread_running,
455
+ "models_loaded": models_loaded,
456
+ "model_inactive_time": model_inactive
457
+ })
458
  except Exception as e:
459
+ return jsonify({
460
+ "status": "warning",
461
+ "error": str(e)
462
+ })
 
 
 
 
 
 
 
 
 
 
463
 
464
  @app.route('/', methods=['GET'])
465
  def home():
 
531
  </html>
532
  """
533
 
534
+ @app.route('/purge-results', methods=['POST'])
535
+ def purge_old_results():
536
+ """Endpoint to manually purge old job results to free memory"""
537
+ try:
538
+ # Get the time threshold from request (default to 1 hour)
539
+ threshold_hours = request.json.get('threshold_hours', 1) if request.json else 1
540
+ threshold_time = time.time() - (threshold_hours * 3600)
541
+
542
+ # Track jobs to be removed
543
+ jobs_to_remove = []
544
+ for job_id, result in job_results.items():
545
+ # If the job has a timestamp and it's older than threshold
546
+ if result.get('timestamp', time.time()) < threshold_time:
547
+ jobs_to_remove.append(job_id)
548
+
549
+ # Remove the old jobs
550
+ for job_id in jobs_to_remove:
551
+ job_results.pop(job_id, None)
552
+
553
+ # Force garbage collection
554
+ gc.collect()
555
+
556
+ return jsonify({
557
+ "success": True,
558
+ "message": f"Purged {len(jobs_to_remove)} old job results",
559
+ "remaining_jobs": len(job_results)
560
+ })
561
+ except Exception as e:
562
+ return jsonify({
563
+ "success": False,
564
+ "error": str(e)
565
+ }), 500
566
+
567
  if __name__ == '__main__':
568
+ # Start the worker thread and model unloader
569
  ensure_worker_thread_running()
570
+ start_model_unloader()
571
 
572
  # Recommended to run with gunicorn for production with increased timeout:
573
  # $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1
574
+ app.run(host='0.0.0.0', port=7860, debug=False) # Set debug=False in production