mac9087 commited on
Commit
ca7e93f
·
verified ·
1 Parent(s): 97be4f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -119
app.py CHANGED
@@ -1,136 +1,104 @@
1
- import os
2
- import tempfile
3
- import numpy as np
4
- import torch
5
  from flask import Flask, request, jsonify, send_file
6
  from flask_cors import CORS
7
- from point_e.diffusion.sampler import PointCloudSampler
8
- from point_e.models.configs import MODEL_CONFIGS, model_from_config
9
- import open3d as o3d
10
- import logging
11
-
12
- # Set up logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
- # Fix matplotlib and clip cache path issues
17
- os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
18
- os.environ['HOME'] = '/tmp'
19
-
20
- # Set cache directories explicitly
21
- cache_dir = os.environ.get('POINT_E_CACHE_DIR', '/tmp/point_e_model_cache')
22
- clip_cache_dir = os.environ.get('CLIP_MODEL_DIR', '/tmp/clip_models')
23
 
24
- # Ensure cache directories exist and are writable
25
- for directory in [cache_dir, clip_cache_dir, '/tmp/point_e_models']:
26
- if not os.path.exists(directory):
27
- os.makedirs(directory, exist_ok=True)
28
- # Make sure permissions are correct
29
- os.chmod(directory, 0o777)
30
 
31
  app = Flask(__name__)
32
  CORS(app)
33
 
34
- # Determine if CUDA is available (but prepare for CPU operation)
35
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
- logger.info(f"Using device: {device}")
37
-
38
- # Set up model directory for downloads
39
- model_path = "/tmp/point_e_models/base40M-textvec.pt"
40
- if not os.path.exists(model_path):
41
- logger.info("Model weights not found. Downloading...")
42
- # Use torch.hub to download the model from original repo
43
- # This is a placeholder. In reality, you'd need to download the weights from the official source
44
- # or include them in your repository
45
- try:
46
- # For demonstration purposes, we'll assume the model is included with the repo
47
- # In a real setup, you'd need to download it or include it
48
- pass
49
- except Exception as e:
50
- logger.error(f"Failed to download model: {e}")
51
- raise
52
-
53
- # Load the model
54
- logger.info("Loading base model...")
55
- try:
56
- base_model = model_from_config(
57
- MODEL_CONFIGS["base40M-textvec"],
58
- device=device,
59
- cache_dir=cache_dir # Pass cache directory explicitly
60
- )
61
- base_model.load_state_dict(torch.load(model_path, map_location=device))
62
- base_model.eval()
63
- logger.info("Base model loaded successfully")
64
- except Exception as e:
65
- logger.error(f"Failed to load model: {e}")
66
- raise
67
-
68
- # Create sampler with memory-efficient settings for 2vCPU/18GB RAM
69
- logger.info("Creating sampler...")
70
- sampler = PointCloudSampler(
71
- device=device,
72
- models=[base_model],
73
- diffusion=None,
74
- num_points=1024, # Lower point count for memory efficiency
75
- aux_channels=[],
76
- guidance_scale=3.0,
77
- )
78
 
79
- def save_point_cloud_open3d(pc_xyz: np.ndarray, out_path: str):
80
- """Save a point cloud to a PLY file using Open3D."""
81
- pcd = o3d.geometry.PointCloud()
82
- pcd.points = o3d.utility.Vector3dVector(pc_xyz)
83
- o3d.io.write_point_cloud(out_path, pcd)
84
- logger.info(f"Point cloud saved to {out_path}")
85
 
86
- @app.route("/", methods=["GET"])
87
- def index():
88
- """Provide a simple interface for the API."""
89
- return """
90
- <html>
91
- <head><title>Point-E Text-to-3D API</title></head>
92
- <body>
93
- <h1>Point-E Text-to-3D Generator</h1>
94
- <form action="/generate" method="post">
95
- <label for="prompt">Enter a description:</label>
96
- <input type="text" id="prompt" name="prompt" value="a toy robot">
97
- <input type="submit" value="Generate 3D Model">
98
- </form>
99
- <p>Or use the API endpoint at POST /generate with JSON {"prompt": "your description"}</p>
100
- </body>
101
- </html>
102
- """
103
 
104
- @app.route("/generate", methods=["POST"])
105
- def generate():
106
- """Generate a 3D point cloud from a text description."""
107
- if request.is_json:
108
- prompt = request.json.get("prompt", "a toy robot")
109
- else:
110
- prompt = request.form.get("prompt", "a toy robot")
111
-
112
- logger.info(f"Generating point cloud for prompt: '{prompt}'")
113
-
114
  try:
