File size: 1,682 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from typing import Optional, Dict, Any
import onnxruntime as ort
import optimum.onnxruntime
from modules.onnx_impl.pipelines import CallablePipelineBase
from modules.onnx_impl.pipelines.utils import prepare_latents


class OnnxStableDiffusionXLPipeline(CallablePipelineBase, optimum.onnxruntime.ORTStableDiffusionXLPipeline):
    __module__ = 'optimum.onnxruntime.modeling_diffusion'
    __name__ = 'ORTStableDiffusionXLPipeline'

    def __init__(
        self,
        vae_decoder: ort.InferenceSession,
        text_encoder: ort.InferenceSession,
        unet: ort.InferenceSession,
        config: Dict[str, Any],
        tokenizer: Any,
        scheduler: Any,
        feature_extractor: Any = None,
        vae_encoder: Optional[ort.InferenceSession] = None,
        text_encoder_2: Optional[ort.InferenceSession] = None,
        tokenizer_2: Any = None,
        use_io_binding: Optional[bool] = None,
        model_save_dir = None,
        add_watermarker: Optional[bool] = None
    ):
        optimum.onnxruntime.ORTStableDiffusionXLPipeline.__init__(self, vae_decoder, text_encoder, unet, config, tokenizer, scheduler, feature_extractor, vae_encoder, text_encoder_2, tokenizer_2, use_io_binding, model_save_dir, add_watermarker)
        super().__init__()
        del self.image_processor # This image processor requires np array. In order to share same workflow with non-XL pipelines, delete it.

    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, generator, latents=None):
        return prepare_latents(self.scheduler.init_noise_sigma, batch_size, height, width, dtype, generator, latents, num_channels_latents, self.vae_scale_factor)