File size: 3,293 Bytes
3133b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from torch import Tensor, cat, nn


class SpanMeanPooler(nn.Module):
    """Pooler that takes the mean hidden state over spans. If the start or end index is negative, a
    learned embedding is used. The indices are expected to have the shape [batch_size,
    num_indices].

    The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim].
    Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler,
    i.e. we changed the aggregation method from torch.amax to torch.mean.

    Args:
        input_dim: The input dimension of the hidden state.
        num_indices: The number of indices to pool.

    Returns:
        The pooled hidden states with shape [batch_size, num_indices * input_dim].
    """

    def __init__(self, input_dim: int, num_indices: int = 2, **kwargs):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.num_indices = num_indices
        self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim))
        nn.init.normal_(self.missing_embeddings)

    def forward(
        self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs
    ) -> Tensor:
        batch_size, seq_len, hidden_size = hidden_state.shape
        if start_indices.shape[1] != self.num_indices:
            raise ValueError(
                f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
            )

        if end_indices.shape[1] != self.num_indices:
            raise ValueError(
                f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
            )

        # check that start_indices are before end_indices
        mask_both_positive = (start_indices >= 0) & (end_indices >= 0)
        mask_start_before_end = start_indices < end_indices
        mask_valid = mask_start_before_end | ~mask_both_positive
        if not torch.all(mask_valid):
            raise ValueError(
                f"values in start_indices have to be smaller than respective values in "
                f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}"
            )

        # times num_indices due to concat
        result = torch.zeros(
            batch_size, hidden_size * self.num_indices, device=hidden_state.device
        )
        for batch_idx in range(batch_size):
            current_start_indices = start_indices[batch_idx]
            current_end_indices = end_indices[batch_idx]
            current_embeddings = [
                (
                    torch.mean(
                        hidden_state[
                            batch_idx, current_start_indices[i] : current_end_indices[i], :
                        ],
                        dim=0,
                    )
                    if current_start_indices[i] >= 0 and current_end_indices[i] >= 0
                    else self.missing_embeddings[i]
                )
                for i in range(self.num_indices)
            ]
            result[batch_idx] = cat(current_embeddings, 0)

        return result

    @property
    def output_dim(self) -> int:
        return self.input_dim * self.num_indices