driver / scripts /inference.py
Guru-25's picture
new
d343b30 verified
import torch
import numpy as np
import cv2
from torchvision import transforms, models
from utils.preprocess import preprocess_frame
class GazeEstimationModel(torch.nn.Module):
def __init__(self):
super(GazeEstimationModel, self).__init__()
# Initialize ResNet-50 as the backbone
self.backbone = models.resnet50(pretrained=False)
# Modify the final fully connected layer for 3 outputs (head_pose, gaze_h, gaze_v)
self.backbone.fc = torch.nn.Linear(self.backbone.fc.in_features, 3)
def forward(self, x):
return self.backbone(x)
class GazePredictor:
def __init__(self, model_path):
# Initialize without moving to device - we'll do that during prediction
self.model_path = model_path
self.model = None
self.device = None
# Define transform outside of initialization
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def _initialize_model(self):
# Only initialize model when needed (inside ZeroGPU function)
if self.model is not None:
return
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize the custom model
self.model = GazeEstimationModel()
# Load the state dictionary
state_dict = torch.load(self.model_path, map_location=self.device)
# Check if state_dict has 'backbone.' prefix and strip it if necessary
new_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace("backbone.", "") # Remove 'backbone.' prefix
new_state_dict[new_key] = value
# Load the adjusted state dictionary into the model
try:
self.model.backbone.load_state_dict(new_state_dict)
except RuntimeError as e:
print("Error loading state dict directly:", e)
print("Trying to load state dict with strict=False...")
self.model.backbone.load_state_dict(new_state_dict, strict=False)
# Move to device and set to evaluation mode
self.model.to(self.device)
self.model.eval()
def predict_gaze(self, frame):
# Initialize model if not already done
self._initialize_model()
preprocessed = preprocess_frame(frame)
preprocessed = preprocessed[0]
preprocessed = self.transform(preprocessed).float().unsqueeze(0)
preprocessed = preprocessed.to(self.device)
with torch.no_grad():
outputs = self.model(preprocessed)
outputs = outputs.cpu().numpy()[0]
print("Model outputs:", outputs) # Debug print
head_pose, gaze_h, gaze_v = outputs
return head_pose, gaze_h, gaze_v