mac9087 commited on
Commit
cd1cc5d
·
verified ·
1 Parent(s): 3a58e1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -55
app.py CHANGED
@@ -1,10 +1,9 @@
1
- # app.py
2
-
3
  import os
4
  import torch
5
  import time
6
  import threading
7
  import json
 
8
  from flask import Flask, request, jsonify, send_file, Response, stream_with_context
9
  from werkzeug.utils import secure_filename
10
  from PIL import Image
@@ -15,7 +14,9 @@ import traceback
15
  from diffusers import ShapEImg2ImgPipeline
16
  from diffusers.utils import export_to_obj
17
  from huggingface_hub import snapshot_download
18
- from flask_cors import CORS # Import CORS
 
 
19
 
20
  app = Flask(__name__)
21
  CORS(app) # Enable CORS for all routes
@@ -42,43 +43,130 @@ app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max
42
  # Job tracking dictionary
43
  processing_jobs = {}
44
 
45
- # Lazy loading for the model
46
- device = "cuda" if torch.cuda.is_available() else "cpu"
47
  pipe = None
 
 
48
 
49
- def load_model():
50
- global pipe
51
- if pipe is None:
52
- try:
53
- model_name = "openai/shap-e-img2img"
54
-
55
- # Download model
56
- snapshot_download(
57
- repo_id=model_name,
58
- cache_dir=CACHE_DIR,
59
- resume_download=True,
60
- )
61
-
62
- # Initialize pipeline
63
- pipe = ShapEImg2ImgPipeline.from_pretrained(
64
- model_name,
65
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
66
- cache_dir=CACHE_DIR,
67
- )
68
- pipe = pipe.to(device)
69
- print(f"Model loaded successfully on {device}")
70
- except Exception as e:
71
- print(f"Error loading model: {str(e)}")
72
- print(traceback.format_exc())
73
- raise
74
- return pipe
75
 
76
  def allowed_file(filename):
77
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  @app.route('/health', methods=['GET'])
80
  def health_check():
81
- return jsonify({"status": "healthy", "model": "Shap-E Image to 3D"}), 200
 
 
 
 
82
 
83
  @app.route('/progress/<job_id>', methods=['GET'])
84
  def progress(job_id):
@@ -94,11 +182,22 @@ def progress(job_id):
94
 
95
  # Wait for job to complete or update
96
  last_progress = job['progress']
 
97
  while job['status'] == 'processing':
98
  if job['progress'] != last_progress:
99
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
100
  last_progress = job['progress']
 
101
  time.sleep(0.5)
 
 
 
 
 
 
 
 
 
102
 
103
  # Send final status
104
  if job['status'] == 'completed':
@@ -121,10 +220,20 @@ def convert_image_to_3d():
121
  if not allowed_file(file.filename):
122
  return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
123
 
124
- # Get optional parameters
125
- guidance_scale = float(request.form.get('guidance_scale', 3.0))
126
- num_inference_steps = int(request.form.get('num_inference_steps', 64))
127
- output_format = request.form.get('output_format', 'obj').lower()
 
 
 
 
 
 
 
 
 
 
128
 
129
  # Validate output format
130
  if output_format not in ['obj', 'glb']:
@@ -137,7 +246,7 @@ def convert_image_to_3d():
137
 
138
  # Save the uploaded file
139
  filename = secure_filename(file.filename)
140
- filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
141
  file.save(filepath)
142
 
143
  # Initialize job tracking
@@ -147,28 +256,44 @@ def convert_image_to_3d():
147
  'result_url': None,
148
  'preview_url': None,
149
  'error': None,
150
- 'output_format': output_format
 
151
  }
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # Start processing in a separate thread
154
  def process_image():
 
 
 
155
  try:
156
- # Open image
157
- image = Image.open(filepath).convert("RGB")
 
158
  processing_jobs[job_id]['progress'] = 10
159
 
160
- # Load model
161
- pipe = load_model()
162
- processing_jobs[job_id]['progress'] = 30
163
-
164
- # Generate 3D model
165
- images = pipe(
166
- image,
167
- guidance_scale=guidance_scale,
168
- num_inference_steps=num_inference_steps,
169
- output_type="mesh",
170
- ).images
171
- processing_jobs[job_id]['progress'] = 80
172
 
173
  # Export based on requested format
174
  if output_format == 'obj':
@@ -206,6 +331,15 @@ def convert_image_to_3d():
206
  processing_jobs[job_id]['status'] = 'completed'
