|
|
|
|
|
|
|
import torch |
|
|
|
|
|
cpu = torch.device('cpu') |
|
gpu = torch.device(f'cuda:{torch.cuda.current_device()}') |
|
gpu_complete_modules = [] |
|
|
|
|
|
class DynamicSwapInstaller: |
|
@staticmethod |
|
def _install_module(module: torch.nn.Module, **kwargs): |
|
original_class = module.__class__ |
|
module.__dict__['forge_backup_original_class'] = original_class |
|
|
|
def hacked_get_attr(self, name: str): |
|
if '_parameters' in self.__dict__: |
|
_parameters = self.__dict__['_parameters'] |
|
if name in _parameters: |
|
p = _parameters[name] |
|
if p is None: |
|
return None |
|
if p.__class__ == torch.nn.Parameter: |
|
return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad) |
|
else: |
|
return p.to(**kwargs) |
|
if '_buffers' in self.__dict__: |
|
_buffers = self.__dict__['_buffers'] |
|
if name in _buffers: |
|
return _buffers[name].to(**kwargs) |
|
return super(original_class, self).__getattr__(name) |
|
|
|
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), { |
|
'__getattr__': hacked_get_attr, |
|
}) |
|
|
|
return |
|
|
|
@staticmethod |
|
def _uninstall_module(module: torch.nn.Module): |
|
if 'forge_backup_original_class' in module.__dict__: |
|
module.__class__ = module.__dict__.pop('forge_backup_original_class') |
|
return |
|
|
|
@staticmethod |
|
def install_model(model: torch.nn.Module, **kwargs): |
|
for m in model.modules(): |
|
DynamicSwapInstaller._install_module(m, **kwargs) |
|
return |
|
|
|
@staticmethod |
|
def uninstall_model(model: torch.nn.Module): |
|
for m in model.modules(): |
|
DynamicSwapInstaller._uninstall_module(m) |
|
return |
|
|
|
|
|
def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device): |
|
if hasattr(model, 'scale_shift_table'): |
|
model.scale_shift_table.data = model.scale_shift_table.data.to(target_device) |
|
return |
|
|
|
for k, p in model.named_modules(): |
|
if hasattr(p, 'weight'): |
|
p.to(target_device) |
|
return |
|
|
|
|
|
def get_cuda_free_memory_gb(device=None): |
|
if device is None: |
|
device = gpu |
|
|
|
memory_stats = torch.cuda.memory_stats(device) |
|
bytes_active = memory_stats['active_bytes.all.current'] |
|
bytes_reserved = memory_stats['reserved_bytes.all.current'] |
|
bytes_free_cuda, _ = torch.cuda.mem_get_info(device) |
|
bytes_inactive_reserved = bytes_reserved - bytes_active |
|
bytes_total_available = bytes_free_cuda + bytes_inactive_reserved |
|
return bytes_total_available / (1024 ** 3) |
|
|
|
|
|
def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0): |
|
print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB') |
|
|
|
for m in model.modules(): |
|
if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb: |
|
torch.cuda.empty_cache() |
|
return |
|
|
|
if hasattr(m, 'weight'): |
|
m.to(device=target_device) |
|
|
|
model.to(device=target_device) |
|
torch.cuda.empty_cache() |
|
return |
|
|
|
|
|
def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0): |
|
print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB') |
|
|
|
for m in model.modules(): |
|
if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb: |
|
torch.cuda.empty_cache() |
|
return |
|
|
|
if hasattr(m, 'weight'): |
|
m.to(device=cpu) |
|
|
|
model.to(device=cpu) |
|
torch.cuda.empty_cache() |
|
return |
|
|
|
|
|
def unload_complete_models(*args): |
|
for m in gpu_complete_modules + list(args): |
|
m.to(device=cpu) |
|
print(f'Unloaded {m.__class__.__name__} as complete.') |
|
|
|
gpu_complete_modules.clear() |
|
torch.cuda.empty_cache() |
|
return |
|
|
|
|
|
def load_model_as_complete(model, target_device, unload=True): |
|
if unload: |
|
unload_complete_models() |
|
|
|
model.to(device=target_device) |
|
print(f'Loaded {model.__class__.__name__} to {target_device} as complete.') |
|
|
|
gpu_complete_modules.append(model) |
|
return |
|
|