Spaces:
Running on Zero

Ruurd commited on
Commit
7d7b6d7
·
verified ·
1 Parent(s): 09a7f62

Changed to bidirectional

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +2 -2
llama_diffusion_model.py CHANGED
@@ -77,7 +77,7 @@ class BidirectionalLlamaAttention(LlamaAttention):
77
 
78
  class CustomTransformerConfig(PretrainedConfig):
79
  def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
80
- max_position_embeddings=4096, masking_type="bidirectional_masked", **kwargs):
81
  super().__init__(**kwargs)
82
  self.vocab_size = vocab_size
83
  self.hidden_size = hidden_size
@@ -122,7 +122,7 @@ class CustomTransformerModel(PreTrainedModel):
122
  # Build attention mask
123
  device = input_ids.device
124
 
125
- masking_type = getattr(self.config, "masking_type", "bidirectional_masked")
126
  if masking_type == 'bidirectional':
127
  base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
128
  elif masking_type == 'bidirectional_masked':
 
77
 
78
  class CustomTransformerConfig(PretrainedConfig):
79
  def __init__(self, vocab_size=128256, hidden_size=4096, num_layers=32, num_heads=32, prediction_chunk=256, dropout=0,
80
+ max_position_embeddings=4096, masking_type="bidirectional", **kwargs):
81
  super().__init__(**kwargs)
82
  self.vocab_size = vocab_size
83
  self.hidden_size = hidden_size
 
122
  # Build attention mask
123
  device = input_ids.device
124
 
125
+ masking_type = getattr(self.config, "masking_type", "bidirectional")
126
  if masking_type == 'bidirectional':
127
  base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
128
  elif masking_type == 'bidirectional_masked':