Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import cv2 | |
import torch | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
from transformers import CLIPProcessor, CLIPModel | |
# Load LightCLIP model. | |
# Replace "openai/clip-vit-base-patch32" with your LightCLIP model checkpoint if available. | |
MODEL_NAME = "openai/clip-vit-base-patch32" | |
# Define text prompts for fall and non-fall. | |
fall_prompt = "A person falling on the ground." | |
nofall_prompt = "A person standing or walking." | |
# if torch.cuda.is_available(): | |
# text_inputs = {k: v.cuda() for k, v in text_inputs.items()} | |
def extract_frames(video_path, target_size=(224, 224)): | |
""" | |
Extract all frames from the uploaded video and convert them to PIL Image. | |
""" | |
cap = cv2.VideoCapture(video_path) | |
frames = [] | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Convert frame from BGR to RGB and resize. | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame = cv2.resize(frame, target_size) | |
frames.append(Image.fromarray(frame)) | |
cap.release() | |
return frames | |
def process_window(frames_window): | |
""" | |
Process a window of frames and compute the average fall score. | |
""" | |
processor = CLIPProcessor.from_pretrained(MODEL_NAME) | |
model = CLIPModel.from_pretrained(MODEL_NAME) | |
text_inputs = processor(text=[fall_prompt, nofall_prompt], return_tensors="pt", padding=True) | |
inputs = processor(images=frames_window, return_tensors="pt", padding=True) | |
if torch.cuda.is_available(): | |
text_inputs = text_inputs.to(torch.device("cuda")) | |
model = model.to(torch.device("cuda")) | |
inputs = {k: v.cuda() for k, v in inputs.items()} | |
with torch.no_grad(): | |
image_features = model.get_image_features(**inputs) | |
# Normalize embeddings. | |
image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
with torch.no_grad(): | |
text_features = model.get_text_features(**text_inputs) | |
text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
# Compute cosine similarity. | |
sims = (image_features @ text_features.T).cpu().numpy() # shape: (num_frames, 2) | |
# We assume index 0 is for the fall prompt. | |
fall_scores = sims[:, 0] | |
window_score = np.mean(fall_scores) | |
return window_score, fall_scores | |
def detect_fall(video_path, window_size=16, stride=8, threshold=0.8, fps=15): | |
""" | |
Process the video file using a sliding window over frames. | |
Returns a list of timestamps where a fall is detected. | |
""" | |
frames = extract_frames(video_path) | |
if len(frames) < window_size: | |
return "Video too short for inference.", None | |
window_scores = [] | |
window_indices = [] | |
for start in range(0, len(frames) - window_size + 1, stride): | |
window = frames[start:start + window_size] | |
score, _ = process_window(window) | |
window_scores.append(score) | |
window_indices.append(start) | |
detected_events = [] | |
for idx, score in zip(window_indices, window_scores): | |
if score > threshold: | |
time_sec = idx / fps # approximate timestamp | |
detected_events.append(time_sec) | |
result_text = "" | |
if detected_events: | |
result_text = "Fall events detected at (sec): " + ", ".join([f"{t:.1f}" for t in detected_events]) | |
else: | |
result_text = "No fall detected." | |
# Return result and a representative frame for visual reference. | |
rep_frame = frames[len(frames) // 2] | |
return result_text, rep_frame | |
def process_video(video_file): | |
result_text, rep_frame = detect_fall(video_file) | |
return result_text, rep_frame | |
# Gradio interface definition. | |
demo = gr.Interface( | |
fn=process_video, | |
inputs=gr.Video(value="filepath", label="Upload Video Clip"), | |
outputs=[gr.Textbox(label="Detection Results"), gr.Image(label="Representative Frame")], | |
title="LightCLIP Fall Detection Demo", | |
description=( | |
"This demo detects human falls in video clips using a lightweight transformer-based model (LightCLIP). " | |
"A sliding window approach aggregates results over multiple frames to improve precision in complex scenes." | |
) | |
) | |
if __name__ == "__main__": | |
demo.launch() | |