Spaces:
Running
on
Zero
Running
on
Zero
New masking implementation
Browse files- llama_diffusion_model.py +6 -5
llama_diffusion_model.py
CHANGED
@@ -97,8 +97,8 @@ class CustomTransformerModel(PreTrainedModel):
|
|
97 |
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
|
98 |
self.llama.resize_token_embeddings(config.vocab_size)
|
99 |
|
100 |
-
for i, layer in enumerate(self.llama.model.layers):
|
101 |
-
|
102 |
|
103 |
for param in self.llama.parameters():
|
104 |
param.requires_grad = False
|
@@ -113,7 +113,7 @@ class CustomTransformerModel(PreTrainedModel):
|
|
113 |
|
114 |
self.llama = get_peft_model(self.llama, lora_config)
|
115 |
self.llama.print_trainable_parameters()
|
116 |
-
self.llama = self.llama.to(torch.float16)
|
117 |
|
118 |
def forward(self, input_ids, labels=None, **kwargs):
|
119 |
batch_size, seq_len = input_ids.shape
|
@@ -121,8 +121,8 @@ class CustomTransformerModel(PreTrainedModel):
|
|
121 |
|
122 |
# Build attention mask
|
123 |
device = input_ids.device
|
124 |
-
|
125 |
-
masking_type = getattr(self.config, "masking_type", "
|
126 |
if masking_type == 'bidirectional':
|
127 |
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
|
128 |
elif masking_type == 'bidirectional_masked':
|
@@ -142,6 +142,7 @@ class CustomTransformerModel(PreTrainedModel):
|
|
142 |
input_ids,
|
143 |
attention_mask=attention_mask,
|
144 |
output_hidden_states=True,
|
|
|
145 |
**kwargs
|
146 |
)
|
147 |
|
|
|
97 |
self.llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B", torch_dtype=torch.float16, device_map="auto", token=hf_token)
|
98 |
self.llama.resize_token_embeddings(config.vocab_size)
|
99 |
|
100 |
+
# for i, layer in enumerate(self.llama.model.layers):
|
101 |
+
# layer.self_attn = BidirectionalLlamaAttention(layer.self_attn, masking=config.masking_type)
|
102 |
|
103 |
for param in self.llama.parameters():
|
104 |
param.requires_grad = False
|
|
|
113 |
|
114 |
self.llama = get_peft_model(self.llama, lora_config)
|
115 |
self.llama.print_trainable_parameters()
|
116 |
+
# self.llama = self.llama.to(torch.float16)
|
117 |
|
118 |
def forward(self, input_ids, labels=None, **kwargs):
|
119 |
batch_size, seq_len = input_ids.shape
|
|
|
121 |
|
122 |
# Build attention mask
|
123 |
device = input_ids.device
|
124 |
+
|
125 |
+
masking_type = getattr(self.config, "masking_type", "bidirectional_masked")
|
126 |
if masking_type == 'bidirectional':
|
127 |
base_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
|
128 |
elif masking_type == 'bidirectional_masked':
|
|
|
142 |
input_ids,
|
143 |
attention_mask=attention_mask,
|
144 |
output_hidden_states=True,
|
145 |
+
use_cache=False,
|
146 |
**kwargs
|
147 |
)
|
148 |
|