File size: 25,223 Bytes
2e7de92
 
ca7e93f
 
 
 
 
 
 
 
 
d6ba12d
9c8ecc3
d5ed7cc
 
 
ab52342
a3fa7e4
dfba0ad
 
a3fa7e4
 
 
 
 
 
ca7e93f
d6ba12d
 
dfba0ad
d6ba12d
 
 
 
 
dfba0ad
d6ba12d
dfba0ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6ba12d
ab52342
cb944d7
 
 
ca7e93f
a3fa7e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147bbe2
d6ba12d
 
ca7e93f
cef8b58
d6ba12d
 
 
 
 
d5ed7cc
 
 
 
 
 
542f872
 
 
 
 
 
 
 
 
e51a639
542f872
e51a639
542f872
 
 
 
 
05ec4e5
e51a639
542f872
 
 
e51a639
542f872
 
 
 
 
 
 
e51a639
542f872
 
 
 
 
e51a639
542f872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6ba12d
542f872
 
 
2e7de92
542f872
 
 
d5ed7cc
542f872
e51a639
542f872
 
 
 
e51a639
 
 
542f872
e51a639
542f872
 
 
 
d5ed7cc
e51a639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5ed7cc
cef8b58
542f872
 
 
 
 
d6ba12d
542f872
 
cef8b58
542f872
9c8ecc3
 
 
d5ed7cc
 
542f872
e51a639
d5ed7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cef8b58
e51a639
05ec4e5
9c8ecc3
ca7e93f
a3fa7e4
 
ca7e93f
542f872
 
 
 
d5ed7cc
ca7e93f
9c8ecc3
542f872
 
 
 
e51a639
542f872
 
ca7e93f
542f872
 
ca7e93f
542f872
 
e51a639
 
 
542f872
e51a639
 
 
542f872
 
9c8ecc3
 
 
e51a639
 
ca7e93f
e51a639
 
d6ba12d
542f872
9c8ecc3
 
 
e51a639
 
ca7e93f
 
e51a639
 
 
 
 
d5ed7cc
e51a639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca7e93f
cef8b58
d5ed7cc
d6ba12d
 
d5ed7cc
 
e51a639
 
d5ed7cc
 
542f872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e51a639
 
 
542f872
 
 
 
 
 
 
 
 
 
 
e51a639
 
542f872
 
e51a639
 
 
542f872
 
 
e51a639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5ed7cc
 
 
 
 
 
 
 
e51a639
 
 
542f872
e51a639
 
 
542f872
d5ed7cc
 
542f872
 
 
 
 
 
 
 
 
 
d5ed7cc
 
 
 
 
 
 
 
 
 
 
 
 
 
542f872
d5ed7cc
 
 
 
 
 
 
 
 
 
 
 
 
e51a639
05ec4e5
d5ed7cc
 
 
 
 
 
 
2e7de92
ca7e93f
 
 
e51a639
 
 
 
 
ca7e93f
e51a639
cb944d7
d6ba12d
 
542f872
9c8ecc3
542f872
9c8ecc3
 
d5ed7cc
542f872
d5ed7cc
 
542f872
 
 
 
 
 
 
 
 
 
 
 
 
 
d5ed7cc
 
542f872
 
 
 
 
e51a639
 
 
542f872
 
 
 
 
 
 
 
e51a639
542f872
 
05ec4e5
542f872
d5ed7cc
542f872
 
 
 
d6ba12d
 
 
 
 
 
d5ed7cc
 
 
 
 
 
 
 
 
d6ba12d
 
d5ed7cc
 
d6ba12d
d5ed7cc
 
 
 
 
 
 
 
 
d6ba12d
d5ed7cc
 
 
 
 
 
 
d6ba12d
d5ed7cc
 
 
 
 
 
 
 
 
 
 
d6ba12d
d5ed7cc
 
 
 
 
 
e51a639
 
d5ed7cc
 
 
 
e51a639
d5ed7cc
 
 
 
d6ba12d
 
 
 
542f872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e51a639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca7e93f
e51a639
 
 
 
d5ed7cc
 
9c8ecc3
 
ce03fd0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import torch
import numpy as np
import trimesh
import os
from io import BytesIO
import base64
from PIL import Image
import uuid
import time
import sys
import gc  # For explicit garbage collection
import threading
import queue
import psutil

