Spaces:
Running
on
Zero
Running
on
Zero
Try out bidirectional_masked prediction
Browse files- llama_diffusion_model.py +2 -2
llama_diffusion_model.py
CHANGED
@@ -126,7 +126,7 @@ class BidirectionalLlamaAttention(LlamaAttention):
|
|
126 |
attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
|
127 |
elif self.masking == 'bidirectional_masked':
|
128 |
base_mask = torch.ones((seq_len, seq_len), device=hidden_states.device, dtype=torch.bool)
|
129 |
-
base_mask[:,
|
130 |
attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
|
131 |
else: # unidirectional
|
132 |
# 🚀 Standard autoregressive (causal) mask
|
@@ -192,7 +192,7 @@ class CustomTransformerModel(PreTrainedModel):
|
|
192 |
self.llama.resize_token_embeddings(config.vocab_size)
|
193 |
|
194 |
for i, layer in enumerate(self.llama.model.layers):
|
195 |
-
layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking='
|
196 |
|
197 |
# Freeze Llama to retain pre-trained knowledge
|
198 |
for param in self.llama.parameters():
|
|
|
126 |
attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
|
127 |
elif self.masking == 'bidirectional_masked':
|
128 |
base_mask = torch.ones((seq_len, seq_len), device=hidden_states.device, dtype=torch.bool)
|
129 |
+
base_mask[:, :].fill_diagonal_(False) # ✅ Apply diagonal masking only in 2D
|
130 |
attn_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone() # ✅ Copy for each batch
|
131 |
else: # unidirectional
|
132 |
# 🚀 Standard autoregressive (causal) mask
|
|
|
192 |
self.llama.resize_token_embeddings(config.vocab_size)
|
193 |
|
194 |
for i, layer in enumerate(self.llama.model.layers):
|
195 |
+
layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking='bidirectional_masked')
|
196 |
|
197 |
# Freeze Llama to retain pre-trained knowledge
|
198 |
for param in self.llama.parameters():
|