Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,230 +1,78 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import time
|
4 |
-
import json
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
from fastapi import FastAPI, HTTPException, Body
|
8 |
-
from fastapi.responses import StreamingResponse, Response, HTMLResponse
|
9 |
-
from fastapi.middleware import Middleware
|
10 |
-
from fastapi.middleware.gzip import GZipMiddleware
|
11 |
from pydantic import BaseModel
|
12 |
-
|
13 |
-
from onnxruntime import InferenceSession
|
14 |
-
from huggingface_hub import snapshot_download
|
15 |
-
from scipy.io.wavfile import write as write_wav
|
16 |
-
|
17 |
from diffusers import OnnxStableDiffusionPipeline
|
|
|
18 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
class ImageRequest(BaseModel):
|
21 |
prompt: str
|
22 |
num_inference_steps: int = 50
|
23 |
guidance_scale: float = 7.5
|
24 |
-
format: str = "png" # or "jpeg"
|
25 |
-
|
26 |
-
|
27 |
-
model_repo = "runwayml/stable-diffusion-v1-5" # Or any other ONNX compatible Stable Diffusion model
|
28 |
-
model_name = "model_q4.onnx" # if specific model file needed, otherwise directory is enough
|
29 |
-
voice_file_pattern = "*.bin" # not used, keep for inspiration, remove if not needed
|
30 |
-
local_dir = "sd_onnx_models_snapshot" # different folder for sd models
|
31 |
-
snapshot_download(
|
32 |
-
repo_id=model_repo,
|
33 |
-
revision="onnx",
|
34 |
-
local_dir=local_dir,
|
35 |
-
local_dir_use_symlinks=False,
|
36 |
-
allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed)
|
37 |
-
)
|
38 |
-
|
39 |
-
|
40 |
-
pipeline = OnnxStableDiffusionPipeline.from_pretrained(
|
41 |
-
local_dir, # Use the local path from snapshot_download
|
42 |
-
provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
|
43 |
-
)
|
44 |
-
|
45 |
-
|
46 |
-
app = FastAPI(
|
47 |
-
title="FastAPI Image Generation with ONNX",
|
48 |
-
middleware=[Middleware(GZipMiddleware, compresslevel=9)] # maybe compression is not needed for images? check later
|
49 |
-
)
|
50 |
-
|
51 |
-
|
52 |
-
@app.post("/generate-image/streaming", summary="Streaming Image Generation")
|
53 |
-
async def generate_image_streaming(request: ImageRequest = Body(...)):
|
54 |
-
prompt = request.prompt
|
55 |
-
num_inference_steps = request.num_inference_steps
|
56 |
-
guidance_scale = request.guidance_scale
|
57 |
-
format = request.format.lower()
|
58 |
-
|
59 |
-
def image_generator():
|
60 |
|
|
|
|
|
|
|
|
|
61 |
try:
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
img_byte_arr = img_byte_arr.getvalue()
|
74 |
-
yield img_byte_arr
|
75 |
|
76 |
except Exception as e:
|
77 |
-
print(f"Error
|
78 |
-
|
79 |
-
|
80 |
-
media_type = f"image/{format}" if format in ["png", "jpeg"] else "image/png"
|
81 |
-
return StreamingResponse(
|
82 |
-
image_generator(),
|
83 |
-
media_type=media_type,
|
84 |
-
headers={"Cache-Control": "no-cache"},
|
85 |
-
)
|
86 |
-
|
87 |
-
|
88 |
-
@app.post("/generate-image/full", summary="Full Image Generation")
|
89 |
-
async def generate_image_full(request: ImageRequest = Body(...)):
|
90 |
-
prompt = request.prompt
|
91 |
-
num_inference_steps = request.num_inference_steps
|
92 |
-
guidance_scale = request.guidance_scale
|
93 |
-
format = request.format.lower()
|
94 |
-
|
95 |
-
start_time = time.time()
|
96 |
-
image = pipeline(
|
97 |
-
prompt,
|
98 |
-
num_inference_steps=num_inference_steps,
|
99 |
-
guidance_scale=guidance_scale
|
100 |
-
).images[0]
|
101 |
-
print(f"Full Image generation inference time: {time.time()-start_time:.3f}s")
|
102 |
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
@app.get("/", response_class=HTMLResponse)
|
115 |
-
def index():
|
116 |
-
return """
|
117 |
-
<!DOCTYPE html>
|
118 |
-
<html>
|
119 |
-
<head>
|
120 |
-
<title>FastAPI Image Generation Demo</title>
|
121 |
-
<style>
|
122 |
-
body { font-family: Arial, sans-serif; }
|
123 |
-
.container { width: 80%; margin: auto; padding-top: 20px; }
|
124 |
-
h1 { text-align: center; }
|
125 |
-
.form-group { margin-bottom: 15px; }
|
126 |
-
label { display: block; margin-bottom: 5px; font-weight: bold; }
|
127 |
-
input[type="text"], input[type="number"], textarea, select { width: 100%; padding: 8px; box-sizing: border-box; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; }
|
128 |
-
textarea { height: 100px; }
|
129 |
-
button { padding: 10px 15px; border: none; color: white; background-color: #007bff; border-radius: 4px; cursor: pointer; }
|
130 |
-
button:hover { background-color: #0056b3; }
|
131 |
-
img { display: block; margin-top: 20px; max-width: 500px; } /* Adjust max-width as needed */
|
132 |
-
</style>
|
133 |
-
</head>
|
134 |
-
<body>
|
135 |
-
<div class="container">
|
136 |
-
<h1>FastAPI Image Generation Demo</h1>
|
137 |
-
<div class="form-group">
|
138 |
-
<label for="prompt">Text Prompt:</label>
|
139 |
-
<textarea id="prompt" rows="4" placeholder="Enter text prompt here"></textarea>
|
140 |
-
</div>
|
141 |
-
<div class="form-group">
|
142 |
-
<label for="num_inference_steps">Number of Inference Steps:</label>
|
143 |
-
<input type="number" id="num_inference_steps" value="50">
|
144 |
-
</div>
|
145 |
-
<div class="form-group">
|
146 |
-
<label for="guidance_scale">Guidance Scale:</label>
|
147 |
-
<input type="number" step="0.5" id="guidance_scale" value="7.5">
|
148 |
-
</div>
|
149 |
-
<div class="form-group">
|
150 |
-
<label for="format">Format:</label>
|
151 |
-
<select id="format">
|
152 |
-
<option value="png" selected>PNG</option>
|
153 |
-
<option value="jpeg">JPEG</option>
|
154 |
-
</select>
|
155 |
-
</div>
|
156 |
-
<div class="form-group">
|
157 |
-
<button onclick="generateStreamingImage()">Generate Streaming Image</button>
|
158 |
-
<button onclick="generateFullImage()">Generate Full Image</button>
|
159 |
-
</div>
|
160 |
-
<div id="image-container">
|
161 |
-
<img id="image" src="#" alt="Generated Image" style="display:none;">
|
162 |
-
</div>
|
163 |
-
</div>
|
164 |
-
<script>
|
165 |
-
function generateStreamingImage() {
|
166 |
-
const prompt = document.getElementById('prompt').value;
|
167 |
-
const num_inference_steps = document.getElementById('num_inference_steps').value;
|
168 |
-
const guidance_scale = document.getElementById('guidance_scale').value;
|
169 |
-
const format = document.getElementById('format').value;
|
170 |
-
const imageElement = document.getElementById('image');
|
171 |
-
const imageContainer = document.getElementById('image-container');
|
172 |
-
|
173 |
-
fetch('/generate-image/streaming', {
|
174 |
-
method: 'POST',
|
175 |
-
headers: {
|
176 |
-
'Content-Type': 'application/json'
|
177 |
-
},
|
178 |
-
body: JSON.stringify({
|
179 |
-
prompt: prompt,
|
180 |
-
num_inference_steps: parseInt(num_inference_steps),
|
181 |
-
guidance_scale: parseFloat(guidance_scale),
|
182 |
-
format: format
|
183 |
-
})
|
184 |
-
})
|
185 |
-
.then(response => response.blob())
|
186 |
-
.then(blob => {
|
187 |
-
const imageUrl = URL.createObjectURL(blob);
|
188 |
-
imageElement.src = imageUrl;
|
189 |
-
imageElement.style.display = 'block'; // Show the image
|
190 |
-
imageContainer.style.display = 'block'; // Show the container if hidden
|
191 |
-
});
|
192 |
-
}
|
193 |
-
|
194 |
-
function generateFullImage() {
|
195 |
-
const prompt = document.getElementById('prompt').value;
|
196 |
-
const num_inference_steps = document.getElementById('num_inference_steps').value;
|
197 |
-
const guidance_scale = document.getElementById('guidance_scale').value;
|
198 |
-
const format = document.getElementById('format').value;
|
199 |
-
const imageElement = document.getElementById('image');
|
200 |
-
const imageContainer = document.getElementById('image-container');
|
201 |
|
|
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
'Content-Type': 'application/json'
|
207 |
-
},
|
208 |
-
body: JSON.stringify({
|
209 |
-
prompt: prompt,
|
210 |
-
num_inference_steps: parseInt(num_inference_steps),
|
211 |
-
guidance_scale: parseFloat(guidance_scale),
|
212 |
-
format: format
|
213 |
-
})
|
214 |
-
})
|
215 |
-
.then(response => response.blob())
|
216 |
-
.then(blob => {
|
217 |
-
const imageUrl = URL.createObjectURL(blob);
|
218 |
-
imageElement.src = imageUrl;
|
219 |
-
imageElement.style.display = 'block'; // Show the image
|
220 |
-
imageContainer.style.display = 'block'; // Show the container if hidden
|
221 |
-
});
|
222 |
-
}
|
223 |
-
</script>
|
224 |
-
</body>
|
225 |
-
</html>
|
226 |
-
"""
|
227 |
|
228 |
if __name__ == "__main__":
|
229 |
import uvicorn
|
230 |
-
uvicorn.run(
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Depends
|
2 |
+
from fastapi.responses import StreamingResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from pydantic import BaseModel
|
4 |
+
from io import BytesIO
|
|
|
|
|
|
|
|
|
5 |
from diffusers import OnnxStableDiffusionPipeline
|
6 |
+
from huggingface_hub import snapshot_download
|
7 |
from PIL import Image
|
8 |
+
import os
|
9 |
+
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
# Global variable to hold the loaded pipeline
|
13 |
+
pipeline = None
|
14 |
+
model_id = "clip.opt/model.onnx" # Or any other ONNX compatible Stable Diffusion model
|
15 |
+
repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Directory to store downloaded ONNX models using snapshot_download
|
16 |
|
17 |
class ImageRequest(BaseModel):
|
18 |
prompt: str
|
19 |
num_inference_steps: int = 50
|
20 |
guidance_scale: float = 7.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
+
async def load_pipeline():
|
23 |
+
"""Loads the ONNX Stable Diffusion pipeline from Hugging Face Hub using snapshot_download."""
|
24 |
+
global pipeline
|
25 |
+
if pipeline is None:
|
26 |
try:
|
27 |
+
local_model_path = snapshot_download(
|
28 |
+
repo_id=repo_id,
|
29 |
+
local_dir=".", # Specify local_dir to ensure files are placed there
|
30 |
+
allow_patterns=["clip.opt/*.onnx"] # Specify necessary file patterns (adjust as needed)
|
31 |
+
)
|
32 |
+
|
33 |
+
pipeline = OnnxStableDiffusionPipeline.from_pretrained(
|
34 |
+
local_model_path, # Use the local path from snapshot_download
|
35 |
+
provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
|
36 |
+
)
|
37 |
+
print(f"ONNX Stable Diffusion pipeline loaded successfully from {model_id} (ONNX revision) using snapshot_download from: {local_model_path}")
|
|
|
|
|
38 |
|
39 |
except Exception as e:
|
40 |
+
print(f"Error loading ONNX pipeline using snapshot_download: {e}")
|
41 |
+
raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}")
|
42 |
+
return pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
async def get_pipeline():
|
45 |
+
"""Dependency to ensure pipeline is loaded before endpoint is called."""
|
46 |
+
return await load_pipeline()
|
47 |
|
48 |
+
@app.on_event("startup")
|
49 |
+
async def startup_event():
|
50 |
+
"""Loads the pipeline on startup."""
|
51 |
+
await load_pipeline()
|
52 |
|
53 |
+
@app.post("/generate-image/")
|
54 |
+
async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)):
|
55 |
+
"""
|
56 |
+
Generates an image based on the provided text prompt using the loaded ONNX Stable Diffusion pipeline.
|
57 |
+
"""
|
58 |
+
try:
|
59 |
+
image = pipeline_dep(
|
60 |
+
request.prompt,
|
61 |
+
num_inference_steps=request.num_inference_steps,
|
62 |
+
guidance_scale=request.guidance_scale
|
63 |
+
).images[0]
|
64 |
|
65 |
+
# Convert PIL Image to bytes for streaming response
|
66 |
+
img_byte_arr = BytesIO()
|
67 |
+
image.save(img_byte_arr, format="PNG")
|
68 |
+
img_byte_arr = img_byte_arr.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
return StreamingResponse(content=iter([img_byte_arr]), media_type="image/png")
|
71 |
|
72 |
+
except Exception as e:
|
73 |
+
print(f"Error during image generation: {e}")
|
74 |
+
raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
import uvicorn
|
78 |
+
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)
|