207
  processing_jobs[job_id]['progress'] = 100
208
 
 
 
 
 
 
 
 
 
 
209
  except Exception as e:
210
  # Handle errors
211
  error_details = traceback.format_exc()
@@ -213,9 +347,15 @@ def convert_image_to_3d():
213
  processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
214
  print(f"Error processing job {job_id}: {str(e)}")
215
  print(error_details)
 
 
 
 
216
 
217
  # Start processing thread
218
- threading.Thread(target=process_image).start()
 
 
219
 
220
  # Return job ID immediately
221
  return jsonify({"job_id": job_id}), 202 # 202 Accepted
@@ -262,11 +402,47 @@ def preview_model(job_id):
262
 
263
  return jsonify({"error": "Model file not found"}), 404
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  @app.route('/', methods=['GET'])
266
  def index():
267
- return jsonify({"message": "Image to 3D API is running", "endpoints": ["/convert", "/progress/<job_id>", "/download/<job_id>", "/preview/<job_id>"]}), 200
 
 
 
268
 
269
  if __name__ == '__main__':
 
 
 
270
  # Use port 7860 which is standard for Hugging Face Spaces
271
  port = int(os.environ.get('PORT', 7860))
272
- app.run(host='0.0.0.0', port=port)
 
 
 
1
  import os
2
  import torch
3
  import time
4
  import threading
5
  import json
6
+ import gc
7
  from flask import Flask, request, jsonify, send_file, Response, stream_with_context
8
  from werkzeug.utils import secure_filename
9
  from PIL import Image
 
14
  from diffusers import ShapEImg2ImgPipeline
15
  from diffusers.utils import export_to_obj
16
  from huggingface_hub import snapshot_download
17
+ from flask_cors import CORS
18
+ import signal
19
+ import functools
20
 
21
  app = Flask(__name__)
22
  CORS(app) # Enable CORS for all routes
 
43
  # Job tracking dictionary
44
  processing_jobs = {}
45
 
46
+ # Global model variable
 
47
  pipe = None
48
+ model_loaded = False
49
+ model_loading = False
50
 
51
+ # Configuration for processing
52
+ TIMEOUT_SECONDS = 300 # 5 minutes max for processing
53
+ MAX_DIMENSION = 512 # Max image dimension to process
54
+
55
+ # Timeout handler for long-running processes
56
+ class TimeoutError(Exception):
57
+ pass
58
+
59
+ def timeout_handler(signum, frame):
60
+ raise TimeoutError("Processing timed out")
61
+
62
+ def with_timeout(timeout):
63
+ def decorator(func):
64
+ @functools.wraps(func)
65
+ def wrapper(*args, **kwargs):
66
+ # Set the timeout handler
67
+ signal.signal(signal.SIGALRM, timeout_handler)
68
+ signal.alarm(timeout)
69
+ try:
70
+ result = func(*args, **kwargs)
71
+ finally:
72
+ # Disable the alarm
73
+ signal.alarm(0)
74
+ return result
75
+ return wrapper
76
+ return decorator
77
 
78
  def allowed_file(filename):
79
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
80
 
