Nicolas-BZRD commited on
Commit
f4dc0e6
·
verified ·
1 Parent(s): 919fbb3

Fix: flash_attention_2 mask

Browse files
Files changed (1) hide show
  1. modeling_eurobert.py +11 -11
modeling_eurobert.py CHANGED
@@ -26,15 +26,15 @@ import torch
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
 
29
- from transformers.activations import ACT2FN
30
- from transformers.cache_utils import Cache, StaticCache
31
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
32
- from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
33
- from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput
34
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
35
- from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
- from transformers.processing_utils import Unpack
37
- from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
  from .configuration_eurobert import EuroBertConfig
39
 
40
 
@@ -224,7 +224,7 @@ EUROBERT_START_DOCSTRING = r"""
224
 
225
 
226
  @add_start_docstrings(
227
- "The bare ModernBert Model outputting raw hidden-states without any specific head on top.",
228
  EUROBERT_START_DOCSTRING,
229
  )
230
  class EuroBertPreTrainedModel(PreTrainedModel):
@@ -523,7 +523,7 @@ class EuroBertModel(EuroBertPreTrainedModel):
523
  if inputs_embeds is None:
524
  inputs_embeds = self.embed_tokens(input_ids)
525
 
526
- if attention_mask is not None:
527
  mask = self.mask_converter.to_4d(attention_mask, attention_mask.shape[1], inputs_embeds.dtype)
528
  else:
529
  mask = None
 
26
  from torch import nn
27
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
 
29
+ from ...activations import ACT2FN
30
+ from ...cache_utils import Cache, StaticCache
31
+ from ...modeling_attn_mask_utils import AttentionMaskConverter
32
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
33
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput
34
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
35
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from ...processing_utils import Unpack
37
+ from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
  from .configuration_eurobert import EuroBertConfig
39
 
40
 
 
224
 
225
 
226
  @add_start_docstrings(
227
+ "The bare EuroBERT Model outputting raw hidden-states without any specific head on top.",
228
  EUROBERT_START_DOCSTRING,
229
  )
230
  class EuroBertPreTrainedModel(PreTrainedModel):
 
523
  if inputs_embeds is None:
524
  inputs_embeds = self.embed_tokens(input_ids)
525
 
526
+ if attention_mask is not None and self.config._attn_implementation != "flash_attention_2":
527
  mask = self.mask_converter.to_4d(attention_mask, attention_mask.shape[1], inputs_embeds.dtype)
528
  else:
529
  mask = None