Rightlight / app.py
mike23415's picture
Update app.py
8a1bb06 verified
raw
history blame
4.52 kB
import os
import logging
import pip
from flask import Flask, request, jsonify, send_file
import torch
from PIL import Image
import io
# Configure logging to stdout instead of files
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Try to update the required packages
try:
logger.info("Updating huggingface_hub and diffusers...")
pip.main(['install', '--upgrade', 'huggingface_hub', '--quiet'])
pip.main(['install', '--upgrade', 'diffusers', '--quiet'])
except Exception as e:
logger.warning(f"Failed to update libraries: {str(e)}")
# Set Hugging Face cache directory to a writable path
os.environ['HF_HOME'] = '/tmp/hf_home'
os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
# Create cache directories if they don't exist
os.makedirs('/tmp/hf_home', exist_ok=True)
os.makedirs('/tmp/cache', exist_ok=True)
os.makedirs('/tmp/diffusers_cache', exist_ok=True)
# Global variable for the model
pipe = None
# Initialize the model at startup
def load_model():
global pipe
try:
logger.info("Loading Zero123Plus model...")
# Import here to ensure the environment variables are set before import
from diffusers import AutoPipelineForImage2Image
from huggingface_hub import snapshot_download
try:
# First try to download the model files
model_path = snapshot_download(
"sudo-ai/zero123plus-v1.2",
cache_dir="/tmp/diffusers_cache",
local_files_only=False
)
# Then load from local path
pipe = AutoPipelineForImage2Image.from_pretrained(
model_path,
torch_dtype=torch.float32,
safety_checker=None,
low_cpu_mem_usage=True
)
except Exception as download_error:
logger.warning(f"Failed to download using snapshot_download: {str(download_error)}")
# Fallback to direct loading with local_files_only=False
pipe = AutoPipelineForImage2Image.from_pretrained(
"sudo-ai/zero123plus-v1.2",
torch_dtype=torch.float32,
cache_dir="/tmp/diffusers_cache",
safety_checker=None,
low_cpu_mem_usage=True,
local_files_only=False
)
pipe.to("cpu")
logger.info("Model loaded successfully")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
# Initialize Flask app
app = Flask(__name__)
# Load the model immediately
load_model()
@app.route("/", methods=["GET"])
def index():
# Check if logs parameter is present
if request.args.get('logs') == 'container':
return jsonify({"message": "Zero123Plus API is running.", "status": "logs viewed"})
return jsonify({"message": "Zero123Plus API is running."})
@app.route("/generate", methods=["POST"])
def generate():
try:
global pipe
# Check if model is loaded
if pipe is None:
success = load_model()
if not success:
return jsonify({"error": "Failed to initialize model"}), 500
if 'image' not in request.files:
logger.warning("No image uploaded")
return jsonify({"error": "No image uploaded"}), 400
image_file = request.files['image']
input_image = Image.open(image_file).convert("RGB")
logger.info(f"Received image of size {input_image.size}")
# Get optional parameters with defaults
num_steps = int(request.form.get('num_inference_steps', 25))
logger.info(f"Starting image generation with {num_steps} steps")
# Generate new views
result = pipe(
image=input_image,
num_inference_steps=num_steps
)
output_image = result.images[0]
logger.info(f"Generated image of size {output_image.size}")
# Save to a BytesIO object
img_io = io.BytesIO()
output_image.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)