File size: 1,878 Bytes
bcc039b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) Meta Platforms, Inc. and affiliates.
import argparse
import os
from typing import Optional

from requests.exceptions import HTTPError

TOKENIZER = {
    "llama2": ("meta-llama/Llama-2-7b", "tokenizer.model"),
    "llama3": ("meta-llama/Meta-Llama-3-8B", "original/tokenizer.model"),
    "gemma": ("google/gemma-2-9b", "tokenizer.model"),
}


def main(tokenizer_name: str, path_to_save: str, api_key: Optional[str] = None):
    if tokenizer_name in TOKENIZER:
        repo_id, filename = TOKENIZER[tokenizer_name]

        from huggingface_hub import hf_hub_download

        try:
            hf_hub_download(
                repo_id=repo_id,
                filename=filename,
                local_dir=path_to_save,
                local_dir_use_symlinks=False,
                token=api_key if api_key else None,
            )
        except HTTPError as e:
            if e.response.status_code == 401:
                print(
                    "You need to pass a valid `--hf_token=...` to download private checkpoints."
                )
            else:
                raise e
    else:
        from tiktoken import get_encoding

        if "TIKTOKEN_CACHE_DIR" not in os.environ:
            os.environ["TIKTOKEN_CACHE_DIR"] = path_to_save
        try:
            get_encoding(tokenizer_name)
        except ValueError:
            print(
                f"Tokenizer {tokenizer_name} not found. Please check the name and try again."
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("tokenizer_name", type=str)
    parser.add_argument("tokenizer_dir", type=str, default=8)
    parser.add_argument("--api_key", type=str, default="")
    args = parser.parse_args()

    main(
        tokenizer_name=args.tokenizer_name,
        path_to_save=args.tokenizer_dir,
        api_key=args.api_key,
    )