81
+ # Function to preprocess image - resize if needed
82
+ def preprocess_image(image_path):
83
+ with Image.open(image_path) as img:
84
+ img = img.convert("RGB")
85
+ # Resize if the image is too large
86
+ if img.width > MAX_DIMENSION or img.height > MAX_DIMENSION:
87
+ # Calculate new dimensions while preserving aspect ratio
88
+ if img.width > img.height:
89
+ new_width = MAX_DIMENSION
90
+ new_height = int(img.height * (MAX_DIMENSION / img.width))
91
+ else:
92
+ new_height = MAX_DIMENSION
93
+ new_width = int(img.width * (MAX_DIMENSION / img.height))
94
+ img = img.resize((new_width, new_height), Image.LANCZOS)
95
+
96
+ # Convert to RGB and return
97
+ return img
98
+
99
+ def load_model():
100
+ global pipe, model_loaded, model_loading
101
+
102
+ if model_loaded:
103
+ return pipe
104
+
105
+ if model_loading:
106
+ # Wait for model to load if it's already in progress
107
+ while model_loading and not model_loaded:
108
+ time.sleep(0.5)
109
+ return pipe
110
+
111
+ try:
112
+ model_loading = True
113
+ print("Starting model loading...")
114
+
115
+ model_name = "openai/shap-e-img2img"
116
+
117
+ # Download model with retry mechanism
118
+ max_retries = 3
119
+ retry_delay = 5
120
+
121
+ for attempt in range(max_retries):
122
+ try:
123
+ snapshot_download(
124
+ repo_id=model_name,
125
+ cache_dir=CACHE_DIR,
126
+ resume_download=True,
127
+ )
128
+ break
129
+ except Exception as e:
130
+ if attempt < max_retries - 1:
131
+ print(f"Download attempt {attempt+1} failed: {str(e)}. Retrying in {retry_delay} seconds...")
132
+ time.sleep(retry_delay)
133
+ retry_delay *= 2
134
+ else:
135
+ raise
136
+
137
+ # Initialize pipeline with lower precision to save memory
138
+ device = "cuda" if torch.cuda.is_available() else "cpu"
139
+ dtype = torch.float16 if device == "cuda" else torch.float32
140
+
141
+ pipe = ShapEImg2ImgPipeline.from_pretrained(
142
+ model_name,
143
+ torch_dtype=dtype,
144
+ cache_dir=CACHE_DIR,
145
+ )
146
+ pipe = pipe.to(device)
147
+
148
+ # Optimize for inference
149
+ if device == "cuda":
150
+ pipe.enable_model_cpu_offload()
151
+
152
+ model_loaded = True
153
+ print(f"Model loaded successfully on {device}")
154
+ return pipe
155
+
156
+ except Exception as e:
157
+ print(f"Error loading model: {str(e)}")
158
+ print(traceback.format_exc())
159
+ raise
160
+ finally:
161
+ model_loading = False
162
+
163
  @app.route('/health', methods=['GET'])
164
  def health_check():
165
+ return jsonify({
166
+ "status": "healthy",
167
+ "model": "Shap-E Image to 3D",
168
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
169
+ }), 200
170
 
171
  @app.route('/progress/<job_id>', methods=['GET'])
172
  def progress(job_id):
 
182
 
183
  # Wait for job to complete or update
184
  last_progress = job['progress']
185
+ check_count = 0
186
  while job['status'] == 'processing':
187
  if job['progress'] != last_progress:
188
  yield f"data: {json.dumps({'status': 'processing', 'progress': job['progress']})}\n\n"
189
  last_progress = job['progress']
190
+
191
  time.sleep(0.5)
192
+ check_count += 1
193
+
194
+ # If client hasn't received updates for a while, check if job is still running
195
+ if check_count > 60: # 30 seconds with no updates
196
+ if 'thread_alive' in job and not job['thread_alive']():
197
+ job['status'] = 'error'
198
+ job['error'] = 'Processing thread died unexpectedly'
199
+ break
200
+ check_count = 0
201
 
202
  # Send final status
203
  if job['status'] == 'completed':
 
220
  if not allowed_file(file.filename):
