PFNEngineeringConstrainedBO / pfns4bo /priors /hyperparameter_sampling.py
rosenyu's picture
Upload 529 files
165ee00 verified
# this is a wrapper prior that samples hyperparameters which are set to be ConfigSpace parameters
import torch
from .prior import Batch
from ConfigSpace import hyperparameters as CSH
import ConfigSpace as CS
from copy import deepcopy
def list_all_hps_in_nested(config):
if isinstance(config, CSH.Hyperparameter):
return [config]
elif isinstance(config, dict):
result = []
for k, v in config.items():
result += list_all_hps_in_nested(v)
return result
else:
return []
def create_configspace_from_hierarchical(config):
cs = CS.ConfigurationSpace()
for hp in list_all_hps_in_nested(config):
cs.add_hyperparameter(hp)
return cs
def fill_in_configsample(config, configsample):
# config is our dict that defines config distribution
# configsample is a CS.Configuration
hierarchical_configsample = deepcopy(config)
for k, v in config.items():
if isinstance(v, CSH.Hyperparameter):
hierarchical_configsample[k] = configsample[v.name]
elif isinstance(v, dict):
hierarchical_configsample[k] = fill_in_configsample(v, configsample)
return hierarchical_configsample
def sample_configspace_hyperparameters(hyperparameters):
cs = create_configspace_from_hierarchical(hyperparameters)
cs_sample = cs.sample_configuration()
return fill_in_configsample(hyperparameters, cs_sample)
def get_batch(batch_size, *args, hyperparameters, get_batch, **kwargs):
num_models = min(hyperparameters.get('num_hyperparameter_samples_per_batch', 1), batch_size)
if num_models == -1:
num_models = batch_size
assert batch_size % num_models == 0, 'batch_size must be a multiple of num_models'
cs = create_configspace_from_hierarchical(hyperparameters)
sub_batches = []
for i in range(num_models):
cs_sample = cs.sample_configuration()
hyperparameters_sample = fill_in_configsample(hyperparameters, cs_sample)
sub_batch = get_batch(batch_size//num_models, *args, hyperparameters=hyperparameters_sample, **kwargs)
sub_batches.append(sub_batch)
# concat x, y, target (and maybe style)
#assert 3 <= len(sub_batch) <= 4
#return tuple(torch.cat([sb[i] for sb in sub_batches], dim=(0 if i == 3 else 1)) for i in range(len(sub_batch)))
assert all(not b.other_filled_attributes(set_of_attributes=('x', 'y', 'target_y')) for b in sub_batches)
return Batch(x=torch.cat([b.x for b in sub_batches], dim=1),
y=torch.cat([b.y for b in sub_batches], dim=1),
target_y=torch.cat([b.target_y for b in sub_batches], dim=1))