Spaces:
Running
Running
import os | |
import shutil | |
from PIL import Image | |
import ffmpeg | |
import streamlit as st | |
import torch | |
from transformers import AutoProcessor, AutoModel | |
from src.lstm_model import LSTMNetwork | |
from src.frames import extract_frames, convert_to_mp4 | |
# Required dictionary | |
idx_to_class = {0: 'cover', 1: 'defense', 2: 'flick', 3: 'hook', 4: 'late_cut', | |
5: 'lofted', 6: 'pull', 7: 'square_cut', 8: 'straight', 9: 'sweep'} | |
class_label_mapping = {'cover': 0, 'defense': 1, 'flick': 2, 'hook': 3, 'late_cut': 4, | |
'lofted': 5, 'pull': 6, 'square_cut': 7, 'straight': 8, 'sweep': 9} | |
# Definig the paths | |
CLIP_MODEL_PATH = "clip-cricket-classifier.pt" | |
SIGLIP_MODEL_PATH = "siglip-cricket-classifier.pt" | |
CLIP_MODEL_ID = "openai/clip-vit-base-patch32" | |
SIGLIP_MODEL_ID = "google/siglip-base-patch16-224" | |
def embeddings_creators(MODEL_ID): | |
embedding_processor = AutoProcessor.from_pretrained(MODEL_ID) | |
embedding_model = AutoModel.from_pretrained(MODEL_ID) | |
embedding_model.to(device) | |
return embedding_processor, embedding_model | |
def load_model(MODEL_PATH): | |
if MODEL_PATH == CLIP_MODEL_PATH: | |
input_size = 512 | |
elif MODEL_PATH == SIGLIP_MODEL_PATH: | |
input_size = 768 | |
else: | |
raise ValueError(f"Invalid model path: {MODEL_PATH}") | |
model = LSTMNetwork(input_size=input_size, hidden_size=256, num_classes=10).to(device) | |
model.load_state_dict(torch.load(MODEL_PATH)) | |
return model | |
# device | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def app(): | |
st.image("assets/banner.png") | |
st.title("Cricket Shot Classifier", anchor=False) | |
model_choice = st.radio("Select a model", ["None", "CLIP", "SIGLIP"]) | |
if model_choice == "None": | |
st.stop() | |
st.write("Please select a model") | |
if model_choice == "CLIP": | |
embedding_processor, embedding_model = embeddings_creators(CLIP_MODEL_ID) | |
model = load_model(CLIP_MODEL_PATH) | |
elif model_choice == "SIGLIP": | |
embedding_processor, embedding_model = embeddings_creators(SIGLIP_MODEL_ID) | |
model = load_model(SIGLIP_MODEL_PATH) | |
# List sample videos from assets folder | |
sample_videos = [f for f in os.listdir("assets") if f.endswith(('.avi'))] | |
if not sample_videos: | |
st.error("No sample videos found in assets folder") | |
st.stop() | |
selected_video = st.selectbox("Select a sample video", sample_videos) | |
video_path = os.path.join("assets", selected_video) | |
save_directory = './demo' | |
os.makedirs(save_directory, exist_ok=True) | |
new_video_path = f"{save_directory}/{selected_video}" | |
shutil.copy2(video_path, new_video_path) | |
final_video_path = f"{save_directory}/{os.path.splitext(os.path.basename(new_video_path))[0]}.mp4" | |
if not new_video_path.lower().endswith('.mp4'): | |
convert_to_mp4(new_video_path, final_video_path) | |
else: | |
final_video_path = new_video_path | |
st.video(final_video_path) | |
frames_dir = f"{save_directory}/frames" | |
os.makedirs(frames_dir, exist_ok=True) | |
extract_frames(final_video_path, frames_dir) | |
st.write("Frames extracted from the video.") | |
inference_paths = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith(('.jpg', '.jpeg', '.png'))] | |
inference_images = [Image.open(path).convert("RGB") for path in inference_paths] | |
tokens = embedding_processor( | |
text=None, | |
images=inference_images, | |
return_tensors="pt" | |
).to(device) | |
inference_embeddings = embedding_model.get_image_features(**tokens) | |
with torch.no_grad(): | |
output = model(inference_embeddings.unsqueeze(0)) | |
prob = output.softmax(dim=1) | |
_, indices = torch.sort(prob[0], descending=True) | |
for idx in indices: | |
i = idx.item() | |
st.write(f"Prediction: {idx_to_class[i]}") | |
st.progress(int(prob[0][i].item() * 100)) | |
try: | |
shutil.rmtree(frames_dir) | |
os.remove(new_video_path) | |
os.remove(final_video_path) | |
print(f"Folder '{frames_dir}' and its contents have been deleted.") | |
except Exception as e: | |
print(f"Error while deleting folder '{frames_dir}': {e}") | |
if __name__ == "__main__": | |
app() | |