new
Browse files- app.py +12 -8
- scripts/__pycache__/inference.cpython-312.pyc +0 -0
- scripts/inference.py +20 -7
app.py
CHANGED
@@ -69,13 +69,8 @@ GAZE_MODEL_PATH = os.path.join("models", "gaze_estimation_model.pth")
|
|
69 |
DISTRACTION_MODEL_PATH = "best.pt"
|
70 |
|
71 |
# --- Global Initializations ---
|
72 |
-
gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
|
73 |
blink_detector = BlinkDetector()
|
74 |
|
75 |
-
# Load Distraction Model
|
76 |
-
distraction_model = YOLO(DISTRACTION_MODEL_PATH)
|
77 |
-
distraction_model.to('cpu')
|
78 |
-
|
79 |
# Distraction Class Names
|
80 |
distraction_class_names = [
|
81 |
'safe driving', 'drinking', 'eating', 'hair and makeup',
|
@@ -106,6 +101,7 @@ EYE_CLOSURE_THRESHOLD = 10
|
|
106 |
HEAD_STABILITY_THRESHOLD = 0.05
|
107 |
DISTRACTION_CONF_THRESHOLD = 0.1
|
108 |
|
|
|
109 |
def analyze_video(input_video):
|
110 |
cap = cv2.VideoCapture(input_video)
|
111 |
local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
|
@@ -247,13 +243,16 @@ def analyze_distraction_video(input_video):
|
|
247 |
|
248 |
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
249 |
|
|
|
|
|
|
|
250 |
while True:
|
251 |
ret, frame = cap.read()
|
252 |
if not ret:
|
253 |
break
|
254 |
|
255 |
try:
|
256 |
-
results =
|
257 |
|
258 |
display_text = "safe driving"
|
259 |
alarm_action = None
|
@@ -312,9 +311,12 @@ def process_distraction_frame(frame):
|
|
312 |
if frame is None:
|
313 |
return np.zeros((480, 640, 3), dtype=np.uint8)
|
314 |
|
|
|
|
|
|
|
315 |
try:
|
316 |
# Run distraction detection model
|
317 |
-
results =
|
318 |
|
319 |
display_text = "safe driving"
|
320 |
alarm_action = None
|
@@ -410,8 +412,10 @@ def process_gaze_frame(frame):
|
|
410 |
if start_time == 0:
|
411 |
start_time = current_time
|
412 |
|
|
|
|
|
413 |
try:
|
414 |
-
head_pose_gaze, gaze_h, gaze_v =
|
415 |
current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None
|
416 |
smoothed_gaze = smooth_values(gaze_history, current_gaze)
|
417 |
|
|
|
69 |
DISTRACTION_MODEL_PATH = "best.pt"
|
70 |
|
71 |
# --- Global Initializations ---
|
|
|
72 |
blink_detector = BlinkDetector()
|
73 |
|
|
|
|
|
|
|
|
|
74 |
# Distraction Class Names
|
75 |
distraction_class_names = [
|
76 |
'safe driving', 'drinking', 'eating', 'hair and makeup',
|
|
|
101 |
HEAD_STABILITY_THRESHOLD = 0.05
|
102 |
DISTRACTION_CONF_THRESHOLD = 0.1
|
103 |
|
104 |
+
@spaces.GPU(duration=30) # Set duration to 30 seconds for real-time processing
|
105 |
def analyze_video(input_video):
|
106 |
cap = cv2.VideoCapture(input_video)
|
107 |
local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
|
|
|
243 |
|
244 |
fps = cap.get(cv2.CAP_PROP_FPS) or 30
|
245 |
|
246 |
+
local_distraction_model = YOLO(DISTRACTION_MODEL_PATH)
|
247 |
+
local_distraction_model.to('cpu')
|
248 |
+
|
249 |
while True:
|
250 |
ret, frame = cap.read()
|
251 |
if not ret:
|
252 |
break
|
253 |
|
254 |
try:
|
255 |
+
results = local_distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
|
256 |
|
257 |
display_text = "safe driving"
|
258 |
alarm_action = None
|
|
|
311 |
if frame is None:
|
312 |
return np.zeros((480, 640, 3), dtype=np.uint8)
|
313 |
|
314 |
+
local_distraction_model = YOLO(DISTRACTION_MODEL_PATH)
|
315 |
+
local_distraction_model.to('cpu')
|
316 |
+
|
317 |
try:
|
318 |
# Run distraction detection model
|
319 |
+
results = local_distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
|
320 |
|
321 |
display_text = "safe driving"
|
322 |
alarm_action = None
|
|
|
412 |
if start_time == 0:
|
413 |
start_time = current_time
|
414 |
|
415 |
+
local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
|
416 |
+
|
417 |
try:
|
418 |
+
head_pose_gaze, gaze_h, gaze_v = local_gaze_predictor.predict_gaze(frame)
|
419 |
current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None
|
420 |
smoothed_gaze = smooth_values(gaze_history, current_gaze)
|
421 |
|
scripts/__pycache__/inference.cpython-312.pyc
CHANGED
Binary files a/scripts/__pycache__/inference.cpython-312.pyc and b/scripts/__pycache__/inference.cpython-312.pyc differ
|
|
scripts/inference.py
CHANGED
@@ -17,13 +17,29 @@ class GazeEstimationModel(torch.nn.Module):
|
|
17 |
|
18 |
class GazePredictor:
|
19 |
def __init__(self, model_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
|
22 |
# Initialize the custom model
|
23 |
self.model = GazeEstimationModel()
|
24 |
|
25 |
# Load the state dictionary
|
26 |
-
state_dict = torch.load(model_path, map_location=self.device)
|
27 |
|
28 |
# Check if state_dict has 'backbone.' prefix and strip it if necessary
|
29 |
new_state_dict = {}
|
@@ -42,14 +58,11 @@ class GazePredictor:
|
|
42 |
# Move to device and set to evaluation mode
|
43 |
self.model.to(self.device)
|
44 |
self.model.eval()
|
45 |
-
|
46 |
-
# Define preprocessing transform
|
47 |
-
self.transform = transforms.Compose([
|
48 |
-
transforms.ToTensor(),
|
49 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
50 |
-
])
|
51 |
|
52 |
def predict_gaze(self, frame):
|
|
|
|
|
|
|
53 |
preprocessed = preprocess_frame(frame)
|
54 |
preprocessed = preprocessed[0]
|
55 |
preprocessed = self.transform(preprocessed).float().unsqueeze(0)
|
|
|
17 |
|
18 |
class GazePredictor:
|
19 |
def __init__(self, model_path):
|
20 |
+
# Initialize without moving to device - we'll do that during prediction
|
21 |
+
self.model_path = model_path
|
22 |
+
self.model = None
|
23 |
+
self.device = None
|
24 |
+
|
25 |
+
# Define transform outside of initialization
|
26 |
+
self.transform = transforms.Compose([
|
27 |
+
transforms.ToTensor(),
|
28 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
29 |
+
])
|
30 |
+
|
31 |
+
def _initialize_model(self):
|
32 |
+
# Only initialize model when needed (inside ZeroGPU function)
|
33 |
+
if self.model is not None:
|
34 |
+
return
|
35 |
+
|
36 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
37 |
|
38 |
# Initialize the custom model
|
39 |
self.model = GazeEstimationModel()
|
40 |
|
41 |
# Load the state dictionary
|
42 |
+
state_dict = torch.load(self.model_path, map_location=self.device)
|
43 |
|
44 |
# Check if state_dict has 'backbone.' prefix and strip it if necessary
|
45 |
new_state_dict = {}
|
|
|
58 |
# Move to device and set to evaluation mode
|
59 |
self.model.to(self.device)
|
60 |
self.model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
def predict_gaze(self, frame):
|
63 |
+
# Initialize model if not already done
|
64 |
+
self._initialize_model()
|
65 |
+
|
66 |
preprocessed = preprocess_frame(frame)
|
67 |
preprocessed = preprocessed[0]
|
68 |
preprocessed = self.transform(preprocessed).float().unsqueeze(0)
|