imagen / app.py
cybergamer0123's picture
Update app.py
7527d1f verified
from fastapi import FastAPI, HTTPException, Depends
from fastapi.responses import StreamingResponse, Response, HTMLResponse
from pydantic import BaseModel
from io import BytesIO
from diffusers import OnnxStableDiffusionPipeline
from huggingface_hub import snapshot_download
from PIL import Image
import os
from contextlib import asynccontextmanager
repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Hugging Face repo ID
local_dir = "sd_onnx_models" # Subdirectory for models
pipeline = None # Initialize pipeline outside lifespan
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
Lifespan event handler to load the ONNX Stable Diffusion pipeline on startup and unload on shutdown.
"""
global pipeline
if pipeline is None:
try:
allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed)
local_model_path = snapshot_download(
repo_id=repo_id,
local_dir=local_dir, # Specify local_dir to ensure files are placed there
allow_patterns=allow_patterns
)
pipeline = OnnxStableDiffusionPipeline.from_pretrained(
local_dir, # Use the local path from snapshot_download
provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
revision="onnx" # Explicitly specify ONNX revision when loading from local path
)
print(f"ONNX Stable Diffusion pipeline loaded successfully from {repo_id} (ONNX revision) using snapshot_download from: {local_model_path}")
except Exception as e:
print(f"Error loading ONNX pipeline using snapshot_download: {e}")
raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}")
yield # App starts up after this point
pipeline = None # Optionally unload pipeline on shutdown (if needed for resource management)
print("ONNX Stable Diffusion pipeline unloaded.")
app = FastAPI(lifespan=lifespan) # Register lifespan function
class ImageRequest(BaseModel):
prompt: str
num_inference_steps: int = 50
guidance_scale: float = 7.5
format: str = "png" # default format
async def get_pipeline():
"""Dependency to ensure pipeline is loaded before endpoint is called."""
if pipeline is None: # Check if pipeline is loaded (should be by lifespan)
raise HTTPException(status_code=500, detail="Pipeline not loaded. Startup might have failed.")
return pipeline
@app.post("/generate-image/")
async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)):
"""
Generates an image based on the provided text prompt using the loaded ONNX Stable Diffusion pipeline.
"""
try:
image = pipeline_dep(
request.prompt,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale
).images[0]
# Convert PIL Image to bytes for streaming response
img_byte_arr = BytesIO()
image.save(img_byte_arr, format=request.format.upper()) # Use format from request
img_byte_arr = img_byte_arr.getvalue()
return StreamingResponse(content=iter([img_byte_arr]), media_type=f"image/{request.format}") # Use format from request
except Exception as e:
print(f"Error during image generation: {e}")
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
@app.get("/", response_class=HTMLResponse)
def index():
return """
<!DOCTYPE html>
<html>
<head>
<title>FastAPI Image Generation Demo</title>
<style>
body { font-family: Arial, sans-serif; }
.container { width: 80%; margin: auto; padding-top: 20px; }
h1 { text-align: center; }
.form-group { margin-bottom: 15px; }
label { display: block; margin-bottom: 5px; font-weight: bold; }
input[type="text"], input[type="number"], textarea, select { width: 100%; padding: 8px; box-sizing: border-box; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; }
textarea { height: 100px; }
button { padding: 10px 15px; border: none; color: white; background-color: #007bff; border-radius: 4px; cursor: pointer; }
button:hover { background-color: #0056b3; }
img { display: block; margin-top: 20px; max-width: 500px; } /* Adjust max-width as needed */
#image-container { display: none; margin-top: 20px; text-align: center; } /* Initially hidden container */
</style>
</head>
<body>
<div class="container">
<h1>FastAPI Image Generation Demo</h1>
<div class="form-group">
<label for="prompt">Text Prompt:</label>
<textarea id="prompt" rows="4" placeholder="Enter text prompt here"></textarea>
</div>
<div class="form-group">
<label for="num_inference_steps">Number of Inference Steps:</label>
<input type="number" id="num_inference_steps" value="50">
</div>
<div class="form-group">
<label for="guidance_scale">Guidance Scale:</label>
<input type="number" step="0.5" id="guidance_scale" value="7.5">
</div>
<div class="form-group">
<label for="format">Format:</label>
<select id="format">
<option value="png" selected>PNG</option>
<option value="jpeg">JPEG</option>
</select>
</div>
<div class="form-group">
<button onclick="generateImage()">Generate Image</button>
</div>
<div id="image-container">
<img id="image" src="#" alt="Generated Image" style="max-width: 80%; height: auto;">
</div>
</div>
<script>
async function generateImage() {
const prompt = document.getElementById('prompt').value;
const num_inference_steps = document.getElementById('num_inference_steps').value;
const guidance_scale = document.getElementById('guidance_scale').value;
const format = document.getElementById('format').value;
const imageElement = document.getElementById('image');
const imageContainer = document.getElementById('image-container');
imageElement.style.display = 'none'; // Hide previous image while loading
imageContainer.style.display = 'none'; // Hide container too
try {
const response = await fetch('/generate-image/', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
prompt: prompt,
num_inference_steps: parseInt(num_inference_steps),
guidance_scale: parseFloat(guidance_scale),
format: format
})
});
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const blob = await response.blob();
const imageUrl = URL.createObjectURL(blob);
imageElement.src = imageUrl;
imageElement.onload = () => { // Only show when image is loaded
imageContainer.style.display = 'block'; // Show image container
imageElement.style.display = 'block'; // Show the image
};
imageElement.onerror = () => {
imageElement.style.display = 'none'; // Hide image on error
imageContainer.style.display = 'none'; // Hide container on error
alert("Error loading image. Please check console for details.");
};
} catch (error) {
console.error("Fetch error:", error);
alert("Error generating image. Please check console for details.");
}
}
</script>
</body>
</html>
"""
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)