Spaces:
Runtime error
Runtime error
import io | |
import base64 | |
import torch | |
import os | |
from flask import Flask, request, jsonify | |
from diffusers import StableDiffusionPipeline # Placeholder; adjust based on SF3D docs | |
from PIL import Image | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = Flask(__name__) | |
# Load the model once at startup (on CPU) with token from environment | |
try: | |
logger.info("Loading Stable Fast 3D pipeline...") | |
token = os.getenv("HF_TOKEN") # Retrieve token from environment variable | |
if not token: | |
raise ValueError("HF_TOKEN environment variable not set") | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-fast-3d", | |
torch_dtype=torch.float32, | |
cache_dir="/tmp/hf_home", | |
token=token, # Use the environment variable token | |
) | |
pipe.to("cpu") | |
logger.info("=== Application Startup at CPU mode =====") | |
except Exception as e: | |
logger.error(f"Error loading model: {e}", exc_info=True) | |
pipe = None | |
def pil_to_base64(image): | |
buffer = io.BytesIO() | |
image.save(buffer, format="PNG") | |
return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
def home(): | |
return "Stable Fast 3D CPU API is running!" | |
def generate(): | |
if pipe is None: | |
return jsonify({"error": "Model not loaded"}), 500 | |
try: | |
data = request.get_json() | |
image_data = data.get("image") | |
if not image_data: | |
return jsonify({"error": "No image provided"}), 400 | |
if image_data.startswith("data:image"): | |
image_data = image_data.split(",")[1] | |
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB") | |
logger.info("Processing image with pipeline...") | |
result = pipe(image) # Adjust based on SF3D documentation | |
output_mesh = result.mesh # Hypothetical; check SF3D output format | |
output_path = "/tmp/output.glb" | |
output_mesh.save(output_path) | |
with open(output_path, "rb") as f: | |
mesh_data = base64.b64encode(f.read()).decode("utf-8") | |
logger.info("Mesh processed successfully") | |
return jsonify({"mesh": f"data:model/gltf-binary;base64,{mesh_data}"}) | |
except Exception as e: | |
logger.error(f"Error generating mesh: {e}", exc_info=True) | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) |