guibegotti commited on
Commit
2e1d0be
·
verified ·
1 Parent(s): 78b5e0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +319 -30
app.py CHANGED
@@ -1,51 +1,340 @@
1
-
 
 
 
 
2
  import gradio as gr
3
- import requests
4
  from PIL import Image
5
- import tempfile
6
-
7
- def call_api(prompt, structure_image, style_image, depth_strength, style_strength):
8
- with tempfile.NamedTemporaryFile(suffix=".png") as structure_tmp, tempfile.NamedTemporaryFile(suffix=".png") as style_tmp:
9
- structure_image.save(structure_tmp.name)
10
- style_image.save(style_tmp.name)
11
-
12
- files = {
13
- 'structure_image': open(structure_tmp.name, 'rb'),
14
- 'style_image': open(style_tmp.name, 'rb'),
15
- }
16
- data = {
17
- 'prompt': prompt,
18
- 'depth_strength': depth_strength,
19
- 'style_strength': style_strength,
20
- }
21
- response = requests.post("http://localhost:7860", data=data, files=files)
22
- return Image.open(response.raw)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  with gr.Blocks() as app:
25
  gr.Markdown("# FLUX Style Shaping")
26
- gr.Markdown("Interface conectada à API para geração de imagem com estilo e estrutura.")
27
-
28
  with gr.Row():
29
  with gr.Column():
30
  prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
31
  with gr.Row():
32
  with gr.Group():
33
- structure_image = gr.Image(label="Structure Image")
34
  depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Depth Strength")
35
  with gr.Group():
36
- style_image = gr.Image(label="Style Image")
37
  style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Style Strength")
38
  generate_btn = gr.Button("Generate")
39
- output_image = gr.Image(label="Generated Image")
40
-
 
 
 
 
 
 
 
 
41
  with gr.Column():
42
- pass
43
-
44
  generate_btn.click(
45
- fn=call_api,
46
  inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
47
  outputs=[output_image]
48
  )
49
 
50
  if __name__ == "__main__":
51
- app.launch(show_api=False)
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
+ import torch
6
  import gradio as gr
 
7
  from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
