diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..a4ac8416ef1eb85474259bcd10552ed395713b08 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+images/clean_label_simplexheatmap2.gif filter=lfs diff=lfs merge=lfs -text
+images/false_label_simplexheatmap.gif filter=lfs diff=lfs merge=lfs -text
+images/illustration_of_ELR.png filter=lfs diff=lfs merge=lfs -text
+images/simplexheatmap.gif filter=lfs diff=lfs merge=lfs -text
diff --git a/ELR/README.md b/ELR/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6f16cb083c568865bfd9719b2d526fb82d6a6805
--- /dev/null
+++ b/ELR/README.md
@@ -0,0 +1,44 @@
+# ELR
+This is an official PyTorch implementation of ELR method proposed in [Early-Learning Regularization Prevents Memorization of Noisy Labels](https://arxiv.org/abs/2007.00151).
+
+
+## Usage
+Train the network on the Symmmetric Noise CIFAR-10 dataset (noise rate = 0.8):
+
+```
+python train.py -c config_cifar10.json --percent 0.8
+```
+Train the network on the Asymmmetric Noise CIFAR-10 dataset (noise rate = 0.4):
+
+```
+python train.py -c config_cifar10_asym.json --percent 0.4 --asym 1
+```
+
+Train the network on the Asymmmetric Noise CIFAR-100 dataset (noise rate = 0.4):
+
+```
+python train.py -c config_cifar100.json --percent 0.4 --asym 1
+```
+
+The config files can be modified to adjust hyperparameters and optimization settings.
+
+## Results
+### CIFAR10
+
+
+| Method | 20% | 40% | 60% | 80% | 40% Asym |
+| ---------------------- | ----------- | ----------- | ----------- | ----------- | ----------- |
+| ELR | 91.16% | 89.15% | 86.12% | 73.86% | 90.12% |
+| ELR (cosine annealing) | 91.12% | 91.43% | 88.87% | 80.69% | 90.35% |
+
+### CIAFAR100
+
+| Method | 20% | 40% | 60% | 80% | 40% Asym |
+| ---------------------- | ----------- | ----------- | ----------- | ----------- | ----------- |
+| ELR | 74.21% | 68.28% | 59.28% | 29.78% | 73.71% |
+| ELR (cosine annealing) | 74.68% | 68.43% | 60.05% | 30.27% | 73.96% |
+
+
+
+## References
+- S. Liu, J. Niles-Weed, N. Razavian and C. Fernandez-Granda "Early-Learning Regularization Prevents Memorization of Noisy Labels", 2020
diff --git a/ELR/base/__init__.py b/ELR/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e28cb1d23b331d44d7ad6e26dd5906f1f18fb657
--- /dev/null
+++ b/ELR/base/__init__.py
@@ -0,0 +1,3 @@
+from .base_data_loader import *
+from .base_model import *
+from .base_trainer import *
diff --git a/ELR/base/base_data_loader.py b/ELR/base/base_data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb7a8b410fc1e3c813db7790251ec23eb51f1a65
--- /dev/null
+++ b/ELR/base/base_data_loader.py
@@ -0,0 +1,83 @@
+from typing import Tuple, Union, Optional
+
+import numpy as np
+from torch.utils.data import DataLoader
+from torch.utils.data.dataloader import default_collate
+from torch.utils.data.sampler import SubsetRandomSampler
+
+
+class BaseDataLoader(DataLoader):
+ """
+ Base class for all data loaders
+ """
+ valid_sampler: Optional[SubsetRandomSampler]
+ sampler: Optional[SubsetRandomSampler]
+
+ def __init__(self, train_dataset, batch_size, shuffle, validation_split: float, num_workers, pin_memory,
+ collate_fn=default_collate, val_dataset=None):
+ self.collate_fn = collate_fn
+ self.validation_split = validation_split
+ self.shuffle = shuffle
+ self.val_dataset = val_dataset
+
+ self.batch_idx = 0
+ self.n_samples = len(train_dataset) if val_dataset is None else len(train_dataset) + len(val_dataset)
+ self.init_kwargs = {
+ 'dataset': train_dataset,
+ 'batch_size': batch_size,
+ 'shuffle': self.shuffle,
+ 'collate_fn': collate_fn,
+ 'num_workers': num_workers,
+ 'pin_memory': pin_memory
+ }
+ if val_dataset is None:
+ self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
+ super().__init__(sampler=self.sampler, **self.init_kwargs)
+ else:
+ super().__init__(**self.init_kwargs)
+
+ def _split_sampler(self, split) -> Union[Tuple[None, None], Tuple[SubsetRandomSampler, SubsetRandomSampler]]:
+ if split == 0.0:
+ return None, None
+
+ idx_full = np.arange(self.n_samples)
+
+ np.random.seed(0)
+ np.random.shuffle(idx_full)
+
+ if isinstance(split, int):
+ assert split > 0
+ assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
+ len_valid = split
+ else:
+ len_valid = int(self.n_samples * split)
+
+ valid_idx = idx_full[0:len_valid]
+ train_idx = np.delete(idx_full, np.arange(0, len_valid))
+
+ train_sampler = SubsetRandomSampler(train_idx)
+ valid_sampler = SubsetRandomSampler(valid_idx)
+ print(f"Train: {len(train_sampler)} Val: {len(valid_sampler)}")
+
+ # turn off shuffle option which is mutually exclusive with sampler
+ self.shuffle = False
+ self.n_samples = len(train_idx)
+
+ return train_sampler, valid_sampler
+
+ def split_validation(self, bs = 1000):
+ if self.val_dataset is not None:
+ kwargs = {
+ 'dataset': self.val_dataset,
+ 'batch_size': bs,
+ 'shuffle': False,
+ 'collate_fn': self.collate_fn,
+ 'num_workers': self.num_workers
+ }
+ return DataLoader(**kwargs)
+ else:
+ print('Using sampler to split!')
+ return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
+
+
+
diff --git a/ELR/base/base_model.py b/ELR/base/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8caaade05c698e328b3df8af1de24ec9942992e1
--- /dev/null
+++ b/ELR/base/base_model.py
@@ -0,0 +1,25 @@
+import torch.nn as nn
+import numpy as np
+from abc import abstractmethod
+
+
+class BaseModel(nn.Module):
+ """
+ Base class for all models
+ """
+ @abstractmethod
+ def forward(self, *inputs):
+ """
+ Forward pass logic
+
+ :return: Model output
+ """
+ raise NotImplementedError
+
+ def __str__(self):
+ """
+ Model prints with number of trainable parameters
+ """
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
+ params = sum([np.prod(p.size()) for p in model_parameters])
+ return super().__str__() + '\nTrainable parameters: {}'.format(params)
diff --git a/ELR/base/base_trainer.py b/ELR/base/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3af3990fbf1439cd637177e87594fd6287dee234
--- /dev/null
+++ b/ELR/base/base_trainer.py
@@ -0,0 +1,195 @@
+from typing import TypeVar, List, Tuple
+import torch
+from tqdm import tqdm
+from abc import abstractmethod
+from numpy import inf
+from logger import TensorboardWriter
+import numpy as np
+
+class BaseTrainer:
+ """
+ Base class for all trainers
+ """
+ def __init__(self, model, train_criterion, metrics, optimizer, config, val_criterion):
+ self.config = config
+ self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
+
+ # setup GPU device if available, move model into configured device
+ self.device, device_ids = self._prepare_device(config['n_gpu'])
+ self.model = model.to(self.device)
+
+ if len(device_ids) > 1:
+ self.model = torch.nn.DataParallel(model, device_ids=device_ids)
+
+ self.train_criterion = train_criterion.to(self.device)
+
+
+ self.val_criterion = val_criterion
+ self.metrics = metrics
+
+ self.optimizer = optimizer
+
+ cfg_trainer = config['trainer']
+ self.epochs = cfg_trainer['epochs']
+ self.save_period = cfg_trainer['save_period']
+ self.monitor = cfg_trainer.get('monitor', 'off')
+
+ # configuration to monitor model performance and save best
+ if self.monitor == 'off':
+ self.mnt_mode = 'off'
+ self.mnt_best = 0
+ else:
+ self.mnt_mode, self.mnt_metric = self.monitor.split()
+ assert self.mnt_mode in ['min', 'max']
+
+ self.mnt_best = inf if self.mnt_mode == 'min' else -inf
+ self.early_stop = cfg_trainer.get('early_stop', inf)
+
+ self.start_epoch = 1
+
+ self.checkpoint_dir = config.save_dir
+
+ # setup visualization writer instance
+ self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
+
+ if config.resume is not None:
+ self._resume_checkpoint(config.resume)
+
+ @abstractmethod
+ def _train_epoch(self, epoch):
+ """
+ Training logic for an epoch
+
+ :param epoch: Current epochs number
+ """
+ raise NotImplementedError
+
+ def train(self):
+ """
+ Full training logic
+ """
+ not_improved_count = 0
+
+ for epoch in tqdm(range(self.start_epoch, self.epochs + 1), desc='Total progress: '):
+ if epoch <= self.config['trainer']['warmup']:
+ result = self._warmup_epoch(epoch)
+ else:
+ result= self._train_epoch(epoch)
+
+
+
+ # save logged informations into log dict
+ log = {'epoch': epoch}
+ for key, value in result.items():
+ if key == 'metrics':
+ log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
+ elif key == 'val_metrics':
+ log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
+ elif key == 'test_metrics':
+ log.update({'test_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
+ else:
+ log[key] = value
+
+ # print logged informations to the screen
+ for key, value in log.items():
+ self.logger.info(' {:15s}: {}'.format(str(key), value))
+
+ # evaluate model performance according to configured metric, save best checkpoint as model_best
+ best = False
+ if self.mnt_mode != 'off':
+ try:
+ # check whether model performance improved or not, according to specified metric(mnt_metric)
+ improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
+ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
+ except KeyError:
+ self.logger.warning("Warning: Metric '{}' is not found. "
+ "Model performance monitoring is disabled.".format(self.mnt_metric))
+ self.mnt_mode = 'off'
+ improved = False
+
+ if improved:
+ self.mnt_best = log[self.mnt_metric]
+ not_improved_count = 0
+ best = True
+ else:
+ not_improved_count += 1
+
+ if not_improved_count > self.early_stop:
+ self.logger.info("Validation performance didn\'t improve for {} epochs. "
+ "Training stops.".format(self.early_stop))
+ break
+
+ if epoch % self.save_period == 0:
+ self._save_checkpoint(epoch, save_best=best)
+
+ def _prepare_device(self, n_gpu_use):
+ """
+ setup GPU device if available, move model into configured device
+ """
+ n_gpu = torch.cuda.device_count()
+ if n_gpu_use > 0 and n_gpu == 0:
+ self.logger.warning("Warning: There\'s no GPU available on this machine,"
+ "training will be performed on CPU.")
+ n_gpu_use = 0
+ if n_gpu_use > n_gpu:
+ self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
+ "on this machine.".format(n_gpu_use, n_gpu))
+ n_gpu_use = n_gpu
+ device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
+ list_ids = list(range(n_gpu_use))
+ return device, list_ids
+
+ def _save_checkpoint(self, epoch, save_best=False):
+ """
+ Saving checkpoints
+
+ :param epoch: current epoch number
+ :param log: logging information of the epoch
+ :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
+ """
+ arch = type(self.model).__name__
+
+ state = {
+ 'arch': arch,
+ 'epoch': epoch,
+ 'state_dict': self.model.state_dict(),
+ 'optimizer': self.optimizer.state_dict(),
+ 'monitor_best': self.mnt_best
+ }
+ # filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
+ # torch.save(state, filename)
+ # self.logger.info("Saving checkpoint: {} ...".format(filename))
+ if save_best:
+ best_path = str(self.checkpoint_dir / 'model_best.pth')
+ torch.save(state, best_path)
+ self.logger.info("Saving current best: model_best.pth at: {} ...".format(best_path))
+
+
+ def _resume_checkpoint(self, resume_path):
+ """
+ Resume from saved checkpoints
+
+ :param resume_path: Checkpoint path to be resumed
+ """
+ resume_path = str(resume_path)
+ self.logger.info("Loading checkpoint: {} ...".format(resume_path))
+ checkpoint = torch.load(resume_path)
+ self.start_epoch = checkpoint['epoch'] + 1
+ self.mnt_best = checkpoint['monitor_best']
+
+ # load architecture params from checkpoint.
+ if checkpoint['config']['arch'] != self.config['arch']:
+ self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
+ "checkpoint. This may yield an exception while state_dict is being loaded.")
+ self.model.load_state_dict(checkpoint['state_dict'])
+
+ # load optimizer state from checkpoint only when optimizer type is not changed.
+ if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
+ self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
+ "Optimizer parameters not being resumed.")
+ else:
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+
+ self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
+
+
diff --git a/ELR/config_cifar10.json b/ELR/config_cifar10.json
new file mode 100644
index 0000000000000000000000000000000000000000..579f5ecafe15a085ad3fe85ca909a839cb82a749
--- /dev/null
+++ b/ELR/config_cifar10.json
@@ -0,0 +1,75 @@
+{
+ "name": "cifar10_resnet34_cosine",
+ "n_gpu": 1,
+ "seed": 123,
+
+ "arch": {
+ "type": "resnet34",
+ "args": {"num_classes":10}
+ },
+
+ "num_classes": 10,
+
+ "data_loader": {
+ "type": "CIFAR10DataLoader",
+ "args":{
+ "data_dir": "/dir_to_data",
+ "batch_size": 128,
+ "shuffle": true,
+ "num_batches": 0,
+ "validation_split": 0,
+ "num_workers": 8,
+ "pin_memory": true
+ }
+ },
+
+
+ "optimizer": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 1e-3
+ }
+ },
+
+ "train_loss": {
+ "type": "elr_loss",
+ "args":{
+ "beta": 0.7,
+ "lambda": 3
+ }
+ },
+
+ "val_loss": "cross_entropy",
+ "metrics": [
+ "my_metric", "my_metric2"
+ ],
+
+ "lr_scheduler": {
+ "type": "CosineAnnealingWarmRestarts",
+ "args": {
+ "T_0": 10,
+ "eta_min": 0.001
+ }
+ },
+
+ "trainer": {
+ "epochs": 150,
+ "warmup": 0,
+ "save_dir": "saved/",
+ "save_period": 1,
+ "verbosity": 2,
+ "label_dir": "saved/",
+ "monitor": "max val_my_metric",
+ "early_stop": 2000,
+ "tensorboard": false,
+ "mlflow": true,
+ "_percent": "Percentage of noise",
+ "percent": 0.8,
+ "_begin": "When to begin updating labels",
+ "begin": 0,
+ "_asym": "symmetric noise if false",
+ "asym": false
+ }
+}
diff --git a/ELR/config_cifar100.json b/ELR/config_cifar100.json
new file mode 100644
index 0000000000000000000000000000000000000000..4bdde0f94986f282d8780bcc3138cc265bae46cb
--- /dev/null
+++ b/ELR/config_cifar100.json
@@ -0,0 +1,75 @@
+{
+ "name": "cifar100_sy_60_resnet34",
+ "n_gpu": 1,
+ "seed": 123,
+
+ "arch": {
+ "type": "resnet34",
+ "args": {"num_classes":100}
+ },
+
+ "num_classes": 100,
+
+ "data_loader": {
+ "type": "CIFAR100DataLoader",
+ "args":{
+ "data_dir": "/dir_to_data",
+ "batch_size": 128,
+ "shuffle": true,
+ "num_batches": 0,
+ "validation_split": 0,
+ "num_workers": 8,
+ "pin_memory": true
+ }
+ },
+
+
+ "optimizer": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 1e-3
+ }
+ },
+
+ "train_loss": {
+ "type": "elr_loss",
+ "args":{
+ "beta": 0.9,
+ "lambda": 7
+ }
+ },
+
+ "val_loss": "cross_entropy",
+ "metrics": [
+ "my_metric", "my_metric2"
+ ],
+
+ "lr_scheduler": {
+ "type": "MultiStepLR",
+ "args": {
+ "milestones": [80,120],
+ "gamma": 0.01
+ }
+ },
+
+ "trainer": {
+ "epochs": 150,
+ "warmup": 0,
+ "save_dir": "saved/",
+ "save_period": 1,
+ "verbosity": 2,
+ "label_dir": "saved/",
+ "monitor": "max val_my_metric",
+ "early_stop": 2000,
+ "tensorboard": false,
+ "mlflow": true,
+ "_percent": "Percentage of noise",
+ "percent": 0.6,
+ "_begin": "When to begin updating labels",
+ "begin": 0,
+ "_asym": "symmetric noise if false",
+ "asym": false
+ }
+}
diff --git a/ELR/config_cifar10_asym.json b/ELR/config_cifar10_asym.json
new file mode 100644
index 0000000000000000000000000000000000000000..daa585ccdf404edcd9887ccd937007d539540e18
--- /dev/null
+++ b/ELR/config_cifar10_asym.json
@@ -0,0 +1,75 @@
+{
+ "name": "cifar10_resnet34_cosine",
+ "n_gpu": 1,
+ "seed": 123,
+
+ "arch": {
+ "type": "resnet34",
+ "args": {"num_classes":10}
+ },
+
+ "num_classes": 10,
+
+ "data_loader": {
+ "type": "CIFAR10DataLoader",
+ "args":{
+ "data_dir": "/dir_to_data",
+ "batch_size": 128,
+ "shuffle": true,
+ "num_batches": 0,
+ "validation_split": 0,
+ "num_workers": 8,
+ "pin_memory": true
+ }
+ },
+
+
+ "optimizer": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 1e-3
+ }
+ },
+
+ "train_loss": {
+ "type": "elr_loss",
+ "args":{
+ "beta": 0.9,
+ "lambda": 1
+ }
+ },
+
+ "val_loss": "cross_entropy",
+ "metrics": [
+ "my_metric", "my_metric2"
+ ],
+
+ "lr_scheduler": {
+ "type": "MultiStepLR",
+ "args": {
+ "milestones": [40,80],
+ "gamma": 0.01
+ }
+ },
+
+ "trainer": {
+ "epochs": 120,
+ "warmup": 0,
+ "save_dir": "saved/",
+ "save_period": 1,
+ "verbosity": 2,
+ "label_dir": "saved/",
+ "monitor": "max val_my_metric",
+ "early_stop": 2000,
+ "tensorboard": false,
+ "mlflow": true,
+ "_percent": "Percentage of noise",
+ "percent": 0.4,
+ "_begin": "When to begin updating labels",
+ "begin": 0,
+ "_asym": "symmetric noise if false",
+ "asym": true
+ }
+}
diff --git a/ELR/data_loader/__pycache__/cifar10.cpython-36.pyc b/ELR/data_loader/__pycache__/cifar10.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d797ab1c7839168a90d43f22eec06c4137d1e2b8
Binary files /dev/null and b/ELR/data_loader/__pycache__/cifar10.cpython-36.pyc differ
diff --git a/ELR/data_loader/__pycache__/clothing1m.cpython-36.pyc b/ELR/data_loader/__pycache__/clothing1m.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14372f82eb66a38a5002938fafd66bb038e04091
Binary files /dev/null and b/ELR/data_loader/__pycache__/clothing1m.cpython-36.pyc differ
diff --git a/ELR/data_loader/__pycache__/data_loaders.cpython-36.pyc b/ELR/data_loader/__pycache__/data_loaders.cpython-36.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39d529ed4163398feaa8526387e387138417f0c5
Binary files /dev/null and b/ELR/data_loader/__pycache__/data_loaders.cpython-36.pyc differ
diff --git a/ELR/data_loader/cifar10.py b/ELR/data_loader/cifar10.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e0ff510d543e952984f08995ab1361a8a71fa98
--- /dev/null
+++ b/ELR/data_loader/cifar10.py
@@ -0,0 +1,212 @@
+import sys
+
+import numpy as np
+from PIL import Image
+import torchvision
+from torch.utils.data.dataset import Subset
+from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
+import torch
+import torch.nn.functional as F
+import random
+import json
+import os
+
+def get_cifar10(root, cfg_trainer, train=True,
+ transform_train=None, transform_val=None,
+ download=False, noise_file = ''):
+ base_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download)
+ if train:
+ train_idxs, val_idxs = train_val_split(base_dataset.targets)
+ train_dataset = CIFAR10_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
+ val_dataset = CIFAR10_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
+ if cfg_trainer['asym']:
+ train_dataset.asymmetric_noise()
+ val_dataset.asymmetric_noise()
+ else:
+ train_dataset.symmetric_noise()
+ val_dataset.symmetric_noise()
+
+ print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}") # Train: 45000 Val: 5000
+ else:
+ train_dataset = []
+ val_dataset = CIFAR10_val(root, cfg_trainer, None, train=train, transform=transform_val)
+ print(f"Test: {len(val_dataset)}")
+
+
+
+ return train_dataset, val_dataset
+
+
+def train_val_split(base_dataset: torchvision.datasets.CIFAR10):
+ num_classes = 10
+ base_dataset = np.array(base_dataset)
+ train_n = int(len(base_dataset) * 0.9 / num_classes)
+ train_idxs = []
+ val_idxs = []
+
+ for i in range(num_classes):
+ idxs = np.where(base_dataset == i)[0]
+ np.random.shuffle(idxs)
+ train_idxs.extend(idxs[:train_n])
+ val_idxs.extend(idxs[train_n:])
+ np.random.shuffle(train_idxs)
+ np.random.shuffle(val_idxs)
+
+ return train_idxs, val_idxs
+
+
+class CIFAR10_train(torchvision.datasets.CIFAR10):
+ def __init__(self, root, cfg_trainer, indexs, train=True,
+ transform=None, target_transform=None,
+ download=False):
+ super(CIFAR10_train, self).__init__(root, train=train,
+ transform=transform, target_transform=target_transform,
+ download=download)
+ self.num_classes = 10
+ self.cfg_trainer = cfg_trainer
+ self.train_data = self.data[indexs]#self.train_data[indexs]
+ self.train_labels = np.array(self.targets)[indexs]#np.array(self.train_labels)[indexs]
+ self.indexs = indexs
+ self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
+ self.noise_indx = []
+
+ def symmetric_noise(self):
+ self.train_labels_gt = self.train_labels.copy()
+ #np.random.seed(seed=888)
+ indices = np.random.permutation(len(self.train_data))
+ for i, idx in enumerate(indices):
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
+ self.noise_indx.append(idx)
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
+
+ def asymmetric_noise(self):
+ self.train_labels_gt = self.train_labels.copy()
+ for i in range(self.num_classes):
+ indices = np.where(self.train_labels == i)[0]
+ np.random.shuffle(indices)
+ for j, idx in enumerate(indices):
+ if j < self.cfg_trainer['percent'] * len(indices):
+ self.noise_indx.append(idx)
+ # truck -> automobile
+ if i == 9:
+ self.train_labels[idx] = 1
+ # bird -> airplane
+ elif i == 2:
+ self.train_labels[idx] = 0
+ # cat -> dog
+ elif i == 3:
+ self.train_labels[idx] = 5
+ # dog -> cat
+ elif i == 5:
+ self.train_labels[idx] = 3
+ # deer -> horse
+ elif i == 4:
+ self.train_labels[idx] = 7
+
+
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
+
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img)
+
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img,target, index, target_gt
+
+ def __len__(self):
+ return len(self.train_data)
+
+
+
+class CIFAR10_val(torchvision.datasets.CIFAR10):
+
+ def __init__(self, root, cfg_trainer, indexs, train=True,
+ transform=None, target_transform=None,
+ download=False):
+ super(CIFAR10_val, self).__init__(root, train=train,
+ transform=transform, target_transform=target_transform,
+ download=download)
+
+ # self.train_data = self.data[indexs]
+ # self.train_labels = np.array(self.targets)[indexs]
+ self.num_classes = 10
+ self.cfg_trainer = cfg_trainer
+ if train:
+ self.train_data = self.data[indexs]
+ self.train_labels = np.array(self.targets)[indexs]
+ else:
+ self.train_data = self.data
+ self.train_labels = np.array(self.targets)
+ self.train_labels_gt = self.train_labels.copy()
+ def symmetric_noise(self):
+
+ indices = np.random.permutation(len(self.train_data))
+ for i, idx in enumerate(indices):
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
+
+ def asymmetric_noise(self):
+ for i in range(self.num_classes):
+ indices = np.where(self.train_labels == i)[0]
+ np.random.shuffle(indices)
+ for j, idx in enumerate(indices):
+ if j < self.cfg_trainer['percent'] * len(indices):
+ # truck -> automobile
+ if i == 9:
+ self.train_labels[idx] = 1
+ # bird -> airplane
+ elif i == 2:
+ self.train_labels[idx] = 0
+ # cat -> dog
+ elif i == 3:
+ self.train_labels[idx] = 5
+ # dog -> cat
+ elif i == 5:
+ self.train_labels[idx] = 3
+ # deer -> horse
+ elif i == 4:
+ self.train_labels[idx] = 7
+ def __len__(self):
+ return len(self.train_data)
+
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
+
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img)
+
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target, index, target_gt
+
\ No newline at end of file
diff --git a/ELR/data_loader/cifar100.py b/ELR/data_loader/cifar100.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0bac72dcddd142236dc4564aea06f6b5b86fab7
--- /dev/null
+++ b/ELR/data_loader/cifar100.py
@@ -0,0 +1,317 @@
+import sys
+
+import numpy as np
+from PIL import Image
+import torchvision
+from torch.utils.data.dataset import Subset
+from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
+import torch
+import torch.nn.functional as F
+import random
+import os
+import json
+from numpy.testing import assert_array_almost_equal
+
+
+
+def get_cifar100(root, cfg_trainer, train=True,
+ transform_train=None, transform_val=None,
+ download=False, noise_file = ''):
+ base_dataset = torchvision.datasets.CIFAR100(root, train=train, download=download)
+ if train:
+ train_idxs, val_idxs = train_val_split(base_dataset.targets)
+ train_dataset = CIFAR100_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
+ val_dataset = CIFAR100_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
+ if cfg_trainer['asym']:
+ train_dataset.asymmetric_noise()
+ val_dataset.asymmetric_noise()
+ else:
+ train_dataset.symmetric_noise()
+ val_dataset.symmetric_noise()
+
+ print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}") # Train: 45000 Val: 5000
+ else:
+ train_dataset = []
+ val_dataset = CIFAR100_val(root, cfg_trainer, None, train=train, transform=transform_val)
+ print(f"Test: {len(val_dataset)}")
+
+
+
+
+ return train_dataset, val_dataset
+
+
+def train_val_split(base_dataset: torchvision.datasets.CIFAR100):
+ num_classes = 100
+ base_dataset = np.array(base_dataset)
+ train_n = int(len(base_dataset) * 0.9 / num_classes)
+ train_idxs = []
+ val_idxs = []
+
+ for i in range(num_classes):
+ idxs = np.where(base_dataset == i)[0]
+ np.random.shuffle(idxs)
+ train_idxs.extend(idxs[:train_n])
+ val_idxs.extend(idxs[train_n:])
+ np.random.shuffle(train_idxs)
+ np.random.shuffle(val_idxs)
+
+ return train_idxs, val_idxs
+
+
+class CIFAR100_train(torchvision.datasets.CIFAR100):
+ def __init__(self, root, cfg_trainer, indexs, train=True,
+ transform=None, target_transform=None,
+ download=False):
+ super(CIFAR100_train, self).__init__(root, train=train,
+ transform=transform, target_transform=target_transform,
+ download=download)
+ self.num_classes = 100
+ self.cfg_trainer = cfg_trainer
+ self.train_data = self.data[indexs]
+ self.train_labels = np.array(self.targets)[indexs]
+ self.indexs = indexs
+ self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
+ self.noise_indx = []
+ #self.all_refs_encoded = torch.zeros(self.num_classes,self.num_ref,1024, dtype=np.float32)
+
+ self.count = 0
+
+ def symmetric_noise(self):
+ self.train_labels_gt = self.train_labels.copy()
+ indices = np.random.permutation(len(self.train_data))
+ for i, idx in enumerate(indices):
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
+ self.noise_indx.append(idx)
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
+
+ def multiclass_noisify(self, y, P, random_state=0):
+ """ Flip classes according to transition probability matrix T.
+ It expects a number between 0 and the number of classes - 1.
+ """
+
+ assert P.shape[0] == P.shape[1]
+ assert np.max(y) < P.shape[0]
+
+ # row stochastic matrix
+ assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
+ assert (P >= 0.0).all()
+
+ m = y.shape[0]
+ new_y = y.copy()
+ flipper = np.random.RandomState(random_state)
+
+ for idx in np.arange(m):
+ i = y[idx]
+ # draw a vector with only an 1
+ flipped = flipper.multinomial(1, P[i, :], 1)[0]
+ new_y[idx] = np.where(flipped == 1)[0]
+
+ return new_y
+
+# def build_for_cifar100(self, size, noise):
+# """ random flip between two random classes.
+# """
+# assert(noise >= 0.) and (noise <= 1.)
+
+# P = np.eye(size)
+# cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
+# P[cls1, cls2] = noise
+# P[cls2, cls1] = noise
+# P[cls1, cls1] = 1.0 - noise
+# P[cls2, cls2] = 1.0 - noise
+
+# assert_array_almost_equal(P.sum(axis=1), 1, 1)
+# return P
+ def build_for_cifar100(self, size, noise):
+ """ The noise matrix flips to the "next" class with probability 'noise'.
+ """
+
+ assert(noise >= 0.) and (noise <= 1.)
+
+ P = (1. - noise) * np.eye(size)
+ for i in np.arange(size - 1):
+ P[i, i + 1] = noise
+
+ # adjust last row
+ P[size - 1, 0] = noise
+
+ assert_array_almost_equal(P.sum(axis=1), 1, 1)
+ return P
+
+
+ def asymmetric_noise(self, asym=False, random_shuffle=False):
+ self.train_labels_gt = self.train_labels.copy()
+ P = np.eye(self.num_classes)
+ n = self.cfg_trainer['percent']
+ nb_superclasses = 20
+ nb_subclasses = 5
+
+ if n > 0.0:
+ for i in np.arange(nb_superclasses):
+ init, end = i * nb_subclasses, (i+1) * nb_subclasses
+ P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
+
+ y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
+ random_state=0)
+ actual_noise = (y_train_noisy != self.train_labels).mean()
+ assert actual_noise > 0.0
+ self.train_labels = y_train_noisy
+
+
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
+
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img)
+
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target, index, target_gt
+
+ def __len__(self):
+ return len(self.train_data)
+
+
+class CIFAR100_val(torchvision.datasets.CIFAR100):
+
+ def __init__(self, root, cfg_trainer, indexs, train=True,
+ transform=None, target_transform=None,
+ download=False):
+ super(CIFAR100_val, self).__init__(root, train=train,
+ transform=transform, target_transform=target_transform,
+ download=download)
+
+ # self.train_data = self.data[indexs]
+ # self.train_labels = np.array(self.targets)[indexs]
+ self.num_classes = 100
+ self.cfg_trainer = cfg_trainer
+ if train:
+ self.train_data = self.data[indexs]
+ self.train_labels = np.array(self.targets)[indexs]
+ else:
+ self.train_data = self.data
+ self.train_labels = np.array(self.targets)
+ self.train_labels_gt = self.train_labels.copy()
+ def symmetric_noise(self):
+ indices = np.random.permutation(len(self.train_data))
+ for i, idx in enumerate(indices):
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
+
+ def multiclass_noisify(self, y, P, random_state=0):
+ """ Flip classes according to transition probability matrix T.
+ It expects a number between 0 and the number of classes - 1.
+ """
+
+ assert P.shape[0] == P.shape[1]
+ assert np.max(y) < P.shape[0]
+
+ # row stochastic matrix
+ assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
+ assert (P >= 0.0).all()
+
+ m = y.shape[0]
+ new_y = y.copy()
+ flipper = np.random.RandomState(random_state)
+
+ for idx in np.arange(m):
+ i = y[idx]
+ # draw a vector with only an 1
+ flipped = flipper.multinomial(1, P[i, :], 1)[0]
+ new_y[idx] = np.where(flipped == 1)[0]
+
+ return new_y
+
+# def build_for_cifar100(self, size, noise):
+# """ random flip between two random classes.
+# """
+# assert(noise >= 0.) and (noise <= 1.)
+
+# P = np.eye(size)
+# cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
+# P[cls1, cls2] = noise
+# P[cls2, cls1] = noise
+# P[cls1, cls1] = 1.0 - noise
+# P[cls2, cls2] = 1.0 - noise
+
+# assert_array_almost_equal(P.sum(axis=1), 1, 1)
+# return P
+ def build_for_cifar100(self, size, noise):
+ """ The noise matrix flips to the "next" class with probability 'noise'.
+ """
+
+ assert(noise >= 0.) and (noise <= 1.)
+
+ P = (1. - noise) * np.eye(size)
+ for i in np.arange(size - 1):
+ P[i, i + 1] = noise
+
+ # adjust last row
+ P[size - 1, 0] = noise
+
+ assert_array_almost_equal(P.sum(axis=1), 1, 1)
+ return P
+
+
+ def asymmetric_noise(self, asym=False, random_shuffle=False):
+ P = np.eye(self.num_classes)
+ n = self.cfg_trainer['percent']
+ nb_superclasses = 20
+ nb_subclasses = 5
+
+ if n > 0.0:
+ for i in np.arange(nb_superclasses):
+ init, end = i * nb_subclasses, (i+1) * nb_subclasses
+ P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
+
+ y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
+ random_state=0)
+ actual_noise = (y_train_noisy != self.train_labels).mean()
+ assert actual_noise > 0.0
+ self.train_labels = y_train_noisy
+ def __len__(self):
+ return len(self.train_data)
+
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
+
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img)
+
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target, index, target_gt
+
+
+
diff --git a/ELR/data_loader/data_loaders.py b/ELR/data_loader/data_loaders.py
new file mode 100644
index 0000000000000000000000000000000000000000..eecf40ea026847f3bae963177855e096c03a189a
--- /dev/null
+++ b/ELR/data_loader/data_loaders.py
@@ -0,0 +1,70 @@
+import sys
+
+from torchvision import datasets, transforms
+from base import BaseDataLoader
+from data_loader.cifar10 import get_cifar10
+from data_loader.cifar100 import get_cifar100
+from parse_config import ConfigParser
+from PIL import Image
+
+
+class CIFAR10DataLoader(BaseDataLoader):
+ def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True, num_workers=4, pin_memory=True):
+ config = ConfigParser.get_instance()
+ cfg_trainer = config['trainer']
+
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ transform_val = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
+ ])
+ self.data_dir = data_dir
+
+ noise_file='%sCIFAR10_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
+
+ self.train_dataset, self.val_dataset = get_cifar10(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
+ transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
+
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
+ val_dataset = self.val_dataset)
+ def run_loader(self, batch_size, shuffle, validation_split, num_workers, pin_memory):
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
+ val_dataset = self.val_dataset)
+
+
+
+class CIFAR100DataLoader(BaseDataLoader):
+ def __init__(self, data_dir, batch_size, shuffle=True, validation_split=0.0, num_batches=0, training=True,num_workers=4, pin_memory=True):
+ config = ConfigParser.get_instance()
+ cfg_trainer = config['trainer']
+
+ transform_train = transforms.Compose([
+ #transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ transform_val = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
+ ])
+ self.data_dir = data_dir
+ config = ConfigParser.get_instance()
+ cfg_trainer = config['trainer']
+
+ noise_file='%sCIFAR100_%.1f_Asym_%s.json'%(config['data_loader']['args']['data_dir'],cfg_trainer['percent'],cfg_trainer['asym'])
+
+ self.train_dataset, self.val_dataset = get_cifar100(config['data_loader']['args']['data_dir'], cfg_trainer, train=training,
+ transform_train=transform_train, transform_val=transform_val, noise_file = noise_file)
+
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
+ val_dataset = self.val_dataset)
+ def run_loader(self, batch_size, shuffle, validation_split, num_workers, pin_memory):
+ super().__init__(self.train_dataset, batch_size, shuffle, validation_split, num_workers, pin_memory,
+ val_dataset = self.val_dataset)
diff --git a/ELR/logger/__init__.py b/ELR/logger/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..092cad0f26c7ad447026eff419ed6a988b291eff
--- /dev/null
+++ b/ELR/logger/__init__.py
@@ -0,0 +1,2 @@
+from .logger import *
+from .visualization import *
\ No newline at end of file
diff --git a/ELR/logger/logger.py b/ELR/logger/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b8a1ef28b7c58fb5f42f576fbcada548f01a789
--- /dev/null
+++ b/ELR/logger/logger.py
@@ -0,0 +1,22 @@
+import logging
+import logging.config
+from pathlib import Path
+from utils import read_json
+
+
+def setup_logging(save_dir, log_config='logger/logger_config.json', default_level=logging.INFO):
+ """
+ Setup logging configuration
+ """
+ log_config = Path(log_config)
+ if log_config.is_file():
+ config = read_json(log_config)
+ # modify logging paths based on run config
+ for _, handler in config['handlers'].items():
+ if 'filename' in handler:
+ handler['filename'] = str(save_dir / handler['filename'])
+
+ logging.config.dictConfig(config)
+ else:
+ print("Warning: logging configuration file is not found in {}.".format(log_config))
+ logging.basicConfig(level=default_level)
diff --git a/ELR/logger/logger_config.json b/ELR/logger/logger_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..5afd588205c0be4b6dcc0dd65a5df2827d74e23a
--- /dev/null
+++ b/ELR/logger/logger_config.json
@@ -0,0 +1,32 @@
+
+{
+ "version": 1,
+ "disable_existing_loggers": false,
+ "formatters": {
+ "simple": {"format": "%(message)s"},
+ "datetime": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
+ },
+ "handlers": {
+ "console": {
+ "class": "logging.StreamHandler",
+ "level": "DEBUG",
+ "formatter": "simple",
+ "stream": "ext://sys.stdout"
+ },
+ "info_file_handler": {
+ "class": "logging.handlers.RotatingFileHandler",
+ "level": "INFO",
+ "formatter": "datetime",
+ "filename": "info.log",
+ "maxBytes": 10485760,
+ "backupCount": 20, "encoding": "utf8"
+ }
+ },
+ "root": {
+ "level": "INFO",
+ "handlers": [
+ "console",
+ "info_file_handler"
+ ]
+ }
+}
\ No newline at end of file
diff --git a/ELR/logger/visualization.py b/ELR/logger/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..28969a26f1ec2e6785e35f8a4ba247ab50fc56f9
--- /dev/null
+++ b/ELR/logger/visualization.py
@@ -0,0 +1,154 @@
+import importlib
+from utils import Timer
+
+
+class MLFlow:
+ def __init__(self, log_dir, logger, enabled):
+ self.mlflow = None
+
+ if enabled:
+ log_dir = str(log_dir)
+
+ # Retrieve visualization writer.
+ try:
+ self.mlflow = importlib.import_module("mlflow")
+ succeeded = True
+ except ImportError:
+ succeeded = False
+
+ if not succeeded:
+ message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \
+ "this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \
+ "the 'config.json' file."
+ logger.warning(message)
+
+ self.step = 0
+ self.mode = ''
+
+ self.mlflow_ftns_with_tag_and_value = {
+ 'log_param', 'log_metric'
+ }
+ self.mlflow_ftns = {
+ 'start_run'
+ }
+ # self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
+
+ # self.timer = Timer()
+
+ # def set_step(self, step, mode='train'):
+ # self.mode = mode
+ # self.step = step
+ # if step == 0:
+ # self.timer.reset()
+ # else:
+ # duration = self.timer.check()
+ # self.add_scalar('steps_per_sec', 1 / duration)
+
+ def __getattr__(self, name):
+ """
+ If visualization is configured to use:
+ return add_data() methods of tensorboard with additional information (step, tag) added.
+ Otherwise:
+ return a blank function handle that does nothing
+ """
+ if name in self.mlflow_ftns_with_tag_and_value:
+ add_data = getattr(self.mlflow, name, None)
+
+ def wrapper(tag, data, *args, **kwargs):
+ if add_data is not None:
+ # add mode(train/valid) tag
+ if name not in self.tag_mode_exceptions:
+ tag = '{}/{}'.format(tag, self.mode)
+ add_data(tag, data, *args, **kwargs)
+
+ return wrapper
+ elif name in self.mlflow_ftns:
+ add_data = getattr(self.mlflow, name, None)
+
+ def wrapper(*args, **kwargs):
+ if add_data is not None:
+ # add mode(train/valid) tag
+ # if name not in self.tag_mode_exceptions:
+ # tag = '{}/{}'.format(tag, self.mode)
+ add_data(*args, **kwargs)
+
+ return wrapper
+ else:
+ # default action for returning methods defined in this class, set_step() for instance.
+ try:
+ attr = object.__getattr__(name)
+ except AttributeError:
+ raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
+ return attr
+
+
+class TensorboardWriter:
+ def __init__(self, log_dir, logger, enabled):
+ self.writer = None
+ self.selected_module = ""
+
+ if enabled:
+ log_dir = str(log_dir)
+
+ # Retrieve vizualization writer.
+ succeeded = False
+ for module in ["torch.utils.tensorboard", "tensorboardX"]:
+ try:
+ self.writer = importlib.import_module(module).SummaryWriter(log_dir)
+ succeeded = True
+ break
+ except ImportError:
+ succeeded = False
+ self.selected_module = module
+
+ if not succeeded:
+ message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \
+ "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \
+ "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
+ "the 'config.json' file."
+ logger.warning(message)
+
+ self.step = 0
+ self.mode = ''
+
+ self.tb_writer_ftns = {
+ 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
+ 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
+ }
+ self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
+
+ self.timer = Timer()
+
+ def set_step(self, step, mode='train'):
+ self.mode = mode
+ self.step = step
+ if step == 0:
+ self.timer.reset()
+ else:
+ duration = self.timer.check()
+ self.add_scalar('steps_per_sec', 1 / duration)
+
+ def __getattr__(self, name):
+ """
+ If visualization is configured to use:
+ return add_data() methods of tensorboard with additional information (step, tag) added.
+ Otherwise:
+ return a blank function handle that does nothing
+ """
+ if name in self.tb_writer_ftns:
+ add_data = getattr(self.writer, name, None)
+
+ def wrapper(tag, data, *args, **kwargs):
+ if add_data is not None:
+ # add mode(train/valid) tag
+ if name not in self.tag_mode_exceptions:
+ tag = '{}/{}'.format(tag, self.mode)
+ add_data(tag, data, self.step, *args, **kwargs)
+ return wrapper
+ else:
+ # default action for returning methods defined in this class, set_step() for instance.
+ try:
+ attr = object.__getattr__(name)
+ except AttributeError:
+ raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
+ return attr
diff --git a/ELR/model/ResNet_Zoo.py b/ELR/model/ResNet_Zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..6149ae2b811bcb15baeb40b5525682788112351b
--- /dev/null
+++ b/ELR/model/ResNet_Zoo.py
@@ -0,0 +1,133 @@
+'''ResNet in PyTorch.
+For Pre-activation ResNet, see 'preact_resnet.py'.
+Reference:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion*planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(self.expansion*planes)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion*planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(self.expansion*planes)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = F.relu(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, num_blocks, num_classes=10):
+ super(ResNet, self).__init__()
+ self.in_planes = 64
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+ self.linear = nn.Linear(512*block.expansion, num_classes)
+
+
+ self.gradients = None
+
+
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1]*(num_blocks-1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def activations_hook(self, grad):
+ self.gradients = grad
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ y = out.view(out.size(0), -1)
+ out = self.linear(y)
+ if out.requires_grad:
+ out.register_hook(self.activations_hook)
+ return out
+
+ def get_activations_gradient(self):
+ return self.gradients
+
+
+def ResNet18():
+ return ResNet(BasicBlock, [2,2,2,2])
+
+def ResNet34():
+ return ResNet(BasicBlock, [3,4,6,3])
+
+def ResNet50():
+ return ResNet(Bottleneck, [3,4,6,3])
+
+def ResNet101():
+ return ResNet(Bottleneck, [3,4,23,3])
+
+def ResNet152():
+ return ResNet(Bottleneck, [3,8,36,3])
+
+
+def test():
+ net = ResNet18()
+ y = net(torch.randn(1,3,32,32))
+ print(y.size())
+
+# test()
\ No newline at end of file
diff --git a/ELR/model/loss.py b/ELR/model/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..e59b2ea9efed56f351ea4d6d03570e0a09c0d38f
--- /dev/null
+++ b/ELR/model/loss.py
@@ -0,0 +1,30 @@
+import torch.nn.functional as F
+import torch
+from parse_config import ConfigParser
+import torch.nn as nn
+
+
+def cross_entropy(output, target):
+ return F.cross_entropy(output, target)
+
+
+class elr_loss(nn.Module):
+ def __init__(self, num_examp, num_classes=10, beta=0.3):
+ super(elr_loss, self).__init__()
+ self.num_classes = num_classes
+ self.config = ConfigParser.get_instance()
+ self.USE_CUDA = torch.cuda.is_available()
+ self.target = torch.zeros(num_examp, self.num_classes).cuda() if self.USE_CUDA else torch.zeros(num_examp, self.num_classes)
+ self.beta = beta
+
+
+ def forward(self, index, output, label):
+ y_pred = F.softmax(output,dim=1)
+ y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
+ y_pred_ = y_pred.data.detach()
+ self.target[index] = self.beta * self.target[index] + (1-self.beta) * ((y_pred_)/(y_pred_).sum(dim=1,keepdim=True))
+ ce_loss = F.cross_entropy(output, label)
+ elr_reg = ((1-(self.target[index] * y_pred).sum(dim=1)).log()).mean()
+ final_loss = ce_loss + self.config['train_loss']['args']['lambda']*elr_reg
+ return final_loss
+
diff --git a/ELR/model/metric.py b/ELR/model/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..c47ca20094ebd52e22cb69e3ce1f147f47a4b40d
--- /dev/null
+++ b/ELR/model/metric.py
@@ -0,0 +1,20 @@
+import torch
+
+
+def my_metric(output, target):
+ with torch.no_grad():
+ pred = torch.argmax(output, dim=1)
+ assert pred.shape[0] == len(target)
+ correct = 0
+ correct += torch.sum(pred == target).item()
+ return correct / len(target)
+
+
+def my_metric2(output, target, k=5):
+ with torch.no_grad():
+ pred = torch.topk(output, k, dim=1)[1]
+ assert pred.shape[0] == len(target)
+ correct = 0
+ for i in range(k):
+ correct += torch.sum(pred[:, i] == target).item()
+ return correct / len(target)
diff --git a/ELR/model/model.py b/ELR/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b4bdbd1eeacaf7bcac55f06b87eec689afaebd6
--- /dev/null
+++ b/ELR/model/model.py
@@ -0,0 +1,13 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from base import BaseModel
+from .ResNet_Zoo import ResNet, BasicBlock
+
+
+
+def resnet34(num_classes=10):
+ return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
+
+
+
+
diff --git a/ELR/parse_config.py b/ELR/parse_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..69ec1b201bd4ae12af004b519b99a35d9b956910
--- /dev/null
+++ b/ELR/parse_config.py
@@ -0,0 +1,146 @@
+import os
+import logging
+from pathlib import Path
+from functools import reduce
+from operator import getitem
+from datetime import datetime
+from logger import setup_logging
+from utils import read_json, write_json
+
+
+class ConfigParser:
+
+ __instance = None
+
+ def __new__(cls, args, options='', timestamp=True):
+ raise NotImplementedError('Cannot initialize via Constructor')
+
+ @classmethod
+ def __internal_new__(cls):
+ return super().__new__(cls)
+
+ @classmethod
+ def get_instance(cls, args=None, options='', timestamp=True):
+ if not cls.__instance:
+ if args is None:
+ NotImplementedError('Cannot initialize without args')
+ cls.__instance = cls.__internal_new__()
+ cls.__instance.__init__(args, options)
+
+ return cls.__instance
+
+ def __init__(self, args, options='', timestamp=True):
+ # parse default and custom cli options
+ for opt in options:
+ args.add_argument(*opt.flags, default=None, type=opt.type)
+ args = args.parse_args()
+
+ if args.device:
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
+ if args.resume is None:
+ msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
+ assert args.config is not None, msg_no_cfg
+ self.cfg_fname = Path(args.config)
+ config = read_json(self.cfg_fname)
+ self.resume = None
+ else:
+ self.resume = Path(args.resume)
+ resume_cfg_fname = self.resume.parent / 'config.json'
+ config = read_json(resume_cfg_fname)
+ if args.config is not None:
+ config.update(read_json(Path(args.config)))
+
+ # load config file and apply custom cli options
+ self._config = _update_config(config, options, args)
+
+ # set save_dir where trained model and log will be saved.
+ save_dir = Path(self.config['trainer']['save_dir'])
+ timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
+
+
+ if self.config['trainer']['asym']:
+ exper_name = self.config['name'] + '_asym_' + str(int(self.config['trainer']['percent']*100))
+ else:
+ exper_name = self.config['name'] + '_sym_' + str(int(self.config['trainer']['percent']*100))
+ self._save_dir = save_dir / 'models' / exper_name / timestamp
+ self._log_dir = save_dir / 'log' / exper_name / timestamp
+
+ self.save_dir.mkdir(parents=True, exist_ok=True)
+ self.log_dir.mkdir(parents=True, exist_ok=True)
+
+ # save updated config file to the checkpoint dir
+ write_json(self.config, self.save_dir / 'config.json')
+
+ # configure logging module
+ setup_logging(self.log_dir)
+ self.log_levels = {
+ 0: logging.WARNING,
+ 1: logging.INFO,
+ 2: logging.DEBUG
+ }
+
+ def initialize(self, name, module, *args, **kwargs):
+ """
+ finds a function handle with the name given as 'type' in config, and returns the
+ instance initialized with corresponding keyword args given as 'args'.
+ """
+ module_name = self[name]['type']
+ module_args = dict(self[name]['args'])
+ assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
+ module_args.update(kwargs)
+ return getattr(module, module_name)(*args, **module_args)
+
+ def __getitem__(self, name):
+ return self.config[name]
+
+ def get_logger(self, name, verbosity=2):
+ msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
+ self.log_levels.keys())
+ assert verbosity in self.log_levels, msg_verbosity
+ logger = logging.getLogger(name)
+ logger.setLevel(self.log_levels[verbosity])
+ return logger
+
+ # setting read-only attributes
+ @property
+ def config(self):
+ return self._config
+
+ @property
+ def save_dir(self):
+ return self._save_dir
+
+ @property
+ def log_dir(self):
+ return self._log_dir
+
+
+# helper functions used to update config dict with custom cli options
+def _update_config(config, options, args):
+ for opt in options:
+ value = getattr(args, _get_opt_name(opt.flags))
+ if value is not None:
+ _set_by_path(config, opt.target, value)
+ if 'target2' in opt._fields:
+ _set_by_path(config, opt.target2, value)
+ if 'target3' in opt._fields:
+ _set_by_path(config, opt.target3, value)
+
+ return config
+
+
+def _get_opt_name(flags):
+ for flg in flags:
+ if flg.startswith('--'):
+ return flg.replace('--', '')
+ return flags[0].replace('--', '')
+
+
+def _set_by_path(tree, keys, value):
+ """Set a value in a nested object in tree by sequence of keys."""
+ _get_by_path(tree, keys[:-1])[keys[-1]] = value
+
+
+def _get_by_path(tree, keys):
+ """Access a nested object in tree by sequence of keys."""
+ return reduce(getitem, keys, tree)
diff --git a/ELR/test.py b/ELR/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3084bdb3a59d756db1c8914962108a44550293de
--- /dev/null
+++ b/ELR/test.py
@@ -0,0 +1,82 @@
+import argparse
+import torch
+from tqdm import tqdm
+import data_loader.data_loaders as module_data
+import model.loss as module_loss
+import model.metric as module_metric
+import model.model as module_arch
+from parse_config import ConfigParser
+
+
+def main(config):
+ logger = config.get_logger('test')
+
+ # setup data_loader instances
+ data_loader = getattr(module_data, config['data_loader']['type'])(
+ config['data_loader']['args']['data_dir'],
+ batch_size=512,
+ shuffle=False,
+ validation_split=0.0,
+ training=False,
+ num_workers=2
+ ).split_validation()
+
+ # build model architecture
+ model = config.initialize('arch', module_arch)
+ logger.info(model)
+
+ # get function handles of loss and metrics
+ loss_fn = getattr(module_loss, config['val_loss'])
+ metric_fns = [getattr(module_metric, met) for met in config['metrics']]
+
+ logger.info('Loading checkpoint: {} ...'.format(config.resume))
+ checkpoint = torch.load(config.resume,map_location='cpu')
+ state_dict = checkpoint['state_dict']
+ if config['n_gpu'] > 1:
+ model = torch.nn.DataParallel(model)
+ model.load_state_dict(state_dict)
+
+ # prepare model for testing
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ model = model.to(device)
+ model.eval()
+
+ total_loss = 0.0
+ total_metrics = torch.zeros(len(metric_fns))
+
+ with torch.no_grad():
+ for i, (data, target,_,_) in enumerate(tqdm(data_loader)):
+ data, target = data.to(device), target.to(device)
+ output = model(data)
+
+ #
+ # save sample images, or do something with output here
+ #
+
+ # computing loss, metrics on test set
+ loss = loss_fn(output, target)
+ batch_size = data.shape[0]
+ total_loss += loss.item() * batch_size
+ for i, metric in enumerate(metric_fns):
+ total_metrics[i] += metric(output, target) * batch_size
+
+ n_samples = len(data_loader.sampler)
+ log = {'loss': total_loss / n_samples}
+ log.update({
+ met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
+ })
+ logger.info(log)
+
+
+if __name__ == '__main__':
+ args = argparse.ArgumentParser(description='PyTorch Template')
+
+ args.add_argument('-c', '--config', default=None, type=str,
+ help='config file path (default: None)')
+ args.add_argument('-r', '--resume', default=None, type=str,
+ help='path to latest checkpoint (default: None)')
+ args.add_argument('-d', '--device', default=None, type=str,
+ help='indices of GPUs to enable (default: all)')
+ config = ConfigParser.get_instance(args, '')
+ #config = ConfigParser(args)
+ main(config)
diff --git a/ELR/train.py b/ELR/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5688b8681b8fe2ed6d8298c4e9049f88cf0fa4a
--- /dev/null
+++ b/ELR/train.py
@@ -0,0 +1,125 @@
+import argparse
+import collections
+import sys
+import requests
+import socket
+import torch
+import mlflow
+import mlflow.pytorch
+import data_loader.data_loaders as module_data
+import model.loss as module_loss
+import model.metric as module_metric
+import model.model as module_arch
+from parse_config import ConfigParser
+from trainer import Trainer
+from collections import OrderedDict
+import random
+
+
+
+def log_params(conf: OrderedDict, parent_key: str = None):
+ for key, value in conf.items():
+ if parent_key is not None:
+ combined_key = f'{parent_key}-{key}'
+ else:
+ combined_key = key
+
+ if not isinstance(value, OrderedDict):
+ mlflow.log_param(combined_key, value)
+ else:
+ log_params(value, combined_key)
+
+
+def main(config: ConfigParser):
+
+ logger = config.get_logger('train')
+
+ data_loader = getattr(module_data, config['data_loader']['type'])(
+ config['data_loader']['args']['data_dir'],
+ batch_size= config['data_loader']['args']['batch_size'],
+ shuffle=config['data_loader']['args']['shuffle'],
+ validation_split=config['data_loader']['args']['validation_split'],
+ num_batches=config['data_loader']['args']['num_batches'],
+ training=True,
+ num_workers=config['data_loader']['args']['num_workers'],
+ pin_memory=config['data_loader']['args']['pin_memory']
+ )
+
+
+ valid_data_loader = data_loader.split_validation()
+
+ # test_data_loader = None
+
+ test_data_loader = getattr(module_data, config['data_loader']['type'])(
+ config['data_loader']['args']['data_dir'],
+ batch_size=128,
+ shuffle=False,
+ validation_split=0.0,
+ training=False,
+ num_workers=2
+ ).split_validation()
+
+
+ # build model architecture, then print to console
+ model = config.initialize('arch', module_arch)
+
+ # get function handles of loss and metrics
+ logger.info(config.config)
+ if hasattr(data_loader.dataset, 'num_raw_example'):
+ num_examp = data_loader.dataset.num_raw_example
+ else:
+ num_examp = len(data_loader.dataset)
+
+ train_loss = getattr(module_loss, config['train_loss']['type'])(num_examp=num_examp, num_classes=config['num_classes'],
+ beta=config['train_loss']['args']['beta'])
+
+ val_loss = getattr(module_loss, config['val_loss'])
+ metrics = [getattr(module_metric, met) for met in config['metrics']]
+
+ # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
+ trainable_params = filter(lambda p: p.requires_grad, model.parameters())
+
+ optimizer = config.initialize('optimizer', torch.optim, [{'params': trainable_params}])
+
+ lr_scheduler = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer)
+
+ trainer = Trainer(model, train_loss, metrics, optimizer,
+ config=config,
+ data_loader=data_loader,
+ valid_data_loader=valid_data_loader,
+ test_data_loader=test_data_loader,
+ lr_scheduler=lr_scheduler,
+ val_criterion=val_loss)
+
+ trainer.train()
+ logger = config.get_logger('trainer', config['trainer']['verbosity'])
+ cfg_trainer = config['trainer']
+
+
+if __name__ == '__main__':
+ args = argparse.ArgumentParser(description='PyTorch Template')
+ args.add_argument('-c', '--config', default=None, type=str,
+ help='config file path (default: None)')
+ args.add_argument('-r', '--resume', default=None, type=str,
+ help='path to latest checkpoint (default: None)')
+ args.add_argument('-d', '--device', default=None, type=str,
+ help='indices of GPUs to enable (default: all)')
+
+ # custom cli options to modify configuration from default values given in json file.
+ CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
+ options = [
+ CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
+ CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')),
+ CustomArgs(['--lamb', '--lamb'], type=float, target=('train_loss', 'args', 'lambda')),
+ CustomArgs(['--beta', '--beta'], type=float, target=('train_loss', 'args', 'beta')),
+ CustomArgs(['--percent', '--percent'], type=float, target=('trainer', 'percent')),
+ CustomArgs(['--asym', '--asym'], type=bool, target=('trainer', 'asym')),
+ CustomArgs(['--name', '--exp_name'], type=str, target=('name',)),
+ CustomArgs(['--seed', '--seed'], type=int, target=('seed',))
+ ]
+ config = ConfigParser.get_instance(args, options)
+
+ random.seed(config['seed'])
+ torch.manual_seed(config['seed'])
+ torch.cuda.manual_seed_all(config['seed'])
+ main(config)
diff --git a/ELR/trainer/__init__.py b/ELR/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fe21f802159f55665fcdebcd9b2c32bb4acf5b9
--- /dev/null
+++ b/ELR/trainer/__init__.py
@@ -0,0 +1 @@
+from .trainer import *
diff --git a/ELR/trainer/trainer.py b/ELR/trainer/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fb0342d1551a55f3d9c958678d1a6be2afe298d
--- /dev/null
+++ b/ELR/trainer/trainer.py
@@ -0,0 +1,278 @@
+import numpy as np
+import torch
+from tqdm import tqdm
+from typing import List
+from torchvision.utils import make_grid
+from base import BaseTrainer
+from utils import inf_loop
+import sys
+from sklearn.mixture import GaussianMixture
+
+class Trainer(BaseTrainer):
+ """
+ Trainer class
+
+ Note:
+ Inherited from BaseTrainer.
+ """
+ def __init__(self, model, train_criterion, metrics, optimizer, config, data_loader,
+ valid_data_loader=None, test_data_loader=None, lr_scheduler=None, len_epoch=None, val_criterion=None):
+ super().__init__(model, train_criterion, metrics, optimizer, config, val_criterion)
+ self.config = config
+ self.data_loader = data_loader
+ if len_epoch is None:
+ # epoch-based training
+ self.len_epoch = len(self.data_loader)
+ else:
+ # iteration-based training
+ self.data_loader = inf_loop(data_loader)
+ self.len_epoch = len_epoch
+ self.valid_data_loader = valid_data_loader
+
+ self.test_data_loader = test_data_loader
+ self.do_validation = self.valid_data_loader is not None
+ self.do_test = self.test_data_loader is not None
+ self.lr_scheduler = lr_scheduler
+ self.log_step = int(np.sqrt(data_loader.batch_size))
+ self.train_loss_list: List[float] = []
+ self.val_loss_list: List[float] = []
+ self.test_loss_list: List[float] = []
+ #Visdom visualization
+
+
+ def _eval_metrics(self, output, label):
+ acc_metrics = np.zeros(len(self.metrics))
+ for i, metric in enumerate(self.metrics):
+ acc_metrics[i] += metric(output, label)
+ self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i])
+ return acc_metrics
+
+ def _train_epoch(self, epoch):
+ """
+ Training logic for an epoch
+
+ :param epoch: Current training epoch.
+ :return: A log that contains all information you want to save.
+
+ Note:
+ If you have additional information to record, for example:
+ > additional_log = {"x": x, "y": y}
+ merge it with log before return. i.e.
+ > log = {**log, **additional_log}
+ > return log
+
+ The metrics in log must have the key 'metrics'.
+ """
+ self.model.train()
+
+ total_loss = 0
+ total_metrics = np.zeros(len(self.metrics))
+
+ with tqdm(self.data_loader) as progress:
+ for batch_idx, (data, label, indexs, _) in enumerate(progress):
+ progress.set_description_str(f'Train epoch {epoch}')
+
+ data, label = data.to(self.device), label.long().to(self.device)
+
+ output = self.model(data)
+
+ loss = self.train_criterion(indexs.cpu().detach().numpy().tolist(), output, label)
+ self.optimizer.zero_grad()
+ loss.backward()
+
+
+
+
+ self.optimizer.step()
+
+ self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
+ self.writer.add_scalar('loss', loss.item())
+ self.train_loss_list.append(loss.item())
+ total_loss += loss.item()
+ total_metrics += self._eval_metrics(output, label)
+
+
+ if batch_idx % self.log_step == 0:
+ progress.set_postfix_str(' {} Loss: {:.6f}'.format(
+ self._progress(batch_idx),
+ loss.item()))
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
+
+ if batch_idx == self.len_epoch:
+ break
+ # if hasattr(self.data_loader, 'run'):
+ # self.data_loader.run()
+
+ log = {
+ 'loss': total_loss / self.len_epoch,
+ 'metrics': (total_metrics / self.len_epoch).tolist(),
+ 'learning rate': self.lr_scheduler.get_lr()
+ }
+
+
+ if self.do_validation:
+ val_log = self._valid_epoch(epoch)
+ log.update(val_log)
+ if self.do_test:
+ test_log, test_meta = self._test_epoch(epoch)
+ log.update(test_log)
+ else:
+ test_meta = [0,0]
+
+
+ if self.lr_scheduler is not None:
+ self.lr_scheduler.step()
+
+ return log
+
+
+ def _valid_epoch(self, epoch):
+ """
+ Validate after training an epoch
+
+ :return: A log that contains information about validation
+
+ Note:
+ The validation metrics in log must have the key 'val_metrics'.
+ """
+ self.model.eval()
+
+ total_val_loss = 0
+ total_val_metrics = np.zeros(len(self.metrics))
+ with torch.no_grad():
+ with tqdm(self.valid_data_loader) as progress:
+ for batch_idx, (data, label, _, _) in enumerate(progress):
+ progress.set_description_str(f'Valid epoch {epoch}')
+ data, label = data.to(self.device), label.to(self.device)
+ output = self.model(data)
+ loss = self.val_criterion(output, label)
+
+ self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
+ self.writer.add_scalar('loss', loss.item())
+ self.val_loss_list.append(loss.item())
+ total_val_loss += loss.item()
+ total_val_metrics += self._eval_metrics(output, label)
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
+
+ # add histogram of model parameters to the tensorboard
+ for name, p in self.model.named_parameters():
+ self.writer.add_histogram(name, p, bins='auto')
+
+ return {
+ 'val_loss': total_val_loss / len(self.valid_data_loader),
+ 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist()
+ }
+
+ def _test_epoch(self, epoch):
+ """
+ Test after training an epoch
+
+ :return: A log that contains information about test
+
+ Note:
+ The Test metrics in log must have the key 'val_metrics'.
+ """
+ self.model.eval()
+ total_test_loss = 0
+ total_test_metrics = np.zeros(len(self.metrics))
+ results = np.zeros((len(self.test_data_loader.dataset), self.config['num_classes']), dtype=np.float32)
+ tar_ = np.zeros((len(self.test_data_loader.dataset),), dtype=np.float32)
+ with torch.no_grad():
+ with tqdm(self.test_data_loader) as progress:
+ for batch_idx, (data, label,indexs,_) in enumerate(progress):
+ progress.set_description_str(f'Test epoch {epoch}')
+ data, label = data.to(self.device), label.to(self.device)
+ output = self.model(data)
+
+ loss = self.val_criterion(output, label)
+
+ self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test')
+ self.writer.add_scalar('loss', loss.item())
+ self.test_loss_list.append(loss.item())
+ total_test_loss += loss.item()
+ total_test_metrics += self._eval_metrics(output, label)
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
+
+ results[indexs.cpu().detach().numpy().tolist()] = output.cpu().detach().numpy().tolist()
+ tar_[indexs.cpu().detach().numpy().tolist()] = label.cpu().detach().numpy().tolist()
+
+ # add histogram of model parameters to the tensorboard
+ for name, p in self.model.named_parameters():
+ self.writer.add_histogram(name, p, bins='auto')
+
+ return {
+ 'test_loss': total_test_loss / len(self.test_data_loader),
+ 'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist()
+ },[results,tar_]
+
+
+ def _warmup_epoch(self, epoch):
+ total_loss = 0
+ total_metrics = np.zeros(len(self.metrics))
+ self.model.train()
+
+ data_loader = self.data_loader#self.loader.run('warmup')
+
+
+ with tqdm(data_loader) as progress:
+ for batch_idx, (data, label, _, indexs , _) in enumerate(progress):
+ progress.set_description_str(f'Warm up epoch {epoch}')
+
+ data, label = data.to(self.device), label.long().to(self.device)
+
+ self.optimizer.zero_grad()
+ output = self.model(data)
+ out_prob = torch.nn.functional.softmax(output).data.detach()
+
+ self.train_criterion.update_hist(indexs.cpu().detach().numpy().tolist(), out_prob)
+
+ loss = torch.nn.functional.cross_entropy(output, label)
+
+ loss.backward()
+ self.optimizer.step()
+
+ self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
+ self.writer.add_scalar('loss', loss.item())
+ self.train_loss_list.append(loss.item())
+ total_loss += loss.item()
+ total_metrics += self._eval_metrics(output, label)
+
+
+ if batch_idx % self.log_step == 0:
+ progress.set_postfix_str(' {} Loss: {:.6f}'.format(
+ self._progress(batch_idx),
+ loss.item()))
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
+
+ if batch_idx == self.len_epoch:
+ break
+ if hasattr(self.data_loader, 'run'):
+ self.data_loader.run()
+ log = {
+ 'loss': total_loss / self.len_epoch,
+ 'noise detection rate' : 0.0,
+ 'metrics': (total_metrics / self.len_epoch).tolist(),
+ 'learning rate': self.lr_scheduler.get_lr()
+ }
+
+ if self.do_validation:
+ val_log = self._valid_epoch(epoch)
+ log.update(val_log)
+ if self.do_test:
+ test_log, test_meta = self._test_epoch(epoch)
+ log.update(test_log)
+ else:
+ test_meta = [0,0]
+
+ return log
+
+
+ def _progress(self, batch_idx):
+ base = '[{}/{} ({:.0f}%)]'
+ if hasattr(self.data_loader, 'n_samples'):
+ current = batch_idx * self.data_loader.batch_size
+ total = self.data_loader.n_samples
+ else:
+ current = batch_idx
+ total = self.len_epoch
+ return base.format(current, total, 100.0 * current / total)
\ No newline at end of file
diff --git a/ELR/utils/__init__.py b/ELR/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..74ace79a2790b01f71138e46a6b4243d416113ab
--- /dev/null
+++ b/ELR/utils/__init__.py
@@ -0,0 +1 @@
+from .util import *
diff --git a/ELR/utils/util.py b/ELR/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cac31876b100ebbb5e7f96152adf149048a6cf1
--- /dev/null
+++ b/ELR/utils/util.py
@@ -0,0 +1,75 @@
+import json
+from pathlib import Path
+from datetime import datetime
+from itertools import repeat
+from collections import OrderedDict
+import numpy as np
+
+def ensure_dir(dirname):
+ dirname = Path(dirname)
+ if not dirname.is_dir():
+ dirname.mkdir(parents=True, exist_ok=False)
+
+
+def read_json(fname):
+ with fname.open('rt') as handle:
+ return json.load(handle, object_hook=OrderedDict)
+
+
+def write_json(content, fname):
+ with fname.open('wt') as handle:
+ json.dump(content, handle, indent=4, sort_keys=False)
+
+
+def inf_loop(data_loader):
+ ''' wrapper function for endless data loader. '''
+ for loader in repeat(data_loader):
+ yield from loader
+
+
+class Timer:
+ def __init__(self):
+ self.cache = datetime.now()
+
+ def check(self):
+ now = datetime.now()
+ duration = now - self.cache
+ self.cache = now
+ return duration.total_seconds()
+
+ def reset(self):
+ self.cache = datetime.now()
+
+
+
+def sigmoid_rampup(current, rampup_length):
+ """Exponential rampup from 2"""
+ if rampup_length == 0:
+ return 1.0
+ else:
+ current = np.clip(current, 0.0, rampup_length)
+ phase = 1.0 - current / rampup_length
+ return float(np.exp(-5.0 * phase * phase))
+
+
+def linear_rampup(current, rampup_length):
+ """Linear rampup"""
+ assert current >= 0 and rampup_length >= 0
+ if current >= rampup_length:
+ return 1.0
+ else:
+ return current / rampup_length
+
+
+def cosine_rampdown(current, rampdown_length):
+ """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
+ current = np.clip(current, 0.0, rampdown_length)
+ return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
+
+
+def cosine_rampup(current, rampup_length):
+ """Cosine rampup"""
+ current = np.clip(current, 0.0, rampup_length)
+ return float(-.5 * (np.cos(np.pi * current / rampup_length) - 1))
+
+
diff --git a/ELR_plus/README.md b/ELR_plus/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c78c725a1107c3c3f91bc551ecd88635fe237e84
--- /dev/null
+++ b/ELR_plus/README.md
@@ -0,0 +1,27 @@
+# ELR+
+This is an official PyTorch implementation of ELR+ method proposed in [Early-Learning Regularization Prevents Memorization of Noisy Labels](https://arxiv.org/abs/2007.00151).
+
+
+## Usage
+Train the network on the Symmmetric Noise CIFAR-10 dataset (noise rate = 0.8):
+
+```
+python train.py -c config_cifar10.json --percent 0.8
+```
+Train the network on the Asymmmetric Noise CIFAR-10 dataset (noise rate = 0.4):
+
+```
+python train.py -c config_cifar10_asym.json --percent 0.4
+```
+
+Train the network on the Asymmmetric Noise CIFAR-100 dataset (noise rate = 0.4):
+
+```
+python train.py -c config_cifar100.json --percent 0.4 --asym 1
+```
+
+The config files can be modified to adjust hyperparameters and optimization settings.
+
+
+## References
+- S. Liu, J. Niles-Weed, N. Razavian and C. Fernandez-Granda "Early-Learning Regularization Prevents Memorization of Noisy Labels", 2020
diff --git a/ELR_plus/base/__init__.py b/ELR_plus/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e28cb1d23b331d44d7ad6e26dd5906f1f18fb657
--- /dev/null
+++ b/ELR_plus/base/__init__.py
@@ -0,0 +1,3 @@
+from .base_data_loader import *
+from .base_model import *
+from .base_trainer import *
diff --git a/ELR_plus/base/base_data_loader.py b/ELR_plus/base/base_data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb7a8b410fc1e3c813db7790251ec23eb51f1a65
--- /dev/null
+++ b/ELR_plus/base/base_data_loader.py
@@ -0,0 +1,83 @@
+from typing import Tuple, Union, Optional
+
+import numpy as np
+from torch.utils.data import DataLoader
+from torch.utils.data.dataloader import default_collate
+from torch.utils.data.sampler import SubsetRandomSampler
+
+
+class BaseDataLoader(DataLoader):
+ """
+ Base class for all data loaders
+ """
+ valid_sampler: Optional[SubsetRandomSampler]
+ sampler: Optional[SubsetRandomSampler]
+
+ def __init__(self, train_dataset, batch_size, shuffle, validation_split: float, num_workers, pin_memory,
+ collate_fn=default_collate, val_dataset=None):
+ self.collate_fn = collate_fn
+ self.validation_split = validation_split
+ self.shuffle = shuffle
+ self.val_dataset = val_dataset
+
+ self.batch_idx = 0
+ self.n_samples = len(train_dataset) if val_dataset is None else len(train_dataset) + len(val_dataset)
+ self.init_kwargs = {
+ 'dataset': train_dataset,
+ 'batch_size': batch_size,
+ 'shuffle': self.shuffle,
+ 'collate_fn': collate_fn,
+ 'num_workers': num_workers,
+ 'pin_memory': pin_memory
+ }
+ if val_dataset is None:
+ self.sampler, self.valid_sampler = self._split_sampler(self.validation_split)
+ super().__init__(sampler=self.sampler, **self.init_kwargs)
+ else:
+ super().__init__(**self.init_kwargs)
+
+ def _split_sampler(self, split) -> Union[Tuple[None, None], Tuple[SubsetRandomSampler, SubsetRandomSampler]]:
+ if split == 0.0:
+ return None, None
+
+ idx_full = np.arange(self.n_samples)
+
+ np.random.seed(0)
+ np.random.shuffle(idx_full)
+
+ if isinstance(split, int):
+ assert split > 0
+ assert split < self.n_samples, "validation set size is configured to be larger than entire dataset."
+ len_valid = split
+ else:
+ len_valid = int(self.n_samples * split)
+
+ valid_idx = idx_full[0:len_valid]
+ train_idx = np.delete(idx_full, np.arange(0, len_valid))
+
+ train_sampler = SubsetRandomSampler(train_idx)
+ valid_sampler = SubsetRandomSampler(valid_idx)
+ print(f"Train: {len(train_sampler)} Val: {len(valid_sampler)}")
+
+ # turn off shuffle option which is mutually exclusive with sampler
+ self.shuffle = False
+ self.n_samples = len(train_idx)
+
+ return train_sampler, valid_sampler
+
+ def split_validation(self, bs = 1000):
+ if self.val_dataset is not None:
+ kwargs = {
+ 'dataset': self.val_dataset,
+ 'batch_size': bs,
+ 'shuffle': False,
+ 'collate_fn': self.collate_fn,
+ 'num_workers': self.num_workers
+ }
+ return DataLoader(**kwargs)
+ else:
+ print('Using sampler to split!')
+ return DataLoader(sampler=self.valid_sampler, **self.init_kwargs)
+
+
+
diff --git a/ELR_plus/base/base_model.py b/ELR_plus/base/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8caaade05c698e328b3df8af1de24ec9942992e1
--- /dev/null
+++ b/ELR_plus/base/base_model.py
@@ -0,0 +1,25 @@
+import torch.nn as nn
+import numpy as np
+from abc import abstractmethod
+
+
+class BaseModel(nn.Module):
+ """
+ Base class for all models
+ """
+ @abstractmethod
+ def forward(self, *inputs):
+ """
+ Forward pass logic
+
+ :return: Model output
+ """
+ raise NotImplementedError
+
+ def __str__(self):
+ """
+ Model prints with number of trainable parameters
+ """
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
+ params = sum([np.prod(p.size()) for p in model_parameters])
+ return super().__str__() + '\nTrainable parameters: {}'.format(params)
diff --git a/ELR_plus/base/base_trainer.py b/ELR_plus/base/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1aad45e59ade0d2c5d34fae475f3ce7d51b402b
--- /dev/null
+++ b/ELR_plus/base/base_trainer.py
@@ -0,0 +1,341 @@
+from typing import TypeVar, List, Tuple
+import torch
+from tqdm import tqdm
+from abc import abstractmethod
+from numpy import inf
+from logger import TensorboardWriter
+import numpy as np
+
+
+class BaseTrainer:
+ """
+ Base class for all trainers
+ """
+ def __init__(self, model1, model2, model_ema1, model_ema2, train_criterion1,
+ train_criterion2, metrics, optimizer1, optimizer2, config, val_criterion,
+ model_ema1_copy, model_ema2_copy):
+ self.config = config.config
+
+ self.logger = config.get_logger('trainer', config['trainer']['verbosity'])
+
+
+ # setup GPU device if available, move model into configured device
+ self.device, self.device_ids = self._prepare_device(config['n_gpu'])
+
+ if len(self.device_ids) > 1:
+ print('Using Multi-Processing!')
+
+ self.model1 = model1.to(self.device+str(self.device_ids[0]))
+ self.model2 = model2.to(self.device+str(self.device_ids[-1]))
+
+ if model_ema1 is not None:
+ self.model_ema1 = model_ema1.to(self.device+str(self.device_ids[0]))
+ self.model_ema2_copy = model_ema2_copy.to(self.device+str(self.device_ids[0]))
+ else:
+ self.model_ema1 = None
+ self.model_ema2_copy = None
+
+ if model_ema2 is not None:
+ self.model_ema2 = model_ema2.to(self.device+str(self.device_ids[-1]))
+ self.model_ema1_copy = model_ema1_copy.to(self.device+str(self.device_ids[-1]))
+ else:
+ self.model_ema2 = None
+ self.model_ema1_copy = None
+
+ if self.model_ema1 is not None:
+ for param in self.model_ema1.parameters():
+ param.detach_()
+
+ for param in self.model_ema2_copy.parameters():
+ param.detach_()
+
+ if self.model_ema2 is not None:
+ for param in self.model_ema2.parameters():
+ param.detach_()
+
+ for param in self.model_ema1_copy.parameters():
+ param.detach_()
+
+
+ self.train_criterion1 = train_criterion1.to(self.device+str(self.device_ids[0]))
+ self.train_criterion2 = train_criterion2.to(self.device+str(self.device_ids[-1]))
+
+ self.val_criterion = val_criterion
+
+ self.metrics = metrics
+
+ self.optimizer1 = optimizer1
+ self.optimizer2 = optimizer2
+
+ cfg_trainer = config['trainer']
+ self.epochs = cfg_trainer['epochs']
+ self.save_period = cfg_trainer['save_period']
+ self.monitor = cfg_trainer.get('monitor', 'off')
+
+ # configuration to monitor model performance and save best
+ if self.monitor == 'off':
+ self.mnt_mode = 'off'
+ self.mnt_best = 0
+ else:
+ self.mnt_mode, self.mnt_metric = self.monitor.split()
+ assert self.mnt_mode in ['min', 'max']
+
+ self.mnt_best = inf if self.mnt_mode == 'min' else -inf
+ self.early_stop = cfg_trainer.get('early_stop', inf)
+
+ self.start_epoch = 1
+
+ self.global_step = 0
+
+ self.checkpoint_dir = config.save_dir
+
+ # setup visualization writer instance
+ self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])
+
+ if config.resume is not None:
+ self._resume_checkpoint(config.resume)
+
+
+
+ @abstractmethod
+ def _train_epoch(self, epoch):
+ """
+ Training logic for an epoch
+
+ :param epoch: Current epochs number
+ """
+ raise NotImplementedError
+
+
+
+ def train(self):
+ """
+ Full training logic
+ """
+
+ if len(self.device_ids) > 1:
+ import torch.multiprocessing as mp
+ mp.set_start_method('spawn', force =True)
+
+ not_improved_count = 0
+
+ for epoch in tqdm(range(self.start_epoch, self.epochs + 1), desc='Total progress: '):
+ if epoch <= self.config['trainer']['warmup']:
+ if len(self.device_ids) > 1:
+ q1 = mp.Queue()
+ q2 = mp.Queue()
+ p1 = mp.Process(target=self._warmup_epoch, args=(epoch, self.model1, self.data_loader1, self.optimizer1, self.train_criterion1, self.lr_scheduler1, self.device+str(self.device_ids[0]), q1 ))
+ p2 = mp.Process(target=self._warmup_epoch, args=(epoch, self.model2, self.data_loader2, self.optimizer2, self.train_criterion2, self.lr_scheduler2, self.device+str(self.device_ids[-1]), q2))
+ p1.start()
+ p2.start()
+ result1 = q1.get()
+ result2 = q2.get()
+ p1.join()
+ p2.join()
+ else:
+ result1 = self._warmup_epoch(epoch, self.model1, self.data_loader1, self.optimizer1, self.train_criterion1, self.lr_scheduler1, self.device+str(self.device_ids[0]))
+ result2 = self._warmup_epoch(epoch, self.model2, self.data_loader2, self.optimizer2, self.train_criterion2, self.lr_scheduler2, self.device+str(self.device_ids[-1]))
+
+ if len(self.device_ids) > 1:
+ self.model_ema1_copy.load_state_dict(self.model_ema1.state_dict())
+ self.model_ema2_copy.load_state_dict(self.model_ema2.state_dict())
+ if self.do_validation:
+ q1 = mp.Queue()
+ p1 = mp.Process(target=self._valid_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q1))
+
+ if self.do_test:
+ q2 = mp.Queue()
+ p2 = mp.Process(target=self._test_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q2))
+ p1.start()
+ p2.start()
+ val_log = q1.get()
+ test_log, test_meta = q2.get()
+ result1.update(val_log)
+ result2.update(val_log)
+ result1.update(test_log)
+ result2.update(test_log)
+ p1.join()
+ p2.join()
+ else:
+ if self.do_validation:
+ val_log = self._valid_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
+ result1.update(val_log)
+ result2.update(val_log)
+ if self.do_test:
+ test_log, test_meta = self._test_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
+ result1.update(test_log)
+ result2.update(test_log)
+ else:
+ test_meta = [0,0]
+
+ else:
+ if len(self.device_ids) > 1:
+ q1 = mp.Queue()
+ q2 = mp.Queue()
+ p1 = mp.Process(target=self._train_epoch, args=(epoch, self.model1, self.model_ema1, self.model_ema2_copy, self.data_loader1, self.train_criterion1, self.optimizer1, self.lr_scheduler1, self.device+str(self.device_ids[0]), q1 ))
+ p2 = mp.Process(target=self._train_epoch, args=(epoch, self.model2, self.model_ema2, self.model_ema1_copy, self.data_loader2, self.train_criterion2, self.optimizer2, self.lr_scheduler2, self.device+str(self.device_ids[-1]), q2 ))
+ p1.start()
+ p2.start()
+ result1 = q1.get()
+ result2 = q2.get()
+ p1.join()
+ p2.join()
+ else:
+ result1 = self._train_epoch(epoch, self.model1, self.model_ema1, self.model_ema2, self.data_loader1, self.train_criterion1, self.optimizer1, self.lr_scheduler1, self.device+str(self.device_ids[0]))
+ result2 = self._train_epoch(epoch, self.model2, self.model_ema2, self.model_ema1, self.data_loader2, self.train_criterion2, self.optimizer2, self.lr_scheduler2, self.device+str(self.device_ids[-1]))
+
+
+ self.global_step += result1['local_step']
+ if len(self.device_ids) > 1:
+ self.model_ema1_copy.load_state_dict(self.model_ema1.state_dict())
+ self.model_ema2_copy.load_state_dict(self.model_ema2.state_dict())
+ if self.do_validation:
+ q1 = mp.Queue()
+ p1 = mp.Process(target=self._valid_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q1))
+
+ if self.do_test:
+ q2 = mp.Queue()
+ p2 = mp.Process(target=self._test_epoch, args=(epoch, self.model1, self.model_ema2_copy, self.device+str(self.device_ids[0]),q2))
+ p1.start()
+ p2.start()
+ val_log = q1.get()
+ test_log = q2.get()
+ result1.update(val_log)
+ result2.update(val_log)
+ result1.update(test_log)
+ result2.update(test_log)
+ p1.join()
+ p2.join()
+ else:
+ if self.do_validation:
+ val_log = self._valid_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
+ result1.update(val_log)
+ result2.update(val_log)
+ if self.do_test:
+ test_log = self._test_epoch(epoch, self.model1, self.model2, self.device+str(self.device_ids[0]))
+ result1.update(test_log)
+ result2.update(test_log)
+
+
+
+ # save logged informations into log dict
+ log = {'epoch': epoch}
+ for key, value in result1.items():
+ if key == 'metrics':
+ log.update({'Net1' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
+ log.update({'Net2' + mtr.__name__: result2[key][i] for i, mtr in enumerate(self.metrics)})
+ elif key == 'val_metrics':
+ log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
+ elif key == 'test_metrics':
+ log.update({'test_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)})
+ else:
+ log['Net1'+key] = value
+ log['Net2'+key] = result2[key]
+
+ # print logged informations to the screen
+ for key, value in log.items():
+ self.logger.info(' {:15s}: {}'.format(str(key), value))
+
+ # evaluate model performance according to configured metric, save best checkpoint as model_best
+ best = False
+ if self.mnt_mode != 'off':
+ try:
+ # check whether model performance improved or not, according to specified metric(mnt_metric)
+ improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
+ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
+ except KeyError:
+ self.logger.warning("Warning: Metric '{}' is not found. "
+ "Model performance monitoring is disabled.".format(self.mnt_metric))
+ self.mnt_mode = 'off'
+ improved = False
+
+ if improved:
+ self.mnt_best = log[self.mnt_metric]
+ not_improved_count = 0
+ best = True
+ else:
+ not_improved_count += 1
+
+ if not_improved_count > self.early_stop:
+ self.logger.info("Validation performance didn\'t improve for {} epochs. "
+ "Training stops.".format(self.early_stop))
+ break
+
+ if epoch % self.save_period == 0:
+ self._save_checkpoint(epoch, save_best=best)
+
+
+ def _prepare_device(self, n_gpu_use):
+ """
+ setup GPU device if available, move model into configured device
+ """
+ n_gpu = torch.cuda.device_count()
+ if n_gpu_use > 0 and n_gpu == 0:
+ self.logger.warning("Warning: There\'s no GPU available on this machine,"
+ "training will be performed on CPU.")
+ n_gpu_use = 0
+ if n_gpu_use > n_gpu:
+ self.logger.warning("Warning: The number of GPU\'s configured to use is {}, but only {} are available "
+ "on this machine.".format(n_gpu_use, n_gpu))
+ n_gpu_use = n_gpu
+ device = 'cuda:'#torch.device('cuda:' if n_gpu_use > 0 else 'cpu')
+ list_ids = list(range(n_gpu_use))
+ return device, list_ids
+
+ def _save_checkpoint(self, epoch, save_best=False):
+ """
+ Saving checkpoints
+
+ :param epoch: current epoch number
+ :param log: logging information of the epoch
+ :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
+ """
+ arch = type(self.model1).__name__
+
+ state = {
+ 'arch': arch,
+ 'epoch': epoch,
+ 'state_dict1': self.model1.state_dict(),
+ 'state_dict2': self.model2.state_dict(),
+ 'optimizer1': self.optimizer1.state_dict(),
+ 'optimizer2': self.optimizer2.state_dict(),
+ 'monitor_best': self.mnt_best
+ #'config': self.config
+ }
+ filename = str(self.checkpoint_dir / 'checkpoint-epoch{}.pth'.format(epoch))
+ torch.save(state, filename)
+ self.logger.info("Saving checkpoint: {} ...".format(filename))
+ if save_best:
+ best_path = str(self.checkpoint_dir / 'model_best.pth')
+ torch.save(state, best_path)
+ self.logger.info("Saving current best: model_best.pth at: {} ...".format(best_path))
+
+
+
+ def _resume_checkpoint(self, resume_path):
+ """
+ Resume from saved checkpoints
+
+ :param resume_path: Checkpoint path to be resumed
+ """
+ resume_path = str(resume_path)
+ self.logger.info("Loading checkpoint: {} ...".format(resume_path))
+ checkpoint = torch.load(resume_path)
+ self.start_epoch = checkpoint['epoch'] + 1
+ self.mnt_best = checkpoint['monitor_best']
+
+ # load architecture params from checkpoint.
+ if checkpoint['config']['arch'] != self.config['arch1']:
+ self.logger.warning("Warning: Architecture configuration given in config file is different from that of "
+ "checkpoint. This may yield an exception while state_dict is being loaded.")
+ self.model.load_state_dict(checkpoint['state_dict'])
+
+ # load optimizer state from checkpoint only when optimizer type is not changed.
+ if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']:
+ self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. "
+ "Optimizer parameters not being resumed.")
+ else:
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+
+ self.logger.info("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
+
diff --git a/ELR_plus/config_cifar10.json b/ELR_plus/config_cifar10.json
new file mode 100644
index 0000000000000000000000000000000000000000..256c25806cf5d50f8029c640d6382614966e41d3
--- /dev/null
+++ b/ELR_plus/config_cifar10.json
@@ -0,0 +1,105 @@
+{
+ "name": "cifar10_ELR_plus_PreActResNet18",
+ "n_gpu": 1,
+ "seed":123,
+
+ "arch": {
+ "args": {"num_classes":10}
+ },
+
+ "arch1": {
+ "type": "PreActResNet18",
+ "args": {"num_classes":10}
+ },
+
+ "arch2": {
+ "type": "PreActResNet18",
+ "args": {"num_classes":10}
+ },
+
+ "mixup_alpha": 1,
+ "coef_step": 0,
+ "num_classes": 10,
+ "ema_alpha": 0.997,
+ "ema_update": true,
+ "ema_step": 40000,
+
+
+ "data_loader": {
+ "type": "CIFAR10DataLoader",
+ "args":{
+ "data_dir": "/dir/to/data",
+ "batch_size": 128,
+ "batch_size2": 128,
+ "num_batches": 0,
+ "shuffle": true,
+ "validation_split": 0,
+ "num_workers": 8,
+ "pin_memory": true
+ }
+ },
+
+
+ "optimizer1": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 5e-4
+ }
+ },
+
+ "optimizer2": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 5e-4
+ }
+ },
+
+
+
+ "train_loss": {
+ "type": "elr_plus_loss",
+ "args":{
+ "beta": 0.7,
+ "lambda": 3
+ }
+ },
+
+ "val_loss": "cross_entropy",
+ "metrics": [
+ "my_metric", "my_metric2"
+ ],
+
+ "lr_scheduler": {
+ "type": "MultiStepLR",
+ "args": {
+ "milestones": [150],
+ "gamma": 0.1
+ }
+ },
+
+ "trainer": {
+ "epochs": 200,
+ "warmup": 0,
+ "save_dir": "dir/to/model",
+ "save_period": 1,
+ "verbosity": 2,
+ "label_dir": "saved/",
+
+ "monitor": "max val_my_metric",
+ "early_stop": 2000,
+
+ "tensorboard": false,
+ "mlflow": true,
+
+ "_percent": "Percentage of noise",
+ "percent": 0.8,
+ "_begin": "When to begin updating labels",
+ "begin": 0,
+ "_asym": "symmetric noise if false",
+ "asym": false
+ }
+}
diff --git a/ELR_plus/config_cifar100.json b/ELR_plus/config_cifar100.json
new file mode 100644
index 0000000000000000000000000000000000000000..09033d360473c11975cc3396b5c003de9924f2da
--- /dev/null
+++ b/ELR_plus/config_cifar100.json
@@ -0,0 +1,104 @@
+{
+ "name": "cifar100_ELR_plus_PreActResNet18",
+ "n_gpu": 1,
+ "seed":123,
+
+ "arch": {
+ "args": {"num_classes":100}
+ },
+
+ "arch1": {
+ "type": "PreActResNet18",
+ "args": {"num_classes":100}
+ },
+
+ "arch2": {
+ "type": "PreActResNet18",
+ "args": {"num_classes":100}
+ },
+
+ "mixup_alpha": 1,
+ "coef_step": 40000,
+ "num_classes": 100,
+ "ema_alpha": 0.997,
+ "ema_update": true,
+ "ema_step": 40000,
+
+
+ "data_loader": {
+ "type": "CIFAR100DataLoader",
+ "args":{
+ "data_dir": "/gpfs/scratch/sl5924/noisy/data/",
+ "batch_size": 128,
+ "batch_size2": 128,
+ "num_batches": 0,
+ "shuffle": true,
+ "validation_split": 0,
+ "num_workers": 8,
+ "pin_memory": true
+ }
+ },
+
+ "optimizer1": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 5e-4
+ }
+ },
+
+ "optimizer2": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 5e-4
+ }
+ },
+
+
+
+ "train_loss": {
+ "type": "elr_plus_loss",
+ "args":{
+ "beta": 0.9,
+ "lambda": 7
+ }
+ },
+
+ "val_loss": "cross_entropy",
+ "metrics": [
+ "my_metric", "my_metric2"
+ ],
+
+ "lr_scheduler": {
+ "type": "MultiStepLR",
+ "args": {
+ "milestones": [200],
+ "gamma": 0.1
+ }
+ },
+
+ "trainer": {
+ "epochs": 250,
+ "warmup": 0,
+ "save_dir": "/gpfs/data/razavianlab/skynet/alzheimers/noisy_label/saved/",
+ "save_period": 1,
+ "verbosity": 2,
+ "label_dir": "saved/",
+
+ "monitor": "max val_my_metric",
+ "early_stop": 2000,
+
+ "tensorboard": false,
+ "mlflow": true,
+
+ "_percent": "Percentage of noise",
+ "percent": 0.8,
+ "_begin": "When to begin updating labels",
+ "begin": 0,
+ "_asym": "symmetric noise if false",
+ "asym": false
+ }
+}
diff --git a/ELR_plus/config_cifar10_asym.json b/ELR_plus/config_cifar10_asym.json
new file mode 100644
index 0000000000000000000000000000000000000000..66b953090783f6458a94fb2fb5ab947a4cc2be09
--- /dev/null
+++ b/ELR_plus/config_cifar10_asym.json
@@ -0,0 +1,105 @@
+{
+ "name": "cifar10_ELR_plus_PreActResNet18",
+ "n_gpu": 1,
+ "seed":123,
+
+ "arch": {
+ "args": {"num_classes":10}
+ },
+
+ "arch1": {
+ "type": "PreActResNet18",
+ "args": {"num_classes":10}
+ },
+
+ "arch2": {
+ "type": "PreActResNet18",
+ "args": {"num_classes":10}
+ },
+
+ "mixup_alpha": 1,
+ "coef_step": 0,
+ "num_classes": 10,
+ "ema_alpha": 0.997,
+ "ema_update": true,
+ "ema_step": 40000,
+
+
+ "data_loader": {
+ "type": "CIFAR10DataLoader",
+ "args":{
+ "data_dir": "dir/to/data",
+ "batch_size": 128,
+ "batch_size2": 128,
+ "num_batches": 0,
+ "shuffle": true,
+ "validation_split": 0,
+ "num_workers": 8,
+ "pin_memory": true
+ }
+ },
+
+
+ "optimizer1": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 5e-4
+ }
+ },
+
+ "optimizer2": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 5e-4
+ }
+ },
+
+
+
+ "train_loss": {
+ "type": "elr_plus_loss",
+ "args":{
+ "beta": 0.9,
+ "lambda": 1
+ }
+ },
+
+ "val_loss": "cross_entropy",
+ "metrics": [
+ "my_metric", "my_metric2"
+ ],
+
+ "lr_scheduler": {
+ "type": "MultiStepLR",
+ "args": {
+ "milestones": [150],
+ "gamma": 0.1
+ }
+ },
+
+ "trainer": {
+ "epochs": 200,
+ "warmup": 0,
+ "save_dir": "dir/to/model",
+ "save_period": 1,
+ "verbosity": 2,
+ "label_dir": "saved/",
+
+ "monitor": "max val_my_metric",
+ "early_stop": 2000,
+
+ "tensorboard": false,
+ "mlflow": true,
+
+ "_percent": "Percentage of noise",
+ "percent": 0.4,
+ "_begin": "When to begin updating labels",
+ "begin": 0,
+ "_asym": "symmetric noise if false",
+ "asym": true
+ }
+}
diff --git a/ELR_plus/config_clothing1m.json b/ELR_plus/config_clothing1m.json
new file mode 100644
index 0000000000000000000000000000000000000000..e057d299394f8809e50f684e60b673d7272ac9e0
--- /dev/null
+++ b/ELR_plus/config_clothing1m.json
@@ -0,0 +1,102 @@
+{
+ "name": "clothing1M_ELR_plus_resnet50",
+ "n_gpu": 1,
+ "seed":123,
+
+
+ "arch1": {
+ "type": "resnet50",
+ "args": {"num_classes":14}
+ },
+
+ "arch2": {
+ "type": "resnet50",
+ "args": {"num_classes":14}
+ },
+
+ "mixup_alpha": 1,
+ "coef_step": 0,
+ "num_classes": 14,
+ "ema_alpha": 0.9999,
+ "ema_update": false,
+ "ema_step": -1,
+
+
+ "data_loader": {
+ "type": "Clothing1MDataLoader",
+ "args":{
+ "data_dir": "/gpfs/data/razavianlab/skynet/alzheimers/noisy_label/clothing1M/images",
+ "batch_size": 64,
+ "batch_size2": 64,
+ "num_batches": 3000,
+ "shuffle": true,
+ "validation_split": 0,
+ "num_workers": 8,
+ "pin_memory": true
+ }
+ },
+
+ "optimizer1": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.002,
+ "momentum": 0.9,
+ "weight_decay": 1e-3
+ }
+ },
+
+ "optimizer2": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.002,
+ "momentum": 0.9,
+ "weight_decay": 1e-3
+ }
+ },
+
+
+
+ "train_loss": {
+ "type": "elr_plus_loss",
+ "args":{
+ "beta": 0.7,
+ "lambda": 3
+ }
+ },
+
+ "val_loss": "cross_entropy",
+ "metrics": [
+ "my_metric", "my_metric2"
+ ],
+
+ "lr_scheduler": {
+ "type": "MultiStepLR",
+ "args": {
+ "milestones": [7],
+ "gamma": 0.1
+ }
+ },
+
+ "trainer": {
+ "epochs": 15,
+ "warmup": 0,
+ "save_dir": "/gpfs/data/razavianlab/skynet/alzheimers/noisy_label/saved/",
+ "save_period": 1,
+ "verbosity": 2,
+ "label_dir": "saved/",
+
+ "monitor": "max val_my_metric",
+ "early_stop": 2000,
+
+ "tensorboard": true,
+ "mlflow": true,
+
+ "_percent": "Percentage of noise",
+ "percent": 0.8,
+ "_begin": "When to begin updating labels",
+ "begin": 0,
+ "_asym": "symmetric noise if false",
+ "asym": false
+ }
+}
+
diff --git a/ELR_plus/config_webvision.json b/ELR_plus/config_webvision.json
new file mode 100644
index 0000000000000000000000000000000000000000..3f26c9bf4a98c98e5ed6e10eeb2221bb5eff3fe7
--- /dev/null
+++ b/ELR_plus/config_webvision.json
@@ -0,0 +1,103 @@
+{
+ "name": "Webvision_ELR_plus_InceptionResNetV2",
+ "n_gpu": 2,
+ "seed": 123,
+
+
+ "arch": {
+ "args": {"num_classes":50}
+ },
+
+ "arch1": {
+ "type": "InceptionResNetV2",
+ "args": {"num_classes":50}
+ },
+
+ "arch2": {
+ "type": "InceptionResNetV2",
+ "args": {"num_classes":50}
+ },
+
+
+ "mixup_alpha": 1.5,
+ "mixup_ramp": false,
+ "num_classes": 50,
+ "ema_alpha": 0.997,
+ "ema_update": false,
+ "ema_step": 40000,
+
+
+
+ "data_loader": {
+ "type": "WebvisionDataLoader",
+ "args":{
+ "data_dir": "/dir/to/data",
+ "batch_size": 32,
+ "batch_size2": 32,
+ "shuffle": true,
+ "num_batches": 0,
+ "validation_split": 0,
+ "num_workers": 8,
+ "pin_memory": true
+ }
+ },
+
+ "optimizer1": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 5e-4
+ }
+ },
+
+ "optimizer2": {
+ "type": "SGD",
+ "args":{
+ "lr": 0.02,
+ "momentum": 0.9,
+ "weight_decay": 5e-4
+ }
+ },
+
+
+ "train_loss": {
+ "type": "elr_plus_loss",
+ "args":{
+ "beta": 0.7,
+ "lambda": 3
+ }
+ },
+ "val_loss": "cross_entropy",
+ "metrics": [
+ "my_metric", "my_metric2"
+ ],
+ "lr_scheduler": {
+ "type": "MultiStepLR",
+ "args": {
+ "milestones": [50],
+ "gamma": 0.1
+ }
+ },
+ "trainer": {
+ "epochs": 100,
+ "warmup": 0,
+ "save_dir": "/dir/to/data",
+ "save_period": 1,
+ "verbosity": 2,
+ "label_dir": "saved/",
+
+ "monitor": "max val_my_metric",
+ "early_stop": 2000,
+
+ "tensorboard": false,
+ "mlflow": true,
+
+ "_percent": "Percentage of noise",
+ "percent": 0.9,
+ "_begin": "When to begin updating labels",
+ "begin": 0,
+ "_asym": "symmetric noise if false",
+ "asym": false
+ }
+}
diff --git a/ELR_plus/data_loader/cifar10.py b/ELR_plus/data_loader/cifar10.py
new file mode 100644
index 0000000000000000000000000000000000000000..618846eb2f9cf6d1cc37c932bf05118387659967
--- /dev/null
+++ b/ELR_plus/data_loader/cifar10.py
@@ -0,0 +1,214 @@
+import sys
+
+import numpy as np
+from PIL import Image
+import torchvision
+from torch.utils.data.dataset import Subset
+from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
+import torch
+import torch.nn.functional as F
+import random
+import json
+import os
+
+
+def get_cifar10(root, cfg_trainer, train=True,
+ transform_train=None, transform_val=None,
+ download=False, noise_file = ''):
+ base_dataset = torchvision.datasets.CIFAR10(root, train=train, download=download)
+ if train:
+ train_idxs, val_idxs = train_val_split(base_dataset.targets)
+ train_dataset = CIFAR10_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
+ val_dataset = CIFAR10_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
+ if cfg_trainer['asym']:
+ train_dataset.asymmetric_noise()
+ val_dataset.asymmetric_noise()
+ else:
+ train_dataset.symmetric_noise()
+ val_dataset.symmetric_noise()
+
+ print(f"Train: {len(train_idxs)} Val: {len(val_idxs)}") # Train: 45000 Val: 5000
+ else:
+ train_dataset = []
+ val_dataset = CIFAR10_val(root, cfg_trainer, None, train=train, transform=transform_val)
+ print(f"Test: {len(val_dataset)}")
+
+ return train_dataset, val_dataset
+
+
+def train_val_split(base_dataset: torchvision.datasets.CIFAR10):
+ num_classes = 10
+ base_dataset = np.array(base_dataset)
+ train_n = int(len(base_dataset) * 0.9 / num_classes)
+ train_idxs = []
+ val_idxs = []
+
+ for i in range(num_classes):
+ idxs = np.where(base_dataset == i)[0]
+ np.random.shuffle(idxs)
+ train_idxs.extend(idxs[:train_n])
+ val_idxs.extend(idxs[train_n:])
+ np.random.shuffle(train_idxs)
+ np.random.shuffle(val_idxs)
+
+ return train_idxs, val_idxs
+
+
+class CIFAR10_train(torchvision.datasets.CIFAR10):
+ def __init__(self, root, cfg_trainer, indexs, train=True,
+ transform=None, target_transform=None,
+ download=False):
+ super(CIFAR10_train, self).__init__(root, train=train,
+ transform=transform, target_transform=target_transform,
+ download=download)
+ self.num_classes = 10
+ self.cfg_trainer = cfg_trainer
+ self.train_data = self.data[indexs]#self.train_data[indexs]
+ self.train_labels = np.array(self.targets)[indexs]#np.array(self.train_labels)[indexs]
+ self.indexs = indexs
+ self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
+ self.noise_indx = []
+ #self.all_refs_encoded = torch.zeros(self.num_classes,self.num_ref,1024, dtype=np.float32)
+
+ def symmetric_noise(self):
+ self.train_labels_gt = self.train_labels.copy()
+ #np.random.seed(seed=888)
+ indices = np.random.permutation(len(self.train_data))
+ for i, idx in enumerate(indices):
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
+ self.noise_indx.append(idx)
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
+
+ def asymmetric_noise(self):
+ self.train_labels_gt = self.train_labels.copy()
+ for i in range(self.num_classes):
+ indices = np.where(self.train_labels == i)[0]
+ np.random.shuffle(indices)
+ for j, idx in enumerate(indices):
+ if j < self.cfg_trainer['percent'] * len(indices):
+ self.noise_indx.append(idx)
+ # truck -> automobile
+ if i == 9:
+ self.train_labels[idx] = 1
+ # bird -> airplane
+ elif i == 2:
+ self.train_labels[idx] = 0
+ # cat -> dog
+ elif i == 3:
+ self.train_labels[idx] = 5
+ # dog -> cat
+ elif i == 5:
+ self.train_labels[idx] = 3
+ # deer -> horse
+ elif i == 4:
+ self.train_labels[idx] = 7
+
+
+
+
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
+
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img)
+
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img,target, index, target_gt
+
+ def __len__(self):
+ return len(self.train_data)
+
+
+
+class CIFAR10_val(torchvision.datasets.CIFAR10):
+
+ def __init__(self, root, cfg_trainer, indexs, train=True,
+ transform=None, target_transform=None,
+ download=False):
+ super(CIFAR10_val, self).__init__(root, train=train,
+ transform=transform, target_transform=target_transform,
+ download=download)
+
+ # self.train_data = self.data[indexs]
+ # self.train_labels = np.array(self.targets)[indexs]
+ self.num_classes = 10
+ self.cfg_trainer = cfg_trainer
+ if train:
+ self.train_data = self.data[indexs]
+ self.train_labels = np.array(self.targets)[indexs]
+ else:
+ self.train_data = self.data
+ self.train_labels = np.array(self.targets)
+ self.train_labels_gt = self.train_labels.copy()
+ def symmetric_noise(self):
+
+ indices = np.random.permutation(len(self.train_data))
+ for i, idx in enumerate(indices):
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
+
+ def asymmetric_noise(self):
+ for i in range(self.num_classes):
+ indices = np.where(self.train_labels == i)[0]
+ np.random.shuffle(indices)
+ for j, idx in enumerate(indices):
+ if j < self.cfg_trainer['percent'] * len(indices):
+ # truck -> automobile
+ if i == 9:
+ self.train_labels[idx] = 1
+ # bird -> airplane
+ elif i == 2:
+ self.train_labels[idx] = 0
+ # cat -> dog
+ elif i == 3:
+ self.train_labels[idx] = 5
+ # dog -> cat
+ elif i == 5:
+ self.train_labels[idx] = 3
+ # deer -> horse
+ elif i == 4:
+ self.train_labels[idx] = 7
+ def __len__(self):
+ return len(self.train_data)
+
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
+
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img)
+
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target, index, target_gt
+
\ No newline at end of file
diff --git a/ELR_plus/data_loader/cifar100.py b/ELR_plus/data_loader/cifar100.py
new file mode 100644
index 0000000000000000000000000000000000000000..60c916a3951d36de7d30825356392ce479cb5e73
--- /dev/null
+++ b/ELR_plus/data_loader/cifar100.py
@@ -0,0 +1,307 @@
+import sys
+
+import numpy as np
+from PIL import Image
+import torchvision
+from torch.utils.data.dataset import Subset
+from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
+import torch
+import torch.nn.functional as F
+import random
+from numpy.testing import assert_array_almost_equal
+import os
+import json
+
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+def get_cifar100(root, cfg_trainer, train=True,
+ transform_train=None, transform_val=None,
+ download=False, noise_file = ''):
+ base_dataset = torchvision.datasets.CIFAR100(root, train=train, download=download)
+ if train:
+ train_idxs, val_idxs = train_val_split(base_dataset.targets)
+ train_dataset = CIFAR100_train(root, cfg_trainer, train_idxs, train=True, transform=transform_train)
+ val_dataset = CIFAR100_val(root, cfg_trainer, val_idxs, train=train, transform=transform_val)
+ if cfg_trainer['asym']:
+ train_dataset.asymmetric_noise()
+ val_dataset.asymmetric_noise()
+ else:
+ train_dataset.symmetric_noise()
+ val_dataset.symmetric_noise()
+
+ print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}\n") # Train: 45000 Val: 5000
+ else:
+ train_dataset = []
+ val_dataset = CIFAR100_val(root, cfg_trainer, None, train=train, transform=transform_val)
+ print(f"Test: {len(val_dataset)}\n")
+
+
+
+
+ return train_dataset, val_dataset
+
+
+def train_val_split(base_dataset: torchvision.datasets.CIFAR100):
+ num_classes = 100
+ base_dataset = np.array(base_dataset)
+ train_n = int(len(base_dataset) * 0.9 / num_classes)
+ train_idxs = []
+ val_idxs = []
+
+ for i in range(num_classes):
+ idxs = np.where(base_dataset == i)[0]
+ np.random.shuffle(idxs)
+ train_idxs.extend(idxs[:train_n])
+ val_idxs.extend(idxs[train_n:])
+ np.random.shuffle(train_idxs)
+ np.random.shuffle(val_idxs)
+
+ return train_idxs, val_idxs
+
+
+class CIFAR100_train(torchvision.datasets.CIFAR100):
+ def __init__(self, root, cfg_trainer, indexs, train=True,
+ transform=None, target_transform=None,
+ download=False):
+ super(CIFAR100_train, self).__init__(root, train=train,
+ transform=transform, target_transform=target_transform,
+ download=download)
+ self.num_classes = 100
+ self.cfg_trainer = cfg_trainer
+ self.train_data = self.data[indexs]
+ self.train_labels = np.array(self.targets)[indexs]
+ self.indexs = indexs
+ self.soft_labels = np.zeros((len(self.train_data), self.num_classes), dtype=np.float32)
+ self.prediction = np.zeros((len(self.train_data), self.num_classes, self.num_classes), dtype=np.float32)
+ self.noise_indx = []
+ #self.all_refs_encoded = torch.zeros(self.num_classes,self.num_ref,1024, dtype=np.float32)
+
+ self.count = 0
+
+ def symmetric_noise(self):
+ self.train_labels_gt = self.train_labels.copy()
+ np.random.seed(seed=888)
+ indices = np.random.permutation(len(self.train_data))
+ for i, idx in enumerate(indices):
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
+ self.noise_indx.append(idx)
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
+ self.soft_labels[idx][self.train_labels[idx]] = 1.
+
+ def multiclass_noisify(self, y, P, random_state=0):
+ """ Flip classes according to transition probability matrix T.
+ It expects a number between 0 and the number of classes - 1.
+ """
+
+ assert P.shape[0] == P.shape[1]
+ assert np.max(y) < P.shape[0]
+
+ # row stochastic matrix
+ assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
+ assert (P >= 0.0).all()
+
+ m = y.shape[0]
+ new_y = y.copy()
+ flipper = np.random.RandomState(random_state)
+
+ for idx in np.arange(m):
+ i = y[idx]
+ # draw a vector with only an 1
+ flipped = flipper.multinomial(1, P[i, :], 1)[0]
+ new_y[idx] = np.where(flipped == 1)[0]
+
+ return new_y
+
+ def build_for_cifar100(self, size, noise):
+ """ The noise matrix flips to the "next" class with probability 'noise'.
+ """
+
+ assert(noise >= 0.) and (noise <= 1.)
+
+ P = (1. - noise) * np.eye(size)
+ for i in np.arange(size - 1):
+ P[i, i + 1] = noise
+
+ # adjust last row
+ P[size - 1, 0] = noise
+
+ assert_array_almost_equal(P.sum(axis=1), 1, 1)
+ return P
+
+
+ def asymmetric_noise(self, asym=False, random_shuffle=False):
+ self.train_labels_gt = self.train_labels.copy()
+ P = np.eye(self.num_classes)
+ n = self.cfg_trainer['percent']
+ nb_superclasses = 20
+ nb_subclasses = 5
+
+ if n > 0.0:
+ for i in np.arange(nb_superclasses):
+ init, end = i * nb_subclasses, (i+1) * nb_subclasses
+ P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
+
+ y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
+ random_state=0)
+ actual_noise = (y_train_noisy != self.train_labels).mean()
+ assert actual_noise > 0.0
+ self.train_labels = y_train_noisy
+ #np.save(P_file, P)
+
+
+
+
+
+
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
+
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+
+ img = Image.fromarray(img)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ return img, target, index, target_gt
+
+ def __len__(self):
+ return len(self.train_data)
+
+ def rotate_img(self, img, rot):
+ if rot == 0: # 0 degrees rotation
+ return img
+ elif rot == 90: # 90 degrees rotation
+ return np.flipud(np.transpose(img, (1,0,2)))
+ elif rot == 180: # 90 degrees rotation
+ return np.fliplr(np.flipud(img))
+ elif rot == 270: # 270 degrees rotation / or -90
+ return np.transpose(np.flipud(img), (1,0,2))
+ else:
+ raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
+
+ def __len__(self):
+ return len(self.train_data)
+
+
+class CIFAR100_val(torchvision.datasets.CIFAR100):
+
+ def __init__(self, root, cfg_trainer, indexs, train=True,
+ transform=None, target_transform=None,
+ download=False):
+ super(CIFAR100_val, self).__init__(root, train=train,
+ transform=transform, target_transform=target_transform,
+ download=download)
+
+ # self.train_data = self.data[indexs]
+ # self.train_labels = np.array(self.targets)[indexs]
+ self.num_classes = 100
+ self.cfg_trainer = cfg_trainer
+ if train:
+ self.train_data = self.data[indexs]
+ self.train_labels = np.array(self.targets)[indexs]
+ else:
+ self.train_data = self.data
+ self.train_labels = np.array(self.targets)
+ self.train_labels_gt = self.train_labels.copy()
+ def symmetric_noise(self):
+ indices = np.random.permutation(len(self.train_data))
+ for i, idx in enumerate(indices):
+ if i < self.cfg_trainer['percent'] * len(self.train_data):
+ self.train_labels[idx] = np.random.randint(self.num_classes, dtype=np.int32)
+
+ def multiclass_noisify(self, y, P, random_state=0):
+ """ Flip classes according to transition probability matrix T.
+ It expects a number between 0 and the number of classes - 1.
+ """
+
+ assert P.shape[0] == P.shape[1]
+ assert np.max(y) < P.shape[0]
+
+ # row stochastic matrix
+ assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
+ assert (P >= 0.0).all()
+
+ m = y.shape[0]
+ new_y = y.copy()
+ flipper = np.random.RandomState(random_state)
+
+ for idx in np.arange(m):
+ i = y[idx]
+ # draw a vector with only an 1
+ flipped = flipper.multinomial(1, P[i, :], 1)[0]
+ new_y[idx] = np.where(flipped == 1)[0]
+
+ return new_y
+
+ def build_for_cifar100(self, size, noise):
+ """ The noise matrix flips to the "next" class with probability 'noise'.
+ """
+
+ assert(noise >= 0.) and (noise <= 1.)
+
+ P = (1. - noise) * np.eye(size)
+ for i in np.arange(size - 1):
+ P[i, i + 1] = noise
+
+ # adjust last row
+ P[size - 1, 0] = noise
+
+ assert_array_almost_equal(P.sum(axis=1), 1, 1)
+ return P
+
+
+ def asymmetric_noise(self, asym=False, random_shuffle=False):
+ P = np.eye(self.num_classes)
+ n = self.cfg_trainer['percent']
+ nb_superclasses = 20
+ nb_subclasses = 5
+
+ if n > 0.0:
+ for i in np.arange(nb_superclasses):
+ init, end = i * nb_subclasses, (i+1) * nb_subclasses
+ P[init:end, init:end] = self.build_for_cifar100(nb_subclasses, n)
+
+ y_train_noisy = self.multiclass_noisify(self.train_labels, P=P,
+ random_state=0)
+ actual_noise = (y_train_noisy != self.train_labels).mean()
+ assert actual_noise > 0.0
+ self.train_labels = y_train_noisy
+ def __len__(self):
+ return len(self.train_data)
+
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target, target_gt = self.train_data[index], self.train_labels[index], self.train_labels_gt[index]
+
+
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ img = Image.fromarray(img)
+
+
+ if self.transform is not None:
+ img = self.transform(img)
+
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target, index, target_gt
diff --git a/ELR_plus/data_loader/clothing1m.py b/ELR_plus/data_loader/clothing1m.py
new file mode 100644
index 0000000000000000000000000000000000000000..e04c85894ba241c0176ca260246a2ff991eb67a0
--- /dev/null
+++ b/ELR_plus/data_loader/clothing1m.py
@@ -0,0 +1,128 @@
+import sys
+import os
+import numpy as np
+from PIL import Image
+import torchvision
+from torch.utils.data.dataset import Subset
+from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
+import torch
+import torch.nn.functional as F
+import random
+
+def get_clothing(root, cfg_trainer, num_samples=0, train=True,
+ transform_train=None, transform_val=None):
+
+ if train:
+ train_dataset = Clothing(root, cfg_trainer, num_samples=num_samples, train=train, transform=transform_train)
+ val_dataset = Clothing(root, cfg_trainer, val=train, transform=transform_val)
+ print(f"Train: {len(train_dataset)} Val: {len(val_dataset)}")
+
+ else:
+ train_dataset = []
+ val_dataset = Clothing(root, cfg_trainer, test= (not train), transform=transform_val)
+ print(f"Test: {len(val_dataset)}")
+
+ return train_dataset, val_dataset
+
+class Clothing(torch.utils.data.Dataset):
+
+ def __init__(self, root, cfg_trainer, num_samples=0, train=False, val=False, test=False, transform=None, num_class = 14):
+ self.cfg_trainer = cfg_trainer
+ self.root = root
+ self.transform = transform
+ self.train_labels = {}
+ self.test_labels = {}
+ self.val_labels = {}
+
+ self.train = train
+ self.val = val
+ self.test = test
+
+ with open('%s/noisy_label_kv.txt'%self.root,'r') as f:
+ lines = f.read().splitlines()
+ for l in lines:
+ entry = l.split()
+ img_path = '%s/'%self.root+entry[0][7:]
+ self.train_labels[img_path] = int(entry[1])
+ with open('%s/clean_label_kv.txt'%self.root,'r') as f:
+ lines = f.read().splitlines()
+ for l in lines:
+ entry = l.split()
+ img_path = '%s/'%self.root+entry[0][7:]
+ self.test_labels[img_path] = int(entry[1])
+
+ if train:
+ train_imgs=[]
+ with open('%s/noisy_train_key_list.txt'%self.root,'r') as f:
+ lines = f.read().splitlines()
+ for i , l in enumerate(lines):
+ img_path = '%s/'%self.root+l[7:]
+ train_imgs.append((i,img_path))
+ self.num_raw_example = len(train_imgs)
+ random.shuffle(train_imgs)
+ class_num = torch.zeros(num_class)
+ self.train_imgs = []
+ for id_raw, impath in train_imgs:
+ label = self.train_labels[impath]
+ if class_num[label]<(num_samples/14) and len(self.train_imgs)= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \
+ "the 'config.json' file."
+ logger.warning(message)
+
+ self.step = 0
+ self.mode = ''
+
+ self.tb_writer_ftns = {
+ 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio',
+ 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding'
+ }
+ self.tag_mode_exceptions = {'add_histogram', 'add_embedding'}
+
+ self.timer = Timer()
+
+ def set_step(self, step, mode='train'):
+ self.mode = mode
+ self.step = step
+ if step == 0:
+ self.timer.reset()
+ else:
+ duration = self.timer.check()
+ self.add_scalar('steps_per_sec', 1 / duration)
+
+ def __getattr__(self, name):
+ """
+ If visualization is configured to use:
+ return add_data() methods of tensorboard with additional information (step, tag) added.
+ Otherwise:
+ return a blank function handle that does nothing
+ """
+ if name in self.tb_writer_ftns:
+ add_data = getattr(self.writer, name, None)
+
+ def wrapper(tag, data, *args, **kwargs):
+ if add_data is not None:
+ # add mode(train/valid) tag
+ if name not in self.tag_mode_exceptions:
+ tag = '{}/{}'.format(tag, self.mode)
+ add_data(tag, data, self.step, *args, **kwargs)
+ return wrapper
+ else:
+ # default action for returning methods defined in this class, set_step() for instance.
+ try:
+ attr = object.__getattr__(name)
+ except AttributeError:
+ raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name))
+ return attr
diff --git a/ELR_plus/model/InceptionResNetV2.py b/ELR_plus/model/InceptionResNetV2.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0dbcb93ef6132fafcae4e7fd0dcece0d24d8f89
--- /dev/null
+++ b/ELR_plus/model/InceptionResNetV2.py
@@ -0,0 +1,314 @@
+from __future__ import print_function, division, absolute_import
+import torch
+import torch.nn as nn
+import os
+import sys
+
+
+class BasicConv2d(nn.Module):
+
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
+ super(BasicConv2d, self).__init__()
+ self.conv = nn.Conv2d(in_planes, out_planes,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, bias=False) # verify bias false
+ self.bn = nn.BatchNorm2d(out_planes,
+ eps=0.001, # value found in tensorflow
+ momentum=0.1, # default pytorch value
+ affine=True)
+ self.relu = nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+class Mixed_5b(nn.Module):
+
+ def __init__(self):
+ super(Mixed_5b, self).__init__()
+
+ self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(192, 48, kernel_size=1, stride=1),
+ BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
+ )
+
+ self.branch2 = nn.Sequential(
+ BasicConv2d(192, 64, kernel_size=1, stride=1),
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.branch3 = nn.Sequential(
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
+ BasicConv2d(192, 64, kernel_size=1, stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ x3 = self.branch3(x)
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class Block35(nn.Module):
+
+ def __init__(self, scale=1.0):
+ super(Block35, self).__init__()
+
+ self.scale = scale
+
+ self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.branch2 = nn.Sequential(
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
+ BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
+ self.relu = nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ out = torch.cat((x0, x1, x2), 1)
+ out = self.conv2d(out)
+ out = out * self.scale + x
+ out = self.relu(out)
+ return out
+
+
+class Mixed_6a(nn.Module):
+
+ def __init__(self):
+ super(Mixed_6a, self).__init__()
+
+ self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(320, 256, kernel_size=1, stride=1),
+ BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
+ )
+
+ self.branch2 = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ out = torch.cat((x0, x1, x2), 1)
+ return out
+
+
+class Block17(nn.Module):
+
+ def __init__(self, scale=1.0):
+ super(Block17, self).__init__()
+
+ self.scale = scale
+
+ self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(1088, 128, kernel_size=1, stride=1),
+ BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),
+ BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0))
+ )
+
+ self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
+ self.relu = nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ out = torch.cat((x0, x1), 1)
+ out = self.conv2d(out)
+ out = out * self.scale + x
+ out = self.relu(out)
+ return out
+
+
+class Mixed_7a(nn.Module):
+
+ def __init__(self):
+ super(Mixed_7a, self).__init__()
+
+ self.branch0 = nn.Sequential(
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
+ )
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
+ BasicConv2d(256, 288, kernel_size=3, stride=2)
+ )
+
+ self.branch2 = nn.Sequential(
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
+ BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(288, 320, kernel_size=3, stride=2)
+ )
+
+ self.branch3 = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ x3 = self.branch3(x)
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class Block8(nn.Module):
+
+ def __init__(self, scale=1.0, noReLU=False):
+ super(Block8, self).__init__()
+
+ self.scale = scale
+ self.noReLU = noReLU
+
+ self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(2080, 192, kernel_size=1, stride=1),
+ BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),
+ BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0))
+ )
+
+ self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
+ if not self.noReLU:
+ self.relu = nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ out = torch.cat((x0, x1), 1)
+ out = self.conv2d(out)
+ out = out * self.scale + x
+ if not self.noReLU:
+ out = self.relu(out)
+ return out
+
+
+class InceptionResNetV2(nn.Module):
+
+ def __init__(self, num_classes=50):
+ super(InceptionResNetV2, self).__init__()
+ # Special attributs
+ self.input_space = None
+ self.input_size = (299, 299, 3)
+ self.mean = None
+ self.std = None
+ # Modules
+ self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
+ self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
+ self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
+ self.maxpool_3a = nn.MaxPool2d(3, stride=2)
+ self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
+ self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
+ self.maxpool_5a = nn.MaxPool2d(3, stride=2)
+ self.mixed_5b = Mixed_5b()
+ self.repeat = nn.Sequential(
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17)
+ )
+ self.mixed_6a = Mixed_6a()
+ self.repeat_1 = nn.Sequential(
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10)
+ )
+ self.mixed_7a = Mixed_7a()
+ self.repeat_2 = nn.Sequential(
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20)
+ )
+ self.block8 = Block8(noReLU=True)
+ self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
+ self.avgpool_1a = nn.AdaptiveAvgPool2d((1, 1))#nn.AvgPool2d(8, count_include_pad=False)
+ self.last_linear = nn.Linear(1536, num_classes)
+
+
+ def features(self, input):
+ x = self.conv2d_1a(input)
+ x = self.conv2d_2a(x)
+ x = self.conv2d_2b(x)
+ x = self.maxpool_3a(x)
+ x = self.conv2d_3b(x)
+ x = self.conv2d_4a(x)
+ x = self.maxpool_5a(x)
+ x = self.mixed_5b(x)
+ x = self.repeat(x)
+ x = self.mixed_6a(x)
+ x = self.repeat_1(x)
+ x = self.mixed_7a(x)
+ x = self.repeat_2(x)
+ x = self.block8(x)
+ x = self.conv2d_7b(x)
+ return x
+
+ def logits(self, features):
+ x = self.avgpool_1a(features)
+ x = x.view(x.size(0), -1)
+ out = self.last_linear(x)
+ return out
+
+
+ def forward(self, input):
+ x = self.features(input)
+ out = self.logits(x)
+ return out
+
+
+def test():
+ net = InceptionResNetV2().cuda()
+ y = net(torch.randn(1,3,227,227).cuda())
+ print(y.size())
+#test()
\ No newline at end of file
diff --git a/ELR_plus/model/PreResNet.py b/ELR_plus/model/PreResNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..117c4b428bd510b33b619d7a9cc5ec42e96db3d6
--- /dev/null
+++ b/ELR_plus/model/PreResNet.py
@@ -0,0 +1,181 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.autograd import Variable
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(in_planes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion*planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(self.expansion*planes)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class PreActBlock(nn.Module):
+ '''Pre-activation version of the BasicBlock.'''
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(PreActBlock, self).__init__()
+ self.bn1 = nn.BatchNorm2d(in_planes)
+ self.conv1 = conv3x3(in_planes, planes, stride)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv2 = conv3x3(planes, planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion*planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(x))
+ shortcut = self.shortcut(out)
+ out = self.conv1(out)
+ out = self.conv2(F.relu(self.bn2(out)))
+ out += shortcut
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion*planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(self.expansion*planes)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = F.relu(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class PreActBottleneck(nn.Module):
+ '''Pre-activation version of the original Bottleneck module.'''
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(PreActBottleneck, self).__init__()
+ self.bn1 = nn.BatchNorm2d(in_planes)
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion*planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(x))
+ shortcut = self.shortcut(out)
+ out = self.conv1(out)
+ out = self.conv2(F.relu(self.bn2(out)))
+ out = self.conv3(F.relu(self.bn3(out)))
+ out += shortcut
+ return out
+
+
+class PreActResNet(nn.Module):
+ def __init__(self, block, num_blocks, num_classes=10):
+ super(PreActResNet, self).__init__()
+ self.in_planes = 64
+
+ self.conv1 = conv3x3(3,64)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+ self.linear = nn.Linear(512*block.expansion, num_classes)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1]*(num_blocks-1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x, lin=0, lout=5):
+ out = x
+ if lin < 1 and lout > -1:
+ out = self.conv1(out)
+ out = self.bn1(out)
+ out = F.relu(out)
+ if lin < 2 and lout > 0:
+ out = self.layer1(out)
+ if lin < 3 and lout > 1:
+ out = self.layer2(out)
+ if lin < 4 and lout > 2:
+ out = self.layer3(out)
+ if lin < 5 and lout > 3:
+ out = self.layer4(out)
+ if lout > 4:
+ out = F.avg_pool2d(out, 4)
+ out = out.view(out.size(0), -1)
+ out_final = self.linear(out)
+ return out_final
+
+
+def PreActResNet18(num_classes=10):
+ return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes)
+
+def PreActResNet34(num_classes=10):
+ return PreActResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
+
+def PreActResNet50(num_classes=10):
+ return PreActResNet(Bottleneck, [3,4,6,3], num_classes=num_classes)
+
+def PreActResNet101(num_classes=10):
+ return PreActResNet(Bottleneck, [3,4,23,3], num_classes=num_classes)
+
+def PreActResNet152(num_classes=10):
+ return PreActResNet(Bottleneck, [3,8,36,3], num_classes=num_classes)
+
+
+def test():
+ net = PreActResNet18()
+ y = net(Variable(torch.randn(1,3,32,32)))
+ print(y.size())
diff --git a/ELR_plus/model/ResNet_Zoo.py b/ELR_plus/model/ResNet_Zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..34cf692ea9d57423a21f11a9e0cdcb91d1f5f122
--- /dev/null
+++ b/ELR_plus/model/ResNet_Zoo.py
@@ -0,0 +1,121 @@
+'''ResNet in PyTorch.
+For Pre-activation ResNet, see 'preact_resnet.py'.
+Reference:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+ Deep Residual Learning for Image Recognition. arXiv:1512.03385
+'''
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion*planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(self.expansion*planes)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(self.expansion*planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion*planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(self.expansion*planes)
+ )
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = F.relu(self.bn2(self.conv2(out)))
+ out = self.bn3(self.conv3(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(self, block, num_blocks, num_classes=10):
+ super(ResNet, self).__init__()
+ self.in_planes = 64
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
+ self.linear = nn.Linear(512*block.expansion, num_classes)
+
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1]*(num_blocks-1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = self.layer3(out)
+ out = self.layer4(out)
+ out = F.avg_pool2d(out, 4)
+ y = out.view(out.size(0), -1)
+ out = self.linear(y)
+ return out
+
+
+def ResNet18():
+ return ResNet(BasicBlock, [2,2,2,2])
+
+def ResNet34():
+ return ResNet(BasicBlock, [3,4,6,3])
+
+def ResNet50():
+ return ResNet(Bottleneck, [3,4,6,3])
+
+def ResNet101():
+ return ResNet(Bottleneck, [3,4,23,3])
+
+def ResNet152():
+ return ResNet(Bottleneck, [3,8,36,3])
+
+
+def test():
+ net = ResNet18()
+ y = net(torch.randn(1,3,32,32))
+ print(y.size())
+
+# test()
\ No newline at end of file
diff --git a/ELR_plus/model/loss.py b/ELR_plus/model/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..07bd9b92d904f79ba574492bc9f210685456a42f
--- /dev/null
+++ b/ELR_plus/model/loss.py
@@ -0,0 +1,41 @@
+import torch.nn.functional as F
+import torch
+import numpy as np
+from parse_config import ConfigParser
+import torch.nn as nn
+from torch.autograd import Variable
+import math
+from utils import sigmoid_rampup, sigmoid_rampdown, cosine_rampup, cosine_rampdown, linear_rampup
+
+
+def cross_entropy(output, target, M=3):
+ return F.cross_entropy(output, target)
+
+class elr_plus_loss(nn.Module):
+ def __init__(self, num_examp, config, device, num_classes=10, beta=0.3):
+ super(elr_plus_loss, self).__init__()
+ self.config = config
+ self.pred_hist = (torch.zeros(num_examp, num_classes)).to(device)
+ self.q = 0
+ self.beta = beta
+ self.num_classes = num_classes
+
+ def forward(self, iteration, output, y_labeled):
+ y_pred = F.softmax(output,dim=1)
+
+ y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
+
+ if self.num_classes == 100:
+ y_labeled = y_labeled*self.q
+ y_labeled = y_labeled/(y_labeled).sum(dim=1,keepdim=True)
+
+ ce_loss = torch.mean(-torch.sum(y_labeled * F.log_softmax(output, dim=1), dim = -1))
+ reg = ((1-(self.q * y_pred).sum(dim=1)).log()).mean()
+ final_loss = ce_loss + sigmoid_rampup(iteration, self.config['coef_step'])*(self.config['train_loss']['args']['lambda']*reg)
+
+ return final_loss, y_pred.cpu().detach()
+
+ def update_hist(self, epoch, out, index= None, mix_index = ..., mixup_l = 1):
+ y_pred_ = F.softmax(out,dim=1)
+ self.pred_hist[index] = self.beta * self.pred_hist[index] + (1-self.beta) * y_pred_/(y_pred_).sum(dim=1,keepdim=True)
+ self.q = mixup_l * self.pred_hist[index] + (1-mixup_l) * self.pred_hist[index][mix_index]
diff --git a/ELR_plus/model/metric.py b/ELR_plus/model/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..51ed18f5e9e33b67d297692e8a987f5237e35f7e
--- /dev/null
+++ b/ELR_plus/model/metric.py
@@ -0,0 +1,20 @@
+import torch
+
+
+def my_metric(output, target):
+ with torch.no_grad():
+ pred = torch.argmax(output, dim=1)
+ assert pred.shape[0] == len(target)
+ correct = 0
+ correct += torch.sum(pred == target).item()
+ return correct / len(target)
+
+
+def my_metric2(output, target, k=3):
+ with torch.no_grad():
+ pred = torch.topk(output, k, dim=1)[1]
+ assert pred.shape[0] == len(target)
+ correct = 0
+ for i in range(k):
+ correct += torch.sum(pred[:, i] == target).item()
+ return correct / len(target)
diff --git a/ELR_plus/model/model.py b/ELR_plus/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..410d193d796fa603e45dfaf91d845732c80b8d82
--- /dev/null
+++ b/ELR_plus/model/model.py
@@ -0,0 +1,26 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from base import BaseModel
+from .ResNet_Zoo import ResNet, BasicBlock
+from .PreResNet import PreActResNet, PreActBlock
+import torchvision.models as models
+from .InceptionResNetV2 import InceptionResNetV2
+
+
+def resnet34(num_classes=10):
+ return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes)
+ #return models.resnet34(num_classes=10)
+
+
+def resnet50(num_classes=14):
+ import torchvision.models as models
+ model_ft = models.resnet50(pretrained=True)
+ num_ftrs = model_ft.fc.in_features
+ model_ft.fc = nn.Linear(num_ftrs, num_classes)
+ return model_ft
+
+
+def PreActResNet34(num_classes=10) -> PreActResNet:
+ return PreActResNet(PreActBlock, [3, 4, 6, 3], num_classes=num_classes)
+def PreActResNet18(num_classes=10) -> PreActResNet:
+ return PreActResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes)
diff --git a/ELR_plus/parse_config.py b/ELR_plus/parse_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8434f5cd529118d5764052ad395d230d47fe793c
--- /dev/null
+++ b/ELR_plus/parse_config.py
@@ -0,0 +1,145 @@
+import os
+import logging
+from pathlib import Path
+from functools import reduce
+from operator import getitem
+from datetime import datetime
+from logger import setup_logging
+from utils import read_json, write_json
+
+
+class ConfigParser:
+
+ __instance = None
+
+ def __new__(cls, args, options='', timestamp=True):
+ raise NotImplementedError('Cannot initialize via Constructor')
+
+ @classmethod
+ def __internal_new__(cls):
+ return super().__new__(cls)
+
+ @classmethod
+ def get_instance(cls, args=None, options='', timestamp=True):
+ if not cls.__instance:
+ if args is None:
+ NotImplementedError('Cannot initialize without args')
+ cls.__instance = cls.__internal_new__()
+ cls.__instance.__init__(args, options)
+
+ return cls.__instance
+
+ def __init__(self, args, options='', timestamp=True):
+ # parse default and custom cli options
+ for opt in options:
+ args.add_argument(*opt.flags, default=None, type=opt.type)
+ args = args.parse_args()
+
+ if args.device:
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.device
+ if args.resume is None:
+ msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
+ assert args.config is not None, msg_no_cfg
+ self.cfg_fname = Path(args.config)
+ config = read_json(self.cfg_fname)
+ self.resume = None
+ else:
+ self.resume = Path(args.resume)
+ resume_cfg_fname = self.resume.parent / 'config.json'
+ config = read_json(resume_cfg_fname)
+ if args.config is not None:
+ config.update(read_json(Path(args.config)))
+
+ # load config file and apply custom cli options
+ self._config = _update_config(config, options, args)
+
+ # set save_dir where trained model and log will be saved.
+ save_dir = Path(self.config['trainer']['save_dir'])
+ timestamp = datetime.now().strftime(r'%m%d_%H%M%S') if timestamp else ''
+
+ if self.config['trainer']['asym']:
+ exper_name = self.config['name'] + '_asym_' + str(int(self.config['trainer']['percent']*100))
+ else:
+ exper_name = self.config['name'] + '_sym_' + str(int(self.config['trainer']['percent']*100))
+
+ self._save_dir = save_dir / 'models' / exper_name / timestamp
+ self._log_dir = save_dir / 'log' / exper_name / timestamp
+
+ self.save_dir.mkdir(parents=True, exist_ok=True)
+ self.log_dir.mkdir(parents=True, exist_ok=True)
+
+ # save updated config file to the checkpoint dir
+ write_json(self.config, self.save_dir / 'config.json')
+
+ # configure logging module
+ setup_logging(self.log_dir)
+ self.log_levels = {
+ 0: logging.WARNING,
+ 1: logging.INFO,
+ 2: logging.DEBUG
+ }
+
+ def initialize(self, name, module, *args, **kwargs):
+ """
+ finds a function handle with the name given as 'type' in config, and returns the
+ instance initialized with corresponding keyword args given as 'args'.
+ """
+ module_name = self[name]['type']
+ module_args = dict(self[name]['args'])
+ assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
+ module_args.update(kwargs)
+ return getattr(module, module_name)(*args, **module_args)
+
+ def __getitem__(self, name):
+ return self.config[name]
+
+ def get_logger(self, name, verbosity=2):
+ msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity,
+ self.log_levels.keys())
+ assert verbosity in self.log_levels, msg_verbosity
+ logger = logging.getLogger(name)
+ logger.setLevel(self.log_levels[verbosity])
+ return logger
+
+ # setting read-only attributes
+ @property
+ def config(self):
+ return self._config
+
+ @property
+ def save_dir(self):
+ return self._save_dir
+
+ @property
+ def log_dir(self):
+ return self._log_dir
+
+
+# helper functions used to update config dict with custom cli options
+def _update_config(config, options, args):
+ for opt in options:
+ value = getattr(args, _get_opt_name(opt.flags))
+ if value is not None:
+ _set_by_path(config, opt.target, value)
+ if 'target2' in opt._fields:
+ _set_by_path(config, opt.target2, value)
+ if 'target3' in opt._fields:
+ _set_by_path(config, opt.target3, value)
+ return config
+
+
+def _get_opt_name(flags):
+ for flg in flags:
+ if flg.startswith('--'):
+ return flg.replace('--', '')
+ return flags[0].replace('--', '')
+
+
+def _set_by_path(tree, keys, value):
+ """Set a value in a nested object in tree by sequence of keys."""
+ _get_by_path(tree, keys[:-1])[keys[-1]] = value
+
+
+def _get_by_path(tree, keys):
+ """Access a nested object in tree by sequence of keys."""
+ return reduce(getitem, keys, tree)
diff --git a/ELR_plus/test.py b/ELR_plus/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..3084bdb3a59d756db1c8914962108a44550293de
--- /dev/null
+++ b/ELR_plus/test.py
@@ -0,0 +1,82 @@
+import argparse
+import torch
+from tqdm import tqdm
+import data_loader.data_loaders as module_data
+import model.loss as module_loss
+import model.metric as module_metric
+import model.model as module_arch
+from parse_config import ConfigParser
+
+
+def main(config):
+ logger = config.get_logger('test')
+
+ # setup data_loader instances
+ data_loader = getattr(module_data, config['data_loader']['type'])(
+ config['data_loader']['args']['data_dir'],
+ batch_size=512,
+ shuffle=False,
+ validation_split=0.0,
+ training=False,
+ num_workers=2
+ ).split_validation()
+
+ # build model architecture
+ model = config.initialize('arch', module_arch)
+ logger.info(model)
+
+ # get function handles of loss and metrics
+ loss_fn = getattr(module_loss, config['val_loss'])
+ metric_fns = [getattr(module_metric, met) for met in config['metrics']]
+
+ logger.info('Loading checkpoint: {} ...'.format(config.resume))
+ checkpoint = torch.load(config.resume,map_location='cpu')
+ state_dict = checkpoint['state_dict']
+ if config['n_gpu'] > 1:
+ model = torch.nn.DataParallel(model)
+ model.load_state_dict(state_dict)
+
+ # prepare model for testing
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ model = model.to(device)
+ model.eval()
+
+ total_loss = 0.0
+ total_metrics = torch.zeros(len(metric_fns))
+
+ with torch.no_grad():
+ for i, (data, target,_,_) in enumerate(tqdm(data_loader)):
+ data, target = data.to(device), target.to(device)
+ output = model(data)
+
+ #
+ # save sample images, or do something with output here
+ #
+
+ # computing loss, metrics on test set
+ loss = loss_fn(output, target)
+ batch_size = data.shape[0]
+ total_loss += loss.item() * batch_size
+ for i, metric in enumerate(metric_fns):
+ total_metrics[i] += metric(output, target) * batch_size
+
+ n_samples = len(data_loader.sampler)
+ log = {'loss': total_loss / n_samples}
+ log.update({
+ met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
+ })
+ logger.info(log)
+
+
+if __name__ == '__main__':
+ args = argparse.ArgumentParser(description='PyTorch Template')
+
+ args.add_argument('-c', '--config', default=None, type=str,
+ help='config file path (default: None)')
+ args.add_argument('-r', '--resume', default=None, type=str,
+ help='path to latest checkpoint (default: None)')
+ args.add_argument('-d', '--device', default=None, type=str,
+ help='indices of GPUs to enable (default: all)')
+ config = ConfigParser.get_instance(args, '')
+ #config = ConfigParser(args)
+ main(config)
diff --git a/ELR_plus/train.py b/ELR_plus/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..855c7dea7809447452f81f65054a0b10216e6b4a
--- /dev/null
+++ b/ELR_plus/train.py
@@ -0,0 +1,163 @@
+import argparse
+import collections
+import sys
+import requests
+import socket
+import torch
+import mlflow
+import mlflow.pytorch
+import data_loader.data_loaders as module_data
+import model.loss as module_loss
+import model.metric as module_metric
+import model.model as module_arch
+from parse_config import ConfigParser
+from trainer import Trainer
+from collections import OrderedDict
+import random
+
+
+def log_params(conf: OrderedDict, parent_key: str = None):
+ for key, value in conf.items():
+ if parent_key is not None:
+ combined_key = f'{parent_key}-{key}'
+ else:
+ combined_key = key
+
+ if not isinstance(value, OrderedDict):
+ mlflow.log_param(combined_key, value)
+ else:
+ log_params(value, combined_key)
+
+
+def main(config: ConfigParser):
+
+
+ logger = config.get_logger('train')
+ logger.info(config.config)
+
+ # setup data_loader instances
+ data_loader1 = getattr(module_data, config['data_loader']['type'])(
+ config['data_loader']['args']['data_dir'],
+ batch_size= config['data_loader']['args']['batch_size'],
+ shuffle=config['data_loader']['args']['shuffle'],
+ validation_split=config['data_loader']['args']['validation_split'],
+ num_batches=config['data_loader']['args']['num_batches'],
+ training=True,
+ num_workers=config['data_loader']['args']['num_workers'],
+ pin_memory=config['data_loader']['args']['pin_memory']
+ )
+
+ data_loader2 = getattr(module_data, config['data_loader']['type'])(
+ config['data_loader']['args']['data_dir'],
+ batch_size= config['data_loader']['args']['batch_size2'],
+ shuffle=config['data_loader']['args']['shuffle'],
+ validation_split=config['data_loader']['args']['validation_split'],
+ num_batches=config['data_loader']['args']['num_batches'],
+ training=True,
+ num_workers=config['data_loader']['args']['num_workers'],
+ pin_memory=config['data_loader']['args']['pin_memory']
+ )
+
+
+ valid_data_loader = data_loader1.split_validation()
+
+ test_data_loader = getattr(module_data, config['data_loader']['type'])(
+ config['data_loader']['args']['data_dir'],
+ batch_size=128,
+ shuffle=False,
+ validation_split=0.0,
+ training=False,
+ num_workers=2
+ ).split_validation()
+
+ # build model architecture
+ model1 = config.initialize('arch1', module_arch)
+ model_ema1 = config.initialize('arch1', module_arch)
+ model_ema1_copy = config.initialize('arch1', module_arch)
+ model2 = config.initialize('arch2', module_arch)
+ model_ema2 = config.initialize('arch2', module_arch)
+ model_ema2_copy = config.initialize('arch2', module_arch)
+
+
+ # get function handles of loss and metrics
+ device_id = list(range(min(torch.cuda.device_count(), config['n_gpu'])))
+
+ if hasattr(data_loader1.dataset, 'num_raw_example') and hasattr(data_loader2.dataset, 'num_raw_example'):
+ num_examp1 = data_loader1.dataset.num_raw_example
+ num_examp2 = data_loader2.dataset.num_raw_example
+ else:
+ num_examp1 = len(data_loader1.dataset)
+ num_examp2 = len(data_loader2.dataset)
+
+ train_loss1 = getattr(module_loss, config['train_loss']['type'])(num_examp=num_examp1, num_classes=config['num_classes'],
+ device = 'cuda:'+ str(device_id[0]), config = config.config, beta=config['train_loss']['args']['beta'])
+ train_loss2 = getattr(module_loss, config['train_loss']['type'])(num_examp=num_examp2, num_classes=config['num_classes'],
+ device = 'cuda:'+str(device_id[-1]), config = config.config, beta=config['train_loss']['args']['beta'])
+
+ val_loss = getattr(module_loss, config['val_loss'])
+ metrics = [getattr(module_metric, met) for met in config['metrics']]
+
+ # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
+ trainable_params1 = filter(lambda p: p.requires_grad, model1.parameters())
+ trainable_params2 = filter(lambda p: p.requires_grad, model2.parameters())
+
+ optimizer1 = config.initialize('optimizer1', torch.optim, [{'params': trainable_params1}])
+ optimizer2 = config.initialize('optimizer2', torch.optim, [{'params': trainable_params2}])
+
+ lr_scheduler1 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer1)
+ lr_scheduler2 = config.initialize('lr_scheduler', torch.optim.lr_scheduler, optimizer2)
+
+ trainer = Trainer(model1, model2, model_ema1, model_ema2, train_loss1, train_loss2,
+ metrics,
+ optimizer1, optimizer2,
+ config=config,
+ data_loader1=data_loader1,
+ data_loader2=data_loader2,
+ valid_data_loader=valid_data_loader,
+ test_data_loader=test_data_loader,
+ lr_scheduler1=lr_scheduler1,
+ lr_scheduler2=lr_scheduler2,
+ val_criterion=val_loss,
+ model_ema1_copy = model_ema1_copy,
+ model_ema2_copy = model_ema2_copy)
+
+ trainer.train()
+ logger = config.get_logger('trainer', config['trainer']['verbosity'])
+ cfg_trainer = config['trainer']
+
+
+
+if __name__ == '__main__':
+ args = argparse.ArgumentParser(description='PyTorch Template')
+ args.add_argument('-c', '--config', default=None, type=str,
+ help='config file path (default: None)')
+ args.add_argument('-r', '--resume', default=None, type=str,
+ help='path to latest checkpoint (default: None)')
+ args.add_argument('-d', '--device', default=None, type=str,
+ help='indices of GPUs to enable (default: all)')
+
+ # custom cli options to modify configuration from default values given in json file.
+ CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
+ options = [
+ CustomArgs(['--lr', '--learning_rate'], type=float, target=('optimizer', 'args', 'lr')),
+ CustomArgs(['--bs', '--batch_size'], type=int, target=('data_loader', 'args', 'batch_size')),
+ CustomArgs(['--beta', '--beta'], type=float, target=('train_loss', 'args', 'beta')),
+ CustomArgs(['--lambda', '--lambda'], type=float, target=('train_loss', 'args', 'lambda')),
+ CustomArgs(['--percent', '--percent'], type=float, target=('trainer', 'percent')),
+ CustomArgs(['--asym', '--asym'], type=bool, target=('trainer', 'asym')),
+ CustomArgs(['--name', '--exp_name'], type=str, target=('name',)),
+ CustomArgs(['--malpha', '--mixup_alpha'], type=float, target=('mixup_alpha',)),
+ CustomArgs(['--ealpha', '--ema_alpha'], type=float, target=('ema_alpha',)),
+ CustomArgs(['--nb', '--num_batches'], type=float, target=('data_loader', 'args', 'num_batches')),
+ CustomArgs(['--warm', '--warmup'], type=int, target=('trainer', 'warmup')),
+ CustomArgs(['--seed', '--seed'], type=int, target=('seed',)),
+ CustomArgs(['--wc1', '--weight_decay1'], type=float, target=('optimizer1','weight_decay')),
+ CustomArgs(['--wc2', '--weight_decay2'], type=float, target=('optimizer2','weight_decay')),
+ CustomArgs(['--estep', '--ema_step'], type=float, target=('ema_step',)),
+
+ ]
+ config = ConfigParser.get_instance(args, options)
+ random.seed(config['seed'])
+ torch.manual_seed(config['seed'])
+ torch.cuda.manual_seed_all(config['seed'])
+ main(config)
diff --git a/ELR_plus/trainer/__init__.py b/ELR_plus/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fe21f802159f55665fcdebcd9b2c32bb4acf5b9
--- /dev/null
+++ b/ELR_plus/trainer/__init__.py
@@ -0,0 +1 @@
+from .trainer import *
diff --git a/ELR_plus/trainer/trainer.py b/ELR_plus/trainer/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff2fa43ccdea4ffdb67d4443e2b506bee3b3f05e
--- /dev/null
+++ b/ELR_plus/trainer/trainer.py
@@ -0,0 +1,365 @@
+import numpy as np
+import torch
+from tqdm import tqdm
+from typing import List
+from torchvision.utils import make_grid
+from base import BaseTrainer
+from utils import inf_loop, linear_rampup, sigmoid_rampup, linear_rampdown
+import sys
+from sklearn.mixture import GaussianMixture
+import torch.nn.functional as F
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+class Trainer(BaseTrainer):
+ """
+ Trainer class
+
+ Note:
+ Inherited from BaseTrainer.
+ """
+ def __init__(self, model1, model2, model_ema1, model_ema2, train_criterion1, train_criterion2, metrics, optimizer1, optimizer2, config,
+ data_loader1, data_loader2,
+ valid_data_loader=None,
+ test_data_loader=None,
+ lr_scheduler1=None, lr_scheduler2=None,
+ len_epoch=None, val_criterion=None,
+ model_ema1_copy=None, model_ema2_copy=None):
+ super().__init__(model1, model2, model_ema1, model_ema2, train_criterion1, train_criterion2,
+ metrics, optimizer1, optimizer2, config, val_criterion, model_ema1_copy, model_ema2_copy)
+ self.config = config.config
+ self.data_loader1 = data_loader1
+ self.data_loader2 = data_loader2
+ if len_epoch is None:
+ # epoch-based training
+ self.len_epoch = len(self.data_loader1)
+ else:
+ # iteration-based training
+ self.data_loader1 = inf_loop(data_loader1)
+ self.data_loader2 = inf_loop(data_loader2)
+ self.len_epoch = len_epoch
+ self.valid_data_loader = valid_data_loader
+
+ self.test_data_loader = test_data_loader
+ self.do_validation = self.valid_data_loader is not None
+ self.do_test = self.test_data_loader is not None
+ self.lr_scheduler1 = lr_scheduler1
+ self.lr_scheduler2 = lr_scheduler2
+ self.log_step = int(np.sqrt(self.data_loader1.batch_size))
+ self.train_loss_list: List[float] = []
+ self.val_loss_list: List[float] = []
+ self.test_loss_list: List[float] = []
+
+
+ def _eval_metrics(self, output, target):
+ acc_metrics = np.zeros(len(self.metrics))
+ for i, metric in enumerate(self.metrics):
+ acc_metrics[i] += metric(output, target)
+ self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i])
+ return acc_metrics
+
+ def _train_epoch(self, epoch, model, model_ema, model_ema2, data_loader, train_criterion, optimizer, lr_scheduler, device = 'cpu', queue = None):
+ """
+ Training logic for an epoch
+
+ :param epoch: Current training epoch.
+ :return: A log that contains all information you want to save.
+
+ Note:
+ If you have additional information to record, for example:
+ > additional_log = {"x": x, "y": y}
+ merge it with log before return. i.e.
+ > log = {**log, **additional_log}
+ > return log
+
+ The metrics in log must have the key 'metrics'.
+ """
+ model.train()
+ model_ema.train()
+
+ total_loss = 0
+ total_metrics = np.zeros(len(self.metrics))
+ total_metrics_ema = np.zeros(len(self.metrics))
+
+ if hasattr(data_loader.dataset, 'num_raw_example'):
+ num_examp = data_loader.dataset.num_raw_example
+ else:
+ num_examp = len(data_loader.dataset)
+
+ local_step = 0
+
+
+
+ with tqdm(data_loader) as progress:
+ for batch_idx, (data, target, indexs, _) in enumerate(progress):
+ progress.set_description_str(f'Train epoch {epoch}')
+
+ data_original = data
+ target_original = target
+
+ target = torch.zeros(len(target), self.config['num_classes']).scatter_(1, target.view(-1,1), 1)
+ data, target, target_original = data.to(device), target.float().to(device), target_original.to(device)
+
+ data, target, mixup_l, mix_index = self._mixup_data(data, target, alpha = self.config['mixup_alpha'], device = device)
+
+ output = model(data)
+
+ data_original = data_original.to(device)
+ output_original = model_ema2(data_original)
+ output_original = output_original.data.detach()
+ train_criterion.update_hist(epoch, output_original, indexs.numpy().tolist(), mix_index = mix_index, mixup_l = mixup_l)
+
+ local_step += 1
+ loss, probs = train_criterion(self.global_step + local_step, output, target)
+
+ optimizer.zero_grad()
+ loss.backward()
+
+
+ optimizer.step()
+
+ self.update_ema_variables(model, model_ema, self.global_step + local_step, self.config['ema_alpha'])
+
+ self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
+ self.writer.add_scalar('loss', loss.item())
+ self.train_loss_list.append(loss.item())
+ total_loss += loss.item()
+ total_metrics += self._eval_metrics(output, target.argmax(dim=1))
+ if output_original is not None:
+ total_metrics_ema += self._eval_metrics(output_original, target.argmax(dim=1))
+
+
+ if batch_idx % self.log_step == 0:
+ progress.set_postfix_str(' {} Loss: {:.6f}'.format(
+ self._progress(batch_idx),
+ loss.item()))
+
+ if batch_idx == self.len_epoch:
+ break
+
+ if hasattr(data_loader, 'run'):
+ data_loader.run()
+
+
+ log = {
+ 'global step': self.global_step,
+ 'local_step': local_step,
+ 'loss': total_loss / self.len_epoch,
+ 'metrics': (total_metrics / self.len_epoch).tolist(),
+ 'metrics_ema': (total_metrics_ema / self.len_epoch).tolist(),
+ 'learning rate': lr_scheduler.get_lr()
+ }
+
+
+ if lr_scheduler is not None:
+ lr_scheduler.step()
+
+ if queue is None:
+ return log
+ else:
+ queue.put(log)
+
+
+ def _valid_epoch(self, epoch, model1, model2, device = 'cpu', queue = None):
+ """
+ Validate after training an epoch
+
+ :return: A log that contains information about validation
+
+ Note:
+ The validation metrics in log must have the key 'val_metrics'.
+ """
+ model1.eval()
+ model2.eval()
+
+ total_val_loss = 0
+ total_val_metrics = np.zeros(len(self.metrics))
+ with torch.no_grad():
+ with tqdm(self.valid_data_loader) as progress:
+ for batch_idx, (data, target, _, _) in enumerate(progress):
+ progress.set_description_str(f'Valid epoch {epoch}')
+ data, target = data.to(device), target.to(device)
+
+ output1 = model1(data)
+ output2 = model2(data)
+
+ output = 0.5*(output1 + output2)
+
+ loss = self.val_criterion(output, target)
+
+ self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
+ self.writer.add_scalar('loss', loss.item())
+ self.val_loss_list.append(loss.item())
+ total_val_loss += loss.item()
+ total_val_metrics += self._eval_metrics(output, target)
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
+
+ # #add histogram of model parameters to the tensorboard
+ # for name, p in model.named_parameters():
+ # self.writer.add_histogram(name, p, bins='auto')
+
+ if queue is None:
+ return {
+ 'val_loss': total_val_loss / len(self.valid_data_loader),
+ 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist()
+ }
+ else:
+ queue.put({
+ 'val_loss': total_val_loss / len(self.valid_data_loader),
+ 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist()
+ })
+
+ def _test_epoch(self, epoch, model1, model2, device = 'cpu', queue = None):
+ """
+ Test after training an epoch
+
+ :return: A log that contains information about test
+
+ Note:
+ The Test metrics in log must have the key 'val_metrics'.
+ """
+ model1.eval()
+ model2.eval()
+
+ total_test_loss = 0
+ total_test_metrics = np.zeros(len(self.metrics))
+ with torch.no_grad():
+ with tqdm(self.test_data_loader) as progress:
+ for batch_idx, (data, target,indexs,_) in enumerate(progress):
+ progress.set_description_str(f'Test epoch {epoch}')
+ data, target = data.to(device), target.to(device)
+
+ output1 = model1(data)
+ output2 = model2(data)
+
+ output = 0.5*(output1 + output2)
+ loss = self.val_criterion(output, target)
+ self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test')
+ self.writer.add_scalar('loss', loss.item())
+ self.test_loss_list.append(loss.item())
+ total_test_loss += loss.item()
+ total_test_metrics += self._eval_metrics(output, target)
+ self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
+
+
+
+ #add histogram of model parameters to the tensorboard
+ for name, p in model1.named_parameters():
+ self.writer.add_histogram(name, p, bins='auto')
+ if queue is None:
+ return {
+ 'test_loss': total_test_loss / len(self.test_data_loader),
+ 'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist()
+ }
+ else:
+ queue.put({
+ 'test_loss': total_test_loss / len(self.test_data_loader),
+ 'test_metrics': (total_test_metrics / len(self.test_data_loader)).tolist()
+ })
+
+
+ def _warmup_epoch(self, epoch, model, data_loader, optimizer, train_criterion, lr_scheduler, device = 'cpu', queue = None):
+ total_loss = 0
+ total_metrics = np.zeros(len(self.metrics))
+ model.train()
+
+ with tqdm(data_loader) as progress:
+ for batch_idx, (data, target, indexs , _) in enumerate(progress):
+ progress.set_description_str(f'Train epoch {epoch}')
+
+ data, target = data.to(device), target.long().to(device)
+ optimizer.zero_grad()
+ output = model(data)
+ out_prob = output.data.detach()
+
+ train_criterion.update_hist(epoch, out_prob ,indexs.cpu().detach().numpy().tolist())
+
+ loss = torch.nn.functional.cross_entropy(output, target)
+
+ loss.backward()
+ optimizer.step()
+
+ # self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
+ # self.writer.add_scalar('loss', loss.item())
+ self.train_loss_list.append(loss.item())
+ total_loss += loss.item()
+ total_metrics += self._eval_metrics(output, target)
+
+
+ if batch_idx % self.log_step == 0:
+ progress.set_postfix_str(' {} Loss: {:.6f}'.format(
+ self._progress(batch_idx),
+ loss.item()))
+
+
+ if batch_idx == self.len_epoch:
+ break
+
+ log = {
+ 'loss': total_loss / self.len_epoch,
+ 'noise detection rate' : 0.0,
+ 'metrics': (total_metrics / self.len_epoch).tolist(),
+ 'learning rate': lr_scheduler.get_lr()
+ }
+ if queue is None:
+ return log
+ else:
+ queue.put(log)
+
+ def eval_train(self, epoch, model_ema2, train_criterion):
+ #model.eval()
+ num_samples = args.num_batches*args.batch_size
+ losses = torch.zeros(num_samples)
+ with torch.no_grad():
+ for batch_idx, (inputs, targets, path) in enumerate(eval_loader):
+ inputs, targets = inputs.cuda(), targets.cuda()
+ output0 = model_ema2(inputs)
+ output0 = output0.data.detach()
+ output1, output2, output3 = None, None, None
+ train_criterion.update_hist(epoch, output0, output1, output2, output3, indexs.numpy().tolist(),mix_index = mix_index, mixup_l = mixup_l)
+
+
+ def update_ema_variables(self, model, model_ema, global_step, alpha_=0.997):
+ # Use the true average until the exponential average is more correct
+ if alpha_ == 0:
+ ema_param.data = param.data
+ else:
+ if self.config['ema_update']:
+ alpha = sigmoid_rampup(global_step + 1, self.config['ema_step'])*alpha_
+ else:
+ alpha = min(1 - 1 / (global_step + 1), alpha_)
+ for ema_param, param in zip(model_ema.parameters(), model.parameters()):
+ ema_param.data.mul_(alpha).add_(1 - alpha, param.data)
+
+
+ def _progress(self, batch_idx):
+ base = '[{}/{} ({:.0f}%)]'
+ if hasattr(self.data_loader1, 'n_samples'):
+ current = batch_idx * self.data_loader1.batch_size
+ total = self.data_loader1.n_samples
+ else:
+ current = batch_idx
+ total = self.len_epoch
+ return base.format(current, total, 100.0 * current / total)
+
+ def _mixup_data(self, x, y, alpha=1.0, device = 'cpu'):
+ '''Returns mixed inputs, pairs of targets, and lambda'''
+ if alpha > 0:
+ lam = np.random.beta(alpha, alpha)
+ lam = max(lam, 1-lam)
+ batch_size = x.size()[0]
+ mix_index = torch.randperm(batch_size).to(device)
+
+ mixed_x = lam * x + (1 - lam) * x[mix_index, :]#
+ mixed_target = lam * y + (1 - lam) * y[mix_index, :]
+
+
+ return mixed_x, mixed_target, lam, mix_index
+ else:
+ lam = 1
+ return x, y, lam, ...
+
+
+ def _mixup_criterion(self, pred, y_a, y_b, lam, *args):
+ loss_a, prob_a, entropy_a= self.train_criterion(pred, y_a, *args)
+ loss_b, porb_b, entropy_b = self.train_criterion(pred, y_b, *args)
+ return lam * loss_a + (1 - lam) * loss_b, lam * prob_a + (1-lam) * porb_b, lam * entropy_a + (1-lam) * entropy_b
diff --git a/ELR_plus/utils/__init__.py b/ELR_plus/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..74ace79a2790b01f71138e46a6b4243d416113ab
--- /dev/null
+++ b/ELR_plus/utils/__init__.py
@@ -0,0 +1 @@
+from .util import *
diff --git a/ELR_plus/utils/util.py b/ELR_plus/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c01e3eaee5123ba9cd757e0713b871bd0db9d5e
--- /dev/null
+++ b/ELR_plus/utils/util.py
@@ -0,0 +1,91 @@
+import json
+from pathlib import Path
+from datetime import datetime
+from itertools import repeat
+from collections import OrderedDict
+import numpy as np
+
+def ensure_dir(dirname):
+ dirname = Path(dirname)
+ if not dirname.is_dir():
+ dirname.mkdir(parents=True, exist_ok=False)
+
+
+def read_json(fname):
+ with fname.open('rt') as handle:
+ return json.load(handle, object_hook=OrderedDict)
+
+
+def write_json(content, fname):
+ with fname.open('wt') as handle:
+ json.dump(content, handle, indent=4, sort_keys=False)
+
+
+def inf_loop(data_loader):
+ ''' wrapper function for endless data loader. '''
+ for loader in repeat(data_loader):
+ yield from loader
+
+
+class Timer:
+ def __init__(self):
+ self.cache = datetime.now()
+
+ def check(self):
+ now = datetime.now()
+ duration = now - self.cache
+ self.cache = now
+ return duration.total_seconds()
+
+ def reset(self):
+ self.cache = datetime.now()
+
+
+
+def sigmoid_rampup(current, rampup_length):
+ """Exponential rampup from 2"""
+ if rampup_length == 0:
+ return 1.0
+ else:
+ current = np.clip(current, 0.0, rampup_length)
+ phase = 1.0 - current / rampup_length
+ return float(np.exp(-5.0 * phase * phase))
+
+def sigmoid_rampdown(current, rampdown_length):
+ """Exponential rampdown"""
+ if rampdown_length == 0:
+ return 1.0
+ else:
+ current = np.clip(current, 0.0, rampdown_length)
+ phase = 1.0 - (rampdown_length - current) / rampdown_length
+ return float(np.exp(-12.5 * phase * phase))
+
+def linear_rampup(current, rampup_length):
+ """Linear rampup"""
+ assert current >= 0 and rampup_length >= 0
+ if current >= rampup_length:
+ return 1.0
+ else:
+ return current / rampup_length
+
+def linear_rampdown(current, rampdown_length):
+ """Linear rampup"""
+ assert current >= 0 and rampdown_length >= 0
+ if current >= rampdown_length:
+ return 1.0
+ else:
+ return 1.0 - current / rampdown_length
+
+
+def cosine_rampdown(current, rampdown_length):
+ """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
+ current = np.clip(current, 0.0, rampdown_length)
+ return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
+
+
+def cosine_rampup(current, rampup_length):
+ """Cosine rampup"""
+ current = np.clip(current, 0.0, rampup_length)
+ return float(-.5 * (np.cos(np.pi * current / rampup_length) - 1))
+
+
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..a727e13f64bdcd8be6f62accd85a6c23fdd567aa
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Sheng Liu
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index 22aa4bf935316e49699a08ba4c5e6d9394578555..000506a60dd2c58470a504b2b70afbfaab9a229c 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,141 @@
----
-title: Assi1
-emoji: 📉
-colorFrom: blue
-colorTo: purple
-sdk: gradio
-sdk_version: 5.20.1
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
+# Early-Learning Regularization Prevents Memorization of Noisy Labels
+[](https://arxiv.org/abs/2007.00151)
+
+
+
+This repository is the official implementation of [Early-Learning Regularization Prevents Memorization of Noisy Labels](https://arxiv.org/abs/2007.00151) (NeurIPS 2020).
+
+We propose a novel framework to perform classification via deep learning in the presence of **noisy annotations**. When trained on noisy labels, deep neural networks have been observed to first fit the training data with clean labels during an **early learning** phase, before eventually **memorizing** the examples with false labels. Our technique exploits the progress of the early learning phase via **regularization** to perform classification from noisy labels. There are two key elements to our approach. First, we leverage semi-supervised learning techniques to produce target probabilities based on the model outputs. Second, we design a regularization term that steers the model towards these targets, implicitly preventing memorization of the false labels. The resulting framework is shown to provide robustness to noisy annotations on several standard benchmarks and real-world datasets, where it achieves results comparable to the state of the art.
+
+
+
+
+These graphs show the results of training a ResNet-34 with a traditional cross entropy loss (top row) and our proposed method (bottom row) to perform classification on the CIFAR-10 dataset where 40% of the labels are flipped at random. The left column shows the fraction of examples with clean labels that are predicted correctly (green) and incorrectly (blue). The right column shows the fraction of examples with wrong labels that are predicted correctly (green), memorized (the prediction equals the wrong label, shown in red), and incorrectly predicted as neither the true nor the labeled class (blue). The model trained with cross entropy begins by learning to predict the true labels, even for many of the examples with wrong labels, but eventually memorizes the wrong labels. Our proposed method based on early-learning regularization prevents memorization, allowing the model to continue learning on the examples with clean labels to attain high accuracy on examples with both clean and wrong labels.
+
+
+
+
+
+
+
+Learning path of sample with correct label (left) and sample with wrong label (right). Corners correspond to one-hot
+vectors. Bright green represents model's prediction: when the example is wrongly labeled, the clean label is predicted during early-learning, and then wrong label is predicted at the end of training. The model is trained with first 3 classes in CIFAR10.
+
+
+
+## Early-learning modeled for each example
+Early-learning could not happen simultaneously for all examples, e.g. when noisy labels are dependent to each instance. Therefore, In [SOP](https://github.com/shengliu66/SOP), We model the early-learning phenomenon for each example using a overparameterization term learned for each instance. We further impose sparsity on it via implicit bias of stochastic gradient descent. This method achieved SoTA for instance dependent noisy label. If you are interested, take a look at our [Paper](https://proceedings.mlr.press/v162/liu22w/liu22w.pdf) published in ICML 2022!
+
+
+## Requirements
+- This codebase is written for `python3`.
+- To install necessary python packages, run `pip install -r requirements.txt`.
+
+
+## Training
+### Basics
+- ELR loss is implemented in the file [`loss.py`](./ELR/model/loss.py)
+- All functions used for training the basic version of our technique (**ELR**) can be found in the `ELR` folder.
+- All functions used for training the more advanced version (**ELR+**) can be found in the `ELR_plus` folder.
+- Experiments settings and configurations used for different datasets are in the corresponding config json files.
+### Data
+- Please download the data before running the code, add path to the downloaded data to `data_loader.args.data_dir` in the corresponding config file.
+### Training
+- Code for training ELR is in the following file: [`train.py`](./ELR/train.py), code for training ELR+ is in the following file: [`train.py`](./ELR_plus/train.py)
+```
+usage: train.py [-c] [-r] [-d] [--lr learning_rate] [--bs batch_size] [--beta beta] [--lambda lambda] [--malpha mixup_alpha]
+ [--percent percent] [--asym asym] [--ealpha ema_alpha] [--name exp_name]
+
+ arguments:
+ -c, --config config file path (default: None)
+ -r, --resume path to latest checkpoint (default: None)
+ -d, --device indices of GPUs to enable (default: all)
+
+ options:
+ --lr learning_rate learning rate (default value is the value in the config file)
+ --bs batch_size batch size (default value is the value in the config file)
+ --beta beta temporal ensembling momentum beta for target estimation
+ --lambda lambda regularization coefficient
+ --malpha mixup_alpha mixup parameter alpha
+ --percent percent noise level (e.g. 0.4 for 40%)
+ --asym asym asymmetric noise is used when set to True
+ --ealpha ema_alpha weight averaging momentum for target estimation
+ --name exp_name experiment name
+```
+Configuration file is **required** to be specified. Default option values, if not reset, will be the values in the configuration file.
+Examples for ELR and ELR+ are shown in the *readme.md* of `ELR` and `ELR_plus` subfolders respectively.
+
+### Example
+In order to use our proposed early learning regularization (ELR), you can simply replace your loss function by the following loss function. Usually, **lambda** which is used to control the strength of the regularization term needs to be tuned more carefully, and the value of **beta** is quite robust (can be 0.7, 0.9 or 0.99, etc.)
+```
+class elr_loss(nn.Module):
+ def __init__(self, num_examp, num_classes=10, lambda = 3, beta=0.7):
+ r"""Early Learning Regularization.
+ Parameters
+ * `num_examp` Total number of training examples.
+ * `num_classes` Number of classes in the classification problem.
+ * `lambda` Regularization strength; must be a positive float, controling the strength of the ELR.
+ * `beta` Temporal ensembling momentum for target estimation.
+ """
+
+ super(elr_loss, self).__init__()
+ self.num_classes = num_classes
+ self.USE_CUDA = torch.cuda.is_available()
+ self.target = torch.zeros(num_examp, self.num_classes).cuda() if self.USE_CUDA else torch.zeros(num_examp, self.num_classes)
+ self.beta = beta
+ self.lambda = lambda
+
+
+ def forward(self, index, output, label):
+ r"""Early Learning Regularization.
+ Args
+ * `index` Training sample index, due to training set shuffling, index is used to track training examples in different iterations.
+ * `output` Model's logits, same as PyTorch provided loss functions.
+ * `label` Labels, same as PyTorch provided loss functions.
+ """
+
+ y_pred = F.softmax(output,dim=1)
+ y_pred = torch.clamp(y_pred, 1e-4, 1.0-1e-4)
+ y_pred_ = y_pred.data.detach()
+ self.target[index] = self.beta * self.target[index] + (1-self.beta) * ((y_pred_)/(y_pred_).sum(dim=1,keepdim=True))
+ ce_loss = F.cross_entropy(output, label)
+ elr_reg = ((1-(self.target[index] * y_pred).sum(dim=1)).log()).mean()
+ final_loss = ce_loss + self.lambda *elr_reg
+ return final_loss
+
+```
+## Identify Wrong Labels
+- After finish training, obtain and compare the self.target of the ELR loss to original labels y
+- the mislabeled examples are identified as those who has argmax(self.target) != y
+
+
+## License and Contributing
+- This README is formatted based on [paperswithcode](https://github.com/paperswithcode/releasing-research-code).
+- Feel free to post issues via Github.
+
+## Reference
+For technical details and full experimental results, please check [our paper](https://arxiv.org/abs/2007.00151).
+```
+@InProceedings{pmlr-v162-liu22w,
+ title = {Robust Training under Label Noise by Over-parameterization},
+ author = {Liu, Sheng and Zhu, Zhihui and Qu, Qing and You, Chong},
+ journal = {Proceedings of the 39th International Conference on Machine Learning},
+ volume = {162},
+ year = {2022}
+}
+```
+```
+@article{liu2020early,
+ title={Early-Learning Regularization Prevents Memorization of Noisy Labels},
+ author={Liu, Sheng and Niles-Weed, Jonathan and Razavian, Narges and Fernandez-Granda, Carlos},
+ journal={Advances in Neural Information Processing Systems},
+ volume={33},
+ year={2020}
+}
+```
+Similar early learning and memorization phenomenon is observed in semantic segmentations, a related paper addressing WSSS by adaptive correction:
+[Adaptive Early-Learning Correction for Segmentation from Noisy Annotations](https://arxiv.org/abs/2110.03740) (CVPR2022 **Oral**).
+## Contact
+Please contact shengliu@nyu.edu if you have any question on the codes.
diff --git a/images/ELR.png b/images/ELR.png
new file mode 100644
index 0000000000000000000000000000000000000000..c2c85844e7a5e9f46c39c77909bfcd8c7bf9ff14
Binary files /dev/null and b/images/ELR.png differ
diff --git a/images/clean_label_simplexheatmap2.gif b/images/clean_label_simplexheatmap2.gif
new file mode 100644
index 0000000000000000000000000000000000000000..bbbe036528f932c4b870c6b4c5d4eb635598dd79
--- /dev/null
+++ b/images/clean_label_simplexheatmap2.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6915d13dc449834481ef5246e02f09160fbe989f769d619ef6b8b15a191e500
+size 833882
diff --git a/images/false_label_simplexheatmap.gif b/images/false_label_simplexheatmap.gif
new file mode 100644
index 0000000000000000000000000000000000000000..566dfbdf750c39b7b1c92e43bc54b6f5ed07dc9a
--- /dev/null
+++ b/images/false_label_simplexheatmap.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c3e5cc6928ebc26966b161a9308021195df98b418139b724cd1f6c20748649db
+size 812027
diff --git a/images/illustration_of_ELR.png b/images/illustration_of_ELR.png
new file mode 100644
index 0000000000000000000000000000000000000000..8c72436047a51c4ce447403ca379b3f9241249c5
--- /dev/null
+++ b/images/illustration_of_ELR.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:94bc2ee99bbc79324750cf458d48051cf4cd41d4f3c8174c2eebb8605b83425a
+size 1120626
diff --git a/images/simplexheatmap.gif b/images/simplexheatmap.gif
new file mode 100644
index 0000000000000000000000000000000000000000..00d1bd2ccf541d4be1c3b6072aa16d2913ca44a9
--- /dev/null
+++ b/images/simplexheatmap.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:18e68ee9843819bd4101c47e0695c2333e6df372cfd5cf847b2db4441f53e0d9
+size 3917605
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..64de891b7a7b4720930323da97a77fb5f4078f77
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,8 @@
+torchvision==0.4.0a0
+requests==2.22.0
+tqdm==4.32.1
+torch==1.2.0
+numpy==1.16.4
+mlflow==1.9.1
+Pillow==8.1.1
+scikit_learn==0.23.1