from fastapi import FastAPI, HTTPException, Depends from fastapi.responses import StreamingResponse, Response, HTMLResponse from pydantic import BaseModel from io import BytesIO from diffusers import OnnxStableDiffusionPipeline from huggingface_hub import snapshot_download from PIL import Image import os from contextlib import asynccontextmanager repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Hugging Face repo ID local_dir = "sd_onnx_models" # Subdirectory for models pipeline = None # Initialize pipeline outside lifespan @asynccontextmanager async def lifespan(app: FastAPI): """ Lifespan event handler to load the ONNX Stable Diffusion pipeline on startup and unload on shutdown. """ global pipeline if pipeline is None: try: allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed) local_model_path = snapshot_download( repo_id=repo_id, local_dir=local_dir, # Specify local_dir to ensure files are placed there allow_patterns=allow_patterns ) pipeline = OnnxStableDiffusionPipeline.from_pretrained( local_dir, # Use the local path from snapshot_download provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU revision="onnx" # Explicitly specify ONNX revision when loading from local path ) print(f"ONNX Stable Diffusion pipeline loaded successfully from {repo_id} (ONNX revision) using snapshot_download from: {local_model_path}") except Exception as e: print(f"Error loading ONNX pipeline using snapshot_download: {e}") raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}") yield # App starts up after this point pipeline = None # Optionally unload pipeline on shutdown (if needed for resource management) print("ONNX Stable Diffusion pipeline unloaded.") app = FastAPI(lifespan=lifespan) # Register lifespan function class ImageRequest(BaseModel): prompt: str num_inference_steps: int = 50 guidance_scale: float = 7.5 format: str = "png" # default format async def get_pipeline(): """Dependency to ensure pipeline is loaded before endpoint is called.""" if pipeline is None: # Check if pipeline is loaded (should be by lifespan) raise HTTPException(status_code=500, detail="Pipeline not loaded. Startup might have failed.") return pipeline @app.post("/generate-image/") async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)): """ Generates an image based on the provided text prompt using the loaded ONNX Stable Diffusion pipeline. """ try: image = pipeline_dep( request.prompt, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale ).images[0] # Convert PIL Image to bytes for streaming response img_byte_arr = BytesIO() image.save(img_byte_arr, format=request.format.upper()) # Use format from request img_byte_arr = img_byte_arr.getvalue() return StreamingResponse(content=iter([img_byte_arr]), media_type=f"image/{request.format}") # Use format from request except Exception as e: print(f"Error during image generation: {e}") raise HTTPException(status_code=500, detail=f"Image generation failed: {e}") @app.get("/", response_class=HTMLResponse) def index(): return """