Spaces:
Running
Running
File size: 3,343 Bytes
48056a7 7949d53 48056a7 c774bcd 48056a7 7949d53 48056a7 125b5d0 7949d53 125b5d0 c774bcd 1087492 60b636a 1087492 60b636a 7949d53 c774bcd 7949d53 125b5d0 c774bcd 7949d53 1087492 60b636a 48056a7 1087492 48056a7 7949d53 60b636a 7949d53 1087492 7949d53 1087492 7949d53 1087492 7949d53 1087492 7949d53 1087492 48056a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
import logging
from flask import Flask, request, jsonify, send_file
from diffusers.pipelines import DiffusionPipeline
try:
from diffusers.pipelines.zero123plus.pipeline_zero123plus import Zero123PlusPipeline
except ImportError:
print("Zero123PlusPipeline not found in current diffusers version")
import torch
from PIL import Image
import io
# Configure logging to stdout instead of files
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)
# Set Hugging Face cache directory to a writable path
# Make sure to set this BEFORE importing or initializing any models
os.environ['HF_HOME'] = '/tmp/hf_home'
os.environ['XDG_CACHE_HOME'] = '/tmp/cache'
# Create cache directories if they don't exist
os.makedirs('/tmp/hf_home', exist_ok=True)
os.makedirs('/tmp/cache', exist_ok=True)
os.makedirs('/tmp/diffusers_cache', exist_ok=True)
# Global variable for the model
pipe = None
# Initialize the model at startup
def load_model():
global pipe
try:
logger.info("Loading Zero123Plus model...")
# Use auto_mapping to let diffusers figure out the correct pipeline class
pipe = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
torch_dtype=torch.float32,
cache_dir="/tmp/diffusers_cache",
local_files_only=False,
resume_download=True
)
pipe.to("cpu")
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise
# Load the model immediately
load_model()
app = Flask(__name__)
@app.route("/", methods=["GET"])
def index():
return jsonify({"message": "Zero123Plus API is running."})
@app.route("/generate", methods=["POST"])
def generate():
try:
global pipe
# Check if model is loaded
if pipe is None:
try:
load_model()
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
return jsonify({"error": "Failed to initialize model"}), 500
if 'image' not in request.files:
logger.warning("No image uploaded")
return jsonify({"error": "No image uploaded"}), 400
image_file = request.files['image']
input_image = Image.open(image_file).convert("RGB")
logger.info(f"Received image of size {input_image.size}")
# Get optional parameters with defaults
num_steps = int(request.form.get('num_inference_steps', 25))
logger.info(f"Starting image generation with {num_steps} steps")
# Generate new views
result = pipe(image=input_image, num_inference_steps=num_steps)
output_image = result.images[0]
logger.info(f"Generated image of size {output_image.size}")
# Save to a BytesIO object
img_io = io.BytesIO()
output_image.save(img_io, 'PNG')
img_io.seek(0)
return send_file(img_io, mimetype='image/png')
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860) |