helblazer811 commited on
Commit
4ac824d
·
1 Parent(s): c6147d1

Change t5 download code

Browse files
concept_attention/flux/src/flux/util.py CHANGED
@@ -7,7 +7,7 @@ from huggingface_hub import hf_hub_download
7
  from imwatermark import WatermarkEncoder
8
  from safetensors.torch import load_file as load_sft
9
 
10
- from transformers import T5EncoderModel
11
 
12
  from concept_attention.flux.src.flux.model import Flux, FluxParams
13
  from concept_attention.flux.src.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
@@ -127,16 +127,30 @@ def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download:
127
  return model
128
 
129
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
131
  # Load the safe tensors model
132
  # ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
133
  # return T5Encoder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
134
- t5_encoder = T5EncoderModel.from_pretrained(
135
- "black-forest-labs/FLUX.1-schnell",
136
- folder="text_encoder_2",
137
- max_length=max_length,
138
- torch_dtype=torch.bfloat16,
139
- ).to(device)
140
 
141
  return t5_encoder
142
  # return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
 
7
  from imwatermark import WatermarkEncoder
8
  from safetensors.torch import load_file as load_sft
9
 
10
+ from transformers import T5EncoderModel, AutoConfig, AutoModel
11
 
12
  from concept_attention.flux.src.flux.model import Flux, FluxParams
13
  from concept_attention.flux.src.flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
 
127
  return model
128
 
129
  def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
130
+ # Download each of the files
131
+ config_file = hf_hub_download(configs["flux-schnell"].repo_id, "text_encoder_2/config.json") # File 1: config.json
132
+ safe_tensor_1 = hf_hub_download(configs["flux-schnell"].repo_id, "text_encoder_2/model-00001-of-00002.safetensors") # File 2: model-00001-of-00002.safetensors
133
+ safe_tensor_2 = hf_hub_download(configs["flux-schnell"].repo_id, "text_encoder_2/model-00002-of-00002.safetensors") # File 3: model-00002-of-00002.safetensors
134
+ safetensor_index = hf_hub_download(configs["flux-schnell"].repo_id, "text_encoder_2/model.safetensors.index.json") # File 4: model.safetensors.index.json
135
+ # Auto config the model from the loaded config
136
+ model_config = AutoConfig.from_pretrained(config_file)
137
+ # Load the safe tensors into a single state dict
138
+ state_dict = {}
139
+ state_dict.update(load_sft(safe_tensor_1, device=device))
140
+ state_dict.update(load_sft(safe_tensor_2, device=device))
141
+ # Load the state dict
142
+ t5_encoder = AutoModel.from_pretrained(configs["flux-schnell"].repo_id, config=model_config, state_dict=state_dict)
143
+
144
  # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
145
  # Load the safe tensors model
146
  # ckpt_path = hf_hub_download(configs["name"].repo_id, configs["name"].repo_flow)
147
  # return T5Encoder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
148
+ # t5_encoder = T5EncoderModel.from_pretrained(
149
+ # "black-forest-labs/FLUX.1-schnell",
150
+ # folder="text_encoder_2",
151
+ # max_length=max_length,
152
+ # torch_dtype=torch.bfloat16,
153
+ # ).to(device)
154
 
155
  return t5_encoder
156
  # return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)