mac9087 commited on
Commit
2e7de92
·
verified ·
1 Parent(s): 147bbe2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -1,12 +1,15 @@
1
- # app.py (simplified)
2
- from flask import Flask, request, jsonify
3
- from flask_cors import CORS
4
  import torch
 
 
5
  from point_e.diffusion.sampler import PointCloudSampler
6
  from point_e.models.configs import MODEL_CONFIGS, model_from_config
7
- from point_e.util.plotting import save_point_cloud
8
- import tempfile
9
- import os
 
10
 
11
  app = Flask(__name__)
12
  CORS(app)
@@ -18,26 +21,32 @@ base_model = model_from_config(MODEL_CONFIGS["base40M-textvec"], device)
18
  base_model.load_state_dict(torch.load("base40M-textvec.pt", map_location=device))
19
  base_model.eval()
20
 
21
- print("Loading sampler...")
22
  sampler = PointCloudSampler(
23
  device=device,
24
  models=[base_model],
25
  diffusion=None,
26
  num_points=1024,
27
  aux_channels=[],
28
- guidance_scale=1.0,
29
  )
30
 
 
 
 
 
 
31
  @app.route("/generate", methods=["POST"])
32
  def generate():
33
- prompt = request.json.get("prompt", "a red apple")
34
- samples = sampler.sample_batch(prompt=[prompt])
35
  pc = sampler.output_to_point_clouds(samples)[0]
36
 
37
  tmp_dir = tempfile.mkdtemp()
38
  output_path = os.path.join(tmp_dir, "model.ply")
39
- save_point_cloud(pc, output_path)
40
- return jsonify({"model_path": output_path})
 
41
 
42
  if __name__ == "__main__":
43
  app.run(host="0.0.0.0", port=7860)
 
1
+ import os
2
+ import tempfile
3
+ import numpy as np
4
  import torch
5
+ from flask import Flask, request, jsonify, send_file
6
+ from flask_cors import CORS
7
  from point_e.diffusion.sampler import PointCloudSampler
8
  from point_e.models.configs import MODEL_CONFIGS, model_from_config
9
+ import open3d as o3d
10
+
11
+ # Fix matplotlib temp issue
12
+ os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
13
 
14
  app = Flask(__name__)
15
  CORS(app)
 
21
  base_model.load_state_dict(torch.load("base40M-textvec.pt", map_location=device))
22
  base_model.eval()
23
 
24
+ print("Creating sampler...")
25
  sampler = PointCloudSampler(
26
  device=device,
27
  models=[base_model],
28
  diffusion=None,
29
  num_points=1024,
30
  aux_channels=[],
31
+ guidance_scale=3.0,
32
  )
33
 
34
+ def save_point_cloud_open3d(pc_xyz: np.ndarray, out_path: str):
35
+ pcd = o3d.geometry.PointCloud()
36
+ pcd.points = o3d.utility.Vector3dVector(pc_xyz)
37
+ o3d.io.write_point_cloud(out_path, pcd)
38
+
39
  @app.route("/generate", methods=["POST"])
40
  def generate():
41
+ prompt = request.json.get("prompt", "a toy robot")
42
+ samples = sampler.sample_batch(batch_size=1, model_kwargs={"texts": [prompt]})
43
  pc = sampler.output_to_point_clouds(samples)[0]
44
 
45
  tmp_dir = tempfile.mkdtemp()
46
  output_path = os.path.join(tmp_dir, "model.ply")
47
+ save_point_cloud_open3d(pc['coords'], output_path)
48
+
49
+ return send_file(output_path, as_attachment=True)
50
 
51
  if __name__ == "__main__":
52
  app.run(host="0.0.0.0", port=7860)