cybergamer0123 commited on
Commit
a7e8e3a
·
verified ·
1 Parent(s): 5f55b78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -209
app.py CHANGED
@@ -1,230 +1,78 @@
1
- import io
2
- import os
3
- import time
4
- import json
5
-
6
- import numpy as np
7
- from fastapi import FastAPI, HTTPException, Body
8
- from fastapi.responses import StreamingResponse, Response, HTMLResponse
9
- from fastapi.middleware import Middleware
10
- from fastapi.middleware.gzip import GZipMiddleware
11
  from pydantic import BaseModel
12
-
13
- from onnxruntime import InferenceSession
14
- from huggingface_hub import snapshot_download
15
- from scipy.io.wavfile import write as write_wav
16
-
17
  from diffusers import OnnxStableDiffusionPipeline
 
18
  from PIL import Image
 
 
 
 
 
 
 
 
19
 
20
  class ImageRequest(BaseModel):
21
  prompt: str
22
  num_inference_steps: int = 50
23
  guidance_scale: float = 7.5
24
- format: str = "png" # or "jpeg"
25
-
26
-
27
- model_repo = "runwayml/stable-diffusion-v1-5" # Or any other ONNX compatible Stable Diffusion model
28
- model_name = "model_q4.onnx" # if specific model file needed, otherwise directory is enough
29
- voice_file_pattern = "*.bin" # not used, keep for inspiration, remove if not needed
30
- local_dir = "sd_onnx_models_snapshot" # different folder for sd models
31
- snapshot_download(
32
- repo_id=model_repo,
33
- revision="onnx",
34
- local_dir=local_dir,
35
- local_dir_use_symlinks=False,
36
- allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed)
37
- )
38
-
39
-
40
- pipeline = OnnxStableDiffusionPipeline.from_pretrained(
41
- local_dir, # Use the local path from snapshot_download
42
- provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
43
- )
44
-
45
-
46
- app = FastAPI(
47
- title="FastAPI Image Generation with ONNX",
48
- middleware=[Middleware(GZipMiddleware, compresslevel=9)] # maybe compression is not needed for images? check later
49
- )
50
-
51
-
52
- @app.post("/generate-image/streaming", summary="Streaming Image Generation")
53
- async def generate_image_streaming(request: ImageRequest = Body(...)):
54
- prompt = request.prompt
55
- num_inference_steps = request.num_inference_steps
56
- guidance_scale = request.guidance_scale
57
- format = request.format.lower()
58
-
59
- def image_generator():
60
 
 
 
 
 
61
  try:
62
- start_time = time.time()
63
- image = pipeline(
64
- prompt,
65
- num_inference_steps=num_inference_steps,
66
- guidance_scale=guidance_scale
67
- ).images[0]
68
- print(f"Image generation inference time: {time.time() - start_time:.3f}s")
69
-
70
- img_byte_arr = io.BytesIO()
71
- image_format = format.upper() if format in ["png", "jpeg"] else "PNG" # Default to PNG if format is invalid
72
- image.save(img_byte_arr, format=image_format)
73
- img_byte_arr = img_byte_arr.getvalue()
74
- yield img_byte_arr
75
 
76
  except Exception as e:
77
- print(f"Error processing image generation: {e}")
78
- # yield error response? or just error out
79
-
80
- media_type = f"image/{format}" if format in ["png", "jpeg"] else "image/png"
81
- return StreamingResponse(
82
- image_generator(),
83
- media_type=media_type,
84
- headers={"Cache-Control": "no-cache"},
85
- )
86
-
87
-
88
- @app.post("/generate-image/full", summary="Full Image Generation")
89
- async def generate_image_full(request: ImageRequest = Body(...)):
90
- prompt = request.prompt
91
- num_inference_steps = request.num_inference_steps
92
- guidance_scale = request.guidance_scale
93
- format = request.format.lower()
94
-
95
- start_time = time.time()
96
- image = pipeline(
97
- prompt,
98
- num_inference_steps=num_inference_steps,
99
- guidance_scale=guidance_scale
100
- ).images[0]
101
- print(f"Full Image generation inference time: {time.time()-start_time:.3f}s")
102
 
 
 
 
103
 
