Spaces:
Paused
Paused
# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) | |
import torch | |
import gguf | |
from .ops import GGMLTensor | |
from .dequant import is_quantized, dequantize_tensor | |
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "ltxv", "hyvid", "wan"} | |
TXT_ARCH_LIST = {"t5", "t5encoder", "llama"} | |
def get_orig_shape(reader, tensor_name): | |
field_key = f"comfy.gguf.orig_shape.{tensor_name}" | |
field = reader.get_field(field_key) | |
if field is None: | |
return None | |
# Has original shape metadata, so we try to decode it. | |
if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32: | |
raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}") | |
return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data)) | |
def get_field(reader, field_name, field_type): | |
field = reader.get_field(field_name) | |
if field is None: | |
return None | |
elif field_type == str: | |
# extra check here as this is used for checking arch string | |
if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING: | |
raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}") | |
return str(field.parts[field.data[-1]], encoding="utf-8") | |
elif field_type in [int, float, bool]: | |
return field_type(field.parts[field.data[-1]]) | |
else: | |
raise TypeError(f"Unknown field type {field_type}") | |
def get_list_field(reader, field_name, field_type): | |
field = reader.get_field(field_name) | |
if field is None: | |
return None | |
elif field_type == str: | |
return tuple(str(field.parts[part_idx], encoding="utf-8") for part_idx in field.data) | |
elif field_type in [int, float, bool]: | |
return tuple(field_type(field.parts[part_idx][0]) for part_idx in field.data) | |
else: | |
raise TypeError(f"Unknown field type {field_type}") | |
def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False): | |
""" | |
Read state dict as fake tensors | |
""" | |
reader = gguf.GGUFReader(path) | |
# filter and strip prefix | |
has_prefix = False | |
if handle_prefix is not None: | |
prefix_len = len(handle_prefix) | |
tensor_names = set(tensor.name for tensor in reader.tensors) | |
has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) | |
tensors = [] | |
for tensor in reader.tensors: | |
sd_key = tensor_name = tensor.name | |
if has_prefix: | |
if not tensor_name.startswith(handle_prefix): | |
continue | |
sd_key = tensor_name[prefix_len:] | |
tensors.append((sd_key, tensor)) | |
# detect and verify architecture | |
compat = None | |
arch_str = get_field(reader, "general.architecture", str) | |
if arch_str is None: # stable-diffusion.cpp | |
# import here to avoid changes to convert.py breaking regular models | |
from .tools.convert import detect_arch | |
arch_str = detect_arch(set(val[0] for val in tensors)).arch | |
compat = "sd.cpp" | |
elif arch_str in ["pig"]: | |
from .tools.convert import detect_arch | |
arch_str = detect_arch(set(val[0] for val in tensors)).arch | |
compat = "pig" | |
elif arch_str not in IMG_ARCH_LIST and arch_str not in TXT_ARCH_LIST: | |
raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}") | |
if compat: | |
print(f"Warning: This model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]") | |
# main loading loop | |
state_dict = {} | |
qtype_dict = {} | |
for sd_key, tensor in tensors: | |
tensor_name = tensor.name | |
torch_tensor = torch.from_numpy(tensor.data) # mmap | |
shape = get_orig_shape(reader, tensor_name) | |
if shape is None: | |
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) | |
# Workaround for stable-diffusion.cpp SDXL detection. | |
if compat == "sd.cpp" and arch_str == "sdxl": | |
if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]): | |
while len(shape) > 2 and shape[-1] == 1: | |
shape = shape[:-1] | |
# add to state dict | |
if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: | |
torch_tensor = torch_tensor.view(*shape) | |
state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) | |
# keep track of loaded tensor types | |
tensor_type_str = getattr(tensor.tensor_type, "name", repr(tensor.tensor_type)) | |
qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 | |
# print loaded tensor type counts | |
print("gguf qtypes: " + ", ".join(f"{k} ({v})" for k, v in qtype_dict.items())) | |
# mark largest tensor for vram estimation | |
qsd = {k:v for k,v in state_dict.items() if is_quantized(v)} | |
if len(qsd) > 0: | |
max_key = max(qsd.keys(), key=lambda k: qsd[k].numel()) | |
state_dict[max_key].is_largest_weight = True | |
if return_arch: | |
return (state_dict, arch_str) | |
return state_dict | |
# for remapping llama.cpp -> original key names | |
T5_SD_MAP = { | |
"enc.": "encoder.", | |
".blk.": ".block.", | |
"token_embd": "shared", | |
"output_norm": "final_layer_norm", | |
"attn_q": "layer.0.SelfAttention.q", | |
"attn_k": "layer.0.SelfAttention.k", | |
"attn_v": "layer.0.SelfAttention.v", | |
"attn_o": "layer.0.SelfAttention.o", | |
"attn_norm": "layer.0.layer_norm", | |
"attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", | |
"ffn_up": "layer.1.DenseReluDense.wi_1", | |
"ffn_down": "layer.1.DenseReluDense.wo", | |
"ffn_gate": "layer.1.DenseReluDense.wi_0", | |
"ffn_norm": "layer.1.layer_norm", | |
} | |
LLAMA_SD_MAP = { | |
"blk.": "model.layers.", | |
"attn_norm": "input_layernorm", | |
"attn_q": "self_attn.q_proj", | |
"attn_k": "self_attn.k_proj", | |
"attn_v": "self_attn.v_proj", | |
"attn_output": "self_attn.o_proj", | |
"ffn_up": "mlp.up_proj", | |
"ffn_down": "mlp.down_proj", | |
"ffn_gate": "mlp.gate_proj", | |
"ffn_norm": "post_attention_layernorm", | |
"token_embd": "model.embed_tokens", | |
"output_norm": "model.norm", | |
"output.weight": "lm_head.weight", | |
} | |
def sd_map_replace(raw_sd, key_map): | |
sd = {} | |
for k,v in raw_sd.items(): | |
for s,d in key_map.items(): | |
k = k.replace(s,d) | |
sd[k] = v | |
return sd | |
def llama_permute(raw_sd, n_head, n_head_kv): | |
# Reverse version of LlamaModel.permute in llama.cpp convert script | |
sd = {} | |
permute = lambda x,h: x.reshape(h, x.shape[0] // h // 2, 2, *x.shape[1:]).swapaxes(1, 2).reshape(x.shape) | |
for k,v in raw_sd.items(): | |
if k.endswith(("q_proj.weight", "q_proj.bias")): | |
v.data = permute(v.data, n_head) | |
if k.endswith(("k_proj.weight", "k_proj.bias")): | |
v.data = permute(v.data, n_head_kv) | |
sd[k] = v | |
return sd | |
def gguf_tokenizer_loader(path, temb_shape): | |
# convert gguf tokenizer to spiece | |
print(f"Attempting to recreate sentencepiece tokenizer from GGUF file metadata...") | |
try: | |
from sentencepiece import sentencepiece_model_pb2 as model | |
except ImportError: | |
raise ImportError("Please make sure sentencepiece and protobuf are installed.\npip install sentencepiece protobuf") | |
spm = model.ModelProto() | |
reader = gguf.GGUFReader(path) | |
if get_field(reader, "tokenizer.ggml.model", str) == "t5": | |
if temb_shape == (256384, 4096): # probably UMT5 | |
spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?) | |
else: | |
raise NotImplementedError(f"Unknown model, can't set tokenizer!") | |
else: | |
raise NotImplementedError(f"Unknown model, can't set tokenizer!") | |
spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool) | |
spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool) | |
tokens = get_list_field(reader, "tokenizer.ggml.tokens", str) | |
scores = get_list_field(reader, "tokenizer.ggml.scores", float) | |
toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int) | |
for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)): | |
# # These aren't present in the original? | |
# if toktype == 5 and idx >= temb_shape[0]%1000): | |
# continue | |
piece = spm.SentencePiece() | |
piece.piece = token | |
piece.score = score | |
piece.type = toktype | |
spm.pieces.append(piece) | |
# unsure if any of these are correct | |
spm.trainer_spec.byte_fallback = True | |
spm.trainer_spec.vocab_size = len(tokens) # split off unused? | |
spm.trainer_spec.max_sentence_length = 4096 | |
spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int) | |
spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int) | |
print(f"Created tokenizer with vocab size of {len(spm.pieces)}") | |
del reader | |
return torch.ByteTensor(list(spm.SerializeToString())) | |
def gguf_clip_loader(path): | |
sd, arch = gguf_sd_loader(path, return_arch=True) | |
if arch in {"t5", "t5encoder"}: | |
temb_key = "token_embd.weight" | |
if temb_key in sd and sd[temb_key].shape == (256384, 4096): | |
# non-standard Comfy-Org tokenizer | |
sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape) | |
# TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive. | |
print(f"Dequantizing {temb_key} to prevent runtime OOM.") | |
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) | |
sd = sd_map_replace(sd, T5_SD_MAP) | |
elif arch in {"llama"}: | |
temb_key = "token_embd.weight" | |
if temb_key in sd and sd[temb_key].shape != (128320, 4096): | |
# This still works. Raise error? | |
print("Warning! token_embd shape may be incorrect for llama 3 model!") | |
sd = sd_map_replace(sd, LLAMA_SD_MAP) | |
sd = llama_permute(sd, 32, 8) # L3 | |
else: | |
pass | |
return sd | |