jbilcke-hf HF Staff commited on
Commit
9846dba
·
1 Parent(s): 0431fa9

upgrade finetrainers

Browse files
finetrainers/args.py CHANGED
@@ -447,7 +447,7 @@ class BaseArgs:
447
  }
448
 
449
  training_arguments = {
450
- "training_type":self.training_type,
451
  "seed": self.seed,
452
  "batch_size": self.batch_size,
453
  "train_steps": self.train_steps,
 
447
  }
448
 
449
  training_arguments = {
450
+ "training_type": self.training_type,
451
  "seed": self.seed,
452
  "batch_size": self.batch_size,
453
  "train_steps": self.train_steps,
finetrainers/patches/__init__.py CHANGED
@@ -17,7 +17,12 @@ def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBa
17
  if parallel_backend.tensor_parallel_enabled:
18
  patch.patch_apply_rotary_emb_for_tp_compatibility()
19
 
 
 
 
 
 
20
  if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
21
- from dependencies.peft import patch
22
 
23
  patch.patch_peft_move_adapter_to_device_of_base_layer()
 
17
  if parallel_backend.tensor_parallel_enabled:
18
  patch.patch_apply_rotary_emb_for_tp_compatibility()
19
 
20
+ if args.model_name == ModelType.WAN:
21
+ from .models.wan import patch
22
+
23
+ patch.patch_time_text_image_embedding_forward()
24
+
25
  if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
26
+ from .dependencies.peft import patch
27
 
28
  patch.patch_peft_move_adapter_to_device_of_base_layer()
finetrainers/patches/models/ltx_video/patch.py CHANGED
@@ -16,7 +16,7 @@ def patch_apply_rotary_emb_for_tp_compatibility() -> None:
16
 
17
 
18
  def _perform_ltx_transformer_forward_patch() -> None:
19
- LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3Dforward
20
 
21
 
22
  def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
@@ -35,7 +35,7 @@ def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
35
  diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
36
 
37
 
38
- def _patched_LTXVideoTransformer3Dforward(
39
  self,
40
  hidden_states: torch.Tensor,
41
  encoder_hidden_states: torch.Tensor,
 
16
 
17
 
18
  def _perform_ltx_transformer_forward_patch() -> None:
19
+ LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward
20
 
21
 
22
  def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
 
35
  diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
36
 
37
 
38
+ def _patched_LTXVideoTransformer3D_forward(
39
  self,
40
  hidden_states: torch.Tensor,
41
  encoder_hidden_states: torch.Tensor,
finetrainers/patches/models/wan/patch.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import diffusers
4
+ import torch
5
+
6
+
7
+ def patch_time_text_image_embedding_forward() -> None:
8
+ _patch_time_text_image_embedding_forward()
9
+
10
+
11
+ def _patch_time_text_image_embedding_forward() -> None:
12
+ diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = (
13
+ _patched_WanTimeTextImageEmbedding_forward
14
+ )
15
+
16
+
17
+ def _patched_WanTimeTextImageEmbedding_forward(
18
+ self,
19
+ timestep: torch.Tensor,
20
+ encoder_hidden_states: torch.Tensor,
21
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
22
+ ):
23
+ # Some code has been removed compared to original implementation in Diffusers
24
+ # Also, timestep is typed as that of encoder_hidden_states
25
+ timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
26
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
27
+ timestep_proj = self.time_proj(self.act_fn(temb))
28
+
29
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
30
+ if encoder_hidden_states_image is not None:
31
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
32
+
33
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
finetrainers/trainer/sft_trainer/trainer.py CHANGED
@@ -334,6 +334,7 @@ class SFTTrainer:
334
  parallel_backend = self.state.parallel_backend
335
  train_state = self.state.train_state
336
  device = parallel_backend.device
 
337
 
338
  memory_statistics = utils.get_memory_statistics()
339
  logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
@@ -447,8 +448,8 @@ class SFTTrainer:
447
 
448
  logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
449
 
450
- utils.align_device_and_dtype(latent_model_conditions, device, self.args.transformer_dtype)
451
- utils.align_device_and_dtype(condition_model_conditions, device, self.args.transformer_dtype)
452
  latent_model_conditions = utils.make_contiguous(latent_model_conditions)
453
  condition_model_conditions = utils.make_contiguous(condition_model_conditions)
454
 
 
334
  parallel_backend = self.state.parallel_backend
335
  train_state = self.state.train_state
336
  device = parallel_backend.device
337
+ dtype = self.args.transformer_dtype
338
 
339
  memory_statistics = utils.get_memory_statistics()
340
  logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
 
448
 
449
  logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
450
 
451
+ latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype)
452
+ condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype)
453
  latent_model_conditions = utils.make_contiguous(latent_model_conditions)
454
  condition_model_conditions = utils.make_contiguous(condition_model_conditions)
455