anas-awadalla commited on
Commit
7bf035c
·
1 Parent(s): 350967e

Update open_flamingo/open_flamingo/src/factory.py

Browse files
open_flamingo/open_flamingo/src/factory.py CHANGED
@@ -16,7 +16,6 @@ def create_model_and_transforms(
16
  use_local_files: bool = False,
17
  decoder_layers_attr_name: str = None,
18
  freeze_lm_embeddings: bool = False,
19
- device: int = 0,
20
  **flamingo_kwargs,
21
  ):
22
  """
@@ -41,7 +40,6 @@ def create_model_and_transforms(
41
  )
42
  # set the vision encoder to output the visual features
43
  vision_encoder.visual.output_tokens = True
44
- vision_encoder.to(device, dtype=torch.bfloat16) if device > -1 else None
45
 
46
  text_tokenizer = AutoTokenizer.from_pretrained(
47
  tokenizer_path,
@@ -60,7 +58,7 @@ def create_model_and_transforms(
60
  lang_encoder = AutoModelForCausalLM.from_pretrained(
61
  lang_encoder_path,
62
  local_files_only=use_local_files,
63
- trust_remote_code=True).to(device, dtype=torch.bfloat16) if device > -1 else None
64
 
65
  # hacks for MPT-1B, which doesn't have a get_input_embeddings method
66
  if "mpt-1b-redpajama-200b" in lang_encoder_path:
@@ -91,7 +89,7 @@ def create_model_and_transforms(
91
  "width"
92
  ],
93
  cross_attn_every_n_layers=cross_attn_every_n_layers,
94
- **flamingo_kwargs).to(device, dtype=torch.bfloat16) if device > -1 else None
95
 
96
  # Freeze all parameters
97
  model.requires_grad_(False)
 
16
  use_local_files: bool = False,
17
  decoder_layers_attr_name: str = None,
18
  freeze_lm_embeddings: bool = False,
 
19
  **flamingo_kwargs,
20
  ):
21
  """
 
40
  )
41
  # set the vision encoder to output the visual features
42
  vision_encoder.visual.output_tokens = True
 
43
 
44
  text_tokenizer = AutoTokenizer.from_pretrained(
45
  tokenizer_path,
 
58
  lang_encoder = AutoModelForCausalLM.from_pretrained(
59
  lang_encoder_path,
60
  local_files_only=use_local_files,
61
+ trust_remote_code=True)
62
 
63
  # hacks for MPT-1B, which doesn't have a get_input_embeddings method
64
  if "mpt-1b-redpajama-200b" in lang_encoder_path:
 
89
  "width"
90
  ],
91
  cross_attn_every_n_layers=cross_attn_every_n_layers,
92
+ **flamingo_kwargs)
93
 
94
  # Freeze all parameters
95
  model.requires_grad_(False)