File size: 4,522 Bytes
48056a7
7949d53
8a1bb06
48056a7
 
 
 
 
7949d53
 
 
 
 
 
 
 
8a1bb06
 
 
 
 
 
 
 
48056a7
125b5d0
 
7949d53
125b5d0
 
 
c774bcd
1087492
60b636a
 
1087492
60b636a
7949d53
 
 
 
81914fc
 
8a1bb06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81914fc
7949d53
 
81914fc
7949d53
 
81914fc
 
8a1bb06
81914fc
1087492
60b636a
 
 
48056a7
 
8a1bb06
 
 
48056a7
1087492
48056a7
 
7949d53
60b636a
 
 
81914fc
 
60b636a
 
7949d53
 
 
1087492
7949d53
 
 
1087492
7949d53
 
 
 
 
81914fc
 
 
 
 
7949d53
 
1087492
7949d53
 
 
 
1087492
7949d53
 
 
 
 
1087492
48056a7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import logging
import pip
from flask import Flask, request, jsonify, send_file
import torch
from PIL import Image
import io

# Configure logging to stdout instead of files
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Try to update the required packages
try:
    logger.info("Updating huggingface_hub and diffusers...")
    pip.main(['install', '--upgrade', 'huggingface_hub', '--quiet'])
    pip.main(['install', '--upgrade', 'diffusers', '--quiet'])
except Exception as e:
    logger.warning(f"Failed to update libraries: {str(e)}")

# Set Hugging Face cache directory to a writable path
os.environ['HF_HOME'] = '/tmp/hf_home'
os.environ['XDG_CACHE_HOME'] = '/tmp/cache'

# Create cache directories if they don't exist
os.makedirs('/tmp/hf_home', exist_ok=True)
os.makedirs('/tmp/cache', exist_ok=True)
os.makedirs('/tmp/diffusers_cache', exist_ok=True)

# Global variable for the model
pipe = None

# Initialize the model at startup
def load_model():
    global pipe
    try:
        logger.info("Loading Zero123Plus model...")
        # Import here to ensure the environment variables are set before import
        from diffusers import AutoPipelineForImage2Image
        from huggingface_hub import snapshot_download
        
        try:
            # First try to download the model files
            model_path = snapshot_download(
                "sudo-ai/zero123plus-v1.2",
                cache_dir="/tmp/diffusers_cache",
                local_files_only=False
            )
            
            # Then load from local path
            pipe = AutoPipelineForImage2Image.from_pretrained(
                model_path,
                torch_dtype=torch.float32,
                safety_checker=None,
                low_cpu_mem_usage=True
            )
        except Exception as download_error:
            logger.warning(f"Failed to download using snapshot_download: {str(download_error)}")
            
            # Fallback to direct loading with local_files_only=False
            pipe = AutoPipelineForImage2Image.from_pretrained(
                "sudo-ai/zero123plus-v1.2",
                torch_dtype=torch.float32,
                cache_dir="/tmp/diffusers_cache",
                safety_checker=None,
                low_cpu_mem_usage=True,
                local_files_only=False
            )
        
        pipe.to("cpu")
        logger.info("Model loaded successfully")
        return True
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        return False

# Initialize Flask app
app = Flask(__name__)

# Load the model immediately
load_model()

@app.route("/", methods=["GET"])
def index():
    # Check if logs parameter is present
    if request.args.get('logs') == 'container':
        return jsonify({"message": "Zero123Plus API is running.", "status": "logs viewed"})
    return jsonify({"message": "Zero123Plus API is running."})

@app.route("/generate", methods=["POST"])
def generate():
    try:
        global pipe
        # Check if model is loaded
        if pipe is None:
            success = load_model()
            if not success:
                return jsonify({"error": "Failed to initialize model"}), 500
        
        if 'image' not in request.files:
            logger.warning("No image uploaded")
            return jsonify({"error": "No image uploaded"}), 400

        image_file = request.files['image']
        input_image = Image.open(image_file).convert("RGB")
        logger.info(f"Received image of size {input_image.size}")

        # Get optional parameters with defaults
        num_steps = int(request.form.get('num_inference_steps', 25))
        
        logger.info(f"Starting image generation with {num_steps} steps")
        # Generate new views
        result = pipe(
            image=input_image, 
            num_inference_steps=num_steps
        )
        
        output_image = result.images[0]
        logger.info(f"Generated image of size {output_image.size}")

        # Save to a BytesIO object
        img_io = io.BytesIO()
        output_image.save(img_io, 'PNG')
        img_io.seek(0)

        return send_file(img_io, mimetype='image/png')
    
    except Exception as e:
        logger.error(f"Error during generation: {str(e)}")
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)