Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import torch | |
import torch.nn.functional as F | |
logger = logging.getLogger(__name__) | |
def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): | |
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) | |
return F.nll_loss( | |
lprobs, | |
target, | |
ignore_index=ignore_index, | |
reduction=reduction, | |
) | |
try: | |
import xentropy_cuda | |
from apex.contrib import xentropy | |
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): | |
if logits.device == torch.device("cpu"): | |
return _cross_entropy_pytorch(logits, target, ignore_index, reduction) | |
else: | |
if not getattr(cross_entropy, "_has_logged_once", False): | |
logger.info("using fused cross entropy") | |
cross_entropy._has_logged_once = True | |
half_to_float = logits.dtype == torch.half | |
losses = xentropy.SoftmaxCrossEntropyLoss.apply( | |
logits, | |
target, | |
0.0, | |
ignore_index, | |
half_to_float, | |
) | |
if reduction == "sum": | |
return losses.sum() | |
elif reduction == "mean": | |
if ignore_index >= 0: | |
return losses.sum() / target.ne(ignore_index).sum() | |
else: | |
return losses.mean() | |
elif reduction == "none": | |
return losses | |
else: | |
raise NotImplementedError | |
except ImportError: | |
def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): | |
return _cross_entropy_pytorch(logits, target, ignore_index, reduction) | |