mac9087 commited on
Commit
d5ed7cc
·
verified ·
1 Parent(s): 9c8ecc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -55
app.py CHANGED
@@ -11,6 +11,9 @@ import uuid
11
  import time
12
  import sys
13
  import gc # For explicit garbage collection
 
 
 
14
 
15
  # Set environment variables before anything else
16
  os.environ['SHAPEE_NO_INTERACTIVE'] = '1'
@@ -92,11 +95,19 @@ xm = None
92
  model = None
93
  diffusion = None
94
 
 
 
 
 
 
 
95
  def load_models_if_needed():
96
  global xm, model, diffusion
97
  if xm is None or model is None or diffusion is None:
98
  print("Loading models for the first time...")
99
  try:
 
 
100
  xm = load_model('transmitter', device=device)
101
  model = load_model('text300M', device=device)
102
  diffusion = diffusion_from_config(load_config('diffusion'))
@@ -105,78 +116,103 @@ def load_models_if_needed():
105
  print(f"Error loading models: {e}")
106
  raise
107
 
108
- @app.route('/generate', methods=['POST'])
109
- def generate_3d():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  try:
111
  # Load models if not already loaded
112
  load_models_if_needed()
113
 
114
- # Get the prompt from the request
115
- data = request.json
116
- if not data or 'prompt' not in data:
117
- return jsonify({"error": "No prompt provided"}), 400
118
-
119
- prompt = data['prompt']
120
- print(f"Received prompt: {prompt}")
121
-
122
  # Set parameters for CPU performance (reduced steps and other optimizations)
123
  batch_size = 1
124
  guidance_scale = 15.0
125
 
126
- # *** OPTIMIZATION: Significantly reduce steps for low-memory environments ***
127
- karras_steps = 16 # Reduced from 32 to 16 for better performance
128
 
129
  # *** OPTIMIZATION: Run garbage collection before starting intensive task ***
130
  gc.collect()
131
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
132
 
133
  # Generate latents with the text-to-3D model
134
- print("Starting latent generation with reduced steps...")
135
- latents = sample_latents(
136
- batch_size=batch_size,
137
- model=model,
138
- diffusion=diffusion,
139
- guidance_scale=guidance_scale,
140
- model_kwargs=dict(texts=[prompt] * batch_size),
141
- progress=True,
142
- clip_denoised=True,
143
- use_fp16=False, # CPU doesn't support fp16
144
- use_karras=True,
145
- karras_steps=karras_steps, # *** OPTIMIZATION: Reduced steps ***
146
- sigma_min=1e-3,
147
- sigma_max=160,
148
- s_churn=0,
149
- )
150
- print("Latent generation complete!")
 
 
 
151
 
152
  # *** OPTIMIZATION: Run garbage collection after intensive step ***
153
  gc.collect()
154
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
155
 
156
  # Generate a unique filename
157
  unique_id = str(uuid.uuid4())
158
  filename = f"{output_dir}/{unique_id}"
159
 
160
- # Convert latent to mesh
161
- print("Decoding mesh...")
162
  t0 = time.time()
163
 
164
- # *** OPTIMIZATION: Use simplified decoding for memory constraints ***
165
- mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
166
  print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
167
 
168
  # *** OPTIMIZATION: Clear latents from memory as they're no longer needed ***
169
  del latents
170
  gc.collect()
171
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
172
 
173
  # Save as GLB
174
- print("Saving as GLB...")
175
  glb_path = f"{filename}.glb"
176
  mesh.write_glb(glb_path)
177
 
178
  # Save as OBJ
179
- print("Saving as OBJ...")
180
  obj_path = f"{filename}.obj"
181
  with open(obj_path, 'w') as f:
182
  mesh.write_obj(f)
@@ -185,21 +221,74 @@ def generate_3d():
185
  del mesh
186
  gc.collect()
187
 
188
- print("Files saved successfully!")
189
 
190
  # Return paths to the generated files
191
- return jsonify({
192
  "success": True,
193
  "message": "3D model generated successfully",
194
  "glb_url": f"/download/{os.path.basename(glb_path)}",
195
  "obj_url": f"/download/{os.path.basename(obj_path)}"
196
- })
197
 
198
  except Exception as e:
199
- print(f"Error during generation: {str(e)}")
200
  import traceback
201
  traceback.print_exc()
202
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  @app.route('/download/<filename>', methods=['GET'])
205
  def download_file(filename):
