Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8fd23f8
1
Parent(s):
3279a4e
Changed state dict
Browse files
concept_attention/flux/src/flux/util.py
CHANGED
@@ -7,7 +7,7 @@ from huggingface_hub import hf_hub_download
|
|
7 |
from imwatermark import WatermarkEncoder
|
8 |
from safetensors.torch import load_file as load_sft
|
9 |
|
10 |
-
from transformers import T5EncoderModel, AutoConfig, AutoModel
|
11 |
|
12 |
from concept_attention.flux.src.flux.model import Flux, FluxParams
|
13 |
from concept_attention.flux.src.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
@@ -139,7 +139,12 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
|
|
139 |
state_dict.update(load_sft(safe_tensor_1, device=str(device)))
|
140 |
state_dict.update(load_sft(safe_tensor_2, device=str(device)))
|
141 |
# Load the state dict
|
142 |
-
t5_encoder =
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
145 |
# Load the safe tensors model
|
|
|
7 |
from imwatermark import WatermarkEncoder
|
8 |
from safetensors.torch import load_file as load_sft
|
9 |
|
10 |
+
from transformers import T5EncoderModel, AutoConfig, AutoModel, T5Tokenizer
|
11 |
|
12 |
from concept_attention.flux.src.flux.model import Flux, FluxParams
|
13 |
from concept_attention.flux.src.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
|
|
139 |
state_dict.update(load_sft(safe_tensor_1, device=str(device)))
|
140 |
state_dict.update(load_sft(safe_tensor_2, device=str(device)))
|
141 |
# Load the state dict
|
142 |
+
t5_encoder = T5EncoderModel(config=model_config)
|
143 |
+
t5_encoder.load_state_dict(state_dict, strict=False)
|
144 |
+
|
145 |
+
# Load the tokenizer
|
146 |
+
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
|
147 |
+
t5_encoder.tokenizer = tokenizer
|
148 |
|
149 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
150 |
# Load the safe tensors model
|