|
|
|
|
|
from typing import Dict |
|
|
|
import torch |
|
|
|
|
|
def get_text_field_mask(text_field_tensors: torch.Tensor, |
|
num_wrapping_dims: int = 0) -> torch.LongTensor: |
|
|
|
tensor_dims = [(tensor.dim(), tensor) for tensor in text_field_tensors.values()] |
|
tensor_dims.sort(key=lambda x: x[0]) |
|
|
|
smallest_dim = tensor_dims[0][0] - num_wrapping_dims |
|
if smallest_dim == 2: |
|
token_tensor = tensor_dims[0][1] |
|
return (token_tensor != 0).long() |
|
elif smallest_dim == 3: |
|
character_tensor = tensor_dims[0][1] |
|
return ((character_tensor > 0).long().sum(dim=-1) > 0).long() |
|
else: |
|
raise ValueError("Expected a tensor with dimension 2 or 3, found {}".format(smallest_dim)) |
|
|
|
|
|
if __name__ == '__main__': |
|
pass |
|
|