import contextlib from .. import transformer from .. import bar_distribution import torch import scipy import math from sklearn.preprocessing import power_transform, PowerTransformer def log01(x, eps=.0000001, input_between_zero_and_one=False): logx = torch.log(x + eps) if input_between_zero_and_one: return (logx - math.log(eps)) / (math.log(1 + eps) - math.log(eps)) return (logx - logx.min(0)[0]) / (logx.max(0)[0] - logx.min(0)[0]) def log01_batch(x, eps=.0000001, input_between_zero_and_one=False): x = x.repeat(1, x.shape[-1] + 1, 1) for b in range(x.shape[-1]): x[:, b, b] = log01(x[:, b, b], eps=eps, input_between_zero_and_one=input_between_zero_and_one) return x def lognormed_batch(x, eval_pos, eps=.0000001): x = x.repeat(1, x.shape[-1] + 1, 1) for b in range(x.shape[-1]): logx = torch.log(x[:, b, b]+eps) x[:, b, b] = (logx - logx[:eval_pos].mean(0))/logx[:eval_pos].std(0) return x def _rank_transform(x_train, x): assert len(x_train.shape) == len(x.shape) == 1 relative_to = torch.cat((torch.zeros_like(x_train[:1]),x_train.unique(sorted=True,), torch.ones_like(x_train[-1:])),-1) higher_comparison = (relative_to < x[...,None]).sum(-1).clamp(min=1) pos_inside_interval = (x - relative_to[higher_comparison-1])/(relative_to[higher_comparison] - relative_to[higher_comparison-1]) x_transformed = higher_comparison - 1 + pos_inside_interval return x_transformed/(len(relative_to)-1.) def rank_transform(x_train, x): assert x.shape[1] == x_train.shape[1], f"{x.shape=} and {x_train.shape=}" # make sure everything is between 0 and 1 assert (x_train >= 0.).all() and (x_train <= 1.).all(), f"{x_train=}" assert (x >= 0.).all() and (x <= 1.).all(), f"{x=}" return_x = x.clone() for feature_dim in range(x.shape[1]): return_x[:, feature_dim] = _rank_transform(x_train[:, feature_dim], x[:, feature_dim]) return return_x def general_power_transform(x_train, x_apply, eps, less_safe=False): if eps > 0: try: pt = PowerTransformer(method='box-cox') pt.fit(x_train.cpu()+eps) x_out = torch.tensor(pt.transform(x_apply.cpu()+eps), dtype=x_apply.dtype, device=x_apply.device) except ValueError as e: print(e) x_out = x_apply - x_train.mean(0) else: pt = PowerTransformer(method='yeo-johnson') if not less_safe and (x_train.std() > 1_000 or x_train.mean().abs() > 1_000): x_apply = (x_apply - x_train.mean(0)) / x_train.std(0) x_train = (x_train - x_train.mean(0)) / x_train.std(0) # print('inputs are LAARGEe, normalizing them') try: pt.fit(x_train.cpu().double()) except ValueError as e: print('caught this errrr', e) if less_safe: x_train = (x_train - x_train.mean(0)) / x_train.std(0) x_apply = (x_apply - x_train.mean(0)) / x_train.std(0) else: x_train = x_train - x_train.mean(0) x_apply = x_apply - x_train.mean(0) pt.fit(x_train.cpu().double()) x_out = torch.tensor(pt.transform(x_apply.cpu()), dtype=x_apply.dtype, device=x_apply.device) if torch.isnan(x_out).any() or torch.isinf(x_out).any(): print('WARNING: power transform failed') print(f"{x_train=} and {x_apply=}") x_out = x_apply - x_train.mean(0) return x_out #@torch.inference_mode() def general_acq_function(model: transformer.TransformerModel, x_given, y_given, x_eval, apply_power_transform=True, rand_sample=False, znormalize=False, pre_normalize=False, pre_znormalize=False, predicted_mean_fbest=False, input_znormalize=False, max_dataset_size=10_000, remove_features_with_one_value_only=False, return_actual_ei=False, acq_function='ei', ucb_rest_prob=.05, ensemble_log_dims=False, ensemble_type='mean_probs', # in ('mean_probs', 'max_acq') input_power_transform=False, power_transform_eps=.0, input_power_transform_eps=.0, input_rank_transform=False, ensemble_input_rank_transform=False, ensemble_power_transform=False, ensemble_feature_rotation=False, style=None, outlier_stretching_interval=0.0, verbose=False, unsafe_power_transform=False, ): """ Differences to HEBO: - The noise can't be set in the same way, as it depends on the tuning of HPs via VI. - Log EI and PI are always used directly instead of using the approximation. This is a stochastic function, relying on torch.randn :param model: :param x_given: torch.Tensor of shape (N, D) :param y_given: torch.Tensor of shape (N, 1) or (N,) :param x_eval: torch.Tensor of shape (M, D) :param kappa: :param eps: :return: """ assert ensemble_type in ('mean_probs', 'max_acq') if rand_sample is not False \ and (len(x_given) == 0 or ((1 + x_given.shape[1] if rand_sample is None else max(2, rand_sample)) > x_given.shape[0])): print('rando') return torch.zeros_like(x_eval[:,0]) #torch.randperm(x_eval.shape[0])[0] y_given = y_given.reshape(-1) assert len(y_given) == len(x_given) if apply_power_transform: if pre_normalize: y_normed = y_given / y_given.std() if not torch.isinf(y_normed).any() and not torch.isnan(y_normed).any(): y_given = y_normed elif pre_znormalize: y_znormed = (y_given - y_given.mean()) / y_given.std() if not torch.isinf(y_znormed).any() and not torch.isnan(y_znormed).any(): y_given = y_znormed y_given = general_power_transform(y_given.unsqueeze(1), y_given.unsqueeze(1), power_transform_eps, less_safe=unsafe_power_transform).squeeze(1) if verbose: print(f"{y_given=}") #y_given = torch.tensor(power_transform(y_given.cpu().unsqueeze(1), method='yeo-johnson', standardize=znormalize), device=y_given.device, dtype=y_given.dtype,).squeeze(1) y_given_std = torch.tensor(1., device=y_given.device, dtype=y_given.dtype) if znormalize and not apply_power_transform: if len(y_given) > 1: y_given_std = y_given.std() y_given_mean = y_given.mean() y_given = (y_given - y_given_mean) / y_given_std if remove_features_with_one_value_only: x_all = torch.cat([x_given, x_eval], dim=0) only_one_value_feature = torch.tensor([len(torch.unique(x_all[:,i])) for i in range(x_all.shape[1])]) == 1 x_given = x_given[:,~only_one_value_feature] x_eval = x_eval[:,~only_one_value_feature] if outlier_stretching_interval > 0.: tx = torch.cat([x_given, x_eval], dim=0) m = outlier_stretching_interval eps = 1e-10 small_values = (tx < m) & (tx > 0.) tx[small_values] = m * (torch.log(tx[small_values] + eps) - math.log(eps)) / (math.log(m + eps) - math.log(eps)) large_values = (tx > 1. - m) & (tx < 1.) tx[large_values] = 1. - m * (torch.log(1 - tx[large_values] + eps) - math.log(eps)) / ( math.log(m + eps) - math.log(eps)) x_given = tx[:len(x_given)] x_eval = tx[len(x_given):] if input_znormalize: # implementation that relies on the test set, too... std = x_given.std(dim=0) std[std == 0.] = 1. mean = x_given.mean(dim=0) x_given = (x_given - mean) / std x_eval = (x_eval - mean) / std if input_power_transform: x_given = general_power_transform(x_given, x_given, input_power_transform_eps) x_eval = general_power_transform(x_given, x_eval, input_power_transform_eps) if input_rank_transform is True or input_rank_transform == 'full': # uses test set x statistics... x_all = torch.cat((x_given,x_eval), dim=0) for feature_dim in range(x_all.shape[-1]): uniques = torch.sort(torch.unique(x_all[..., feature_dim])).values x_eval[...,feature_dim] = torch.searchsorted(uniques,x_eval[..., feature_dim]).float() / (len(uniques)-1) x_given[...,feature_dim] = torch.searchsorted(uniques,x_given[..., feature_dim]).float() / (len(uniques)-1) elif input_rank_transform is False: pass elif input_rank_transform == 'train': x_given = rank_transform(x_given, x_given) x_eval = rank_transform(x_given, x_eval) elif input_rank_transform.startswith('train'): likelihood = float(input_rank_transform.split('_')[-1]) if torch.rand(1).item() < likelihood: print('rank transform') x_given = rank_transform(x_given, x_given) x_eval = rank_transform(x_given, x_eval) else: raise NotImplementedError # compute logits criterion: bar_distribution.BarDistribution = model.criterion x_predict = torch.cat([x_given, x_eval], dim=0) logits_list = [] for x_feed in torch.split(x_predict, max_dataset_size, dim=0): x_full_feed = torch.cat([x_given, x_feed], dim=0).unsqueeze(1) y_full_feed = y_given.unsqueeze(1) if ensemble_log_dims == '01': x_full_feed = log01_batch(x_full_feed) elif ensemble_log_dims == 'global01' or ensemble_log_dims is True: x_full_feed = log01_batch(x_full_feed, input_between_zero_and_one=True) elif ensemble_log_dims == '01-10': x_full_feed = torch.cat((log01_batch(x_full_feed)[:, :-1], log01_batch(1. - x_full_feed)), 1) elif ensemble_log_dims == 'norm': x_full_feed = lognormed_batch(x_full_feed, len(x_given)) elif ensemble_log_dims is not False: raise NotImplementedError if ensemble_feature_rotation: x_full_feed = torch.cat([x_full_feed[:, :, (i+torch.arange(x_full_feed.shape[2])) % x_full_feed.shape[2]] for i in range(x_full_feed.shape[2])], dim=1) if ensemble_input_rank_transform == 'train' or ensemble_input_rank_transform is True: x_full_feed = torch.cat([rank_transform(x_given, x_full_feed[:,i,:])[:,None] for i in range(x_full_feed.shape[1])] + [x_full_feed], dim=1) if ensemble_power_transform: assert apply_power_transform is False y_full_feed = torch.cat((general_power_transform(y_full_feed, y_full_feed, power_transform_eps), y_full_feed), dim=1) if style is not None: if callable(style): style = style() if isinstance(style, torch.Tensor): style = style.to(x_full_feed.device) else: style = torch.tensor(style, device=x_full_feed.device).view(1, 1).repeat(x_full_feed.shape[1], 1) logits = model( (style, x_full_feed.repeat_interleave(dim=1, repeats=y_full_feed.shape[1]), y_full_feed.repeat(1,x_full_feed.shape[1])), single_eval_pos=len(x_given) ) if ensemble_type == 'mean_probs': logits = logits.softmax(-1).mean(1, keepdim=True).log_() # (num given + num eval, 1, num buckets) # print('in ensemble_type == mean_probs ') logits_list.append(logits) # (< max_dataset_size, 1 , num_buckets) logits = torch.cat(logits_list, dim=0) # (num given + num eval, 1 or (num_features+1), num buckets) del logits_list, x_full_feed if torch.isnan(logits).any(): print('nan logits') print(f"y_given: {y_given}, x_given: {x_given}, x_eval: {x_eval}") print(f"logits: {logits}") return torch.zeros_like(x_eval[:,0]) #logits = model((torch.cat([x_given, x_given, x_eval], dim=0).unsqueeze(1), # torch.cat([y_given, torch.zeros(len(x_eval)+len(x_given), device=y_given.device)], dim=0).unsqueeze(1)), # single_eval_pos=len(x_given))[:,0] # (N + M, num_buckets) logits_given = logits[:len(x_given)] logits_eval = logits[len(x_given):] #tau = criterion.mean(logits_given)[torch.argmax(y_given)] # predicted mean at the best y if predicted_mean_fbest: tau = criterion.mean(logits_given)[torch.argmax(y_given)].squeeze(0) else: tau = torch.max(y_given) #log_ei = torch.stack([criterion.ei(logits_eval[:,i], noisy_best_f[i]).log() for i in range(len(logits_eval))],0) def acq_ensembling(acq_values): # (points, ensemble dim) return acq_values.max(1).values if isinstance(acq_function, (dict,list)): acq_function = acq_function[style] if acq_function == 'ei': acq_value = acq_ensembling(criterion.ei(logits_eval, tau)) elif acq_function == 'ei_or_rand': if torch.rand(1).item() < 0.5: acq_value = torch.rand(len(x_eval)) else: acq_value = acq_ensembling(criterion.ei(logits_eval, tau)) elif acq_function == 'pi': acq_value = acq_ensembling(criterion.pi(logits_eval, tau)) elif acq_function == 'ucb': acq_function = criterion.ucb if ucb_rest_prob is not None: acq_function = lambda *args: criterion.ucb(*args, rest_prob=ucb_rest_prob) acq_value = acq_ensembling(acq_function(logits_eval, tau)) elif acq_function == 'mean': acq_value = acq_ensembling(criterion.mean(logits_eval)) elif acq_function.startswith('hebo'): noise, upsi, delta, eps = (float(v) for v in acq_function.split('_')[1:]) noise = y_given_std * math.sqrt(2 * noise) kappa = math.sqrt(upsi * 2 * ((2.0 + x_given.shape[1] / 2.0) * math.log(max(1, len(x_given))) + math.log( 3 * math.pi ** 2 / (3 * delta)))) rest_prob = 1. - .5 * (1 + torch.erf(torch.tensor(kappa / math.sqrt(2), device=logits.device))) ucb = acq_ensembling(criterion.ucb(logits_eval, None, rest_prob=rest_prob)) \ + torch.randn(len(logits_eval), device=logits_eval.device) * noise noisy_best_f = tau + eps + \ noise * torch.randn(len(logits_eval), device=logits_eval.device)[:, None].repeat(1, logits_eval.shape[1]) log_pi = acq_ensembling(criterion.pi(logits_eval, noisy_best_f).log()) # log_ei = torch.stack([criterion.ei(logits_eval[:,i], noisy_best_f[i]).log() for i in range(len(logits_eval))],0) log_ei = acq_ensembling(criterion.ei(logits_eval, noisy_best_f).log()) acq_values = torch.stack([ucb, log_ei, log_pi], dim=1) def is_pareto_efficient(costs): """ Find the pareto-efficient points :param costs: An (n_points, n_costs) array :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient """ is_efficient = torch.ones(costs.shape[0], dtype=bool, device=costs.device) for i, c in enumerate(costs): if is_efficient[i]: is_efficient[is_efficient.clone()] = (costs[is_efficient] < c).any( 1) # Keep any point with a lower cost is_efficient[i] = True # And keep self return is_efficient acq_value = is_pareto_efficient(-acq_values) else: raise ValueError(f'Unknown acquisition function: {acq_function}') max_acq = acq_value.max() return acq_value if return_actual_ei else (acq_value == max_acq) def optimize_acq(model, known_x, known_y, num_grad_steps=10, num_random_samples=100, lr=.01, **kwargs): """ intervals are assumed to be between 0 and 1 only works with ei recommended extra kwarg: ensemble_input_rank_transform=='train' :param model: model to optimize, should already handle different num_features with its encoder You can add this simply with `model.encoder = encoders.VariableNumFeaturesEncoder(model.encoder, model.encoder.num_features)` :param known_x: (N, num_features) :param known_y: (N,) :param num_grad_steps: int :param num_random_samples: int :param lr: float :param kwargs: will be given to `general_acq_function` :return: """ x_eval = torch.rand(num_random_samples, known_x.shape[1]).requires_grad_(True) opt = torch.optim.Adam(params=[x_eval], lr=lr) best_acq, best_x = -float('inf'), x_eval[0].detach() for grad_step in range(num_grad_steps): acq = general_acq_function(model, known_x, known_y, x_eval, return_actual_ei=True, **kwargs) max_acq = acq.detach().max().item() if max_acq > best_acq: best_x = x_eval[acq.argmax()].detach() best_acq = max_acq (-acq.mean()).backward() assert (x_eval.grad != 0.).any() if torch.isfinite(x_eval.grad).all(): opt.step() opt.zero_grad() with torch.no_grad(): x_eval.clamp_(min=0., max=1.) return best_x def optimize_acq_w_lbfgs(model, known_x, known_y, num_grad_steps=15_000, num_candidates=100, pre_sample_size=100_000, device='cpu', verbose=False, dims_wo_gradient_opt=[], rand_sample_func=None, **kwargs): """ intervals are assumed to be between 0 and 1 only works with deterministic acq recommended extra kwarg: ensemble_input_rank_transform=='train' :param model: model to optimize, should already handle different num_features with its encoder You can add this simply with `model.encoder = encoders.VariableNumFeaturesEncoder(model.encoder, model.encoder.num_features)` :param known_x: (N, num_features) :param known_y: (N,) :param num_grad_steps: int: how many steps to take inside of scipy, can be left high, as it stops most of the time automatically early :param num_candidates: int: how many candidates to optimize with LBFGS, increases costs when higher :param pre_sample_size: int: how many settings to try first with a random search, before optimizing the best with grads :param dims_wo_gradient_opt: int: which dimensions to not optimize with gradients, but with random search only :param rand_sample_func: function: how to sample random points, should be a function that takes a number of samples and returns a tensor For example `lambda n: torch.rand(n, known_x.shape[1])`. :param kwargs: will be given to `general_acq_function` :return: """ num_features = known_x.shape[1] dims_w_gradient_opt = sorted(set(range(num_features)) - set(dims_wo_gradient_opt)) known_x = known_x.to(device) known_y = known_y.to(device) pre_sample_size = max(pre_sample_size, num_candidates) rand_sample_func = rand_sample_func or (lambda n: torch.rand(n, num_features, device=device)) if len(known_x) < pre_sample_size: x_initial = torch.cat((rand_sample_func(pre_sample_size-len(known_x)).to(device), known_x), 0) else: x_initial = rand_sample_func(pre_sample_size) x_initial = x_initial.clamp(min=0., max=1.) x_initial_all = x_initial model.to(device) with torch.no_grad(): acq = general_acq_function(model, known_x, known_y, x_initial.to(device), return_actual_ei=True, **kwargs) if verbose: import matplotlib.pyplot as plt if x_initial.shape[1] == 2: plt.title('initial acq values, red -> blue') plt.scatter(x_initial[:, 0][:100], x_initial[:, 1][:100], c=acq.cpu().numpy()[:100], cmap='RdBu') x_initial = x_initial[acq.argsort(descending=True)[:num_candidates].cpu()].detach() # num_candidates x num_features x_initial_all_ei = acq.cpu().detach() def opt_f(x): x_eval = torch.tensor(x).view(-1, len(dims_w_gradient_opt)).float().to(device).requires_grad_(True) x_eval_new = x_initial.clone().detach().to(device) x_eval_new[:, dims_w_gradient_opt] = x_eval assert x_eval_new.requires_grad assert not torch.isnan(x_eval_new).any() model.requires_grad_(False) acq = general_acq_function(model, known_x, known_y, x_eval_new, return_actual_ei=True, **kwargs) neg_mean_acq = -acq.mean() neg_mean_acq.backward() #print(neg_mean_acq.detach().numpy(), x_eval.grad.detach().view(*x.shape).numpy()) with torch.no_grad(): x_eval.grad[x_eval.grad != x_eval.grad] = 0. return neg_mean_acq.detach().cpu().to(torch.float64).numpy(), \ x_eval.grad.detach().view(*x.shape).cpu().to(torch.float64).numpy() # Optimize best candidates with LBFGS if num_grad_steps > 0 and len(dims_w_gradient_opt) > 0: # the columns not in dims_wo_gradient_opt will be optimized with gradients x_initial_for_gradient_opt = x_initial[:, dims_w_gradient_opt].detach().cpu().flatten().numpy() # x_initial.cpu().flatten().numpy() res = scipy.optimize.minimize(opt_f, x_initial_for_gradient_opt, method='L-BFGS-B', jac=True, bounds=[(0, 1)]*x_initial_for_gradient_opt.size, options={'maxiter': num_grad_steps}) results = x_initial.cpu() results[:, dims_w_gradient_opt] = torch.tensor(res.x).float().view(-1, len(dims_w_gradient_opt)) else: results = x_initial.cpu() results = results.clamp(min=0., max=1.) # Recalculate the acq values for the best candidates with torch.no_grad(): acq = general_acq_function(model, known_x, known_y, results.to(device), return_actual_ei=True, verbose=verbose, **kwargs) #print(acq) if verbose: from scipy.stats import rankdata import matplotlib.pyplot as plt if results.shape[1] == 2: plt.scatter(results[:, 0], results[:, 1], c=rankdata(acq.cpu().numpy()), marker='x', cmap='RdBu') plt.show() best_x = results[acq.argmax().item()].detach() acq_order = acq.argsort(descending=True).cpu() all_order = x_initial_all_ei.argsort(descending=True).cpu() return best_x.detach(), results[acq_order].detach(), acq.cpu()[acq_order].detach(), x_initial_all.cpu()[all_order].detach(), x_initial_all_ei.cpu()[all_order].detach() from ..utils import to_tensor class TransformerBOMethod: def __init__(self, model, acq_f=general_acq_function, device='cpu:0', fit_encoder=None, **kwargs): self.model = model self.device = device self.kwargs = kwargs self.acq_function = acq_f self.fit_encoder = fit_encoder @torch.no_grad() def observe_and_suggest(self, X_obs, y_obs, X_pen, return_actual_ei=False): # assert X_pen is not None # assumptions about X_obs and X_pen: # X_obs is a numpy array of shape (n_samples, n_features) # y_obs is a numpy array of shape (n_samples,), between 0 and 1 # X_pen is a numpy array of shape (n_samples_left, n_features) X_obs = to_tensor(X_obs, device=self.device).to(torch.float32) y_obs = to_tensor(y_obs, device=self.device).to(torch.float32).view(-1) X_pen = to_tensor(X_pen, device=self.device).to(torch.float32) assert len(X_obs) == len(y_obs), "make sure both X_obs and y_obs have the same length." self.model.to(self.device) if self.fit_encoder is not None: w = self.fit_encoder(self.model, X_obs, y_obs) X_obs = w(X_obs) X_pen = w(X_pen) # with (torch.cuda.amp.autocast() if self.device.type != 'cpu' else contextlib.nullcontext()): with (torch.cuda.amp.autocast() if self.device[:3] != 'cpu' else contextlib.nullcontext()): acq_values = self.acq_function(self.model, X_obs, y_obs, X_pen, return_actual_ei=return_actual_ei, **self.kwargs).cpu().clone() # bool array acq_mask = acq_values.max() == acq_values possible_next = torch.arange(len(X_pen))[acq_mask] if len(possible_next) == 0: possible_next = torch.arange(len(X_pen)) r = possible_next[torch.randperm(len(possible_next))[0]].cpu().item() if return_actual_ei: return r, acq_values else: return r