Spaces:
Running
Running
from flask import Flask, request, jsonify, send_file | |
from flask_cors import CORS | |
import torch | |
import numpy as np | |
import trimesh | |
import os | |
from io import BytesIO | |
import base64 | |
from PIL import Image | |
import uuid | |
import time | |
import sys | |
import gc # For explicit garbage collection | |
import threading | |
import queue | |
import psutil | |
# Set environment variables before anything else | |
os.environ['SHAPEE_NO_INTERACTIVE'] = '1' | |
# Setup cache directory with appropriate permissions | |
cache_dir = os.path.join(os.getcwd(), 'shap_e_model_cache') | |
os.makedirs(cache_dir, exist_ok=True) | |
os.environ['XDG_CACHE_HOME'] = os.getcwd() | |
print(f"Using cache directory: {cache_dir}") | |
# Import Shap-E | |
print("Importing Shap-E modules...") | |
try: | |
# Try the direct import approach first | |
from shap_e.diffusion.sample import sample_latents | |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
from shap_e.models.download import load_model, load_config | |
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh | |
print("Shap-E modules imported successfully!") | |
except ImportError as e: | |
print(f"Error importing Shap-E modules: {e}") | |
# Alternative approach if direct import fails | |
try: | |
print("Attempting alternative import approach...") | |
# Try monkey patching the ipywidgets module if that's the issue | |
import sys | |
import types | |
if 'ipywidgets' not in sys.modules: | |
sys.modules['ipywidgets'] = types.ModuleType('ipywidgets') | |
print("Added mock ipywidgets module") | |
# Try imports again | |
from shap_e.diffusion.sample import sample_latents | |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
from shap_e.models.download import load_model, load_config | |
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh | |
print("Shap-E modules imported successfully with workaround!") | |
except Exception as e2: | |
print(f"Alternative import also failed: {e2}") | |
sys.exit(1) | |
except Exception as e: | |
print(f"Unexpected error importing Shap-E modules: {e}") | |
sys.exit(1) | |
app = Flask(__name__) | |
CORS(app) | |
# Create output directory if it doesn't exist | |
output_dir = os.path.join(os.getcwd(), "outputs") | |
os.makedirs(output_dir, exist_ok=True) | |
print(f"Output directory: {output_dir}") | |
# Check permissions on directories | |
try: | |
test_file_path = os.path.join(cache_dir, "test_write_permissions.txt") | |
with open(test_file_path, 'w') as f: | |
f.write("Testing write permissions") | |
os.remove(test_file_path) | |
print("Cache directory is writable") | |
except Exception as e: | |
print(f"WARNING: Cache directory is not writable: {e}") | |
try: | |
test_file_path = os.path.join(output_dir, "test_write_permissions.txt") | |
with open(test_file_path, 'w') as f: | |
f.write("Testing write permissions") | |
os.remove(test_file_path) | |
print("Output directory is writable") | |
except Exception as e: | |
print(f"WARNING: Output directory is not writable: {e}") | |
print("Setting up device...") | |
device = torch.device('cpu') # Force CPU for Hugging Face Spaces | |
print(f"Using device: {device}") | |
# Global variables for models (will be loaded on first request) | |
xm = None | |
model = None | |
diffusion = None | |
# Job queue and results dictionary | |
job_queue = queue.Queue() | |
job_results = {} | |
generation_thread = None | |
is_thread_running = False | |
# New global variables for optimizations | |
last_usage_time = None | |
active_jobs = 0 | |
max_concurrent_jobs = 1 # Limit concurrent jobs for 2vCPU | |
def get_adaptive_parameters(): | |
"""Adjust parameters based on current system resources""" | |
mem = psutil.virtual_memory() | |
# Base parameters - more conservative to prevent memory issues | |
params = { | |
'karras_steps': 6, # Reduced from 8 to 6 as default | |
'batch_size': 1, | |
'guidance_scale': 15.0 | |
} | |
# If memory is tight, reduce steps further | |
if mem.percent > 70: | |
params['karras_steps'] = 4 # Even more conservative | |
# If we have more memory to spare, can be slightly more generous | |
if mem.percent < 50: | |
params['karras_steps'] = 8 | |
print(f"Adaptive parameters chosen: karras_steps={params['karras_steps']}, mem={mem.percent}%") | |
return params | |
def check_memory_pressure(): | |
"""Check if memory is getting too high and take action if needed""" | |
mem = psutil.virtual_memory() | |
if mem.percent > 80: # Reduced threshold from 85 to 80 | |
print("WARNING: Memory pressure critical. Forcing garbage collection.") | |
gc.collect() | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
# If still critical, try more aggressive measures | |
if psutil.virtual_memory().percent > 75: | |
print("EMERGENCY: Memory still critical. Clearing model cache.") | |
# Reset global models to force reload when memory is better | |
global xm, model, diffusion | |
xm, model, diffusion = None, None, None | |
gc.collect() | |
return True | |
return False | |
def load_transmitter_model(): | |
global xm, last_usage_time | |
last_usage_time = time.time() | |
if xm is None: | |
print("Loading transmitter model...") | |
xm = load_model('transmitter', device=device) | |
print("Transmitter model loaded!") | |
def load_primary_model(): | |
global model, diffusion, last_usage_time | |
last_usage_time = time.time() | |
if model is None or diffusion is None: | |
print("Loading primary models...") | |
torch.set_default_dtype(torch.float32) # Use float32 instead of float64 | |
model = load_model('text300M', device=device) | |
diffusion = diffusion_from_config(load_config('diffusion')) | |
print("Primary models loaded!") | |
def load_models_if_needed(): | |
"""Legacy function for compatibility""" | |
load_primary_model() | |
load_transmitter_model() | |
def model_unloader_thread(): | |
"""Thread that periodically unloads models if they haven't been used""" | |
global xm, model, diffusion, last_usage_time | |
while True: | |
time.sleep(180) # Check more frequently: every 3 minutes instead of 5 | |
if last_usage_time is not None: | |
idle_time = time.time() - last_usage_time | |
# If models have been idle for more than 5 minutes (reduced from 10) and no active jobs | |
if idle_time > 300 and active_jobs == 0: | |
# Check memory usage - more aggressive unloading | |
mem = psutil.virtual_memory() | |
if mem.percent > 40: # Lowered threshold from 50 to 40 | |
print(f"Models idle for {idle_time:.1f} seconds and memory at {mem.percent}%. Unloading...") | |
xm, model, diffusion = None, None, None | |
gc.collect() | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
def save_trimesh(mesh, filename_base): | |
"""Save mesh in multiple formats using trimesh""" | |
# Convert to trimesh format if needed | |
if not isinstance(mesh, trimesh.Trimesh): | |
try: | |
# Try to convert to trimesh | |
vertices = np.array(mesh.vertices) | |
faces = np.array(mesh.faces) | |
trimesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces) | |
except Exception as e: | |
print(f"Error converting to trimesh: {e}") | |
raise | |
else: | |
trimesh_obj = mesh | |
# Save as GLB | |
glb_path = f"{filename_base}.glb" | |
try: | |
trimesh_obj.export(glb_path, file_type='glb') | |
print(f"Saved GLB file: {glb_path}") | |
except Exception as e: | |
print(f"Error saving GLB: {e}") | |
# Try alternative approach | |
try: | |
scene = trimesh.Scene() | |
scene.add_geometry(trimesh_obj) | |
scene.export(glb_path) | |
print(f"Saved GLB using scene approach: {glb_path}") | |
except Exception as e2: | |
print(f"Alternative GLB export also failed: {e2}") | |
glb_path = None | |
# Save as OBJ - always works more reliably | |
obj_path = f"{filename_base}.obj" | |
try: | |
trimesh_obj.export(obj_path, file_type='obj') | |
print(f"Saved OBJ file: {obj_path}") | |
except Exception as e: | |
print(f"Error saving OBJ: {e}") | |
# Try to write directly | |
try: | |
with open(obj_path, 'w') as f: | |
for v in trimesh_obj.vertices: | |
f.write(f"v {v[0]} {v[1]} {v[2]}\n") | |
for face in trimesh_obj.faces: | |
f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n") | |
print(f"Saved OBJ using direct write: {obj_path}") | |
except Exception as e2: | |
print(f"Alternative OBJ export also failed: {e2}") | |
obj_path = None | |
# Also save as PLY as a fallback | |
ply_path = f"{filename_base}.ply" | |
try: | |
trimesh_obj.export(ply_path, file_type='ply') | |
print(f"Saved PLY file: {ply_path}") | |
except Exception as e: | |
print(f"Error saving PLY: {e}") | |
ply_path = None | |
return { | |
"glb": os.path.basename(glb_path) if glb_path else None, | |
"obj": os.path.basename(obj_path) if obj_path else None, | |
"ply": os.path.basename(ply_path) if ply_path else None | |
} | |
def process_job(job_id, prompt): | |
try: | |
# Get adaptive parameters | |
adaptive_params = get_adaptive_parameters() | |
karras_steps = adaptive_params['karras_steps'] | |
batch_size = adaptive_params['batch_size'] | |
guidance_scale = adaptive_params['guidance_scale'] | |
# Load primary models for generation | |
load_primary_model() | |
# Optimization: Run garbage collection before starting intensive task | |
gc.collect() | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
print(f"Starting latent generation for job {job_id} with {karras_steps} steps...") | |
# Generate latents | |
latents = None | |
with torch.inference_mode(): | |
latents = sample_latents( | |
batch_size=batch_size, | |
model=model, | |
diffusion=diffusion, | |
guidance_scale=guidance_scale, | |
model_kwargs=dict(texts=[prompt] * batch_size), | |
progress=True, | |
clip_denoised=True, | |
use_fp16=False, # CPU doesn't support fp16 | |
use_karras=True, | |
karras_steps=karras_steps, | |
sigma_min=1e-3, | |
sigma_max=160, | |
s_churn=0, | |
) | |
print(f"Latent generation complete for job {job_id}!") | |
# Optimization: Clear unnecessary memory and check pressure | |
check_memory_pressure() | |
# Generate a unique filename | |
unique_id = str(uuid.uuid4()) | |
filename = f"{output_dir}/{unique_id}" | |
# Load transmitter model for decoding | |
load_transmitter_model() | |
# Convert latent to mesh | |
print(f"Decoding mesh for job {job_id}...") | |
t0 = time.time() | |
# Monitor memory | |
mem_before = psutil.Process().memory_info().rss / (1024 * 1024) | |
print(f"Memory before mesh decoding: {mem_before:.2f} MB") | |
# Decode the mesh | |
mesh = decode_latent_mesh(xm, latents[0]).tri_mesh() | |
print(f"Mesh decoded in {time.time() - t0:.2f} seconds") | |
mem_after = psutil.Process().memory_info().rss / (1024 * 1024) | |
print(f"Memory after decoding: {mem_after:.2f} MB (delta: {mem_after - mem_before:.2f} MB)") | |
# Report mesh complexity if possible | |
try: | |
vertices_count = len(mesh.vertices) | |
faces_count = len(mesh.faces) | |
print(f"Mesh complexity: {vertices_count} vertices, {faces_count} faces") | |
except Exception as e: | |
print(f"Could not determine mesh complexity: {e}") | |
vertices_count = 0 | |
faces_count = 0 | |
# Clear latents from memory | |
del latents | |
gc.collect() | |
# Convert to trimesh format and save files | |
print(f"Converting and saving mesh for job {job_id}...") | |
# Save mesh using the helper function | |
saved_files = save_trimesh(mesh, filename) | |
# Clear mesh from memory | |
del mesh | |
gc.collect() | |
# Check which files were successfully saved | |
result = { | |
"success": True, | |
"message": "3D model generated successfully", | |
"timestamp": time.time(), | |
"stats": { | |
"vertices": vertices_count, | |
"faces": faces_count | |
} | |
} | |
# Add URLs for the files that were saved | |
if saved_files["glb"]: | |
result["glb_url"] = f"/download/{saved_files['glb']}" | |
if saved_files["obj"]: | |
result["obj_url"] = f"/download/{saved_files['obj']}" | |
if saved_files["ply"]: | |
result["ply_url"] = f"/download/{saved_files['ply']}" | |
# If no files were saved, mark as failure | |
if not (saved_files["glb"] or saved_files["obj"] or saved_files["ply"]): | |
result["success"] = False | |
result["message"] = "Failed to save mesh in any format" | |
print(f"Files saved successfully for job {job_id}!") | |
# Force garbage collection again | |
gc.collect() | |
return result | |
except Exception as e: | |
print(f"Error during generation for job {job_id}: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return { | |
"success": False, | |
"error": str(e), | |
"timestamp": time.time() | |
} | |
def worker_thread(): | |
global is_thread_running, active_jobs | |
is_thread_running = True | |
try: | |
while True: | |
try: | |
# Get job from queue with a timeout | |
job_id, prompt = job_queue.get(timeout=1) | |
print(f"Processing job {job_id} with prompt: {prompt}") | |
# Process the job | |
result = process_job(job_id, prompt) | |
# Store the result and update counter | |
job_results[job_id] = result | |
active_jobs -= 1 | |
# Explicit cleanup after job | |
gc.collect() | |
except queue.Empty: | |
# No jobs in queue, continue waiting | |
pass | |
except Exception as e: | |
print(f"Error in worker thread: {e}") | |
import traceback | |
traceback.print_exc() | |
# If there was a job being processed, mark it as failed | |
if 'job_id' in locals(): | |
job_results[job_id] = { | |
"success": False, | |
"error": str(e), | |
"timestamp": time.time() | |
} | |
active_jobs -= 1 | |
# Force garbage collection to clean up | |
gc.collect() | |
finally: | |
is_thread_running = False | |
def purge_old_results_thread(): | |
"""Thread that periodically cleans up old job results to manage memory""" | |
while True: | |
try: | |
time.sleep(1800) # Run every 30 minutes | |
# Default threshold: 2 hours | |
threshold_time = time.time() - (2 * 3600) | |
# Track jobs to be removed | |
jobs_to_remove = [] | |
for job_id, result in job_results.items(): | |
# If the job has a timestamp and it's older than threshold | |
if result.get('timestamp', time.time()) < threshold_time: | |
jobs_to_remove.append(job_id) | |
# Remove the old jobs | |
for job_id in jobs_to_remove: | |
job_results.pop(job_id, None) | |
if jobs_to_remove: | |
print(f"Auto-purged {len(jobs_to_remove)} old job results") | |
# Force garbage collection | |
gc.collect() | |
except Exception as e: | |
print(f"Error in purge thread: {e}") | |
def ensure_worker_thread_running(): | |
global generation_thread, is_thread_running | |
if generation_thread is None or not generation_thread.is_alive(): | |
print("Starting worker thread...") | |
generation_thread = threading.Thread(target=worker_thread, daemon=True) | |
generation_thread.start() | |
def start_monitoring_threads(): | |
"""Start all monitoring and maintenance threads""" | |
# Start model unloader thread | |
threading.Thread(target=model_unloader_thread, daemon=True).start() | |
# Start results purge thread | |
threading.Thread(target=purge_old_results_thread, daemon=True).start() | |
def generate_3d(): | |
global active_jobs | |
# Check if we're already at max capacity | |
if active_jobs >= max_concurrent_jobs: | |
return jsonify({ | |
"success": False, | |
"error": "Server is at maximum capacity. Please try again later.", | |
"retry_after": 300 | |
}), 503 | |
# Get the prompt from the request | |
data = request.json | |
if not data or 'prompt' not in data: | |
return jsonify({"error": "No prompt provided"}), 400 | |
prompt = data['prompt'] | |
print(f"Received prompt: {prompt}") | |
# Generate a job ID | |
job_id = str(uuid.uuid4()) | |
# Add job to queue | |
ensure_worker_thread_running() | |
job_queue.put((job_id, prompt)) | |
active_jobs += 1 | |
# Return job ID immediately | |
return jsonify({ | |
"success": True, | |
"message": "Job submitted successfully", | |
"job_id": job_id, | |
"status_url": f"/status/{job_id}" | |
}) | |
def job_status(job_id): | |
if job_id in job_results: | |
result = job_results[job_id] | |
# Return the result | |
return jsonify(result) | |
else: | |
# Job is still in progress | |
return jsonify({ | |
"success": None, | |
"message": "Job is still processing", | |
"job_id": job_id | |
}) | |
def download_file(filename): | |
try: | |
file_path = os.path.join(output_dir, filename) | |
if not os.path.exists(file_path): | |
return jsonify({"error": "File not found"}), 404 | |
return send_file(file_path, as_attachment=True) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
def health_check(): | |
"""Enhanced health check endpoint to monitor resource usage""" | |
try: | |
# Memory info | |
memory_info = psutil.virtual_memory() | |
memory_usage = f"{memory_info.percent}% (Available: {memory_info.available / (1024**3):.2f} GB)" | |
# CPU info | |
cpu_usage = f"{psutil.cpu_percent(interval=0.1)}%" | |
# Process specific info | |
process = psutil.Process() | |
process_memory = f"{process.memory_info().rss / (1024**3):.2f} GB" | |
# Models status | |
models_loaded = [] | |
if model is not None: | |
models_loaded.append("text300M") | |
if diffusion is not None: | |
models_loaded.append("diffusion") | |
if xm is not None: | |
models_loaded.append("transmitter") | |
# Queue status | |
queue_size = job_queue.qsize() | |
# Check for model inactivity | |
model_inactive = "N/A" | |
if last_usage_time is not None: | |
model_inactive = f"{(time.time() - last_usage_time) / 60:.1f} minutes" | |
# Number of saved jobs | |
saved_jobs = len(job_results) | |
return jsonify({ | |
"status": "ok", | |
"message": "Service is running", | |
"memory_usage": memory_usage, | |
"process_memory": process_memory, | |
"cpu_usage": cpu_usage, | |
"queue_size": queue_size, | |
"active_jobs": active_jobs, | |
"saved_jobs": saved_jobs, | |
"worker_running": is_thread_running, | |
"models_loaded": models_loaded, | |
"model_inactive_time": model_inactive | |
}) | |
except Exception as e: | |
return jsonify({ | |
"status": "warning", | |
"error": str(e) | |
}) | |
def home(): | |
"""Landing page with usage instructions""" | |
return """ | |
<html> | |
<head> | |
<title>Text to 3D API</title> | |
<style> | |
body { font-family: Arial, sans-serif; line-height: 1.6; margin: 0; padding: 20px; max-width: 800px; margin: 0 auto; } | |
pre { background: #f4f4f4; padding: 15px; border-radius: 5px; overflow-x: auto; } | |
code { background: #f4f4f4; padding: 2px 5px; border-radius: 3px; } | |
h1, h2 { color: #333; } | |
</style> | |
</head> | |
<body> | |
<h1>Text to 3D API</h1> | |
<p>This is an optimized API that converts text prompts to 3D models.</p> | |
<h2>How to use:</h2> | |
<h3>Step 1: Submit a generation job</h3> | |
<pre> | |
POST /generate | |
Content-Type: application/json | |
{ | |
"prompt": "A futuristic building" | |
} | |
</pre> | |
<p>Response:</p> | |
<pre> | |
{ | |
"success": true, | |
"message": "Job submitted successfully", | |
"job_id": "123e4567-e89b-12d3-a456-426614174000", | |
"status_url": "/status/123e4567-e89b-12d3-a456-426614174000" | |
} | |
</pre> | |
<h3>Step 2: Check job status</h3> | |
<pre> | |
GET /status/123e4567-e89b-12d3-a456-426614174000 | |
</pre> | |
<p>Response (while processing):</p> | |
<pre> | |
{ | |
"success": null, | |
"message": "Job is still processing", | |
"job_id": "123e4567-e89b-12d3-a456-426614174000" | |
} | |
</pre> | |
<p>Response (when complete):</p> | |
<pre> | |
{ | |
"success": true, | |
"message": "3D model generated successfully", | |
"glb_url": "/download/abc123.glb", | |
"obj_url": "/download/abc123.obj", | |
"ply_url": "/download/abc123.ply" | |
} | |
</pre> | |
<h3>Step 3: Download the files</h3> | |
<p>Use the provided URLs to download the GLB, OBJ, and PLY files.</p> | |
<h2>Health Check:</h2> | |
<pre>GET /health</pre> | |
<p>Provides information about the service status and resource usage.</p> | |
</body> | |
</html> | |
""" | |
def purge_old_results(): | |
"""Endpoint to manually purge old job results to free memory""" | |
try: | |
# Get the time threshold from request (default to 1 hour) | |
threshold_hours = request.json.get('threshold_hours', 1) if request.json else 1 | |
threshold_time = time.time() - (threshold_hours * 3600) | |
# Track jobs to be removed | |
jobs_to_remove = [] | |
for job_id, result in job_results.items(): | |
# If the job has a timestamp and it's older than threshold | |
if result.get('timestamp', time.time()) < threshold_time: | |
jobs_to_remove.append(job_id) | |
# Remove the old jobs | |
for job_id in jobs_to_remove: | |
job_results.pop(job_id, None) | |
# Force garbage collection | |
gc.collect() | |
return jsonify({ | |
"success": True, | |
"message": f"Purged {len(jobs_to_remove)} old job results", | |
"remaining_jobs": len(job_results) | |
}) | |
except Exception as e: | |
return jsonify({ | |
"success": False, | |
"error": str(e) | |
}), 500 | |
def force_garbage_collection(): | |
"""Endpoint to manually trigger garbage collection""" | |
try: | |
# Get current memory usage | |
before_mem = psutil.Process().memory_info().rss / (1024**3) | |
# Force garbage collection | |
gc.collect() | |
torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
# Get memory usage after GC | |
after_mem = psutil.Process().memory_info().rss / (1024**3) | |
freed = before_mem - after_mem | |
return jsonify({ | |
"success": True, | |
"message": f"Garbage collection completed", | |
"before_memory_gb": round(before_mem, 2), | |
"after_memory_gb": round(after_mem, 2), | |
"freed_memory_gb": round(freed, 2) if freed > 0 else 0 | |
}) | |
except Exception as e: | |
return jsonify({ | |
"success": False, | |
"error": str(e) | |
}), 500 | |
if __name__ == '__main__': | |
# Start all monitoring threads | |
start_monitoring_threads() | |
# Start the worker thread | |
ensure_worker_thread_running() | |
# Recommended to run with gunicorn for production with increased timeout: | |
# $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1 | |
app.run(host='0.0.0.0', port=7860, debug=False) # Set debug=False in production |