matjarm commited on
Commit
043d7db
·
1 Parent(s): 805f75d
Files changed (2) hide show
  1. app.py +207 -0
  2. requirement.txt +154 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
9
+ from typing import Tuple
10
+
11
+ # CSS for Gradio Interface
12
+ css = '''
13
+ .gradio-container{max-width: 575px !important}
14
+ h1{text-align:center}
15
+ footer {
16
+ visibility: hidden
17
+ }
18
+ '''
19
+
20
+ DESCRIPTION = """
21
+ ## Text-to-Image Generator 🚀
22
+ Create stunning images from text prompts using Stable Diffusion XL. Explore high-quality styles and customizable options.
23
+ """
24
+
25
+ # Example Prompts
26
+ examples = [
27
+ "A beautiful sunset over the ocean, ultra-realistic, high resolution",
28
+ "A futuristic cityscape with flying cars, cyberpunk theme, vibrant colors",
29
+ "A cozy cabin in the woods during winter, detailed and realistic",
30
+ "A magical forest with glowing plants and creatures, fantasy art",
31
+ ]
32
+
33
+ # Model Configurations
34
+ MODEL_OPTIONS = {
35
+ "LIGHTNING V5.0": "SG161222/RealVisXL_V5.0_Lightning",
36
+ "LIGHTNING V4.0": "SG161222/RealVisXL_V4.0_Lightning",
37
+ }
38
+
39
+ # Define Styles
40
+ style_list = [
41
+ {
42
+ "name": "Ultra HD",
43
+ "prompt": "hyper-realistic 8K image of {prompt}. ultra-detailed, lifelike, high-resolution, sharp, vibrant colors, photorealistic",
44
+ "negative_prompt": "cartoonish, low resolution, blurry, simplistic, abstract, deformed, ugly",
45
+ },
46
+ {
47
+ "name": "4K Realistic",
48
+ "prompt": "realistic 4K image of {prompt}. sharp, detailed, vibrant colors, photorealistic",
49
+ "negative_prompt": "cartoonish, blurry, low resolution",
50
+ },
51
+ {
52
+ "name": "Minimal Style",
53
+ "prompt": "{prompt}, clean, minimalistic",
54
+ "negative_prompt": "",
55
+ },
56
+ ]
57
+
58
+ styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
59
+ DEFAULT_STYLE_NAME = "Ultra HD"
60
+
61
+ # Define Global Variables
62
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
63
+ MAX_IMAGE_SIZE = 4096
64
+ MAX_SEED = np.iinfo(np.int32).max
65
+
66
+ # Load Model Function
67
+ def load_and_prepare_model(model_id):
68
+ pipe = StableDiffusionXLPipeline.from_pretrained(
69
+ model_id,
70
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
71
+ ).to(device)
72
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
73
+ return pipe
74
+
75
+ # Load Models
76
+ models = {key: load_and_prepare_model(value) for key, value in MODEL_OPTIONS.items()}
77
+
78
+ # Generate Function
79
+ def generate_image(
80
+ model_choice: str,
81
+ prompt: str,
82
+ negative_prompt: str,
83
+ style_name: str,
84
+ width: int,
85
+ height: int,
86
+ guidance_scale: float,
87
+ num_steps: int,
88
+ num_images: int,
89
+ randomize_seed: bool,
90
+ seed: int,
91
+ ):
92
+ # Apply Style
93
+ positive_style, negative_style = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
94
+ styled_prompt = positive_style.replace("{prompt}", prompt)
95
+ styled_negative_prompt = negative_style + (negative_prompt if negative_prompt else "")
96
+
97
+ # Randomize Seed if Enabled
98
+ if randomize_seed:
99
+ seed = random.randint(0, MAX_SEED)
100
+ generator = torch.Generator(device=device).manual_seed(seed)
101
+
102
+ # Generate Images
103
+ pipe = models[model_choice]
104
+ images = pipe(
105
+ prompt=[styled_prompt] * num_images,
106
+ negative_prompt=[styled_negative_prompt] * num_images,
107
+ width=width,
108
+ height=height,
109
+ guidance_scale=guidance_scale,
110
+ num_inference_steps=num_steps,
111
+ generator=generator,
112
+ output_type="pil",
113
+ ).images
114
+
115
+ # Save and Return Images
116
+ image_paths = []
117
+ for img in images:
118
+ unique_name = f"{uuid.uuid4()}.png"
119
+ img.save(unique_name)
120
+ image_paths.append(unique_name)
121
+
122
+ return image_paths, seed
123
+
124
+ # Gradio Interface
125
+ with gr.Blocks(css=css) as demo:
126
+ gr.Markdown(DESCRIPTION)
127
+
128
+ with gr.Row():
129
+ model_choice = gr.Dropdown(
130
+ label="Select Model",
131
+ choices=list(MODEL_OPTIONS.keys()),
132
+ value="LIGHTNING V5.0",
133
+ )
134
+
135
+ prompt = gr.Textbox(
136
+ label="Prompt",
137
+ placeholder="Enter your creative prompt here...",
138
+ )
139
+
140
+ negative_prompt = gr.Textbox(
141
+ label="Negative Prompt",
142
+ placeholder="Optional: Add details you want to avoid...",
143
+ value="blurry, deformed, low-quality, cartoonish",
144
+ )
145
+
146
+ style_name = gr.Radio(
147
+ label="Style",
148
+ choices=list(styles.keys()),
149
+ value=DEFAULT_STYLE_NAME,
150
+ )
151
+
152
+ with gr.Accordion("Advanced Options", open=False):
153
+ width = gr.Slider(label="Width", minimum=512, maximum=2048, step=8, value=1024)
154
+ height = gr.Slider(label="Height", minimum=512, maximum=2048, step=8, value=1024)
155
+ guidance_scale = gr.Slider(
156
+ label="Guidance Scale",
157
+ minimum=1,
158
+ maximum=20,
159
+ step=0.5,
160
+ value=7.5,
161
+ )
162
+ num_steps = gr.Slider(
163
+ label="Steps",
164
+ minimum=1,
165
+ maximum=50,
166
+ step=1,
167
+ value=25,
168
+ )
169
+ num_images = gr.Slider(
170
+ label="Number of Images",
171
+ minimum=1,
172
+ maximum=5,
173
+ step=1,
174
+ value=1,
175
+ )
176
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
177
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
178
+
179
+ with gr.Row():
180
+ run_button = gr.Button("Generate Images")
181
+ result_gallery = gr.Gallery(label="Generated Images", show_label=False)
182
+
183
+ run_button.click(
184
+ generate_image,
185
+ inputs=[
186
+ model_choice,
187
+ prompt,
188
+ negative_prompt,
189
+ style_name,
190
+ width,
191
+ height,
192
+ guidance_scale,
193
+ num_steps,
194
+ num_images,
195
+ randomize_seed,
196
+ seed,
197
+ ],
198
+ outputs=[result_gallery, seed],
199
+ )
200
+
201
+ gr.Examples(
202
+ examples=examples,
203
+ inputs=prompt,
204
+ )
205
+
206
+ if __name__ == "__main__":
207
+ demo.queue(max_size=50).launch()
requirement.txt ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.6.2.post1
4
+ appnope==0.1.4
5
+ argon2-cffi==23.1.0
6
+ argon2-cffi-bindings==21.2.0
7
+ arrow==1.3.0
8
+ asttokens==2.4.1
9
+ async-lru==2.0.4
10
+ attrs==24.2.0
11
+ babel==2.16.0
12
+ beautifulsoup4==4.12.3
13
+ bleach==6.2.0
14
+ blinker==1.9.0
15
+ certifi==2024.8.30
16
+ cffi==1.17.1
17
+ charset-normalizer==3.4.0
18
+ click==8.1.7
19
+ comm==0.2.2
20
+ contourpy==1.3.0
21
+ cycler==0.12.1
22
+ debugpy==1.8.8
23
+ decorator==5.1.1
24
+ defusedxml==0.7.1
25
+ diffusers==0.31.0
26
+ exceptiongroup==1.2.2
27
+ executing==2.1.0
28
+ fastapi==0.115.6
29
+ fastjsonschema==2.20.0
30
+ ffmpy==0.4.0
31
+ filelock==3.16.1
32
+ Flask==3.1.0
33
+ fonttools==4.55.2
34
+ fqdn==1.5.1
35
+ fsspec==2024.10.0
36
+ gradio==4.44.1
37
+ gradio_client==1.3.0
38
+ h11==0.14.0
39
+ httpcore==1.0.6
40
+ httpx==0.27.2
41
+ huggingface-hub==0.26.3
42
+ idna==3.10
43
+ importlib_metadata==8.5.0
44
+ importlib_resources==6.4.5
45
+ ipykernel==6.29.5
46
+ ipython==8.18.1
47
+ isoduration==20.11.0
48
+ itsdangerous==2.2.0
49
+ jedi==0.19.2
50
+ Jinja2==3.1.4
51
+ joblib==1.4.2
52
+ json5==0.9.28
53
+ jsonpointer==3.0.0
54
+ jsonschema==4.23.0
55
+ jsonschema-specifications==2024.10.1
56
+ jupyter-events==0.10.0
57
+ jupyter-lsp==2.2.5
58
+ jupyter_client==8.6.3
59
+ jupyter_core==5.7.2
60
+ jupyter_server==2.14.2
61
+ jupyter_server_terminals==0.5.3
62
+ jupyterlab==4.3.0
63
+ jupyterlab_pygments==0.3.0
64
+ jupyterlab_server==2.27.3
65
+ kiwisolver==1.4.7
66
+ markdown-it-py==3.0.0
67
+ MarkupSafe==2.1.5
68
+ matplotlib==3.9.3
69
+ matplotlib-inline==0.1.7
70
+ mdurl==0.1.2
71
+ mistune==3.0.2
72
+ mpmath==1.3.0
73
+ nbclient==0.10.0
74
+ nbconvert==7.16.4
75
+ nbformat==5.10.4
76
+ nest-asyncio==1.6.0
77
+ networkx==3.2.1
78
+ nltk==3.9.1
79
+ notebook_shim==0.2.4
80
+ numpy==2.0.2
81
+ opencv-python==4.10.0.84
82
+ orjson==3.10.12
83
+ overrides==7.7.0
84
+ packaging==24.2
85
+ pandas==2.2.3
86
+ pandocfilters==1.5.1
87
+ parso==0.8.4
88
+ pexpect==4.9.0
89
+ pillow==10.4.0
90
+ pipeline==0.1.0
91
+ platformdirs==4.3.6
92
+ prometheus_client==0.21.0
93
+ prompt_toolkit==3.0.48
94
+ psutil==6.1.0
95
+ ptyprocess==0.7.0
96
+ pure_eval==0.2.3
97
+ pycparser==2.22
98
+ pydantic==2.10.3
99
+ pydantic_core==2.27.1
100
+ pydub==0.25.1
101
+ Pygments==2.18.0
102
+ pyparsing==3.2.0
103
+ python-dateutil==2.9.0.post0
104
+ python-json-logger==2.0.7
105
+ python-multipart==0.0.19
106
+ pytz==2024.2
107
+ PyYAML==6.0.2
108
+ pyzmq==26.2.0
109
+ referencing==0.35.1
110
+ regex==2024.11.6
111
+ requests==2.32.3
112
+ rfc3339-validator==0.1.4
113
+ rfc3986-validator==0.1.1
114
+ rich==13.9.4
115
+ rpds-py==0.21.0
116
+ ruff==0.8.2
117
+ safetensors==0.4.5
118
+ scikit-learn==1.5.2
119
+ scipy==1.13.1
120
+ semantic-version==2.10.0
121
+ Send2Trash==1.8.3
122
+ shellingham==1.5.4
123
+ six==1.16.0
124
+ sklearn==0.0
125
+ sniffio==1.3.1
126
+ soupsieve==2.6
127
+ stack-data==0.6.3
128
+ starlette==0.41.3
129
+ sympy==1.13.1
130
+ terminado==0.18.1
131
+ threadpoolctl==3.5.0
132
+ tinycss2==1.4.0
133
+ tokenizers==0.21.0
134
+ tomli==2.1.0
135
+ tomlkit==0.12.0
136
+ torch==2.5.1
137
+ tornado==6.4.1
138
+ tqdm==4.67.0
139
+ traitlets==5.14.3
140
+ transformers==4.47.0
141
+ typer==0.15.1
142
+ types-python-dateutil==2.9.0.20241003
143
+ typing_extensions==4.12.2
144
+ tzdata==2024.2
145
+ uri-template==1.3.0
146
+ urllib3==2.2.3
147
+ uvicorn==0.32.1
148
+ wcwidth==0.2.13
149
+ webcolors==24.11.1
150
+ webencodings==0.5.1
151
+ websocket-client==1.8.0
152
+ websockets==12.0
153
+ Werkzeug==3.1.3
154
+ zipp==3.21.0