Spaces:
Sleeping
Sleeping
Dit-document-layout-analysis
/
unilm
/decoding
/IAD
/fairseq
/examples
/simultaneous_translation
/utils
/latency.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
class LatencyMetric(object): | |
def length_from_padding_mask(padding_mask, batch_first: bool = False): | |
dim = 1 if batch_first else 0 | |
return padding_mask.size(dim) - padding_mask.sum(dim=dim, keepdim=True) | |
def prepare_latency_metric( | |
self, | |
delays, | |
src_lens, | |
target_padding_mask=None, | |
batch_first: bool = False, | |
start_from_zero: bool = True, | |
): | |
assert len(delays.size()) == 2 | |
assert len(src_lens.size()) == 2 | |
if start_from_zero: | |
delays = delays + 1 | |
if batch_first: | |
# convert to batch_last | |
delays = delays.t() | |
src_lens = src_lens.t() | |
tgt_len, bsz = delays.size() | |
_, bsz_1 = src_lens.size() | |
if target_padding_mask is not None: | |
target_padding_mask = target_padding_mask.t() | |
tgt_len_1, bsz_2 = target_padding_mask.size() | |
assert tgt_len == tgt_len_1 | |
assert bsz == bsz_2 | |
assert bsz == bsz_1 | |
if target_padding_mask is None: | |
tgt_lens = tgt_len * delays.new_ones([1, bsz]).float() | |
else: | |
# 1, batch_size | |
tgt_lens = self.length_from_padding_mask(target_padding_mask, False).float() | |
delays = delays.masked_fill(target_padding_mask, 0) | |
return delays, src_lens, tgt_lens, target_padding_mask | |
def __call__( | |
self, | |
delays, | |
src_lens, | |
target_padding_mask=None, | |
batch_first: bool = False, | |
start_from_zero: bool = True, | |
): | |
delays, src_lens, tgt_lens, target_padding_mask = self.prepare_latency_metric( | |
delays, src_lens, target_padding_mask, batch_first, start_from_zero | |
) | |
return self.cal_metric(delays, src_lens, tgt_lens, target_padding_mask) | |
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): | |
""" | |
Expected sizes: | |
delays: tgt_len, batch_size | |
src_lens: 1, batch_size | |
target_padding_mask: tgt_len, batch_size | |
""" | |
raise NotImplementedError | |
class AverageProportion(LatencyMetric): | |
""" | |
Function to calculate Average Proportion from | |
Can neural machine translation do simultaneous translation? | |
(https://arxiv.org/abs/1606.02012) | |
Delays are monotonic steps, range from 1 to src_len. | |
Give src x tgt y, AP is calculated as: | |
AP = 1 / (|x||y]) sum_i^|Y| deleys_i | |
""" | |
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): | |
if target_padding_mask is not None: | |
AP = torch.sum( | |
delays.masked_fill(target_padding_mask, 0), dim=0, keepdim=True | |
) | |
else: | |
AP = torch.sum(delays, dim=0, keepdim=True) | |
AP = AP / (src_lens * tgt_lens) | |
return AP | |
class AverageLagging(LatencyMetric): | |
""" | |
Function to calculate Average Lagging from | |
STACL: Simultaneous Translation with Implicit Anticipation | |
and Controllable Latency using Prefix-to-Prefix Framework | |
(https://arxiv.org/abs/1810.08398) | |
Delays are monotonic steps, range from 1 to src_len. | |
Give src x tgt y, AP is calculated as: | |
AL = 1 / tau sum_i^tau delays_i - (i - 1) / gamma | |
Where | |
gamma = |y| / |x| | |
tau = argmin_i(delays_i = |x|) | |
""" | |
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): | |
# tau = argmin_i(delays_i = |x|) | |
tgt_len, bsz = delays.size() | |
lagging_padding_mask = delays >= src_lens | |
lagging_padding_mask = torch.nn.functional.pad( | |
lagging_padding_mask.t(), (1, 0) | |
).t()[:-1, :] | |
gamma = tgt_lens / src_lens | |
lagging = ( | |
delays | |
- torch.arange(delays.size(0)) | |
.unsqueeze(1) | |
.type_as(delays) | |
.expand_as(delays) | |
/ gamma | |
) | |
lagging.masked_fill_(lagging_padding_mask, 0) | |
tau = (1 - lagging_padding_mask.type_as(lagging)).sum(dim=0, keepdim=True) | |
AL = lagging.sum(dim=0, keepdim=True) / tau | |
return AL | |
class DifferentiableAverageLagging(LatencyMetric): | |
""" | |
Function to calculate Differentiable Average Lagging from | |
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation | |
(https://arxiv.org/abs/1906.05218) | |
Delays are monotonic steps, range from 0 to src_len-1. | |
(In the original paper thery are from 1 to src_len) | |
Give src x tgt y, AP is calculated as: | |
DAL = 1 / |Y| sum_i^|Y| delays'_i - (i - 1) / gamma | |
Where | |
delays'_i = | |
1. delays_i if i == 1 | |
2. max(delays_i, delays'_{i-1} + 1 / gamma) | |
""" | |
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): | |
tgt_len, bsz = delays.size() | |
gamma = tgt_lens / src_lens | |
new_delays = torch.zeros_like(delays) | |
for i in range(delays.size(0)): | |
if i == 0: | |
new_delays[i] = delays[i] | |
else: | |
new_delays[i] = torch.cat( | |
[ | |
new_delays[i - 1].unsqueeze(0) + 1 / gamma, | |
delays[i].unsqueeze(0), | |
], | |
dim=0, | |
).max(dim=0)[0] | |
DAL = ( | |
new_delays | |
- torch.arange(delays.size(0)) | |
.unsqueeze(1) | |
.type_as(delays) | |
.expand_as(delays) | |
/ gamma | |
) | |
if target_padding_mask is not None: | |
DAL = DAL.masked_fill(target_padding_mask, 0) | |
DAL = DAL.sum(dim=0, keepdim=True) / tgt_lens | |
return DAL | |
class LatencyMetricVariance(LatencyMetric): | |
def prepare_latency_metric( | |
self, | |
delays, | |
src_lens, | |
target_padding_mask=None, | |
batch_first: bool = True, | |
start_from_zero: bool = True, | |
): | |
assert batch_first | |
assert len(delays.size()) == 3 | |
assert len(src_lens.size()) == 2 | |
if start_from_zero: | |
delays = delays + 1 | |
# convert to batch_last | |
bsz, num_heads_x_layers, tgt_len = delays.size() | |
bsz_1, _ = src_lens.size() | |
assert bsz == bsz_1 | |
if target_padding_mask is not None: | |
bsz_2, tgt_len_1 = target_padding_mask.size() | |
assert tgt_len == tgt_len_1 | |
assert bsz == bsz_2 | |
if target_padding_mask is None: | |
tgt_lens = tgt_len * delays.new_ones([bsz, tgt_len]).float() | |
else: | |
# batch_size, 1 | |
tgt_lens = self.length_from_padding_mask(target_padding_mask, True).float() | |
delays = delays.masked_fill(target_padding_mask.unsqueeze(1), 0) | |
return delays, src_lens, tgt_lens, target_padding_mask | |
class VarianceDelay(LatencyMetricVariance): | |
def cal_metric(delays, src_lens, tgt_lens, target_padding_mask): | |
""" | |
delays : bsz, num_heads_x_layers, tgt_len | |
src_lens : bsz, 1 | |
target_lens : bsz, 1 | |
target_padding_mask: bsz, tgt_len or None | |
""" | |
if delays.size(1) == 1: | |
return delays.new_zeros([1]) | |
variance_delays = delays.var(dim=1) | |
if target_padding_mask is not None: | |
variance_delays.masked_fill_(target_padding_mask, 0) | |
return variance_delays.sum(dim=1, keepdim=True) / tgt_lens | |
class LatencyInference(object): | |
def __init__(self, start_from_zero=True): | |
self.metric_calculator = { | |
"differentiable_average_lagging": DifferentiableAverageLagging(), | |
"average_lagging": AverageLagging(), | |
"average_proportion": AverageProportion(), | |
} | |
self.start_from_zero = start_from_zero | |
def __call__(self, monotonic_step, src_lens): | |
""" | |
monotonic_step range from 0 to src_len. src_len means eos | |
delays: bsz, tgt_len | |
src_lens: bsz, 1 | |
""" | |
if not self.start_from_zero: | |
monotonic_step -= 1 | |
src_lens = src_lens | |
delays = monotonic_step.view( | |
monotonic_step.size(0), -1, monotonic_step.size(-1) | |
).max(dim=1)[0] | |
delays = delays.masked_fill(delays >= src_lens, 0) + (src_lens - 1).expand_as( | |
delays | |
).masked_fill(delays < src_lens, 0) | |
return_dict = {} | |
for key, func in self.metric_calculator.items(): | |
return_dict[key] = func( | |
delays.float(), | |
src_lens.float(), | |
target_padding_mask=None, | |
batch_first=True, | |
start_from_zero=True, | |
).t() | |
return return_dict | |
class LatencyTraining(object): | |
def __init__( | |
self, | |
avg_weight, | |
var_weight, | |
avg_type, | |
var_type, | |
stay_on_last_token, | |
average_method, | |
): | |
self.avg_weight = avg_weight | |
self.var_weight = var_weight | |
self.avg_type = avg_type | |
self.var_type = var_type | |
self.stay_on_last_token = stay_on_last_token | |
self.average_method = average_method | |
self.metric_calculator = { | |
"differentiable_average_lagging": DifferentiableAverageLagging(), | |
"average_lagging": AverageLagging(), | |
"average_proportion": AverageProportion(), | |
} | |
self.variance_calculator = { | |
"variance_delay": VarianceDelay(), | |
} | |
def expected_delays_from_attention( | |
self, attention, source_padding_mask=None, target_padding_mask=None | |
): | |
if type(attention) == list: | |
# bsz, num_heads, tgt_len, src_len | |
bsz, num_heads, tgt_len, src_len = attention[0].size() | |
attention = torch.cat(attention, dim=1) | |
bsz, num_heads_x_layers, tgt_len, src_len = attention.size() | |
# bsz * num_heads * num_layers, tgt_len, src_len | |
attention = attention.view(-1, tgt_len, src_len) | |
else: | |
# bsz * num_heads * num_layers, tgt_len, src_len | |
bsz, tgt_len, src_len = attention.size() | |
num_heads_x_layers = 1 | |
attention = attention.view(-1, tgt_len, src_len) | |
if not self.stay_on_last_token: | |
residual_attention = 1 - attention[:, :, :-1].sum(dim=2, keepdim=True) | |
attention = torch.cat([attention[:, :, :-1], residual_attention], dim=2) | |
# bsz * num_heads_x_num_layers, tgt_len, src_len for MMA | |
steps = ( | |
torch.arange(1, 1 + src_len) | |
.unsqueeze(0) | |
.unsqueeze(1) | |
.expand_as(attention) | |
.type_as(attention) | |
) | |
if source_padding_mask is not None: | |
src_offset = ( | |
source_padding_mask.type_as(attention) | |
.sum(dim=1, keepdim=True) | |
.expand(bsz, num_heads_x_layers) | |
.contiguous() | |
.view(-1, 1) | |
) | |
src_lens = src_len - src_offset | |
if source_padding_mask[:, 0].any(): | |
# Pad left | |
src_offset = src_offset.view(-1, 1, 1) | |
steps = steps - src_offset | |
steps = steps.masked_fill(steps <= 0, 0) | |
else: | |
src_lens = attention.new_ones([bsz, num_heads_x_layers]) * src_len | |
src_lens = src_lens.view(-1, 1) | |
# bsz * num_heads_num_layers, tgt_len, src_len | |
expected_delays = ( | |
(steps * attention).sum(dim=2).view(bsz, num_heads_x_layers, tgt_len) | |
) | |
if target_padding_mask is not None: | |
expected_delays.masked_fill_(target_padding_mask.unsqueeze(1), 0) | |
return expected_delays, src_lens | |
def avg_loss(self, expected_delays, src_lens, target_padding_mask): | |
bsz, num_heads_x_layers, tgt_len = expected_delays.size() | |
target_padding_mask = ( | |
target_padding_mask.unsqueeze(1) | |
.expand_as(expected_delays) | |
.contiguous() | |
.view(-1, tgt_len) | |
) | |
if self.average_method == "average": | |
# bsz * tgt_len | |
expected_delays = expected_delays.mean(dim=1) | |
elif self.average_method == "weighted_average": | |
weights = torch.nn.functional.softmax(expected_delays, dim=1) | |
expected_delays = torch.sum(expected_delays * weights, dim=1) | |
elif self.average_method == "max": | |
# bsz * num_heads_x_num_layers, tgt_len | |
expected_delays = expected_delays.max(dim=1)[0] | |
else: | |
raise RuntimeError(f"{self.average_method} is not supported") | |
src_lens = src_lens.view(bsz, -1)[:, :1] | |
target_padding_mask = target_padding_mask.view(bsz, -1, tgt_len)[:, 0] | |
if self.avg_weight > 0.0: | |
if self.avg_type in self.metric_calculator: | |
average_delays = self.metric_calculator[self.avg_type]( | |
expected_delays, | |
src_lens, | |
target_padding_mask, | |
batch_first=True, | |
start_from_zero=False, | |
) | |
else: | |
raise RuntimeError(f"{self.avg_type} is not supported.") | |
# bsz * num_heads_x_num_layers, 1 | |
return self.avg_weight * average_delays.sum() | |
else: | |
return 0.0 | |
def var_loss(self, expected_delays, src_lens, target_padding_mask): | |
src_lens = src_lens.view(expected_delays.size(0), expected_delays.size(1))[ | |
:, :1 | |
] | |
if self.var_weight > 0.0: | |
if self.var_type in self.variance_calculator: | |
variance_delays = self.variance_calculator[self.var_type]( | |
expected_delays, | |
src_lens, | |
target_padding_mask, | |
batch_first=True, | |
start_from_zero=False, | |
) | |
else: | |
raise RuntimeError(f"{self.var_type} is not supported.") | |
return self.var_weight * variance_delays.sum() | |
else: | |
return 0.0 | |
def loss(self, attention, source_padding_mask=None, target_padding_mask=None): | |
expected_delays, src_lens = self.expected_delays_from_attention( | |
attention, source_padding_mask, target_padding_mask | |
) | |
latency_loss = 0 | |
latency_loss += self.avg_loss(expected_delays, src_lens, target_padding_mask) | |
latency_loss += self.var_loss(expected_delays, src_lens, target_padding_mask) | |
return latency_loss | |