Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,9 @@ import uuid
|
|
11 |
import time
|
12 |
import sys
|
13 |
import gc # For explicit garbage collection
|
|
|
|
|
|
|
14 |
|
15 |
# Set environment variables before anything else
|
16 |
os.environ['SHAPEE_NO_INTERACTIVE'] = '1'
|
@@ -92,11 +95,19 @@ xm = None
|
|
92 |
model = None
|
93 |
diffusion = None
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
def load_models_if_needed():
|
96 |
global xm, model, diffusion
|
97 |
if xm is None or model is None or diffusion is None:
|
98 |
print("Loading models for the first time...")
|
99 |
try:
|
|
|
|
|
100 |
xm = load_model('transmitter', device=device)
|
101 |
model = load_model('text300M', device=device)
|
102 |
diffusion = diffusion_from_config(load_config('diffusion'))
|
@@ -105,78 +116,103 @@ def load_models_if_needed():
|
|
105 |
print(f"Error loading models: {e}")
|
106 |
raise
|
107 |
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
try:
|
111 |
# Load models if not already loaded
|
112 |
load_models_if_needed()
|
113 |
|
114 |
-
# Get the prompt from the request
|
115 |
-
data = request.json
|
116 |
-
if not data or 'prompt' not in data:
|
117 |
-
return jsonify({"error": "No prompt provided"}), 400
|
118 |
-
|
119 |
-
prompt = data['prompt']
|
120 |
-
print(f"Received prompt: {prompt}")
|
121 |
-
|
122 |
# Set parameters for CPU performance (reduced steps and other optimizations)
|
123 |
batch_size = 1
|
124 |
guidance_scale = 15.0
|
125 |
|
126 |
-
# *** OPTIMIZATION: Significantly reduce steps for low-memory environments ***
|
127 |
-
karras_steps =
|
128 |
|
129 |
# *** OPTIMIZATION: Run garbage collection before starting intensive task ***
|
130 |
gc.collect()
|
131 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
132 |
|
133 |
# Generate latents with the text-to-3D model
|
134 |
-
print("Starting latent generation with
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
151 |
|
152 |
# *** OPTIMIZATION: Run garbage collection after intensive step ***
|
153 |
gc.collect()
|
154 |
-
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
155 |
|
156 |
# Generate a unique filename
|
157 |
unique_id = str(uuid.uuid4())
|
158 |
filename = f"{output_dir}/{unique_id}"
|
159 |
|
160 |
-
# Convert latent to mesh
|
161 |
-
print("Decoding mesh...")
|
162 |
t0 = time.time()
|
163 |
|
164 |
-
# *** OPTIMIZATION: Use simplified decoding
|
165 |
-
mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
|
166 |
print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
|
167 |
|
168 |
# *** OPTIMIZATION: Clear latents from memory as they're no longer needed ***
|
169 |
del latents
|
170 |
gc.collect()
|
171 |
-
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
172 |
|
173 |
# Save as GLB
|
174 |
-
print("Saving as GLB...")
|
175 |
glb_path = f"{filename}.glb"
|
176 |
mesh.write_glb(glb_path)
|
177 |
|
178 |
# Save as OBJ
|
179 |
-
print("Saving as OBJ...")
|
180 |
obj_path = f"{filename}.obj"
|
181 |
with open(obj_path, 'w') as f:
|
182 |
mesh.write_obj(f)
|
@@ -185,21 +221,74 @@ def generate_3d():
|
|
185 |
del mesh
|
186 |
gc.collect()
|
187 |
|
188 |
-
print("Files saved successfully!")
|
189 |
|
190 |
# Return paths to the generated files
|
191 |
-
return
|
192 |
"success": True,
|
193 |
"message": "3D model generated successfully",
|
194 |
"glb_url": f"/download/{os.path.basename(glb_path)}",
|
195 |
"obj_url": f"/download/{os.path.basename(obj_path)}"
|
196 |
-
}
|
197 |
|
198 |
except Exception as e:
|
199 |
-
print(f"Error during generation: {str(e)}")
|
200 |
import traceback
|
201 |
traceback.print_exc()
|
202 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
@app.route('/download/<filename>', methods=['GET'])
|
205 |
def download_file(filename):
|
@@ -213,16 +302,31 @@ def health_check():
|
|
213 |
"""Simple health check endpoint to verify the app is running"""
|
214 |
# Check available memory
|
215 |
try:
|
216 |
-
import psutil
|
217 |
memory_info = psutil.virtual_memory()
|
218 |
memory_usage = f"{memory_info.percent}% (Available: {memory_info.available / (1024**3):.2f} GB)"
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
return jsonify({
|
223 |
"status": "ok",
|
224 |
"message": "Service is running",
|
225 |
-
"memory_usage": memory_usage
|
|
|
|
|
|
|
|
|
226 |
})
|
227 |
|
228 |
@app.route('/', methods=['GET'])
|
@@ -230,25 +334,75 @@ def home():
|
|
230 |
"""Landing page with usage instructions"""
|
231 |
return """
|
232 |
<html>
|
233 |
-
<head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
<body>
|
235 |
<h1>Text to 3D API</h1>
|
236 |
-
<p>This is
|
|
|
237 |
<h2>How to use:</h2>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
<pre>
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
245 |
</pre>
|
246 |
-
<p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
</body>
|
248 |
</html>
|
249 |
"""
|
250 |
|
251 |
if __name__ == '__main__':
|
|
|
|
|
|
|
252 |
# Recommended to run with gunicorn for production with increased timeout:
|
253 |
# $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1
|
254 |
-
app.run(host='0.0.0.0', port=7860, debug=
|
|
|
11 |
import time
|
12 |
import sys
|
13 |
import gc # For explicit garbage collection
|
14 |
+
import threading
|
15 |
+
import queue
|
16 |
+
import psutil
|
17 |
|
18 |
# Set environment variables before anything else
|
19 |
os.environ['SHAPEE_NO_INTERACTIVE'] = '1'
|
|
|
95 |
model = None
|
96 |
diffusion = None
|
97 |
|
98 |
+
# Job queue and results dictionary
|
99 |
+
job_queue = queue.Queue()
|
100 |
+
job_results = {}
|
101 |
+
generation_thread = None
|
102 |
+
is_thread_running = False
|
103 |
+
|
104 |
def load_models_if_needed():
|
105 |
global xm, model, diffusion
|
106 |
if xm is None or model is None or diffusion is None:
|
107 |
print("Loading models for the first time...")
|
108 |
try:
|
109 |
+
# Set lower precision for memory optimization
|
110 |
+
torch.set_default_dtype(torch.float32) # Use float32 instead of float64
|
111 |
xm = load_model('transmitter', device=device)
|
112 |
model = load_model('text300M', device=device)
|
113 |
diffusion = diffusion_from_config(load_config('diffusion'))
|
|
|
116 |
print(f"Error loading models: {e}")
|
117 |
raise
|
118 |
|
119 |
+
def worker_thread():
|
120 |
+
global is_thread_running
|
121 |
+
is_thread_running = True
|
122 |
+
|
123 |
+
try:
|
124 |
+
while True:
|
125 |
+
try:
|
126 |
+
# Get job from queue with a timeout
|
127 |
+
job_id, prompt = job_queue.get(timeout=1)
|
128 |
+
print(f"Processing job {job_id} with prompt: {prompt}")
|
129 |
+
|
130 |
+
# Process the job
|
131 |
+
result = process_job(job_id, prompt)
|
132 |
+
|
133 |
+
# Store the result
|
134 |
+
job_results[job_id] = result
|
135 |
+
|
136 |
+
except queue.Empty:
|
137 |
+
# No jobs in queue, continue waiting
|
138 |
+
pass
|
139 |
+
except Exception as e:
|
140 |
+
print(f"Error in worker thread: {e}")
|
141 |
+
import traceback
|
142 |
+
traceback.print_exc()
|
143 |
+
# If there was a job being processed, mark it as failed
|
144 |
+
if 'job_id' in locals():
|
145 |
+
job_results[job_id] = {
|
146 |
+
"success": False,
|
147 |
+
"error": str(e)
|
148 |
+
}
|
149 |
+
finally:
|
150 |
+
is_thread_running = False
|
151 |
+
|
152 |
+
def process_job(job_id, prompt):
|
153 |
try:
|
154 |
# Load models if not already loaded
|
155 |
load_models_if_needed()
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
# Set parameters for CPU performance (reduced steps and other optimizations)
|
158 |
batch_size = 1
|
159 |
guidance_scale = 15.0
|
160 |
|
161 |
+
# *** EXTREME OPTIMIZATION: Significantly reduce steps for low-memory environments ***
|
162 |
+
karras_steps = 8 # Reduced from 16 to 8 for even better performance
|
163 |
|
164 |
# *** OPTIMIZATION: Run garbage collection before starting intensive task ***
|
165 |
gc.collect()
|
166 |
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
167 |
|
168 |
# Generate latents with the text-to-3D model
|
169 |
+
print(f"Starting latent generation for job {job_id} with {karras_steps} steps...")
|
170 |
+
|
171 |
+
# Force lower precision
|
172 |
+
with torch.inference_mode():
|
173 |
+
latents = sample_latents(
|
174 |
+
batch_size=batch_size,
|
175 |
+
model=model,
|
176 |
+
diffusion=diffusion,
|
177 |
+
guidance_scale=guidance_scale,
|
178 |
+
model_kwargs=dict(texts=[prompt] * batch_size),
|
179 |
+
progress=True,
|
180 |
+
clip_denoised=True,
|
181 |
+
use_fp16=False, # CPU doesn't support fp16
|
182 |
+
use_karras=True,
|
183 |
+
karras_steps=karras_steps,
|
184 |
+
sigma_min=1e-3,
|
185 |
+
sigma_max=160,
|
186 |
+
s_churn=0,
|
187 |
+
)
|
188 |
+
print(f"Latent generation complete for job {job_id}!")
|
189 |
|
190 |
# *** OPTIMIZATION: Run garbage collection after intensive step ***
|
191 |
gc.collect()
|
|
|
192 |
|
193 |
# Generate a unique filename
|
194 |
unique_id = str(uuid.uuid4())
|
195 |
filename = f"{output_dir}/{unique_id}"
|
196 |
|
197 |
+
# Convert latent to mesh with optimization settings
|
198 |
+
print(f"Decoding mesh for job {job_id}...")
|
199 |
t0 = time.time()
|
200 |
|
201 |
+
# *** OPTIMIZATION: Use simplified decoding with lower resolution ***
|
202 |
+
mesh = decode_latent_mesh(xm, latents[0], max_points=4000).tri_mesh() # Reduced point count
|
203 |
print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
|
204 |
|
205 |
# *** OPTIMIZATION: Clear latents from memory as they're no longer needed ***
|
206 |
del latents
|
207 |
gc.collect()
|
|
|
208 |
|
209 |
# Save as GLB
|
210 |
+
print(f"Saving job {job_id} as GLB...")
|
211 |
glb_path = f"{filename}.glb"
|
212 |
mesh.write_glb(glb_path)
|
213 |
|
214 |
# Save as OBJ
|
215 |
+
print(f"Saving job {job_id} as OBJ...")
|
216 |
obj_path = f"{filename}.obj"
|
217 |
with open(obj_path, 'w') as f:
|
218 |
mesh.write_obj(f)
|
|
|
221 |
del mesh
|
222 |
gc.collect()
|
223 |
|
224 |
+
print(f"Files saved successfully for job {job_id}!")
|
225 |
|
226 |
# Return paths to the generated files
|
227 |
+
return {
|
228 |
"success": True,
|
229 |
"message": "3D model generated successfully",
|
230 |
"glb_url": f"/download/{os.path.basename(glb_path)}",
|
231 |
"obj_url": f"/download/{os.path.basename(obj_path)}"
|
232 |
+
}
|
233 |
|
234 |
except Exception as e:
|
235 |
+
print(f"Error during generation for job {job_id}: {str(e)}")
|
236 |
import traceback
|
237 |
traceback.print_exc()
|
238 |
+
return {
|
239 |
+
"success": False,
|
240 |
+
"error": str(e)
|
241 |
+
}
|
242 |
+
|
243 |
+
def ensure_worker_thread_running():
|
244 |
+
global generation_thread, is_thread_running
|
245 |
+
|
246 |
+
if generation_thread is None or not generation_thread.is_alive():
|
247 |
+
print("Starting worker thread...")
|
248 |
+
generation_thread = threading.Thread(target=worker_thread, daemon=True)
|
249 |
+
generation_thread.start()
|
250 |
+
|
251 |
+
@app.route('/generate', methods=['POST'])
|
252 |
+
def generate_3d():
|
253 |
+
# Get the prompt from the request
|
254 |
+
data = request.json
|
255 |
+
if not data or 'prompt' not in data:
|
256 |
+
return jsonify({"error": "No prompt provided"}), 400
|
257 |
+
|
258 |
+
prompt = data['prompt']
|
259 |
+
print(f"Received prompt: {prompt}")
|
260 |
+
|
261 |
+
# Generate a job ID
|
262 |
+
job_id = str(uuid.uuid4())
|
263 |
+
|
264 |
+
# Add job to queue
|
265 |
+
ensure_worker_thread_running()
|
266 |
+
job_queue.put((job_id, prompt))
|
267 |
+
|
268 |
+
# Return job ID immediately
|
269 |
+
return jsonify({
|
270 |
+
"success": True,
|
271 |
+
"message": "Job submitted successfully",
|
272 |
+
"job_id": job_id,
|
273 |
+
"status_url": f"/status/{job_id}"
|
274 |
+
})
|
275 |
+
|
276 |
+
@app.route('/status/<job_id>', methods=['GET'])
|
277 |
+
def job_status(job_id):
|
278 |
+
if job_id in job_results:
|
279 |
+
result = job_results[job_id]
|
280 |
+
# Clean up memory if the job is complete and successful
|
281 |
+
if result.get("success", False):
|
282 |
+
return jsonify(result)
|
283 |
+
else:
|
284 |
+
return jsonify({"error": result.get("error", "Unknown error")}), 500
|
285 |
+
else:
|
286 |
+
# Job is still in progress
|
287 |
+
return jsonify({
|
288 |
+
"success": None,
|
289 |
+
"message": "Job is still processing",
|
290 |
+
"job_id": job_id
|
291 |
+
})
|
292 |
|
293 |
@app.route('/download/<filename>', methods=['GET'])
|
294 |
def download_file(filename):
|
|
|
302 |
"""Simple health check endpoint to verify the app is running"""
|
303 |
# Check available memory
|
304 |
try:
|
|
|
305 |
memory_info = psutil.virtual_memory()
|
306 |
memory_usage = f"{memory_info.percent}% (Available: {memory_info.available / (1024**3):.2f} GB)"
|
307 |
+
|
308 |
+
# Check CPU usage
|
309 |
+
cpu_usage = f"{psutil.cpu_percent(interval=0.1)}%"
|
310 |
+
|
311 |
+
# Get queue status
|
312 |
+
queue_size = job_queue.qsize()
|
313 |
+
|
314 |
+
# Get active jobs
|
315 |
+
active_jobs = len(job_results)
|
316 |
+
except Exception as e:
|
317 |
+
memory_usage = "Error getting system info"
|
318 |
+
cpu_usage = "Error getting CPU info"
|
319 |
+
queue_size = "Unknown"
|
320 |
+
active_jobs = "Unknown"
|
321 |
|
322 |
return jsonify({
|
323 |
"status": "ok",
|
324 |
"message": "Service is running",
|
325 |
+
"memory_usage": memory_usage,
|
326 |
+
"cpu_usage": cpu_usage,
|
327 |
+
"queue_size": queue_size,
|
328 |
+
"active_jobs": active_jobs,
|
329 |
+
"worker_running": is_thread_running
|
330 |
})
|
331 |
|
332 |
@app.route('/', methods=['GET'])
|
|
|
334 |
"""Landing page with usage instructions"""
|
335 |
return """
|
336 |
<html>
|
337 |
+
<head>
|
338 |
+
<title>Text to 3D API</title>
|
339 |
+
<style>
|
340 |
+
body { font-family: Arial, sans-serif; line-height: 1.6; margin: 0; padding: 20px; max-width: 800px; margin: 0 auto; }
|
341 |
+
pre { background: #f4f4f4; padding: 15px; border-radius: 5px; overflow-x: auto; }
|
342 |
+
code { background: #f4f4f4; padding: 2px 5px; border-radius: 3px; }
|
343 |
+
h1, h2 { color: #333; }
|
344 |
+
</style>
|
345 |
+
</head>
|
346 |
<body>
|
347 |
<h1>Text to 3D API</h1>
|
348 |
+
<p>This is an optimized API that converts text prompts to 3D models.</p>
|
349 |
+
|
350 |
<h2>How to use:</h2>
|
351 |
+
<h3>Step 1: Submit a generation job</h3>
|
352 |
+
<pre>
|
353 |
+
POST /generate
|
354 |
+
Content-Type: application/json
|
355 |
+
|
356 |
+
{
|
357 |
+
"prompt": "A futuristic building"
|
358 |
+
}
|
359 |
+
</pre>
|
360 |
+
<p>Response:</p>
|
361 |
<pre>
|
362 |
+
{
|
363 |
+
"success": true,
|
364 |
+
"message": "Job submitted successfully",
|
365 |
+
"job_id": "123e4567-e89b-12d3-a456-426614174000",
|
366 |
+
"status_url": "/status/123e4567-e89b-12d3-a456-426614174000"
|
367 |
+
}
|
368 |
+
</pre>
|
369 |
|
370 |
+
<h3>Step 2: Check job status</h3>
|
371 |
+
<pre>
|
372 |
+
GET /status/123e4567-e89b-12d3-a456-426614174000
|
373 |
+
</pre>
|
374 |
+
<p>Response (while processing):</p>
|
375 |
+
<pre>
|
376 |
+
{
|
377 |
+
"success": null,
|
378 |
+
"message": "Job is still processing",
|
379 |
+
"job_id": "123e4567-e89b-12d3-a456-426614174000"
|
380 |
+
}
|
381 |
</pre>
|
382 |
+
<p>Response (when complete):</p>
|
383 |
+
<pre>
|
384 |
+
{
|
385 |
+
"success": true,
|
386 |
+
"message": "3D model generated successfully",
|
387 |
+
"glb_url": "/download/abc123.glb",
|
388 |
+
"obj_url": "/download/abc123.obj"
|
389 |
+
}
|
390 |
+
</pre>
|
391 |
+
|
392 |
+
<h3>Step 3: Download the files</h3>
|
393 |
+
<p>Use the provided URLs to download the GLB and OBJ files.</p>
|
394 |
+
|
395 |
+
<h2>Health Check:</h2>
|
396 |
+
<pre>GET /health</pre>
|
397 |
+
<p>Provides information about the service status and resource usage.</p>
|
398 |
</body>
|
399 |
</html>
|
400 |
"""
|
401 |
|
402 |
if __name__ == '__main__':
|
403 |
+
# Start the worker thread
|
404 |
+
ensure_worker_thread_running()
|
405 |
+
|
406 |
# Recommended to run with gunicorn for production with increased timeout:
|
407 |
# $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1
|
408 |
+
app.run(host='0.0.0.0', port=7860, debug=False) # Set debug=False in production
|