Rightlight / app.py
mike23415's picture
Update app.py
4e31b1a verified
raw
history blame
2.47 kB
import io
import base64
import torch
import os
from flask import Flask, request, jsonify
from diffusers import StableDiffusionPipeline # Placeholder; adjust based on SF3D docs
from PIL import Image
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Load the model once at startup (on CPU) with token from environment
try:
logger.info("Loading Stable Fast 3D pipeline...")
token = os.getenv("HF_TOKEN") # Retrieve token from environment variable
if not token:
raise ValueError("HF_TOKEN environment variable not set")
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-fast-3d",
torch_dtype=torch.float32,
cache_dir="/tmp/hf_home",
token=token, # Use the environment variable token
)
pipe.to("cpu")
logger.info("=== Application Startup at CPU mode =====")
except Exception as e:
logger.error(f"Error loading model: {e}", exc_info=True)
pipe = None
def pil_to_base64(image):
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
@app.route("/")
def home():
return "Stable Fast 3D CPU API is running!"
@app.route("/generate", methods=["POST"])
def generate():
if pipe is None:
return jsonify({"error": "Model not loaded"}), 500
try:
data = request.get_json()
image_data = data.get("image")
if not image_data:
return jsonify({"error": "No image provided"}), 400
if image_data.startswith("data:image"):
image_data = image_data.split(",")[1]
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
logger.info("Processing image with pipeline...")
result = pipe(image) # Adjust based on SF3D documentation
output_mesh = result.mesh # Hypothetical; check SF3D output format
output_path = "/tmp/output.glb"
output_mesh.save(output_path)
with open(output_path, "rb") as f:
mesh_data = base64.b64encode(f.read()).decode("utf-8")
logger.info("Mesh processed successfully")
return jsonify({"mesh": f"data:model/gltf-binary;base64,{mesh_data}"})
except Exception as e:
logger.error(f"Error generating mesh: {e}", exc_info=True)
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)