Spaces:
Running on Zero

Ruurd commited on
Commit
7141e39
·
verified ·
1 Parent(s): 7ec3bd7

New masking implementation

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +6 -5
llama_diffusion_model.py CHANGED
@@ -97,8 +97,8 @@ class CustomTransformerModel(PreTrainedModel):
97
  self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
98
  self.llama.resize_token_embeddings(config.vocab_size)
99
 
100
- for i, layer in enumerate(self.llama.model.layers):
101
- layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)
102
 
103
  for param in self.llama.parameters():
104
  param.requires_grad = False
@@ -113,7 +113,7 @@ class CustomTransformerModel(PreTrainedModel):
113
 
114
  self.llama = get_peft_model(self.llama, lora_config)
115
  self.llama.print_trainable_parameters()
116
- self.llama = self.llama.to(torch.float16)
117
 
118
  def forward(self, input_ids, labels=None, **kwargs):
119
  batch_size, seq_len = input_ids.shape
@@ -121,8 +121,8 @@ class CustomTransformerModel(PreTrainedModel):
121
 
122
  # Build attention mask
123
  device = input_ids.device
124
-
125
- masking_type = getattr(self.config, "masking_type", "bidirectional")
126
  if masking_type == 'bidirectional':
127
  base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
128
  elif masking_type == 'bidirectional_masked':
@@ -142,6 +142,7 @@ class CustomTransformerModel(PreTrainedModel):
142
  input_ids,
143
  attention_mask=attention_mask,
144
  output_hidden_states=True,
 
145
  **kwargs
146
  )
147
 
 
97
  self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
98
  self.llama.resize_token_embeddings(config.vocab_size)
99
 
100
+ # for i, layer in enumerate(self.llama.model.layers):
101
+ # layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)
102
 
103
  for param in self.llama.parameters():
104
  param.requires_grad = False
 
113
 
114
  self.llama = get_peft_model(self.llama, lora_config)
115
  self.llama.print_trainable_parameters()
116
+ # self.llama = self.llama.to(torch.float16)
117
 
118
  def forward(self, input_ids, labels=None, **kwargs):
119
  batch_size, seq_len = input_ids.shape
 
121
 
122
  # Build attention mask
123
  device = input_ids.device
124
+
125
+ masking_type = getattr(self.config, "masking_type", "bidirectional_masked")
126
  if masking_type == 'bidirectional':
127
  base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
128
  elif masking_type == 'bidirectional_masked':
 
142
  input_ids,
143
  attention_mask=attention_mask,
144
  output_hidden_states=True,
145
+ use_cache=False,
146
  **kwargs
147
  )
148