helblazer811 commited on
Commit
8fd23f8
·
1 Parent(s): 3279a4e

Changed state dict

Browse files
concept_attention/flux/src/flux/util.py CHANGED
@@ -7,7 +7,7 @@ from huggingface_hub import hf_hub_download
7
  from imwatermark import WatermarkEncoder
8
  from safetensors.torch import load_file as load_sft
9
 
10
- from transformers import T5EncoderModel, AutoConfig, AutoModel
11
 
12
  from concept_attention.flux.src.flux.model import Flux, FluxParams
13
  from concept_attention.flux.src.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
@@ -139,7 +139,12 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
139
  state_dict.update(load_sft(safe_tensor_1, device=str(device)))
140
  state_dict.update(load_sft(safe_tensor_2, device=str(device)))
141
  # Load the state dict
142
- t5_encoder = AutoModel.from_config(config=model_config, state_dict=state_dict)
 
 
 
 
 
143
 
144
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
145
  # Load the safe tensors model
 
7
  from imwatermark import WatermarkEncoder
8
  from safetensors.torch import load_file as load_sft
9
 
10
+ from transformers import T5EncoderModel, AutoConfig, AutoModel, T5Tokenizer
11
 
12
  from concept_attention.flux.src.flux.model import Flux, FluxParams
13
  from concept_attention.flux.src.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
 
139
  state_dict.update(load_sft(safe_tensor_1, device=str(device)))
140
  state_dict.update(load_sft(safe_tensor_2, device=str(device)))
141
  # Load the state dict
142
+ t5_encoder = T5EncoderModel(config=model_config)
143
+ t5_encoder.load_state_dict(state_dict, strict=False)
144
+
145
+ # Load the tokenizer
146
+ tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
147
+ t5_encoder.tokenizer = tokenizer
148
 
149
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
150
  # Load the safe tensors model