import torch def penalty_builder(penalty_config): if penalty_config == '': return lambda x, y: y pen_type, alpha = penalty_config.split('_') alpha = float(alpha) if pen_type == 'wu': return lambda x, y: length_wu(x, y, alpha) if pen_type == 'avg': return lambda x, y: length_average(x, y, alpha) def length_wu(length, logprobs, alpha=0.): """ NMT length re-ranking score from "Google's Neural Machine Translation System" :cite:`wu2016google`. """ modifier = (((5 + length) ** alpha) / ((5 + 1) ** alpha)) return logprobs / modifier def length_average(length, logprobs, alpha=0.): """ Returns the average probability of tokens in a sequence. """ return logprobs / length def split_tensors(n, x): if torch.is_tensor(x): assert x.shape[0] % n == 0 x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1) elif type(x) is list or type(x) is tuple: x = [split_tensors(n, _) for _ in x] elif x is None: x = [None] * n return x def repeat_tensors(n, x): """ For a tensor of size Bx..., we repeat it n times, and make it Bnx... For collections, do nested repeat """ if torch.is_tensor(x): x = x.unsqueeze(1) # Bx1x... x = x.expand(-1, n, *([-1] * len(x.shape[2:]))) # Bxnx... x = x.reshape(x.shape[0] * n, *x.shape[2:]) # Bnx... elif type(x) is list or type(x) is tuple: x = [repeat_tensors(n, _) for _ in x] return x