lucas-ventura commited on
Commit
ba5b89f
·
verified ·
1 Parent(s): 69c046b

Upload models.py

Browse files
Files changed (1) hide show
  1. models.py +70 -0
models.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from huggingface_hub import hf_hub_download
4
+
5
+ REPO_ID = "lucas-ventura/chapter-llama"
6
+
7
+ # Dictionary mapping short identifiers to full model paths
8
+ MODEL_PATHS = {
9
+ "asr-1k": "outputs/chapterize/Meta-Llama-3.1-8B-Instruct/asr/default/sml1k_train/default/model_checkpoints/",
10
+ "asr-10k": "outputs/chapterize/Meta-Llama-3.1-8B-Instruct/asr/default/s10k-2_train/default/model_checkpoints/",
11
+ "captions_asr-1k": "outputs/chapterize/Meta-Llama-3.1-8B-Instruct/captions_asr/asr_s10k-2_train_preds+no-asr-10s/sml1k_train/default/model_checkpoints/",
12
+ "captions_asr-10k": "outputs/chapterize/Meta-Llama-3.1-8B-Instruct/captions_asr/asr_s10k-2_train_preds+no-asr-10s/sml10k_train/default/model_checkpoints/",
13
+ }
14
+
15
+ FILES = ["adapter_model.safetensors", "adapter_config.json"]
16
+
17
+
18
+ def download_model(model_id_or_path, overwrite=False, local_dir=None):
19
+ # Get filename from aliases or use the provided path
20
+ model_path = MODEL_PATHS.get(model_id_or_path, model_id_or_path)
21
+
22
+ for file in FILES:
23
+ try:
24
+ file_path = Path(model_path) / file
25
+ cache_path = hf_hub_download(
26
+ repo_id=REPO_ID,
27
+ filename=str(file_path),
28
+ force_download=overwrite,
29
+ local_dir=local_dir,
30
+ )
31
+
32
+ if not overwrite:
33
+ print(f"File {file} found in cache at: {cache_path}")
34
+ else:
35
+ print(f"File {file} downloaded to: {cache_path}")
36
+
37
+ except Exception as e:
38
+ print(f"Error downloading {file}: {e}")
39
+ return None
40
+
41
+ print("All files loaded successfully")
42
+ return str(Path(cache_path).parent)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ import argparse
47
+
48
+ parser = argparse.ArgumentParser(
49
+ description="Download models from Hugging Face Hub"
50
+ )
51
+ parser.add_argument(
52
+ "model_id", type=str, help="ID or full path of the model to download"
53
+ )
54
+ parser.add_argument(
55
+ "--overwrite",
56
+ action="store_true",
57
+ help="Force re-download even if the model exists in cache",
58
+ )
59
+ parser.add_argument(
60
+ "--local_dir",
61
+ type=str,
62
+ default=None,
63
+ help="Download to local directory instead of cache",
64
+ )
65
+ args = parser.parse_args()
66
+
67
+ model_dir = download_model(
68
+ args.model_id, overwrite=args.overwrite, local_dir=args.local_dir
69
+ )
70
+ print(f"Model directory: {model_dir}")