@@ -213,16 +302,31 @@ def health_check():
213
  """Simple health check endpoint to verify the app is running"""
214
  # Check available memory
215
  try:
216
- import psutil
217
  memory_info = psutil.virtual_memory()
218
  memory_usage = f"{memory_info.percent}% (Available: {memory_info.available / (1024**3):.2f} GB)"
219
- except ImportError:
220
- memory_usage = "psutil not installed"
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  return jsonify({
223
  "status": "ok",
224
  "message": "Service is running",
225
- "memory_usage": memory_usage
 
 
 
 
226
  })
227
 
228
  @app.route('/', methods=['GET'])
@@ -230,25 +334,75 @@ def home():
230
  """Landing page with usage instructions"""
231
  return """
232
  <html>
233
- <head><title>Text to 3D API</title></head>
 
 
 
 
 
 
 
 
234
  <body>
235
  <h1>Text to 3D API</h1>
236
- <p>This is a simple API that converts text prompts to 3D models.</p>
 
237
  <h2>How to use:</h2>
 
 
 
 
 
 
 
 
 
 
238
  <pre>
239
- POST /generate
240
- Content-Type: application/json
 
 
 
 
 
241
 
242
- {
243
- "prompt": "A futuristic building"
244
- }
 
 
 
 
 
 
 
 
245
  </pre>
246
- <p>The response will include URLs to download the generated models in GLB and OBJ formats.</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  </body>
248
  </html>
249
  """
250
 
251
  if __name__ == '__main__':
 
 
 
252
  # Recommended to run with gunicorn for production with increased timeout:
253
  # $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1
254
- app.run(host='0.0.0.0', port=7860, debug=True)
 
11
  import time
12
  import sys
13
  import gc # For explicit garbage collection
14
+ import threading
15
+ import queue
16
+ import psutil
17
 
18
  # Set environment variables before anything else
19
  os.environ['SHAPEE_NO_INTERACTIVE'] = '1'
 
95
  model = None
96
  diffusion = None
97
 
98
+ # Job queue and results dictionary
99
+ job_queue = queue.Queue()
100
+ job_results = {}
101
+ generation_thread = None
102
+ is_thread_running = False
103
+
104
  def load_models_if_needed():
105
  global xm, model, diffusion
106
  if xm is None or model is None or diffusion is None:
107
  print("Loading models for the first time...")
108
  try:
109
+ # Set lower precision for memory optimization
110
+ torch.set_default_dtype(torch.float32) # Use float32 instead of float64
111
  xm = load_model('transmitter', device=device)
112
  model = load_model('text300M', device=device)
113
  diffusion = diffusion_from_config(load_config('diffusion'))
 
116
  print(f"Error loading models: {e}")
117
  raise
118
 
119
+ def worker_thread():
120
+ global is_thread_running
121
+ is_thread_running = True
122
+
123
+ try:
124
+ while True:
125
+ try:
126
+ # Get job from queue with a timeout
127
+ job_id, prompt = job_queue.get(timeout=1)
128
+ print(f"Processing job {job_id} with prompt: {prompt}")
129
+
130
+ # Process the job
131
+ result = process_job(job_id, prompt)
132
+
133
+ # Store the result
134
+ job_results[job_id] = result
135
+
136
+ except queue.Empty:
137
+ # No jobs in queue, continue waiting
138
+ pass
139
+ except Exception as e:
140
+ print(f"Error in worker thread: {e}")
141
+ import traceback
142
+ traceback.print_exc()
143
+ # If there was a job being processed, mark it as failed
144
+ if 'job_id' in locals():
145
+ job_results[job_id] = {
146
+ "success": False,
147
+ "error": str(e)
148
+ }
149
+ finally:
150
+ is_thread_running = False
151
+
152
+ def process_job(job_id, prompt):
153
  try:
154
  # Load models if not already loaded
155
  load_models_if_needed()
156
 
 
 
 
 
 
 
 
 
157
  # Set parameters for CPU performance (reduced steps and other optimizations)
158
  batch_size = 1
159
  guidance_scale = 15.0
160
 
161
+ # *** EXTREME OPTIMIZATION: Significantly reduce steps for low-memory environments ***
162
+ karras_steps = 8 # Reduced from 16 to 8 for even better performance
163
 
164
  # *** OPTIMIZATION: Run garbage collection before starting intensive task ***
