abhishek-kumar commited on
Commit
1e8d169
·
1 Parent(s): 20a3309

Add files from Neural-IMage-Assessment repo

Browse files
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ snapshots/badpred.png filter=lfs diff=lfs merge=lfs -text
36
+ snapshots/contrast.png filter=lfs diff=lfs merge=lfs -text
37
+ snapshots/goodpred.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ ./data/
LICENSE ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT) Copyright (c) 2020 Yunxiao Shi
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4
+
5
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6
+
7
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
8
+ © 2020 GitHub, Inc.
README.md CHANGED
@@ -1,12 +1,81 @@
1
- ---
2
- title: NIMA
3
- emoji: 👁
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 3.19.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## NIMA: Neural IMage Assessment
2
+
3
+ [![Python 3.6+](https://img.shields.io/badge/Python-3.6%2B-blue)](https://www.python.org/)
4
+ [![MIT License](https://img.shields.io/badge/MIT-License-brightgreen)](./LICENSE)
5
+
6
+ This is a PyTorch implementation of the paper [NIMA: Neural IMage Assessment](https://arxiv.org/abs/1709.05424) (accepted at [IEEE Transactions on Image Processing](https://ieeexplore.ieee.org/document/8352823)) by Hossein Talebi and Peyman Milanfar. You can learn more from [this post at Google Research Blog](https://research.googleblog.com/2017/12/introducing-nima-neural-image-assessment.html).
7
+
8
+ ## Implementation Details
9
+
10
+ + The model was trained on the [AVA (Aesthetic Visual Analysis) dataset](http://refbase.cvc.uab.es/files/MMP2012a.pdf) containing 255,500+ images. You can get it from [here](https://github.com/mtobeiyf/ava_downloader). ~~**Note: there may be some corrupted images in the dataset, remove them first before you start training**.~~ Use provided CSVs which have already done this for you.
11
+
12
+ + Dataset is split into 229,981 images for training, 12,691 images for validation and 12,818 images for testing.
13
+
14
+ + An ImageNet pretrained VGG-16 is used as the base network. Should be easy to plug in the other two options (MobileNet and Inception-v2).
15
+
16
+ + The learning rate setting differs from the original paper. Can't seem to get the model to converge using the original params. Also didn't do much hyper-param tuning therefore you could probably get better results. Other settings are all directly mirrored from the paper.
17
+
18
+ ## Requirements
19
+
20
+ Code is written using [PyTorch](https://pytorch.org/get-started/locally/) 1.8.1 with [CUDA](https://developer.nvidia.com/cuda-toolkit) 11.1. You can recreate the environment I used with [conda](https://docs.conda.io/en/latest/miniconda.html) by
21
+ ```
22
+ conda env create -f env.yml
23
+ ```
24
+ to install the dependancies.
25
+
26
+ ## Usage
27
+
28
+ To start training on the AVA dataset, first download the dataset from the link above and decompress which should create a directory named ```images/```. Then download the curated annotation CSVs below
29
+ which already splits the dataset (You can create your own split of course). Then do
30
+
31
+ ```python
32
+ python main.py --img_path /path/to/images/ --train --train_csv_file /path/to/train_labels.csv --val_csv_file /path/to/val_labels.csv --conv_base_lr 5e-4 --dense_lr 5e-3 --decay --ckpt_path /path/to/ckpts --epochs 100 --early_stoppping_patience 10
33
+ ```
34
+
35
+ For inference, do
36
+
37
+ ```python
38
+ python -W ignore test.py --model /path/to/your_model --test_csv /path/to/test_labels.csv --test_images /path/to/images --predictions /path/to/save/predictions
39
+ ```
40
+
41
+ See ```predictions/``` for dumped predictions as an example.
42
+
43
+ ## Training Statistics
44
+
45
+ Training is done with early stopping. Here I set ```early_stopping_patience=10```.
46
+ <p align="center">
47
+ <img src="./snapshots/[email protected]">
48
+ </p>
49
+
50
+ ## Pretrained Model
51
+
52
+ ~0.069 EMD on validation. Not fully converged yet (constrained by resources). To continue training, download the pretrained weights and add ```--warm_start --warm_start_epoch 34``` to your args.
53
+
54
+ [Google Drive](https://drive.google.com/file/d/1w9Ig_d6yZqUZSR63kPjZLrEjJ1n845B_/view?usp=sharing)
55
+
56
+ ## Annotation CSV Files
57
+ [Train](https://drive.google.com/file/d/1IBXPXPkCiTz04wWcoReJv4Nk06VsjSkI/view?usp=sharing) [Validation](https://drive.google.com/file/d/1tJfO1zFBoQYzd8kUo5PKeHTcdzBL7115/view?usp=sharing) [Test](https://drive.google.com/file/d/105UGnkglpKuusPhJaPnFSa2JlQV3du9O/view?usp=sharing)
58
+
59
+ ## Example Results
60
+
61
+ + Here first shows some good predictions from the test set. Each image title starts with ground-truth rating followed by the predicted mean and std in the parentheses.
62
+
63
+ <p align="center">
64
+ <img src="./snapshots/goodpred.png">
65
+ </p>
66
+
67
+ + Also some failure cases, it would seem that the model usually fails at images with low/high aesthetic ratings.
68
+
69
+ <p align="center">
70
+ <img src="./snapshots/badpred.png">
71
+ </p>
72
+
73
+ + The predicted aesthetic ratings from training on the AVA dataset are sensitive to contrast adjustments, preferring images with higher contrast. Below top row is the reference image with contrast ```c=1.0```, while bottom images are enhanced with contrast ```[0.25, 0.75, 1.25, 1.75]```. Contrast adjustment is done using ```ImageEnhance.Contrast``` from ```PIL``` (in this case [pillow-simd](https://github.com/uploadcare/pillow-simd)).
74
+
75
+ <p align="center">
76
+ <img src="./snapshots/contrast.png">
77
+ </p>
78
+
79
+ ## License
80
+
81
+ MIT
dataset/__init__.py ADDED
File without changes
dataset/dataset.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ file - dataset.py
3
+ Customized dataset class to loop through the AVA dataset and apply needed image augmentations for training.
4
+
5
+ Copyright (C) Yunxiao Shi 2017 - 2021
6
+ NIMA is released under the MIT license. See LICENSE for the fill license text.
7
+ """
8
+
9
+ import os
10
+
11
+ import pandas as pd
12
+ from PIL import Image
13
+
14
+ import torch
15
+ from torch.utils import data
16
+ import torchvision.transforms as transforms
17
+
18
+
19
+ class AVADataset(data.Dataset):
20
+ """AVA dataset
21
+
22
+ Args:
23
+ csv_file: a 11-column csv_file, column one contains the names of image files, column 2-11 contains the empiricial distributions of ratings
24
+ root_dir: directory to the images
25
+ transform: preprocessing and augmentation of the training images
26
+ """
27
+
28
+ def __init__(self, csv_file, root_dir, transform=None):
29
+ self.annotations = pd.read_csv(csv_file)
30
+ self.root_dir = root_dir
31
+ self.transform = transform
32
+
33
+ def __len__(self):
34
+ return len(self.annotations)
35
+
36
+ def __getitem__(self, idx):
37
+ img_name = os.path.join(self.root_dir, str(self.annotations.iloc[idx, 0]) + '.jpg')
38
+ image = Image.open(img_name).convert('RGB')
39
+ annotations = self.annotations.iloc[idx, 1:].to_numpy()
40
+ annotations = annotations.astype('float').reshape(-1, 1)
41
+ sample = {'img_id': img_name, 'image': image, 'annotations': annotations}
42
+
43
+ if self.transform:
44
+ sample['image'] = self.transform(sample['image'])
45
+
46
+ return sample
47
+
48
+
49
+ if __name__ == '__main__':
50
+
51
+ # sanity check
52
+ root = './data/images'
53
+ csv_file = './data/train_labels.csv'
54
+ train_transform = transforms.Compose([
55
+ transforms.Scale(256),
56
+ transforms.RandomCrop(224),
57
+ transforms.RandomHorizontalFlip(),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
60
+ dset = AVADataset(csv_file=csv_file, root_dir=root, transform=train_transform)
61
+ train_loader = data.DataLoader(dset, batch_size=4, shuffle=True, num_workers=4)
62
+ for i, data in enumerate(train_loader):
63
+ images = data['image']
64
+ print(images.size())
65
+ labels = data['annotations']
66
+ print(labels.size())
env.yml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: torch1.8.1
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - blas=1.0=mkl
9
+ - bzip2=1.0.8=h7b6447c_0
10
+ - ca-certificates=2021.5.25=h06a4308_1
11
+ - certifi=2021.5.30=py38h06a4308_0
12
+ - cudatoolkit=11.1.74=h6bb024c_0
13
+ - cycler=0.10.0=py38_0
14
+ - dbus=1.13.18=hb2f20db_0
15
+ - expat=2.4.1=h2531618_2
16
+ - ffmpeg=4.3=hf484d3e_0
17
+ - fontconfig=2.13.1=h6c09931_0
18
+ - freetype=2.10.4=h5ab3b9f_0
19
+ - glib=2.68.2=h36276a3_0
20
+ - gmp=6.2.1=h2531618_2
21
+ - gnutls=3.6.15=he1e5248_0
22
+ - gst-plugins-base=1.14.0=h8213a91_2
23
+ - gstreamer=1.14.0=h28cd5cc_2
24
+ - icu=58.2=he6710b0_3
25
+ - intel-openmp=2021.2.0=h06a4308_610
26
+ - jpeg=9b=h024ee3a_2
27
+ - kiwisolver=1.3.1=py38h2531618_0
28
+ - lame=3.100=h7b6447c_0
29
+ - lcms2=2.12=h3be6417_0
30
+ - ld_impl_linux-64=2.33.1=h53a641e_7
31
+ - libffi=3.3=he6710b0_2
32
+ - libgcc-ng=9.1.0=hdf63c60_0
33
+ - libiconv=1.15=h63c8f33_5
34
+ - libidn2=2.3.1=h27cfd23_0
35
+ - libpng=1.6.37=hbc83047_0
36
+ - libstdcxx-ng=9.1.0=hdf63c60_0
37
+ - libtasn1=4.16.0=h27cfd23_0
38
+ - libtiff=4.2.0=h85742a9_0
39
+ - libunistring=0.9.10=h27cfd23_0
40
+ - libuuid=1.0.3=h1bed415_2
41
+ - libuv=1.40.0=h7b6447c_0
42
+ - libwebp-base=1.2.0=h27cfd23_0
43
+ - libxcb=1.14=h7b6447c_0
44
+ - libxml2=2.9.10=hb55368b_3
45
+ - lz4-c=1.9.3=h2531618_0
46
+ - matplotlib=3.3.4=py38h06a4308_0
47
+ - matplotlib-base=3.3.4=py38h62a2d02_0
48
+ - mkl=2021.2.0=h06a4308_296
49
+ - mkl-service=2.3.0=py38h27cfd23_1
50
+ - mkl_fft=1.3.0=py38h42c9631_2
51
+ - mkl_random=1.2.1=py38ha9443f7_2
52
+ - ncurses=6.2=he6710b0_1
53
+ - nettle=3.7.2=hbbd107a_1
54
+ - ninja=1.10.2=hff7bd54_1
55
+ - numpy=1.20.2=py38h2d18471_0
56
+ - numpy-base=1.20.2=py38hfae3a4d_0
57
+ - olefile=0.46=py_0
58
+ - openh264=2.1.0=hd408876_0
59
+ - openssl=1.1.1k=h27cfd23_0
60
+ - pcre=8.44=he6710b0_0
61
+ - pip=21.1.1=py38h06a4308_0
62
+ - pyparsing=2.4.7=pyhd3eb1b0_0
63
+ - pyqt=5.9.2=py38h05f1152_4
64
+ - python=3.8.10=hdb3f193_7
65
+ - python-dateutil=2.8.1=pyhd3eb1b0_0
66
+ - pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0
67
+ - qt=5.9.7=h5867ecd_1
68
+ - readline=8.1=h27cfd23_0
69
+ - setuptools=52.0.0=py38h06a4308_0
70
+ - sip=4.19.13=py38he6710b0_0
71
+ - six=1.15.0=py38h06a4308_0
72
+ - sqlite=3.35.4=hdfb4753_0
73
+ - tk=8.6.10=hbc83047_0
74
+ - torchaudio=0.8.1=py38
75
+ - torchvision=0.9.1=py38_cu111
76
+ - tornado=6.1=py38h27cfd23_0
77
+ - typing_extensions=3.7.4.3=pyha847dfd_0
78
+ - wheel=0.36.2=pyhd3eb1b0_0
79
+ - xz=5.2.5=h7b6447c_0
80
+ - zlib=1.2.11=h7b6447c_3
81
+ - zstd=1.4.9=haebb681_0
82
+ - pip:
83
+ - absl-py==0.12.0
84
+ - cachetools==4.2.2
85
+ - chardet==4.0.0
86
+ - google-auth==1.30.1
87
+ - google-auth-oauthlib==0.4.4
88
+ - grpcio==1.38.0
89
+ - idna==2.10
90
+ - markdown==3.3.4
91
+ - oauthlib==3.1.1
92
+ - pandas==1.2.4
93
+ - pillow-simd==7.0.0.post3
94
+ - protobuf==3.17.2
95
+ - pyasn1==0.4.8
96
+ - pyasn1-modules==0.2.8
97
+ - pytz==2021.1
98
+ - requests==2.25.1
99
+ - requests-oauthlib==1.3.0
100
+ - rsa==4.7.2
101
+ - tensorboard==2.5.0
102
+ - tensorboard-data-server==0.6.1
103
+ - tensorboard-plugin-wit==1.8.0
104
+ - tqdm==4.61.0
105
+ - urllib3==1.26.5
106
+ - werkzeug==2.0.1
main.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ file - main.py
3
+ Main script to train the aesthetic model on the AVA dataset.
4
+
5
+ Copyright (C) Yunxiao Shi 2017 - 2021
6
+ NIMA is released under the MIT license. See LICENSE for the fill license text.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+
12
+ import numpy as np
13
+ import matplotlib
14
+ # matplotlib.use('Agg')
15
+ import matplotlib.pyplot as plt
16
+
17
+ import torch
18
+ import torch.autograd as autograd
19
+ import torch.optim as optim
20
+
21
+ import torchvision.transforms as transforms
22
+ import torchvision.datasets as dsets
23
+ import torchvision.models as models
24
+
25
+ from torch.utils.tensorboard import SummaryWriter
26
+
27
+ from dataset.dataset import AVADataset
28
+
29
+ from model.model import *
30
+
31
+
32
+ def main(config):
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ writer = SummaryWriter()
36
+
37
+ train_transform = transforms.Compose([
38
+ transforms.Scale(256),
39
+ transforms.RandomCrop(224),
40
+ transforms.RandomHorizontalFlip(),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
43
+ std=[0.229, 0.224, 0.225])])
44
+
45
+ val_transform = transforms.Compose([
46
+ transforms.Scale(256),
47
+ transforms.RandomCrop(224),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
50
+ std=[0.229, 0.224, 0.225])])
51
+
52
+ base_model = models.vgg16(pretrained=True)
53
+ model = NIMA(base_model)
54
+
55
+ if config.warm_start:
56
+ model.load_state_dict(torch.load(os.path.join(config.ckpt_path, 'epoch-%d.pth' % config.warm_start_epoch)))
57
+ print('Successfully loaded model epoch-%d.pth' % config.warm_start_epoch)
58
+
59
+ if config.multi_gpu:
60
+ model.features = torch.nn.DataParallel(model.features, device_ids=config.gpu_ids)
61
+ model = model.to(device)
62
+ else:
63
+ model = model.to(device)
64
+
65
+ conv_base_lr = config.conv_base_lr
66
+ dense_lr = config.dense_lr
67
+ optimizer = optim.SGD([
68
+ {'params': model.features.parameters(), 'lr': conv_base_lr},
69
+ {'params': model.classifier.parameters(), 'lr': dense_lr}],
70
+ momentum=0.9
71
+ )
72
+
73
+ param_num = 0
74
+ for param in model.parameters():
75
+ if param.requires_grad:
76
+ param_num += param.numel()
77
+ print('Trainable params: %.2f million' % (param_num / 1e6))
78
+
79
+ if config.train:
80
+ trainset = AVADataset(csv_file=config.train_csv_file, root_dir=config.img_path, transform=train_transform)
81
+ valset = AVADataset(csv_file=config.val_csv_file, root_dir=config.img_path, transform=val_transform)
82
+
83
+ train_loader = torch.utils.data.DataLoader(trainset, batch_size=config.train_batch_size,
84
+ shuffle=True, num_workers=config.num_workers)
85
+ val_loader = torch.utils.data.DataLoader(valset, batch_size=config.val_batch_size,
86
+ shuffle=False, num_workers=config.num_workers)
87
+ # for early stopping
88
+ count = 0
89
+ init_val_loss = float('inf')
90
+ train_losses = []
91
+ val_losses = []
92
+ for epoch in range(config.warm_start_epoch, config.epochs):
93
+ batch_losses = []
94
+ for i, data in enumerate(train_loader):
95
+ images = data['image'].to(device)
96
+ labels = data['annotations'].to(device).float()
97
+ outputs = model(images)
98
+ outputs = outputs.view(-1, 10, 1)
99
+
100
+ optimizer.zero_grad()
101
+
102
+ loss = emd_loss(labels, outputs)
103
+ batch_losses.append(loss.item())
104
+
105
+ loss.backward()
106
+
107
+ optimizer.step()
108
+
109
+ print('Epoch: %d/%d | Step: %d/%d | Training EMD loss: %.4f' % (epoch + 1, config.epochs, i + 1, len(trainset) // config.train_batch_size + 1, loss.data[0]))
110
+ writer.add_scalar('batch train loss', loss.data[0], i + epoch * (len(trainset) // config.train_batch_size + 1))
111
+
112
+ avg_loss = sum(batch_losses) / (len(trainset) // config.train_batch_size + 1)
113
+ train_losses.append(avg_loss)
114
+ print('Epoch %d mean training EMD loss: %.4f' % (epoch + 1, avg_loss))
115
+
116
+ # exponetial learning rate decay
117
+ if config.decay:
118
+ if (epoch + 1) % 10 == 0:
119
+ conv_base_lr = conv_base_lr * config.lr_decay_rate ** ((epoch + 1) / config.lr_decay_freq)
120
+ dense_lr = dense_lr * config.lr_decay_rate ** ((epoch + 1) / config.lr_decay_freq)
121
+ optimizer = optim.SGD([
122
+ {'params': model.features.parameters(), 'lr': conv_base_lr},
123
+ {'params': model.classifier.parameters(), 'lr': dense_lr}],
124
+ momentum=0.9
125
+ )
126
+
127
+ # do validation after each epoch
128
+ batch_val_losses = []
129
+ for data in val_loader:
130
+ images = data['image'].to(device)
131
+ labels = data['annotations'].to(device).float()
132
+ with torch.no_grad():
133
+ outputs = model(images)
134
+ outputs = outputs.view(-1, 10, 1)
135
+ val_loss = emd_loss(labels, outputs)
136
+ batch_val_losses.append(val_loss.item())
137
+ avg_val_loss = sum(batch_val_losses) / (len(valset) // config.val_batch_size + 1)
138
+ val_losses.append(avg_val_loss)
139
+ print('Epoch %d completed. Mean EMD loss on val set: %.4f.' % (epoch + 1, avg_val_loss))
140
+ writer.add_scalars('epoch losses', {'epoch train loss': avg_loss, 'epoch val loss': avg_val_loss}, epoch + 1)
141
+
142
+ # Use early stopping to monitor training
143
+ if avg_val_loss < init_val_loss:
144
+ init_val_loss = avg_val_loss
145
+ # save model weights if val loss decreases
146
+ print('Saving model...')
147
+ if not os.path.exists(config.ckpt_path):
148
+ os.makedirs(config.ckpt_path)
149
+ torch.save(model.state_dict(), os.path.join(config.ckpt_path, 'epoch-%d.pth' % (epoch + 1)))
150
+ print('Done.\n')
151
+ # reset count
152
+ count = 0
153
+ elif avg_val_loss >= init_val_loss:
154
+ count += 1
155
+ if count == config.early_stopping_patience:
156
+ print('Val EMD loss has not decreased in %d epochs. Training terminated.' % config.early_stopping_patience)
157
+ break
158
+
159
+ print('Training completed.')
160
+
161
+ '''
162
+ # use tensorboard to log statistics instead
163
+ if config.save_fig:
164
+ # plot train and val loss
165
+ epochs = range(1, epoch + 2)
166
+ plt.plot(epochs, train_losses, 'b-', label='train loss')
167
+ plt.plot(epochs, val_losses, 'g-', label='val loss')
168
+ plt.title('EMD loss')
169
+ plt.legend()
170
+ plt.savefig('./loss.png')
171
+ '''
172
+
173
+ if config.test:
174
+ model.eval()
175
+ # compute mean score
176
+ test_transform = val_transform
177
+ testset = AVADataset(csv_file=config.test_csv_file, root_dir=config.img_path, transform=val_transform)
178
+ test_loader = torch.utils.data.DataLoader(testset, batch_size=config.test_batch_size, shuffle=False, num_workers=config.num_workers)
179
+
180
+ mean_preds = []
181
+ std_preds = []
182
+ for data in test_loader:
183
+ image = data['image'].to(device)
184
+ output = model(image)
185
+ output = output.view(10, 1)
186
+ predicted_mean, predicted_std = 0.0, 0.0
187
+ for i, elem in enumerate(output, 1):
188
+ predicted_mean += i * elem
189
+ for j, elem in enumerate(output, 1):
190
+ predicted_std += elem * (j - predicted_mean) ** 2
191
+ predicted_std = predicted_std ** 0.5
192
+ mean_preds.append(predicted_mean)
193
+ std_preds.append(predicted_std)
194
+ # Do what you want with predicted and std...
195
+
196
+
197
+ if __name__ == '__main__':
198
+
199
+ parser = argparse.ArgumentParser()
200
+
201
+ # input parameters
202
+ parser.add_argument('--img_path', type=str, default='./data/images')
203
+ parser.add_argument('--train_csv_file', type=str, default='./data/train_labels.csv')
204
+ parser.add_argument('--val_csv_file', type=str, default='./data/val_labels.csv')
205
+ parser.add_argument('--test_csv_file', type=str, default='./data/test_labels.csv')
206
+
207
+ # training parameters
208
+ parser.add_argument('--train',action='store_true')
209
+ parser.add_argument('--test', action='store_true')
210
+ parser.add_argument('--decay', action='store_true')
211
+ parser.add_argument('--conv_base_lr', type=float, default=5e-3)
212
+ parser.add_argument('--dense_lr', type=float, default=5e-4)
213
+ parser.add_argument('--lr_decay_rate', type=float, default=0.95)
214
+ parser.add_argument('--lr_decay_freq', type=int, default=10)
215
+ parser.add_argument('--train_batch_size', type=int, default=128)
216
+ parser.add_argument('--val_batch_size', type=int, default=128)
217
+ parser.add_argument('--test_batch_size', type=int, default=1)
218
+ parser.add_argument('--num_workers', type=int, default=2)
219
+ parser.add_argument('--epochs', type=int, default=100)
220
+
221
+ # misc
222
+ parser.add_argument('--ckpt_path', type=str, default='./ckpts')
223
+ parser.add_argument('--multi_gpu', action='store_true')
224
+ parser.add_argument('--gpu_ids', type=list, default=None)
225
+ parser.add_argument('--warm_start', action='store_true')
226
+ parser.add_argument('--warm_start_epoch', type=int, default=0)
227
+ parser.add_argument('--early_stopping_patience', type=int, default=10)
228
+ parser.add_argument('--save_fig', action='store_true')
229
+
230
+ config = parser.parse_args()
231
+
232
+ main(config)
233
+
model/__init__.py ADDED
File without changes
model/model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ file - model.py
3
+ Implements the aesthemic model and emd loss used in paper.
4
+
5
+ Copyright (C) Yunxiao Shi 2017 - 2021
6
+ NIMA is released under the MIT license. See LICENSE for the fill license text.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ class NIMA(nn.Module):
13
+
14
+ """Neural IMage Assessment model by Google"""
15
+ def __init__(self, base_model, num_classes=10):
16
+ super(NIMA, self).__init__()
17
+ self.features = base_model.features
18
+ self.classifier = nn.Sequential(
19
+ nn.Dropout(p=0.75),
20
+ nn.Linear(in_features=25088, out_features=num_classes),
21
+ nn.Softmax())
22
+
23
+ def forward(self, x):
24
+ out = self.features(x)
25
+ out = out.view(out.size(0), -1)
26
+ out = self.classifier(out)
27
+ return out
28
+
29
+
30
+ def single_emd_loss(p, q, r=2):
31
+ """
32
+ Earth Mover's Distance of one sample
33
+
34
+ Args:
35
+ p: true distribution of shape num_classes × 1
36
+ q: estimated distribution of shape num_classes × 1
37
+ r: norm parameter
38
+ """
39
+ assert p.shape == q.shape, "Length of the two distribution must be the same"
40
+ length = p.shape[0]
41
+ emd_loss = 0.0
42
+ for i in range(1, length + 1):
43
+ emd_loss += torch.abs(sum(p[:i] - q[:i])) ** r
44
+ return (emd_loss / length) ** (1. / r)
45
+
46
+
47
+ def emd_loss(p, q, r=2):
48
+ """
49
+ Earth Mover's Distance on a batch
50
+
51
+ Args:
52
+ p: true distribution of shape mini_batch_size × num_classes × 1
53
+ q: estimated distribution of shape mini_batch_size × num_classes × 1
54
+ r: norm parameters
55
+ """
56
+ assert p.shape == q.shape, "Shape of the two distribution batches must be the same."
57
+ mini_batch_size = p.shape[0]
58
+ loss_vector = []
59
+ for i in range(mini_batch_size):
60
+ loss_vector.append(single_emd_loss(p[i], q[i], r=r))
61
+ return sum(loss_vector) / mini_batch_size
predictions/pred.txt ADDED
The diff for this file is too large to render. See raw diff
 
snapshots/badpred.png ADDED

Git LFS Details

  • SHA256: a2cc2b2886ae204106ac590c798b02cdc69be03716258f60adc45e9b5128acc6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.54 MB
snapshots/contrast.png ADDED

Git LFS Details

  • SHA256: b09aaf3e92865133142753c85a0f09e2a8b88f0e168cd91e5f8d6431d48bfb4d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
snapshots/goodpred.png ADDED

Git LFS Details

  • SHA256: 5c04192f9ce670487bd4d9095dc428fee0afe5ef8b0c1d458d72a23002d71d4c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.35 MB
snapshots/[email protected] ADDED
test.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ file - test.py
3
+ Simple quick script to evaluate model on test images.
4
+
5
+ Copyright (C) Yunxiao Shi 2017 - 2021
6
+ NIMA is released under the MIT license. See LICENSE for the fill license text.
7
+ """
8
+
9
+ import argparse
10
+ import os
11
+ import numpy as np
12
+ import matplotlib.pyplot as plt
13
+ from PIL import Image
14
+ import pandas as pd
15
+ from tqdm import tqdm
16
+ import torch
17
+ import torchvision.models as models
18
+ import torchvision.transforms as transforms
19
+
20
+ from model.model import *
21
+
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument('--model', type=str, help='path to pretrained model')
24
+ parser.add_argument('--test_csv', type=str, help='test csv file')
25
+ parser.add_argument('--test_images', type=str, help='path to folder containing images')
26
+ parser.add_argument('--workers', type=int, default=4, help='number of workers')
27
+ parser.add_argument('--predictions', type=str, help='output file to store predictions')
28
+ args = parser.parse_args()
29
+
30
+ base_model = models.vgg16(pretrained=True)
31
+ model = NIMA(base_model)
32
+
33
+ try:
34
+ model.load_state_dict(torch.load(args.model))
35
+ print('successfully loaded model')
36
+ except:
37
+ raise
38
+
39
+ seed = 42
40
+ torch.manual_seed(seed)
41
+
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+ model = model.to(device)
45
+
46
+ model.eval()
47
+
48
+ test_transform = transforms.Compose([
49
+ transforms.Scale(256),
50
+ transforms.RandomCrop(224),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
53
+ std=[0.229, 0.224, 0.225])
54
+ ])
55
+
56
+ test_df = pd.read_csv(args.test_csv, header=None)
57
+ test_imgs = test_df[0]
58
+ pbar = tqdm(total=len(test_imgs))
59
+
60
+ mean, std = 0.0, 0.0
61
+ for i, img in enumerate(test_imgs):
62
+ im = Image.open(os.path.join(args.test_images, str(img) + '.jpg'))
63
+ im = im.convert('RGB')
64
+ imt = test_transform(im)
65
+ imt = imt.unsqueeze(dim=0)
66
+ imt = imt.to(device)
67
+ with torch.no_grad():
68
+ out = model(imt)
69
+ out = out.view(10, 1)
70
+ for j, e in enumerate(out, 1):
71
+ mean += j * e
72
+ for k, e in enumerate(out, 1):
73
+ std += e * (k - mean) ** 2
74
+ std = std ** 0.5
75
+ gt = test_df[test_df[0] == img].to_numpy()[:, 1:].reshape(10, 1)
76
+ gt_mean = 0.0
77
+ for l, e in enumerate(gt, 1):
78
+ gt_mean += l * e
79
+ # print(str(img) + ' mean: %.3f | std: %.3f | GT: %.3f' % (mean, std, gt_mean))
80
+ if not os.path.exists(args.predictions):
81
+ os.makedirs(args.predictions)
82
+ with open(os.path.join(args.predictions, 'pred.txt'), 'a') as f:
83
+ f.write(str(img) + ' mean: %.3f | std: %.3f | GT: %.3f\n' % (mean, std, gt_mean))
84
+ mean, std = 0.0, 0.0
85
+ pbar.update()