Spaces:
Paused
Paused
import torch | |
import gc | |
from collections import OrderedDict | |
from typing import TYPE_CHECKING | |
from jobs.process import BaseExtensionProcess | |
from toolkit.config_modules import ModelConfig | |
from toolkit.stable_diffusion_model import StableDiffusion | |
from toolkit.train_tools import get_torch_dtype | |
from tqdm import tqdm | |
# Type check imports. Prevents circular imports | |
if TYPE_CHECKING: | |
from jobs import ExtensionJob | |
# extend standard config classes to add weight | |
class ModelInputConfig(ModelConfig): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.weight = kwargs.get('weight', 1.0) | |
# overwrite default dtype unless user specifies otherwise | |
# float 32 will give up better precision on the merging functions | |
self.dtype: str = kwargs.get('dtype', 'float32') | |
def flush(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# this is our main class process | |
class ExampleMergeModels(BaseExtensionProcess): | |
def __init__( | |
self, | |
process_id: int, | |
job: 'ExtensionJob', | |
config: OrderedDict | |
): | |
super().__init__(process_id, job, config) | |
# this is the setup process, do not do process intensive stuff here, just variable setup and | |
# checking requirements. This is called before the run() function | |
# no loading models or anything like that, it is just for setting up the process | |
# all of your process intensive stuff should be done in the run() function | |
# config will have everything from the process item in the config file | |
# convince methods exist on BaseProcess to get config values | |
# if required is set to true and the value is not found it will throw an error | |
# you can pass a default value to get_conf() as well if it was not in the config file | |
# as well as a type to cast the value to | |
self.save_path = self.get_conf('save_path', required=True) | |
self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype) | |
self.device = self.get_conf('device', default='cpu', as_type=torch.device) | |
# build models to merge list | |
models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list) | |
# build list of ModelInputConfig objects. I find it is a good idea to make a class for each config | |
# this way you can add methods to it and it is easier to read and code. There are a lot of | |
# inbuilt config classes located in toolkit.config_modules as well | |
self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge] | |
# setup is complete. Don't load anything else here, just setup variables and stuff | |
# this is the entire run process be sure to call super().run() first | |
def run(self): | |
# always call first | |
super().run() | |
print(f"Running process: {self.__class__.__name__}") | |
# let's adjust our weights first to normalize them so the total is 1.0 | |
total_weight = sum([model.weight for model in self.models_to_merge]) | |
weight_adjust = 1.0 / total_weight | |
for model in self.models_to_merge: | |
model.weight *= weight_adjust | |
output_model: StableDiffusion = None | |
# let's do the merge, it is a good idea to use tqdm to show progress | |
for model_config in tqdm(self.models_to_merge, desc="Merging models"): | |
# setup model class with our helper class | |
sd_model = StableDiffusion( | |
device=self.device, | |
model_config=model_config, | |
dtype="float32" | |
) | |
# load the model | |
sd_model.load_model() | |
# adjust the weight of the text encoder | |
if isinstance(sd_model.text_encoder, list): | |
# sdxl model | |
for text_encoder in sd_model.text_encoder: | |
for key, value in text_encoder.state_dict().items(): | |
value *= model_config.weight | |
else: | |
# normal model | |
for key, value in sd_model.text_encoder.state_dict().items(): | |
value *= model_config.weight | |
# adjust the weights of the unet | |
for key, value in sd_model.unet.state_dict().items(): | |
value *= model_config.weight | |
if output_model is None: | |
# use this one as the base | |
output_model = sd_model | |
else: | |
# merge the models | |
# text encoder | |
if isinstance(output_model.text_encoder, list): | |
# sdxl model | |
for i, text_encoder in enumerate(output_model.text_encoder): | |
for key, value in text_encoder.state_dict().items(): | |
value += sd_model.text_encoder[i].state_dict()[key] | |
else: | |
# normal model | |
for key, value in output_model.text_encoder.state_dict().items(): | |
value += sd_model.text_encoder.state_dict()[key] | |
# unet | |
for key, value in output_model.unet.state_dict().items(): | |
value += sd_model.unet.state_dict()[key] | |
# remove the model to free memory | |
del sd_model | |
flush() | |
# merge loop is done, let's save the model | |
print(f"Saving merged model to {self.save_path}") | |
output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype) | |
print(f"Saved merged model to {self.save_path}") | |
# do cleanup here | |
del output_model | |
flush() | |