|
import streamlit as st |
|
import cv2 |
|
import tempfile |
|
import numpy as np |
|
import torch |
|
from collections import deque |
|
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification |
|
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration |
|
|
|
|
|
NUM_FRAMES = 16 |
|
MODEL_NAME = "jatinmehra/Accident-Detection-using-Dashcam" |
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
@st.cache_resource |
|
def load_model_and_extractor(): |
|
extractor = AutoFeatureExtractor.from_pretrained("facebook/timesformer-base-finetuned-k400") |
|
model = AutoModelForVideoClassification.from_pretrained( |
|
MODEL_NAME, |
|
num_labels=2, |
|
ignore_mismatched_sizes=True |
|
).to(DEVICE) |
|
model.eval() |
|
return extractor, model |
|
|
|
extractor, model = load_model_and_extractor() |
|
|
|
st.title("Dashcam Accident Predictor") |
|
st.write("**higher score = higher accident probability**") |
|
|
|
|
|
|
|
def run_inference_on_video(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) or 30 |
|
if total_frames <= 0: |
|
st.error("Failed to read video frames.") |
|
return None |
|
|
|
|
|
indices = np.linspace(0, total_frames-1, NUM_FRAMES, dtype=int) |
|
frames = [] |
|
for idx in indices: |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) |
|
ret, frame = cap.read() |
|
if not ret: |
|
frames.append(np.zeros((224,224,3), dtype=np.uint8)) |
|
else: |
|
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
resized = cv2.resize(rgb, (224,224)) |
|
frames.append(resized) |
|
cap.release() |
|
|
|
|
|
inputs = extractor(frames, return_tensors="pt") |
|
pixel_values = inputs['pixel_values'].to(DEVICE) |
|
with torch.no_grad(): |
|
outputs = model(pixel_values=pixel_values).logits |
|
prob = torch.softmax(outputs, dim=1)[0,1].item() |
|
return prob |
|
|
|
|
|
source = st.radio("Choose input source", ("Upload Video", "Webcam")) |
|
|
|
if source == "Upload Video": |
|
uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"]) |
|
if uploaded_file is not None: |
|
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") |
|
tfile.write(uploaded_file.read()) |
|
st.video(uploaded_file) |
|
st.write("Running inference...") |
|
score = run_inference_on_video(tfile.name) |
|
if score is not None: |
|
st.success(f"Accident probability: {score:.2f}") |
|
|
|
else: |
|
|
|
class AcciTransformer(VideoTransformerBase): |
|
def __init__(self): |
|
self.buffer = deque(maxlen=NUM_FRAMES) |
|
|
|
def transform(self, frame): |
|
img = frame.to_ndarray(format="bgr24") |
|
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
resized = cv2.resize(rgb, (224,224)) |
|
self.buffer.append(resized) |
|
if len(self.buffer) == NUM_FRAMES: |
|
inputs = extractor(list(self.buffer), return_tensors="pt") |
|
pixel_values = inputs['pixel_values'].to(DEVICE) |
|
with torch.no_grad(): |
|
outputs = model(pixel_values=pixel_values).logits |
|
prob = torch.softmax(outputs, dim=1)[0,1].item() |
|
cv2.putText(img, f"Prob: {prob:.2f}", (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2) |
|
return img |
|
|
|
webrtc_streamer( |
|
key="dashcam-webcam", |
|
mode="recv", |
|
rtc_configuration=RTCConfiguration({ |
|
"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}] |
|
}), |
|
video_transformer_factory=AcciTransformer |
|
) |