dezzman commited on
Commit
f3f96df
·
verified ·
1 Parent(s): aa913a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -19
app.py CHANGED
@@ -1,15 +1,46 @@
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
- from diffusers import StableDiffusionPipeline
 
 
 
 
5
  from peft import PeftModel, LoraConfig
6
  import os
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def get_lora_sd_pipeline(
9
  ckpt_dir='./lora_logos',
10
  base_model_name_or_path=None,
11
  dtype=torch.float16,
12
- adapter_name="default"
 
13
  ):
14
 
15
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
@@ -22,7 +53,12 @@ def get_lora_sd_pipeline(
22
  if base_model_name_or_path is None:
23
  raise ValueError("Please specify the base model name or path")
24
 
25
- pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
 
 
 
 
 
26
  before_params = pipe.unet.parameters()
27
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
28
  pipe.unet.set_adapter(adapter_name)
@@ -35,7 +71,7 @@ def get_lora_sd_pipeline(
35
  if dtype in (torch.float16, torch.bfloat16):
36
  pipe.unet.half()
37
  pipe.text_encoder.half()
38
-
39
  return pipe
40
 
41
  def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
@@ -52,14 +88,36 @@ def align_embeddings(prompt_embeds, negative_prompt_embeds):
52
  return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
53
  torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
54
 
55
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
- model_id_default = "CompVis/stable-diffusion-v1-4"
57
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- pipe_default = get_lora_sd_pipeline(ckpt_dir='./lora_logos', base_model_name_or_path=model_id_default, dtype=torch_dtype).to(device)
60
 
61
- MAX_SEED = np.iinfo(np.int32).max
62
- MAX_IMAGE_SIZE = 1024
63
 
64
  def infer(
65
  prompt,
@@ -71,24 +129,59 @@ def infer(
71
  seed=42,
72
  guidance_scale=7.0,
73
  lora_scale=0.5,
 
 
 
 
 
 
 
74
  progress=gr.Progress(track_tqdm=True)
75
  ):
76
 
77
  generator = torch.Generator(device).manual_seed(seed)
78
 
79
- if model_id != model_id_default:
80
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
81
- prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
82
- negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
83
- prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
 
 
 
 
 
 
 
 
84
  else:
85
- pipe = pipe_default
86
- prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
87
- negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
88
- prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
90
  print(f"LoRA scale applied: {lora_scale}")
91
  pipe.fuse_lora(lora_scale=lora_scale)
 
 
 
 
92
 
93
  params = {
94
  'prompt_embeds': prompt_embeds,
@@ -99,6 +192,23 @@ def infer(
99
  'height': height,
100
  'generator': generator,
101
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  return pipe(**params).images[0]
104
 
@@ -169,6 +279,36 @@ with gr.Blocks(css=css) as demo:
169
  value=20,
170
  )
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  with gr.Accordion("Optional Settings", open=False):
173
  with gr.Row():
174
  width = gr.Slider(
@@ -204,6 +344,13 @@ with gr.Blocks(css=css) as demo:
204
  seed,
205
  guidance_scale,
206
  lora_scale,
 
 
 
 
 
 
 
207
  ],
208
  outputs=[result],
209
  )
 
1
  import gradio as gr
2
  import numpy as np
3
  import torch
4
+ from diffusers import (
5
+ StableDiffusionPipeline,
6
+ StableDiffusionControlNetPipeline,
7
+ ControlNetModel
8
+ )
9
  from peft import PeftModel, LoraConfig
10
  import os
11
 
12
+
13
+ MAX_SEED = np.iinfo(np.int32).max
14
+ MAX_IMAGE_SIZE = 1024
15
+ IP_ADAPTER = 'h94/IP-Adapter'
16
+ IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus_sd15.bin"
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model_id_default = "CompVis/stable-diffusion-v1-4"
20
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
+
22
+ hed = None
23
+ dict_controlnet = {
24
+ "edge_detection": "lllyasviel/sd-controlnet-canny",
25
+ # "pose_estimation": "lllyasviel/sd-controlnet-openpose",
26
+ # "depth_map": "lllyasviel/sd-controlnet-depth",
27
+ "scribble": "lllyasviel/sd-controlnet-scribble",
28
+ # "MLSD": "lllyasviel/sd-controlnet-mlsd"
29
+ }
30
+
31
+ controlnet = ControlNetModel.from_pretrained(
32
+ dict_controlnet["edge_detection"],
33
+ cache_dir="./models_cache",
34
+ torch_dtype=torch_dtype,
35
+ )
36
+
37
+
38
  def get_lora_sd_pipeline(
39
  ckpt_dir='./lora_logos',
40
  base_model_name_or_path=None,
41
  dtype=torch.float16,
42
+ adapter_name="default",
43
+ controlnet
44
  ):
45
 
46
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
 
53
  if base_model_name_or_path is None:
54
  raise ValueError("Please specify the base model name or path")
55
 
56
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
57
+ base_model_name_or_path,
58
+ torch_dtype=dtype,
59
+ controlnet=controlnet,
60
+ )
61
+
62
  before_params = pipe.unet.parameters()
63
  pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
64
  pipe.unet.set_adapter(adapter_name)
 
71
  if dtype in (torch.float16, torch.bfloat16):
72
  pipe.unet.half()
73
  pipe.text_encoder.half()
74
+
75
  return pipe
76
 
77
  def process_prompt(prompt, tokenizer, text_encoder, max_length=77):
 
88
  return torch.nn.functional.pad(prompt_embeds, (0, 0, 0, max_length - prompt_embeds.shape[1])), \
89
  torch.nn.functional.pad(negative_prompt_embeds, (0, 0, 0, max_length - negative_prompt_embeds.shape[1]))
90
 
91
+ def map_edge_detection(image_path: str) -> Image:
92
+ source_img = load_image(image_path).convert('RGB')
93
+ edges = cv.Canny(np.array(source_img), 80, 160)
94
+ edges = np.repeat(edges[:, :, None], 3, axis=2)
95
+ final_image = Image.fromarray(edges)
96
+ return final_image
97
+
98
+ def map_scribble(image_path: str) -> Image:
99
+ global hed
100
+ if not hed:
101
+ hed = HEDdetector.from_pretrained('lllyasviel/Annotators')
102
+
103
+ image = load_image(image_path).convert('RGB')
104
+ scribble_image = hed(image)
105
+ image_np = np.array(scribble_image)
106
+ image_np = cv.medianBlur(image_np, 3)
107
+ image = cv.convertScaleAbs(image_np, alpha=1.5, beta=0)
108
+ final_image = Image.fromarray(image)
109
+ return final_image
110
+
111
+
112
+
113
+ pipe = get_lora_sd_pipeline(
114
+ ckpt_dir='./lora_logos',
115
+ base_model_name_or_path=model_id_default,
116
+ dtype=torch_dtype,
117
+ controlnet=controlnet
118
+ ).to(device)
119
 
 
120
 
 
 
121
 
122
  def infer(
123
  prompt,
 
129
  seed=42,
130
  guidance_scale=7.0,
131
  lora_scale=0.5,
132
+ cn_enable=False,
133
+ cn_strength=0.0,
134
+ cn_mode='edge_detection',
135
+ cn_image=None,
136
+ ip_enable=False,
137
+ ip_scale=0.5,
138
+ ip_image=None,
139
  progress=gr.Progress(track_tqdm=True)
140
  ):
141
 
142
  generator = torch.Generator(device).manual_seed(seed)
143
 
144
+ global pipe
145
+ global controlnet
146
+
147
+ controlnet_changed = False
148
+
149
+ if cn_enable:
150
+ if dict_controlnet[cn_mode] != pipe.controlnet._name_or_path:
151
+ controlnet = ControlNetModel.from_pretrained(
152
+ dict_controlnet[cn_mode],
153
+ cache_dir="./models_cache",
154
+ torch_dtype=torch_dtype
155
+ )
156
+ controlnet_changed = True
157
  else:
158
+ cn_strength = 0.0 # отключаем контролнет принудительно
159
+
160
+ if model_id != pipe._name_or_path:
161
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
162
+ model_id,
163
+ torch_dtype=torch_dtype,
164
+ controlnet=controlnet,
165
+ controlnet_conditioning_scale=cn_strength,
166
+ ).to(device)
167
+ elif (model_id == pipe._name_or_path) and controlnet_changed:
168
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
169
+ model_id,
170
+ torch_dtype=torch_dtype,
171
+ controlnet=controlnet,
172
+ controlnet_conditioning_scale=cn_strength,
173
+ ).to(device)
174
+ print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
175
+ print(f"LoRA scale applied: {lora_scale}")
176
+ pipe.fuse_lora(lora_scale=lora_scale)
177
+ elif (model_id == pipe._name_or_path) and not controlnet_changed:
178
  print(f"LoRA adapter loaded: {pipe.unet.active_adapters}")
