Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
4ac824d
1
Parent(s):
c6147d1
Change t5 download code
Browse files
concept_attention/flux/src/flux/util.py
CHANGED
@@ -7,7 +7,7 @@ from huggingface_hub import hf_hub_download
|
|
7 |
from imwatermark import WatermarkEncoder
|
8 |
from safetensors.torch import load_file as load_sft
|
9 |
|
10 |
-
from transformers import T5EncoderModel
|
11 |
|
12 |
from concept_attention.flux.src.flux.model import Flux, FluxParams
|
13 |
from concept_attention.flux.src.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
@@ -127,16 +127,30 @@ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download:
|
|
127 |
return model
|
128 |
|
129 |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
131 |
# Load the safe tensors model
|
132 |
# ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
|
133 |
# return T5Encoder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
134 |
-
t5_encoder = T5EncoderModel.from_pretrained(
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
).to(device)
|
140 |
|
141 |
return t5_encoder
|
142 |
# return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
|
|
7 |
from imwatermark import WatermarkEncoder
|
8 |
from safetensors.torch import load_file as load_sft
|
9 |
|
10 |
+
from transformers import T5EncoderModel, AutoConfig, AutoModel
|
11 |
|
12 |
from concept_attention.flux.src.flux.model import Flux, FluxParams
|
13 |
from concept_attention.flux.src.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
|
|
127 |
return model
|
128 |
|
129 |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
130 |
+
# Download each of the files
|
131 |
+
config_file = hf_hub_download(configs["flux-schnell"].repo_id, "text_encoder_2/config.json") # File 1: config.json
|
132 |
+
safe_tensor_1 = hf_hub_download(configs["flux-schnell"].repo_id, "text_encoder_2/model-00001-of-00002.safetensors") # File 2: model-00001-of-00002.safetensors
|
133 |
+
safe_tensor_2 = hf_hub_download(configs["flux-schnell"].repo_id, "text_encoder_2/model-00002-of-00002.safetensors") # File 3: model-00002-of-00002.safetensors
|
134 |
+
safetensor_index = hf_hub_download(configs["flux-schnell"].repo_id, "text_encoder_2/model.safetensors.index.json") # File 4: model.safetensors.index.json
|
135 |
+
# Auto config the model from the loaded config
|
136 |
+
model_config = AutoConfig.from_pretrained(config_file)
|
137 |
+
# Load the safe tensors into a single state dict
|
138 |
+
state_dict = {}
|
139 |
+
state_dict.update(load_sft(safe_tensor_1, device=device))
|
140 |
+
state_dict.update(load_sft(safe_tensor_2, device=device))
|
141 |
+
# Load the state dict
|
142 |
+
t5_encoder = AutoModel.from_pretrained(configs["flux-schnell"].repo_id, config=model_config, state_dict=state_dict)
|
143 |
+
|
144 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
145 |
# Load the safe tensors model
|
146 |
# ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
|
147 |
# return T5Encoder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
148 |
+
# t5_encoder = T5EncoderModel.from_pretrained(
|
149 |
+
# "black-forest-labs/FLUX.1-schnell",
|
150 |
+
# folder="text_encoder_2",
|
151 |
+
# max_length=max_length,
|
152 |
+
# torch_dtype=torch.bfloat16,
|
153 |
+
# ).to(device)
|
154 |
|
155 |
return t5_encoder
|
156 |
# return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|