Spaces:
Sleeping
Sleeping
File size: 667 Bytes
6e32a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
import torch.nn as nn
class LanguageModelCriterion(nn.Module):
def __init__(self):
super(LanguageModelCriterion, self).__init__()
def forward(self, input, target, mask):
# truncate to the same size
target = target[:, :input.size(1)]
mask = mask[:, :input.size(1)]
output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask
output = torch.sum(output) / torch.sum(mask)
return output
def compute_loss(output, reports_ids, reports_masks):
criterion = LanguageModelCriterion()
loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean()
return loss
|