Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
import torch
|
@@ -17,15 +19,11 @@ def get_lora_sd_pipeline(
|
|
17 |
base_model_name_or_path=None,
|
18 |
dtype=torch.float16,
|
19 |
adapter_name="default"
|
20 |
-
):
|
|
|
21 |
unet_sub_dir = os.path.join(lora_dir, "unet")
|
22 |
text_encoder_sub_dir = os.path.join(lora_dir, "text_encoder")
|
23 |
|
24 |
-
# Проверка существования директорий LoRA
|
25 |
-
print(f"LoRA directory exists: {os.path.exists(lora_dir)}")
|
26 |
-
print(f"UNet LoRA exists: {os.path.exists(unet_sub_dir)}")
|
27 |
-
print(f"Text encoder LoRA exists: {os.path.exists(text_encoder_sub_dir)}")
|
28 |
-
|
29 |
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
|
30 |
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
|
31 |
base_model_name_or_path = config.base_model_name_or_path
|
@@ -34,30 +32,14 @@ def get_lora_sd_pipeline(
|
|
34 |
raise ValueError("Укажите название базовой модели или путь к ней")
|
35 |
|
36 |
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
# Логирование параметров до применения LoRA
|
39 |
-
before_params = list(pipe.unet.parameters())
|
40 |
-
|
41 |
-
# Применение LoRA к UNet
|
42 |
-
if os.path.exists(unet_sub_dir):
|
43 |
-
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
|
44 |
-
pipe.unet.set_adapter(adapter_name)
|
45 |
-
|
46 |
-
# Применение LoRA к текстовому энкодеру (если есть)
|
47 |
if os.path.exists(text_encoder_sub_dir):
|
48 |
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
|
49 |
|
50 |
-
# Логирование параметров после применения LoRA
|
51 |
-
after_params = list(pipe.unet.parameters())
|
52 |
-
print(f"Parameters changed: {before_params != after_params}")
|
53 |
-
|
54 |
-
# Детальное сравнение параметров
|
55 |
-
for i, (param1, param2) in enumerate(zip(before_params, after_params)):
|
56 |
-
if not torch.equal(param1, param2):
|
57 |
-
print(f"Parameter {i} changed.")
|
58 |
-
else:
|
59 |
-
print(f"Parameter {i} did not change.")
|
60 |
-
|
61 |
if dtype in (torch.float16, torch.bfloat16):
|
62 |
pipe.unet.half()
|
63 |
pipe.text_encoder.half()
|
@@ -89,8 +71,8 @@ def infer(
|
|
89 |
guidance_scale=7.5,
|
90 |
lora_scale=0.5,
|
91 |
progress=gr.Progress(track_tqdm=True)
|
92 |
-
):
|
93 |
-
|
94 |
generator = torch.Generator(device).manual_seed(seed)
|
95 |
|
96 |
if model != model_default:
|
@@ -103,13 +85,7 @@ def infer(
|
|
103 |
prompt_embeds = long_prompt_encoder(prompt, pipe.tokenizer, pipe.text_encoder)
|
104 |
negative_prompt_embeds = long_prompt_encoder(negative_prompt, pipe.tokenizer, pipe.text_encoder)
|
105 |
prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
|
106 |
-
|
107 |
-
# Логирование параметров до и после применения LoRA
|
108 |
-
before_params = list(pipe.unet.parameters())
|
109 |
-
print(f"Applying LoRA with scale: {lora_scale}")
|
110 |
-
pipe.fuse_lora(lora_scale=lora_scale)
|
111 |
-
after_params = list(pipe.unet.parameters())
|
112 |
-
print(f"Parameters changed: {before_params != after_params}")
|
113 |
|
114 |
params = {
|
115 |
'prompt_embeds': prompt_embeds,
|
@@ -222,6 +198,81 @@ with gr.Blocks(css=css) as demo:
|
|
222 |
value=512,
|
223 |
)
|
224 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
gr.Examples(examples=examples, inputs=[prompt])
|
226 |
gr.Examples(examples=examples_negative, inputs=[negative_prompt])
|
227 |
|
@@ -248,4 +299,4 @@ with gr.Blocks(css=css) as demo:
|
|
248 |
|
249 |
if __name__ == "__main__":
|
250 |
demo.launch()
|
251 |
-
|
|
|
1 |
+
# app.py 07.02.25
|
2 |
+
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
import torch
|
|
|
19 |
base_model_name_or_path=None,
|
20 |
dtype=torch.float16,
|
21 |
adapter_name="default"
|
22 |
+
):
|
23 |
+
|
24 |
unet_sub_dir = os.path.join(lora_dir, "unet")
|
25 |
text_encoder_sub_dir = os.path.join(lora_dir, "text_encoder")
|
26 |
|
|
|
|
|
|
|
|
|
|
|
27 |
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
|
28 |
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
|
29 |
base_model_name_or_path = config.base_model_name_or_path
|
|
|
32 |
raise ValueError("Укажите название базовой модели или путь к ней")
|
33 |
|
34 |
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype)
|
35 |
+
before_params = pipe.unet.parameters()
|
36 |
+
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
|
37 |
+
pipe.unet.set_adapter(adapter_name)
|
38 |
+
after_params = pipe.unet.parameters()
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
if os.path.exists(text_encoder_sub_dir):
|
41 |
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name)
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
if dtype in (torch.float16, torch.bfloat16):
|
44 |
pipe.unet.half()
|
45 |
pipe.text_encoder.half()
|
|
|
71 |
guidance_scale=7.5,
|
72 |
lora_scale=0.5,
|
73 |
progress=gr.Progress(track_tqdm=True)
|
74 |
+
):
|
75 |
+
|
76 |
generator = torch.Generator(device).manual_seed(seed)
|
77 |
|
78 |
if model != model_default:
|
|
|
85 |
prompt_embeds = long_prompt_encoder(prompt, pipe.tokenizer, pipe.text_encoder)
|
86 |
negative_prompt_embeds = long_prompt_encoder(negative_prompt, pipe.tokenizer, pipe.text_encoder)
|
87 |
prompt_embeds, negative_prompt_embeds = align_embeddings(prompt_embeds, negative_prompt_embeds)
|
88 |
+
pipe.fuse_lora(lora_scale=lora_scale) # Коэфф. добавления lora
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
params = {
|
91 |
'prompt_embeds': prompt_embeds,
|
|
|
198 |
value=512,
|
199 |
)
|
200 |
|
201 |
+
# Функция для работы с ControlNet ---------------------------------------------------------------------
|
202 |
+
def process_input_ControlNet(image, use_control_net, control_strength, control_mode):
|
203 |
+
if use_control_net:
|
204 |
+
# Логика для обработки с использованием ControlNet
|
205 |
+
result = f"ControlNet активен! Режим: {control_mode}, Интенсивность: {control_strength}"
|
206 |
+
else:
|
207 |
+
# Логика для обработки без ControlNet
|
208 |
+
result = "ControlNet отключен."
|
209 |
+
return result
|
210 |
+
|
211 |
+
with gr.Blocks():
|
212 |
+
with gr.Row():
|
213 |
+
# Чекбокс для включения/отключения ControlNet
|
214 |
+
use_control_net = gr.Checkbox(
|
215 |
+
label="Use ControlNet",
|
216 |
+
value=False,
|
217 |
+
)
|
218 |
+
|
219 |
+
# Дополнительные опции для ControlNet
|
220 |
+
with gr.Column(visible=False) as control_net_options:
|
221 |
+
# Слайдер для настройки интенсивности
|
222 |
+
control_strength = gr.Slider(
|
223 |
+
label="Control Strength",
|
224 |
+
minimum=0.0,
|
225 |
+
maximum=1.0,
|
226 |
+
value=0.5,
|
227 |
+
step=0.05,
|
228 |
+
)
|
229 |
+
|
230 |
+
# Выпадающий список для выбора режима
|
231 |
+
control_mode = gr.Dropdown(
|
232 |
+
label="Control Mode",
|
233 |
+
choices=[
|
234 |
+
"edge_detection",
|
235 |
+
"canny_edge_detection",
|
236 |
+
"pose_estimation",
|
237 |
+
"depth_map",
|
238 |
+
"segmentation_map",
|
239 |
+
"scribble_sketch",
|
240 |
+
"normal_map",
|
241 |
+
"hed_edge_detection",
|
242 |
+
"openpose",
|
243 |
+
"mlsd_line_detection",
|
244 |
+
"scribble_diffusion",
|
245 |
+
"semantic_segmentation",
|
246 |
+
"style_transfer",
|
247 |
+
"colorization",
|
248 |
+
"custom_map"
|
249 |
+
],
|
250 |
+
value="pose_estimation",
|
251 |
+
)
|
252 |
+
|
253 |
+
# Окно для загрузки изображений
|
254 |
+
control_image = gr.Image(label="Upload Control Image")
|
255 |
+
|
256 |
+
# Кнопка для запуска работы ControlNet
|
257 |
+
run_button = gr.Button("Run")
|
258 |
+
|
259 |
+
# Текстовое поле для вывода результата
|
260 |
+
output = gr.Textbox(label="Output")
|
261 |
+
|
262 |
+
# Логика для отображения/скрытия дополнительных опций
|
263 |
+
use_control_net.change(
|
264 |
+
fn=lambda x: gr.Row.update(visible=x),
|
265 |
+
inputs=use_control_net,
|
266 |
+
outputs=control_net_options,
|
267 |
+
)
|
268 |
+
|
269 |
+
# Привязка кнопки Run к функции работы с ControlNet
|
270 |
+
run_button.click(
|
271 |
+
fn=process_input_ControlNet,
|
272 |
+
inputs=[control_image, use_control_net, control_strength, control_mode],
|
273 |
+
outputs=output,
|
274 |
+
)
|
275 |
+
|
276 |
gr.Examples(examples=examples, inputs=[prompt])
|
277 |
gr.Examples(examples=examples_negative, inputs=[negative_prompt])
|
278 |
|
|
|
299 |
|
300 |
if __name__ == "__main__":
|
301 |
demo.launch()
|
302 |
+
|