221
  return jsonify({"error": f"File type not allowed. Supported types: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
222
 
223
+ # Get optional parameters with defaults
224
+ try:
225
+ guidance_scale = float(request.form.get('guidance_scale', 3.0))
226
+ num_inference_steps = int(request.form.get('num_inference_steps', 64))
227
+ output_format = request.form.get('output_format', 'obj').lower()
228
+ except ValueError:
229
+ return jsonify({"error": "Invalid parameter values"}), 400
230
+
231
+ # Validate parameters
232
+ if guidance_scale < 1.0 or guidance_scale > 5.0:
233
+ return jsonify({"error": "Guidance scale must be between 1.0 and 5.0"}), 400
234
+
235
+ if num_inference_steps < 32 or num_inference_steps > 128:
236
+ return jsonify({"error": "Number of inference steps must be between 32 and 128"}), 400
237
 
238
  # Validate output format
239
  if output_format not in ['obj', 'glb']:
 
246
 
247
  # Save the uploaded file
248
  filename = secure_filename(file.filename)
249
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{job_id}_{filename}")
250
  file.save(filepath)
251
 
252
  # Initialize job tracking
 
256
  'result_url': None,
257
  'preview_url': None,
258
  'error': None,
259
+ 'output_format': output_format,
260
+ 'created_at': time.time()
261
  }
262
 
263
+ # Process function with timeout
264
+ @with_timeout(TIMEOUT_SECONDS)
265
+ def process_with_timeout(image, steps, scale, format):
266
+ # Load model
267
+ pipe = load_model()
268
+ processing_jobs[job_id]['progress'] = 30
269
+
270
+ # Generate 3D model
271
+ return pipe(
272
+ image,
273
+ guidance_scale=scale,
274
+ num_inference_steps=steps,
275
+ output_type="mesh",
276
+ ).images
277
+
278
  # Start processing in a separate thread
279
  def process_image():
280
+ thread = threading.current_thread()
281
+ processing_jobs[job_id]['thread_alive'] = lambda: thread.is_alive()
282
+
283
  try:
284
+ # Preprocess image (resize if needed)
285
+ processing_jobs[job_id]['progress'] = 5
286
+ image = preprocess_image(filepath)
287
  processing_jobs[job_id]['progress'] = 10
288
 
289
+ # Process image with timeout
290
+ try:
291
+ images = process_with_timeout(image, num_inference_steps, guidance_scale, output_format)
292
+ processing_jobs[job_id]['progress'] = 80
293
+ except TimeoutError:
294
+ processing_jobs[job_id]['status'] = 'error'
295
+ processing_jobs[job_id]['error'] = f"Processing timed out after {TIMEOUT_SECONDS} seconds"
296
+ return
 
 
 
 
297
 
298
  # Export based on requested format
299
  if output_format == 'obj':
 
331
  processing_jobs[job_id]['status'] = 'completed'
332
  processing_jobs[job_id]['progress'] = 100
333
 
334
+ # Clean up temporary file
335
+ if os.path.exists(filepath):
336
+ os.remove(filepath)
337
+
338
+ # Force garbage collection to free memory
339
+ gc.collect()
340
+ if torch.cuda.is_available():
341
+ torch.cuda.empty_cache()
342
+
343
  except Exception as e:
344
  # Handle errors
345
  error_details = traceback.format_exc()
 
347
  processing_jobs[job_id]['error'] = f"{str(e)}\n{error_details}"
348
  print(f"Error processing job {job_id}: {str(e)}")
349
  print(error_details)
350
+
351
+ # Clean up on error
352
+ if os.path.exists(filepath):
353
+ os.remove(filepath)
354
 
355
  # Start processing thread
356
+ processing_thread = threading.Thread(target=process_image)
357
+ processing_thread.daemon = True
358
+ processing_thread.start()
359
 
360
  # Return job ID immediately
361
  return jsonify({"job_id": job_id}), 202 # 202 Accepted
 
402
 
403
  return jsonify({"error": "Model file not found"}), 404
404
 
405
+ # Cleanup old jobs periodically
406
+ def cleanup_old_jobs():
407
+ current_time = time.time()
408
+ job_ids_to_remove = []
409
+
410
+ for job_id, job_data in processing_jobs.items():
411
+ # Remove completed jobs after 1 hour
412
+ if job_data['status'] == 'completed' and (current_time - job_data.get('created_at', 0)) > 3600:
413
+ job_ids_to_remove.append(job_id)
414
+ # Remove error jobs after 30 minutes
415
+ elif job_data['status'] == 'error' and (current_time - job_data.get('created_at', 0)) > 1800:
416
+ job_ids_to_remove.append(job_id)
417
+
418
+ # Remove the jobs
419
+ for job_id in job_ids_to_remove:
420
+ output_dir = os.path.join(RESULTS_FOLDER, job_id)
421
+ try:
422
+ import shutil
423
+ if os.path.exists(output_dir):
424
+ shutil.rmtree(output_dir)
425
+ except Exception as e:
426
+ print(f"Error cleaning up job {job_id}: {str(e)}")
427
+
428
+ # Remove from tracking dictionary
429
+ if job_id in processing_jobs:
430
+ del processing_jobs[job_id]
431
+
432
+ # Schedule the next cleanup
433
+ threading.Timer(300, cleanup_old_jobs).start() # Run every 5 minutes
434
+
435
  @app.route('/', methods=['GET'])
436
  def index():
437
+ return jsonify({
438
+ "message": "Image to 3D API is running",
439
+ "endpoints": ["/convert", "/progress/<job_id>", "/download/<job_id>", "/preview/<job_id>"]
440
+ }), 200
441
 
442
  if __name__ == '__main__':
443
+ # Start the cleanup thread
444
+ cleanup_old_jobs()
445
+
446
  # Use port 7860 which is standard for Hugging Face Spaces
447
  port = int(os.environ.get('PORT', 7860))
448
+ app.run(host='0.0.0.0', port=port)