Spaces:
Running
on
Zero
Running
on
Zero
Update llama_diffusion_model.py
Browse files- llama_diffusion_model.py +3 -3
llama_diffusion_model.py
CHANGED
@@ -122,12 +122,12 @@ class CustomTransformerModel(PreTrainedModel):
|
|
122 |
device = input_ids.device
|
123 |
|
124 |
masking_type = getattr(self.config, "masking_type", "bidirectional_masked")
|
125 |
-
if
|
126 |
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
|
127 |
-
elif
|
128 |
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
|
129 |
base_mask.fill_diagonal_(False)
|
130 |
-
elif
|
131 |
base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
132 |
else:
|
133 |
raise ValueError(f"Unknown masking type: {self.config.masking_type}")
|
|
|
122 |
device = input_ids.device
|
123 |
|
124 |
masking_type = getattr(self.config, "masking_type", "bidirectional_masked")
|
125 |
+
if masking_type == 'bidirectional':
|
126 |
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
|
127 |
+
elif masking_type == 'bidirectional_masked':
|
128 |
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
|
129 |
base_mask.fill_diagonal_(False)
|
130 |
+
elif masking_type == 'unidirectional':
|
131 |
base_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
|
132 |
else:
|
133 |
raise ValueError(f"Unknown masking type: {self.config.masking_type}")
|