Spaces:
Runtime error
Runtime error
Commit
·
1e8d169
1
Parent(s):
20a3309
Add files from Neural-IMage-Assessment repo
Browse files- .gitattributes +3 -0
- .gitignore +1 -0
- LICENSE +8 -0
- README.md +81 -12
- dataset/__init__.py +0 -0
- dataset/dataset.py +66 -0
- env.yml +106 -0
- main.py +233 -0
- model/__init__.py +0 -0
- model/model.py +61 -0
- predictions/pred.txt +0 -0
- snapshots/badpred.png +3 -0
- snapshots/contrast.png +3 -0
- snapshots/goodpred.png +3 -0
- snapshots/[email protected] +0 -0
- test.py +85 -0
.gitattributes
CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
snapshots/badpred.png filter=lfs diff=lfs merge=lfs -text
|
36 |
+
snapshots/contrast.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
snapshots/goodpred.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
./data/
|
LICENSE
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT) Copyright (c) 2020 Yunxiao Shi
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
|
5 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6 |
+
|
7 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
8 |
+
© 2020 GitHub, Inc.
|
README.md
CHANGED
@@ -1,12 +1,81 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## NIMA: Neural IMage Assessment
|
2 |
+
|
3 |
+
[](https://www.python.org/)
|
4 |
+
[](./LICENSE)
|
5 |
+
|
6 |
+
This is a PyTorch implementation of the paper [NIMA: Neural IMage Assessment](https://arxiv.org/abs/1709.05424) (accepted at [IEEE Transactions on Image Processing](https://ieeexplore.ieee.org/document/8352823)) by Hossein Talebi and Peyman Milanfar. You can learn more from [this post at Google Research Blog](https://research.googleblog.com/2017/12/introducing-nima-neural-image-assessment.html).
|
7 |
+
|
8 |
+
## Implementation Details
|
9 |
+
|
10 |
+
+ The model was trained on the [AVA (Aesthetic Visual Analysis) dataset](http://refbase.cvc.uab.es/files/MMP2012a.pdf) containing 255,500+ images. You can get it from [here](https://github.com/mtobeiyf/ava_downloader). ~~**Note: there may be some corrupted images in the dataset, remove them first before you start training**.~~ Use provided CSVs which have already done this for you.
|
11 |
+
|
12 |
+
+ Dataset is split into 229,981 images for training, 12,691 images for validation and 12,818 images for testing.
|
13 |
+
|
14 |
+
+ An ImageNet pretrained VGG-16 is used as the base network. Should be easy to plug in the other two options (MobileNet and Inception-v2).
|
15 |
+
|
16 |
+
+ The learning rate setting differs from the original paper. Can't seem to get the model to converge using the original params. Also didn't do much hyper-param tuning therefore you could probably get better results. Other settings are all directly mirrored from the paper.
|
17 |
+
|
18 |
+
## Requirements
|
19 |
+
|
20 |
+
Code is written using [PyTorch](https://pytorch.org/get-started/locally/) 1.8.1 with [CUDA](https://developer.nvidia.com/cuda-toolkit) 11.1. You can recreate the environment I used with [conda](https://docs.conda.io/en/latest/miniconda.html) by
|
21 |
+
```
|
22 |
+
conda env create -f env.yml
|
23 |
+
```
|
24 |
+
to install the dependancies.
|
25 |
+
|
26 |
+
## Usage
|
27 |
+
|
28 |
+
To start training on the AVA dataset, first download the dataset from the link above and decompress which should create a directory named ```images/```. Then download the curated annotation CSVs below
|
29 |
+
which already splits the dataset (You can create your own split of course). Then do
|
30 |
+
|
31 |
+
```python
|
32 |
+
python main.py --img_path /path/to/images/ --train --train_csv_file /path/to/train_labels.csv --val_csv_file /path/to/val_labels.csv --conv_base_lr 5e-4 --dense_lr 5e-3 --decay --ckpt_path /path/to/ckpts --epochs 100 --early_stoppping_patience 10
|
33 |
+
```
|
34 |
+
|
35 |
+
For inference, do
|
36 |
+
|
37 |
+
```python
|
38 |
+
python -W ignore test.py --model /path/to/your_model --test_csv /path/to/test_labels.csv --test_images /path/to/images --predictions /path/to/save/predictions
|
39 |
+
```
|
40 |
+
|
41 |
+
See ```predictions/``` for dumped predictions as an example.
|
42 |
+
|
43 |
+
## Training Statistics
|
44 |
+
|
45 |
+
Training is done with early stopping. Here I set ```early_stopping_patience=10```.
|
46 |
+
<p align="center">
|
47 |
+
<img src="./snapshots/[email protected]">
|
48 |
+
</p>
|
49 |
+
|
50 |
+
## Pretrained Model
|
51 |
+
|
52 |
+
~0.069 EMD on validation. Not fully converged yet (constrained by resources). To continue training, download the pretrained weights and add ```--warm_start --warm_start_epoch 34``` to your args.
|
53 |
+
|
54 |
+
[Google Drive](https://drive.google.com/file/d/1w9Ig_d6yZqUZSR63kPjZLrEjJ1n845B_/view?usp=sharing)
|
55 |
+
|
56 |
+
## Annotation CSV Files
|
57 |
+
[Train](https://drive.google.com/file/d/1IBXPXPkCiTz04wWcoReJv4Nk06VsjSkI/view?usp=sharing) [Validation](https://drive.google.com/file/d/1tJfO1zFBoQYzd8kUo5PKeHTcdzBL7115/view?usp=sharing) [Test](https://drive.google.com/file/d/105UGnkglpKuusPhJaPnFSa2JlQV3du9O/view?usp=sharing)
|
58 |
+
|
59 |
+
## Example Results
|
60 |
+
|
61 |
+
+ Here first shows some good predictions from the test set. Each image title starts with ground-truth rating followed by the predicted mean and std in the parentheses.
|
62 |
+
|
63 |
+
<p align="center">
|
64 |
+
<img src="./snapshots/goodpred.png">
|
65 |
+
</p>
|
66 |
+
|
67 |
+
+ Also some failure cases, it would seem that the model usually fails at images with low/high aesthetic ratings.
|
68 |
+
|
69 |
+
<p align="center">
|
70 |
+
<img src="./snapshots/badpred.png">
|
71 |
+
</p>
|
72 |
+
|
73 |
+
+ The predicted aesthetic ratings from training on the AVA dataset are sensitive to contrast adjustments, preferring images with higher contrast. Below top row is the reference image with contrast ```c=1.0```, while bottom images are enhanced with contrast ```[0.25, 0.75, 1.25, 1.75]```. Contrast adjustment is done using ```ImageEnhance.Contrast``` from ```PIL``` (in this case [pillow-simd](https://github.com/uploadcare/pillow-simd)).
|
74 |
+
|
75 |
+
<p align="center">
|
76 |
+
<img src="./snapshots/contrast.png">
|
77 |
+
</p>
|
78 |
+
|
79 |
+
## License
|
80 |
+
|
81 |
+
MIT
|
dataset/__init__.py
ADDED
File without changes
|
dataset/dataset.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
file - dataset.py
|
3 |
+
Customized dataset class to loop through the AVA dataset and apply needed image augmentations for training.
|
4 |
+
|
5 |
+
Copyright (C) Yunxiao Shi 2017 - 2021
|
6 |
+
NIMA is released under the MIT license. See LICENSE for the fill license text.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
|
11 |
+
import pandas as pd
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch.utils import data
|
16 |
+
import torchvision.transforms as transforms
|
17 |
+
|
18 |
+
|
19 |
+
class AVADataset(data.Dataset):
|
20 |
+
"""AVA dataset
|
21 |
+
|
22 |
+
Args:
|
23 |
+
csv_file: a 11-column csv_file, column one contains the names of image files, column 2-11 contains the empiricial distributions of ratings
|
24 |
+
root_dir: directory to the images
|
25 |
+
transform: preprocessing and augmentation of the training images
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, csv_file, root_dir, transform=None):
|
29 |
+
self.annotations = pd.read_csv(csv_file)
|
30 |
+
self.root_dir = root_dir
|
31 |
+
self.transform = transform
|
32 |
+
|
33 |
+
def __len__(self):
|
34 |
+
return len(self.annotations)
|
35 |
+
|
36 |
+
def __getitem__(self, idx):
|
37 |
+
img_name = os.path.join(self.root_dir, str(self.annotations.iloc[idx, 0]) + '.jpg')
|
38 |
+
image = Image.open(img_name).convert('RGB')
|
39 |
+
annotations = self.annotations.iloc[idx, 1:].to_numpy()
|
40 |
+
annotations = annotations.astype('float').reshape(-1, 1)
|
41 |
+
sample = {'img_id': img_name, 'image': image, 'annotations': annotations}
|
42 |
+
|
43 |
+
if self.transform:
|
44 |
+
sample['image'] = self.transform(sample['image'])
|
45 |
+
|
46 |
+
return sample
|
47 |
+
|
48 |
+
|
49 |
+
if __name__ == '__main__':
|
50 |
+
|
51 |
+
# sanity check
|
52 |
+
root = './data/images'
|
53 |
+
csv_file = './data/train_labels.csv'
|
54 |
+
train_transform = transforms.Compose([
|
55 |
+
transforms.Scale(256),
|
56 |
+
transforms.RandomCrop(224),
|
57 |
+
transforms.RandomHorizontalFlip(),
|
58 |
+
transforms.ToTensor(),
|
59 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
|
60 |
+
dset = AVADataset(csv_file=csv_file, root_dir=root, transform=train_transform)
|
61 |
+
train_loader = data.DataLoader(dset, batch_size=4, shuffle=True, num_workers=4)
|
62 |
+
for i, data in enumerate(train_loader):
|
63 |
+
images = data['image']
|
64 |
+
print(images.size())
|
65 |
+
labels = data['annotations']
|
66 |
+
print(labels.size())
|
env.yml
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: torch1.8.1
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1=main
|
8 |
+
- blas=1.0=mkl
|
9 |
+
- bzip2=1.0.8=h7b6447c_0
|
10 |
+
- ca-certificates=2021.5.25=h06a4308_1
|
11 |
+
- certifi=2021.5.30=py38h06a4308_0
|
12 |
+
- cudatoolkit=11.1.74=h6bb024c_0
|
13 |
+
- cycler=0.10.0=py38_0
|
14 |
+
- dbus=1.13.18=hb2f20db_0
|
15 |
+
- expat=2.4.1=h2531618_2
|
16 |
+
- ffmpeg=4.3=hf484d3e_0
|
17 |
+
- fontconfig=2.13.1=h6c09931_0
|
18 |
+
- freetype=2.10.4=h5ab3b9f_0
|
19 |
+
- glib=2.68.2=h36276a3_0
|
20 |
+
- gmp=6.2.1=h2531618_2
|
21 |
+
- gnutls=3.6.15=he1e5248_0
|
22 |
+
- gst-plugins-base=1.14.0=h8213a91_2
|
23 |
+
- gstreamer=1.14.0=h28cd5cc_2
|
24 |
+
- icu=58.2=he6710b0_3
|
25 |
+
- intel-openmp=2021.2.0=h06a4308_610
|
26 |
+
- jpeg=9b=h024ee3a_2
|
27 |
+
- kiwisolver=1.3.1=py38h2531618_0
|
28 |
+
- lame=3.100=h7b6447c_0
|
29 |
+
- lcms2=2.12=h3be6417_0
|
30 |
+
- ld_impl_linux-64=2.33.1=h53a641e_7
|
31 |
+
- libffi=3.3=he6710b0_2
|
32 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
33 |
+
- libiconv=1.15=h63c8f33_5
|
34 |
+
- libidn2=2.3.1=h27cfd23_0
|
35 |
+
- libpng=1.6.37=hbc83047_0
|
36 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
37 |
+
- libtasn1=4.16.0=h27cfd23_0
|
38 |
+
- libtiff=4.2.0=h85742a9_0
|
39 |
+
- libunistring=0.9.10=h27cfd23_0
|
40 |
+
- libuuid=1.0.3=h1bed415_2
|
41 |
+
- libuv=1.40.0=h7b6447c_0
|
42 |
+
- libwebp-base=1.2.0=h27cfd23_0
|
43 |
+
- libxcb=1.14=h7b6447c_0
|
44 |
+
- libxml2=2.9.10=hb55368b_3
|
45 |
+
- lz4-c=1.9.3=h2531618_0
|
46 |
+
- matplotlib=3.3.4=py38h06a4308_0
|
47 |
+
- matplotlib-base=3.3.4=py38h62a2d02_0
|
48 |
+
- mkl=2021.2.0=h06a4308_296
|
49 |
+
- mkl-service=2.3.0=py38h27cfd23_1
|
50 |
+
- mkl_fft=1.3.0=py38h42c9631_2
|
51 |
+
- mkl_random=1.2.1=py38ha9443f7_2
|
52 |
+
- ncurses=6.2=he6710b0_1
|
53 |
+
- nettle=3.7.2=hbbd107a_1
|
54 |
+
- ninja=1.10.2=hff7bd54_1
|
55 |
+
- numpy=1.20.2=py38h2d18471_0
|
56 |
+
- numpy-base=1.20.2=py38hfae3a4d_0
|
57 |
+
- olefile=0.46=py_0
|
58 |
+
- openh264=2.1.0=hd408876_0
|
59 |
+
- openssl=1.1.1k=h27cfd23_0
|
60 |
+
- pcre=8.44=he6710b0_0
|
61 |
+
- pip=21.1.1=py38h06a4308_0
|
62 |
+
- pyparsing=2.4.7=pyhd3eb1b0_0
|
63 |
+
- pyqt=5.9.2=py38h05f1152_4
|
64 |
+
- python=3.8.10=hdb3f193_7
|
65 |
+
- python-dateutil=2.8.1=pyhd3eb1b0_0
|
66 |
+
- pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0
|
67 |
+
- qt=5.9.7=h5867ecd_1
|
68 |
+
- readline=8.1=h27cfd23_0
|
69 |
+
- setuptools=52.0.0=py38h06a4308_0
|
70 |
+
- sip=4.19.13=py38he6710b0_0
|
71 |
+
- six=1.15.0=py38h06a4308_0
|
72 |
+
- sqlite=3.35.4=hdfb4753_0
|
73 |
+
- tk=8.6.10=hbc83047_0
|
74 |
+
- torchaudio=0.8.1=py38
|
75 |
+
- torchvision=0.9.1=py38_cu111
|
76 |
+
- tornado=6.1=py38h27cfd23_0
|
77 |
+
- typing_extensions=3.7.4.3=pyha847dfd_0
|
78 |
+
- wheel=0.36.2=pyhd3eb1b0_0
|
79 |
+
- xz=5.2.5=h7b6447c_0
|
80 |
+
- zlib=1.2.11=h7b6447c_3
|
81 |
+
- zstd=1.4.9=haebb681_0
|
82 |
+
- pip:
|
83 |
+
- absl-py==0.12.0
|
84 |
+
- cachetools==4.2.2
|
85 |
+
- chardet==4.0.0
|
86 |
+
- google-auth==1.30.1
|
87 |
+
- google-auth-oauthlib==0.4.4
|
88 |
+
- grpcio==1.38.0
|
89 |
+
- idna==2.10
|
90 |
+
- markdown==3.3.4
|
91 |
+
- oauthlib==3.1.1
|
92 |
+
- pandas==1.2.4
|
93 |
+
- pillow-simd==7.0.0.post3
|
94 |
+
- protobuf==3.17.2
|
95 |
+
- pyasn1==0.4.8
|
96 |
+
- pyasn1-modules==0.2.8
|
97 |
+
- pytz==2021.1
|
98 |
+
- requests==2.25.1
|
99 |
+
- requests-oauthlib==1.3.0
|
100 |
+
- rsa==4.7.2
|
101 |
+
- tensorboard==2.5.0
|
102 |
+
- tensorboard-data-server==0.6.1
|
103 |
+
- tensorboard-plugin-wit==1.8.0
|
104 |
+
- tqdm==4.61.0
|
105 |
+
- urllib3==1.26.5
|
106 |
+
- werkzeug==2.0.1
|
main.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
file - main.py
|
3 |
+
Main script to train the aesthetic model on the AVA dataset.
|
4 |
+
|
5 |
+
Copyright (C) Yunxiao Shi 2017 - 2021
|
6 |
+
NIMA is released under the MIT license. See LICENSE for the fill license text.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import matplotlib
|
14 |
+
# matplotlib.use('Agg')
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.autograd as autograd
|
19 |
+
import torch.optim as optim
|
20 |
+
|
21 |
+
import torchvision.transforms as transforms
|
22 |
+
import torchvision.datasets as dsets
|
23 |
+
import torchvision.models as models
|
24 |
+
|
25 |
+
from torch.utils.tensorboard import SummaryWriter
|
26 |
+
|
27 |
+
from dataset.dataset import AVADataset
|
28 |
+
|
29 |
+
from model.model import *
|
30 |
+
|
31 |
+
|
32 |
+
def main(config):
|
33 |
+
|
34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
writer = SummaryWriter()
|
36 |
+
|
37 |
+
train_transform = transforms.Compose([
|
38 |
+
transforms.Scale(256),
|
39 |
+
transforms.RandomCrop(224),
|
40 |
+
transforms.RandomHorizontalFlip(),
|
41 |
+
transforms.ToTensor(),
|
42 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
43 |
+
std=[0.229, 0.224, 0.225])])
|
44 |
+
|
45 |
+
val_transform = transforms.Compose([
|
46 |
+
transforms.Scale(256),
|
47 |
+
transforms.RandomCrop(224),
|
48 |
+
transforms.ToTensor(),
|
49 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
50 |
+
std=[0.229, 0.224, 0.225])])
|
51 |
+
|
52 |
+
base_model = models.vgg16(pretrained=True)
|
53 |
+
model = NIMA(base_model)
|
54 |
+
|
55 |
+
if config.warm_start:
|
56 |
+
model.load_state_dict(torch.load(os.path.join(config.ckpt_path, 'epoch-%d.pth' % config.warm_start_epoch)))
|
57 |
+
print('Successfully loaded model epoch-%d.pth' % config.warm_start_epoch)
|
58 |
+
|
59 |
+
if config.multi_gpu:
|
60 |
+
model.features = torch.nn.DataParallel(model.features, device_ids=config.gpu_ids)
|
61 |
+
model = model.to(device)
|
62 |
+
else:
|
63 |
+
model = model.to(device)
|
64 |
+
|
65 |
+
conv_base_lr = config.conv_base_lr
|
66 |
+
dense_lr = config.dense_lr
|
67 |
+
optimizer = optim.SGD([
|
68 |
+
{'params': model.features.parameters(), 'lr': conv_base_lr},
|
69 |
+
{'params': model.classifier.parameters(), 'lr': dense_lr}],
|
70 |
+
momentum=0.9
|
71 |
+
)
|
72 |
+
|
73 |
+
param_num = 0
|
74 |
+
for param in model.parameters():
|
75 |
+
if param.requires_grad:
|
76 |
+
param_num += param.numel()
|
77 |
+
print('Trainable params: %.2f million' % (param_num / 1e6))
|
78 |
+
|
79 |
+
if config.train:
|
80 |
+
trainset = AVADataset(csv_file=config.train_csv_file, root_dir=config.img_path, transform=train_transform)
|
81 |
+
valset = AVADataset(csv_file=config.val_csv_file, root_dir=config.img_path, transform=val_transform)
|
82 |
+
|
83 |
+
train_loader = torch.utils.data.DataLoader(trainset, batch_size=config.train_batch_size,
|
84 |
+
shuffle=True, num_workers=config.num_workers)
|
85 |
+
val_loader = torch.utils.data.DataLoader(valset, batch_size=config.val_batch_size,
|
86 |
+
shuffle=False, num_workers=config.num_workers)
|
87 |
+
# for early stopping
|
88 |
+
count = 0
|
89 |
+
init_val_loss = float('inf')
|
90 |
+
train_losses = []
|
91 |
+
val_losses = []
|
92 |
+
for epoch in range(config.warm_start_epoch, config.epochs):
|
93 |
+
batch_losses = []
|
94 |
+
for i, data in enumerate(train_loader):
|
95 |
+
images = data['image'].to(device)
|
96 |
+
labels = data['annotations'].to(device).float()
|
97 |
+
outputs = model(images)
|
98 |
+
outputs = outputs.view(-1, 10, 1)
|
99 |
+
|
100 |
+
optimizer.zero_grad()
|
101 |
+
|
102 |
+
loss = emd_loss(labels, outputs)
|
103 |
+
batch_losses.append(loss.item())
|
104 |
+
|
105 |
+
loss.backward()
|
106 |
+
|
107 |
+
optimizer.step()
|
108 |
+
|
109 |
+
print('Epoch: %d/%d | Step: %d/%d | Training EMD loss: %.4f' % (epoch + 1, config.epochs, i + 1, len(trainset) // config.train_batch_size + 1, loss.data[0]))
|
110 |
+
writer.add_scalar('batch train loss', loss.data[0], i + epoch * (len(trainset) // config.train_batch_size + 1))
|
111 |
+
|
112 |
+
avg_loss = sum(batch_losses) / (len(trainset) // config.train_batch_size + 1)
|
113 |
+
train_losses.append(avg_loss)
|
114 |
+
print('Epoch %d mean training EMD loss: %.4f' % (epoch + 1, avg_loss))
|
115 |
+
|
116 |
+
# exponetial learning rate decay
|
117 |
+
if config.decay:
|
118 |
+
if (epoch + 1) % 10 == 0:
|
119 |
+
conv_base_lr = conv_base_lr * config.lr_decay_rate ** ((epoch + 1) / config.lr_decay_freq)
|
120 |
+
dense_lr = dense_lr * config.lr_decay_rate ** ((epoch + 1) / config.lr_decay_freq)
|
121 |
+
optimizer = optim.SGD([
|
122 |
+
{'params': model.features.parameters(), 'lr': conv_base_lr},
|
123 |
+
{'params': model.classifier.parameters(), 'lr': dense_lr}],
|
124 |
+
momentum=0.9
|
125 |
+
)
|
126 |
+
|
127 |
+
# do validation after each epoch
|
128 |
+
batch_val_losses = []
|
129 |
+
for data in val_loader:
|
130 |
+
images = data['image'].to(device)
|
131 |
+
labels = data['annotations'].to(device).float()
|
132 |
+
with torch.no_grad():
|
133 |
+
outputs = model(images)
|
134 |
+
outputs = outputs.view(-1, 10, 1)
|
135 |
+
val_loss = emd_loss(labels, outputs)
|
136 |
+
batch_val_losses.append(val_loss.item())
|
137 |
+
avg_val_loss = sum(batch_val_losses) / (len(valset) // config.val_batch_size + 1)
|
138 |
+
val_losses.append(avg_val_loss)
|
139 |
+
print('Epoch %d completed. Mean EMD loss on val set: %.4f.' % (epoch + 1, avg_val_loss))
|
140 |
+
writer.add_scalars('epoch losses', {'epoch train loss': avg_loss, 'epoch val loss': avg_val_loss}, epoch + 1)
|
141 |
+
|
142 |
+
# Use early stopping to monitor training
|
143 |
+
if avg_val_loss < init_val_loss:
|
144 |
+
init_val_loss = avg_val_loss
|
145 |
+
# save model weights if val loss decreases
|
146 |
+
print('Saving model...')
|
147 |
+
if not os.path.exists(config.ckpt_path):
|
148 |
+
os.makedirs(config.ckpt_path)
|
149 |
+
torch.save(model.state_dict(), os.path.join(config.ckpt_path, 'epoch-%d.pth' % (epoch + 1)))
|
150 |
+
print('Done.\n')
|
151 |
+
# reset count
|
152 |
+
count = 0
|
153 |
+
elif avg_val_loss >= init_val_loss:
|
154 |
+
count += 1
|
155 |
+
if count == config.early_stopping_patience:
|
156 |
+
print('Val EMD loss has not decreased in %d epochs. Training terminated.' % config.early_stopping_patience)
|
157 |
+
break
|
158 |
+
|
159 |
+
print('Training completed.')
|
160 |
+
|
161 |
+
'''
|
162 |
+
# use tensorboard to log statistics instead
|
163 |
+
if config.save_fig:
|
164 |
+
# plot train and val loss
|
165 |
+
epochs = range(1, epoch + 2)
|
166 |
+
plt.plot(epochs, train_losses, 'b-', label='train loss')
|
167 |
+
plt.plot(epochs, val_losses, 'g-', label='val loss')
|
168 |
+
plt.title('EMD loss')
|
169 |
+
plt.legend()
|
170 |
+
plt.savefig('./loss.png')
|
171 |
+
'''
|
172 |
+
|
173 |
+
if config.test:
|
174 |
+
model.eval()
|
175 |
+
# compute mean score
|
176 |
+
test_transform = val_transform
|
177 |
+
testset = AVADataset(csv_file=config.test_csv_file, root_dir=config.img_path, transform=val_transform)
|
178 |
+
test_loader = torch.utils.data.DataLoader(testset, batch_size=config.test_batch_size, shuffle=False, num_workers=config.num_workers)
|
179 |
+
|
180 |
+
mean_preds = []
|
181 |
+
std_preds = []
|
182 |
+
for data in test_loader:
|
183 |
+
image = data['image'].to(device)
|
184 |
+
output = model(image)
|
185 |
+
output = output.view(10, 1)
|
186 |
+
predicted_mean, predicted_std = 0.0, 0.0
|
187 |
+
for i, elem in enumerate(output, 1):
|
188 |
+
predicted_mean += i * elem
|
189 |
+
for j, elem in enumerate(output, 1):
|
190 |
+
predicted_std += elem * (j - predicted_mean) ** 2
|
191 |
+
predicted_std = predicted_std ** 0.5
|
192 |
+
mean_preds.append(predicted_mean)
|
193 |
+
std_preds.append(predicted_std)
|
194 |
+
# Do what you want with predicted and std...
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == '__main__':
|
198 |
+
|
199 |
+
parser = argparse.ArgumentParser()
|
200 |
+
|
201 |
+
# input parameters
|
202 |
+
parser.add_argument('--img_path', type=str, default='./data/images')
|
203 |
+
parser.add_argument('--train_csv_file', type=str, default='./data/train_labels.csv')
|
204 |
+
parser.add_argument('--val_csv_file', type=str, default='./data/val_labels.csv')
|
205 |
+
parser.add_argument('--test_csv_file', type=str, default='./data/test_labels.csv')
|
206 |
+
|
207 |
+
# training parameters
|
208 |
+
parser.add_argument('--train',action='store_true')
|
209 |
+
parser.add_argument('--test', action='store_true')
|
210 |
+
parser.add_argument('--decay', action='store_true')
|
211 |
+
parser.add_argument('--conv_base_lr', type=float, default=5e-3)
|
212 |
+
parser.add_argument('--dense_lr', type=float, default=5e-4)
|
213 |
+
parser.add_argument('--lr_decay_rate', type=float, default=0.95)
|
214 |
+
parser.add_argument('--lr_decay_freq', type=int, default=10)
|
215 |
+
parser.add_argument('--train_batch_size', type=int, default=128)
|
216 |
+
parser.add_argument('--val_batch_size', type=int, default=128)
|
217 |
+
parser.add_argument('--test_batch_size', type=int, default=1)
|
218 |
+
parser.add_argument('--num_workers', type=int, default=2)
|
219 |
+
parser.add_argument('--epochs', type=int, default=100)
|
220 |
+
|
221 |
+
# misc
|
222 |
+
parser.add_argument('--ckpt_path', type=str, default='./ckpts')
|
223 |
+
parser.add_argument('--multi_gpu', action='store_true')
|
224 |
+
parser.add_argument('--gpu_ids', type=list, default=None)
|
225 |
+
parser.add_argument('--warm_start', action='store_true')
|
226 |
+
parser.add_argument('--warm_start_epoch', type=int, default=0)
|
227 |
+
parser.add_argument('--early_stopping_patience', type=int, default=10)
|
228 |
+
parser.add_argument('--save_fig', action='store_true')
|
229 |
+
|
230 |
+
config = parser.parse_args()
|
231 |
+
|
232 |
+
main(config)
|
233 |
+
|
model/__init__.py
ADDED
File without changes
|
model/model.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
file - model.py
|
3 |
+
Implements the aesthemic model and emd loss used in paper.
|
4 |
+
|
5 |
+
Copyright (C) Yunxiao Shi 2017 - 2021
|
6 |
+
NIMA is released under the MIT license. See LICENSE for the fill license text.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
class NIMA(nn.Module):
|
13 |
+
|
14 |
+
"""Neural IMage Assessment model by Google"""
|
15 |
+
def __init__(self, base_model, num_classes=10):
|
16 |
+
super(NIMA, self).__init__()
|
17 |
+
self.features = base_model.features
|
18 |
+
self.classifier = nn.Sequential(
|
19 |
+
nn.Dropout(p=0.75),
|
20 |
+
nn.Linear(in_features=25088, out_features=num_classes),
|
21 |
+
nn.Softmax())
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
out = self.features(x)
|
25 |
+
out = out.view(out.size(0), -1)
|
26 |
+
out = self.classifier(out)
|
27 |
+
return out
|
28 |
+
|
29 |
+
|
30 |
+
def single_emd_loss(p, q, r=2):
|
31 |
+
"""
|
32 |
+
Earth Mover's Distance of one sample
|
33 |
+
|
34 |
+
Args:
|
35 |
+
p: true distribution of shape num_classes × 1
|
36 |
+
q: estimated distribution of shape num_classes × 1
|
37 |
+
r: norm parameter
|
38 |
+
"""
|
39 |
+
assert p.shape == q.shape, "Length of the two distribution must be the same"
|
40 |
+
length = p.shape[0]
|
41 |
+
emd_loss = 0.0
|
42 |
+
for i in range(1, length + 1):
|
43 |
+
emd_loss += torch.abs(sum(p[:i] - q[:i])) ** r
|
44 |
+
return (emd_loss / length) ** (1. / r)
|
45 |
+
|
46 |
+
|
47 |
+
def emd_loss(p, q, r=2):
|
48 |
+
"""
|
49 |
+
Earth Mover's Distance on a batch
|
50 |
+
|
51 |
+
Args:
|
52 |
+
p: true distribution of shape mini_batch_size × num_classes × 1
|
53 |
+
q: estimated distribution of shape mini_batch_size × num_classes × 1
|
54 |
+
r: norm parameters
|
55 |
+
"""
|
56 |
+
assert p.shape == q.shape, "Shape of the two distribution batches must be the same."
|
57 |
+
mini_batch_size = p.shape[0]
|
58 |
+
loss_vector = []
|
59 |
+
for i in range(mini_batch_size):
|
60 |
+
loss_vector.append(single_emd_loss(p[i], q[i], r=r))
|
61 |
+
return sum(loss_vector) / mini_batch_size
|
predictions/pred.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
snapshots/badpred.png
ADDED
![]() |
Git LFS Details
|
snapshots/contrast.png
ADDED
![]() |
Git LFS Details
|
snapshots/goodpred.png
ADDED
![]() |
Git LFS Details
|
snapshots/[email protected]
ADDED
![]() |
test.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
file - test.py
|
3 |
+
Simple quick script to evaluate model on test images.
|
4 |
+
|
5 |
+
Copyright (C) Yunxiao Shi 2017 - 2021
|
6 |
+
NIMA is released under the MIT license. See LICENSE for the fill license text.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from PIL import Image
|
14 |
+
import pandas as pd
|
15 |
+
from tqdm import tqdm
|
16 |
+
import torch
|
17 |
+
import torchvision.models as models
|
18 |
+
import torchvision.transforms as transforms
|
19 |
+
|
20 |
+
from model.model import *
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument('--model', type=str, help='path to pretrained model')
|
24 |
+
parser.add_argument('--test_csv', type=str, help='test csv file')
|
25 |
+
parser.add_argument('--test_images', type=str, help='path to folder containing images')
|
26 |
+
parser.add_argument('--workers', type=int, default=4, help='number of workers')
|
27 |
+
parser.add_argument('--predictions', type=str, help='output file to store predictions')
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
base_model = models.vgg16(pretrained=True)
|
31 |
+
model = NIMA(base_model)
|
32 |
+
|
33 |
+
try:
|
34 |
+
model.load_state_dict(torch.load(args.model))
|
35 |
+
print('successfully loaded model')
|
36 |
+
except:
|
37 |
+
raise
|
38 |
+
|
39 |
+
seed = 42
|
40 |
+
torch.manual_seed(seed)
|
41 |
+
|
42 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
|
44 |
+
model = model.to(device)
|
45 |
+
|
46 |
+
model.eval()
|
47 |
+
|
48 |
+
test_transform = transforms.Compose([
|
49 |
+
transforms.Scale(256),
|
50 |
+
transforms.RandomCrop(224),
|
51 |
+
transforms.ToTensor(),
|
52 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
53 |
+
std=[0.229, 0.224, 0.225])
|
54 |
+
])
|
55 |
+
|
56 |
+
test_df = pd.read_csv(args.test_csv, header=None)
|
57 |
+
test_imgs = test_df[0]
|
58 |
+
pbar = tqdm(total=len(test_imgs))
|
59 |
+
|
60 |
+
mean, std = 0.0, 0.0
|
61 |
+
for i, img in enumerate(test_imgs):
|
62 |
+
im = Image.open(os.path.join(args.test_images, str(img) + '.jpg'))
|
63 |
+
im = im.convert('RGB')
|
64 |
+
imt = test_transform(im)
|
65 |
+
imt = imt.unsqueeze(dim=0)
|
66 |
+
imt = imt.to(device)
|
67 |
+
with torch.no_grad():
|
68 |
+
out = model(imt)
|
69 |
+
out = out.view(10, 1)
|
70 |
+
for j, e in enumerate(out, 1):
|
71 |
+
mean += j * e
|
72 |
+
for k, e in enumerate(out, 1):
|
73 |
+
std += e * (k - mean) ** 2
|
74 |
+
std = std ** 0.5
|
75 |
+
gt = test_df[test_df[0] == img].to_numpy()[:, 1:].reshape(10, 1)
|
76 |
+
gt_mean = 0.0
|
77 |
+
for l, e in enumerate(gt, 1):
|
78 |
+
gt_mean += l * e
|
79 |
+
# print(str(img) + ' mean: %.3f | std: %.3f | GT: %.3f' % (mean, std, gt_mean))
|
80 |
+
if not os.path.exists(args.predictions):
|
81 |
+
os.makedirs(args.predictions)
|
82 |
+
with open(os.path.join(args.predictions, 'pred.txt'), 'a') as f:
|
83 |
+
f.write(str(img) + ' mean: %.3f | std: %.3f | GT: %.3f\n' % (mean, std, gt_mean))
|
84 |
+
mean, std = 0.0, 0.0
|
85 |
+
pbar.update()
|