|
import torch
|
|
import numpy as np
|
|
import cv2
|
|
from torchvision import transforms, models
|
|
from utils.preprocess import preprocess_frame
|
|
|
|
class GazeEstimationModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(GazeEstimationModel, self).__init__()
|
|
|
|
self.backbone = models.resnet50(pretrained=False)
|
|
|
|
self.backbone.fc = torch.nn.Linear(self.backbone.fc.in_features, 3)
|
|
|
|
def forward(self, x):
|
|
return self.backbone(x)
|
|
|
|
class GazePredictor:
|
|
def __init__(self, model_path):
|
|
|
|
self.model_path = model_path
|
|
self.model = None
|
|
self.device = None
|
|
|
|
|
|
self.transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
])
|
|
|
|
def _initialize_model(self):
|
|
|
|
if self.model is not None:
|
|
return
|
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
self.model = GazeEstimationModel()
|
|
|
|
|
|
state_dict = torch.load(self.model_path, map_location=self.device)
|
|
|
|
|
|
new_state_dict = {}
|
|
for key, value in state_dict.items():
|
|
new_key = key.replace("backbone.", "")
|
|
new_state_dict[new_key] = value
|
|
|
|
|
|
try:
|
|
self.model.backbone.load_state_dict(new_state_dict)
|
|
except RuntimeError as e:
|
|
print("Error loading state dict directly:", e)
|
|
print("Trying to load state dict with strict=False...")
|
|
self.model.backbone.load_state_dict(new_state_dict, strict=False)
|
|
|
|
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
def predict_gaze(self, frame):
|
|
|
|
self._initialize_model()
|
|
|
|
preprocessed = preprocess_frame(frame)
|
|
preprocessed = preprocessed[0]
|
|
preprocessed = self.transform(preprocessed).float().unsqueeze(0)
|
|
preprocessed = preprocessed.to(self.device)
|
|
with torch.no_grad():
|
|
outputs = self.model(preprocessed)
|
|
outputs = outputs.cpu().numpy()[0]
|
|
print("Model outputs:", outputs)
|
|
head_pose, gaze_h, gaze_v = outputs
|
|
return head_pose, gaze_h, gaze_v |