Spaces:
Running on Zero

Ruurd commited on
Commit
238c8f8
·
verified ·
1 Parent(s): f2ca6a6

Update llama_diffusion_model.py

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +3 -3
llama_diffusion_model.py CHANGED
@@ -122,12 +122,12 @@ class CustomTransformerModel(PreTrainedModel):
122
  device = input_ids.device
123
 
124
  masking_type = getattr(self.config, "masking_type", "bidirectional_masked")
125
- if self.config.masking_type == 'bidirectional':
126
  base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
127
- elif self.config.masking_type == 'bidirectional_masked':
128
  base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
129
  base_mask.fill_diagonal_(False)
130
- elif self.config.masking_type == 'unidirectional':
131
  base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
132
  else:
133
  raise ValueError(f"Unknown masking type: {self.config.masking_type}")
 
122
  device = input_ids.device
123
 
124
  masking_type = getattr(self.config, "masking_type", "bidirectional_masked")
125
+ if masking_type == 'bidirectional':
126
  base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
127
+ elif masking_type == 'bidirectional_masked':
128
  base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
129
  base_mask.fill_diagonal_(False)
130
+ elif masking_type == 'unidirectional':
131
  base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
132
  else:
133
  raise ValueError(f"Unknown masking type: {self.config.masking_type}")