File size: 3,829 Bytes
4d9e196 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
'''VGGFace models for Keras.
# Notes:
- Utility functions are modified versions of Keras functions [Keras](https://keras.io)
'''
import numpy as np
from keras import backend as K
from keras.utils.data_utils import get_file
V1_LABELS_PATH = 'https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_labels_v1.npy'
V2_LABELS_PATH = 'https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_labels_v2.npy'
VGG16_WEIGHTS_PATH = 'https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_vgg16.h5'
VGG16_WEIGHTS_PATH_NO_TOP = 'https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_notop_vgg16.h5'
RESNET50_WEIGHTS_PATH = 'https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_resnet50.h5'
RESNET50_WEIGHTS_PATH_NO_TOP = 'https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_notop_resnet50.h5'
SENET50_WEIGHTS_PATH = 'https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_senet50.h5'
SENET50_WEIGHTS_PATH_NO_TOP = 'https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_notop_senet50.h5'
VGGFACE_DIR = 'models/vggface'
def preprocess_input(x, data_format=None, version=1):
x_temp = np.copy(x)
if data_format is None:
data_format = K.image_data_format()
assert data_format in {'channels_last', 'channels_first'}
if version == 1:
if data_format == 'channels_first':
x_temp = x_temp[:, ::-1, ...]
x_temp[:, 0, :, :] -= 93.5940
x_temp[:, 1, :, :] -= 104.7624
x_temp[:, 2, :, :] -= 129.1863
else:
x_temp = x_temp[..., ::-1]
x_temp[..., 0] -= 93.5940
x_temp[..., 1] -= 104.7624
x_temp[..., 2] -= 129.1863
elif version == 2:
if data_format == 'channels_first':
x_temp = x_temp[:, ::-1, ...]
x_temp[:, 0, :, :] -= 91.4953
x_temp[:, 1, :, :] -= 103.8827
x_temp[:, 2, :, :] -= 131.0912
else:
x_temp = x_temp[..., ::-1]
x_temp[..., 0] -= 91.4953
x_temp[..., 1] -= 103.8827
x_temp[..., 2] -= 131.0912
else:
raise NotImplementedError
return x_temp
def decode_predictions(preds, top=5):
LABELS = None
if len(preds.shape) == 2:
if preds.shape[1] == 2622:
fpath = get_file('rcmalli_vggface_labels_v1.npy',
V1_LABELS_PATH,
cache_subdir=VGGFACE_DIR)
LABELS = np.load(fpath)
elif preds.shape[1] == 8631:
fpath = get_file('rcmalli_vggface_labels_v2.npy',
V2_LABELS_PATH,
cache_subdir=VGGFACE_DIR)
LABELS = np.load(fpath)
else:
raise ValueError('`decode_predictions` expects '
'a batch of predictions '
'(i.e. a 2D array of shape (samples, 2622)) for V1 or '
'(samples, 8631) for V2.'
'Found array with shape: ' + str(preds.shape))
else:
raise ValueError('`decode_predictions` expects '
'a batch of predictions '
'(i.e. a 2D array of shape (samples, 2622)) for V1 or '
'(samples, 8631) for V2.'
'Found array with shape: ' + str(preds.shape))
results = []
for pred in preds:
top_indices = pred.argsort()[-top:][::-1]
result = [[str(LABELS[i].encode('utf8')), pred[i]] for i in top_indices]
result.sort(key=lambda x: x[1], reverse=True)
results.append(result)
return results
|