Rightlight / app.py
mike23415's picture
Update app.py
3831488 verified
raw
history blame
1.91 kB
import os
import torch
import logging
from flask import Flask, request, jsonify
from diffusers import DiffusionPipeline
from PIL import Image
from io import BytesIO
import base64
# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Flask app
app = Flask(__name__)
# Load Zero123Plus pipeline (for CPU)
logger.info("Loading Zero123Plus pipeline...")
try:
pipe = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
torch_dtype=torch.float32,
variant=None, # avoid fp16 issues
)
pipe.to("cpu")
logger.info("Pipeline loaded successfully.")
except Exception as e:
logger.error(f"Error loading model: {e}")
pipe = None
@app.route("/", methods=["GET"])
def index():
return "Zero123Plus API is running."
@app.route("/generate", methods=["POST"])
def generate():
if pipe is None:
return jsonify({"error": "Model not loaded"}), 500
try:
data = request.get_json()
image_data = data.get("image")
if not image_data:
return jsonify({"error": "No image provided"}), 400
# Decode base64 to PIL image
image = Image.open(BytesIO(base64.b64decode(image_data.split(",")[-1]))).convert("RGB")
# Run inference
logger.info("Generating 3D views...")
output = pipe(image)
generated_image = output.images[0]
# Convert output to base64
buffered = BytesIO()
generated_image.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return jsonify({"image": f"data:image/png;base64,{img_base64}"})
except Exception as e:
logger.error(f"Generation failed: {e}")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
logger.info("=== Application Startup at CPU mode =====")
app.run(host="0.0.0.0", port=7860)