cricketshot / app.py
rokmr's picture
Adding app files
077d8c0
import os
import shutil
from PIL import Image
import ffmpeg
import streamlit as st
import torch
from transformers import AutoProcessor, AutoModel
from src.lstm_model import LSTMNetwork
from src.frames import extract_frames, convert_to_mp4
# Required dictionary
idx_to_class = {0: 'cover', 1: 'defense', 2: 'flick', 3: 'hook', 4: 'late_cut',
5: 'lofted', 6: 'pull', 7: 'square_cut', 8: 'straight', 9: 'sweep'}
class_label_mapping = {'cover': 0, 'defense': 1, 'flick': 2, 'hook': 3, 'late_cut': 4,
'lofted': 5, 'pull': 6, 'square_cut': 7, 'straight': 8, 'sweep': 9}
# Definig the paths
CLIP_MODEL_PATH = "clip-cricket-classifier.pt"
SIGLIP_MODEL_PATH = "siglip-cricket-classifier.pt"
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
SIGLIP_MODEL_ID = "google/siglip-base-patch16-224"
def embeddings_creators(MODEL_ID):
embedding_processor = AutoProcessor.from_pretrained(MODEL_ID)
embedding_model = AutoModel.from_pretrained(MODEL_ID)
embedding_model.to(device)
return embedding_processor, embedding_model
def load_model(MODEL_PATH):
if MODEL_PATH == CLIP_MODEL_PATH:
input_size = 512
elif MODEL_PATH == SIGLIP_MODEL_PATH:
input_size = 768
else:
raise ValueError(f"Invalid model path: {MODEL_PATH}")
model = LSTMNetwork(input_size=input_size, hidden_size=256, num_classes=10).to(device)
model.load_state_dict(torch.load(MODEL_PATH))
return model
# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def app():
st.image("assets/banner.png")
st.title("Cricket Shot Classifier", anchor=False)
model_choice = st.radio("Select a model", ["None", "CLIP", "SIGLIP"])
if model_choice == "None":
st.stop()
st.write("Please select a model")
if model_choice == "CLIP":
embedding_processor, embedding_model = embeddings_creators(CLIP_MODEL_ID)
model = load_model(CLIP_MODEL_PATH)
elif model_choice == "SIGLIP":
embedding_processor, embedding_model = embeddings_creators(SIGLIP_MODEL_ID)
model = load_model(SIGLIP_MODEL_PATH)
# List sample videos from assets folder
sample_videos = [f for f in os.listdir("assets") if f.endswith(('.avi'))]
if not sample_videos:
st.error("No sample videos found in assets folder")
st.stop()
selected_video = st.selectbox("Select a sample video", sample_videos)
video_path = os.path.join("assets", selected_video)
save_directory = './demo'
os.makedirs(save_directory, exist_ok=True)
new_video_path = f"{save_directory}/{selected_video}"
shutil.copy2(video_path, new_video_path)
final_video_path = f"{save_directory}/{os.path.splitext(os.path.basename(new_video_path))[0]}.mp4"
if not new_video_path.lower().endswith('.mp4'):
convert_to_mp4(new_video_path, final_video_path)
else:
final_video_path = new_video_path
st.video(final_video_path)
frames_dir = f"{save_directory}/frames"
os.makedirs(frames_dir, exist_ok=True)
extract_frames(final_video_path, frames_dir)
st.write("Frames extracted from the video.")
inference_paths = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
inference_images = [Image.open(path).convert("RGB") for path in inference_paths]
tokens = embedding_processor(
text=None,
images=inference_images,
return_tensors="pt"
).to(device)
inference_embeddings = embedding_model.get_image_features(**tokens)
with torch.no_grad():
output = model(inference_embeddings.unsqueeze(0))
prob = output.softmax(dim=1)
_, indices = torch.sort(prob[0], descending=True)
for idx in indices:
i = idx.item()
st.write(f"Prediction: {idx_to_class[i]}")
st.progress(int(prob[0][i].item() * 100))
try:
shutil.rmtree(frames_dir)
os.remove(new_video_path)
os.remove(final_video_path)
print(f"Folder '{frames_dir}' and its contents have been deleted.")
except Exception as e:
print(f"Error while deleting folder '{frames_dir}': {e}")
if __name__ == "__main__":
app()