|
import torch |
|
import os |
|
|
|
from diffusers import ( |
|
DDPMScheduler, |
|
StableDiffusionXLImg2ImgPipeline, |
|
LTXPipeline, |
|
AutoencoderKL, |
|
) |
|
|
|
from hidiffusion import apply_hidiffusion |
|
|
|
from mediapipe.tasks import python |
|
from mediapipe.tasks.python import vision |
|
|
|
from image_gen_aux import UpscaleWithModel |
|
|
|
BASE_MODEL = "stabilityai/sdxl-turbo" |
|
VIDEO_MODEL = "Lightricks/LTX-Video" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
class ModelHandler: |
|
def __init__(self): |
|
self.base_pipe = None |
|
self.video_pipe = None |
|
self.compiled_model = None |
|
self.segmenter = None |
|
self.upscaler = None |
|
self.upscaler4SD = None |
|
self.load_models() |
|
|
|
def load_base(self): |
|
vae = AutoencoderKL.from_pretrained( |
|
"madebyollin/sdxl-vae-fp16-fix", |
|
torch_dtype=torch.float16, |
|
) |
|
|
|
base_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
|
BASE_MODEL, |
|
vae=vae, |
|
torch_dtype=torch.float16, |
|
variant="fp16", |
|
use_safetensors=True, |
|
) |
|
base_pipe = base_pipe.to(device, silence_dtype_warnings=True) |
|
base_pipe.scheduler = DDPMScheduler.from_pretrained( |
|
BASE_MODEL, |
|
subfolder="scheduler", |
|
) |
|
apply_hidiffusion(base_pipe) |
|
|
|
return base_pipe |
|
|
|
def load_video_pipe(self): |
|
pipe = LTXPipeline.from_pretrained(VIDEO_MODEL, torch_dtype=torch.bfloat16) |
|
pipe.to(device) |
|
return pipe |
|
|
|
def load_segmenter(self): |
|
segment_model = "checkpoints/selfie_multiclass_256x256.tflite" |
|
base_options = python.BaseOptions(model_asset_path=segment_model) |
|
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True) |
|
segmenter = vision.ImageSegmenter.create_from_options(options) |
|
return segmenter |
|
|
|
def load_upscaler(self): |
|
model_name = os.environ.get("UPSCALE_MODEL", "Phips/4xNomosWebPhoto_RealPLKSR") |
|
upscaler = UpscaleWithModel.from_pretrained(model_name).to(device) |
|
return upscaler |
|
|
|
def load_upscaler4SD(self): |
|
model_name = os.environ.get("UPSCALE_FOR_SD_MODEL", "Phips/1xDeJPG_realplksr_otf") |
|
upscaler = UpscaleWithModel.from_pretrained(model_name).to(device) |
|
return upscaler |
|
|
|
def load_models(self): |
|
base_pipe = self.load_base() |
|
segmenter = self.load_segmenter() |
|
upscaler = self.load_upscaler() |
|
upscaler4SD = self.load_upscaler4SD() |
|
|
|
self.base_pipe = base_pipe |
|
self.segmenter = segmenter |
|
self.upscaler = upscaler |
|
self.upscaler4SD = upscaler4SD |
|
|
|
MODELS = ModelHandler() |