Spaces:
Runtime error
Runtime error
""" | |
file - test.py | |
Simple quick script to evaluate model on test images. | |
Copyright (C) Yunxiao Shi 2017 - 2021 | |
NIMA is released under the MIT license. See LICENSE for the fill license text. | |
""" | |
import argparse | |
import os | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import pandas as pd | |
from tqdm import tqdm | |
import torch | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from model.model import * | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--model', type=str, help='path to pretrained model') | |
parser.add_argument('--test_csv', type=str, help='test csv file') | |
parser.add_argument('--test_images', type=str, help='path to folder containing images') | |
parser.add_argument('--workers', type=int, default=4, help='number of workers') | |
parser.add_argument('--predictions', type=str, help='output file to store predictions') | |
args = parser.parse_args() | |
base_model = models.vgg16(pretrained=True) | |
model = NIMA(base_model) | |
try: | |
model.load_state_dict(torch.load(args.model)) | |
print('successfully loaded model') | |
except: | |
raise | |
seed = 42 | |
torch.manual_seed(seed) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
test_transform = transforms.Compose([ | |
transforms.Scale(256), | |
transforms.RandomCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
test_df = pd.read_csv(args.test_csv, header=None) | |
test_imgs = test_df[0] | |
pbar = tqdm(total=len(test_imgs)) | |
mean, std = 0.0, 0.0 | |
for i, img in enumerate(test_imgs): | |
im = Image.open(os.path.join(args.test_images, str(img) + '.jpg')) | |
im = im.convert('RGB') | |
imt = test_transform(im) | |
imt = imt.unsqueeze(dim=0) | |
imt = imt.to(device) | |
with torch.no_grad(): | |
out = model(imt) | |
out = out.view(10, 1) | |
for j, e in enumerate(out, 1): | |
mean += j * e | |
for k, e in enumerate(out, 1): | |
std += e * (k - mean) ** 2 | |
std = std ** 0.5 | |
gt = test_df[test_df[0] == img].to_numpy()[:, 1:].reshape(10, 1) | |
gt_mean = 0.0 | |
for l, e in enumerate(gt, 1): | |
gt_mean += l * e | |
# print(str(img) + ' mean: %.3f | std: %.3f | GT: %.3f' % (mean, std, gt_mean)) | |
if not os.path.exists(args.predictions): | |
os.makedirs(args.predictions) | |
with open(os.path.join(args.predictions, 'pred.txt'), 'a') as f: | |
f.write(str(img) + ' mean: %.3f | std: %.3f | GT: %.3f\n' % (mean, std, gt_mean)) | |
mean, std = 0.0, 0.0 | |
pbar.update() | |