File size: 4,256 Bytes
077d8c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import shutil
from PIL import Image
import ffmpeg
import streamlit as st
import torch
from transformers import AutoProcessor, AutoModel
from src.lstm_model import LSTMNetwork
from src.frames import extract_frames, convert_to_mp4

# Required dictionary
idx_to_class = {0: 'cover', 1: 'defense', 2: 'flick', 3: 'hook', 4: 'late_cut', 
                5: 'lofted', 6: 'pull', 7: 'square_cut', 8: 'straight', 9: 'sweep'}

class_label_mapping = {'cover': 0, 'defense': 1, 'flick': 2, 'hook': 3, 'late_cut': 4,
                      'lofted': 5, 'pull': 6, 'square_cut': 7, 'straight': 8, 'sweep': 9}

# Definig the paths 
CLIP_MODEL_PATH = "clip-cricket-classifier.pt"
SIGLIP_MODEL_PATH = "siglip-cricket-classifier.pt"

CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
SIGLIP_MODEL_ID = "google/siglip-base-patch16-224"

def embeddings_creators(MODEL_ID):
    embedding_processor = AutoProcessor.from_pretrained(MODEL_ID)
    embedding_model = AutoModel.from_pretrained(MODEL_ID)
    embedding_model.to(device)
    return embedding_processor, embedding_model

def load_model(MODEL_PATH):
    if MODEL_PATH == CLIP_MODEL_PATH:
        input_size = 512
    elif MODEL_PATH == SIGLIP_MODEL_PATH:
        input_size = 768
    else:
        raise ValueError(f"Invalid model path: {MODEL_PATH}")
    model = LSTMNetwork(input_size=input_size, hidden_size=256, num_classes=10).to(device)
    model.load_state_dict(torch.load(MODEL_PATH))
    return model

# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def app():
    st.image("assets/banner.png")
    st.title("Cricket Shot Classifier", anchor=False)

    model_choice = st.radio("Select a model", ["None", "CLIP", "SIGLIP"])

    if model_choice == "None":
        st.stop()
        st.write("Please select a model")

    if model_choice == "CLIP":
        embedding_processor, embedding_model = embeddings_creators(CLIP_MODEL_ID)
        model = load_model(CLIP_MODEL_PATH)

    elif model_choice == "SIGLIP":
        embedding_processor, embedding_model = embeddings_creators(SIGLIP_MODEL_ID)
        model = load_model(SIGLIP_MODEL_PATH)
    
    # List sample videos from assets folder
    sample_videos = [f for f in os.listdir("assets") if f.endswith(('.avi'))]
    if not sample_videos:
        st.error("No sample videos found in assets folder")
        st.stop()
    
    selected_video = st.selectbox("Select a sample video", sample_videos)
    video_path = os.path.join("assets", selected_video)
    
    save_directory = './demo'
    os.makedirs(save_directory, exist_ok=True)
    new_video_path = f"{save_directory}/{selected_video}"
    shutil.copy2(video_path, new_video_path)
        

    final_video_path = f"{save_directory}/{os.path.splitext(os.path.basename(new_video_path))[0]}.mp4"
    
    if not new_video_path.lower().endswith('.mp4'):
        convert_to_mp4(new_video_path, final_video_path)
    else:
        final_video_path = new_video_path

    st.video(final_video_path)

    frames_dir = f"{save_directory}/frames"
    os.makedirs(frames_dir, exist_ok=True)
    extract_frames(final_video_path, frames_dir)
    st.write("Frames extracted from the video.")

    inference_paths = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
    inference_images = [Image.open(path).convert("RGB") for path in inference_paths]
    tokens = embedding_processor(
        text=None,
        images=inference_images,
        return_tensors="pt"
    ).to(device)
    inference_embeddings = embedding_model.get_image_features(**tokens)

    with torch.no_grad():
        output = model(inference_embeddings.unsqueeze(0))
        prob = output.softmax(dim=1)
        
        _, indices = torch.sort(prob[0], descending=True)
        
        for idx in indices:
            i = idx.item()
            st.write(f"Prediction: {idx_to_class[i]}")
            st.progress(int(prob[0][i].item() * 100))

    try:
        shutil.rmtree(frames_dir)
        os.remove(new_video_path)
        os.remove(final_video_path)
        print(f"Folder '{frames_dir}' and its contents have been deleted.")
    except Exception as e:
        print(f"Error while deleting folder '{frames_dir}': {e}")


if __name__ == "__main__":
    app()