|
import torch |
|
from torch import Tensor, cat, nn |
|
|
|
|
|
class SpanMeanPooler(nn.Module): |
|
"""Pooler that takes the mean hidden state over spans. If the start or end index is negative, a |
|
learned embedding is used. The indices are expected to have the shape [batch_size, |
|
num_indices]. |
|
|
|
The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim]. |
|
Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler, |
|
i.e. we changed the aggregation method from torch.amax to torch.mean. |
|
|
|
Args: |
|
input_dim: The input dimension of the hidden state. |
|
num_indices: The number of indices to pool. |
|
|
|
Returns: |
|
The pooled hidden states with shape [batch_size, num_indices * input_dim]. |
|
""" |
|
|
|
def __init__(self, input_dim: int, num_indices: int = 2, **kwargs): |
|
super().__init__(**kwargs) |
|
self.input_dim = input_dim |
|
self.num_indices = num_indices |
|
self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim)) |
|
nn.init.normal_(self.missing_embeddings) |
|
|
|
def forward( |
|
self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs |
|
) -> Tensor: |
|
batch_size, seq_len, hidden_size = hidden_state.shape |
|
if start_indices.shape[1] != self.num_indices: |
|
raise ValueError( |
|
f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" |
|
) |
|
|
|
if end_indices.shape[1] != self.num_indices: |
|
raise ValueError( |
|
f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]" |
|
) |
|
|
|
|
|
mask_both_positive = (start_indices >= 0) & (end_indices >= 0) |
|
mask_start_before_end = start_indices < end_indices |
|
mask_valid = mask_start_before_end | ~mask_both_positive |
|
if not torch.all(mask_valid): |
|
raise ValueError( |
|
f"values in start_indices have to be smaller than respective values in " |
|
f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}" |
|
) |
|
|
|
|
|
result = torch.zeros( |
|
batch_size, hidden_size * self.num_indices, device=hidden_state.device |
|
) |
|
for batch_idx in range(batch_size): |
|
current_start_indices = start_indices[batch_idx] |
|
current_end_indices = end_indices[batch_idx] |
|
current_embeddings = [ |
|
( |
|
torch.mean( |
|
hidden_state[ |
|
batch_idx, current_start_indices[i] : current_end_indices[i], : |
|
], |
|
dim=0, |
|
) |
|
if current_start_indices[i] >= 0 and current_end_indices[i] >= 0 |
|
else self.missing_embeddings[i] |
|
) |
|
for i in range(self.num_indices) |
|
] |
|
result[batch_idx] = cat(current_embeddings, 0) |
|
|
|
return result |
|
|
|
@property |
|
def output_dim(self) -> int: |
|
return self.input_dim * self.num_indices |
|
|