File size: 2,533 Bytes
b8b61aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import numpy as np
import cv2
from torchvision import transforms, models
from utils.preprocess import preprocess_frame

class GazeEstimationModel(torch.nn.Module):
    def __init__(self):
        super(GazeEstimationModel, self).__init__()
        # Initialize ResNet-50 as the backbone
        self.backbone = models.resnet50(pretrained=False)
        # Modify the final fully connected layer for 3 outputs (head_pose, gaze_h, gaze_v)
        self.backbone.fc = torch.nn.Linear(self.backbone.fc.in_features, 3)
    
    def forward(self, x):
        return self.backbone(x)

class GazePredictor:
    def __init__(self, model_path):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize the custom model
        self.model = GazeEstimationModel()
        
        # Load the state dictionary
        state_dict = torch.load(model_path, map_location=self.device)
        
        # Check if state_dict has 'backbone.' prefix and strip it if necessary
        new_state_dict = {}
        for key, value in state_dict.items():
            new_key = key.replace("backbone.", "")  # Remove 'backbone.' prefix
            new_state_dict[new_key] = value
        
        # Load the adjusted state dictionary into the model
        try:
            self.model.backbone.load_state_dict(new_state_dict)
        except RuntimeError as e:
            print("Error loading state dict directly:", e)
            print("Trying to load state dict with strict=False...")
            self.model.backbone.load_state_dict(new_state_dict, strict=False)
        
        # Move to device and set to evaluation mode
        self.model.to(self.device)
        self.model.eval()
        
        # Define preprocessing transform
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def predict_gaze(self, frame):
        preprocessed = preprocess_frame(frame)
        preprocessed = preprocessed[0]
        preprocessed = self.transform(preprocessed).float().unsqueeze(0)
        preprocessed = preprocessed.to(self.device)
        with torch.no_grad():
            outputs = self.model(preprocessed)
            outputs = outputs.cpu().numpy()[0]
        print("Model outputs:", outputs)  # Debug print
        head_pose, gaze_h, gaze_v = outputs
        return head_pose, gaze_h, gaze_v