Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import cv2 | |
import os | |
import gradio as gr | |
from PIL import Image | |
from transformers import ViTImageProcessor, ViTModel | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MODEL_PATH = "Vit_LSTM.pth" | |
NUM_CLASSES = 5 | |
MAX_FRAMES = 16 | |
class ViT_LSTM(nn.Module): | |
def __init__(self, feature_dim=768, hidden_dim=512, num_classes=NUM_CLASSES): | |
super(ViT_LSTM, self).__init__() | |
self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True) | |
self.fc = nn.Linear(hidden_dim * 2, num_classes) | |
self.dropout = nn.Dropout(0.3) | |
def forward(self, x): | |
lstm_out, _ = self.lstm(x) | |
lstm_out = lstm_out[:, -1, :] | |
out = self.dropout(lstm_out) | |
out = self.fc(out) | |
return out | |
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | |
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE) | |
model = ViT_LSTM() | |
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
model.to(DEVICE) | |
model.eval() | |
LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"] | |
def extract_vit_features(video_path, max_frames=MAX_FRAMES): | |
cap = cv2.VideoCapture(video_path) | |
frames = [] | |
frame_count = 0 | |
while cap.isOpened() and frame_count < max_frames: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(Image.fromarray(frame)) | |
frame_count += 1 | |
cap.release() | |
if not frames: | |
return None | |
print(f"Extracted {len(frames)} frames from video.") | |
inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"].to(DEVICE) | |
with torch.no_grad(): | |
features = vit_model(inputs).last_hidden_state.mean(dim=1) | |
return features | |
def predict_action(video_file): | |
video_path = video_file.name | |
print(f"Received video path: {video_path}") | |
features = extract_vit_features(video_path) | |
if features is None: | |
return "No frames extracted, please upload a valid video." | |
features = features.unsqueeze(0) | |
with torch.no_grad(): | |
output = model(features) | |
predicted_class = torch.argmax(output, dim=1).item() | |
return f"Predicted Action: {LABELS[predicted_class]}" | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Action Recognition with ViT-LSTM") | |
gr.Markdown("Upload a short video to predict the action.") | |
video_input = gr.File(label="Upload a video") | |
output_text = gr.Textbox(label="Prediction") | |
predict_btn = gr.Button("Predict Action") | |
predict_btn.click(fn=predict_action, inputs=video_input, outputs=output_text) | |
demo.launch() | |