File size: 1,756 Bytes
2e2dac7
 
388cf5c
2e2dac7
 
 
 
81914fc
2e2dac7
 
1087492
2e2dac7
388cf5c
2e2dac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a74686
 
2e2dac7
 
388cf5c
2e2dac7
 
 
 
583e56c
388cf5c
2e2dac7
 
 
 
 
388cf5c
2e2dac7
 
388cf5c
2e2dac7
388cf5c
2e2dac7
 
0a74686
2e2dac7
583e56c
388cf5c
2e2dac7
388cf5c
1087492
48056a7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import io
import base64
import torch
from flask import Flask, request, jsonify, send_file
from diffusers import DiffusionPipeline
from PIL import Image
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = Flask(__name__)

# Load the model once at startup (on CPU)
try:
    logger.info("Loading Zero123Plus pipeline...")
    pipe = DiffusionPipeline.from_pretrained(
        "sudo-ai/zero123plus-v1.2",
        torch_dtype=torch.float32,  # CPU needs float32
    )
    pipe.to("cpu")
    logger.info("=== Application Startup at CPU mode =====")
except Exception as e:
    logger.error(f"Error loading model: {e}")
    pipe = None

def pil_to_base64(image):
    buffer = io.BytesIO()
    image.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")

@app.route("/")
def home():
    return "Zero123Plus CPU API is running!"

@app.route("/generate", methods=["POST"])
def generate():
    if pipe is None:
        return jsonify({"error": "Model not loaded"}), 500

    try:
        data = request.get_json()
        image_data = data.get("image")

        if not image_data:
            return jsonify({"error": "No image provided"}), 400

        if image_data.startswith("data:image"):
            image_data = image_data.split(",")[1]

        image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")

        result = pipe(image)
        output_image = result.images[0]

        return jsonify({"image": f"data:image/png;base64,{pil_to_base64(output_image)}"})

    except Exception as e:
        logger.error(f"Error generating image: {e}")
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)