File size: 2,782 Bytes
ed4a2b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import torch.nn as nn
import cv2
import os
import gradio as gr
from PIL import Image
from transformers import ViTImageProcessor, ViTModel

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "Vit_LSTM.pth" 
NUM_CLASSES = 5 
MAX_FRAMES = 16  

class ViT_LSTM(nn.Module):
    def __init__(self, feature_dim=768, hidden_dim=512, num_classes=NUM_CLASSES):
        super(ViT_LSTM, self).__init__()
        self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        lstm_out = lstm_out[:, -1, :]  
        out = self.dropout(lstm_out)
        out = self.fc(out)
        return out

vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)

model = ViT_LSTM()
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"]

def extract_vit_features(video_path, max_frames=MAX_FRAMES):
    cap = cv2.VideoCapture(video_path)
    frames = []
    frame_count = 0

    while cap.isOpened() and frame_count < max_frames:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  
        frames.append(Image.fromarray(frame))
        frame_count += 1

    cap.release()

    if not frames:
        return None

    print(f"Extracted {len(frames)} frames from video.")

    inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"].to(DEVICE)

    with torch.no_grad():
        features = vit_model(inputs).last_hidden_state.mean(dim=1)  

    return features 

def predict_action(video_file):
    video_path = video_file.name  
    print(f"Received video path: {video_path}") 

    features = extract_vit_features(video_path)

    if features is None:
        return "No frames extracted, please upload a valid video."

    features = features.unsqueeze(0) 

    with torch.no_grad():
        output = model(features)
        predicted_class = torch.argmax(output, dim=1).item()

    return f"Predicted Action: {LABELS[predicted_class]}"

#  Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Action Recognition with ViT-LSTM")
    gr.Markdown("Upload a short video to predict the action.")

    video_input = gr.File(label="Upload a video")
    output_text = gr.Textbox(label="Prediction")

    predict_btn = gr.Button("Predict Action")
    predict_btn.click(fn=predict_action, inputs=video_input, outputs=output_text)

demo.launch()