Spaces:
Running
Running
File size: 2,782 Bytes
ed4a2b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import torch.nn as nn
import cv2
import os
import gradio as gr
from PIL import Image
from transformers import ViTImageProcessor, ViTModel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "Vit_LSTM.pth"
NUM_CLASSES = 5
MAX_FRAMES = 16
class ViT_LSTM(nn.Module):
def __init__(self, feature_dim=768, hidden_dim=512, num_classes=NUM_CLASSES):
super(ViT_LSTM, self).__init__()
self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
lstm_out, _ = self.lstm(x)
lstm_out = lstm_out[:, -1, :]
out = self.dropout(lstm_out)
out = self.fc(out)
return out
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
model = ViT_LSTM()
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"]
def extract_vit_features(video_path, max_frames=MAX_FRAMES):
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
while cap.isOpened() and frame_count < max_frames:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame))
frame_count += 1
cap.release()
if not frames:
return None
print(f"Extracted {len(frames)} frames from video.")
inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"].to(DEVICE)
with torch.no_grad():
features = vit_model(inputs).last_hidden_state.mean(dim=1)
return features
def predict_action(video_file):
video_path = video_file.name
print(f"Received video path: {video_path}")
features = extract_vit_features(video_path)
if features is None:
return "No frames extracted, please upload a valid video."
features = features.unsqueeze(0)
with torch.no_grad():
output = model(features)
predicted_class = torch.argmax(output, dim=1).item()
return f"Predicted Action: {LABELS[predicted_class]}"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Action Recognition with ViT-LSTM")
gr.Markdown("Upload a short video to predict the action.")
video_input = gr.File(label="Upload a video")
output_text = gr.Textbox(label="Prediction")
predict_btn = gr.Button("Predict Action")
predict_btn.click(fn=predict_action, inputs=video_input, outputs=output_text)
demo.launch()
|