Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,12 +9,19 @@ import base64
|
|
9 |
from PIL import Image
|
10 |
import uuid
|
11 |
import time
|
|
|
12 |
|
13 |
# Import Shap-E
|
14 |
-
|
15 |
-
|
16 |
-
from shap_e.
|
17 |
-
from shap_e.
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
app = Flask(__name__)
|
20 |
CORS(app)
|
@@ -22,19 +29,35 @@ CORS(app)
|
|
22 |
# Create output directory if it doesn't exist
|
23 |
os.makedirs("outputs", exist_ok=True)
|
24 |
|
25 |
-
#
|
26 |
-
print("
|
27 |
-
device = torch.device('
|
28 |
print(f"Using device: {device}")
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
@app.route('/generate', methods=['POST'])
|
36 |
def generate_3d():
|
37 |
try:
|
|
|
|
|
|
|
38 |
# Get the prompt from the request
|
39 |
data = request.json
|
40 |
if not data or 'prompt' not in data:
|
@@ -43,11 +66,12 @@ def generate_3d():
|
|
43 |
prompt = data['prompt']
|
44 |
print(f"Received prompt: {prompt}")
|
45 |
|
46 |
-
# Set parameters
|
47 |
batch_size = 1
|
48 |
guidance_scale = 15.0
|
49 |
|
50 |
# Generate latents with the text-to-3D model
|
|
|
51 |
latents = sample_latents(
|
52 |
batch_size=batch_size,
|
53 |
model=model,
|
@@ -56,30 +80,36 @@ def generate_3d():
|
|
56 |
model_kwargs=dict(texts=[prompt] * batch_size),
|
57 |
progress=True,
|
58 |
clip_denoised=True,
|
59 |
-
use_fp16=
|
60 |
use_karras=True,
|
61 |
-
karras_steps=
|
62 |
sigma_min=1e-3,
|
63 |
sigma_max=160,
|
64 |
s_churn=0,
|
65 |
)
|
|
|
66 |
|
67 |
# Generate a unique filename
|
68 |
filename = f"outputs/{uuid.uuid4()}"
|
69 |
|
70 |
# Convert latent to mesh
|
|
|
71 |
t0 = time.time()
|
72 |
mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
|
73 |
print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
|
74 |
|
75 |
# Save as GLB
|
|
|
76 |
glb_path = f"{filename}.glb"
|
77 |
mesh.write_glb(glb_path)
|
78 |
|
79 |
# Save as OBJ
|
|
|
80 |
obj_path = f"{filename}.obj"
|
81 |
with open(obj_path, 'w') as f:
|
82 |
mesh.write_obj(f)
|
|
|
|
|
83 |
|
84 |
# Return paths to the generated files
|
85 |
return jsonify({
|
@@ -90,7 +120,9 @@ def generate_3d():
|
|
90 |
})
|
91 |
|
92 |
except Exception as e:
|
93 |
-
print(f"Error: {str(e)}")
|
|
|
|
|
94 |
return jsonify({"error": str(e)}), 500
|
95 |
|
96 |
@app.route('/download/<filename>', methods=['GET'])
|
@@ -100,5 +132,33 @@ def download_file(filename):
|
|
100 |
except Exception as e:
|
101 |
return jsonify({"error": str(e)}), 404
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
if __name__ == '__main__':
|
104 |
app.run(host='0.0.0.0', port=7860, debug=True)
|
|
|
9 |
from PIL import Image
|
10 |
import uuid
|
11 |
import time
|
12 |
+
import sys
|
13 |
|
14 |
# Import Shap-E
|
15 |
+
print("Importing Shap-E modules...")
|
16 |
+
try:
|
17 |
+
from shap_e.diffusion.sample import sample_latents
|
18 |
+
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
19 |
+
from shap_e.models.download import load_model, load_config
|
20 |
+
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
|
21 |
+
print("Shap-E modules imported successfully!")
|
22 |
+
except Exception as e:
|
23 |
+
print(f"Error importing Shap-E modules: {e}")
|
24 |
+
sys.exit(1)
|
25 |
|
26 |
app = Flask(__name__)
|
27 |
CORS(app)
|
|
|
29 |
# Create output directory if it doesn't exist
|
30 |
os.makedirs("outputs", exist_ok=True)
|
31 |
|
32 |
+
# Use lazy loading for models
|
33 |
+
print("Setting up device...")
|
34 |
+
device = torch.device('cpu') # Force CPU for Hugging Face Spaces
|
35 |
print(f"Using device: {device}")
|
36 |
|
37 |
+
# Global variables for models (will be loaded on first request)
|
38 |
+
xm = None
|
39 |
+
model = None
|
40 |
+
diffusion = None
|
41 |
+
|
42 |
+
def load_models_if_needed():
|
43 |
+
global xm, model, diffusion
|
44 |
+
if xm is None or model is None or diffusion is None:
|
45 |
+
print("Loading models for the first time...")
|
46 |
+
try:
|
47 |
+
xm = load_model('transmitter', device=device)
|
48 |
+
model = load_model('text300M', device=device)
|
49 |
+
diffusion = diffusion_from_config(load_config('diffusion'))
|
50 |
+
print("Models loaded successfully!")
|
51 |
+
except Exception as e:
|
52 |
+
print(f"Error loading models: {e}")
|
53 |
+
raise
|
54 |
|
55 |
@app.route('/generate', methods=['POST'])
|
56 |
def generate_3d():
|
57 |
try:
|
58 |
+
# Load models if not already loaded
|
59 |
+
load_models_if_needed()
|
60 |
+
|
61 |
# Get the prompt from the request
|
62 |
data = request.json
|
63 |
if not data or 'prompt' not in data:
|
|
|
66 |
prompt = data['prompt']
|
67 |
print(f"Received prompt: {prompt}")
|
68 |
|
69 |
+
# Set parameters for CPU performance (reduced steps)
|
70 |
batch_size = 1
|
71 |
guidance_scale = 15.0
|
72 |
|
73 |
# Generate latents with the text-to-3D model
|
74 |
+
print("Starting latent generation...")
|
75 |
latents = sample_latents(
|
76 |
batch_size=batch_size,
|
77 |
model=model,
|
|
|
80 |
model_kwargs=dict(texts=[prompt] * batch_size),
|
81 |
progress=True,
|
82 |
clip_denoised=True,
|
83 |
+
use_fp16=False, # CPU doesn't support fp16
|
84 |
use_karras=True,
|
85 |
+
karras_steps=32, # Reduced steps for CPU
|
86 |
sigma_min=1e-3,
|
87 |
sigma_max=160,
|
88 |
s_churn=0,
|
89 |
)
|
90 |
+
print("Latent generation complete!")
|
91 |
|
92 |
# Generate a unique filename
|
93 |
filename = f"outputs/{uuid.uuid4()}"
|
94 |
|
95 |
# Convert latent to mesh
|
96 |
+
print("Decoding mesh...")
|
97 |
t0 = time.time()
|
98 |
mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
|
99 |
print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
|
100 |
|
101 |
# Save as GLB
|
102 |
+
print("Saving as GLB...")
|
103 |
glb_path = f"{filename}.glb"
|
104 |
mesh.write_glb(glb_path)
|
105 |
|
106 |
# Save as OBJ
|
107 |
+
print("Saving as OBJ...")
|
108 |
obj_path = f"{filename}.obj"
|
109 |
with open(obj_path, 'w') as f:
|
110 |
mesh.write_obj(f)
|
111 |
+
|
112 |
+
print("Files saved successfully!")
|
113 |
|
114 |
# Return paths to the generated files
|
115 |
return jsonify({
|
|
|
120 |
})
|
121 |
|
122 |
except Exception as e:
|
123 |
+
print(f"Error during generation: {str(e)}")
|
124 |
+
import traceback
|
125 |
+
traceback.print_exc()
|
126 |
return jsonify({"error": str(e)}), 500
|
127 |
|
128 |
@app.route('/download/<filename>', methods=['GET'])
|
|
|
132 |
except Exception as e:
|
133 |
return jsonify({"error": str(e)}), 404
|
134 |
|
135 |
+
@app.route('/health', methods=['GET'])
|
136 |
+
def health_check():
|
137 |
+
"""Simple health check endpoint to verify the app is running"""
|
138 |
+
return jsonify({"status": "ok", "message": "Service is running"})
|
139 |
+
|
140 |
+
@app.route('/', methods=['GET'])
|
141 |
+
def home():
|
142 |
+
"""Landing page with usage instructions"""
|
143 |
+
return """
|
144 |
+
<html>
|
145 |
+
<head><title>Text to 3D API</title></head>
|
146 |
+
<body>
|
147 |
+
<h1>Text to 3D API</h1>
|
148 |
+
<p>This is a simple API that converts text prompts to 3D models.</p>
|
149 |
+
<h2>How to use:</h2>
|
150 |
+
<pre>
|
151 |
+
POST /generate
|
152 |
+
Content-Type: application/json
|
153 |
+
|
154 |
+
{
|
155 |
+
"prompt": "A futuristic building"
|
156 |
+
}
|
157 |
+
</pre>
|
158 |
+
<p>The response will include URLs to download the generated models in GLB and OBJ formats.</p>
|
159 |
+
</body>
|
160 |
+
</html>
|
161 |
+
"""
|
162 |
+
|
163 |
if __name__ == '__main__':
|
164 |
app.run(host='0.0.0.0', port=7860, debug=True)
|