import streamlit as st import cv2 import tempfile import numpy as np import torch from collections import deque from transformers import AutoFeatureExtractor, AutoModelForVideoClassification from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration, WebRtcMode # Constants NUM_FRAMES = 16 MODEL_NAME = "jatinmehra/Accident-Detection-using-Dashcam" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" @st.cache_resource def load_model_and_extractor(): extractor = AutoFeatureExtractor.from_pretrained("facebook/timesformer-base-finetuned-k400") model = AutoModelForVideoClassification.from_pretrained( MODEL_NAME, num_labels=2, ignore_mismatched_sizes=True ).to(DEVICE) model.eval() return extractor, model extractor, model = load_model_and_extractor() st.title("Dashcam Accident Predictor") st.write("**higher score = higher accident probability**") # Function to run inference on a saved video file def run_inference_on_video(video_path): cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fps = cap.get(cv2.CAP_PROP_FPS) or 30 if total_frames <= 0: st.error("Failed to read video frames.") return None # Uniform sampling indices = np.linspace(0, total_frames-1, NUM_FRAMES, dtype=int) frames = [] for idx in indices: cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) ret, frame = cap.read() if not ret: frames.append(np.zeros((224,224,3), dtype=np.uint8)) else: rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) resized = cv2.resize(rgb, (224,224)) frames.append(resized) cap.release() # Preprocess and predict inputs = extractor(frames, return_tensors="pt") pixel_values = inputs['pixel_values'].to(DEVICE) with torch.no_grad(): outputs = model(pixel_values=pixel_values).logits prob = torch.softmax(outputs, dim=1)[0,1].item() return prob # UI Selection source = st.radio("Choose input source", ("Upload Video", "Webcam")) if source == "Upload Video": uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"]) if uploaded_file is not None: tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") tfile.write(uploaded_file.read()) st.video(uploaded_file) st.write("Running inference...") score = run_inference_on_video(tfile.name) if score is not None: st.success(f"Accident probability: {score:.2f}") else: # Webcam stream processing class AcciTransformer(VideoTransformerBase): def __init__(self): self.buffer = deque(maxlen=NUM_FRAMES) def transform(self, frame): img = frame.to_ndarray(format="bgr24") rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) resized = cv2.resize(rgb, (224,224)) self.buffer.append(resized) if len(self.buffer) == NUM_FRAMES: inputs = extractor(list(self.buffer), return_tensors="pt") pixel_values = inputs['pixel_values'].to(DEVICE) with torch.no_grad(): outputs = model(pixel_values=pixel_values).logits prob = torch.softmax(outputs, dim=1)[0,1].item() cv2.putText(img, f"Prob: {prob:.2f}", (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2) return img webrtc_streamer( key="dashcam-webcam", mode=WebRtcMode.RECVONLY, rtc_configuration=RTCConfiguration({ "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}] }), video_transformer_factory=AcciTransformer )