+ import spaces
10
+ from comfy import model_management
11
+
12
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models")
13
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models")
14
+ hf_hub_download(repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="models/clip_vision")
15
+ hf_hub_download(repo_id="Kijai/DepthAnythingV2-safetensors", filename="depth_anything_v2_vitl_fp32.safetensors", local_dir="models/depthanything")
16
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae/FLUX1")
17
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders")
18
+ t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders/t5")
19
+
20
+ # Import all the necessary functions from the original script
21
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
22
+ try:
23
+ return obj[index]
24
+ except KeyError:
25
+ return obj["result"][index]
26
+
27
+ # Add all the necessary setup functions from the original script
28
+ def find_path(name: str, path: str = None) -> str:
29
+ if path is None:
30
+ path = os.getcwd()
31
+ if name in os.listdir(path):
32
+ path_name = os.path.join(path, name)
33
+ print(f"{name} found: {path_name}")
34
+ return path_name
35
+ parent_directory = os.path.dirname(path)
36
+ if parent_directory == path:
37
+ return None
38
+ return find_path(name, parent_directory)
39
+
40
+ def add_comfyui_directory_to_sys_path() -> None:
41
+ comfyui_path = find_path("ComfyUI")
42
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
43
+ sys.path.append(comfyui_path)
44
+ print(f"'{comfyui_path}' added to sys.path")
45
+
46
+ def add_extra_model_paths() -> None:
47
+ try:
48
+ from main import load_extra_path_config
49
+ except ImportError:
50
+ from utils.extra_config import load_extra_path_config
51
+ extra_model_paths = find_path("extra_model_paths.yaml")
52
+ if extra_model_paths is not None:
53
+ load_extra_path_config(extra_model_paths)
54
+ else:
55
+ print("Could not find the extra_model_paths config file.")
56
+
57
+ # Initialize paths
58
+ add_comfyui_directory_to_sys_path()
59
+ add_extra_model_paths()
60
+
61
+ def import_custom_nodes() -> None:
62
+ import asyncio
63
+ import execution
64
+ from nodes import init_extra_nodes
65
+ import server
66
+ loop = asyncio.new_event_loop()
67
+ asyncio.set_event_loop(loop)
68
+ server_instance = server.PromptServer(loop)
69
+ execution.PromptQueue(server_instance)
70
+ init_extra_nodes()
71
+
72
+ # Import all necessary nodes
73
+ from nodes import (
74
+ StyleModelLoader,
75
+ VAEEncode,
76
+ NODE_CLASS_MAPPINGS,
77
+ LoadImage,
78
+ CLIPVisionLoader,
79
+ SaveImage,
80
+ VAELoader,
81
+ CLIPVisionEncode,
82
+ DualCLIPLoader,
83
+ EmptyLatentImage,
84
+ VAEDecode,
85
+ UNETLoader,
86
+ CLIPTextEncode,
87
+ )
88
+
89
+ # Initialize all constant nodes and models in global context
90
+ import_custom_nodes()
91
+
92
+ # Global variables for preloaded models and constants
93
+ #with torch.inference_mode():
94
+ # Initialize constants
95
+ intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
96
+ CONST_1024 = intconstant.get_value(value=1024)
97
+
98
+ # Load CLIP
99
+ dualcliploader = DualCLIPLoader()
100
+ CLIP_MODEL = dualcliploader.load_clip(
101
+ clip_name1="t5/t5xxl_fp16.safetensors",
102
+ clip_name2="clip_l.safetensors",
103
+ type="flux",
104
+ )
105
+
106
+ # Load VAE
107
+ vaeloader = VAELoader()
108
+ VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
109
+
110
+ # Load UNET
111
+ unetloader = UNETLoader()
112
+ UNET_MODEL = unetloader.load_unet(
113
+ unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
114
+ )
115
+
116
+ # Load CLIP Vision
117
+ clipvisionloader = CLIPVisionLoader()
118
+ CLIP_VISION_MODEL = clipvisionloader.load_clip(
119
+ clip_name="sigclip_vision_patch14_384.safetensors"
120
+ )
121
+
122
+ # Load Style Model
123
+ stylemodelloader = StyleModelLoader()
124
+ STYLE_MODEL = stylemodelloader.load_style_model(
125
+ style_model_name="flux1-redux-dev.safetensors"
126
+ )
127
+
128
+ # Initialize samplers
129
+ ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
130
+ SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
131
+
132
+ # Initialize depth model
133
+ cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
134
+ downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
135
+ DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
136
+ model="depth_anything_v2_vitl_fp32.safetensors"
137
+ )
138
+
139
+ cliptextencode = CLIPTextEncode()
140
+ loadimage = LoadImage()
141
+ vaeencode = VAEEncode()
142
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
143
+ instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
144
+ clipvisionencode = CLIPVisionEncode()
145
+ stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
146
+ emptylatentimage = EmptyLatentImage()
147
+ basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
148
+ basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
149
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
150
+ samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
151
+ vaedecode = VAEDecode()
152
+ cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
153
+ saveimage = SaveImage()
154
+ getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
155
+ depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
156
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
157
+
158
+ model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL]
159
+
160
+ model_management.load_models_gpu([
161
+ loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders
162
+ ])
163
+
164
+ @spaces.GPU
165
+ def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
166
+ """Main generation function that processes inputs and returns the path to the generated image."""
167
+ with torch.inference_mode():
168
+ # Set up CLIP
169
+ clip_switch = cr_clip_input_switch.switch(
170
+ Input=1,
171
+ clip1=get_value_at_index(CLIP_MODEL, 0),
172
+ clip2=get_value_at_index(CLIP_MODEL, 0),
173
+ )
174
+
175
+ # Encode text
176
+ text_encoded = cliptextencode.encode(
177
+ text=prompt,
178
+ clip=get_value_at_index(clip_switch, 0),
179
+ )
180
+ empty_text = cliptextencode.encode(
181
+ text="",
182
+ clip=get_value_at_index(clip_switch, 0),
183
+ )
184
+
185
+ # Process structure image
186
+ structure_img = loadimage.load_image(image=structure_image)
187
+
188
+ # Resize image
189
+ resized_img = imageresize.execute(
190
+ width=get_value_at_index(CONST_1024, 0),
191
+ height=get_value_at_index(CONST_1024, 0),
192
+ interpolation="bicubic",
193
+ method="keep proportion",
194
+ condition="always",
195
+ multiple_of=16,
196
+ image=get_value_at_index(structure_img, 0),
197
+ )
198
+
199
+ # Get image size
200
+ size_info = getimagesizeandcount.getsize(
201
+ image=get_value_at_index(resized_img, 0)
202
+ )
203
+
204
+ # Encode VAE
205
+ vae_encoded = vaeencode.encode(
206
+ pixels=get_value_at_index(size_info, 0),
207
+ vae=get_value_at_index(VAE_MODEL, 0),
208
+ )
209
+
210
+ # Process depth
211
+ depth_processed = depthanything_v2.process(
212
+ da_model=get_value_at_index(DEPTH_MODEL, 0),
213
+ images=get_value_at_index(size_info, 0),
214
+ )
215
+
216
+ # Apply Flux guidance
217
+ flux_guided = fluxguidance.append(
218
+ guidance=depth_strength,
219
+ conditioning=get_value_at_index(text_encoded, 0),
220
+ )
221
+
222
+ # Process style image
223
+ style_img = loadimage.load_image(image=style_image)
224
+
225
+ # Encode style with CLIP Vision
226
+ style_encoded = clipvisionencode.encode(
227
+ crop="center",
228
+ clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
229
+ image=get_value_at_index(style_img, 0),
230
+ )
231
+
232
+ # Set up conditioning
233
+ conditioning = instructpixtopixconditioning.encode(
234
+ positive=get_value_at_index(flux_guided, 0),
235
+ negative=get_value_at_index(empty_text, 0),
236
+ vae=get_value_at_index(VAE_MODEL, 0),
237
+ pixels=get_value_at_index(depth_processed, 0),
238
+ )
239
+
240
+ # Apply style
241
+ style_applied = stylemodelapplyadvanced.apply_stylemodel(
242
+ strength=style_strength,
243
+ conditioning=get_value_at_index(conditioning, 0),
244
+ style_model=get_value_at_index(STYLE_MODEL, 0),
245
+ clip_vision_output=get_value_at_index(style_encoded, 0),
246
+ )
247
+
248
+ # Set up empty latent
249
+ empty_latent = emptylatentimage.generate(
250
+ width=get_value_at_index(resized_img, 1),
251
+ height=get_value_at_index(resized_img, 2),
252
+ batch_size=1,
253
+ )
254
+
255
+ # Set up guidance
256
+ guided = basicguider.get_guider(
257
+ model=get_value_at_index(UNET_MODEL, 0),
258
+ conditioning=get_value_at_index(style_applied, 0),
259
+ )
260
+
261
+ # Set up scheduler
262
+ schedule = basicscheduler.get_sigmas(
263
+ scheduler="simple",
264
+ steps=28,
265
+ denoise=1,
266
+ model=get_value_at_index(UNET_MODEL, 0),
267
+ )
268
+
269
+ # Generate random noise
270
+ noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
271
+
272
+ # Sample
273
+ sampled = samplercustomadvanced.sample(
274
+ noise=get_value_at_index(noise, 0),
275
+ guider=get_value_at_index(guided, 0),
276
+ sampler=get_value_at_index(SAMPLER, 0),
277
+ sigmas=get_value_at_index(schedule, 0),
278
+ latent_image=get_value_at_index(empty_latent, 0),
279
+ )
280
+
281
+ # Decode VAE
282
+ decoded = vaedecode.decode(
283
+ samples=get_value_at_index(sampled, 0),
284
+ vae=get_value_at_index(VAE_MODEL, 0),
285
+ )
286
+
287
+ # Save image
288
+ prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
289
+
290
+ saved = saveimage.save_images(
291
+ filename_prefix=get_value_at_index(prefix, 0),
292
+ images=get_value_at_index(decoded, 0),
293
+ )
294
+ saved_path = f"output/{saved['ui']['images'][0]['filename']}"
295
+ return saved_path
296
+
297
+ # Create Gradio interface
298
+
299
+ examples = [
300
+ ["", "mona.png", "receita-tacos.webp", 15, 0.6],
301
+ ["a woman looking at a house catching fire on the background", "disaster_girl.png", "abaporu.jpg", 15, 0.15],
302
+ ["istanbul aerial, dramatic photography", "natasha.png", "istambul.jpg", 15, 0.5],
303
+ ]
304
+
305
+ output_image = gr.Image(label="Generated Image")
306
 