104
- img_byte_arr = io.BytesIO()
105
- image_format = format.upper() if format in ["png", "jpeg"] else "PNG"
106
- image.save(img_byte_arr, format=image_format)
107
- img_byte_arr.seek(0)
108
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- media_type = f"image/{format}" if format in ["png", "jpeg"] else "image/png"
111
- return Response(content=img_byte_arr.read(), media_type=media_type)
112
-
113
-
114
- @app.get("/", response_class=HTMLResponse)
115
- def index():
116
- return """
117
- <!DOCTYPE html>
118
- <html>
119
- <head>
120
- <title>FastAPI Image Generation Demo</title>
121
- <style>
122
- body { font-family: Arial, sans-serif; }
123
- .container { width: 80%; margin: auto; padding-top: 20px; }
124
- h1 { text-align: center; }
125
- .form-group { margin-bottom: 15px; }
126
- label { display: block; margin-bottom: 5px; font-weight: bold; }
127
- input[type="text"], input[type="number"], textarea, select { width: 100%; padding: 8px; box-sizing: border-box; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; }
128
- textarea { height: 100px; }
129
- button { padding: 10px 15px; border: none; color: white; background-color: #007bff; border-radius: 4px; cursor: pointer; }
130
- button:hover { background-color: #0056b3; }
131
- img { display: block; margin-top: 20px; max-width: 500px; } /* Adjust max-width as needed */
132
- </style>
133
- </head>
134
- <body>
135
- <div class="container">
136
- <h1>FastAPI Image Generation Demo</h1>
137
- <div class="form-group">
138
- <label for="prompt">Text Prompt:</label>
139
- <textarea id="prompt" rows="4" placeholder="Enter text prompt here"></textarea>
140
- </div>
141
- <div class="form-group">
142
- <label for="num_inference_steps">Number of Inference Steps:</label>
143
- <input type="number" id="num_inference_steps" value="50">
144
- </div>
145
- <div class="form-group">
146
- <label for="guidance_scale">Guidance Scale:</label>
147
- <input type="number" step="0.5" id="guidance_scale" value="7.5">
148
- </div>
149
- <div class="form-group">
150
- <label for="format">Format:</label>
151
- <select id="format">
152
- <option value="png" selected>PNG</option>
153
- <option value="jpeg">JPEG</option>
154
- </select>
155
- </div>
156
- <div class="form-group">
157
- <button onclick="generateStreamingImage()">Generate Streaming Image</button>
158
- <button onclick="generateFullImage()">Generate Full Image</button>
159
- </div>
160
- <div id="image-container">
161
- <img id="image" src="#" alt="Generated Image" style="display:none;">
162
- </div>
163
- </div>
164
- <script>
165
- function generateStreamingImage() {
166
- const prompt = document.getElementById('prompt').value;
167
- const num_inference_steps = document.getElementById('num_inference_steps').value;
168
- const guidance_scale = document.getElementById('guidance_scale').value;
169
- const format = document.getElementById('format').value;
170
- const imageElement = document.getElementById('image');
171
- const imageContainer = document.getElementById('image-container');
172
-
173
- fetch('/generate-image/streaming', {
174
- method: 'POST',
175
- headers: {
176
- 'Content-Type': 'application/json'
177
- },
178
- body: JSON.stringify({
179
- prompt: prompt,
180
- num_inference_steps: parseInt(num_inference_steps),
181
- guidance_scale: parseFloat(guidance_scale),
182
- format: format
183
- })
184
- })
185
- .then(response => response.blob())
186
- .then(blob => {
187
- const imageUrl = URL.createObjectURL(blob);
188
- imageElement.src = imageUrl;
189
- imageElement.style.display = 'block'; // Show the image
190
- imageContainer.style.display = 'block'; // Show the container if hidden
191
- });
192
- }
193
-
194
- function generateFullImage() {
195
- const prompt = document.getElementById('prompt').value;
196
- const num_inference_steps = document.getElementById('num_inference_steps').value;
197
- const guidance_scale = document.getElementById('guidance_scale').value;
198
- const format = document.getElementById('format').value;
199
- const imageElement = document.getElementById('image');
200
- const imageContainer = document.getElementById('image-container');
201
 
 
202
 
203
- fetch('/generate-image/full', {
204
- method: 'POST',
205
- headers: {
206
- 'Content-Type': 'application/json'
207
- },
208
- body: JSON.stringify({
209
- prompt: prompt,
210
- num_inference_steps: parseInt(num_inference_steps),
211
- guidance_scale: parseFloat(guidance_scale),
212
- format: format
213
- })
214
- })
215
- .then(response => response.blob())
216
- .then(blob => {
217
- const imageUrl = URL.createObjectURL(blob);
218
- imageElement.src = imageUrl;
219
- imageElement.style.display = 'block'; // Show the image
220
- imageContainer.style.display = 'block'; // Show the container if hidden
221
- });
222
- }
223
- </script>
224
- </body>
225
- </html>
226
- """
227
 
