File size: 1,912 Bytes
48056a7
 
583e56c
 
9a14904
48056a7
9a14904
583e56c
1087492
3831488
9a14904
583e56c
81914fc
3831488
81914fc
1087492
3831488
583e56c
9a14904
 
583e56c
 
3831488
583e56c
 
 
9a14904
583e56c
9a14904
 
583e56c
 
 
1087492
48056a7
583e56c
9a14904
583e56c
1087492
9a14904
3831488
 
 
583e56c
 
 
3831488
583e56c
 
3831488
583e56c
 
3831488
583e56c
3831488
583e56c
3831488
 
583e56c
3831488
 
7949d53
3831488
7949d53
1087492
48056a7
3831488
48056a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import torch
import logging
from flask import Flask, request, jsonify
from diffusers import DiffusionPipeline
from PIL import Image
from io import BytesIO
import base64

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Flask app
app = Flask(__name__)

# Load Zero123Plus pipeline (for CPU)
logger.info("Loading Zero123Plus pipeline...")
try:
    pipe = DiffusionPipeline.from_pretrained(
        "sudo-ai/zero123plus-v1.2",
        torch_dtype=torch.float32,
        variant=None,  # avoid fp16 issues
    )
    pipe.to("cpu")
    logger.info("Pipeline loaded successfully.")
except Exception as e:
    logger.error(f"Error loading model: {e}")
    pipe = None

@app.route("/", methods=["GET"])
def index():
    return "Zero123Plus API is running."

@app.route("/generate", methods=["POST"])
def generate():
    if pipe is None:
        return jsonify({"error": "Model not loaded"}), 500

    try:
        data = request.get_json()
        image_data = data.get("image")

        if not image_data:
            return jsonify({"error": "No image provided"}), 400

        # Decode base64 to PIL image
        image = Image.open(BytesIO(base64.b64decode(image_data.split(",")[-1]))).convert("RGB")

        # Run inference
        logger.info("Generating 3D views...")
        output = pipe(image)
        generated_image = output.images[0]

        # Convert output to base64
        buffered = BytesIO()
        generated_image.save(buffered, format="PNG")
        img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

        return jsonify({"image": f"data:image/png;base64,{img_base64}"})
    
    except Exception as e:
        logger.error(f"Generation failed: {e}")
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    logger.info("=== Application Startup at CPU mode =====")
    app.run(host="0.0.0.0", port=7860)