File size: 2,714 Bytes
7a6754c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import os

from diffusers import (
    DDPMScheduler,
    StableDiffusionXLImg2ImgPipeline,
    LTXPipeline,
    AutoencoderKL,
)

from hidiffusion import apply_hidiffusion

from mediapipe.tasks import python
from mediapipe.tasks.python import vision

from image_gen_aux import UpscaleWithModel

BASE_MODEL = "stabilityai/sdxl-turbo"
VIDEO_MODEL = "Lightricks/LTX-Video"
device = "cuda" if torch.cuda.is_available() else "cpu"

class ModelHandler:
    def __init__(self):
        self.base_pipe = None
        self.video_pipe = None
        self.compiled_model = None
        self.segmenter = None
        self.upscaler = None
        self.upscaler4SD = None
        self.load_models()

    def load_base(self):
        vae = AutoencoderKL.from_pretrained(
            "madebyollin/sdxl-vae-fp16-fix", 
            torch_dtype=torch.float16,
        )

        base_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
            BASE_MODEL, 
            vae=vae,
            torch_dtype=torch.float16, 
            variant="fp16", 
            use_safetensors=True,
        )
        base_pipe = base_pipe.to(device, silence_dtype_warnings=True)
        base_pipe.scheduler = DDPMScheduler.from_pretrained(
            BASE_MODEL,
            subfolder="scheduler",
        )
        apply_hidiffusion(base_pipe)
        
        return base_pipe
    
    def load_video_pipe(self):
        pipe = LTXPipeline.from_pretrained(VIDEO_MODEL, torch_dtype=torch.bfloat16)
        pipe.to(device)
        return pipe
    
    def load_segmenter(self):
        segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
        base_options = python.BaseOptions(model_asset_path=segment_model)
        options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
        segmenter = vision.ImageSegmenter.create_from_options(options)
        return segmenter
    
    def load_upscaler(self):
        model_name = os.environ.get("UPSCALE_MODEL", "Phips/4xNomosWebPhoto_RealPLKSR")
        upscaler = UpscaleWithModel.from_pretrained(model_name).to(device)
        return upscaler
    
    def load_upscaler4SD(self):
        model_name = os.environ.get("UPSCALE_FOR_SD_MODEL", "Phips/1xDeJPG_realplksr_otf")
        upscaler = UpscaleWithModel.from_pretrained(model_name).to(device)
        return upscaler

    def load_models(self):
        base_pipe = self.load_base()
        segmenter = self.load_segmenter()
        upscaler = self.load_upscaler()
        upscaler4SD = self.load_upscaler4SD()

        self.base_pipe = base_pipe
        self.segmenter = segmenter
        self.upscaler = upscaler
        self.upscaler4SD = upscaler4SD

MODELS = ModelHandler()