179
  print(f"LoRA scale applied: {lora_scale}")
180
  pipe.fuse_lora(lora_scale=lora_scale)
181
+
182
+ prompt_embeds = process_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
183
+ negative_prompt_embeds = process_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
184
+ prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
185
 
186
  params = {
187
  'prompt_embeds': prompt_embeds,
 
192
  'height': height,
193
  'generator': generator,
194
  }
195
+
196
+ if cn_enable:
197
+ params['controlnet_conditioning_scale'] = cn_strength
198
+ if cn_mode == 'edge_detection':
199
+ control_image = map_edge_detection(cn_image)
200
+ elif cn_mode == 'scribble':
201
+ control_image = map_scribble(cn_image)
202
+ params['control_image'] = control_image
203
+
204
+ if ip_enable:
205
+ pipe.load_ip_adapter(
206
+ IP_ADAPTER,
207
+ subfolder="models",
208
+ weight_name=IP_ADAPTER_WEIGHT_NAME,
209
+ )
210
+ params['ip_adapter_image'] = load_image(ip_image).convert('RGB')
211
+ pipe.ip_scale(0.6)
212
 
213
  return pipe(**params).images[0]
214
 
 
279
  value=20,
280
  )
281
 
282
+ # Секция Control Net
283
+ cn_enable = gr.Checkbox(label="Enable ControlNet")
284
+ with gr.Column(visible=False) as cn_options:
285
+ with gr.Row():
286
+ cn_strength = gr.Slider(0, 2, value=0.8, step=0.1, label="Control strength", interactive=True)
287
+ cn_mode = gr.Dropdown(
288
+ choices=["edge_detection", "scribble"],
289
+ label="Work regime",
290
+ interactive=True,
291
+ )
292
+ cn_image = gr.Image(type="filepath", label="Control image")
293
+
294
+ cn_enable.change(
295
+ lambda x: gr.update(visible=x),
296
+ inputs=cn_enable,
297
+ outputs=cn_options
298
+ )
299
+
300
+ # Секция IP-Adapter
301
+ ip_enable = gr.Checkbox(label="Enable IP-Adapter")
302
+ with gr.Column(visible=False) as ip_options:
303
+ ip_scale = gr.Slider(0, 1, value=0.5, step=0.1, label="IP-adapter scale", interactive=True)
304
+ ip_image = gr.Image(type="filepath", label="IP-adapter image", interactive=True)
305
+
306
+ ip_enable.change(
307
+ lambda x: gr.update(visible=x),
308
+ inputs=ip_enable,
309
+ outputs=ip_options
310
+ )
311
+
312
  with gr.Accordion("Optional Settings", open=False):
313
  with gr.Row():
314
  width = gr.Slider(
 
344
  seed,
345
  guidance_scale,
346
  lora_scale,
347
+ cn_enable,
348
+ cn_strength,
349
+ cn_mode,
350
+ cn_image,
351
+ ip_enable,
352
+ ip_scale,
353
+ ip_image
354
  ],
355
  outputs=[result],
356
  )