magical-box / app.py
mac9087's picture
Update app.py
d6ba12d verified
raw
history blame
5.18 kB
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import torch
import numpy as np
import trimesh
import os
from io import BytesIO
import base64
from PIL import Image
import uuid
import time
import sys
# Import Shap-E
print("Importing Shap-E modules...")
try:
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
print("Shap-E modules imported successfully!")
except Exception as e:
print(f"Error importing Shap-E modules: {e}")
sys.exit(1)
app = Flask(__name__)
CORS(app)
# Create output directory if it doesn't exist
os.makedirs("outputs", exist_ok=True)
# Use lazy loading for models
print("Setting up device...")
device = torch.device('cpu') # Force CPU for Hugging Face Spaces
print(f"Using device: {device}")
# Global variables for models (will be loaded on first request)
xm = None
model = None
diffusion = None
def load_models_if_needed():
global xm, model, diffusion
if xm is None or model is None or diffusion is None:
print("Loading models for the first time...")
try:
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))
print("Models loaded successfully!")
except Exception as e:
print(f"Error loading models: {e}")
raise
@app.route('/generate', methods=['POST'])
def generate_3d():
try:
# Load models if not already loaded
load_models_if_needed()
# Get the prompt from the request
data = request.json
if not data or 'prompt' not in data:
return jsonify({"error": "No prompt provided"}), 400
prompt = data['prompt']
print(f"Received prompt: {prompt}")
# Set parameters for CPU performance (reduced steps)
batch_size = 1
guidance_scale = 15.0
# Generate latents with the text-to-3D model
print("Starting latent generation...")
latents = sample_latents(
batch_size=batch_size,
model=model,
diffusion=diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(texts=[prompt] * batch_size),
progress=True,
clip_denoised=True,
use_fp16=False, # CPU doesn't support fp16
use_karras=True,
karras_steps=32, # Reduced steps for CPU
sigma_min=1e-3,
sigma_max=160,
s_churn=0,
)
print("Latent generation complete!")
# Generate a unique filename
filename = f"outputs/{uuid.uuid4()}"
# Convert latent to mesh
print("Decoding mesh...")
t0 = time.time()
mesh = decode_latent_mesh(xm, latents[0]).tri_mesh()
print(f"Mesh decoded in {time.time() - t0:.2f} seconds")
# Save as GLB
print("Saving as GLB...")
glb_path = f"{filename}.glb"
mesh.write_glb(glb_path)
# Save as OBJ
print("Saving as OBJ...")
obj_path = f"{filename}.obj"
with open(obj_path, 'w') as f:
mesh.write_obj(f)
print("Files saved successfully!")
# Return paths to the generated files
return jsonify({
"success": True,
"message": "3D model generated successfully",
"glb_url": f"/download/{os.path.basename(glb_path)}",
"obj_url": f"/download/{os.path.basename(obj_path)}"
})
except Exception as e:
print(f"Error during generation: {str(e)}")
import traceback
traceback.print_exc()
return jsonify({"error": str(e)}), 500
@app.route('/download/<filename>', methods=['GET'])
def download_file(filename):
try:
return send_file(f"outputs/{filename}", as_attachment=True)
except Exception as e:
return jsonify({"error": str(e)}), 404
@app.route('/health', methods=['GET'])
def health_check():
"""Simple health check endpoint to verify the app is running"""
return jsonify({"status": "ok", "message": "Service is running"})
@app.route('/', methods=['GET'])
def home():
"""Landing page with usage instructions"""
return """
<html>
<head><title>Text to 3D API</title></head>
<body>
<h1>Text to 3D API</h1>
<p>This is a simple API that converts text prompts to 3D models.</p>
<h2>How to use:</h2>
<pre>
POST /generate
Content-Type: application/json
{
"prompt": "A futuristic building"
}
</pre>
<p>The response will include URLs to download the generated models in GLB and OBJ formats.</p>
</body>
</html>
"""
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)