165
  gc.collect()
166
  torch.cuda.empty_cache() if torch.cuda.is_available() else None
167
 
168
  # Generate latents with the text-to-3D model
169
+ print(f"Starting latent generation for job {job_id} with {karras_steps} steps...")
170
+
171
+ # Force lower precision
172
+ with torch.inference_mode():
173
+ latents = sample_latents(
174
+ batch_size=batch_size,
175
+ model=model,
176
+ diffusion=diffusion,
177
+ guidance_scale=guidance_scale,
178
+ model_kwargs=dict(texts=[prompt] * batch_size),
179
+ progress=True,
180
+ clip_denoised=True,
181
+ use_fp16=False, # CPU doesn't support fp16
182
+ use_karras=True,
183
+ karras_steps=karras_steps,
184
+ sigma_min=1e-3,
185
+ sigma_max=160,
186
+ s_churn=0,
187
+ )
188
+ print(f"Latent generation complete for job {job_id}!")
189
 
190
  # *** OPTIMIZATION: Run garbage collection after intensive step ***
191
  gc.collect()
 
192
 
193
  # Generate a unique filename
194
  unique_id = str(uuid.uuid4())
195
  filename = f"{output_dir}/{unique_id}"
196
 
197
+ # Convert latent to mesh with optimization settings
198
+ print(f"Decoding mesh for job {job_id}...")
199
  t0 = time.time()
200
 
201
+ # *** OPTIMIZATION: Use simplified decoding with lower resolution ***
202
+ mesh = decode_latent_mesh(xm, latents[0], max_points=4000).tri_mesh() # Reduced point count
203
  print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
204
 
205
  # *** OPTIMIZATION: Clear latents from memory as they're no longer needed ***
206
  del latents
207
  gc.collect()
 
208
 
209
  # Save as GLB
210
+ print(f"Saving job {job_id} as GLB...")
211
  glb_path = f"{filename}.glb"
212
  mesh.write_glb(glb_path)
213
 
214
  # Save as OBJ
215
+ print(f"Saving job {job_id} as OBJ...")
216
  obj_path = f"{filename}.obj"
217
  with open(obj_path, 'w') as f:
218
  mesh.write_obj(f)
 
221
  del mesh
222
  gc.collect()
223
 
224
+ print(f"Files saved successfully for job {job_id}!")
225
 
226
  # Return paths to the generated files
227
+ return {
228
  "success": True,
229
  "message": "3D model generated successfully",
230
  "glb_url": f"/download/{os.path.basename(glb_path)}",
231
  "obj_url": f"/download/{os.path.basename(obj_path)}"
232
+ }
233
 
234
  except Exception as e:
235
+ print(f"Error during generation for job {job_id}: {str(e)}")
236
  import traceback
237
  traceback.print_exc()
238
+ return {
239
+ "success": False,
240
+ "error": str(e)
241
+ }
242
+
243
+ def ensure_worker_thread_running():
244
+ global generation_thread, is_thread_running
245
+
246
+ if generation_thread is None or not generation_thread.is_alive():
247
+ print("Starting worker thread...")
248
+ generation_thread = threading.Thread(target=worker_thread, daemon=True)
249
+ generation_thread.start()
250
+
251
+ @app.route('/generate', methods=['POST'])
252
+ def generate_3d():
253
+ # Get the prompt from the request
254
+ data = request.json
255
+ if not data or 'prompt' not in data:
256
+ return jsonify({"error": "No prompt provided"}), 400
257
+
258
+ prompt = data['prompt']
259
+ print(f"Received prompt: {prompt}")
260
+
261
+ # Generate a job ID
262
+ job_id = str(uuid.uuid4())
263
+
264
+ # Add job to queue
265
+ ensure_worker_thread_running()
266
+ job_queue.put((job_id, prompt))
267
+
268
+ # Return job ID immediately
269
+ return jsonify({
270
+ "success": True,
271
+ "message": "Job submitted successfully",
272
+ "job_id": job_id,
273
+ "status_url": f"/status/{job_id}"
274
+ })
275
+
276
+ @app.route('/status/<job_id>', methods=['GET'])
277
+ def job_status(job_id):
278
+ if job_id in job_results:
279
+ result = job_results[job_id]
280
+ # Clean up memory if the job is complete and successful
281
+ if result.get("success", False):
282
+ return jsonify(result)
283
+ else:
284
+ return jsonify({"error": result.get("error", "Unknown error")}), 500
285
+ else:
286
+ # Job is still in progress
287
+ return jsonify({
288
+ "success": None,
289
+ "message": "Job is still processing",
290
+ "job_id": job_id
291
+ })
292
 