# Set environment variables before anything else
os.environ['SHAPEE_NO_INTERACTIVE'] = '1'

# Setup cache directory with appropriate permissions
cache_dir = os.path.join(os.getcwd(), 'shap_e_model_cache')
os.makedirs(cache_dir, exist_ok=True)
os.environ['XDG_CACHE_HOME'] = os.getcwd()
print(f"Using cache directory: {cache_dir}")

# Import Shap-E
print("Importing Shap-E modules...")
try:
    # Try the direct import approach first
    from shap_e.diffusion.sample import sample_latents
    from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
    from shap_e.models.download import load_model, load_config
    from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
    print("Shap-E modules imported successfully!")
except ImportError as e:
    print(f"Error importing Shap-E modules: {e}")
    # Alternative approach if direct import fails
    try:
        print("Attempting alternative import approach...")
        # Try monkey patching the ipywidgets module if that's the issue
        import sys
        import types
        
        if 'ipywidgets' not in sys.modules:
            sys.modules['ipywidgets'] = types.ModuleType('ipywidgets')
            print("Added mock ipywidgets module")
        
        # Try imports again
        from shap_e.diffusion.sample import sample_latents
        from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
        from shap_e.models.download import load_model, load_config
        from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
        print("Shap-E modules imported successfully with workaround!")
    except Exception as e2:
        print(f"Alternative import also failed: {e2}")
        sys.exit(1)
except Exception as e:
    print(f"Unexpected error importing Shap-E modules: {e}")
    sys.exit(1)

app = Flask(__name__)
CORS(app)

# Create output directory if it doesn't exist
output_dir = os.path.join(os.getcwd(), "outputs")
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")

# Check permissions on directories
try:
    test_file_path = os.path.join(cache_dir, "test_write_permissions.txt")
    with open(test_file_path, 'w') as f:
        f.write("Testing write permissions")
    os.remove(test_file_path)
    print("Cache directory is writable")
except Exception as e:
    print(f"WARNING: Cache directory is not writable: {e}")

try:
    test_file_path = os.path.join(output_dir, "test_write_permissions.txt")
    with open(test_file_path, 'w') as f:
        f.write("Testing write permissions")
    os.remove(test_file_path)
    print("Output directory is writable")
except Exception as e:
    print(f"WARNING: Output directory is not writable: {e}")

print("Setting up device...")
device = torch.device('cpu')  # Force CPU for Hugging Face Spaces
print(f"Using device: {device}")

# Global variables for models (will be loaded on first request)
xm = None
model = None
diffusion = None

# Job queue and results dictionary
job_queue = queue.Queue()
job_results = {}
generation_thread = None
is_thread_running = False

# New global variables for optimizations
last_usage_time = None
active_jobs = 0
max_concurrent_jobs = 1  # Limit concurrent jobs for 2vCPU

def get_adaptive_parameters():
    """Adjust parameters based on current system resources"""
    mem = psutil.virtual_memory()
    
    # Base parameters - more conservative to prevent memory issues
    params = {
        'karras_steps': 6,  # Reduced from 8 to 6 as default
        'batch_size': 1,
        'guidance_scale': 15.0
    }
    
    # If memory is tight, reduce steps further
    if mem.percent > 70:
        params['karras_steps'] = 4  # Even more conservative
    
    # If we have more memory to spare, can be slightly more generous
    if mem.percent < 50:
        params['karras_steps'] = 8
    
    print(f"Adaptive parameters chosen: karras_steps={params['karras_steps']}, mem={mem.percent}%")
    return params

def check_memory_pressure():
    """Check if memory is getting too high and take action if needed"""
    mem = psutil.virtual_memory()
    if mem.percent > 80:  # Reduced threshold from 85 to 80
        print("WARNING: Memory pressure critical. Forcing garbage collection.")
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # If still critical, try more aggressive measures
        if psutil.virtual_memory().percent > 75:
            print("EMERGENCY: Memory still critical. Clearing model cache.")
            # Reset global models to force reload when memory is better
            global xm, model, diffusion
            xm, model, diffusion = None, None, None
            gc.collect()
            return True
    return False

def load_transmitter_model():
    global xm, last_usage_time
    last_usage_time = time.time()
    
    if xm is None:
        print("Loading transmitter model...")
        xm = load_model('transmitter', device=device)
        print("Transmitter model loaded!")

