Spaces:
Running on Zero

Ruurd commited on
Commit
22370b2
·
verified ·
1 Parent(s): 1723639

Fix attention_weights referenced before assigned bug

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +2 -1
llama_diffusion_model.py CHANGED
@@ -51,7 +51,8 @@ class BidirectionalLlamaAttention(LlamaAttention):
51
  # if attention_mask is not None:
52
  # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
53
  # attn_weights = attn_weights + causal_mask
54
-
 
55
  if attention_mask is not None:
56
  # Convert bool -> float with -inf where masked
57
  attn_mask = attention_mask.masked_fill(~attention_mask, float('-inf')).to(query.dtype)
 
51
  # if attention_mask is not None:
52
  # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
53
  # attn_weights = attn_weights + causal_mask
54
+
55
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
56
  if attention_mask is not None:
57
  # Convert bool -> float with -inf where masked
58
  attn_mask = attention_mask.masked_fill(~attention_mask, float('-inf')).to(query.dtype)