File size: 3,343 Bytes
48056a7
7949d53
48056a7
c774bcd
 
 
 
 
48056a7
 
 
 
7949d53
 
 
 
 
 
 
 
48056a7
125b5d0
 
 
7949d53
125b5d0
 
 
c774bcd
1087492
60b636a
 
1087492
60b636a
7949d53
 
 
 
c774bcd
7949d53
 
125b5d0
c774bcd
 
 
7949d53
 
 
 
 
 
1087492
60b636a
 
 
 
 
48056a7
 
 
1087492
48056a7
 
7949d53
60b636a
 
 
 
 
 
 
 
 
7949d53
 
 
1087492
7949d53
 
 
1087492
7949d53
 
 
 
 
 
 
 
1087492
7949d53
 
 
 
1087492
7949d53
 
 
 
 
1087492
48056a7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import logging
from flask import Flask, request, jsonify, send_file
from diffusers.pipelines import DiffusionPipeline
try:
    from diffusers.pipelines.zero123plus.pipeline_zero123plus import Zero123PlusPipeline
except ImportError:
    print("Zero123PlusPipeline not found in current diffusers version")
import torch
from PIL import Image
import io

# Configure logging to stdout instead of files
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Set Hugging Face cache directory to a writable path
# Make sure to set this BEFORE importing or initializing any models
os.environ['HF_HOME'] = '/tmp/hf_home'
os.environ['XDG_CACHE_HOME'] = '/tmp/cache'

# Create cache directories if they don't exist
os.makedirs('/tmp/hf_home', exist_ok=True)
os.makedirs('/tmp/cache', exist_ok=True)
os.makedirs('/tmp/diffusers_cache', exist_ok=True)

# Global variable for the model
pipe = None

# Initialize the model at startup
def load_model():
    global pipe
    try:
        logger.info("Loading Zero123Plus model...")
        # Use auto_mapping to let diffusers figure out the correct pipeline class
        pipe = DiffusionPipeline.from_pretrained(
            "sudo-ai/zero123plus-v1.2",
            torch_dtype=torch.float32,
            cache_dir="/tmp/diffusers_cache",
            local_files_only=False,
            resume_download=True
        )
        pipe.to("cpu")
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        raise

# Load the model immediately
load_model()

app = Flask(__name__)

@app.route("/", methods=["GET"])
def index():
    return jsonify({"message": "Zero123Plus API is running."})

@app.route("/generate", methods=["POST"])
def generate():
    try:
        global pipe
        # Check if model is loaded
        if pipe is None:
            try:
                load_model()
            except Exception as e:
                logger.error(f"Failed to load model: {str(e)}")
                return jsonify({"error": "Failed to initialize model"}), 500
        
        if 'image' not in request.files:
            logger.warning("No image uploaded")
            return jsonify({"error": "No image uploaded"}), 400

        image_file = request.files['image']
        input_image = Image.open(image_file).convert("RGB")
        logger.info(f"Received image of size {input_image.size}")

        # Get optional parameters with defaults
        num_steps = int(request.form.get('num_inference_steps', 25))
        
        logger.info(f"Starting image generation with {num_steps} steps")
        # Generate new views
        result = pipe(image=input_image, num_inference_steps=num_steps)
        output_image = result.images[0]
        logger.info(f"Generated image of size {output_image.size}")

        # Save to a BytesIO object
        img_io = io.BytesIO()
        output_image.save(img_io, 'PNG')
        img_io.seek(0)

        return send_file(img_io, mimetype='image/png')
    
    except Exception as e:
        logger.error(f"Error during generation: {str(e)}")
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)