def load_primary_model():
    global model, diffusion, last_usage_time
    last_usage_time = time.time()
    
    if model is None or diffusion is None:
        print("Loading primary models...")
        torch.set_default_dtype(torch.float32)  # Use float32 instead of float64
        model = load_model('text300M', device=device)
        diffusion = diffusion_from_config(load_config('diffusion'))
        print("Primary models loaded!")

def load_models_if_needed():
    """Legacy function for compatibility"""
    load_primary_model()
    load_transmitter_model()

def model_unloader_thread():
    """Thread that periodically unloads models if they haven't been used"""
    global xm, model, diffusion, last_usage_time
    
    while True:
        time.sleep(180)  # Check more frequently: every 3 minutes instead of 5
        
        if last_usage_time is not None:
            idle_time = time.time() - last_usage_time
            
            # If models have been idle for more than 5 minutes (reduced from 10) and no active jobs
            if idle_time > 300 and active_jobs == 0:
                # Check memory usage - more aggressive unloading
                mem = psutil.virtual_memory()
                if mem.percent > 40:  # Lowered threshold from 50 to 40
                    print(f"Models idle for {idle_time:.1f} seconds and memory at {mem.percent}%. Unloading...")
                    xm, model, diffusion = None, None, None
                    gc.collect()
                    torch.cuda.empty_cache() if torch.cuda.is_available() else None

def save_trimesh(mesh, filename_base):
    """Save mesh in multiple formats using trimesh"""
    # Convert to trimesh format if needed
    if not isinstance(mesh, trimesh.Trimesh):
        try:
            # Try to convert to trimesh
            vertices = np.array(mesh.vertices)
            faces = np.array(mesh.faces)
            trimesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces)
        except Exception as e:
            print(f"Error converting to trimesh: {e}")
            raise
    else:
        trimesh_obj = mesh
    
    # Save as GLB
    glb_path = f"{filename_base}.glb"
    try:
        trimesh_obj.export(glb_path, file_type='glb')
        print(f"Saved GLB file: {glb_path}")
    except Exception as e:
        print(f"Error saving GLB: {e}")
        # Try alternative approach
        try:
            scene = trimesh.Scene()
            scene.add_geometry(trimesh_obj)
            scene.export(glb_path)
            print(f"Saved GLB using scene approach: {glb_path}")
        except Exception as e2:
            print(f"Alternative GLB export also failed: {e2}")
            glb_path = None
    
    # Save as OBJ - always works more reliably
    obj_path = f"{filename_base}.obj"
    try:
        trimesh_obj.export(obj_path, file_type='obj')
        print(f"Saved OBJ file: {obj_path}")
    except Exception as e:
        print(f"Error saving OBJ: {e}")
        # Try to write directly
        try:
            with open(obj_path, 'w') as f:
                for v in trimesh_obj.vertices:
                    f.write(f"v {v[0]} {v[1]} {v[2]}\n")
                for face in trimesh_obj.faces:
                    f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n")
            print(f"Saved OBJ using direct write: {obj_path}")
        except Exception as e2:
            print(f"Alternative OBJ export also failed: {e2}")
            obj_path = None
    
    # Also save as PLY as a fallback
    ply_path = f"{filename_base}.ply"
    try:
        trimesh_obj.export(ply_path, file_type='ply')
        print(f"Saved PLY file: {ply_path}")
    except Exception as e:
        print(f"Error saving PLY: {e}")
        ply_path = None
    
    return {
        "glb": os.path.basename(glb_path) if glb_path else None,
        "obj": os.path.basename(obj_path) if obj_path else None,
        "ply": os.path.basename(ply_path) if ply_path else None
    }

