Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class LanguageModelCriterion(nn.Module): | |
def __init__(self): | |
super(LanguageModelCriterion, self).__init__() | |
def forward(self, input, target, mask): | |
# truncate to the same size | |
target = target[:, :input.size(1)] | |
mask = mask[:, :input.size(1)] | |
output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask | |
output = torch.sum(output) / torch.sum(mask) | |
return output | |
def compute_loss(output, reports_ids, reports_masks): | |
criterion = LanguageModelCriterion() | |
loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean() | |
return loss | |