Rightlight / app.py
mike23415's picture
Update app.py
125b5d0 verified
raw
history blame
3.06 kB
import os
import logging
from flask import Flask, request, jsonify, send_file
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import io
# Configure logging to stdout instead of files
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Set Hugging Face cache directory to a writable path
# Make sure to set this BEFORE importing or initializing any models
os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers_cache'
os.environ['HF_HOME'] = '/tmp/hf_home'
os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
# Create cache directories if they don't exist
os.makedirs('/tmp/transformers_cache', exist_ok=True)
os.makedirs('/tmp/hf_home', exist_ok=True)
os.makedirs('/tmp/cache', exist_ok=True)
# Global variable for the model
pipe = None
# Initialize the model at startup
def load_model():
global pipe
try:
logger.info("Loading Zero123Plus model...")
pipe = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
torch_dtype=torch.float32,
cache_dir="/tmp/diffusers_cache"
)
pipe.to("cpu")
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
# Load the model immediately
load_model()
app = Flask(__name__)
@app.route("/", methods=["GET"])
def index():
return jsonify({"message": "Zero123Plus API is running."})
@app.route("/generate", methods=["POST"])
def generate():
try:
global pipe
# Check if model is loaded
if pipe is None:
try:
load_model()
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
return jsonify({"error": "Failed to initialize model"}), 500
if 'image' not in request.files:
logger.warning("No image uploaded")
return jsonify({"error": "No image uploaded"}), 400
image_file = request.files['image']
input_image = Image.open(image_file).convert("RGB")
logger.info(f"Received image of size {input_image.size}")
# Get optional parameters with defaults
num_steps = int(request.form.get('num_inference_steps', 25))
logger.info(f"Starting image generation with {num_steps} steps")
# Generate new views
result = pipe(image=input_image, num_inference_steps=num_steps)
output_image = result.images[0]
logger.info(f"Generated image of size {output_image.size}")
# Save to a BytesIO object
img_io = io.BytesIO()
output_image.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)