rokmr commited on
Commit
077d8c0
·
1 Parent(s): c57fe1c

Adding app files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.avi filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ *.DS_Store
177
+ data/
178
+ *.gif
179
+ # *.avi
180
+ *.mp4
181
+ *.mp3
182
+ *.wav
183
+ # *.avi
184
+ *.mp4
185
+ /cricketshot-predictor
186
+ /cricketshot
187
+
188
+ *.DS_Store
189
+ logs/
190
+ *.log
README.md CHANGED
@@ -12,3 +12,10 @@ short_description: Classify the cricket shot
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ # Installtion
17
+ ```bash
18
+ conda create -n hf-cricshot python=3.11
19
+ conda activate hf-cricshot
20
+ pip install -r requirements.txt
21
+ ```
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from PIL import Image
4
+ import ffmpeg
5
+ import streamlit as st
6
+ import torch
7
+ from transformers import AutoProcessor, AutoModel
8
+ from src.lstm_model import LSTMNetwork
9
+ from src.frames import extract_frames, convert_to_mp4
10
+
11
+ # Required dictionary
12
+ idx_to_class = {0: 'cover', 1: 'defense', 2: 'flick', 3: 'hook', 4: 'late_cut',
13
+ 5: 'lofted', 6: 'pull', 7: 'square_cut', 8: 'straight', 9: 'sweep'}
14
+
15
+ class_label_mapping = {'cover': 0, 'defense': 1, 'flick': 2, 'hook': 3, 'late_cut': 4,
16
+ 'lofted': 5, 'pull': 6, 'square_cut': 7, 'straight': 8, 'sweep': 9}
17
+
18
+ # Definig the paths
19
+ CLIP_MODEL_PATH = "clip-cricket-classifier.pt"
20
+ SIGLIP_MODEL_PATH = "siglip-cricket-classifier.pt"
21
+
22
+ CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
23
+ SIGLIP_MODEL_ID = "google/siglip-base-patch16-224"
24
+
25
+ def embeddings_creators(MODEL_ID):
26
+ embedding_processor = AutoProcessor.from_pretrained(MODEL_ID)
27
+ embedding_model = AutoModel.from_pretrained(MODEL_ID)
28
+ embedding_model.to(device)
29
+ return embedding_processor, embedding_model
30
+
31
+ def load_model(MODEL_PATH):
32
+ if MODEL_PATH == CLIP_MODEL_PATH:
33
+ input_size = 512
34
+ elif MODEL_PATH == SIGLIP_MODEL_PATH:
35
+ input_size = 768
36
+ else:
37
+ raise ValueError(f"Invalid model path: {MODEL_PATH}")
38
+ model = LSTMNetwork(input_size=input_size, hidden_size=256, num_classes=10).to(device)
39
+ model.load_state_dict(torch.load(MODEL_PATH))
40
+ return model
41
+
42
+ # device
43
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
44
+
45
+ def app():
46
+ st.image("assets/banner.png")
47
+ st.title("Cricket Shot Classifier", anchor=False)
48
+
49
+ model_choice = st.radio("Select a model", ["None", "CLIP", "SIGLIP"])
50
+
51
+ if model_choice == "None":
52
+ st.stop()
53
+ st.write("Please select a model")
54
+
55
+ if model_choice == "CLIP":
56
+ embedding_processor, embedding_model = embeddings_creators(CLIP_MODEL_ID)
57
+ model = load_model(CLIP_MODEL_PATH)
58
+
59
+ elif model_choice == "SIGLIP":
60
+ embedding_processor, embedding_model = embeddings_creators(SIGLIP_MODEL_ID)
61
+ model = load_model(SIGLIP_MODEL_PATH)
62
+
63
+ # List sample videos from assets folder
64
+ sample_videos = [f for f in os.listdir("assets") if f.endswith(('.avi'))]
65
+ if not sample_videos:
66
+ st.error("No sample videos found in assets folder")
67
+ st.stop()
68
+
69
+ selected_video = st.selectbox("Select a sample video", sample_videos)
70
+ video_path = os.path.join("assets", selected_video)
71
+
72
+ save_directory = './demo'
73
+ os.makedirs(save_directory, exist_ok=True)
74
+ new_video_path = f"{save_directory}/{selected_video}"
75
+ shutil.copy2(video_path, new_video_path)
76
+
77
+
78
+ final_video_path = f"{save_directory}/{os.path.splitext(os.path.basename(new_video_path))[0]}.mp4"
79
+
80
+ if not new_video_path.lower().endswith('.mp4'):
81
+ convert_to_mp4(new_video_path, final_video_path)
82
+ else:
83
+ final_video_path = new_video_path
84
+
85
+ st.video(final_video_path)
86
+
87
+ frames_dir = f"{save_directory}/frames"
88
+ os.makedirs(frames_dir, exist_ok=True)
89
+ extract_frames(final_video_path, frames_dir)
90
+ st.write("Frames extracted from the video.")
91
+
92
+ inference_paths = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
93
+ inference_images = [Image.open(path).convert("RGB") for path in inference_paths]
94
+ tokens = embedding_processor(
95
+ text=None,
96
+ images=inference_images,
97
+ return_tensors="pt"
98
+ ).to(device)
99
+ inference_embeddings = embedding_model.get_image_features(**tokens)
100
+
101
+ with torch.no_grad():
102
+ output = model(inference_embeddings.unsqueeze(0))
103
+ prob = output.softmax(dim=1)
104
+
105
+ _, indices = torch.sort(prob[0], descending=True)
106
+
107
+ for idx in indices:
108
+ i = idx.item()
109
+ st.write(f"Prediction: {idx_to_class[i]}")
110
+ st.progress(int(prob[0][i].item() * 100))
111
+
112
+ try:
113
+ shutil.rmtree(frames_dir)
114
+ os.remove(new_video_path)
115
+ os.remove(final_video_path)
116
+ print(f"Folder '{frames_dir}' and its contents have been deleted.")
117
+ except Exception as e:
118
+ print(f"Error while deleting folder '{frames_dir}': {e}")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ app()
123
+
assets/banner.png ADDED

