mac9087 commited on
Commit
147bbe2
·
verified ·
1 Parent(s): 1e13ee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -50
app.py CHANGED
@@ -1,11 +1,9 @@
1
- # app.py
2
  from flask import Flask, request, jsonify
3
  from flask_cors import CORS
4
  import torch
5
- from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
6
- from point_e.models.download import load_checkpoint
7
- from point_e.models.text_to_image import create_model_and_diffusion as create_t2i
8
- from point_e.models.image_to_3d import create_model_and_diffusion as create_i2p
9
  from point_e.util.plotting import save_point_cloud
10
  import tempfile
11
  import os
@@ -13,54 +11,33 @@ import os
13
  app = Flask(__name__)
14
  CORS(app)
15
 
16
- # Load models once on startup
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
- print(f"Using device: {device}")
19
 
20
- print("Loading text-to-image model...")
21
- t2i_model, t2i_diffusion = create_t2i(DIFFUSION_CONFIGS['base40'], device)
22
- t2i_model.load_state_dict(load_checkpoint('base40', device))
23
- t2i_model.eval()
24
-
25
- print("Loading image-to-3d model...")
26
- i2p_model, i2p_diffusion = create_i2p(DIFFUSION_CONFIGS['base40'], device)
27
- i2p_model.load_state_dict(load_checkpoint('base40-img2pc', device))
28
- i2p_model.eval()
29
-
30
- @app.route('/generate', methods=['POST'])
 
 
 
 
 
31
  def generate():
32
- data = request.json
33
- prompt = data.get('prompt', 'a red car')
34
-
35
- # Generate image from text
36
- print("Generating image from text...")
37
- samples = t2i_diffusion.sample_loop(
38
- model=t2i_model,
39
- shape=(1, 3, 64, 64),
40
- device=device,
41
- clip_denoised=True,
42
- model_kwargs={"text": [prompt]},
43
- progress=True
44
- )
45
- image = t2i_model.output_to_image(samples[0])
46
 
47
- # Generate point cloud from image
48
- print("Generating point cloud...")
49
- pc = i2p_diffusion.sample_loop(
50
- model=i2p_model,
51
- shape=(1, 3, 64, 64),
52
- device=device,
53
- clip_denoised=True,
54
- model_kwargs={"images": [image]},
55
- progress=True
56
- )
57
-
58
- # Save to temporary file
59
  tmp_dir = tempfile.mkdtemp()
60
- output_path = os.path.join(tmp_dir, 'output.ply')
61
- save_point_cloud(pc[0], output_path)
62
-
63
- return jsonify({'model_path': output_path})
64
 
65
- if __name__ == '__main__':
66
- app.run(host='0.0.0.0', port=7860)
 
1
+ # app.py (simplified)
2
  from flask import Flask, request, jsonify
3
  from flask_cors import CORS
4
  import torch
5
+ from point_e.diffusion.sampler import PointCloudSampler
6
+ from point_e.models.configs import MODEL_CONFIGS, model_from_config
 
 
7
  from point_e.util.plotting import save_point_cloud
8
  import tempfile
9
  import os
 
11
  app = Flask(__name__)
12
  CORS(app)
13
 
 
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
15
 
16
+ print("Loading base model...")
17
+ base_model = model_from_config(MODEL_CONFIGS["base40M-textvec"], device)
18
+ base_model.load_state_dict(torch.load("base40M-textvec.pt", map_location=device))
19
+ base_model.eval()
20
+
21
+ print("Loading sampler...")
22
+ sampler = PointCloudSampler(
23
+ device=device,
24
+ models=[base_model],
25
+ diffusion=None,
26
+ num_points=1024,
27
+ aux_channels=[],
28
+ guidance_scale=1.0,
29
+ )
30
+
31
+ @app.route("/generate", methods=["POST"])
32
  def generate():
33
+ prompt = request.json.get("prompt", "a red apple")
34
+ samples = sampler.sample_batch(prompt=[prompt])
35
+ pc = sampler.output_to_point_clouds(samples)[0]
 
 
 
 
 
 
 
 
 
 
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  tmp_dir = tempfile.mkdtemp()
38
+ output_path = os.path.join(tmp_dir, "model.ply")
39
+ save_point_cloud(pc, output_path)
40
+ return jsonify({"model_path": output_path})
 
41
 
42
+ if __name__ == "__main__":
43
+ app.run(host="0.0.0.0", port=7860)