File size: 523 Bytes
0077a91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
from torchvision import transforms
from config import DEVICE
def predict_faces(model, faces):
transform = transforms.Compose([
transforms.Resize((299, 299)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
predictions = []
for face in faces:
face = transform(face).unsqueeze(0).to(DEVICE)
with torch.no_grad():
pred = model(face).item()
predictions.append(pred)
return predictions
|