307
  with gr.Blocks() as app:
308
  gr.Markdown("# FLUX Style Shaping")
309
+ gr.Markdown("Flux[dev] Redux + Flux[dev] Depth ComfyUI workflow by [Nathan Shipley](https://x.com/CitizenPlain) running directly on Gradio. [workflow](https://gist.github.com/nathanshipley/7a9ac1901adde76feebe58d558026f68) - [how to convert your any comfy workflow to gradio](https://huggingface.co/blog/run-comfyui-workflows-on-spaces)")
 
310
  with gr.Row():
311
  with gr.Column():
312
  prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
313
  with gr.Row():
314
  with gr.Group():
315
+ structure_image = gr.Image(label="Structure Image", type="filepath")
316
  depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Depth Strength")
317
  with gr.Group():
318
+ style_image = gr.Image(label="Style Image", type="filepath")
319
  style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Style Strength")
320
  generate_btn = gr.Button("Generate")
321
+
322
+ gr.Examples(
323
+ examples=examples,
324
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
325
+ outputs=[output_image],
326
+ fn=generate_image,
327
+ cache_examples=True,
328
+ cache_mode="lazy"
329
+ )
330
+
331
  with gr.Column():
332
+ output_image.render()
 
333
  generate_btn.click(
334
+ fn=generate_image,
335
  inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
336
  outputs=[output_image]
337
  )
338
 
339
  if __name__ == "__main__":
340
+ app.launch(share=True)