Spaces:
Runtime error
Runtime error
""" | |
file - model.py | |
Implements the aesthemic model and emd loss used in paper. | |
Copyright (C) Yunxiao Shi 2017 - 2021 | |
NIMA is released under the MIT license. See LICENSE for the fill license text. | |
""" | |
import torch | |
import torch.nn as nn | |
class NIMA(nn.Module): | |
"""Neural IMage Assessment model by Google""" | |
def __init__(self, base_model, num_classes=10): | |
super(NIMA, self).__init__() | |
self.features = base_model.features | |
self.classifier = nn.Sequential( | |
nn.Dropout(p=0.75), | |
nn.Linear(in_features=25088, out_features=num_classes), | |
nn.Softmax()) | |
def forward(self, x): | |
out = self.features(x) | |
out = out.view(out.size(0), -1) | |
out = self.classifier(out) | |
return out | |
def single_emd_loss(p, q, r=2): | |
""" | |
Earth Mover's Distance of one sample | |
Args: | |
p: true distribution of shape num_classes × 1 | |
q: estimated distribution of shape num_classes × 1 | |
r: norm parameter | |
""" | |
assert p.shape == q.shape, "Length of the two distribution must be the same" | |
length = p.shape[0] | |
emd_loss = 0.0 | |
for i in range(1, length + 1): | |
emd_loss += torch.abs(sum(p[:i] - q[:i])) ** r | |
return (emd_loss / length) ** (1. / r) | |
def emd_loss(p, q, r=2): | |
""" | |
Earth Mover's Distance on a batch | |
Args: | |
p: true distribution of shape mini_batch_size × num_classes × 1 | |
q: estimated distribution of shape mini_batch_size × num_classes × 1 | |
r: norm parameters | |
""" | |
assert p.shape == q.shape, "Shape of the two distribution batches must be the same." | |
mini_batch_size = p.shape[0] | |
loss_vector = [] | |
for i in range(mini_batch_size): | |
loss_vector.append(single_emd_loss(p[i], q[i], r=r)) | |
return sum(loss_vector) / mini_batch_size | |