def process_job(job_id, prompt):
    try:
        # Get adaptive parameters
        adaptive_params = get_adaptive_parameters()
        karras_steps = adaptive_params['karras_steps']
        batch_size = adaptive_params['batch_size']
        guidance_scale = adaptive_params['guidance_scale']
        
        # Load primary models for generation
        load_primary_model()
        
        # Optimization: Run garbage collection before starting intensive task
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        print(f"Starting latent generation for job {job_id} with {karras_steps} steps...")
        
        # Generate latents
        latents = None
        with torch.inference_mode():
            latents = sample_latents(
                batch_size=batch_size,
                model=model,
                diffusion=diffusion,
                guidance_scale=guidance_scale,
                model_kwargs=dict(texts=[prompt] * batch_size),
                progress=True,
                clip_denoised=True,
                use_fp16=False,  # CPU doesn't support fp16
                use_karras=True,
                karras_steps=karras_steps,
                sigma_min=1e-3,
                sigma_max=160,
                s_churn=0,
            )
        print(f"Latent generation complete for job {job_id}!")
        
        # Optimization: Clear unnecessary memory and check pressure
        check_memory_pressure()
        
        # Generate a unique filename
        unique_id = str(uuid.uuid4())
        filename = f"{output_dir}/{unique_id}"
        
        # Load transmitter model for decoding
        load_transmitter_model()
        
        # Convert latent to mesh
        print(f"Decoding mesh for job {job_id}...")
        t0 = time.time()
        
        # Monitor memory
        mem_before = psutil.Process().memory_info().rss / (1024 * 1024)
        print(f"Memory before mesh decoding: {mem_before:.2f} MB")
        
        # Decode the mesh
        mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
        
        print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
        mem_after = psutil.Process().memory_info().rss / (1024 * 1024)
        print(f"Memory after decoding: {mem_after:.2f} MB (delta: {mem_after - mem_before:.2f} MB)")
        
        # Report mesh complexity if possible
        try:
            vertices_count = len(mesh.vertices)
            faces_count = len(mesh.faces)
            print(f"Mesh complexity: {vertices_count} vertices, {faces_count} faces")
        except Exception as e:
            print(f"Could not determine mesh complexity: {e}")
            vertices_count = 0
            faces_count = 0
        
        # Clear latents from memory
        del latents
        gc.collect()
        
        # Convert to trimesh format and save files
        print(f"Converting and saving mesh for job {job_id}...")
        
        # Save mesh using the helper function
        saved_files = save_trimesh(mesh, filename)
        
        # Clear mesh from memory
        del mesh
        gc.collect()
        
        # Check which files were successfully saved
        result = {
            "success": True,
            "message": "3D model generated successfully",
            "timestamp": time.time(),
            "stats": {
                "vertices": vertices_count,
                "faces": faces_count
            }
        }
        
        # Add URLs for the files that were saved
        if saved_files["glb"]:
            result["glb_url"] = f"/download/{saved_files['glb']}"
        if saved_files["obj"]:
            result["obj_url"] = f"/download/{saved_files['obj']}"
        if saved_files["ply"]:
            result["ply_url"] = f"/download/{saved_files['ply']}"
        
        # If no files were saved, mark as failure
        if not (saved_files["glb"] or saved_files["obj"] or saved_files["ply"]):
            result["success"] = False
            result["message"] = "Failed to save mesh in any format"
        
        print(f"Files saved successfully for job {job_id}!")
        
        # Force garbage collection again
        gc.collect()
        
        return result
    
    except Exception as e:
        print(f"Error during generation for job {job_id}: {str(e)}")
        import traceback
        traceback.print_exc()
        return {
            "success": False,
            "error": str(e),
            "timestamp": time.time()
        }

def worker_thread():
    global is_thread_running, active_jobs
    is_thread_running = True
    
    try:
        while True:
            try:
                # Get job from queue with a timeout
                job_id, prompt = job_queue.get(timeout=1)
                print(f"Processing job {job_id} with prompt: {prompt}")
                
                # Process the job
                result = process_job(job_id, prompt)
                
                # Store the result and update counter
                job_results[job_id] = result
                active_jobs -= 1
                
                # Explicit cleanup after job
                gc.collect()
                
            except queue.Empty:
                # No jobs in queue, continue waiting
                pass
            except Exception as e:
                print(f"Error in worker thread: {e}")
                import traceback
                traceback.print_exc()
                # If there was a job being processed, mark it as failed
                if 'job_id' in locals():
                    job_results[job_id] = {
                        "success": False,
                        "error": str(e),
                        "timestamp": time.time()
                    }
                    active_jobs -= 1
                    
                # Force garbage collection to clean up
                gc.collect()
    finally:
        is_thread_running = False

