Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, Depends | |
from fastapi.responses import StreamingResponse, Response, HTMLResponse | |
from pydantic import BaseModel | |
from io import BytesIO | |
from diffusers import OnnxStableDiffusionPipeline | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
import os | |
from contextlib import asynccontextmanager | |
repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Hugging Face repo ID | |
local_dir = "sd_onnx_models" # Subdirectory for models | |
pipeline = None # Initialize pipeline outside lifespan | |
async def lifespan(app: FastAPI): | |
""" | |
Lifespan event handler to load the ONNX Stable Diffusion pipeline on startup and unload on shutdown. | |
""" | |
global pipeline | |
if pipeline is None: | |
try: | |
allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed) | |
local_model_path = snapshot_download( | |
repo_id=repo_id, | |
local_dir=local_dir, # Specify local_dir to ensure files are placed there | |
allow_patterns=allow_patterns | |
) | |
pipeline = OnnxStableDiffusionPipeline.from_pretrained( | |
local_dir, # Use the local path from snapshot_download | |
provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU | |
revision="onnx" # Explicitly specify ONNX revision when loading from local path | |
) | |
print(f"ONNX Stable Diffusion pipeline loaded successfully from {repo_id} (ONNX revision) using snapshot_download from: {local_model_path}") | |
except Exception as e: | |
print(f"Error loading ONNX pipeline using snapshot_download: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}") | |
yield # App starts up after this point | |
pipeline = None # Optionally unload pipeline on shutdown (if needed for resource management) | |
print("ONNX Stable Diffusion pipeline unloaded.") | |
app = FastAPI(lifespan=lifespan) # Register lifespan function | |
class ImageRequest(BaseModel): | |
prompt: str | |
num_inference_steps: int = 50 | |
guidance_scale: float = 7.5 | |
format: str = "png" # default format | |
async def get_pipeline(): | |
"""Dependency to ensure pipeline is loaded before endpoint is called.""" | |
if pipeline is None: # Check if pipeline is loaded (should be by lifespan) | |
raise HTTPException(status_code=500, detail="Pipeline not loaded. Startup might have failed.") | |
return pipeline | |
async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)): | |
""" | |
Generates an image based on the provided text prompt using the loaded ONNX Stable Diffusion pipeline. | |
""" | |
try: | |
image = pipeline_dep( | |
request.prompt, | |
num_inference_steps=request.num_inference_steps, | |
guidance_scale=request.guidance_scale | |
).images[0] | |
# Convert PIL Image to bytes for streaming response | |
img_byte_arr = BytesIO() | |
image.save(img_byte_arr, format=request.format.upper()) # Use format from request | |
img_byte_arr = img_byte_arr.getvalue() | |
return StreamingResponse(content=iter([img_byte_arr]), media_type=f"image/{request.format}") # Use format from request | |
except Exception as e: | |
print(f"Error during image generation: {e}") | |
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}") | |
def index(): | |
return """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>FastAPI Image Generation Demo</title> | |
<style> | |
body { font-family: Arial, sans-serif; } | |
.container { width: 80%; margin: auto; padding-top: 20px; } | |
h1 { text-align: center; } | |
.form-group { margin-bottom: 15px; } | |
label { display: block; margin-bottom: 5px; font-weight: bold; } | |
input[type="text"], input[type="number"], textarea, select { width: 100%; padding: 8px; box-sizing: border-box; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; } | |
textarea { height: 100px; } | |
button { padding: 10px 15px; border: none; color: white; background-color: #007bff; border-radius: 4px; cursor: pointer; } | |
button:hover { background-color: #0056b3; } | |
img { display: block; margin-top: 20px; max-width: 500px; } /* Adjust max-width as needed */ | |
#image-container { display: none; margin-top: 20px; text-align: center; } /* Initially hidden container */ | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>FastAPI Image Generation Demo</h1> | |
<div class="form-group"> | |
<label for="prompt">Text Prompt:</label> | |
<textarea id="prompt" rows="4" placeholder="Enter text prompt here"></textarea> | |
</div> | |
<div class="form-group"> | |
<label for="num_inference_steps">Number of Inference Steps:</label> | |
<input type="number" id="num_inference_steps" value="50"> | |
</div> | |
<div class="form-group"> | |
<label for="guidance_scale">Guidance Scale:</label> | |
<input type="number" step="0.5" id="guidance_scale" value="7.5"> | |
</div> | |
<div class="form-group"> | |
<label for="format">Format:</label> | |
<select id="format"> | |
<option value="png" selected>PNG</option> | |
<option value="jpeg">JPEG</option> | |
</select> | |
</div> | |
<div class="form-group"> | |
<button onclick="generateImage()">Generate Image</button> | |
</div> | |
<div id="image-container"> | |
<img id="image" src="#" alt="Generated Image" style="max-width: 80%; height: auto;"> | |
</div> | |
</div> | |
<script> | |
async function generateImage() { | |
const prompt = document.getElementById('prompt').value; | |
const num_inference_steps = document.getElementById('num_inference_steps').value; | |
const guidance_scale = document.getElementById('guidance_scale').value; | |
const format = document.getElementById('format').value; | |
const imageElement = document.getElementById('image'); | |
const imageContainer = document.getElementById('image-container'); | |
imageElement.style.display = 'none'; // Hide previous image while loading | |
imageContainer.style.display = 'none'; // Hide container too | |
try { | |
const response = await fetch('/generate-image/', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json' | |
}, | |
body: JSON.stringify({ | |
prompt: prompt, | |
num_inference_steps: parseInt(num_inference_steps), | |
guidance_scale: parseFloat(guidance_scale), | |
format: format | |
}) | |
}); | |
if (!response.ok) { | |
throw new Error(`HTTP error! status: ${response.status}`); | |
} | |
const blob = await response.blob(); | |
const imageUrl = URL.createObjectURL(blob); | |
imageElement.src = imageUrl; | |
imageElement.onload = () => { // Only show when image is loaded | |
imageContainer.style.display = 'block'; // Show image container | |
imageElement.style.display = 'block'; // Show the image | |
}; | |
imageElement.onerror = () => { | |
imageElement.style.display = 'none'; // Hide image on error | |
imageContainer.style.display = 'none'; // Hide container on error | |
alert("Error loading image. Please check console for details."); | |
}; | |
} catch (error) { | |
console.error("Fetch error:", error); | |
alert("Error generating image. Please check console for details."); | |
} | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |