File size: 5,323 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import diffusers
import gradio as gr
from modules import scripts, processing, shared, devices, sd_models


class Script(scripts.Script):
    def title(self):
        return 'LEdits++'

    def show(self, is_img2img):
        return is_img2img if shared.backend == shared.Backend.DIFFUSERS else False

    # return signature is array of gradio components
    def ui(self, _is_img2img):
        with gr.Row():
            gr.HTML('<a href="https://leditsplusplus-project.static.hf.space/index.html">&nbsp LEdits++</a><br>')
        with gr.Row():
            edit_start = gr.Slider(label='Edit start', minimum=0.0, maximum=1.0, step=0.01, value=0.1)
            edit_stop = gr.Slider(label='Edit stop', minimum=0.0, maximum=1.0, step=0.01, value=1.0)
            intersect_mask = gr.Checkbox(label='Smooth mask', value=True)
        with gr.Row():
            prompt1 = gr.Textbox(show_label=False, placeholder='Positive prompt')
            scale1 = gr.Slider(label='Scale', minimum=0.0, maximum=1.0, step=0.01, value=0.5)
            threshold1 = gr.Slider(label='Threshold', minimum=0.0, maximum=1.0, step=0.01, value=0.9)
        with gr.Row():
            prompt2 = gr.Textbox(show_label=False, placeholder='Negative prompt')
            scale2 = gr.Slider(label='Scale', minimum=0.0, maximum=1.0, step=0.01, value=0.5)
            threshold2 = gr.Slider(label='Threshold', minimum=0.0, maximum=1.0, step=0.01, value=0.9)
        return [edit_start, edit_stop, intersect_mask, prompt1, scale1, threshold1, prompt2, scale2, threshold2]

    def run(self, p: processing.StableDiffusionProcessing, edit_start, edit_stop, intersect_mask, prompt1, scale1, threshold1, prompt2, scale2, threshold2): # pylint: disable=arguments-differ, unused-argument
        image = getattr(p, 'init_images', None)
        if len(prompt1) == 0 and len(prompt2) == 0:
            shared.log.error('LEdits: no prompts')
            return None
        if image is None or len(image) == 0:
            shared.log.error('LEdits: no init_images')
            return None
        if shared.sd_model_type != 'sd' and shared.sd_model_type != 'sdxl':
            shared.log.error(f'LEdits: invalid model type: {shared.sd_model_type}')
            return None

        orig_pipeline = shared.sd_model
        orig_offload = shared.opts.diffusers_model_cpu_offload
        orig_prompt_attention = shared.opts.prompt_attention
        shared.opts.data['diffusers_model_cpu_offload'] = False
        shared.opts.data['prompt_attention'] = 'Fixed attention'
        # shared.sd_model.maybe_free_model_hooks() # ledits is not compatible with offloading
        # shared.sd_model.has_accelerate = False
        sd_models.move_model(shared.sd_model, devices.device, force=True)
        if shared.sd_model_type == 'sd':
            shared.sd_model = sd_models.switch_pipe(diffusers.LEditsPPPipelineStableDiffusion, shared.sd_model)
        elif shared.sd_model_type == 'sdxl':
            shared.sd_model = sd_models.switch_pipe(diffusers.LEditsPPPipelineStableDiffusionXL, shared.sd_model)
        if str(devices.dtype) == 'torch.float16':
            shared.sd_model.vae.config.force_upcast = False # not compatible

        shared.sd_model.scheduler = diffusers.DPMSolverMultistepScheduler.from_config(shared.sd_model.scheduler.config, algorithm_type="sde-dpmsolver++", solver_order=2) # ledits is very picky
        p.sampler_name = 'Default'
        p.init() # run init early to take care of resizing

        invert_args = {
            'image': p.init_images[0],
            'source_prompt': p.prompt,
            'source_guidance_scale': p.cfg_scale,
            'num_inversion_steps': p.steps,
            'skip': 1.0 - p.denoising_strength, # invert start
            'generator': None, # not supported
        }
        shared.log.info(f'LEdits invert: {invert_args}')
        _output = shared.sd_model.invert(**invert_args)
        p.task_args = {
            'editing_prompt': [],
            'reverse_editing_direction': [],
            'edit_guidance_scale': [],
            'edit_threshold': [],
            'edit_warmup_steps': int(edit_start * p.steps),
            'edit_cooldown_steps': int((1.0 - edit_stop) * p.steps) if edit_stop < 1.0 else None,
            'use_intersect_mask': intersect_mask, # smoothing?
            'generator': None,
            'guidance_rescale': 0.0, # bug in pipeline if guidance rescale is enabled
        }
        if len(prompt1) > 0:
            p.task_args['editing_prompt'].append(prompt1)
            p.task_args['reverse_editing_direction'].append(False)
            p.task_args['edit_guidance_scale'].append(10.0 * scale1)
            p.task_args['edit_threshold'].append(threshold1)
        if len(prompt2) > 0:
            p.task_args['editing_prompt'].append(prompt2)
            p.task_args['reverse_editing_direction'].append(True)
            p.task_args['edit_guidance_scale'].append(10.0 * scale2)
            p.task_args['edit_threshold'].append(threshold2)

        shared.log.info(f'LEdits: {p.task_args}')
        processed = processing.process_images(p)

        # restore pipeline
        shared.sd_model = orig_pipeline
        shared.opts.data['prompt_attention'] = orig_prompt_attention
        shared.opts.data['diffusers_model_cpu_offload'] = orig_offload
        return processed