Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
import torch
|
|
|
|
|
5 |
from point_e.diffusion.sampler import PointCloudSampler
|
6 |
from point_e.models.configs import MODEL_CONFIGS, model_from_config
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
10 |
|
11 |
app = Flask(__name__)
|
12 |
CORS(app)
|
@@ -18,26 +21,32 @@ base_model = model_from_config(MODEL_CONFIGS["base40M-textvec"], device)
|
|
18 |
base_model.load_state_dict(torch.load("base40M-textvec.pt", map_location=device))
|
19 |
base_model.eval()
|
20 |
|
21 |
-
print("
|
22 |
sampler = PointCloudSampler(
|
23 |
device=device,
|
24 |
models=[base_model],
|
25 |
diffusion=None,
|
26 |
num_points=1024,
|
27 |
aux_channels=[],
|
28 |
-
guidance_scale=
|
29 |
)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
31 |
@app.route("/generate", methods=["POST"])
|
32 |
def generate():
|
33 |
-
prompt = request.json.get("prompt", "a
|
34 |
-
samples = sampler.sample_batch(
|
35 |
pc = sampler.output_to_point_clouds(samples)[0]
|
36 |
|
37 |
tmp_dir = tempfile.mkdtemp()
|
38 |
output_path = os.path.join(tmp_dir, "model.ply")
|
39 |
-
|
40 |
-
|
|
|
41 |
|
42 |
if __name__ == "__main__":
|
43 |
app.run(host="0.0.0.0", port=7860)
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
import numpy as np
|
4 |
import torch
|
5 |
+
from flask import Flask, request, jsonify, send_file
|
6 |
+
from flask_cors import CORS
|
7 |
from point_e.diffusion.sampler import PointCloudSampler
|
8 |
from point_e.models.configs import MODEL_CONFIGS, model_from_config
|
9 |
+
import open3d as o3d
|
10 |
+
|
11 |
+
# Fix matplotlib temp issue
|
12 |
+
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
|
13 |
|
14 |
app = Flask(__name__)
|
15 |
CORS(app)
|
|
|
21 |
base_model.load_state_dict(torch.load("base40M-textvec.pt", map_location=device))
|
22 |
base_model.eval()
|
23 |
|
24 |
+
print("Creating sampler...")
|
25 |
sampler = PointCloudSampler(
|
26 |
device=device,
|
27 |
models=[base_model],
|
28 |
diffusion=None,
|
29 |
num_points=1024,
|
30 |
aux_channels=[],
|
31 |
+
guidance_scale=3.0,
|
32 |
)
|
33 |
|
34 |
+
def save_point_cloud_open3d(pc_xyz: np.ndarray, out_path: str):
|
35 |
+
pcd = o3d.geometry.PointCloud()
|
36 |
+
pcd.points = o3d.utility.Vector3dVector(pc_xyz)
|
37 |
+
o3d.io.write_point_cloud(out_path, pcd)
|
38 |
+
|
39 |
@app.route("/generate", methods=["POST"])
|
40 |
def generate():
|
41 |
+
prompt = request.json.get("prompt", "a toy robot")
|
42 |
+
samples = sampler.sample_batch(batch_size=1, model_kwargs={"texts": [prompt]})
|
43 |
pc = sampler.output_to_point_clouds(samples)[0]
|
44 |
|
45 |
tmp_dir = tempfile.mkdtemp()
|
46 |
output_path = os.path.join(tmp_dir, "model.ply")
|
47 |
+
save_point_cloud_open3d(pc['coords'], output_path)
|
48 |
+
|
49 |
+
return send_file(output_path, as_attachment=True)
|
50 |
|
51 |
if __name__ == "__main__":
|
52 |
app.run(host="0.0.0.0", port=7860)
|