File size: 2,714 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from typing import Union
import time
import diffusers.utils
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from modules.shared import log, opts
from modules.control.units import detect
from modules import sd_models


what = 'Reference'


def list_models():
    return ['Reference']


class ReferencePipeline():
    def __init__(self, pipeline: Union[StableDiffusionXLPipeline, StableDiffusionPipeline], dtype = None):
        t0 = time.time()
        self.orig_pipeline = pipeline
        self.pipeline = None
        if pipeline is None:
            log.error(f'Control {what} model pipeline: model not loaded')
            return
        if opts.diffusers_fuse_projections and hasattr(pipeline, 'unfuse_qkv_projections'):
            pipeline.unfuse_qkv_projections()
        if detect.is_sdxl(pipeline):
            cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_xl_reference', module_file='pipeline.py')
            self.pipeline = cls(
                vae=pipeline.vae,
                text_encoder=pipeline.text_encoder,
                text_encoder_2=pipeline.text_encoder_2,
                tokenizer=pipeline.tokenizer,
                tokenizer_2=pipeline.tokenizer_2,
                unet=pipeline.unet,
                scheduler=pipeline.scheduler,
                feature_extractor=getattr(pipeline, 'feature_extractor', None),
            )
            sd_models.move_model(self.pipeline, pipeline.device)
        elif detect.is_sd15(pipeline):
            cls = diffusers.utils.get_class_from_dynamic_module('stable_diffusion_reference', module_file='pipeline.py')
            self.pipeline = cls(
                vae=pipeline.vae,
                text_encoder=pipeline.text_encoder,
                tokenizer=pipeline.tokenizer,
                unet=pipeline.unet,
                scheduler=pipeline.scheduler,
                feature_extractor=getattr(pipeline, 'feature_extractor', None),
                requires_safety_checker=False,
                safety_checker=None,
            )
            sd_models.move_model(self.pipeline, pipeline.device)
        else:
            log.error(f'Control {what} pipeline: class={pipeline.__class__.__name__} unsupported model type')
            return
        if dtype is not None and self.pipeline is not None:
            self.pipeline = self.pipeline.to(dtype)
        t1 = time.time()
        if self.pipeline is not None:
            log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}')
        else:
            log.error(f'Control {what} pipeline: not initialized')

    def restore(self):
        self.pipeline = None
        return self.orig_pipeline