Luigi commited on
Commit
9651249
·
1 Parent(s): 413c57e

show top 5 result

Browse files
Files changed (1) hide show
  1. app.py +44 -19
app.py CHANGED
@@ -2,14 +2,16 @@ import spaces # Import spaces immediately for HF ZeroGPU support.
2
  import os
3
  import cv2
4
  import torch
5
- import yt_dlp # (Retained in requirements for potential video fetching use)
6
  import numpy as np
 
 
7
  from PIL import Image
8
- import gradio as gr
9
  from transformers import AutoFeatureExtractor, AutoModelForVideoClassification
10
 
11
  # Specify the model checkpoint for TimeSformer.
12
- MODEL_NAME = "facebook/timesformer-base-finetuned-k400"
13
 
14
  def extract_frames(video_path, num_frames=16, target_size=(224, 224)):
15
  """
@@ -40,21 +42,23 @@ def extract_frames(video_path, num_frames=16, target_size=(224, 224)):
40
  def classify_video(video_path):
41
  """
42
  Loads the TimeSformer model and feature extractor inside the GPU context,
43
- extracts frames from the video, runs inference, and returns the top 5 predicted actions.
 
 
44
  """
45
  # Load the feature extractor and model inside the GPU context.
46
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
47
  model = AutoModelForVideoClassification.from_pretrained(MODEL_NAME)
48
  model.eval()
49
-
50
  # Determine the device.
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
  model.to(device)
53
 
54
- # Extract frames from the video (sampling 16 frames).
55
  frames = extract_frames(video_path, num_frames=16, target_size=(224, 224))
56
  if len(frames) == 0:
57
- return "No frames extracted from video."
58
 
59
  # Preprocess the frames.
60
  inputs = feature_extractor(frames, return_tensors="pt")
@@ -64,8 +68,8 @@ def classify_video(video_path):
64
  with torch.no_grad():
65
  outputs = model(**inputs)
66
 
67
- # Compute softmax probabilities from logits.
68
- logits = outputs.logits # shape: [batch_size, num_classes] with batch_size=1
69
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
70
 
71
  # Get the top 5 predictions.
@@ -73,31 +77,52 @@ def classify_video(video_path):
73
  top_probs = top_probs.cpu().numpy()
74
  top_indices = top_indices.cpu().numpy()
75
 
76
- # Retrieve the label mapping from the model config.
77
  id2label = model.config.id2label if hasattr(model.config, "id2label") else {}
 
 
78
  results = []
 
79
  for idx, prob in zip(top_indices, top_probs):
80
  label = id2label.get(str(idx), f"Class {idx}")
81
- results.append(f"{label}: {prob:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- return "\n".join(results)
84
 
85
  def process_video(video_file):
86
  if video_file is None:
87
- return "No video provided."
88
- result = classify_video(video_file)
89
- return result
90
 
91
  # Gradio interface definition.
92
  demo = gr.Interface(
93
  fn=process_video,
94
- inputs=gr.Video(sources=["upload"], label="Upload Video Clip"),
95
- outputs=gr.Textbox(label="Predicted Actions"),
 
 
 
96
  title="Video Human Detection Demo using TimeSformer",
97
  description=(
98
  "Upload a video clip to see the top predicted human action labels using the TimeSformer model "
99
- "(fine-tuned on Kinetics-400). This demo loads the model and feature extractor within the GPU context "
100
- "for optimized inference in Hugging Face ZeroGPU Spaces while also supporting CPU-only environments."
101
  )
102
  )
103
 
 
2
  import os
3
  import cv2
4
  import torch
5
+ import gradio as gr
6
  import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from io import BytesIO
9
  from PIL import Image
10
+
11
  from transformers import AutoFeatureExtractor, AutoModelForVideoClassification
12
 
13
  # Specify the model checkpoint for TimeSformer.
14
+ MODEL_NAME = "microsoft/timesformer-base-finetuned-k400"
15
 
16
  def extract_frames(video_path, num_frames=16, target_size=(224, 224)):
17
  """
 
42
  def classify_video(video_path):
43
  """
44
  Loads the TimeSformer model and feature extractor inside the GPU context,
45
+ extracts frames from the video, runs inference, and returns:
46
+ 1. A text string of the top 5 predicted action labels with their class IDs and probabilities.
47
+ 2. A bar chart image showing the distribution over the top predictions.
48
  """
49
  # Load the feature extractor and model inside the GPU context.
50
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
51
  model = AutoModelForVideoClassification.from_pretrained(MODEL_NAME)
52
  model.eval()
53
+
54
  # Determine the device.
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
  model.to(device)
57
 
58
+ # Extract frames from the video.
59
  frames = extract_frames(video_path, num_frames=16, target_size=(224, 224))
60
  if len(frames) == 0:
61
+ return "No frames extracted from video.", None
62
 
63
  # Preprocess the frames.
64
  inputs = feature_extractor(frames, return_tensors="pt")
 
68
  with torch.no_grad():
69
  outputs = model(**inputs)
70
 
71
+ # Get logits and compute probabilities.
72
+ logits = outputs.logits # shape: [batch_size, num_classes] with batch_size=1.
73
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
74
 
75
  # Get the top 5 predictions.
 
77
  top_probs = top_probs.cpu().numpy()
78
  top_indices = top_indices.cpu().numpy()
79
 
80
+ # Retrieve the label mapping from model config.
81
  id2label = model.config.id2label if hasattr(model.config, "id2label") else {}
82
+
83
+ # Prepare textual results showing both ID and label.
84
  results = []
85
+ x_labels = []
86
  for idx, prob in zip(top_indices, top_probs):
87
  label = id2label.get(str(idx), f"Class {idx}")
88
+ results.append(f"ID {idx} - {label}: {prob:.3f}")
89
+ x_labels.append(f"ID {idx}\n{label}")
90
+ results_text = "\n".join(results)
91
+
92
+ # Create a bar chart for the distribution.
93
+ fig, ax = plt.subplots(figsize=(8, 4))
94
+ ax.bar(x_labels, top_probs, color="skyblue")
95
+ ax.set_ylabel("Probability")
96
+ ax.set_title("Top 5 Prediction Distribution")
97
+ plt.xticks(rotation=45, ha="right")
98
+ plt.tight_layout()
99
+
100
+ buf = BytesIO()
101
+ plt.savefig(buf, format="png")
102
+ buf.seek(0)
103
+ plt.close(fig)
104
 
105
+ return results_text, buf
106
 
107
  def process_video(video_file):
108
  if video_file is None:
109
+ return "No video provided.", None
110
+ result_text, plot_img = classify_video(video_file)
111
+ return result_text, plot_img
112
 
113
  # Gradio interface definition.
114
  demo = gr.Interface(
115
  fn=process_video,
116
+ inputs=gr.Video(source="upload", label="Upload Video Clip"),
117
+ outputs=[
118
+ gr.Textbox(label="Predicted Actions"),
119
+ gr.Image(label="Prediction Distribution")
120
+ ],
121
  title="Video Human Detection Demo using TimeSformer",
122
  description=(
123
  "Upload a video clip to see the top predicted human action labels using the TimeSformer model "
124
+ "(fine-tuned on Kinetics-400). The output shows each prediction along with its class ID and probability, "
125
+ "and a bar chart displays the distribution of the top 5 predictions."
126
  )
127
  )
128