Spaces:
Running
on
Zero
Running
on
Zero
Fix attention_weights referenced before assigned bug
Browse files- llama_diffusion_model.py +2 -1
llama_diffusion_model.py
CHANGED
@@ -51,7 +51,8 @@ class BidirectionalLlamaAttention(LlamaAttention):
|
|
51 |
# if attention_mask is not None:
|
52 |
# causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
53 |
# attn_weights = attn_weights + causal_mask
|
54 |
-
|
|
|
55 |
if attention_mask is not None:
|
56 |
# Convert bool -> float with -inf where masked
|
57 |
attn_mask = attention_mask.masked_fill(~attention_mask, float('-inf')).to(query.dtype)
|
|
|
51 |
# if attention_mask is not None:
|
52 |
# causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
53 |
# attn_weights = attn_weights + causal_mask
|
54 |
+
|
55 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
56 |
if attention_mask is not None:
|
57 |
# Convert bool -> float with -inf where masked
|
58 |
attn_mask = attention_mask.masked_fill(~attention_mask, float('-inf')).to(query.dtype)
|