File size: 4,458 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#regional-prompting-pipeline
# https://github.com/huggingface/diffusers/blob/main/examples/community/regional_prompting_stable_diffusion.py

import gradio as gr
from diffusers.pipelines import pipeline_utils
from modules import shared, devices, scripts, processing, sd_models, prompt_parser_diffusers


def hijack_register_modules(self, **kwargs):
    for name, module in kwargs.items():
        if module is None or isinstance(module, (tuple, list)) and module[0] is None:
            register_dict = {name: (None, None)}
        elif isinstance(module, bool):
            pass
        else:
            library, class_name = pipeline_utils._fetch_class_library_tuple(module) # pylint: disable=protected-access
            register_dict = {name: (library, class_name)}
        self.register_to_config(**register_dict)
        setattr(self, name, module)


class Script(scripts.Script):
    def title(self):
        return 'Regional prompting'

    def show(self, is_img2img):
        return not is_img2img if shared.backend == shared.Backend.DIFFUSERS else False

    def change(self, mode):
        return [gr.update(visible='Col' in mode or 'Row' in mode), gr.update(visible='Prompt' in mode)]

    def ui(self, _is_img2img):
        with gr.Row():
            gr.HTML('<a href="https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#regional-prompting-pipeline">&nbsp Regional prompting</a><br>')
        with gr.Row():
            mode = gr.Radio(label='Mode', choices=['None', 'Prompt', 'Prompt EX', 'Columns', 'Rows'], value='None')
        with gr.Row():
            power = gr.Slider(label='Power', minimum=0, maximum=1, value=1.0, step=0.01)
            threshold = gr.Textbox('', label='Prompt thresholds:', default='', visible=False)
            grid = gr.Text('', label='Grid sections:', default='', visible=False)
        mode.change(fn=self.change, inputs=[mode], outputs=[grid, threshold])
        return mode, grid, power, threshold

    def run(self, p: processing.StableDiffusionProcessing, mode, grid, power, threshold): # pylint: disable=arguments-differ
        if mode is None or mode == 'None':
            return
        # backup pipeline and params
        orig_pipeline = shared.sd_model
        orig_dtype = devices.dtype
        orig_prompt_attention = shared.opts.prompt_attention
        # create pipeline
        if shared.sd_model_type != 'sd':
            shared.log.error(f'Regional prompting: incorrect base model: {shared.sd_model.__class__.__name__}')
            return

        pipeline_utils.DiffusionPipeline.register_modules = hijack_register_modules
        prompt_parser_diffusers.EmbeddingsProvider._encode_token_ids_to_embeddings = prompt_parser_diffusers.orig_encode_token_ids_to_embeddings # pylint: disable=protected-access

        shared.sd_model = sd_models.switch_pipe('regional_prompting_stable_diffusion', shared.sd_model)
        if shared.sd_model.__class__.__name__ != 'RegionalPromptingStableDiffusionPipeline': # switch failed
            shared.log.error(f'Regional prompting: not a tiling pipeline: {shared.sd_model.__class__.__name__}')
            shared.sd_model = orig_pipeline
            return
        sd_models.set_diffuser_options(shared.sd_model)
        shared.opts.data['prompt_attention'] = 'Fixed attention' # this pipeline is not compatible with embeds
        processing.fix_seed(p)
        # set pipeline specific params, note that standard params are applied when applicable
        rp_args = {
            'mode': mode.lower(),
            'power': power,
        }
        if 'prompt' in mode.lower():
            rp_args['th'] = threshold
        else:
            rp_args['div'] = grid
        p.task_args = {
            **p.task_args,
            'prompt': p.prompt,
            'rp_args': rp_args,
        }
        # run pipeline
        shared.log.debug(f'Regional: args={p.task_args}')
        processed: processing.Processed = processing.process_images(p) # runs processing using main loop

        # restore pipeline and params
        prompt_parser_diffusers.EmbeddingsProvider._encode_token_ids_to_embeddings = prompt_parser_diffusers.compel_hijack # pylint: disable=protected-access
        shared.opts.data['prompt_attention'] = orig_prompt_attention
        shared.sd_model = orig_pipeline
        shared.sd_model.to(orig_dtype)
        return processed