show top 5 result
Browse files
app.py
CHANGED
@@ -2,14 +2,16 @@ import spaces # Import spaces immediately for HF ZeroGPU support.
|
|
2 |
import os
|
3 |
import cv2
|
4 |
import torch
|
5 |
-
import
|
6 |
import numpy as np
|
|
|
|
|
7 |
from PIL import Image
|
8 |
-
|
9 |
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification
|
10 |
|
11 |
# Specify the model checkpoint for TimeSformer.
|
12 |
-
MODEL_NAME = "
|
13 |
|
14 |
def extract_frames(video_path, num_frames=16, target_size=(224, 224)):
|
15 |
"""
|
@@ -40,21 +42,23 @@ def extract_frames(video_path, num_frames=16, target_size=(224, 224)):
|
|
40 |
def classify_video(video_path):
|
41 |
"""
|
42 |
Loads the TimeSformer model and feature extractor inside the GPU context,
|
43 |
-
extracts frames from the video, runs inference, and returns
|
|
|
|
|
44 |
"""
|
45 |
# Load the feature extractor and model inside the GPU context.
|
46 |
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
|
47 |
model = AutoModelForVideoClassification.from_pretrained(MODEL_NAME)
|
48 |
model.eval()
|
49 |
-
|
50 |
# Determine the device.
|
51 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
52 |
model.to(device)
|
53 |
|
54 |
-
# Extract frames from the video
|
55 |
frames = extract_frames(video_path, num_frames=16, target_size=(224, 224))
|
56 |
if len(frames) == 0:
|
57 |
-
return "No frames extracted from video."
|
58 |
|
59 |
# Preprocess the frames.
|
60 |
inputs = feature_extractor(frames, return_tensors="pt")
|
@@ -64,8 +68,8 @@ def classify_video(video_path):
|
|
64 |
with torch.no_grad():
|
65 |
outputs = model(**inputs)
|
66 |
|
67 |
-
#
|
68 |
-
logits = outputs.logits # shape: [batch_size, num_classes] with batch_size=1
|
69 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
70 |
|
71 |
# Get the top 5 predictions.
|
@@ -73,31 +77,52 @@ def classify_video(video_path):
|
|
73 |
top_probs = top_probs.cpu().numpy()
|
74 |
top_indices = top_indices.cpu().numpy()
|
75 |
|
76 |
-
# Retrieve the label mapping from
|
77 |
id2label = model.config.id2label if hasattr(model.config, "id2label") else {}
|
|
|
|
|
78 |
results = []
|
|
|
79 |
for idx, prob in zip(top_indices, top_probs):
|
80 |
label = id2label.get(str(idx), f"Class {idx}")
|
81 |
-
results.append(f"{label}: {prob:.3f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
return
|
84 |
|
85 |
def process_video(video_file):
|
86 |
if video_file is None:
|
87 |
-
return "No video provided."
|
88 |
-
|
89 |
-
return
|
90 |
|
91 |
# Gradio interface definition.
|
92 |
demo = gr.Interface(
|
93 |
fn=process_video,
|
94 |
-
inputs=gr.Video(
|
95 |
-
outputs=
|
|
|
|
|
|
|
96 |
title="Video Human Detection Demo using TimeSformer",
|
97 |
description=(
|
98 |
"Upload a video clip to see the top predicted human action labels using the TimeSformer model "
|
99 |
-
"(fine-tuned on Kinetics-400).
|
100 |
-
"
|
101 |
)
|
102 |
)
|
103 |
|
|
|
2 |
import os
|
3 |
import cv2
|
4 |
import torch
|
5 |
+
import gradio as gr
|
6 |
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from io import BytesIO
|
9 |
from PIL import Image
|
10 |
+
|
11 |
from transformers import AutoFeatureExtractor, AutoModelForVideoClassification
|
12 |
|
13 |
# Specify the model checkpoint for TimeSformer.
|
14 |
+
MODEL_NAME = "microsoft/timesformer-base-finetuned-k400"
|
15 |
|
16 |
def extract_frames(video_path, num_frames=16, target_size=(224, 224)):
|
17 |
"""
|
|
|
42 |
def classify_video(video_path):
|
43 |
"""
|
44 |
Loads the TimeSformer model and feature extractor inside the GPU context,
|
45 |
+
extracts frames from the video, runs inference, and returns:
|
46 |
+
1. A text string of the top 5 predicted action labels with their class IDs and probabilities.
|
47 |
+
2. A bar chart image showing the distribution over the top predictions.
|
48 |
"""
|
49 |
# Load the feature extractor and model inside the GPU context.
|
50 |
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
|
51 |
model = AutoModelForVideoClassification.from_pretrained(MODEL_NAME)
|
52 |
model.eval()
|
53 |
+
|
54 |
# Determine the device.
|
55 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
56 |
model.to(device)
|
57 |
|
58 |
+
# Extract frames from the video.
|
59 |
frames = extract_frames(video_path, num_frames=16, target_size=(224, 224))
|
60 |
if len(frames) == 0:
|
61 |
+
return "No frames extracted from video.", None
|
62 |
|
63 |
# Preprocess the frames.
|
64 |
inputs = feature_extractor(frames, return_tensors="pt")
|
|
|
68 |
with torch.no_grad():
|
69 |
outputs = model(**inputs)
|
70 |
|
71 |
+
# Get logits and compute probabilities.
|
72 |
+
logits = outputs.logits # shape: [batch_size, num_classes] with batch_size=1.
|
73 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
74 |
|
75 |
# Get the top 5 predictions.
|
|
|
77 |
top_probs = top_probs.cpu().numpy()
|
78 |
top_indices = top_indices.cpu().numpy()
|
79 |
|
80 |
+
# Retrieve the label mapping from model config.
|
81 |
id2label = model.config.id2label if hasattr(model.config, "id2label") else {}
|
82 |
+
|
83 |
+
# Prepare textual results showing both ID and label.
|
84 |
results = []
|
85 |
+
x_labels = []
|
86 |
for idx, prob in zip(top_indices, top_probs):
|
87 |
label = id2label.get(str(idx), f"Class {idx}")
|
88 |
+
results.append(f"ID {idx} - {label}: {prob:.3f}")
|
89 |
+
x_labels.append(f"ID {idx}\n{label}")
|
90 |
+
results_text = "\n".join(results)
|
91 |
+
|
92 |
+
# Create a bar chart for the distribution.
|
93 |
+
fig, ax = plt.subplots(figsize=(8, 4))
|
94 |
+
ax.bar(x_labels, top_probs, color="skyblue")
|
95 |
+
ax.set_ylabel("Probability")
|
96 |
+
ax.set_title("Top 5 Prediction Distribution")
|
97 |
+
plt.xticks(rotation=45, ha="right")
|
98 |
+
plt.tight_layout()
|
99 |
+
|
100 |
+
buf = BytesIO()
|
101 |
+
plt.savefig(buf, format="png")
|
102 |
+
buf.seek(0)
|
103 |
+
plt.close(fig)
|
104 |
|
105 |
+
return results_text, buf
|
106 |
|
107 |
def process_video(video_file):
|
108 |
if video_file is None:
|
109 |
+
return "No video provided.", None
|
110 |
+
result_text, plot_img = classify_video(video_file)
|
111 |
+
return result_text, plot_img
|
112 |
|
113 |
# Gradio interface definition.
|
114 |
demo = gr.Interface(
|
115 |
fn=process_video,
|
116 |
+
inputs=gr.Video(source="upload", label="Upload Video Clip"),
|
117 |
+
outputs=[
|
118 |
+
gr.Textbox(label="Predicted Actions"),
|
119 |
+
gr.Image(label="Prediction Distribution")
|
120 |
+
],
|
121 |
title="Video Human Detection Demo using TimeSformer",
|
122 |
description=(
|
123 |
"Upload a video clip to see the top predicted human action labels using the TimeSformer model "
|
124 |
+
"(fine-tuned on Kinetics-400). The output shows each prediction along with its class ID and probability, "
|
125 |
+
"and a bar chart displays the distribution of the top 5 predictions."
|
126 |
)
|
127 |
)
|
128 |
|