lucas-ventura commited on
Commit
dd653bd
·
verified ·
1 Parent(s): 0ca274b

Update tools/download/models.py

Browse files
Files changed (1) hide show
  1. tools/download/models.py +36 -2
tools/download/models.py CHANGED
@@ -1,6 +1,6 @@
1
  from pathlib import Path
2
 
3
- from huggingface_hub import hf_hub_download
4
 
5
  REPO_ID = "lucas-ventura/chapter-llama"
6
 
@@ -42,6 +42,40 @@ def download_model(model_id_or_path, overwrite=False, local_dir=None):
42
  return str(Path(cache_path).parent)
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  if __name__ == "__main__":
46
  import argparse
47
 
@@ -67,4 +101,4 @@ if __name__ == "__main__":
67
  model_dir = download_model(
68
  args.model_id, overwrite=args.overwrite, local_dir=args.local_dir
69
  )
70
- print(f"Model directory: {model_dir}")
 
1
  from pathlib import Path
2
 
3
+ from huggingface_hub import hf_hub_download, snapshot_download
4
 
5
  REPO_ID = "lucas-ventura/chapter-llama"
6
 
 
42
  return str(Path(cache_path).parent)
43
 
44
 
45
+ def download_base_model(
46
+ repo_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
47
+ local_dir="./models",
48
+ use_symlinks=False,
49
+ ):
50
+ """
51
+ Downloads the base model from Hugging Face Hub.
52
+
53
+ Args:
54
+ repo_id (str): The repository ID on Hugging Face
55
+ local_dir (str): Directory to save the model to
56
+ use_symlinks (bool): Whether to use symlinks for the downloaded files
57
+
58
+ Returns:
59
+ str: Path to the downloaded model directory
60
+ """
61
+ try:
62
+ print(f"Downloading {repo_id} to {local_dir}...")
63
+ model_path = snapshot_download(
64
+ repo_id=repo_id, local_dir=local_dir, local_dir_use_symlinks=use_symlinks
65
+ )
66
+ print(f"Model downloaded successfully to: {model_path}")
67
+ return model_path
68
+ except Exception as e:
69
+ print(f"Error downloading model {repo_id}: {e}")
70
+ print(
71
+ f"\nYou are downloading `{repo_id}` to `{local_dir}` but failed. "
72
+ f"Please accept the agreement and obtain access at https://huggingface.co/{repo_id}. "
73
+ f"Then, use `huggingface-cli login` and your access tokens at https://huggingface.co/settings/tokens to authenticate. "
74
+ f"After that, run the code again."
75
+ )
76
+ return None
77
+
78
+
79
  if __name__ == "__main__":
80
  import argparse
81
 
 
101
  model_dir = download_model(
102
  args.model_id, overwrite=args.overwrite, local_dir=args.local_dir
103
  )
104
+ print(f"Model directory: {model_dir}")