import torch import os from diffusers import ( DDPMScheduler, StableDiffusionXLImg2ImgPipeline, LTXPipeline, AutoencoderKL, ) from hidiffusion import apply_hidiffusion from mediapipe.tasks import python from mediapipe.tasks.python import vision from image_gen_aux import UpscaleWithModel BASE_MODEL = "stabilityai/sdxl-turbo" VIDEO_MODEL = "Lightricks/LTX-Video" device = "cuda" if torch.cuda.is_available() else "cpu" class ModelHandler: def __init__(self): self.base_pipe = None self.video_pipe = None self.compiled_model = None self.segmenter = None self.upscaler = None self.upscaler4SD = None self.load_models() def load_base(self): vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, ) base_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( BASE_MODEL, vae=vae, torch_dtype=torch.float16, variant="fp16", use_safetensors=True, ) base_pipe = base_pipe.to(device, silence_dtype_warnings=True) base_pipe.scheduler = DDPMScheduler.from_pretrained( BASE_MODEL, subfolder="scheduler", ) apply_hidiffusion(base_pipe) return base_pipe def load_video_pipe(self): pipe = LTXPipeline.from_pretrained(VIDEO_MODEL, torch_dtype=torch.bfloat16) pipe.to(device) return pipe def load_segmenter(self): segment_model = "checkpoints/selfie_multiclass_256x256.tflite" base_options = python.BaseOptions(model_asset_path=segment_model) options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) segmenter = vision.ImageSegmenter.create_from_options(options) return segmenter def load_upscaler(self): model_name = os.environ.get("UPSCALE_MODEL", "Phips/4xNomosWebPhoto_RealPLKSR") upscaler = UpscaleWithModel.from_pretrained(model_name).to(device) return upscaler def load_upscaler4SD(self): model_name = os.environ.get("UPSCALE_FOR_SD_MODEL", "Phips/1xDeJPG_realplksr_otf") upscaler = UpscaleWithModel.from_pretrained(model_name).to(device) return upscaler def load_models(self): base_pipe = self.load_base() segmenter = self.load_segmenter() upscaler = self.load_upscaler() upscaler4SD = self.load_upscaler4SD() self.base_pipe = base_pipe self.segmenter = segmenter self.upscaler = upscaler self.upscaler4SD = upscaler4SD MODELS = ModelHandler()