Spaces:
Running on Zero

Ruurd commited on
Commit
1723639
·
verified ·
1 Parent(s): 5213031

Implement improved attention masking for bidirectional_masked

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +9 -3
llama_diffusion_model.py CHANGED
@@ -47,10 +47,16 @@ class BidirectionalLlamaAttention(LlamaAttention):
47
  key_states = self.repeat_kv(key, module.num_key_value_groups)
48
  value_states = self.repeat_kv(value, module.num_key_value_groups)
49
 
50
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
 
 
 
 
51
  if attention_mask is not None:
52
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
53
- attn_weights = attn_weights + causal_mask
 
 
54
 
55
  attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
56
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
 
47
  key_states = self.repeat_kv(key, module.num_key_value_groups)
48
  value_states = self.repeat_kv(value, module.num_key_value_groups)
49
 
50
+ # attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
51
+ # if attention_mask is not None:
52
+ # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
53
+ # attn_weights = attn_weights + causal_mask
54
+
55
  if attention_mask is not None:
56
+ # Convert bool -> float with -inf where masked
57
+ attn_mask = attention_mask.masked_fill(~attention_mask, float('-inf')).to(query.dtype)
58
+ attn_weights = attn_weights + attn_mask
59
+
60
 
61
  attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
62
  attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)