helblazer811 commited on
Commit
5ba3b25
·
1 Parent(s): 4ac824d
concept_attention/flux/src/flux/util.py CHANGED
@@ -136,8 +136,8 @@ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmb
136
  model_config = AutoConfig.from_pretrained(config_file)
137
  # Load the safe tensors into a single state dict
138
  state_dict = {}
139
- state_dict.update(load_sft(safe_tensor_1, device=device))
140
- state_dict.update(load_sft(safe_tensor_2, device=device))
141
  # Load the state dict
142
  t5_encoder = AutoModel.from_pretrained(configs["flux-schnell"].repo_id, config=model_config, state_dict=state_dict)
143
 
 
136
  model_config = AutoConfig.from_pretrained(config_file)
137
  # Load the safe tensors into a single state dict
138
  state_dict = {}
139
+ state_dict.update(load_sft(safe_tensor_1, device=str(device)))
140
+ state_dict.update(load_sft(safe_tensor_2, device=str(device)))
141
  # Load the state dict
142
  t5_encoder = AutoModel.from_pretrained(configs["flux-schnell"].repo_id, config=model_config, state_dict=state_dict)
143