jatinmehra commited on
Commit
d453003
·
verified ·
1 Parent(s): 0c42ef7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import tempfile
4
+ import numpy as np
5
+ import torch
6
+ from collections import deque
7
+ from transformers import AutoFeatureExtractor, AutoModelForVideoClassification
8
+ from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration
9
+
10
+ # Constants
11
+ NUM_FRAMES = 16
12
+ MODEL_NAME = "jatinmehra/Accident-Detection-using-Dashcam"
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+ @st.cache_resource
16
+ def load_model_and_extractor():
17
+ extractor = AutoFeatureExtractor.from_pretrained("facebook/timesformer-base-finetuned-k400")
18
+ model = AutoModelForVideoClassification.from_pretrained(
19
+ MODEL_NAME,
20
+ num_labels=2,
21
+ ignore_mismatched_sizes=True
22
+ ).to(DEVICE)
23
+ model.eval()
24
+ return extractor, model
25
+
26
+ extractor, model = load_model_and_extractor()
27
+
28
+ st.title("Dashcam Accident Predictor")
29
+ st.write("**higher score = higher accident probability**")
30
+
31
+ # Function to run inference on a saved video file
32
+
33
+ def run_inference_on_video(video_path):
34
+ cap = cv2.VideoCapture(video_path)
35
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
36
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
37
+ if total_frames <= 0:
38
+ st.error("Failed to read video frames.")
39
+ return None
40
+
41
+ # Uniform sampling
42
+ indices = np.linspace(0, total_frames-1, NUM_FRAMES, dtype=int)
43
+ frames = []
44
+ for idx in indices:
45
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
46
+ ret, frame = cap.read()
47
+ if not ret:
48
+ frames.append(np.zeros((224,224,3), dtype=np.uint8))
49
+ else:
50
+ rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
51
+ resized = cv2.resize(rgb, (224,224))
52
+ frames.append(resized)
53
+ cap.release()
54
+
55
+ # Preprocess and predict
56
+ inputs = extractor(frames, return_tensors="pt")
57
+ pixel_values = inputs['pixel_values'].to(DEVICE)
58
+ with torch.no_grad():
59
+ outputs = model(pixel_values=pixel_values).logits
60
+ prob = torch.softmax(outputs, dim=1)[0,1].item()
61
+ return prob
62
+
63
+ # UI Selection
64
+ source = st.radio("Choose input source", ("Upload Video", "Webcam"))
65
+
66
+ if source == "Upload Video":
67
+ uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
68
+ if uploaded_file is not None:
69
+ tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
70
+ tfile.write(uploaded_file.read())
71
+ st.video(uploaded_file)
72
+ st.write("Running inference...")
73
+ score = run_inference_on_video(tfile.name)
74
+ if score is not None:
75
+ st.success(f"Accident probability: {score:.2f}")
76
+
77
+ else:
78
+ # Webcam stream processing
79
+ class AcciTransformer(VideoTransformerBase):
80
+ def __init__(self):
81
+ self.buffer = deque(maxlen=NUM_FRAMES)
82
+
83
+ def transform(self, frame):
84
+ img = frame.to_ndarray(format="bgr24")
85
+ rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
86
+ resized = cv2.resize(rgb, (224,224))
87
+ self.buffer.append(resized)
88
+ if len(self.buffer) == NUM_FRAMES:
89
+ inputs = extractor(list(self.buffer), return_tensors="pt")
90
+ pixel_values = inputs['pixel_values'].to(DEVICE)
91
+ with torch.no_grad():
92
+ outputs = model(pixel_values=pixel_values).logits
93
+ prob = torch.softmax(outputs, dim=1)[0,1].item()
94
+ cv2.putText(img, f"Prob: {prob:.2f}", (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2)
95
+ return img
96
+
97
+ webrtc_streamer(
98
+ key="dashcam-webcam",
99
+ mode="recv",
100
+ rtc_configuration=RTCConfiguration({
101
+ "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
102
+ }),
103
+ video_transformer_factory=AcciTransformer
104
+ )