anas-awadalla commited on
Commit
9c16cd6
·
1 Parent(s): 64c725f

Update open_flamingo/open_flamingo/src/factory.py

Browse files
open_flamingo/open_flamingo/src/factory.py CHANGED
@@ -59,9 +59,7 @@ def create_model_and_transforms(
59
  lang_encoder = AutoModelForCausalLM.from_pretrained(
60
  lang_encoder_path,
61
  local_files_only=use_local_files,
62
- trust_remote_code=True,
63
-
64
- ).to(device, dtype=torch.bfloat16) if device > -1 else None
65
 
66
  # hacks for MPT-1B, which doesn't have a get_input_embeddings method
67
  if "mpt-1b-redpajama-200b" in lang_encoder_path:
@@ -92,8 +90,7 @@ def create_model_and_transforms(
92
  "width"
93
  ],
94
  cross_attn_every_n_layers=cross_attn_every_n_layers,
95
- **flamingo_kwargs,
96
- ).to(device, dtype=torch.bfloat16) if device > -1 else None
97
 
98
  # Freeze all parameters
99
  model.requires_grad_(False)
 
59
  lang_encoder = AutoModelForCausalLM.from_pretrained(
60
  lang_encoder_path,
61
  local_files_only=use_local_files,
62
+ trust_remote_code=True).to(device, dtype=torch.bfloat16) if device > -1 else None
 
 
63
 
64
  # hacks for MPT-1B, which doesn't have a get_input_embeddings method
65
  if "mpt-1b-redpajama-200b" in lang_encoder_path:
 
90
  "width"
91
  ],
92
  cross_attn_every_n_layers=cross_attn_every_n_layers,
93
+ **flamingo_kwargs).to(device, dtype=torch.bfloat16) if device > -1 else None
 
94
 
95
  # Freeze all parameters
96
  model.requires_grad_(False)