jatinmehra's picture
Fix ImportError by replacing MediaMode with WebRtcMode in streamlit_webrtc import and webrtc_streamer configuration
89a037c verified
import streamlit as st
import cv2
import tempfile
import numpy as np
import torch
from collections import deque
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration, WebRtcMode
# Constants
NUM_FRAMES = 16
MODEL_NAME = "jatinmehra/Accident-Detection-using-Dashcam"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@st.cache_resource
def load_model_and_extractor():
extractor = AutoFeatureExtractor.from_pretrained("facebook/timesformer-base-finetuned-k400")
model = AutoModelForVideoClassification.from_pretrained(
MODEL_NAME,
num_labels=2,
ignore_mismatched_sizes=True
).to(DEVICE)
model.eval()
return extractor, model
extractor, model = load_model_and_extractor()
st.title("Dashcam Accident Predictor")
st.write("**higher score = higher accident probability**")
# Function to run inference on a saved video file
def run_inference_on_video(video_path):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
if total_frames <= 0:
st.error("Failed to read video frames.")
return None
# Uniform sampling
indices = np.linspace(0, total_frames-1, NUM_FRAMES, dtype=int)
frames = []
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
ret, frame = cap.read()
if not ret:
frames.append(np.zeros((224,224,3), dtype=np.uint8))
else:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, (224,224))
frames.append(resized)
cap.release()
# Preprocess and predict
inputs = extractor(frames, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(DEVICE)
with torch.no_grad():
outputs = model(pixel_values=pixel_values).logits
prob = torch.softmax(outputs, dim=1)[0,1].item()
return prob
# UI Selection
source = st.radio("Choose input source", ("Upload Video", "Webcam"))
if source == "Upload Video":
uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
if uploaded_file is not None:
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tfile.write(uploaded_file.read())
st.video(uploaded_file)
st.write("Running inference...")
score = run_inference_on_video(tfile.name)
if score is not None:
st.success(f"Accident probability: {score:.2f}")
else:
# Webcam stream processing
class AcciTransformer(VideoTransformerBase):
def __init__(self):
self.buffer = deque(maxlen=NUM_FRAMES)
def transform(self, frame):
img = frame.to_ndarray(format="bgr24")
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, (224,224))
self.buffer.append(resized)
if len(self.buffer) == NUM_FRAMES:
inputs = extractor(list(self.buffer), return_tensors="pt")
pixel_values = inputs['pixel_values'].to(DEVICE)
with torch.no_grad():
outputs = model(pixel_values=pixel_values).logits
prob = torch.softmax(outputs, dim=1)[0,1].item()
cv2.putText(img, f"Prob: {prob:.2f}", (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2)
return img
webrtc_streamer(
key="dashcam-webcam",
mode=WebRtcMode.RECVONLY,
rtc_configuration=RTCConfiguration({
"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
}),
video_transformer_factory=AcciTransformer
)