jarif commited on
Commit
bcae046
·
verified ·
1 Parent(s): 7d8befe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -102
app.py CHANGED
@@ -1,102 +1,86 @@
1
- import gradio as gr
2
- import torch
3
- from pytorchvideo.data.encoded_video import EncodedVideo
4
- from pytorchvideo.transforms import UniformTemporalSubsample
5
- from transformers import VideoMAEForVideoClassification
6
- import torch.nn.functional as F
7
- import torchvision.transforms.functional as F_t # Changed import
8
-
9
- # Check device
10
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
-
12
- # Load pre-trained model
13
- model_path = "model"
14
- loaded_model = VideoMAEForVideoClassification.from_pretrained(model_path)
15
- loaded_model = loaded_model.to(device)
16
- loaded_model.eval()
17
-
18
- # Label names for prediction
19
- label_names = [
20
- 'Archery', 'BalanceBeam', 'BenchPress', 'ApplyEyeMakeup', 'BasketballDunk',
21
- 'BandMarching', 'BabyCrawling', 'ApplyLipstick', 'BaseballPitch', 'Basketball'
22
- ]
23
-
24
- def load_video(video_path):
25
- try:
26
- video = EncodedVideo.from_path(video_path)
27
- video_data = video.get_clip(start_sec=0, end_sec=video.duration)
28
- return video_data['video']
29
- except Exception as e:
30
- raise ValueError(f"Error loading video: {str(e)}")
31
-
32
- def preprocess_video(video_frames):
33
- try:
34
- # Temporal subsampling
35
- transform_temporal = UniformTemporalSubsample(16)
36
- video_frames = transform_temporal(video_frames)
37
-
38
- # Convert to float and normalize to [0, 1]
39
- video_frames = video_frames.float() / 255.0
40
-
41
- # Ensure channel order is [T, C, H, W]
42
- if video_frames.shape[0] == 3:
43
- video_frames = video_frames.permute(1, 0, 2, 3)
44
-
45
- # Normalize using torchvision's functional transform
46
- mean = torch.tensor([0.485, 0.456, 0.406])
47
- std = torch.tensor([0.229, 0.224, 0.225])
48
-
49
- for t in range(video_frames.shape[0]):
50
- video_frames[t] = F_t.normalize(video_frames[t], mean, std)
51
-
52
- # Resize frames
53
- video_frames = torch.stack([
54
- F_t.resize(frame, [224, 224], antialias=True)
55
- for frame in video_frames
56
- ])
57
-
58
- # Add batch dimension
59
- video_frames = video_frames.unsqueeze(0)
60
- return video_frames
61
- except Exception as e:
62
- raise ValueError(f"Error preprocessing video: {str(e)}")
63
-
64
- def predict_video(video):
65
- try:
66
- # Load and preprocess video
67
- video_data = load_video(video)
68
- processed_video = preprocess_video(video_data)
69
- processed_video = processed_video.to(device)
70
-
71
- # Make predictions
72
- with torch.no_grad():
73
- outputs = loaded_model(processed_video)
74
- logits = outputs.logits
75
- probabilities = F.softmax(logits, dim=-1)[0]
76
- top_3 = torch.topk(probabilities, 3)
77
-
78
- # Format results
79
- results = [
80
- f"{label_names[idx.item()]}: {prob.item():.2%}"
81
- for idx, prob in zip(top_3.indices, top_3.values)
82
- ]
83
- return "\n".join(results)
84
- except Exception as e:
85
- return f"Error processing video: {str(e)}"
86
-
87
- # Gradio interface
88
- iface = gr.Interface(
89
- fn=predict_video,
90
- inputs=gr.Video(label="Upload Video"),
91
- outputs=gr.Textbox(label="Top 3 Predictions"),
92
- title="Video Action Recognition",
93
- description="Upload a video to classify the action being performed. The model will return the top 3 predictions.",
94
- examples=[
95
- ["test_video_1.avi"],
96
- ["test_video_2.avi"],
97
- ["test_video_3.avi"]
98
- ]
99
- )
100
-
101
- if __name__ == "__main__":
102
- iface.launch(debug=True, share=True)
 
1
+ import gradio as gr
2
+ import torch
3
+ from pytorchvideo.data.encoded_video import EncodedVideo
4
+ from torchvision.transforms import Resize
5
+ from pytorchvideo.transforms import UniformTemporalSubsample
6
+ from transformers import VideoMAEForVideoClassification
7
+
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ model_path = "model"
11
+ loaded_model = VideoMAEForVideoClassification.from_pretrained(model_path)
12
+ loaded_model = loaded_model.to(device)
13
+ loaded_model.eval()
14
+
15
+ label_names = [
16
+ 'Archery', 'BalanceBeam', 'BenchPress', 'ApplyEyeMakeup', 'BasketballDunk',
17
+ 'BandMarching', 'BabyCrawling', 'ApplyLipstick', 'BaseballPitch', 'Basketball'
18
+ ]
19
+
20
+ def load_video(video_path):
21
+ try:
22
+ video = EncodedVideo.from_path(video_path)
23
+ video_data = video.get_clip(start_sec=0, end_sec=video.duration)
24
+ return video_data['video']
25
+ except Exception as e:
26
+ raise ValueError(f"Error loading video: {str(e)}")
27
+
28
+ def preprocess_video(video_frames):
29
+ try:
30
+ transform_temporal = UniformTemporalSubsample(16)
31
+ video_frames = transform_temporal(video_frames)
32
+ video_frames = video_frames / 255.0
33
+
34
+ if video_frames.shape[0] == 3:
35
+ video_frames = video_frames.permute(1, 0, 2, 3)
36
+
37
+ mean = torch.tensor([0.485, 0.456, 0.406])
38
+ std = torch.tensor([0.229, 0.224, 0.225])
39
+ for t in range(video_frames.shape[0]):
40
+ video_frames[t] = (video_frames[t] - mean[:, None, None]) / std[:, None, None]
41
+
42
+ resize_transform = Resize((224, 224))
43
+ video_frames = resize_transform(video_frames)
44
+ video_frames = video_frames.unsqueeze(0)
45
+
46
+ return video_frames
47
+ except Exception as e:
48
+ raise ValueError(f"Error preprocessing video: {str(e)}")
49
+
50
+ def predict_video(video):
51
+ try:
52
+ video_path = video.name
53
+ video_data = load_video(video_path)
54
+ processed_video = preprocess_video(video_data)
55
+ processed_video = processed_video.to(device)
56
+
57
+ with torch.no_grad():
58
+ outputs = loaded_model(processed_video)
59
+ logits = outputs.logits
60
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
61
+ top_3 = torch.topk(probabilities, 3)
62
+
63
+ results = []
64
+ for i in range(3):
65
+ idx = top_3.indices[i].item()
66
+ prob = top_3.values[i].item()
67
+ results.append(f"{label_names[idx]}: {prob*100:.2f}%")
68
+
69
+ return "\n".join(results)
70
+ except Exception as e:
71
+ return f"Error processing video: {str(e)}"
72
+
73
+ iface = gr.Interface(
74
+ fn=predict_video,
75
+ inputs=gr.Video(label="Upload Video"),
76
+ outputs=gr.Textbox(label="Top 3 Predictions"),
77
+ title="Video Action Recognition",
78
+ description="Upload a video to classify the action being performed. The model will return the top 3 predictions with their probabilities.",
79
+ examples=[
80
+ ["test_video_1.mp4"],
81
+ ["test_video_2.mp4"],
82
+ ["test_video_3.mp4"]
83
+ ]
84
+ )
85
+
86
+ iface.launch(debug=True, share=True)