def purge_old_results_thread():
    """Thread that periodically cleans up old job results to manage memory"""
    while True:
        try:
            time.sleep(1800)  # Run every 30 minutes
            
            # Default threshold: 2 hours
            threshold_time = time.time() - (2 * 3600)
            
            # Track jobs to be removed
            jobs_to_remove = []
            for job_id, result in job_results.items():
                # If the job has a timestamp and it's older than threshold
                if result.get('timestamp', time.time()) < threshold_time:
                    jobs_to_remove.append(job_id)
            
            # Remove the old jobs
            for job_id in jobs_to_remove:
                job_results.pop(job_id, None)
                
            if jobs_to_remove:
                print(f"Auto-purged {len(jobs_to_remove)} old job results")
                # Force garbage collection
                gc.collect()
        except Exception as e:
            print(f"Error in purge thread: {e}")

def ensure_worker_thread_running():
    global generation_thread, is_thread_running
    
    if generation_thread is None or not generation_thread.is_alive():
        print("Starting worker thread...")
        generation_thread = threading.Thread(target=worker_thread, daemon=True)
        generation_thread.start()

def start_monitoring_threads():
    """Start all monitoring and maintenance threads"""
    # Start model unloader thread
    threading.Thread(target=model_unloader_thread, daemon=True).start()
    
    # Start results purge thread
    threading.Thread(target=purge_old_results_thread, daemon=True).start()

@app.route('/generate', methods=['POST'])
def generate_3d():
    global active_jobs
    
    # Check if we're already at max capacity
    if active_jobs >= max_concurrent_jobs:
        return jsonify({
            "success": False,
            "error": "Server is at maximum capacity. Please try again later.",
            "retry_after": 300
        }), 503
    
    # Get the prompt from the request
    data = request.json
    if not data or 'prompt' not in data:
        return jsonify({"error": "No prompt provided"}), 400
    
    prompt = data['prompt']
    print(f"Received prompt: {prompt}")
    
    # Generate a job ID
    job_id = str(uuid.uuid4())
    
    # Add job to queue
    ensure_worker_thread_running()
    job_queue.put((job_id, prompt))
    active_jobs += 1
    
    # Return job ID immediately
    return jsonify({
        "success": True,
        "message": "Job submitted successfully",
        "job_id": job_id,
        "status_url": f"/status/{job_id}"
    })

@app.route('/status/<job_id>', methods=['GET'])
def job_status(job_id):
    if job_id in job_results:
        result = job_results[job_id]
        # Return the result
        return jsonify(result)
    else:
        # Job is still in progress
        return jsonify({
            "success": None,
            "message": "Job is still processing",
            "job_id": job_id
        })

@app.route('/download/<filename>', methods=['GET'])
def download_file(filename):
    try:
        file_path = os.path.join(output_dir, filename)
        if not os.path.exists(file_path):
            return jsonify({"error": "File not found"}), 404
            
        return send_file(file_path, as_attachment=True)
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/health', methods=['GET'])
def health_check():
    """Enhanced health check endpoint to monitor resource usage"""
    try:
        # Memory info
        memory_info = psutil.virtual_memory()
        memory_usage = f"{memory_info.percent}% (Available: {memory_info.available / (1024**3):.2f} GB)"
        
        # CPU info
        cpu_usage = f"{psutil.cpu_percent(interval=0.1)}%"
        
        # Process specific info
        process = psutil.Process()
        process_memory = f"{process.memory_info().rss / (1024**3):.2f} GB"
        
        # Models status
        models_loaded = []
        if model is not None:
            models_loaded.append("text300M")
        if diffusion is not None:
            models_loaded.append("diffusion")
        if xm is not None:
            models_loaded.append("transmitter")
        
        # Queue status
        queue_size = job_queue.qsize()
        
        # Check for model inactivity
        model_inactive = "N/A"
        if last_usage_time is not None:
            model_inactive = f"{(time.time() - last_usage_time) / 60:.1f} minutes"
        
        # Number of saved jobs
        saved_jobs = len(job_results)
        
        return jsonify({
            "status": "ok",
            "message": "Service is running",
            "memory_usage": memory_usage,
            "process_memory": process_memory,
            "cpu_usage": cpu_usage,
            "queue_size": queue_size,
            "active_jobs": active_jobs,
            "saved_jobs": saved_jobs,
            "worker_running": is_thread_running,
            "models_loaded": models_loaded,
            "model_inactive_time": model_inactive
        })
    except Exception as e:
        return jsonify({
            "status": "warning",
            "error": str(e)
        })

