cybergamer0123 commited on
Commit
b80eff0
·
verified ·
1 Parent(s): a7e8e3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -28
app.py CHANGED
@@ -1,54 +1,60 @@
1
  from fastapi import FastAPI, HTTPException, Depends
2
- from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
  from io import BytesIO
5
  from diffusers import OnnxStableDiffusionPipeline
6
  from huggingface_hub import snapshot_download
7
  from PIL import Image
8
  import os
 
9
 
10
- app = FastAPI()
 
 
11
 
12
- # Global variable to hold the loaded pipeline
13
- pipeline = None
14
- model_id = "clip.opt/model.onnx" # Or any other ONNX compatible Stable Diffusion model
15
- repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Directory to store downloaded ONNX models using snapshot_download
16
-
17
- class ImageRequest(BaseModel):
18
- prompt: str
19
- num_inference_steps: int = 50
20
- guidance_scale: float = 7.5
21
-
22
- async def load_pipeline():
23
- """Loads the ONNX Stable Diffusion pipeline from Hugging Face Hub using snapshot_download."""
24
  global pipeline
25
  if pipeline is None:
26
  try:
 
27
  local_model_path = snapshot_download(
28
  repo_id=repo_id,
29
- local_dir=".", # Specify local_dir to ensure files are placed there
30
- allow_patterns=["clip.opt/*.onnx"] # Specify necessary file patterns (adjust as needed)
31
  )
32
 
33
  pipeline = OnnxStableDiffusionPipeline.from_pretrained(
34
- local_model_path, # Use the local path from snapshot_download
35
  provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
 
36
  )
37
- print(f"ONNX Stable Diffusion pipeline loaded successfully from {model_id} (ONNX revision) using snapshot_download from: {local_model_path}")
38
 
39
  except Exception as e:
40
  print(f"Error loading ONNX pipeline using snapshot_download: {e}")
41
  raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}")
42
- return pipeline
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  async def get_pipeline():
45
  """Dependency to ensure pipeline is loaded before endpoint is called."""
46
- return await load_pipeline()
47
-
48
- @app.on_event("startup")
49
- async def startup_event():
50
- """Loads the pipeline on startup."""
51
- await load_pipeline()
52
 
53
  @app.post("/generate-image/")
54
  async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)):
@@ -64,15 +70,118 @@ async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusio
64
 
65
  # Convert PIL Image to bytes for streaming response
66
  img_byte_arr = BytesIO()
67
- image.save(img_byte_arr, format="PNG")
68
  img_byte_arr = img_byte_arr.getvalue()
69
 
70
- return StreamingResponse(content=iter([img_byte_arr]), media_type="image/png")
71
 
72
  except Exception as e:
73
  print(f"Error during image generation: {e}")
74
  raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  if __name__ == "__main__":
77
  import uvicorn
78
- uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)
 
1
  from fastapi import FastAPI, HTTPException, Depends
2
+ from fastapi.responses import StreamingResponse, Response, HTMLResponse
3
  from pydantic import BaseModel
4
  from io import BytesIO
5
  from diffusers import OnnxStableDiffusionPipeline
6
  from huggingface_hub import snapshot_download
7
  from PIL import Image
8
  import os
9
+ from contextlib import asynccontextmanager
10
 
11
+ repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Hugging Face repo ID
12
+ local_dir = "sd_onnx_models" # Subdirectory for models
13
+ pipeline = None # Initialize pipeline outside lifespan
14
 
15
+ @asynccontextmanager
16
+ async def lifespan(app: FastAPI):
17
+ """
18
+ Lifespan event handler to load the ONNX Stable Diffusion pipeline on startup and unload on shutdown.
19
+ """
 
 
 
 
 
 
 
20
  global pipeline
21
  if pipeline is None:
22
  try:
23
+ allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed)
24
  local_model_path = snapshot_download(
25
  repo_id=repo_id,
26
+ local_dir=local_dir, # Specify local_dir to ensure files are placed there
27
+ allow_patterns=allow_patterns
28
  )
29
 
30
  pipeline = OnnxStableDiffusionPipeline.from_pretrained(
31
+ local_dir, # Use the local path from snapshot_download
32
  provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
33
+ revision="onnx" # Explicitly specify ONNX revision when loading from local path
34
  )
35
+ print(f"ONNX Stable Diffusion pipeline loaded successfully from {repo_id} (ONNX revision) using snapshot_download from: {local_model_path}")
36
 
37
  except Exception as e:
38
  print(f"Error loading ONNX pipeline using snapshot_download: {e}")
39
  raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}")
40
+ yield # App starts up after this point
41
+ pipeline = None # Optionally unload pipeline on shutdown (if needed for resource management)
42
+ print("ONNX Stable Diffusion pipeline unloaded.")
43
+
44
+
45
+ app = FastAPI(lifespan=lifespan) # Register lifespan function
46
+
47
+ class ImageRequest(BaseModel):
48
+ prompt: str
49
+ num_inference_steps: int = 50
50
+ guidance_scale: float = 7.5
51
+ format: str = "png" # default format
52
 
53
  async def get_pipeline():
54
  """Dependency to ensure pipeline is loaded before endpoint is called."""
