mac9087 commited on
Commit
cb944d7
·
verified ·
1 Parent(s): 868f530

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from flask import Flask, request, jsonify
3
+ from flask_cors import CORS
4
+ import torch
5
+ from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
6
+ from point_e.models.download import load_checkpoint
7
+ from point_e.models.text_to_image import create_model_and_diffusion as create_t2i
8
+ from point_e.models.image_to_3d import create_model_and_diffusion as create_i2p
9
+ from point_e.util.plotting import save_point_cloud
10
+ import tempfile
11
+ import os
12
+
13
+ app = Flask(__name__)
14
+ CORS(app)
15
+
16
+ # Load models once on startup
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ print(f"Using device: {device}")
19
+
20
+ print("Loading text-to-image model...")
21
+ t2i_model, t2i_diffusion = create_t2i(DIFFUSION_CONFIGS['base40'], device)
22
+ t2i_model.load_state_dict(load_checkpoint('base40', device))
23
+ t2i_model.eval()
24
+
25
+ print("Loading image-to-3d model...")
26
+ i2p_model, i2p_diffusion = create_i2p(DIFFUSION_CONFIGS['base40'], device)
27
+ i2p_model.load_state_dict(load_checkpoint('base40-img2pc', device))
28
+ i2p_model.eval()
29
+
30
+ @app.route('/generate', methods=['POST'])
31
+ def generate():
32
+ data = request.json
33
+ prompt = data.get('prompt', 'a red car')
34
+
35
+ # Generate image from text
36
+ print("Generating image from text...")
37
+ samples = t2i_diffusion.sample_loop(
38
+ model=t2i_model,
39
+ shape=(1, 3, 64, 64),
40
+ device=device,
41
+ clip_denoised=True,
42
+ model_kwargs={"text": [prompt]},
43
+ progress=True
44
+ )
45
+ image = t2i_model.output_to_image(samples[0])
46
+
47
+ # Generate point cloud from image
48
+ print("Generating point cloud...")
49
+ pc = i2p_diffusion.sample_loop(
50
+ model=i2p_model,
51
+ shape=(1, 3, 64, 64),
52
+ device=device,
53
+ clip_denoised=True,
54
+ model_kwargs={"images": [image]},
55
+ progress=True
56
+ )
57
+
58
+ # Save to temporary file
59
+ tmp_dir = tempfile.mkdtemp()
60
+ output_path = os.path.join(tmp_dir, 'output.ply')
61
+ save_point_cloud(pc[0], output_path)
62
+
63
+ return jsonify({'model_path': output_path})
64
+
65
+ if __name__ == '__main__':
66
+ app.run(host='0.0.0.0', port=7860)