File size: 523 Bytes
0077a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
from torchvision import transforms
from config import DEVICE

def predict_faces(model, faces):
    transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

    predictions = []
    for face in faces:
        face = transform(face).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            pred = model(face).item()
        predictions.append(pred)
    return predictions