Spaces:
Running
on
Zero
Running
on
Zero
import spaces # Import spaces immediately for HF ZeroGPU support. | |
import os | |
import cv2 | |
import torch | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
from PIL import Image | |
import csv | |
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification | |
# Specify the model checkpoint for TimeSformer. | |
MODEL_NAME = "facebook/timesformer-base-finetuned-k400" | |
def load_kinetics_labels(csv_path="kinetics-400-class-names.csv"): | |
""" | |
Loads the Kinetics-400 labels from a CSV file. | |
Expected CSV format: | |
id,name | |
0,abseiling | |
1,air drumming | |
... | |
399,zumba | |
Returns a dictionary mapping string IDs to label names. | |
""" | |
labels = {} | |
try: | |
with open(csv_path, "r", encoding="utf-8") as f: | |
reader = csv.reader(f) | |
# Skip header if present | |
header = next(reader) | |
if "id" not in header[0].lower(): | |
f.seek(0) | |
reader = csv.reader(f) | |
for row in reader: | |
if len(row) >= 2: | |
labels[row[0].strip()] = row[1].strip() | |
except Exception as e: | |
print("Error reading CSV mapping:", e) | |
return labels | |
def extract_frames(video_path, num_frames=16, target_size=(224, 224)): | |
""" | |
Extract up to `num_frames` uniformly-sampled frames from the video. | |
If the video has fewer frames, all frames are returned. | |
""" | |
cap = cv2.VideoCapture(video_path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frames = [] | |
if total_frames <= 0: | |
cap.release() | |
return frames | |
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
current_frame = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
if current_frame in indices: | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frame = cv2.resize(frame, target_size) | |
frames.append(Image.fromarray(frame)) | |
current_frame += 1 | |
cap.release() | |
return frames | |
def classify_video(video_path): | |
""" | |
Loads the TimeSformer model and feature extractor inside the GPU context, | |
extracts frames from the video, runs inference, and returns: | |
1. A text string of the top 5 predicted actions (with class ID and descriptive label) | |
along with their probabilities. | |
2. A bar chart (as a PIL Image) showing the prediction distribution. | |
""" | |
# Load the feature extractor and model. | |
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) | |
model = AutoModelForVideoClassification.from_pretrained(MODEL_NAME) | |
model.eval() | |
# Load the complete Kinetics-400 mapping from CSV. | |
kinetics_id2label = load_kinetics_labels("kinetics-400-class-names.csv") | |
if kinetics_id2label: | |
print("Loaded complete Kinetics-400 mapping from CSV.") | |
else: | |
print("Warning: Could not load Kinetics-400 mapping; using default labels.") | |
model.config.id2label = kinetics_id2label if kinetics_id2label else model.config.id2label | |
# Determine the device. | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Extract frames from the video. | |
frames = extract_frames(video_path, num_frames=16, target_size=(224, 224)) | |
if len(frames) == 0: | |
return "No frames extracted from video.", None | |
# Preprocess the frames. | |
inputs = feature_extractor(frames, return_tensors="pt") | |
inputs = {key: val.to(device) for key, val in inputs.items()} | |
# Run inference. | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get logits and compute probabilities. | |
logits = outputs.logits # shape: [batch_size, num_classes] with batch_size=1. | |
probs = torch.nn.functional.softmax(logits, dim=-1)[0] | |
# Get the top 5 predictions. | |
top_probs, top_indices = torch.topk(probs, k=5) | |
top_probs = top_probs.cpu().numpy() | |
top_indices = top_indices.cpu().numpy() | |
# Prepare textual results including both ID and label. | |
results = [] | |
x_labels = [] | |
for idx, prob in zip(top_indices, top_probs): | |
label = kinetics_id2label.get(str(idx), f"Class {idx}") | |
results.append(f"ID {idx} - {label}: {prob:.3f}") | |
x_labels.append(f"ID {idx}\n{label}") | |
results_text = "\n".join(results) | |
# Create a bar chart for the distribution. | |
fig, ax = plt.subplots(figsize=(8, 4)) | |
ax.bar(x_labels, top_probs, color="skyblue") | |
ax.set_ylabel("Probability") | |
ax.set_title("Top 5 Prediction Distribution") | |
plt.xticks(rotation=45, ha="right") | |
plt.tight_layout() | |
buf = BytesIO() | |
plt.savefig(buf, format="png") | |
buf.seek(0) | |
plt.close(fig) | |
# Convert the BytesIO plot to a PIL Image. | |
chart_image = Image.open(buf) | |
return results_text, chart_image | |
def process_video(video_file): | |
if video_file is None: | |
return "No video provided.", None | |
result_text, plot_img = classify_video(video_file) | |
return result_text, plot_img | |
# Gradio interface definition. | |
demo = gr.Interface( | |
fn=process_video, | |
inputs=gr.Video(sources=["upload"], label="Upload Video Clip"), | |
outputs=[ | |
gr.Textbox(label="Predicted Actions"), | |
gr.Image(label="Prediction Distribution") | |
], | |
title="Video Human Detection Demo using TimeSformer", | |
description=( | |
"Upload a video clip to see the top predicted human action labels using the TimeSformer model " | |
"(fine-tuned on Kinetics-400). The output displays each prediction's class ID and label, along with " | |
"a bar chart distribution of the top 5 predictions. A complete Kinetics-400 mapping is loaded from a CSV file." | |
) | |
) | |
if __name__ == "__main__": | |
demo.launch() | |