55
+ if pipeline is None: # Check if pipeline is loaded (should be by lifespan)
56
+ raise HTTPException(status_code=500, detail="Pipeline not loaded. Startup might have failed.")
57
+ return pipeline
 
 
 
58
 
59
  @app.post("/generate-image/")
60
  async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)):
 
70
 
71
  # Convert PIL Image to bytes for streaming response
72
  img_byte_arr = BytesIO()
73
+ image.save(img_byte_arr, format=request.format.upper()) # Use format from request
74
  img_byte_arr = img_byte_arr.getvalue()
75
 
76
+ return StreamingResponse(content=iter([img_byte_arr]), media_type=f"image/{request.format}") # Use format from request
77
 
78
  except Exception as e:
79
  print(f"Error during image generation: {e}")
80
  raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")
81
 
82
+ @app.get("/", response_class=HTMLResponse)
83
+ def index():
84
+ return """
85
+ <!DOCTYPE html>
86
+ <html>
87
+ <head>
88
+ <title>FastAPI Image Generation Demo</title>
89
+ <style>
90
+ body { font-family: Arial, sans-serif; }
91
+ .container { width: 80%; margin: auto; padding-top: 20px; }
92
+ h1 { text-align: center; }
93
+ .form-group { margin-bottom: 15px; }
94
+ label { display: block; margin-bottom: 5px; font-weight: bold; }
95
+ input[type="text"], input[type="number"], textarea, select { width: 100%; padding: 8px; box-sizing: border-box; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; }
96
+ textarea { height: 100px; }
97
+ button { padding: 10px 15px; border: none; color: white; background-color: #007bff; border-radius: 4px; cursor: pointer; }
98
+ button:hover { background-color: #0056b3; }
99
+ img { display: block; margin-top: 20px; max-width: 500px; } /* Adjust max-width as needed */
100
+ #image-container { display: none; margin-top: 20px; text-align: center; } /* Initially hidden container */
101
+ </style>
102
+ </head>
103
+ <body>
104
+ <div class="container">
105
+ <h1>FastAPI Image Generation Demo</h1>
106
+ <div class="form-group">
107
+ <label for="prompt">Text Prompt:</label>
108
+ <textarea id="prompt" rows="4" placeholder="Enter text prompt here"></textarea>
109
+ </div>
110
+ <div class="form-group">
111
+ <label for="num_inference_steps">Number of Inference Steps:</label>
112
+ <input type="number" id="num_inference_steps" value="50">
113
+ </div>
114
+ <div class="form-group">
115
+ <label for="guidance_scale">Guidance Scale:</label>
116
+ <input type="number" step="0.5" id="guidance_scale" value="7.5">
117
+ </div>
118
+ <div class="form-group">
119
+ <label for="format">Format:</label>
120
+ <select id="format">
121
+ <option value="png" selected>PNG</option>
122
+ <option value="jpeg">JPEG</option>
123
+ </select>
124
+ </div>
125
+ <div class="form-group">
126
+ <button onclick="generateImage()">Generate Image</button>
127
+ </div>
128
+ <div id="image-container">
129
+ <img id="image" src="#" alt="Generated Image" style="max-width: 80%; height: auto;">
130
+ </div>
131
+ </div>
132
+ <script>
133
+ async function generateImage() {
134
+ const prompt = document.getElementById('prompt').value;
135
+ const num_inference_steps = document.getElementById('num_inference_steps').value;
136
+ const guidance_scale = document.getElementById('guidance_scale').value;
137
+ const format = document.getElementById('format').value;
138
+ const imageElement = document.getElementById('image');
139
+ const imageContainer = document.getElementById('image-container');
140
+ imageElement.style.display = 'none'; // Hide previous image while loading
141
+ imageContainer.style.display = 'none'; // Hide container too
142
+
143
+ try {
144
+ const response = await fetch('/generate-image/', {
145
+ method: 'POST',
146
+ headers: {
147
+ 'Content-Type': 'application/json'
148
+ },
149
+ body: JSON.stringify({
150
+ prompt: prompt,
151
+ num_inference_steps: parseInt(num_inference_steps),
152
+ guidance_scale: parseFloat(guidance_scale),
153
+ format: format
154
+ })
155
+ });
156
+
157
+ if (!response.ok) {
158
+ throw new Error(`HTTP error! status: ${response.status}`);
159
+ }
160
+
161
+ const blob = await response.blob();
162
+ const imageUrl = URL.createObjectURL(blob);
163
+ imageElement.src = imageUrl;
164
+ imageElement.onload = () => { // Only show when image is loaded
165
+ imageContainer.style.display = 'block'; // Show image container
166
+ imageElement.style.display = 'block'; // Show the image
167
+ };
168
+ imageElement.onerror = () => {
169
+ imageElement.style.display = 'none'; // Hide image on error
170
+ imageContainer.style.display = 'none'; // Hide container on error
171
+ alert("Error loading image. Please check console for details.");
172
+ };
173
+
174
+
175
+ } catch (error) {
176
+ console.error("Fetch error:", error);
177
+ alert("Error generating image. Please check console for details.");
178
+ }
179
+ }
180
+ </script>
181
+ </body>
182
+ </html>
183
+ """
184
+
185
  if __name__ == "__main__":
186
  import uvicorn
187
+ uvicorn.run(app, host="0.0.0.0", port=7860, reload=True)