svjack commited on
Commit
7329ede
·
1 Parent(s): 6039b8b

Upload with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import utils
2
+ from diffusers.utils import deprecation_utils
3
+ from diffusers.models import cross_attention
4
+ utils.deprecate = lambda *arg, **kwargs: None
5
+ deprecation_utils.deprecate = lambda *arg, **kwargs: None
6
+ cross_attention.deprecate = lambda *arg, **kwargs: None
7
+
8
+ import os
9
+ import sys
10
+ '''
11
+ MAIN_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
12
+ sys.path.insert(0, MAIN_DIR)
13
+ os.chdir(MAIN_DIR)
14
+ '''
15
+
16
+ import gradio as gr
17
+ import numpy as np
18
+ import torch
19
+ import random
20
+
21
+ from annotator.util import resize_image, HWC3
22
+ from annotator.canny import CannyDetector
23
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
24
+ from diffusers.pipelines import DiffusionPipeline
25
+ from diffusers.schedulers import DPMSolverMultistepScheduler
26
+ from models import ControlLoRA, ControlLoRACrossAttnProcessor
27
+
28
+ apply_canny = CannyDetector()
29
+
30
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
+
32
+ pipeline = DiffusionPipeline.from_pretrained(
33
+ 'IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1', safety_checker=None
34
+ )
35
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
36
+ pipeline = pipeline.to(device)
37
+ unet: UNet2DConditionModel = pipeline.unet
38
+
39
+ #ckpt_path = "ckpts/sd-diffusiondb-canny-model-control-lora-zh"
40
+ ckpt_path = "svjack/canny-control-lora-zh"
41
+ control_lora = ControlLoRA.from_pretrained(ckpt_path)
42
+ control_lora = control_lora.to(device)
43
+
44
+ # load control lora attention processors
45
+ lora_attn_procs = {}
46
+ lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])
47
+ n_ch = len(unet.config.block_out_channels)
48
+ control_ids = [i for i in range(n_ch)]
49
+ for name in pipeline.unet.attn_processors.keys():
50
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
51
+ if name.startswith("mid_block"):
52
+ control_id = control_ids[-1]
53
+ elif name.startswith("up_blocks"):
54
+ block_id = int(name[len("up_blocks.")])
55
+ control_id = list(reversed(control_ids))[block_id]
56
+ elif name.startswith("down_blocks"):
57
+ block_id = int(name[len("down_blocks.")])
58
+ control_id = control_ids[block_id]
59
+
60
+ lora_layers = lora_layers_list[control_id]
61
+ if len(lora_layers) != 0:
62
+ lora_layer: ControlLoRACrossAttnProcessor = lora_layers.pop(0)
63
+ lora_attn_procs[name] = lora_layer
64
+
65
+ unet.set_attn_processor(lora_attn_procs)
66
+
67
+
68
+ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold):
69
+ with torch.no_grad():
70
+ img = resize_image(HWC3(input_image), image_resolution)
71
+ H, W, C = img.shape
72
+
73
+ detected_map = apply_canny(img, low_threshold, high_threshold)
74
+ detected_map = HWC3(detected_map)
75
+
76
+ control = torch.from_numpy(detected_map[...,::-1].copy().transpose([2,0,1])).float().to(device)[None] / 127.5 - 1
77
+ _ = control_lora(control).control_states
78
+
79
+ if seed == -1:
80
+ seed = random.randint(0, 65535)
81
+
82
+ # run inference
83
+ generator = torch.Generator(device=device).manual_seed(seed)
84
+ images = []
85
+ for i in range(num_samples):
86
+ _ = control_lora(control).control_states
87
+ image = pipeline(
88
+ prompt + ', ' + a_prompt, negative_prompt=n_prompt,
89
+ num_inference_steps=sample_steps, guidance_scale=scale, eta=eta,
90
+ generator=generator, height=H, width=W).images[0]
91
+ images.append(np.asarray(image))
92
+
93
+ results = images
94
+ return [255 - detected_map] + results
95
+
96
+
97
+ block = gr.Blocks().queue()
98
+ with block:
99
+ with gr.Row():
100
+ gr.Markdown("## Control Stable Diffusion with Canny Edge Maps")
101
+ with gr.Row():
102
+ with gr.Column():
103
+ input_image = gr.Image(source='upload', type="numpy")
104
+ prompt = gr.Textbox(label="Prompt")
105
+ run_button = gr.Button(label="Run")
106
+ with gr.Accordion("Advanced options", open=False):
107
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
108
+ image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
109
+ low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
110
+ high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
111
+ sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
112
+ scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
113
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
114
+ eta = gr.Number(label="eta", value=0.0)
115
+ a_prompt = gr.Textbox(label="Added Prompt", value='详细的模拟混合媒体拼贴画,帆布质地的当代艺术风格,朋克艺术,逼真主义,感性的身体,表现主义,极简主���。杰作,完美的组成,逼真的美丽的脸')
116
+ n_prompt = gr.Textbox(label="Negative Prompt",
117
+ value='低质量,模糊,混乱')
118
+ with gr.Column():
119
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
120
+ ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold]
121
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
122
+
123
+
124
+ block.launch(server_name='0.0.0.0')
125
+
126
+ #### block.launch(server_name='172.16.202.228', share=True)