115
- # Generate with torch.no_grad() to save memory
116
- with torch.no_grad():
117
- samples = sampler.sample_batch(batch_size=1, model_kwargs={"texts": [prompt]})
118
- pc = sampler.output_to_point_clouds(samples)[0]
 
 
 
 
 
 
 
119
 
120
- # Save the point cloud
121
- tmp_dir = tempfile.mkdtemp()
122
- output_path = os.path.join(tmp_dir, f"{prompt.replace(' ', '_')}.ply")
123
- save_point_cloud_open3d(pc['coords'], output_path)
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- return send_file(output_path, as_attachment=True, download_name=f"{prompt.replace(' ', '_')}.ply")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
127
- logger.error(f"Error generating point cloud: {e}")
128
  return jsonify({"error": str(e)}), 500
129
 
130
- @app.route("/health", methods=["GET"])
131
- def health_check():
132
- """Simple health check endpoint."""
133
- return jsonify({"status": "healthy"}), 200
 
 
134
 
135
- if __name__ == "__main__":
136
- app.run(host="0.0.0.0", port=7860)
 
 
 
 
 
1
  from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
3
+ import torch
4
+ import numpy as np
5
+ import trimesh
6
+ import os
7
+ from io import BytesIO
8
+ import base64
9
+ from PIL import Image
10
+ import uuid
11
+ import time
 
 
 
 
 
 
 
12
 
13
+ # Import Shap-E
14
+ from shap_e.diffusion.sample import sample_latents
15
+ from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
16
+ from shap_e.models.download import load_model, load_config
17
+ from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
 
18
 
19
  app = Flask(__name__)
20
  CORS(app)
21
 
22
+ # Create output directory if it doesn't exist
23
+ os.makedirs("outputs", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Load models only once at startup
26
+ print("Loading models...")
27
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28
+ print(f"Using device: {device}")
 
 
29
 
30
+ xm = load_model('transmitter', device=device)
31
+ model = load_model('text300M', device=device)
32
+ diffusion = diffusion_from_config(load_config('diffusion'))
33
+ print("Models loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ @app.route('/generate', methods=['POST'])
36
+ def generate_3d():
 
 
 
 
 
 
 
 
37
  try:
38
+ # Get the prompt from the request
39
+ data = request.json
40
+ if not data or 'prompt' not in data:
41
+ return jsonify({"error": "No prompt provided"}), 400
42
+
43
+ prompt = data['prompt']
44
+ print(f"Received prompt: {prompt}")
45
+
46
+ # Set parameters
47
+ batch_size = 1
48
+ guidance_scale = 15.0
49
 
50
+ # Generate latents with the text-to-3D model
51
+ latents = sample_latents(
52
+ batch_size=batch_size,
53
+ model=model,
54
+ diffusion=diffusion,
55
+ guidance_scale=guidance_scale,
56
+ model_kwargs=dict(texts=[prompt] * batch_size),
57
+ progress=True,
58
+ clip_denoised=True,
59
+ use_fp16=True,
60
+ use_karras=True,
61
+ karras_steps=64,
62
+ sigma_min=1e-3,
63
+ sigma_max=160,
64
+ s_churn=0,
65
+ )
66
 
67
+ # Generate a unique filename
68
+ filename = f"outputs/{uuid.uuid4()}"
69
+
70
+ # Convert latent to mesh
71
+ t0 = time.time()
72
+ mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
73
+ print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
74
+
75
+ # Save as GLB
76
+ glb_path = f"{filename}.glb"
77
+ mesh.write_glb(glb_path)
78
+
79
+ # Save as OBJ
80
+ obj_path = f"{filename}.obj"
81
+ with open(obj_path, 'w') as f:
82
+ mesh.write_obj(f)
83
+
84
+ # Return paths to the generated files
85
+ return jsonify({
86
+ "success": True,
87
+ "message": "3D model generated successfully",
88
+ "glb_url": f"/download/{os.path.basename(glb_path)}",
89
+ "obj_url": f"/download/{os.path.basename(obj_path)}"
90
+ })
91
+
92
  except Exception as e:
93
+ print(f"Error: {str(e)}")
94
  return jsonify({"error": str(e)}), 500
95
 
96
+ @app.route('/download/<filename>', methods=['GET'])
97
+ def download_file(filename):
98
+ try:
99
+ return send_file(f"outputs/{filename}", as_attachment=True)
100
+ except Exception as e:
101
+ return jsonify({"error": str(e)}), 404
102
 
103
+ if __name__ == '__main__':
104
+ app.run(host='0.0.0.0', port=7860, debug=True)