File size: 8,644 Bytes
a7e8e3a
b80eff0
0ed9e98
a7e8e3a
0ed9e98
a7e8e3a
0ed9e98
a7e8e3a
b80eff0
a7e8e3a
b80eff0
 
 
a7e8e3a
b80eff0
 
 
 
 
a7e8e3a
 
0ed9e98
b80eff0
a7e8e3a
 
b80eff0
 
a7e8e3a
 
 
b80eff0
a7e8e3a
b80eff0
a7e8e3a
b80eff0
0ed9e98
 
a7e8e3a
 
b80eff0
 
 
 
 
 
 
 
 
 
 
 
0ed9e98
a7e8e3a
 
b80eff0
 
 
0ed9e98
a7e8e3a
 
 
 
 
 
 
 
 
 
 
0ed9e98
a7e8e3a
 
b80eff0
a7e8e3a
0ed9e98
b80eff0
0ed9e98
a7e8e3a
 
 
0ed9e98
b80eff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ed9e98
 
7527d1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from fastapi import FastAPI, HTTPException, Depends
from fastapi.responses import StreamingResponse, Response, HTMLResponse
from pydantic import BaseModel
from io import BytesIO
from diffusers import OnnxStableDiffusionPipeline
from huggingface_hub import snapshot_download
from PIL import Image
import os
from contextlib import asynccontextmanager

repo_id = "black-forest-labs/FLUX.1-dev-onnx" # Hugging Face repo ID
local_dir = "sd_onnx_models" # Subdirectory for models
pipeline = None # Initialize pipeline outside lifespan

@asynccontextmanager
async def lifespan(app: FastAPI):
    """
    Lifespan event handler to load the ONNX Stable Diffusion pipeline on startup and unload on shutdown.
    """
    global pipeline
    if pipeline is None:
        try:
            allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed)
            local_model_path = snapshot_download(
                repo_id=repo_id,
                local_dir=local_dir, # Specify local_dir to ensure files are placed there
                allow_patterns=allow_patterns
            )

            pipeline = OnnxStableDiffusionPipeline.from_pretrained(
                local_dir,  # Use the local path from snapshot_download
                provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
                revision="onnx" # Explicitly specify ONNX revision when loading from local path
            )
            print(f"ONNX Stable Diffusion pipeline loaded successfully from {repo_id} (ONNX revision) using snapshot_download from: {local_model_path}")

        except Exception as e:
            print(f"Error loading ONNX pipeline using snapshot_download: {e}")
            raise HTTPException(status_code=500, detail=f"Failed to load ONNX Stable Diffusion pipeline using snapshot_download: {e}")
    yield # App starts up after this point
    pipeline = None # Optionally unload pipeline on shutdown (if needed for resource management)
    print("ONNX Stable Diffusion pipeline unloaded.")


app = FastAPI(lifespan=lifespan) # Register lifespan function

class ImageRequest(BaseModel):
    prompt: str
    num_inference_steps: int = 50
    guidance_scale: float = 7.5
    format: str = "png" # default format

async def get_pipeline():
    """Dependency to ensure pipeline is loaded before endpoint is called."""
    if pipeline is None: # Check if pipeline is loaded (should be by lifespan)
        raise HTTPException(status_code=500, detail="Pipeline not loaded. Startup might have failed.")
    return pipeline

@app.post("/generate-image/")
async def generate_image(request: ImageRequest, pipeline_dep: OnnxStableDiffusionPipeline = Depends(get_pipeline)):
    """
    Generates an image based on the provided text prompt using the loaded ONNX Stable Diffusion pipeline.
    """
    try:
        image = pipeline_dep(
            request.prompt,
            num_inference_steps=request.num_inference_steps,
            guidance_scale=request.guidance_scale
        ).images[0]

        # Convert PIL Image to bytes for streaming response
        img_byte_arr = BytesIO()
        image.save(img_byte_arr, format=request.format.upper()) # Use format from request
        img_byte_arr = img_byte_arr.getvalue()

        return StreamingResponse(content=iter([img_byte_arr]), media_type=f"image/{request.format}") # Use format from request

    except Exception as e:
        print(f"Error during image generation: {e}")
        raise HTTPException(status_code=500, detail=f"Image generation failed: {e}")

