import torch import numpy as np import cv2 from torchvision import transforms, models from utils.preprocess import preprocess_frame class GazeEstimationModel(torch.nn.Module): def __init__(self): super(GazeEstimationModel, self).__init__() # Initialize ResNet-50 as the backbone self.backbone = models.resnet50(pretrained=False) # Modify the final fully connected layer for 3 outputs (head_pose, gaze_h, gaze_v) self.backbone.fc = torch.nn.Linear(self.backbone.fc.in_features, 3) def forward(self, x): return self.backbone(x) class GazePredictor: def __init__(self, model_path): # Initialize without moving to device - we'll do that during prediction self.model_path = model_path self.model = None self.device = None # Define transform outside of initialization self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def _initialize_model(self): # Only initialize model when needed (inside ZeroGPU function) if self.model is not None: return self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize the custom model self.model = GazeEstimationModel() # Load the state dictionary state_dict = torch.load(self.model_path, map_location=self.device) # Check if state_dict has 'backbone.' prefix and strip it if necessary new_state_dict = {} for key, value in state_dict.items(): new_key = key.replace("backbone.", "") # Remove 'backbone.' prefix new_state_dict[new_key] = value # Load the adjusted state dictionary into the model try: self.model.backbone.load_state_dict(new_state_dict) except RuntimeError as e: print("Error loading state dict directly:", e) print("Trying to load state dict with strict=False...") self.model.backbone.load_state_dict(new_state_dict, strict=False) # Move to device and set to evaluation mode self.model.to(self.device) self.model.eval() def predict_gaze(self, frame): # Initialize model if not already done self._initialize_model() preprocessed = preprocess_frame(frame) preprocessed = preprocessed[0] preprocessed = self.transform(preprocessed).float().unsqueeze(0) preprocessed = preprocessed.to(self.device) with torch.no_grad(): outputs = self.model(preprocessed) outputs = outputs.cpu().numpy()[0] print("Model outputs:", outputs) # Debug print head_pose, gaze_h, gaze_v = outputs return head_pose, gaze_h, gaze_v