# (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) import torch import gguf from .ops import GGMLTensor from .dequant import is_quantized, dequantize_tensor IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "ltxv", "hyvid", "wan"} TXT_ARCH_LIST = {"t5", "t5encoder", "llama"} def get_orig_shape(reader, tensor_name): field_key = f"comfy.gguf.orig_shape.{tensor_name}" field = reader.get_field(field_key) if field is None: return None # Has original shape metadata, so we try to decode it. if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32: raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}") return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data)) def get_field(reader, field_name, field_type): field = reader.get_field(field_name) if field is None: return None elif field_type == str: # extra check here as this is used for checking arch string if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING: raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}") return str(field.parts[field.data[-1]], encoding="utf-8") elif field_type in [int, float, bool]: return field_type(field.parts[field.data[-1]]) else: raise TypeError(f"Unknown field type {field_type}") def get_list_field(reader, field_name, field_type): field = reader.get_field(field_name) if field is None: return None elif field_type == str: return tuple(str(field.parts[part_idx], encoding="utf-8") for part_idx in field.data) elif field_type in [int, float, bool]: return tuple(field_type(field.parts[part_idx][0]) for part_idx in field.data) else: raise TypeError(f"Unknown field type {field_type}") def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", return_arch=False): """ Read state dict as fake tensors """ reader = gguf.GGUFReader(path) # filter and strip prefix has_prefix = False if handle_prefix is not None: prefix_len = len(handle_prefix) tensor_names = set(tensor.name for tensor in reader.tensors) has_prefix = any(s.startswith(handle_prefix) for s in tensor_names) tensors = [] for tensor in reader.tensors: sd_key = tensor_name = tensor.name if has_prefix: if not tensor_name.startswith(handle_prefix): continue sd_key = tensor_name[prefix_len:] tensors.append((sd_key, tensor)) # detect and verify architecture compat = None arch_str = get_field(reader, "general.architecture", str) if arch_str is None: # stable-diffusion.cpp # import here to avoid changes to convert.py breaking regular models from .tools.convert import detect_arch arch_str = detect_arch(set(val[0] for val in tensors)).arch compat = "sd.cpp" elif arch_str in ["pig"]: from .tools.convert import detect_arch arch_str = detect_arch(set(val[0] for val in tensors)).arch compat = "pig" elif arch_str not in IMG_ARCH_LIST and arch_str not in TXT_ARCH_LIST: raise ValueError(f"Unexpected architecture type in GGUF file: {arch_str!r}") if compat: print(f"Warning: This model file is loaded in compatibility mode '{compat}' [arch:{arch_str}]") # main loading loop state_dict = {} qtype_dict = {} for sd_key, tensor in tensors: tensor_name = tensor.name torch_tensor = torch.from_numpy(tensor.data) # mmap shape = get_orig_shape(reader, tensor_name) if shape is None: shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape))) # Workaround for stable-diffusion.cpp SDXL detection. if compat == "sd.cpp" and arch_str == "sdxl": if any([tensor_name.endswith(x) for x in (".proj_in.weight", ".proj_out.weight")]): while len(shape) > 2 and shape[-1] == 1: shape = shape[:-1] # add to state dict if tensor.tensor_type in {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16}: torch_tensor = torch_tensor.view(*shape) state_dict[sd_key] = GGMLTensor(torch_tensor, tensor_type=tensor.tensor_type, tensor_shape=shape) # keep track of loaded tensor types tensor_type_str = getattr(tensor.tensor_type, "name", repr(tensor.tensor_type)) qtype_dict[tensor_type_str] = qtype_dict.get(tensor_type_str, 0) + 1 # print loaded tensor type counts print("gguf qtypes: " + ", ".join(f"{k} ({v})" for k, v in qtype_dict.items())) # mark largest tensor for vram estimation qsd = {k:v for k,v in state_dict.items() if is_quantized(v)} if len(qsd) > 0: max_key = max(qsd.keys(), key=lambda k: qsd[k].numel()) state_dict[max_key].is_largest_weight = True if return_arch: return (state_dict, arch_str) return state_dict # for remapping llama.cpp -> original key names T5_SD_MAP = { "enc.": "encoder.", ".blk.": ".block.", "token_embd": "shared", "output_norm": "final_layer_norm", "attn_q": "layer.0.SelfAttention.q", "attn_k": "layer.0.SelfAttention.k", "attn_v": "layer.0.SelfAttention.v", "attn_o": "layer.0.SelfAttention.o", "attn_norm": "layer.0.layer_norm", "attn_rel_b": "layer.0.SelfAttention.relative_attention_bias", "ffn_up": "layer.1.DenseReluDense.wi_1", "ffn_down": "layer.1.DenseReluDense.wo", "ffn_gate": "layer.1.DenseReluDense.wi_0", "ffn_norm": "layer.1.layer_norm", } LLAMA_SD_MAP = { "blk.": "model.layers.", "attn_norm": "input_layernorm", "attn_q": "self_attn.q_proj", "attn_k": "self_attn.k_proj", "attn_v": "self_attn.v_proj", "attn_output": "self_attn.o_proj", "ffn_up": "mlp.up_proj", "ffn_down": "mlp.down_proj", "ffn_gate": "mlp.gate_proj", "ffn_norm": "post_attention_layernorm", "token_embd": "model.embed_tokens", "output_norm": "model.norm", "output.weight": "lm_head.weight", } def sd_map_replace(raw_sd, key_map): sd = {} for k,v in raw_sd.items(): for s,d in key_map.items(): k = k.replace(s,d) sd[k] = v return sd def llama_permute(raw_sd, n_head, n_head_kv): # Reverse version of LlamaModel.permute in llama.cpp convert script sd = {} permute = lambda x,h: x.reshape(h, x.shape[0] // h // 2, 2, *x.shape[1:]).swapaxes(1, 2).reshape(x.shape) for k,v in raw_sd.items(): if k.endswith(("q_proj.weight", "q_proj.bias")): v.data = permute(v.data, n_head) if k.endswith(("k_proj.weight", "k_proj.bias")): v.data = permute(v.data, n_head_kv) sd[k] = v return sd def gguf_tokenizer_loader(path, temb_shape): # convert gguf tokenizer to spiece print(f"Attempting to recreate sentencepiece tokenizer from GGUF file metadata...") try: from sentencepiece import sentencepiece_model_pb2 as model except ImportError: raise ImportError("Please make sure sentencepiece and protobuf are installed.\npip install sentencepiece protobuf") spm = model.ModelProto() reader = gguf.GGUFReader(path) if get_field(reader, "tokenizer.ggml.model", str) == "t5": if temb_shape == (256384, 4096): # probably UMT5 spm.trainer_spec.model_type == 1 # Unigram (do we have a T5 w/ BPE?) else: raise NotImplementedError(f"Unknown model, can't set tokenizer!") else: raise NotImplementedError(f"Unknown model, can't set tokenizer!") spm.normalizer_spec.add_dummy_prefix = get_field(reader, "tokenizer.ggml.add_space_prefix", bool) spm.normalizer_spec.remove_extra_whitespaces = get_field(reader, "tokenizer.ggml.remove_extra_whitespaces", bool) tokens = get_list_field(reader, "tokenizer.ggml.tokens", str) scores = get_list_field(reader, "tokenizer.ggml.scores", float) toktypes = get_list_field(reader, "tokenizer.ggml.token_type", int) for idx, (token, score, toktype) in enumerate(zip(tokens, scores, toktypes)): # # These aren't present in the original? # if toktype == 5 and idx >= temb_shape[0]%1000): # continue piece = spm.SentencePiece() piece.piece = token piece.score = score piece.type = toktype spm.pieces.append(piece) # unsure if any of these are correct spm.trainer_spec.byte_fallback = True spm.trainer_spec.vocab_size = len(tokens) # split off unused? spm.trainer_spec.max_sentence_length = 4096 spm.trainer_spec.eos_id = get_field(reader, "tokenizer.ggml.eos_token_id", int) spm.trainer_spec.pad_id = get_field(reader, "tokenizer.ggml.padding_token_id", int) print(f"Created tokenizer with vocab size of {len(spm.pieces)}") del reader return torch.ByteTensor(list(spm.SerializeToString())) def gguf_clip_loader(path): sd, arch = gguf_sd_loader(path, return_arch=True) if arch in {"t5", "t5encoder"}: temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape == (256384, 4096): # non-standard Comfy-Org tokenizer sd["spiece_model"] = gguf_tokenizer_loader(path, sd[temb_key].shape) # TODO: dequantizing token embed here is janky but otherwise we OOM due to tensor being massive. print(f"Dequantizing {temb_key} to prevent runtime OOM.") sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16) sd = sd_map_replace(sd, T5_SD_MAP) elif arch in {"llama"}: temb_key = "token_embd.weight" if temb_key in sd and sd[temb_key].shape != (128320, 4096): # This still works. Raise error? print("Warning! token_embd shape may be incorrect for llama 3 model!") sd = sd_map_replace(sd, LLAMA_SD_MAP) sd = llama_permute(sd, 32, 8) # L3 else: pass return sd