File size: 5,177 Bytes
2e7de92
 
ca7e93f
 
 
 
 
 
 
 
 
d6ba12d
ab52342
ca7e93f
d6ba12d
 
 
 
 
 
 
 
 
 
ab52342
cb944d7
 
 
ca7e93f
 
147bbe2
d6ba12d
 
 
ca7e93f
cef8b58
d6ba12d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e7de92
ca7e93f
 
cef8b58
d6ba12d
 
 
ca7e93f
 
 
 
 
 
 
 
d6ba12d
ca7e93f
 
cef8b58
ca7e93f
d6ba12d
ca7e93f
 
 
 
 
 
 
 
d6ba12d
ca7e93f
d6ba12d
ca7e93f
 
 
 
d6ba12d
cef8b58
ca7e93f
 
 
 
d6ba12d
ca7e93f
 
 
 
 
d6ba12d
ca7e93f
 
 
 
d6ba12d
ca7e93f
 
 
d6ba12d
 
ca7e93f
 
 
 
 
 
 
 
 
cef8b58
d6ba12d
 
 
cef8b58
2e7de92
ca7e93f
 
 
 
 
 
cb944d7
d6ba12d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca7e93f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import torch
import numpy as np
import trimesh
import os
from io import BytesIO
import base64
from PIL import Image
import uuid
import time
import sys

# Import Shap-E
print("Importing Shap-E modules...")
try:
    from shap_e.diffusion.sample import sample_latents
    from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
    from shap_e.models.download import load_model, load_config
    from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
    print("Shap-E modules imported successfully!")
except Exception as e:
    print(f"Error importing Shap-E modules: {e}")
    sys.exit(1)

app = Flask(__name__)
CORS(app)

# Create output directory if it doesn't exist
os.makedirs("outputs", exist_ok=True)

# Use lazy loading for models
print("Setting up device...")
device = torch.device('cpu')  # Force CPU for Hugging Face Spaces
print(f"Using device: {device}")

# Global variables for models (will be loaded on first request)
xm = None
model = None
diffusion = None

def load_models_if_needed():
    global xm, model, diffusion
    if xm is None or model is None or diffusion is None:
        print("Loading models for the first time...")
        try:
            xm = load_model('transmitter', device=device)
            model = load_model('text300M', device=device)
            diffusion = diffusion_from_config(load_config('diffusion'))
            print("Models loaded successfully!")
        except Exception as e:
            print(f"Error loading models: {e}")
            raise

@app.route('/generate', methods=['POST'])
def generate_3d():
    try:
        # Load models if not already loaded
        load_models_if_needed()
        
        # Get the prompt from the request
        data = request.json
        if not data or 'prompt' not in data:
            return jsonify({"error": "No prompt provided"}), 400
        
        prompt = data['prompt']
        print(f"Received prompt: {prompt}")
        
        # Set parameters for CPU performance (reduced steps)
        batch_size = 1
        guidance_scale = 15.0
        
        # Generate latents with the text-to-3D model
        print("Starting latent generation...")
        latents = sample_latents(
            batch_size=batch_size,
            model=model,
            diffusion=diffusion,
            guidance_scale=guidance_scale,
            model_kwargs=dict(texts=[prompt] * batch_size),
            progress=True,
            clip_denoised=True,
            use_fp16=False,  # CPU doesn't support fp16
            use_karras=True,
            karras_steps=32,  # Reduced steps for CPU
            sigma_min=1e-3,
            sigma_max=160,
            s_churn=0,
        )
        print("Latent generation complete!")
        
        # Generate a unique filename
        filename = f"outputs/{uuid.uuid4()}"
        
        # Convert latent to mesh
        print("Decoding mesh...")
        t0 = time.time()
        mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
        print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
        
        # Save as GLB
        print("Saving as GLB...")
        glb_path = f"{filename}.glb"
        mesh.write_glb(glb_path)
        
        # Save as OBJ
        print("Saving as OBJ...")
        obj_path = f"{filename}.obj"
        with open(obj_path, 'w') as f:
            mesh.write_obj(f)
        
        print("Files saved successfully!")
            
        # Return paths to the generated files
        return jsonify({
            "success": True,
            "message": "3D model generated successfully",
            "glb_url": f"/download/{os.path.basename(glb_path)}",
            "obj_url": f"/download/{os.path.basename(obj_path)}"
        })
    
    except Exception as e:
        print(f"Error during generation: {str(e)}")
        import traceback
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

@app.route('/download/<filename>', methods=['GET'])
def download_file(filename):
    try:
        return send_file(f"outputs/{filename}", as_attachment=True)
    except Exception as e:
        return jsonify({"error": str(e)}), 404

@app.route('/health', methods=['GET'])
def health_check():
    """Simple health check endpoint to verify the app is running"""
    return jsonify({"status": "ok", "message": "Service is running"})

@app.route('/', methods=['GET'])
def home():
    """Landing page with usage instructions"""
    return """
    <html>
        <head><title>Text to 3D API</title></head>
        <body>
            <h1>Text to 3D API</h1>
            <p>This is a simple API that converts text prompts to 3D models.</p>
            <h2>How to use:</h2>
            <pre>
            POST /generate
            Content-Type: application/json
            
            {
                "prompt": "A futuristic building"
            }
            </pre>
            <p>The response will include URLs to download the generated models in GLB and OBJ formats.</p>
        </body>
    </html>
    """

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860, debug=True)