Spaces:
Running
on
Zero
Running
on
Zero
Deal with float values
Browse files- llama_diffusion_model.py +2 -1
llama_diffusion_model.py
CHANGED
@@ -31,7 +31,8 @@ class BidirectionalLlamaAttention(LlamaAttention):
|
|
31 |
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
32 |
|
33 |
if attention_mask is not None:
|
34 |
-
attn_mask =
|
|
|
35 |
attn_weights = attn_weights + attn_mask
|
36 |
|
37 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
|
|
|
31 |
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
32 |
|
33 |
if attention_mask is not None:
|
34 |
+
attn_mask = (1.0 - attention_mask) * float('-inf')
|
35 |
+
attn_mask = attn_mask.to(dtype=query.dtype)
|
36 |
attn_weights = attn_weights + attn_mask
|
37 |
|
38 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)
|