293
  @app.route('/download/<filename>', methods=['GET'])
294
  def download_file(filename):
 
302
  """Simple health check endpoint to verify the app is running"""
303
  # Check available memory
304
  try:
 
305
  memory_info = psutil.virtual_memory()
306
  memory_usage = f"{memory_info.percent}% (Available: {memory_info.available / (1024**3):.2f} GB)"
307
+
308
+ # Check CPU usage
309
+ cpu_usage = f"{psutil.cpu_percent(interval=0.1)}%"
310
+
311
+ # Get queue status
312
+ queue_size = job_queue.qsize()
313
+
314
+ # Get active jobs
315
+ active_jobs = len(job_results)
316
+ except Exception as e:
317
+ memory_usage = "Error getting system info"
318
+ cpu_usage = "Error getting CPU info"
319
+ queue_size = "Unknown"
320
+ active_jobs = "Unknown"
321
 
322
  return jsonify({
323
  "status": "ok",
324
  "message": "Service is running",
325
+ "memory_usage": memory_usage,
326
+ "cpu_usage": cpu_usage,
327
+ "queue_size": queue_size,
328
+ "active_jobs": active_jobs,
329
+ "worker_running": is_thread_running
330
  })
331
 
332
  @app.route('/', methods=['GET'])
 
334
  """Landing page with usage instructions"""
335
  return """
336
  <html>
337
+ <head>
338
+ <title>Text to 3D API</title>
339
+ <style>
340
+ body { font-family: Arial, sans-serif; line-height: 1.6; margin: 0; padding: 20px; max-width: 800px; margin: 0 auto; }
341
+ pre { background: #f4f4f4; padding: 15px; border-radius: 5px; overflow-x: auto; }
342
+ code { background: #f4f4f4; padding: 2px 5px; border-radius: 3px; }
343
+ h1, h2 { color: #333; }
344
+ </style>
345
+ </head>
346
  <body>
347
  <h1>Text to 3D API</h1>
348
+ <p>This is an optimized API that converts text prompts to 3D models.</p>
349
+
350
  <h2>How to use:</h2>
351
+ <h3>Step 1: Submit a generation job</h3>
352
+ <pre>
353
+ POST /generate
354
+ Content-Type: application/json
355
+
356
+ {
357
+ "prompt": "A futuristic building"
358
+ }
359
+ </pre>
360
+ <p>Response:</p>
361
  <pre>
362
+ {
363
+ "success": true,
364
+ "message": "Job submitted successfully",
365
+ "job_id": "123e4567-e89b-12d3-a456-426614174000",
366
+ "status_url": "/status/123e4567-e89b-12d3-a456-426614174000"
367
+ }
368
+ </pre>
369
 
370
+ <h3>Step 2: Check job status</h3>
371
+ <pre>
372
+ GET /status/123e4567-e89b-12d3-a456-426614174000
373
+ </pre>
374
+ <p>Response (while processing):</p>
375
+ <pre>
376
+ {
377
+ "success": null,
378
+ "message": "Job is still processing",
379
+ "job_id": "123e4567-e89b-12d3-a456-426614174000"
380
+ }
381
  </pre>
382
+ <p>Response (when complete):</p>
383
+ <pre>
384
+ {
385
+ "success": true,
386
+ "message": "3D model generated successfully",
387
+ "glb_url": "/download/abc123.glb",
388
+ "obj_url": "/download/abc123.obj"
389
+ }
390
+ </pre>
391
+
392
+ <h3>Step 3: Download the files</h3>
393
+ <p>Use the provided URLs to download the GLB and OBJ files.</p>
394
+
395
+ <h2>Health Check:</h2>
396
+ <pre>GET /health</pre>
397
+ <p>Provides information about the service status and resource usage.</p>
398
  </body>
399
  </html>
400
  """
401
 
402
  if __name__ == '__main__':
403
+ # Start the worker thread
404
+ ensure_worker_thread_running()
405
+
406
  # Recommended to run with gunicorn for production with increased timeout:
407
  # $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1
408
+ app.run(host='0.0.0.0', port=7860, debug=False) # Set debug=False in production