Fix ImportError by replacing MediaMode with WebRtcMode in streamlit_webrtc import and webrtc_streamer configuration
89a037c
verified
import streamlit as st | |
import cv2 | |
import tempfile | |
import numpy as np | |
import torch | |
from collections import deque | |
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification | |
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration, WebRtcMode | |
# Constants | |
NUM_FRAMES = 16 | |
MODEL_NAME = "jatinmehra/Accident-Detection-using-Dashcam" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_model_and_extractor(): | |
extractor = AutoFeatureExtractor.from_pretrained("facebook/timesformer-base-finetuned-k400") | |
model = AutoModelForVideoClassification.from_pretrained( | |
MODEL_NAME, | |
num_labels=2, | |
ignore_mismatched_sizes=True | |
).to(DEVICE) | |
model.eval() | |
return extractor, model | |
extractor, model = load_model_and_extractor() | |
st.title("Dashcam Accident Predictor") | |
st.write("**higher score = higher accident probability**") | |
# Function to run inference on a saved video file | |
def run_inference_on_video(video_path): | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
if total_frames <= 0: | |
st.error("Failed to read video frames.") | |
return None | |
# Uniform sampling | |
indices = np.linspace(0, total_frames-1, NUM_FRAMES, dtype=int) | |
frames = [] | |
for idx in indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx)) | |
ret, frame = cap.read() | |
if not ret: | |
frames.append(np.zeros((224,224,3), dtype=np.uint8)) | |
else: | |
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
resized = cv2.resize(rgb, (224,224)) | |
frames.append(resized) | |
cap.release() | |
# Preprocess and predict | |
inputs = extractor(frames, return_tensors="pt") | |
pixel_values = inputs['pixel_values'].to(DEVICE) | |
with torch.no_grad(): | |
outputs = model(pixel_values=pixel_values).logits | |
prob = torch.softmax(outputs, dim=1)[0,1].item() | |
return prob | |
# UI Selection | |
source = st.radio("Choose input source", ("Upload Video", "Webcam")) | |
if source == "Upload Video": | |
uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"]) | |
if uploaded_file is not None: | |
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
tfile.write(uploaded_file.read()) | |
st.video(uploaded_file) | |
st.write("Running inference...") | |
score = run_inference_on_video(tfile.name) | |
if score is not None: | |
st.success(f"Accident probability: {score:.2f}") | |
else: | |
# Webcam stream processing | |
class AcciTransformer(VideoTransformerBase): | |
def __init__(self): | |
self.buffer = deque(maxlen=NUM_FRAMES) | |
def transform(self, frame): | |
img = frame.to_ndarray(format="bgr24") | |
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
resized = cv2.resize(rgb, (224,224)) | |
self.buffer.append(resized) | |
if len(self.buffer) == NUM_FRAMES: | |
inputs = extractor(list(self.buffer), return_tensors="pt") | |
pixel_values = inputs['pixel_values'].to(DEVICE) | |
with torch.no_grad(): | |
outputs = model(pixel_values=pixel_values).logits | |
prob = torch.softmax(outputs, dim=1)[0,1].item() | |
cv2.putText(img, f"Prob: {prob:.2f}", (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2) | |
return img | |
webrtc_streamer( | |
key="dashcam-webcam", | |
mode=WebRtcMode.RECVONLY, | |
rtc_configuration=RTCConfiguration({ | |
"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}] | |
}), | |
video_transformer_factory=AcciTransformer | |
) |