svjack commited on
Commit
e1f5713
·
1 Parent(s): 7329ede

Delete gradio_canny2image_zh.py

Browse files
Files changed (1) hide show
  1. gradio_canny2image_zh.py +0 -126
gradio_canny2image_zh.py DELETED
@@ -1,126 +0,0 @@
1
- from diffusers import utils
2
- from diffusers.utils import deprecation_utils
3
- from diffusers.models import cross_attention
4
- utils.deprecate = lambda *arg, **kwargs: None
5
- deprecation_utils.deprecate = lambda *arg, **kwargs: None
6
- cross_attention.deprecate = lambda *arg, **kwargs: None
7
-
8
- import os
9
- import sys
10
- '''
11
- MAIN_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
12
- sys.path.insert(0, MAIN_DIR)
13
- os.chdir(MAIN_DIR)
14
- '''
15
-
16
- import gradio as gr
17
- import numpy as np
18
- import torch
19
- import random
20
-
21
- from annotator.util import resize_image, HWC3
22
- from annotator.canny import CannyDetector
23
- from diffusers.models.unet_2d_condition import UNet2DConditionModel
24
- from diffusers.pipelines import DiffusionPipeline
25
- from diffusers.schedulers import DPMSolverMultistepScheduler
26
- from models import ControlLoRA, ControlLoRACrossAttnProcessor
27
-
28
- apply_canny = CannyDetector()
29
-
30
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
31
-
32
- pipeline = DiffusionPipeline.from_pretrained(
33
- 'IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1', safety_checker=None
34
- )
35
- pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
36
- pipeline = pipeline.to(device)
37
- unet: UNet2DConditionModel = pipeline.unet
38
-
39
- #ckpt_path = "ckpts/sd-diffusiondb-canny-model-control-lora-zh"
40
- ckpt_path = "svjack/canny-control-lora-zh"
41
- control_lora = ControlLoRA.from_pretrained(ckpt_path)
42
- control_lora = control_lora.to(device)
43
-
44
- # load control lora attention processors
45
- lora_attn_procs = {}
46
- lora_layers_list = list([list(layer_list) for layer_list in control_lora.lora_layers])
47
- n_ch = len(unet.config.block_out_channels)
48
- control_ids = [i for i in range(n_ch)]
49
- for name in pipeline.unet.attn_processors.keys():
50
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
51
- if name.startswith("mid_block"):
52
- control_id = control_ids[-1]
53
- elif name.startswith("up_blocks"):
54
- block_id = int(name[len("up_blocks.")])
55
- control_id = list(reversed(control_ids))[block_id]
56
- elif name.startswith("down_blocks"):
57
- block_id = int(name[len("down_blocks.")])
58
- control_id = control_ids[block_id]
59
-
60
- lora_layers = lora_layers_list[control_id]
61
- if len(lora_layers) != 0:
62
- lora_layer: ControlLoRACrossAttnProcessor = lora_layers.pop(0)
63
- lora_attn_procs[name] = lora_layer
64
-
65
- unet.set_attn_processor(lora_attn_procs)
66
-
67
-
68
- def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold):
69
- with torch.no_grad():
70
- img = resize_image(HWC3(input_image), image_resolution)
71
- H, W, C = img.shape
72
-
73
- detected_map = apply_canny(img, low_threshold, high_threshold)
74
- detected_map = HWC3(detected_map)
75
-
76
- control = torch.from_numpy(detected_map[...,::-1].copy().transpose([2,0,1])).float().to(device)[None] / 127.5 - 1
77
- _ = control_lora(control).control_states
78
-
79
- if seed == -1:
80
- seed = random.randint(0, 65535)
81
-
82
- # run inference
83
- generator = torch.Generator(device=device).manual_seed(seed)
84
- images = []
85
- for i in range(num_samples):
86
- _ = control_lora(control).control_states
87
- image = pipeline(
88
- prompt + ', ' + a_prompt, negative_prompt=n_prompt,
89
- num_inference_steps=sample_steps, guidance_scale=scale, eta=eta,
90
- generator=generator, height=H, width=W).images[0]
91
- images.append(np.asarray(image))
92
-
93
- results = images
94
- return [255 - detected_map] + results
95
-
96
-
97
- block = gr.Blocks().queue()
98
- with block:
99
- with gr.Row():
100
- gr.Markdown("## Control Stable Diffusion with Canny Edge Maps")
101
- with gr.Row():
102
- with gr.Column():
103
- input_image = gr.Image(source='upload', type="numpy")
104
- prompt = gr.Textbox(label="Prompt")
105
- run_button = gr.Button(label="Run")
106
- with gr.Accordion("Advanced options", open=False):
107
- num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
108
- image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
109
- low_threshold = gr.Slider(label="Canny low threshold", minimum=1, maximum=255, value=100, step=1)
110
- high_threshold = gr.Slider(label="Canny high threshold", minimum=1, maximum=255, value=200, step=1)
111
- sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
112
- scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=9.0, step=0.1)
113
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
114
- eta = gr.Number(label="eta", value=0.0)
115
- a_prompt = gr.Textbox(label="Added Prompt", value='详细的模拟混合媒体拼贴画,帆布质地的当代艺术风格,朋克艺术,逼真主义,感性的身体,表现主义,极简主���。杰作,完美的组成,逼真的美丽的脸')
116
- n_prompt = gr.Textbox(label="Negative Prompt",
117
- value='低质量,模糊,混乱')
118
- with gr.Column():
119
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
120
- ips = [input_image, prompt, a_prompt, n_prompt, num_samples, image_resolution, sample_steps, scale, seed, eta, low_threshold, high_threshold]
121
- run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
122
-
123
-
124
- block.launch(server_name='0.0.0.0')
125
-
126
- #### block.launch(server_name='172.16.202.228', share=True)