turbo_fe / model_handler.py
Sqxww's picture
initial commit
7a6754c
import torch
import os
from diffusers import (
DDPMScheduler,
StableDiffusionXLImg2ImgPipeline,
LTXPipeline,
AutoencoderKL,
)
from hidiffusion import apply_hidiffusion
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from image_gen_aux import UpscaleWithModel
BASE_MODEL = "stabilityai/sdxl-turbo"
VIDEO_MODEL = "Lightricks/LTX-Video"
device = "cuda" if torch.cuda.is_available() else "cpu"
class ModelHandler:
def __init__(self):
self.base_pipe = None
self.video_pipe = None
self.compiled_model = None
self.segmenter = None
self.upscaler = None
self.upscaler4SD = None
self.load_models()
def load_base(self):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
base_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
BASE_MODEL,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
base_pipe = base_pipe.to(device, silence_dtype_warnings=True)
base_pipe.scheduler = DDPMScheduler.from_pretrained(
BASE_MODEL,
subfolder="scheduler",
)
apply_hidiffusion(base_pipe)
return base_pipe
def load_video_pipe(self):
pipe = LTXPipeline.from_pretrained(VIDEO_MODEL, torch_dtype=torch.bfloat16)
pipe.to(device)
return pipe
def load_segmenter(self):
segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=segment_model)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)
return segmenter
def load_upscaler(self):
model_name = os.environ.get("UPSCALE_MODEL", "Phips/4xNomosWebPhoto_RealPLKSR")
upscaler = UpscaleWithModel.from_pretrained(model_name).to(device)
return upscaler
def load_upscaler4SD(self):
model_name = os.environ.get("UPSCALE_FOR_SD_MODEL", "Phips/1xDeJPG_realplksr_otf")
upscaler = UpscaleWithModel.from_pretrained(model_name).to(device)
return upscaler
def load_models(self):
base_pipe = self.load_base()
segmenter = self.load_segmenter()
upscaler = self.load_upscaler()
upscaler4SD = self.load_upscaler4SD()
self.base_pipe = base_pipe
self.segmenter = segmenter
self.upscaler = upscaler
self.upscaler4SD = upscaler4SD
MODELS = ModelHandler()