Spaces:
Sleeping
Sleeping
File size: 1,533 Bytes
6e32a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
def penalty_builder(penalty_config):
if penalty_config == '':
return lambda x, y: y
pen_type, alpha = penalty_config.split('_')
alpha = float(alpha)
if pen_type == 'wu':
return lambda x, y: length_wu(x, y, alpha)
if pen_type == 'avg':
return lambda x, y: length_average(x, y, alpha)
def length_wu(length, logprobs, alpha=0.):
"""
NMT length re-ranking score from
"Google's Neural Machine Translation System" :cite:`wu2016google`.
"""
modifier = (((5 + length) ** alpha) /
((5 + 1) ** alpha))
return logprobs / modifier
def length_average(length, logprobs, alpha=0.):
"""
Returns the average probability of tokens in a sequence.
"""
return logprobs / length
def split_tensors(n, x):
if torch.is_tensor(x):
assert x.shape[0] % n == 0
x = x.reshape(x.shape[0] // n, n, *x.shape[1:]).unbind(1)
elif type(x) is list or type(x) is tuple:
x = [split_tensors(n, _) for _ in x]
elif x is None:
x = [None] * n
return x
def repeat_tensors(n, x):
"""
For a tensor of size Bx..., we repeat it n times, and make it Bnx...
For collections, do nested repeat
"""
if torch.is_tensor(x):
x = x.unsqueeze(1) # Bx1x...
x = x.expand(-1, n, *([-1] * len(x.shape[2:]))) # Bxnx...
x = x.reshape(x.shape[0] * n, *x.shape[2:]) # Bnx...
elif type(x) is list or type(x) is tuple:
x = [repeat_tensors(n, _) for _ in x]
return x
|