Spaces:
Sleeping
Sleeping
File size: 5,177 Bytes
2e7de92 ca7e93f d6ba12d ab52342 ca7e93f d6ba12d ab52342 cb944d7 ca7e93f 147bbe2 d6ba12d ca7e93f cef8b58 d6ba12d 2e7de92 ca7e93f cef8b58 d6ba12d ca7e93f d6ba12d ca7e93f cef8b58 ca7e93f d6ba12d ca7e93f d6ba12d ca7e93f d6ba12d ca7e93f d6ba12d cef8b58 ca7e93f d6ba12d ca7e93f d6ba12d ca7e93f d6ba12d ca7e93f d6ba12d ca7e93f cef8b58 d6ba12d cef8b58 2e7de92 ca7e93f cb944d7 d6ba12d ca7e93f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import torch
import numpy as np
import trimesh
import os
from io import BytesIO
import base64
from PIL import Image
import uuid
import time
import sys
# Import Shap-E
print("Importing Shap-E modules...")
try:
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
print("Shap-E modules imported successfully!")
except Exception as e:
print(f"Error importing Shap-E modules: {e}")
sys.exit(1)
app = Flask(__name__)
CORS(app)
# Create output directory if it doesn't exist
os.makedirs("outputs", exist_ok=True)
# Use lazy loading for models
print("Setting up device...")
device = torch.device('cpu') # Force CPU for Hugging Face Spaces
print(f"Using device: {device}")
# Global variables for models (will be loaded on first request)
xm = None
model = None
diffusion = None
def load_models_if_needed():
global xm, model, diffusion
if xm is None or model is None or diffusion is None:
print("Loading models for the first time...")
try:
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
raise
@app.route('/generate', methods=['POST'])
def generate_3d():
try:
# Load models if not already loaded
load_models_if_needed()
# Get the prompt from the request
data = request.json
if not data or 'prompt' not in data:
return jsonify({"error": "No prompt provided"}), 400
prompt = data['prompt']
print(f"Received prompt: {prompt}")
# Set parameters for CPU performance (reduced steps)
batch_size = 1
guidance_scale = 15.0
# Generate latents with the text-to-3D model
print("Starting latent generation...")
latents = sample_latents(
batch_size=batch_size,
model=model,
diffusion=diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(texts=[prompt] * batch_size),
progress=True,
clip_denoised=True,
use_fp16=False, # CPU doesn't support fp16
use_karras=True,
karras_steps=32, # Reduced steps for CPU
sigma_min=1e-3,
sigma_max=160,
s_churn=0,
)
print("Latent generation complete!")
# Generate a unique filename
filename = f"outputs/{uuid.uuid4()}"
# Convert latent to mesh
print("Decoding mesh...")
t0 = time.time()
mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
# Save as GLB
print("Saving as GLB...")
glb_path = f"{filename}.glb"
mesh.write_glb(glb_path)
# Save as OBJ
print("Saving as OBJ...")
obj_path = f"{filename}.obj"
with open(obj_path, 'w') as f:
mesh.write_obj(f)
print("Files saved successfully!")
# Return paths to the generated files
return jsonify({
"success": True,
"message": "3D model generated successfully",
"glb_url": f"/download/{os.path.basename(glb_path)}",
"obj_url": f"/download/{os.path.basename(obj_path)}"
})
except Exception as e:
print(f"Error during generation: {str(e)}")
import traceback
traceback.print_exc()
return jsonify({"error": str(e)}), 500
@app.route('/download/<filename>', methods=['GET'])
def download_file(filename):
try:
return send_file(f"outputs/{filename}", as_attachment=True)
except Exception as e:
return jsonify({"error": str(e)}), 404
@app.route('/health', methods=['GET'])
def health_check():
"""Simple health check endpoint to verify the app is running"""
return jsonify({"status": "ok", "message": "Service is running"})
@app.route('/', methods=['GET'])
def home():
"""Landing page with usage instructions"""
return """
<html>
<head><title>Text to 3D API</title></head>
<body>
<h1>Text to 3D API</h1>
<p>This is a simple API that converts text prompts to 3D models.</p>
<h2>How to use:</h2>
<pre>
POST /generate
Content-Type: application/json
{
"prompt": "A futuristic building"
}
</pre>
<p>The response will include URLs to download the generated models in GLB and OBJ formats.</p>
</body>
</html>
"""
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True) |