Alimoi's picture
Upload 14 files
3483284 verified
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0)
import gguf
import torch
import comfy.ops
import comfy.model_management
from .dequant import dequantize_tensor, is_quantized
# to avoid breaking really old pytorch versions
if hasattr(torch, "compiler") and hasattr(torch.compiler, "disable"):
torch_compiler_disable = torch.compiler.disable
else:
def torch_compiler_disable(*args, **kwargs):
def noop(x):
return x
return noop
class GGMLTensor(torch.Tensor):
"""
Main tensor-like class for storing quantized weights
"""
def __init__(self, *args, tensor_type, tensor_shape, patches=[], **kwargs):
super().__init__()
self.tensor_type = tensor_type
self.tensor_shape = tensor_shape
self.patches = patches
def __new__(cls, *args, tensor_type, tensor_shape, patches=[], **kwargs):
return super().__new__(cls, *args, **kwargs)
def to(self, *args, **kwargs):
new = super().to(*args, **kwargs)
new.tensor_type = getattr(self, "tensor_type", None)
new.tensor_shape = getattr(self, "tensor_shape", new.data.shape)
new.patches = getattr(self, "patches", []).copy()
return new
def clone(self, *args, **kwargs):
return self
def detach(self, *args, **kwargs):
return self
def copy_(self, *args, **kwargs):
# fixes .weight.copy_ in comfy/clip_model/CLIPTextModel
try:
return super().copy_(*args, **kwargs)
except Exception as e:
print(f"ignoring 'copy_' on tensor: {e}")
def new_empty(self, size, *args, **kwargs):
# Intel Arc fix, ref#50
new_tensor = super().new_empty(size, *args, **kwargs)
return GGMLTensor(
new_tensor,
tensor_type = getattr(self, "tensor_type", None),
tensor_shape = size,
patches = getattr(self, "patches", []).copy()
)
@property
def shape(self):
if not hasattr(self, "tensor_shape"):
self.tensor_shape = self.size()
return self.tensor_shape
class GGMLLayer(torch.nn.Module):
"""
This (should) be responsible for de-quantizing on the fly
"""
comfy_cast_weights = True
dequant_dtype = None
patch_dtype = None
largest_layer = False
torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}
def is_ggml_quantized(self, *, weight=None, bias=None):
if weight is None:
weight = self.weight
if bias is None:
bias = self.bias
return is_quantized(weight) or is_quantized(bias)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
weight, bias = state_dict.get(f"{prefix}weight"), state_dict.get(f"{prefix}bias")
# NOTE: using modified load for linear due to not initializing on creation, see GGMLOps todo
if self.is_ggml_quantized(weight=weight, bias=bias) or isinstance(self, torch.nn.Linear):
return self.ggml_load_from_state_dict(state_dict, prefix, *args, **kwargs)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def ggml_load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
prefix_len = len(prefix)
for k,v in state_dict.items():
if k[prefix_len:] == "weight":
self.weight = torch.nn.Parameter(v, requires_grad=False)
elif k[prefix_len:] == "bias" and v is not None:
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
# For Linear layer with missing weight
if self.weight is None and isinstance(self, torch.nn.Linear):
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
# for vram estimation (TODO: less fragile logic?)
if getattr(self.weight, "is_largest_weight", False):
self.largest_layer = True
def _save_to_state_dict(self, *args, **kwargs):
if self.is_ggml_quantized():
return self.ggml_save_to_state_dict(*args, **kwargs)
return super()._save_to_state_dict(*args, **kwargs)
def ggml_save_to_state_dict(self, destination, prefix, keep_vars):
# This is a fake state dict for vram estimation
weight = torch.zeros_like(self.weight, device=torch.device("meta"))
destination[prefix + "weight"] = weight
if self.bias is not None:
bias = torch.zeros_like(self.bias, device=torch.device("meta"))
destination[prefix + "bias"] = bias
# Take into account space required for dequantizing the largest tensor
if self.largest_layer:
shape = getattr(self.weight, "tensor_shape", self.weight.shape)
dtype = self.dequant_dtype or torch.float16
temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype)
destination[prefix + "temp.weight"] = temp
return
# This would return the dequantized state dict
destination[prefix + "weight"] = self.get_weight(self.weight)
if bias is not None:
destination[prefix + "bias"] = self.get_weight(self.bias)
def get_weight(self, tensor, dtype):
if tensor is None:
return
# consolidate and load patches to GPU in async
patch_list = []
device = tensor.device
for function, patches, key in getattr(tensor, "patches", []):
patch_list += move_patch_to_device(patches, device)
# dequantize tensor while patches load
weight = dequantize_tensor(tensor, dtype, self.dequant_dtype)
# prevent propagating custom tensor class
if isinstance(weight, GGMLTensor):
weight.__class__ = torch.Tensor
# apply patches
if patch_list:
if self.patch_dtype is None:
weight = function(patch_list, weight, key)
else:
# for testing, may degrade image quality
patch_dtype = dtype if self.patch_dtype == "target" else self.patch_dtype
weight = function(patch_list, weight, key, patch_dtype)
return weight
@torch_compiler_disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = getattr(input, "dtype", torch.float32)
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
bias = s.get_weight(s.bias.to(device), dtype)
bias = comfy.ops.cast_to(bias, bias_dtype, device, non_blocking=non_blocking, copy=False)
weight = s.get_weight(s.weight.to(device), dtype)
weight = comfy.ops.cast_to(weight, dtype, device, non_blocking=non_blocking, copy=False)
return weight, bias
def forward_comfy_cast_weights(self, input, *args, **kwargs):
if self.is_ggml_quantized():
out = self.forward_ggml_cast_weights(input, *args, **kwargs)
else:
out = super().forward_comfy_cast_weights(input, *args, **kwargs)
# non-ggml forward might still propagate custom tensor class
if isinstance(out, GGMLTensor):
out.__class__ = torch.Tensor
return out
def forward_ggml_cast_weights(self, input):
raise NotImplementedError
class GGMLOps(comfy.ops.manual_cast):
"""
Dequantize weights on the fly before doing the compute
"""
class Linear(GGMLLayer, comfy.ops.manual_cast.Linear):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
torch.nn.Module.__init__(self)
# TODO: better workaround for reserved memory spike on windows
# Issue is with `torch.empty` still reserving the full memory for the layer
# Windows doesn't over-commit memory so without this 24GB+ of pagefile is used
self.in_features = in_features
self.out_features = out_features
self.weight = None
self.bias = None
def forward_ggml_cast_weights(self, input):
weight, bias = self.cast_bias_weight(input)
return torch.nn.functional.linear(input, weight, bias)
class Conv2d(GGMLLayer, comfy.ops.manual_cast.Conv2d):
def forward_ggml_cast_weights(self, input):
weight, bias = self.cast_bias_weight(input)
return self._conv_forward(input, weight, bias)
class Embedding(GGMLLayer, comfy.ops.manual_cast.Embedding):
def forward_ggml_cast_weights(self, input, out_dtype=None):
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, _bias = self.cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(
input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
).to(dtype=output_dtype)
class LayerNorm(GGMLLayer, comfy.ops.manual_cast.LayerNorm):
def forward_ggml_cast_weights(self, input):
if self.weight is None:
return super().forward_comfy_cast_weights(input)
weight, bias = self.cast_bias_weight(input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
class GroupNorm(GGMLLayer, comfy.ops.manual_cast.GroupNorm):
def forward_ggml_cast_weights(self, input):
weight, bias = self.cast_bias_weight(input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def move_patch_to_device(item, device):
if isinstance(item, torch.Tensor):
return item.to(device, non_blocking=True)
elif isinstance(item, tuple):
return tuple(move_patch_to_device(x, device) for x in item)
elif isinstance(item, list):
return [move_patch_to_device(x, device) for x in item]
else:
return item