Rightlight / app.py
mike23415's picture
Update app.py
388cf5c verified
raw
history blame
1.53 kB
import os
import sys
from flask import Flask, request, jsonify
from PIL import Image
import torch
# Add the current directory to sys.path to allow local import
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from pipeline import Zero123PlusPipeline
app = Flask(__name__)
# Load the pipeline once when the app starts
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on device: {device}")
pipe = Zero123PlusPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
pipe = pipe.to(device)
@app.route("/")
def index():
return "Zero123Plus API is running!"
@app.route("/predict", methods=["POST"])
def predict():
if 'image' not in request.files:
return jsonify({"error": "No image uploaded"}), 400
image = request.files["image"]
try:
input_image = Image.open(image).convert("RGB")
result = pipe(input_image, num_inference_steps=75, num_images_per_prompt=4)
images = result.images # List of PIL Images
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
saved_paths = []
for i, img in enumerate(images):
path = os.path.join(output_dir, f"output_{i}.png")
img.save(path)
saved_paths.append(path)
return jsonify({"outputs": saved_paths})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)