File size: 4,571 Bytes
288784f 9651249 288784f 9651249 288784f 9651249 288784f 9651249 288784f 9651249 288784f 9651249 288784f 9651249 288784f 9651249 a39df26 288784f 9651249 288784f 9651249 a39df26 9651249 288784f 9651249 288784f 9651249 288784f 9651249 288784f 9651249 288784f 9651249 288784f 9651249 288784f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import spaces # Import spaces immediately for HF ZeroGPU support.
import os
import cv2
import torch
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification
# Specify the model checkpoint for TimeSformer.
MODEL_NAME = "microsoft/timesformer-base-finetuned-k400"
def extract_frames(video_path, num_frames=16, target_size=(224, 224)):
"""
Extract up to `num_frames` uniformly-sampled frames from the video.
If the video has fewer frames, all frames are returned.
"""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frames = []
if total_frames <= 0:
cap.release()
return frames
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
current_frame = 0
while True:
ret, frame = cap.read()
if not ret:
break
if current_frame in indices:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, target_size)
frames.append(Image.fromarray(frame))
current_frame += 1
cap.release()
return frames
@spaces.GPU
def classify_video(video_path):
"""
Loads the TimeSformer model and feature extractor inside the GPU context,
extracts frames from the video, runs inference, and returns:
1. A text string of the top 5 predicted action labels with their class IDs and probabilities.
2. A bar chart image showing the distribution over the top predictions.
"""
# Load the feature extractor and model inside the GPU context.
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = AutoModelForVideoClassification.from_pretrained(MODEL_NAME)
model.eval()
# Determine the device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Extract frames from the video.
frames = extract_frames(video_path, num_frames=16, target_size=(224, 224))
if len(frames) == 0:
return "No frames extracted from video.", None
# Preprocess the frames.
inputs = feature_extractor(frames, return_tensors="pt")
inputs = {key: val.to(device) for key, val in inputs.items()}
# Run inference.
with torch.no_grad():
outputs = model(**inputs)
# Get logits and compute probabilities.
logits = outputs.logits # shape: [batch_size, num_classes] with batch_size=1.
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# Get the top 5 predictions.
top_probs, top_indices = torch.topk(probs, k=5)
top_probs = top_probs.cpu().numpy()
top_indices = top_indices.cpu().numpy()
# Retrieve the label mapping from model config.
id2label = model.config.id2label if hasattr(model.config, "id2label") else {}
# Prepare textual results showing both ID and label.
results = []
x_labels = []
for idx, prob in zip(top_indices, top_probs):
label = id2label.get(str(idx), f"Class {idx}")
results.append(f"ID {idx} - {label}: {prob:.3f}")
x_labels.append(f"ID {idx}\n{label}")
results_text = "\n".join(results)
# Create a bar chart for the distribution.
fig, ax = plt.subplots(figsize=(8, 4))
ax.bar(x_labels, top_probs, color="skyblue")
ax.set_ylabel("Probability")
ax.set_title("Top 5 Prediction Distribution")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png")
buf.seek(0)
plt.close(fig)
return results_text, buf
def process_video(video_file):
if video_file is None:
return "No video provided.", None
result_text, plot_img = classify_video(video_file)
return result_text, plot_img
# Gradio interface definition.
demo = gr.Interface(
fn=process_video,
inputs=gr.Video(source="upload", label="Upload Video Clip"),
outputs=[
gr.Textbox(label="Predicted Actions"),
gr.Image(label="Prediction Distribution")
],
title="Video Human Detection Demo using TimeSformer",
description=(
"Upload a video clip to see the top predicted human action labels using the TimeSformer model "
"(fine-tuned on Kinetics-400). The output shows each prediction along with its class ID and probability, "
"and a bar chart displays the distribution of the top 5 predictions."
)
)
if __name__ == "__main__":
demo.launch()
|