@app.get("/", response_class=HTMLResponse)
def index():
    return """
    <!DOCTYPE html>
    <html>
    <head>
        <title>FastAPI Image Generation Demo</title>
        <style>
            body { font-family: Arial, sans-serif; }
            .container { width: 80%; margin: auto; padding-top: 20px; }
            h1 { text-align: center; }
            .form-group { margin-bottom: 15px; }
            label { display: block; margin-bottom: 5px; font-weight: bold; }
            input[type="text"], input[type="number"], textarea, select { width: 100%; padding: 8px; box-sizing: border-box; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; }
            textarea { height: 100px; }
            button { padding: 10px 15px; border: none; color: white; background-color: #007bff; border-radius: 4px; cursor: pointer; }
            button:hover { background-color: #0056b3; }
            img { display: block; margin-top: 20px; max-width: 500px; } /* Adjust max-width as needed */
            #image-container { display: none; margin-top: 20px; text-align: center; } /* Initially hidden container */
        </style>
    </head>
    <body>
        <div class="container">
            <h1>FastAPI Image Generation Demo</h1>
            <div class="form-group">
                <label for="prompt">Text Prompt:</label>
                <textarea id="prompt" rows="4" placeholder="Enter text prompt here"></textarea>
            </div>
            <div class="form-group">
                <label for="num_inference_steps">Number of Inference Steps:</label>
                <input type="number" id="num_inference_steps" value="50">
            </div>
            <div class="form-group">
                <label for="guidance_scale">Guidance Scale:</label>
                <input type="number" step="0.5" id="guidance_scale" value="7.5">
            </div>
            <div class="form-group">
                <label for="format">Format:</label>
                <select id="format">
                    <option value="png" selected>PNG</option>
                    <option value="jpeg">JPEG</option>
                </select>
            </div>
            <div class="form-group">
                <button onclick="generateImage()">Generate Image</button>
            </div>
             <div id="image-container">
                <img id="image" src="#" alt="Generated Image" style="max-width: 80%; height: auto;">
            </div>
        </div>
        <script>
            async function generateImage() {
                const prompt = document.getElementById('prompt').value;
                const num_inference_steps = document.getElementById('num_inference_steps').value;
                const guidance_scale = document.getElementById('guidance_scale').value;
                const format = document.getElementById('format').value;
                const imageElement = document.getElementById('image');
                const imageContainer = document.getElementById('image-container');
                imageElement.style.display = 'none'; // Hide previous image while loading
                imageContainer.style.display = 'none'; // Hide container too

                try {
                    const response = await fetch('/generate-image/', {
                        method: 'POST',
                        headers: {
                            'Content-Type': 'application/json'
                        },
                        body: JSON.stringify({
                            prompt: prompt,
                            num_inference_steps: parseInt(num_inference_steps),
                            guidance_scale: parseFloat(guidance_scale),
                            format: format
                        })
                    });

                    if (!response.ok) {
                        throw new Error(`HTTP error! status: ${response.status}`);
                    }

                    const blob = await response.blob();
                    const imageUrl = URL.createObjectURL(blob);
                    imageElement.src = imageUrl;
                    imageElement.onload = () => { // Only show when image is loaded
                        imageContainer.style.display = 'block'; // Show image container
                        imageElement.style.display = 'block'; // Show the image
                    };
                    imageElement.onerror = () => {
                        imageElement.style.display = 'none'; // Hide image on error
                        imageContainer.style.display = 'none'; // Hide container on error
                        alert("Error loading image. Please check console for details.");
                    };


                } catch (error) {
                    console.error("Fetch error:", error);
                    alert("Error generating image. Please check console for details.");
                }
            }
        </script>
    </body>
    </html>
    """

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)