Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify, send_file | |
from flask_cors import CORS | |
import torch | |
import numpy as np | |
import trimesh | |
import os | |
from io import BytesIO | |
import base64 | |
from PIL import Image | |
import uuid | |
import time | |
import sys | |
# Import Shap-E | |
print("Importing Shap-E modules...") | |
try: | |
from shap_e.diffusion.sample import sample_latents | |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
from shap_e.models.download import load_model, load_config | |
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh | |
print("Shap-E modules imported successfully!") | |
except Exception as e: | |
print(f"Error importing Shap-E modules: {e}") | |
sys.exit(1) | |
app = Flask(__name__) | |
CORS(app) | |
# Create output directory if it doesn't exist | |
os.makedirs("outputs", exist_ok=True) | |
# Use lazy loading for models | |
print("Setting up device...") | |
device = torch.device('cpu') # Force CPU for Hugging Face Spaces | |
print(f"Using device: {device}") | |
# Global variables for models (will be loaded on first request) | |
xm = None | |
model = None | |
diffusion = None | |
def load_models_if_needed(): | |
global xm, model, diffusion | |
if xm is None or model is None or diffusion is None: | |
print("Loading models for the first time...") | |
try: | |
xm = load_model('transmitter', device=device) | |
model = load_model('text300M', device=device) | |
diffusion = diffusion_from_config(load_config('diffusion')) | |
print("Models loaded successfully!") | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
raise | |
def generate_3d(): | |
try: | |
# Load models if not already loaded | |
load_models_if_needed() | |
# Get the prompt from the request | |
data = request.json | |
if not data or 'prompt' not in data: | |
return jsonify({"error": "No prompt provided"}), 400 | |
prompt = data['prompt'] | |
print(f"Received prompt: {prompt}") | |
# Set parameters for CPU performance (reduced steps) | |
batch_size = 1 | |
guidance_scale = 15.0 | |
# Generate latents with the text-to-3D model | |
print("Starting latent generation...") | |
latents = sample_latents( | |
batch_size=batch_size, | |
model=model, | |
diffusion=diffusion, | |
guidance_scale=guidance_scale, | |
model_kwargs=dict(texts=[prompt] * batch_size), | |
progress=True, | |
clip_denoised=True, | |
use_fp16=False, # CPU doesn't support fp16 | |
use_karras=True, | |
karras_steps=32, # Reduced steps for CPU | |
sigma_min=1e-3, | |
sigma_max=160, | |
s_churn=0, | |
) | |
print("Latent generation complete!") | |
# Generate a unique filename | |
filename = f"outputs/{uuid.uuid4()}" | |
# Convert latent to mesh | |
print("Decoding mesh...") | |
t0 = time.time() | |
mesh = decode_latent_mesh(xm, latents[0]).tri_mesh() | |
print(f"Mesh decoded in {time.time() - t0:.2f} seconds") | |
# Save as GLB | |
print("Saving as GLB...") | |
glb_path = f"{filename}.glb" | |
mesh.write_glb(glb_path) | |
# Save as OBJ | |
print("Saving as OBJ...") | |
obj_path = f"{filename}.obj" | |
with open(obj_path, 'w') as f: | |
mesh.write_obj(f) | |
print("Files saved successfully!") | |
# Return paths to the generated files | |
return jsonify({ | |
"success": True, | |
"message": "3D model generated successfully", | |
"glb_url": f"/download/{os.path.basename(glb_path)}", | |
"obj_url": f"/download/{os.path.basename(obj_path)}" | |
}) | |
except Exception as e: | |
print(f"Error during generation: {str(e)}") | |
import traceback | |
traceback.print_exc() | |
return jsonify({"error": str(e)}), 500 | |
def download_file(filename): | |
try: | |
return send_file(f"outputs/{filename}", as_attachment=True) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 404 | |
def health_check(): | |
"""Simple health check endpoint to verify the app is running""" | |
return jsonify({"status": "ok", "message": "Service is running"}) | |
def home(): | |
"""Landing page with usage instructions""" | |
return """ | |
<html> | |
<head><title>Text to 3D API</title></head> | |
<body> | |
<h1>Text to 3D API</h1> | |
<p>This is a simple API that converts text prompts to 3D models.</p> | |
<h2>How to use:</h2> | |
<pre> | |
POST /generate | |
Content-Type: application/json | |
{ | |
"prompt": "A futuristic building" | |
} | |
</pre> | |
<p>The response will include URLs to download the generated models in GLB and OBJ formats.</p> | |
</body> | |
</html> | |
""" | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860, debug=True) |