Spaces:
Running on Zero

Ruurd commited on
Commit
0daaccf
·
1 Parent(s): a7ab71d

Try out bidirectional_masked prediction

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +2 -2
llama_diffusion_model.py CHANGED
@@ -126,7 +126,7 @@ class BidirectionalLlamaAttention(LlamaAttention):
126
  attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
127
  elif self.masking == 'bidirectional_masked':
128
  base_mask = torch.ones((seq_len, seq_len), device=hidden_states.device, dtype=torch.bool)
129
- base_mask[:, 1:].fill_diagonal_(False) # ✅ Apply diagonal masking only in 2D
130
  attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
131
  else: # unidirectional
132
  # 🚀 Standard autoregressive (causal) mask
@@ -192,7 +192,7 @@ class CustomTransformerModel(PreTrainedModel):
192
  self.llama.resize_token_embeddings(config.vocab_size)
193
 
194
  for i, layer in enumerate(self.llama.model.layers):
195
- layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking='bidirectional')
196
 
197
  # Freeze Llama to retain pre-trained knowledge
198
  for param in self.llama.parameters():
 
126
  attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
127
  elif self.masking == 'bidirectional_masked':
128
  base_mask = torch.ones((seq_len, seq_len), device=hidden_states.device, dtype=torch.bool)
129
+ base_mask[:, :].fill_diagonal_(False) # ✅ Apply diagonal masking only in 2D
130
  attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
131
  else: # unidirectional
132
  # 🚀 Standard autoregressive (causal) mask
 
192
  self.llama.resize_token_embeddings(config.vocab_size)
193
 
194
  for i, layer in enumerate(self.llama.model.layers):
195
+ layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking='bidirectional_masked')
196
 
197
  # Freeze Llama to retain pre-trained knowledge
198
  for param in self.llama.parameters():