Rightlight / app.py
mike23415's picture
Update app.py
31cc64d verified
raw
history blame
2.34 kB
import io
import base64
import torch
import os
from flask import Flask, request, jsonify
from diffusers import StableDiffusionPipeline # Placeholder; adjust based on InstantMesh docs
from PIL import Image
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
# Load the model once at startup (on CPU) without token (test only)
try:
logger.info("Loading TencentARC InstantMesh pipeline...")
pipe = StableDiffusionPipeline.from_pretrained(
"TencentARC/InstantMesh",
torch_dtype=torch.float32,
cache_dir="/tmp/hf_home",
# token=token, # Comment out or remove for test
)
pipe.to("cpu")
logger.info("=== Application Startup at CPU mode =====")
except Exception as e:
logger.error(f"Error loading model: {e}", exc_info=True)
pipe = None
def pil_to_base64(image):
buffer = io.BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
@app.route("/")
def home():
return "TencentARC InstantMesh CPU API is running!"
@app.route("/generate", methods=["POST"])
def generate():
if pipe is None:
return jsonify({"error": "Model not loaded"}), 500
try:
data = request.get_json()
image_data = data.get("image")
if not image_data:
return jsonify({"error": "No image provided"}), 400
if image_data.startswith("data:image"):
image_data = image_data.split(",")[1]
image = Image.open(io.BytesIO(base64.b64decode(image_data))).convert("RGB")
logger.info("Processing image with pipeline...")
result = pipe(image) # Adjust based on InstantMesh documentation
output_mesh = result.mesh # Hypothetical; check InstantMesh output format
output_path = "/tmp/output.glb"
output_mesh.save(output_path)
with open(output_path, "rb") as f:
mesh_data = base64.b64encode(f.read()).decode("utf-8")
logger.info("Mesh processed successfully")
return jsonify({"mesh": f"data:model/gltf-binary;base64,{mesh_data}"})
except Exception as e:
logger.error(f"Error generating mesh: {e}", exc_info=True)
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)