svsaurav95 commited on
Commit
e189a00
·
verified ·
1 Parent(s): cce6dbe

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import os
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from transformers import ViTImageProcessor, ViTModel
8
+
9
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ MODEL_PATH = r"C:\Users\Dell\Downloads\Vit_LSTM.pth"
11
+ NUM_CLASSES = 5
12
+ MAX_FRAMES = 16
13
+
14
+ class ViT_LSTM(nn.Module):
15
+ def __init__(self, feature_dim=768, hidden_dim=512, num_classes=NUM_CLASSES):
16
+ super(ViT_LSTM, self).__init__()
17
+ self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True)
18
+ self.fc = nn.Linear(hidden_dim * 2, num_classes)
19
+ self.dropout = nn.Dropout(0.3)
20
+
21
+ def forward(self, x):
22
+ lstm_out, _ = self.lstm(x)
23
+ lstm_out = lstm_out[:, -1, :]
24
+ out = self.dropout(lstm_out)
25
+ out = self.fc(out)
26
+ return out
27
+
28
+ vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
29
+ vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
30
+
31
+ model = ViT_LSTM()
32
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
33
+ model.to(DEVICE)
34
+ model.eval()
35
+ LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"]
36
+
37
+ def extract_vit_features(video_path, max_frames=MAX_FRAMES):
38
+ cap = cv2.VideoCapture(video_path)
39
+ frames = []
40
+ frame_count = 0
41
+
42
+ while cap.isOpened() and frame_count < max_frames:
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
47
+ frames.append(Image.fromarray(frame))
48
+ frame_count += 1
49
+
50
+ cap.release()
51
+
52
+ if not frames:
53
+ return None
54
+
55
+ print(f"Extracted {len(frames)} frames from video.")
56
+
57
+ inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"].to(DEVICE)
58
+
59
+ with torch.no_grad():
60
+ features = vit_model(inputs).last_hidden_state.mean(dim=1)
61
+
62
+ return features
63
+
64
+ def predict_action(video_file):
65
+ video_path = video_file.name
66
+ print(f"Received video path: {video_path}")
67
+
68
+ features = extract_vit_features(video_path)
69
+
70
+ if features is None:
71
+ return "No frames extracted, please upload a valid video."
72
+
73
+ features = features.unsqueeze(0)
74
+
75
+ with torch.no_grad():
76
+ output = model(features)
77
+ predicted_class = torch.argmax(output, dim=1).item()
78
+
79
+ return f"Predicted Action: {LABELS[predicted_class]}"
80
+
81
+ # Gradio Interface
82
+ with gr.Blocks() as demo:
83
+ gr.Markdown("# Action Recognition with ViT-LSTM")
84
+ gr.Markdown("Upload a short video to predict the action.")
85
+
86
+ video_input = gr.File(label="Upload a video")
87
+ output_text = gr.Textbox(label="Prediction")
88
+
89
+ predict_btn = gr.Button("Predict Action")
90
+ predict_btn.click(fn=predict_action, inputs=video_input, outputs=output_text)
91
+
92
+ demo.launch()