Spaces:
Runtime error
Runtime error
import os | |
import sys | |
from flask import Flask, request, jsonify | |
from PIL import Image | |
import torch | |
# Add the current directory to sys.path to allow local import | |
sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
from pipeline import Zero123PlusPipeline | |
app = Flask(__name__) | |
# Load the pipeline once when the app starts | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Running on device: {device}") | |
pipe = Zero123PlusPipeline.from_pretrained( | |
"sudo-ai/zero123plus-v1.2", | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
) | |
pipe = pipe.to(device) | |
def index(): | |
return "Zero123Plus API is running!" | |
def predict(): | |
if 'image' not in request.files: | |
return jsonify({"error": "No image uploaded"}), 400 | |
image = request.files["image"] | |
try: | |
input_image = Image.open(image).convert("RGB") | |
result = pipe(input_image, num_inference_steps=75, num_images_per_prompt=4) | |
images = result.images # List of PIL Images | |
output_dir = "outputs" | |
os.makedirs(output_dir, exist_ok=True) | |
saved_paths = [] | |
for i, img in enumerate(images): | |
path = os.path.join(output_dir, f"output_{i}.png") | |
img.save(path) | |
saved_paths.append(path) | |
return jsonify({"outputs": saved_paths}) | |
except Exception as e: | |
return jsonify({"error": str(e)}), 500 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) |