Spaces:
Running
Running
import os | |
import logging | |
import pip | |
from flask import Flask, request, jsonify, send_file | |
import torch | |
from PIL import Image | |
import io | |
# Configure logging to stdout instead of files | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger(__name__) | |
# Try to update the required packages | |
try: | |
logger.info("Updating huggingface_hub and diffusers...") | |
pip.main(['install', '--upgrade', 'huggingface_hub', '--quiet']) | |
pip.main(['install', '--upgrade', 'diffusers', '--quiet']) | |
except Exception as e: | |
logger.warning(f"Failed to update libraries: {str(e)}") | |
# Set Hugging Face cache directory to a writable path | |
os.environ['HF_HOME'] = '/tmp/hf_home' | |
os.environ['XDG_CACHE_HOME'] = '/tmp/cache' | |
# Create cache directories if they don't exist | |
os.makedirs('/tmp/hf_home', exist_ok=True) | |
os.makedirs('/tmp/cache', exist_ok=True) | |
os.makedirs('/tmp/diffusers_cache', exist_ok=True) | |
# Global variable for the model | |
pipe = None | |
# Initialize the model at startup | |
def load_model(): | |
global pipe | |
try: | |
logger.info("Loading Zero123Plus model...") | |
# Import here to ensure the environment variables are set before import | |
from diffusers import AutoPipelineForImage2Image | |
from huggingface_hub import snapshot_download | |
try: | |
# First try to download the model files | |
model_path = snapshot_download( | |
"sudo-ai/zero123plus-v1.2", | |
cache_dir="/tmp/diffusers_cache", | |
local_files_only=False | |
) | |
# Then load from local path | |
pipe = AutoPipelineForImage2Image.from_pretrained( | |
model_path, | |
torch_dtype=torch.float32, | |
safety_checker=None, | |
low_cpu_mem_usage=True | |
) | |
except Exception as download_error: | |
logger.warning(f"Failed to download using snapshot_download: {str(download_error)}") | |
# Fallback to direct loading with local_files_only=False | |
pipe = AutoPipelineForImage2Image.from_pretrained( | |
"sudo-ai/zero123plus-v1.2", | |
torch_dtype=torch.float32, | |
cache_dir="/tmp/diffusers_cache", | |
safety_checker=None, | |
low_cpu_mem_usage=True, | |
local_files_only=False | |
) | |
pipe.to("cpu") | |
logger.info("Model loaded successfully") | |
return True | |
except Exception as e: | |
logger.error(f"Error loading model: {str(e)}") | |
return False | |
# Initialize Flask app | |
app = Flask(__name__) | |
# Load the model immediately | |
load_model() | |
def index(): | |
# Check if logs parameter is present | |
if request.args.get('logs') == 'container': | |
return jsonify({"message": "Zero123Plus API is running.", "status": "logs viewed"}) | |
return jsonify({"message": "Zero123Plus API is running."}) | |
def generate(): | |
try: | |
global pipe | |
# Check if model is loaded | |
if pipe is None: | |
success = load_model() | |
if not success: | |
return jsonify({"error": "Failed to initialize model"}), 500 | |
if 'image' not in request.files: | |
logger.warning("No image uploaded") | |
return jsonify({"error": "No image uploaded"}), 400 | |
image_file = request.files['image'] | |
input_image = Image.open(image_file).convert("RGB") | |
logger.info(f"Received image of size {input_image.size}") | |
# Get optional parameters with defaults | |
num_steps = int(request.form.get('num_inference_steps', 25)) | |
logger.info(f"Starting image generation with {num_steps} steps") | |
# Generate new views | |
result = pipe( | |
image=input_image, | |
num_inference_steps=num_steps | |
) | |
output_image = result.images[0] | |
logger.info(f"Generated image of size {output_image.size}") | |
# Save to a BytesIO object | |
img_io = io.BytesIO() | |
output_image.save(img_io, 'PNG') | |
img_io.seek(0) | |
return send_file(img_io, mimetype='image/png') | |
except Exception as e: | |
logger.error(f"Error during generation: {str(e)}") | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) |