uthurumella's picture
Upload 69 files
72fc481 verified
import torch.nn.functional as F
import torch
import numpy as np
from parse_config import ConfigParser
import torch.nn as nn
from torch.autograd import Variable
import math
from utils import sigmoid_rampup, sigmoid_rampdown, cosine_rampup, cosine_rampdown, linear_rampup
def cross_entropy(output, target, M=3):
return F.cross_entropy(output, target)
class elr_plus_loss(nn.Module):
def __init__(self, num_examp, config, device, num_classes=10, beta=0.3):
super(elr_plus_loss, self).__init__()
self.config = config
self.pred_hist = (torch.zeros(num_examp, num_classes)).to(device)
self.q = 0
self.beta = beta
self.num_classes = num_classes
def forward(self, iteration, output, y_labeled):
y_pred = F.softmax(output,dim=1)
y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
if self.num_classes == 100:
y_labeled = y_labeled*self.q
y_labeled = y_labeled/(y_labeled).sum(dim=1,keepdim=True)
ce_loss = torch.mean(-torch.sum(y_labeled * F.log_softmax(output, dim=1), dim = -1))
reg = ((1-(self.q * y_pred).sum(dim=1)).log()).mean()
final_loss = ce_loss + sigmoid_rampup(iteration, self.config['coef_step'])*(self.config['train_loss']['args']['lambda']*reg)
return final_loss, y_pred.cpu().detach()
def update_hist(self, epoch, out, index= None, mix_index = ..., mixup_l = 1):
y_pred_ = F.softmax(out,dim=1)
self.pred_hist[index] = self.beta * self.pred_hist[index] + (1-self.beta) * y_pred_/(y_pred_).sum(dim=1,keepdim=True)
self.q = mixup_l * self.pred_hist[index] + (1-mixup_l) * self.pred_hist[index][mix_index]