File size: 3,059 Bytes
48056a7
7949d53
48056a7
 
 
 
 
 
7949d53
 
 
 
 
 
 
 
48056a7
125b5d0
 
 
 
7949d53
125b5d0
 
 
 
1087492
60b636a
 
1087492
60b636a
7949d53
 
 
 
 
 
125b5d0
 
7949d53
 
 
 
 
 
1087492
60b636a
 
 
 
 
48056a7
 
 
1087492
48056a7
 
7949d53
60b636a
 
 
 
 
 
 
 
 
7949d53
 
 
1087492
7949d53
 
 
1087492
7949d53
 
 
 
 
 
 
 
1087492
7949d53
 
 
 
1087492
7949d53
 
 
 
 
1087492
48056a7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import logging
from flask import Flask, request, jsonify, send_file
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import io

# Configure logging to stdout instead of files
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Set Hugging Face cache directory to a writable path
# Make sure to set this BEFORE importing or initializing any models
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['HF_HOME'] = '/tmp/hf_home'
os.environ['XDG_CACHE_HOME'] = '/tmp/cache'

# Create cache directories if they don't exist
os.makedirs('/tmp/transformers_cache', exist_ok=True)
os.makedirs('/tmp/hf_home', exist_ok=True)
os.makedirs('/tmp/cache', exist_ok=True)

# Global variable for the model
pipe = None

# Initialize the model at startup
def load_model():
    global pipe
    try:
        logger.info("Loading Zero123Plus model...")
        pipe = DiffusionPipeline.from_pretrained(
            "sudo-ai/zero123plus-v1.2",
            torch_dtype=torch.float32,
            cache_dir="/tmp/diffusers_cache"
        )
        pipe.to("cpu")
        logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Error loading model: {str(e)}")
        raise

# Load the model immediately
load_model()

app = Flask(__name__)

@app.route("/", methods=["GET"])
def index():
    return jsonify({"message": "Zero123Plus API is running."})

@app.route("/generate", methods=["POST"])
def generate():
    try:
        global pipe
        # Check if model is loaded
        if pipe is None:
            try:
                load_model()
            except Exception as e:
                logger.error(f"Failed to load model: {str(e)}")
                return jsonify({"error": "Failed to initialize model"}), 500
        
        if 'image' not in request.files:
            logger.warning("No image uploaded")
            return jsonify({"error": "No image uploaded"}), 400

        image_file = request.files['image']
        input_image = Image.open(image_file).convert("RGB")
        logger.info(f"Received image of size {input_image.size}")

        # Get optional parameters with defaults
        num_steps = int(request.form.get('num_inference_steps', 25))
        
        logger.info(f"Starting image generation with {num_steps} steps")
        # Generate new views
        result = pipe(image=input_image, num_inference_steps=num_steps)
        output_image = result.images[0]
        logger.info(f"Generated image of size {output_image.size}")

        # Save to a BytesIO object
        img_io = io.BytesIO()
        output_image.save(img_io, 'PNG')
        img_io.seek(0)

        return send_file(img_io, mimetype='image/png')
    
    except Exception as e:
        logger.error(f"Error during generation: {str(e)}")
        return jsonify({"error": str(e)}), 500

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)