File size: 2,406 Bytes
165ee00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os

model_path = 'final_models'

def prepare_models():
    pfns4bo_dir = os.path.dirname(__file__)
    model_names = ['hebo_morebudget_9_unused_features_3_userpriorperdim2_8.pt',
                   'model_sampled_warp_simple_mlp_for_hpob_46.pt',
                   'model_hebo_morebudget_9_unused_features_3.pt',]

    for name in model_names:
        weights_path = os.path.join(pfns4bo_dir, model_path, name)
        compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + '.gz')
        if not os.path.exists(weights_path):
            if not os.path.exists(compressed_weights_path):
                print("Downloading", os.path.abspath(compressed_weights_path))
                import requests
                url = f'https://github.com/automl/PFNs4BO/raw/main/pfns4bo/final_models/{name + ".gz"}'
                r = requests.get(url, allow_redirects=True)
                os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True)
                with open(compressed_weights_path, 'wb') as f:
                    f.write(r.content)
            if os.path.exists(compressed_weights_path):
                print("Unzipping", name)
                os.system(f"gzip -dk {compressed_weights_path}")
            else:
                print("Failed to find", compressed_weights_path)
                print("Make sure you have an internet connection to download the model automatically..")
        if os.path.exists(weights_path):
            print("Successfully located model at", weights_path)


model_dict = {
    'hebo_plus_userprior_model': os.path.join(os.path.dirname(__file__),model_path,
                                              'hebo_morebudget_9_unused_features_3_userpriorperdim2_8.pt'),
    'hebo_plus_model': os.path.join(os.path.dirname(__file__),model_path,
                                    'model_hebo_morebudget_9_unused_features_3.pt'),
    'bnn_model': os.path.join(os.path.dirname(__file__),model_path,'model_sampled_warp_simple_mlp_for_hpob_46.pt')
}


def __getattr__(name):
    if name in model_dict:
        if not os.path.exists(model_dict[name]):
            print("Can't find", os.path.abspath(model_dict[name]), "thus unzipping/downloading models now.")
            print("This might take a while..")
            prepare_models()
        return model_dict[name]
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")