Spaces:
Running on Zero

Ruurd commited on
Commit
a721355
·
verified ·
1 Parent(s): 8851563

Deal with float values

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +2 -1
llama_diffusion_model.py CHANGED
@@ -31,7 +31,8 @@ class BidirectionalLlamaAttention(LlamaAttention):
31
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
32
 
33
  if attention_mask is not None:
34
- attn_mask = attention_mask.masked_fill(~attention_mask, float('-inf')).to(query.dtype)
 
35
  attn_weights = attn_weights + attn_mask
36
 
37
  attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
 
31
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
32
 
33
  if attention_mask is not None:
34
+ attn_mask = (1.0 - attention_mask) * float('-inf')
35
+ attn_mask = attn_mask.to(dtype=query.dtype)
36
  attn_weights = attn_weights + attn_mask
37
 
38
  attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)