jatinmehra's picture
Create app.py
d453003 verified
raw
history blame
3.73 kB
import streamlit as st
import cv2
import tempfile
import numpy as np
import torch
from collections import deque
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase, RTCConfiguration
# Constants
NUM_FRAMES = 16
MODEL_NAME = "jatinmehra/Accident-Detection-using-Dashcam"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@st.cache_resource
def load_model_and_extractor():
extractor = AutoFeatureExtractor.from_pretrained("facebook/timesformer-base-finetuned-k400")
model = AutoModelForVideoClassification.from_pretrained(
MODEL_NAME,
num_labels=2,
ignore_mismatched_sizes=True
).to(DEVICE)
model.eval()
return extractor, model
extractor, model = load_model_and_extractor()
st.title("Dashcam Accident Predictor")
st.write("**higher score = higher accident probability**")
# Function to run inference on a saved video file
def run_inference_on_video(video_path):
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS) or 30
if total_frames <= 0:
st.error("Failed to read video frames.")
return None
# Uniform sampling
indices = np.linspace(0, total_frames-1, NUM_FRAMES, dtype=int)
frames = []
for idx in indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
ret, frame = cap.read()
if not ret:
frames.append(np.zeros((224,224,3), dtype=np.uint8))
else:
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, (224,224))
frames.append(resized)
cap.release()
# Preprocess and predict
inputs = extractor(frames, return_tensors="pt")
pixel_values = inputs['pixel_values'].to(DEVICE)
with torch.no_grad():
outputs = model(pixel_values=pixel_values).logits
prob = torch.softmax(outputs, dim=1)[0,1].item()
return prob
# UI Selection
source = st.radio("Choose input source", ("Upload Video", "Webcam"))
if source == "Upload Video":
uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
if uploaded_file is not None:
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
tfile.write(uploaded_file.read())
st.video(uploaded_file)
st.write("Running inference...")
score = run_inference_on_video(tfile.name)
if score is not None:
st.success(f"Accident probability: {score:.2f}")
else:
# Webcam stream processing
class AcciTransformer(VideoTransformerBase):
def __init__(self):
self.buffer = deque(maxlen=NUM_FRAMES)
def transform(self, frame):
img = frame.to_ndarray(format="bgr24")
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
resized = cv2.resize(rgb, (224,224))
self.buffer.append(resized)
if len(self.buffer) == NUM_FRAMES:
inputs = extractor(list(self.buffer), return_tensors="pt")
pixel_values = inputs['pixel_values'].to(DEVICE)
with torch.no_grad():
outputs = model(pixel_values=pixel_values).logits
prob = torch.softmax(outputs, dim=1)[0,1].item()
cv2.putText(img, f"Prob: {prob:.2f}", (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,0,255), 2)
return img
webrtc_streamer(
key="dashcam-webcam",
mode="recv",
rtc_configuration=RTCConfiguration({
"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
}),
video_transformer_factory=AcciTransformer
)