mac9087 commited on
Commit
d6ba12d
·
verified ·
1 Parent(s): db5a874

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -15
app.py CHANGED
@@ -9,12 +9,19 @@ import base64
9
  from PIL import Image
10
  import uuid
11
  import time
 
12
 
13
  # Import Shap-E
14
- from shap_e.diffusion.sample import sample_latents
15
- from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
16
- from shap_e.models.download import load_model, load_config
17
- from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
 
 
 
 
 
 
18
 
19
  app = Flask(__name__)
20
  CORS(app)
@@ -22,19 +29,35 @@ CORS(app)
22
  # Create output directory if it doesn't exist
23
  os.makedirs("outputs", exist_ok=True)
24
 
25
- # Load models only once at startup
26
- print("Loading models...")
27
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
  print(f"Using device: {device}")
29
 
30
- xm = load_model('transmitter', device=device)
31
- model = load_model('text300M', device=device)
32
- diffusion = diffusion_from_config(load_config('diffusion'))
33
- print("Models loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  @app.route('/generate', methods=['POST'])
36
  def generate_3d():
37
  try:
 
 
 
38
  # Get the prompt from the request
39
  data = request.json
40
  if not data or 'prompt' not in data:
@@ -43,11 +66,12 @@ def generate_3d():
43
  prompt = data['prompt']
44
  print(f"Received prompt: {prompt}")
45
 
46
- # Set parameters
47
  batch_size = 1
48
  guidance_scale = 15.0
49
 
50
  # Generate latents with the text-to-3D model
 
51
  latents = sample_latents(
52
  batch_size=batch_size,
53
  model=model,
@@ -56,30 +80,36 @@ def generate_3d():
56
  model_kwargs=dict(texts=[prompt] * batch_size),
57
  progress=True,
58
  clip_denoised=True,
59
- use_fp16=True,
60
  use_karras=True,
61
- karras_steps=64,
62
  sigma_min=1e-3,
63
  sigma_max=160,
64
  s_churn=0,
65
  )
 
66
 
67
  # Generate a unique filename
68
  filename = f"outputs/{uuid.uuid4()}"
69
 
70
  # Convert latent to mesh
 
71
  t0 = time.time()
72
  mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
73
  print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
74
 
75
  # Save as GLB
 
76
  glb_path = f"{filename}.glb"
77
  mesh.write_glb(glb_path)
78
 
79
  # Save as OBJ
 
80
  obj_path = f"{filename}.obj"
81
  with open(obj_path, 'w') as f:
82
  mesh.write_obj(f)
 
 
83
 
84
  # Return paths to the generated files
85
  return jsonify({
@@ -90,7 +120,9 @@ def generate_3d():
90
  })
91
 
92
  except Exception as e:
93
- print(f"Error: {str(e)}")
 
 
94
  return jsonify({"error": str(e)}), 500
95
 
96
  @app.route('/download/<filename>', methods=['GET'])
@@ -100,5 +132,33 @@ def download_file(filename):
100
  except Exception as e:
101
  return jsonify({"error": str(e)}), 404
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  if __name__ == '__main__':
104
  app.run(host='0.0.0.0', port=7860, debug=True)
 
9
  from PIL import Image
10
  import uuid
11
  import time
12
+ import sys
13
 
14
  # Import Shap-E
15
+ print("Importing Shap-E modules...")
16
+ try:
17
+ from shap_e.diffusion.sample import sample_latents
18
+ from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
19
+ from shap_e.models.download import load_model, load_config
20
+ from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
21
+ print("Shap-E modules imported successfully!")
22
+ except Exception as e:
23
+ print(f"Error importing Shap-E modules: {e}")
24
+ sys.exit(1)
25
 
26
  app = Flask(__name__)
27
  CORS(app)
 
29
  # Create output directory if it doesn't exist
30
  os.makedirs("outputs", exist_ok=True)
31
 
32
+ # Use lazy loading for models
33
+ print("Setting up device...")
34
+ device = torch.device('cpu') # Force CPU for Hugging Face Spaces
35
  print(f"Using device: {device}")
36
 
37
+ # Global variables for models (will be loaded on first request)
38
+ xm = None
39
+ model = None
40
+ diffusion = None
41
+
42
+ def load_models_if_needed():
43
+ global xm, model, diffusion
44
+ if xm is None or model is None or diffusion is None:
45
+ print("Loading models for the first time...")
46
+ try:
47
+ xm = load_model('transmitter', device=device)
48
+ model = load_model('text300M', device=device)
49
+ diffusion = diffusion_from_config(load_config('diffusion'))
50
+ print("Models loaded successfully!")
51
+ except Exception as e:
52
+ print(f"Error loading models: {e}")
53
+ raise
54
 
55
  @app.route('/generate', methods=['POST'])
56
  def generate_3d():
57
  try:
58
+ # Load models if not already loaded
59
+ load_models_if_needed()
60
+
61
  # Get the prompt from the request
62
  data = request.json
63
  if not data or 'prompt' not in data:
 
66
  prompt = data['prompt']
67
  print(f"Received prompt: {prompt}")
68
 
69
+ # Set parameters for CPU performance (reduced steps)
70
  batch_size = 1
71
  guidance_scale = 15.0
72
 
73
  # Generate latents with the text-to-3D model
74
+ print("Starting latent generation...")
75
  latents = sample_latents(
76
  batch_size=batch_size,
77
  model=model,
 
80
  model_kwargs=dict(texts=[prompt] * batch_size),
81
  progress=True,
82
  clip_denoised=True,
83
+ use_fp16=False, # CPU doesn't support fp16
84
  use_karras=True,
85
+ karras_steps=32, # Reduced steps for CPU
86
  sigma_min=1e-3,
87
  sigma_max=160,
88
  s_churn=0,
89
  )
90
+ print("Latent generation complete!")
91
 
92
  # Generate a unique filename
93
  filename = f"outputs/{uuid.uuid4()}"
94
 
95
  # Convert latent to mesh
96
+ print("Decoding mesh...")
97
  t0 = time.time()
98
  mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
99
  print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
100
 
101
  # Save as GLB
102
+ print("Saving as GLB...")
103
  glb_path = f"{filename}.glb"
104
  mesh.write_glb(glb_path)
105
 
106
  # Save as OBJ
107
+ print("Saving as OBJ...")
108
  obj_path = f"{filename}.obj"
109
  with open(obj_path, 'w') as f:
110
  mesh.write_obj(f)
111
+
112
+ print("Files saved successfully!")
113
 
114
  # Return paths to the generated files
115
  return jsonify({
 
120
  })
121
 
122
  except Exception as e:
123
+ print(f"Error during generation: {str(e)}")
124
+ import traceback
125
+ traceback.print_exc()
126
  return jsonify({"error": str(e)}), 500
127
 
128
  @app.route('/download/<filename>', methods=['GET'])
 
132
  except Exception as e:
133
  return jsonify({"error": str(e)}), 404
134
 
135
+ @app.route('/health', methods=['GET'])
136
+ def health_check():
137
+ """Simple health check endpoint to verify the app is running"""
138
+ return jsonify({"status": "ok", "message": "Service is running"})
139
+
140
+ @app.route('/', methods=['GET'])
141
+ def home():
142
+ """Landing page with usage instructions"""
143
+ return """
144
+ <html>
145
+ <head><title>Text to 3D API</title></head>
146
+ <body>
147
+ <h1>Text to 3D API</h1>
148
+ <p>This is a simple API that converts text prompts to 3D models.</p>
149
+ <h2>How to use:</h2>
150
+ <pre>
151
+ POST /generate
152
+ Content-Type: application/json
153
+
154
+ {
155
+ "prompt": "A futuristic building"
156
+ }
157
+ </pre>
158
+ <p>The response will include URLs to download the generated models in GLB and OBJ formats.</p>
159
+ </body>
160
+ </html>
161
+ """
162
+
163
  if __name__ == '__main__':
164
  app.run(host='0.0.0.0', port=7860, debug=True)