228
  if __name__ == "__main__":
229
  import uvicorn
230
- uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
1
+ from fastapi import FastAPI, HTTPException, Depends
2
+ from fastapi.responses import StreamingResponse
 
 
 
 
 
 
 
 
3
  from pydantic import BaseModel
4
+ from io import BytesIO
 
 
 
 
5
  from diffusers import OnnxStableDiffusionPipeline
6
+ from huggingface_hub import snapshot_download
7
  from PIL import Image
8
+ import os
9
+
10
+ app = FastAPI()
11
+
12
+ # Global variable to hold the loaded pipeline
13
+ pipeline = None
14
+ model_id = "clip.opt/model.onnx" # Or any other ONNX compatible Stable Diffusion model
15
+ repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Directory to store downloaded ONNX models using snapshot_download
16
 
17
  class ImageRequest(BaseModel):
18
  prompt: str
19
  num_inference_steps: int = 50
20
  guidance_scale: float = 7.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ async def load_pipeline():
23
+ """Loads the ONNX Stable Diffusion pipeline from Hugging Face Hub using snapshot_download."""
24
+ global pipeline
25
+ if pipeline is None:
26
  try:
27
+ local_model_path = snapshot_download(
28
+ repo_id=repo_id,
29
+ local_dir=".", # Specify local_dir to ensure files are placed there
30
+ allow_patterns=["clip.opt/*.onnx"] # Specify necessary file patterns (adjust as needed)
31
+ )
32
+
33
+ pipeline = OnnxStableDiffusionPipeline.from_pretrained(
34
+ local_model_path, # Use the local path from snapshot_download
35
+ provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
36
+ )
37
+ print(f"ONNX Stable Diffusion pipeline loaded successfully from {model_id} (ONNX revision) using snapshot_download from: {local_model_path}")
 
 
38
 
39
  except Exception as e:
40
+ print(f"Error loading ONNX pipeline using snapshot_download: {e}")
41
+ raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}")
42
+ return pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ async def get_pipeline():
45
+ """Dependency to ensure pipeline is loaded before endpoint is called."""
46
+ return await load_pipeline()
47
 
48
+ @app.on_event("startup")
49
+ async def startup_event():
50
+ """Loads the pipeline on startup."""
51
+ await load_pipeline()
52
 
53
+ @app.post("/generate-image/")
54
+ async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)):
55
+ """
56
+ Generates an image based on the provided text prompt using the loaded ONNX Stable Diffusion pipeline.
57
+ """
58
+ try:
59
+ image = pipeline_dep(
60
+ request.prompt,
61
+ num_inference_steps=request.num_inference_steps,
62
+ guidance_scale=request.guidance_scale
63
+ ).images[0]
64
 
65
+ # Convert PIL Image to bytes for streaming response
66
+ img_byte_arr = BytesIO()
67
+ image.save(img_byte_arr, format="PNG")
68
+ img_byte_arr = img_byte_arr.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ return StreamingResponse(content=iter([img_byte_arr]), media_type="image/png")
71
 
72
+ except Exception as e:
73
+ print(f"Error during image generation: {e}")
74
+ raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  if __name__ == "__main__":
77
  import uvicorn
78
+ uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)