import torch import torch.nn as nn import cv2 import os import gradio as gr from PIL import Image from transformers import ViTImageProcessor, ViTModel DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_PATH = "Vit_LSTM.pth" NUM_CLASSES = 5 MAX_FRAMES = 16 class ViT_LSTM(nn.Module): def __init__(self, feature_dim=768, hidden_dim=512, num_classes=NUM_CLASSES): super(ViT_LSTM, self).__init__() self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True) self.fc = nn.Linear(hidden_dim * 2, num_classes) self.dropout = nn.Dropout(0.3) def forward(self, x): lstm_out, _ = self.lstm(x) lstm_out = lstm_out[:, -1, :] out = self.dropout(lstm_out) out = self.fc(out) return out vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE) model = ViT_LSTM() model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) model.to(DEVICE) model.eval() LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"] def extract_vit_features(video_path, max_frames=MAX_FRAMES): cap = cv2.VideoCapture(video_path) frames = [] frame_count = 0 while cap.isOpened() and frame_count < max_frames: ret, frame = cap.read() if not ret: break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame)) frame_count += 1 cap.release() if not frames: return None print(f"Extracted {len(frames)} frames from video.") inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"].to(DEVICE) with torch.no_grad(): features = vit_model(inputs).last_hidden_state.mean(dim=1) return features def predict_action(video_file): video_path = video_file.name print(f"Received video path: {video_path}") features = extract_vit_features(video_path) if features is None: return "No frames extracted, please upload a valid video." features = features.unsqueeze(0) with torch.no_grad(): output = model(features) predicted_class = torch.argmax(output, dim=1).item() return f"Predicted Action: {LABELS[predicted_class]}" # Gradio Interface with gr.Blocks() as demo: gr.Markdown("# Action Recognition with ViT-LSTM") gr.Markdown("Upload a short video to predict the action.") video_input = gr.File(label="Upload a video") output_text = gr.Textbox(label="Prediction") predict_btn = gr.Button("Predict Action") predict_btn.click(fn=predict_action, inputs=video_input, outputs=output_text) demo.launch()