svsaurav95's picture
Update app.py
ed4a2b8 verified
raw
history blame
2.78 kB
import torch
import torch.nn as nn
import cv2
import os
import gradio as gr
from PIL import Image
from transformers import ViTImageProcessor, ViTModel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "Vit_LSTM.pth"
NUM_CLASSES = 5
MAX_FRAMES = 16
class ViT_LSTM(nn.Module):
def __init__(self, feature_dim=768, hidden_dim=512, num_classes=NUM_CLASSES):
super(ViT_LSTM, self).__init__()
self.lstm = nn.LSTM(feature_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
self.dropout = nn.Dropout(0.3)
def forward(self, x):
lstm_out, _ = self.lstm(x)
lstm_out = lstm_out[:, -1, :]
out = self.dropout(lstm_out)
out = self.fc(out)
return out
vit_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
vit_model = ViTModel.from_pretrained("google/vit-base-patch16-224").to(DEVICE)
model = ViT_LSTM()
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
LABELS = ["BaseballPitch", "Basketball", "BenchPress", "Biking", "Billiards"]
def extract_vit_features(video_path, max_frames=MAX_FRAMES):
cap = cv2.VideoCapture(video_path)
frames = []
frame_count = 0
while cap.isOpened() and frame_count < max_frames:
ret, frame = cap.read()
if not ret:
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame))
frame_count += 1
cap.release()
if not frames:
return None
print(f"Extracted {len(frames)} frames from video.")
inputs = vit_processor(images=frames, return_tensors="pt")["pixel_values"].to(DEVICE)
with torch.no_grad():
features = vit_model(inputs).last_hidden_state.mean(dim=1)
return features
def predict_action(video_file):
video_path = video_file.name
print(f"Received video path: {video_path}")
features = extract_vit_features(video_path)
if features is None:
return "No frames extracted, please upload a valid video."
features = features.unsqueeze(0)
with torch.no_grad():
output = model(features)
predicted_class = torch.argmax(output, dim=1).item()
return f"Predicted Action: {LABELS[predicted_class]}"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# Action Recognition with ViT-LSTM")
gr.Markdown("Upload a short video to predict the action.")
video_input = gr.File(label="Upload a video")
output_text = gr.Textbox(label="Prediction")
predict_btn = gr.Button("Predict Action")
predict_btn.click(fn=predict_action, inputs=video_input, outputs=output_text)
demo.launch()