Maxime commited on
Commit
a184549
·
unverified ·
1 Parent(s): f311df9

ignore: linter

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +1 -1
src/axolotl/utils/models.py CHANGED
@@ -368,7 +368,7 @@ def load_model(
368
 
369
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
  # convert them back to fp16/bf16 for flash-attn compatibility.
371
- if (fix_dtype or cfg.adapter == "" or cfg.adapter == None) and (
372
  cfg.flash_attention and cfg.is_llama_derived_model
373
  ):
374
  for name, module in model.named_modules():
 
368
 
369
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
370
  # convert them back to fp16/bf16 for flash-attn compatibility.
371
+ if (fix_dtype or cfg.adapter == "" or cfg.adapter is None) and (
372
  cfg.flash_attention and cfg.is_llama_derived_model
373
  ):
374
  for name, module in model.named_modules():