Git LFS Details

  • SHA256: 7e31dcf4c8b06b507a19709755d58291e677189399252032b44e63b6f0cad428
  • Pointer size: 132 Bytes
  • Size of remote file: 3.97 MB
assets/cover_0006.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cda892e8d464a6f78eda93331d65d2f92a8b8de185a404730586d5db5e8de55a
3
+ size 1160320
assets/defense_0007.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1b386f7c57c1031a7b253be4d810ca05764461d8dbb58a7be4ae336f351bfd0
3
+ size 2545010
assets/flick_0008.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b19dc6d3dcef0b153e0b72b7b89f09dae52eada5d87c42ecbfb85e5d83d27a17
3
+ size 1303944
assets/hook_0009.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa88e98e553ecd73b7c6a8d260f0347aceb524ecb591baffc6f61072a2b379bb
3
+ size 1038106
assets/late_cut_0010.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:477e6b970bb1ed9b279a4e10524e77ba5945d1c372aec2aa6d2c077863163196
3
+ size 1222750
assets/lofted_0011.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:665a02766349596622e916616492f6a3abef98808b84bd6540fbc4e6e53b5d7e
3
+ size 3914000
assets/pull_0010.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7f8e5d088e6d863542d057358ec5a8aecda8670380158ece462fc003fb785b7
3
+ size 1239450
assets/square_cut_0011.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5dab8de01051da67ce94d81f5cfedfc1c89dab9cb190af9aa585f9414570d6d9
3
+ size 1535956
assets/straight_0012.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76592c312ced400c1a123e7931d644348f52ec05be63b983929df652c366326d
3
+ size 2034822
assets/sweep_0013.avi ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82fa15ab99021da4e24ccaf26672c79ff66c5990ddc62d44adca6ed939a5a7b8
3
+ size 2164312
clip-cricket-classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eb97f8950a29bd41e8af501cc53a503b50bced86314701fa505554c33d1f66c8
3
+ size 3167016
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ datasets
3
+ evaluate
4
+ imageio
5
+ huggingface-hub
6
+ git+https://github.com/facebookresearch/pytorchvideo
7
+ accelerate>=0.26.0
8
+ scikit-learn
9
+ python-dotenv
10
+ sentencepiece
11
+ protobuf
12
+ torch
13
+ torchvision
14
+ ffmpeg-python
siglip-cricket-classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42d8f0ff10ce0a7a40d860a8d5c3a96f46a07f6bd6fb4f832ede61660d919a97
3
+ size 4215612
src/frames.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import ffmpeg
4
+
5
+ def extract_frames(video_path: str, frames_dir: str) -> None:
6
+ """Extracts frames from a video file using FFmpeg."""
7
+ output_pattern = os.path.join(frames_dir, "video_frame_%04d.jpg")
8
+ ffmpeg.input(video_path).output(output_pattern, vf='fps=5', loglevel='quiet').run()
9
+
10
+ def convert_to_mp4(input_path: str, output_path: str) -> None:
11
+ """Converts a video file to MP4 using FFmpeg."""
12
+ ffmpeg.input(input_path).output(output_path).run()
13
+
14
+
src/lstm_model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class LSTMNetwork(nn.Module):
4
+ def __init__(self, input_size=768, hidden_size=256, num_classes=4):
5
+ super(LSTMNetwork, self).__init__()
6
+ self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1, batch_first=True)
7
+ self.fc = nn.Linear(hidden_size, num_classes)
8
+
9
+ def forward(self, x):
10
+ x, _ = self.lstm(x)
11
+ x = self.fc(x[:, -1, :])
12
+ return x