svsaurav95 commited on
Commit
ed4a2b8
·
verified ·
1 Parent(s): e189a00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -92
app.py CHANGED
@@ -1,92 +1,92 @@
1
- import torch
2
- import torch.nn as nn
3
- import cv2
4
- import os
5
- import gradio as gr
6
- from PIL import Image
7
- from transformers import ViTImageProcessor, ViTModel
8
-
9
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- MODEL_PATH = r"C:\Users\Dell\Downloads\Vit_LSTM.pth"
11
- NUM_CLASSES = 5
12
- MAX_FRAMES = 16
13
-
14
- class ViT_LSTM(nn.Module):
15
- def __init__(self, feature_dim=768, hidden_dim=512, num_classes=NUM_CLASSES):
16
- super(ViT_LSTM, self).__init__()
17
- self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True)
18
- self.fc = nn.Linear(hidden_dim * 2, num_classes)
19
- self.dropout = nn.Dropout(0.3)
20
-
21
- def forward(self, x):
22
- lstm_out, _ = self.lstm(x)
23
- lstm_out = lstm_out[:, -1, :]
24
- out = self.dropout(lstm_out)
25
- out = self.fc(out)
26
- return out
27
-
28
- vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
29
- vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
30
-
31
- model = ViT_LSTM()
32
- model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
33
- model.to(DEVICE)
34
- model.eval()
35
- LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"]
36
-
37
- def extract_vit_features(video_path, max_frames=MAX_FRAMES):
38
- cap = cv2.VideoCapture(video_path)
39
- frames = []
40
- frame_count = 0
41
-
42
- while cap.isOpened() and frame_count < max_frames:
43
- ret, frame = cap.read()
44
- if not ret:
45
- break
46
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
47
- frames.append(Image.fromarray(frame))
48
- frame_count += 1
49
-
50
- cap.release()
51
-
52
- if not frames:
53
- return None
54
-
55
- print(f"Extracted {len(frames)} frames from video.")
56
-
57
- inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"].to(DEVICE)
58
-
59
- with torch.no_grad():
60
- features = vit_model(inputs).last_hidden_state.mean(dim=1)
61
-
62
- return features
63
-
64
- def predict_action(video_file):
65
- video_path = video_file.name
66
- print(f"Received video path: {video_path}")
67
-
68
- features = extract_vit_features(video_path)
69
-
70
- if features is None:
71
- return "No frames extracted, please upload a valid video."
72
-
73
- features = features.unsqueeze(0)
74
-
75
- with torch.no_grad():
76
- output = model(features)
77
- predicted_class = torch.argmax(output, dim=1).item()
78
-
79
- return f"Predicted Action: {LABELS[predicted_class]}"
80
-
81
- # Gradio Interface
82
- with gr.Blocks() as demo:
83
- gr.Markdown("# Action Recognition with ViT-LSTM")
84
- gr.Markdown("Upload a short video to predict the action.")
85
-
86
- video_input = gr.File(label="Upload a video")
87
- output_text = gr.Textbox(label="Prediction")
88
-
89
- predict_btn = gr.Button("Predict Action")
90
- predict_btn.click(fn=predict_action, inputs=video_input, outputs=output_text)
91
-
92
- demo.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import os
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from transformers import ViTImageProcessor, ViTModel
8
+
9
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ MODEL_PATH = "Vit_LSTM.pth"
11
+ NUM_CLASSES = 5
12
+ MAX_FRAMES = 16
13
+
14
+ class ViT_LSTM(nn.Module):
15
+ def __init__(self, feature_dim=768, hidden_dim=512, num_classes=NUM_CLASSES):
16
+ super(ViT_LSTM, self).__init__()
17
+ self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True)
18
+ self.fc = nn.Linear(hidden_dim * 2, num_classes)
19
+ self.dropout = nn.Dropout(0.3)
20
+
21
+ def forward(self, x):
22
+ lstm_out, _ = self.lstm(x)
23
+ lstm_out = lstm_out[:, -1, :]
24
+ out = self.dropout(lstm_out)
25
+ out = self.fc(out)
26
+ return out
27
+
28
+ vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
29
+ vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
30
+
31
+ model = ViT_LSTM()
32
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
33
+ model.to(DEVICE)
34
+ model.eval()
35
+ LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"]
36
+
37
+ def extract_vit_features(video_path, max_frames=MAX_FRAMES):
38
+ cap = cv2.VideoCapture(video_path)
39
+ frames = []
40
+ frame_count = 0
41
+
42
+ while cap.isOpened() and frame_count < max_frames:
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
47
+ frames.append(Image.fromarray(frame))
48
+ frame_count += 1
49
+
50
+ cap.release()
51
+
52
+ if not frames:
53
+ return None
54
+
55
+ print(f"Extracted {len(frames)} frames from video.")
56
+
57
+ inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"].to(DEVICE)
58
+
59
+ with torch.no_grad():
60
+ features = vit_model(inputs).last_hidden_state.mean(dim=1)
61
+
62
+ return features
63
+
64
+ def predict_action(video_file):
65
+ video_path = video_file.name
66
+ print(f"Received video path: {video_path}")
67
+
68
+ features = extract_vit_features(video_path)
69
+
70
+ if features is None:
71
+ return "No frames extracted, please upload a valid video."
72
+
73
+ features = features.unsqueeze(0)
74
+
75
+ with torch.no_grad():
76
+ output = model(features)
77
+ predicted_class = torch.argmax(output, dim=1).item()
78
+
79
+ return f"Predicted Action: {LABELS[predicted_class]}"
80
+
81
+ # Gradio Interface
82
+ with gr.Blocks() as demo:
83
+ gr.Markdown("# Action Recognition with ViT-LSTM")
84
+ gr.Markdown("Upload a short video to predict the action.")
85
+
86
+ video_input = gr.File(label="Upload a video")
87
+ output_text = gr.Textbox(label="Prediction")
88
+
89
+ predict_btn = gr.Button("Predict Action")
90
+ predict_btn.click(fn=predict_action, inputs=video_input, outputs=output_text)
91
+
92
+ demo.launch()