Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,136 +1,104 @@
|
|
1 |
-
import os
|
2 |
-
import tempfile
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
from flask import Flask, request, jsonify, send_file
|
6 |
from flask_cors import CORS
|
7 |
-
|
8 |
-
|
9 |
-
import
|
10 |
-
import
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
# Fix matplotlib and clip cache path issues
|
17 |
-
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
18 |
-
os.environ['HOME'] = '/tmp'
|
19 |
-
|
20 |
-
# Set cache directories explicitly
|
21 |
-
cache_dir = os.environ.get('POINT_E_CACHE_DIR', '/tmp/point_e_model_cache')
|
22 |
-
clip_cache_dir = os.environ.get('CLIP_MODEL_DIR', '/tmp/clip_models')
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
os.chmod(directory, 0o777)
|
30 |
|
31 |
app = Flask(__name__)
|
32 |
CORS(app)
|
33 |
|
34 |
-
#
|
35 |
-
|
36 |
-
logger.info(f"Using device: {device}")
|
37 |
-
|
38 |
-
# Set up model directory for downloads
|
39 |
-
model_path = "/tmp/point_e_models/base40M-textvec.pt"
|
40 |
-
if not os.path.exists(model_path):
|
41 |
-
logger.info("Model weights not found. Downloading...")
|
42 |
-
# Use torch.hub to download the model from original repo
|
43 |
-
# This is a placeholder. In reality, you'd need to download the weights from the official source
|
44 |
-
# or include them in your repository
|
45 |
-
try:
|
46 |
-
# For demonstration purposes, we'll assume the model is included with the repo
|
47 |
-
# In a real setup, you'd need to download it or include it
|
48 |
-
pass
|
49 |
-
except Exception as e:
|
50 |
-
logger.error(f"Failed to download model: {e}")
|
51 |
-
raise
|
52 |
-
|
53 |
-
# Load the model
|
54 |
-
logger.info("Loading base model...")
|
55 |
-
try:
|
56 |
-
base_model = model_from_config(
|
57 |
-
MODEL_CONFIGS["base40M-textvec"],
|
58 |
-
device=device,
|
59 |
-
cache_dir=cache_dir # Pass cache directory explicitly
|
60 |
-
)
|
61 |
-
base_model.load_state_dict(torch.load(model_path, map_location=device))
|
62 |
-
base_model.eval()
|
63 |
-
logger.info("Base model loaded successfully")
|
64 |
-
except Exception as e:
|
65 |
-
logger.error(f"Failed to load model: {e}")
|
66 |
-
raise
|
67 |
-
|
68 |
-
# Create sampler with memory-efficient settings for 2vCPU/18GB RAM
|
69 |
-
logger.info("Creating sampler...")
|
70 |
-
sampler = PointCloudSampler(
|
71 |
-
device=device,
|
72 |
-
models=[base_model],
|
73 |
-
diffusion=None,
|
74 |
-
num_points=1024, # Lower point count for memory efficiency
|
75 |
-
aux_channels=[],
|
76 |
-
guidance_scale=3.0,
|
77 |
-
)
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
o3d.io.write_point_cloud(out_path, pcd)
|
84 |
-
logger.info(f"Point cloud saved to {out_path}")
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
<html>
|
91 |
-
<head><title>Point-E Text-to-3D API</title></head>
|
92 |
-
<body>
|
93 |
-
<h1>Point-E Text-to-3D Generator</h1>
|
94 |
-
<form action="/generate" method="post">
|
95 |
-
<label for="prompt">Enter a description:</label>
|
96 |
-
<input type="text" id="prompt" name="prompt" value="a toy robot">
|
97 |
-
<input type="submit" value="Generate 3D Model">
|
98 |
-
</form>
|
99 |
-
<p>Or use the API endpoint at POST /generate with JSON {"prompt": "your description"}</p>
|
100 |
-
</body>
|
101 |
-
</html>
|
102 |
-
"""
|
103 |
|
104 |
-
@app.route(
|
105 |
-
def
|
106 |
-
"""Generate a 3D point cloud from a text description."""
|
107 |
-
if request.is_json:
|
108 |
-
prompt = request.json.get("prompt", "a toy robot")
|
109 |
-
else:
|
110 |
-
prompt = request.form.get("prompt", "a toy robot")
|
111 |
-
|
112 |
-
logger.info(f"Generating point cloud for prompt: '{prompt}'")
|
113 |
-
|
114 |
try:
|
115 |
-
#
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
#
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
except Exception as e:
|
127 |
-
|
128 |
return jsonify({"error": str(e)}), 500
|
129 |
|
130 |
-
@app.route(
|
131 |
-
def
|
132 |
-
|
133 |
-
|
|
|
|
|
134 |
|
135 |
-
if __name__ ==
|
136 |
-
app.run(host=
|
|
|
|
|
|
|
|
|
|
|
1 |
from flask import Flask, request, jsonify, send_file
|
2 |
from flask_cors import CORS
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import trimesh
|
6 |
+
import os
|
7 |
+
from io import BytesIO
|
8 |
+
import base64
|
9 |
+
from PIL import Image
|
10 |
+
import uuid
|
11 |
+
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# Import Shap-E
|
14 |
+
from shap_e.diffusion.sample import sample_latents
|
15 |
+
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
16 |
+
from shap_e.models.download import load_model, load_config
|
17 |
+
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
|
|
|
18 |
|
19 |
app = Flask(__name__)
|
20 |
CORS(app)
|
21 |
|
22 |
+
# Create output directory if it doesn't exist
|
23 |
+
os.makedirs("outputs", exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
# Load models only once at startup
|
26 |
+
print("Loading models...")
|
27 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
28 |
+
print(f"Using device: {device}")
|
|
|
|
|
29 |
|
30 |
+
xm = load_model('transmitter', device=device)
|
31 |
+
model = load_model('text300M', device=device)
|
32 |
+
diffusion = diffusion_from_config(load_config('diffusion'))
|
33 |
+
print("Models loaded successfully!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
+
@app.route('/generate', methods=['POST'])
|
36 |
+
def generate_3d():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
try:
|
38 |
+
# Get the prompt from the request
|
39 |
+
data = request.json
|
40 |
+
if not data or 'prompt' not in data:
|
41 |
+
return jsonify({"error": "No prompt provided"}), 400
|
42 |
+
|
43 |
+
prompt = data['prompt']
|
44 |
+
print(f"Received prompt: {prompt}")
|
45 |
+
|
46 |
+
# Set parameters
|
47 |
+
batch_size = 1
|
48 |
+
guidance_scale = 15.0
|
49 |
|
50 |
+
# Generate latents with the text-to-3D model
|
51 |
+
latents = sample_latents(
|
52 |
+
batch_size=batch_size,
|
53 |
+
model=model,
|
54 |
+
diffusion=diffusion,
|
55 |
+
guidance_scale=guidance_scale,
|
56 |
+
model_kwargs=dict(texts=[prompt] * batch_size),
|
57 |
+
progress=True,
|
58 |
+
clip_denoised=True,
|
59 |
+
use_fp16=True,
|
60 |
+
use_karras=True,
|
61 |
+
karras_steps=64,
|
62 |
+
sigma_min=1e-3,
|
63 |
+
sigma_max=160,
|
64 |
+
s_churn=0,
|
65 |
+
)
|
66 |
|
67 |
+
# Generate a unique filename
|
68 |
+
filename = f"outputs/{uuid.uuid4()}"
|
69 |
+
|
70 |
+
# Convert latent to mesh
|
71 |
+
t0 = time.time()
|
72 |
+
mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
|
73 |
+
print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
|
74 |
+
|
75 |
+
# Save as GLB
|
76 |
+
glb_path = f"{filename}.glb"
|
77 |
+
mesh.write_glb(glb_path)
|
78 |
+
|
79 |
+
# Save as OBJ
|
80 |
+
obj_path = f"{filename}.obj"
|
81 |
+
with open(obj_path, 'w') as f:
|
82 |
+
mesh.write_obj(f)
|
83 |
+
|
84 |
+
# Return paths to the generated files
|
85 |
+
return jsonify({
|
86 |
+
"success": True,
|
87 |
+
"message": "3D model generated successfully",
|
88 |
+
"glb_url": f"/download/{os.path.basename(glb_path)}",
|
89 |
+
"obj_url": f"/download/{os.path.basename(obj_path)}"
|
90 |
+
})
|
91 |
+
|
92 |
except Exception as e:
|
93 |
+
print(f"Error: {str(e)}")
|
94 |
return jsonify({"error": str(e)}), 500
|
95 |
|
96 |
+
@app.route('/download/<filename>', methods=['GET'])
|
97 |
+
def download_file(filename):
|
98 |
+
try:
|
99 |
+
return send_file(f"outputs/{filename}", as_attachment=True)
|
100 |
+
except Exception as e:
|
101 |
+
return jsonify({"error": str(e)}), 404
|
102 |
|
103 |
+
if __name__ == '__main__':
|
104 |
+
app.run(host='0.0.0.0', port=7860, debug=True)
|