@app.route('/', methods=['GET'])
def home():
    """Landing page with usage instructions"""
    return """
    <html>
        <head>
            <title>Text to 3D API</title>
            <style>
                body { font-family: Arial, sans-serif; line-height: 1.6; margin: 0; padding: 20px; max-width: 800px; margin: 0 auto; }
                pre { background: #f4f4f4; padding: 15px; border-radius: 5px; overflow-x: auto; }
                code { background: #f4f4f4; padding: 2px 5px; border-radius: 3px; }
                h1, h2 { color: #333; }
            </style>
        </head>
        <body>
            <h1>Text to 3D API</h1>
            <p>This is an optimized API that converts text prompts to 3D models.</p>
            
            <h2>How to use:</h2>
            <h3>Step 1: Submit a generation job</h3>
            <pre>
POST /generate
Content-Type: application/json
{
    "prompt": "A futuristic building"
}
            </pre>
            <p>Response:</p>
            <pre>
{
    "success": true,
    "message": "Job submitted successfully",
    "job_id": "123e4567-e89b-12d3-a456-426614174000",
    "status_url": "/status/123e4567-e89b-12d3-a456-426614174000"
}
            </pre>
            
            <h3>Step 2: Check job status</h3>
            <pre>
GET /status/123e4567-e89b-12d3-a456-426614174000
            </pre>
            <p>Response (while processing):</p>
            <pre>
{
    "success": null,
    "message": "Job is still processing",
    "job_id": "123e4567-e89b-12d3-a456-426614174000"
}
            </pre>
            <p>Response (when complete):</p>
            <pre>
{
    "success": true,
    "message": "3D model generated successfully",
    "glb_url": "/download/abc123.glb",
    "obj_url": "/download/abc123.obj",
    "ply_url": "/download/abc123.ply"
}
            </pre>
            
            <h3>Step 3: Download the files</h3>
            <p>Use the provided URLs to download the GLB, OBJ, and PLY files.</p>
            
            <h2>Health Check:</h2>
            <pre>GET /health</pre>
            <p>Provides information about the service status and resource usage.</p>
        </body>
    </html>
    """

@app.route('/purge-results', methods=['POST'])
def purge_old_results():
    """Endpoint to manually purge old job results to free memory"""
    try:
        # Get the time threshold from request (default to 1 hour)
        threshold_hours = request.json.get('threshold_hours', 1) if request.json else 1
        threshold_time = time.time() - (threshold_hours * 3600)
        
        # Track jobs to be removed
        jobs_to_remove = []
        for job_id, result in job_results.items():
            # If the job has a timestamp and it's older than threshold
            if result.get('timestamp', time.time()) < threshold_time:
                jobs_to_remove.append(job_id)
        
        # Remove the old jobs
        for job_id in jobs_to_remove:
            job_results.pop(job_id, None)
            
        # Force garbage collection
        gc.collect()
        
        return jsonify({
            "success": True,
            "message": f"Purged {len(jobs_to_remove)} old job results",
            "remaining_jobs": len(job_results)
        })
    except Exception as e:
        return jsonify({
            "success": False,
            "error": str(e)
        }), 500

@app.route('/force-gc', methods=['POST'])
def force_garbage_collection():
    """Endpoint to manually trigger garbage collection"""
    try:
        # Get current memory usage
        before_mem = psutil.Process().memory_info().rss / (1024**3)
        
        # Force garbage collection
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Get memory usage after GC
        after_mem = psutil.Process().memory_info().rss / (1024**3)
        freed = before_mem - after_mem
        
        return jsonify({
            "success": True,
            "message": f"Garbage collection completed",
            "before_memory_gb": round(before_mem, 2),
            "after_memory_gb": round(after_mem, 2),
            "freed_memory_gb": round(freed, 2) if freed > 0 else 0
        })
    except Exception as e:
        return jsonify({
            "success": False,
            "error": str(e)
        }), 500

if __name__ == '__main__':
    # Start all monitoring threads
    start_monitoring_threads()
    
    # Start the worker thread
    ensure_worker_thread_running()
    
    # Recommended to run with gunicorn for production with increased timeout:
    # $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1
    app.run(host='0.0.0.0', port=7860, debug=False)  # Set debug=False in production