File size: 3,293 Bytes
3133b5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from torch import Tensor, cat, nn
class SpanMeanPooler(nn.Module):
"""Pooler that takes the mean hidden state over spans. If the start or end index is negative, a
learned embedding is used. The indices are expected to have the shape [batch_size,
num_indices].
The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim].
Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler,
i.e. we changed the aggregation method from torch.amax to torch.mean.
Args:
input_dim: The input dimension of the hidden state.
num_indices: The number of indices to pool.
Returns:
The pooled hidden states with shape [batch_size, num_indices * input_dim].
"""
def __init__(self, input_dim: int, num_indices: int = 2, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_indices = num_indices
self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim))
nn.init.normal_(self.missing_embeddings)
def forward(
self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs
) -> Tensor:
batch_size, seq_len, hidden_size = hidden_state.shape
if start_indices.shape[1] != self.num_indices:
raise ValueError(
f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
)
if end_indices.shape[1] != self.num_indices:
raise ValueError(
f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
)
# check that start_indices are before end_indices
mask_both_positive = (start_indices >= 0) & (end_indices >= 0)
mask_start_before_end = start_indices < end_indices
mask_valid = mask_start_before_end | ~mask_both_positive
if not torch.all(mask_valid):
raise ValueError(
f"values in start_indices have to be smaller than respective values in "
f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}"
)
# times num_indices due to concat
result = torch.zeros(
batch_size, hidden_size * self.num_indices, device=hidden_state.device
)
for batch_idx in range(batch_size):
current_start_indices = start_indices[batch_idx]
current_end_indices = end_indices[batch_idx]
current_embeddings = [
(
torch.mean(
hidden_state[
batch_idx, current_start_indices[i] : current_end_indices[i], :
],
dim=0,
)
if current_start_indices[i] >= 0 and current_end_indices[i] >= 0
else self.missing_embeddings[i]
)
for i in range(self.num_indices)
]
result[batch_idx] = cat(current_embeddings, 0)
return result
@property
def output_dim(self) -> int:
return self.input_dim * self.num_indices
|