PromptNet / modules /loss.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
raw
history blame contribute delete
667 Bytes
import torch
import torch.nn as nn
class LanguageModelCriterion(nn.Module):
def __init__(self):
super(LanguageModelCriterion, self).__init__()
def forward(self, input, target, mask):
# truncate to the same size
target = target[:, :input.size(1)]
mask = mask[:, :input.size(1)]
output = -input.gather(2, target.long().unsqueeze(2)).squeeze(2) * mask
output = torch.sum(output) / torch.sum(mask)
return output
def compute_loss(output, reports_ids, reports_masks):
criterion = LanguageModelCriterion()
loss = criterion(output, reports_ids[:, 1:], reports_masks[:, 1:]).mean()
return loss