-
Notifications
You must be signed in to change notification settings - Fork 77
Botorch preset #757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev/gp
Are you sure you want to change the base?
Botorch preset #757
Changes from all commits
8679dba
d76d452
15e9dbb
920dd88
887c6bc
f213893
79f3604
2f847b6
19d6b3b
0538b72
4d41104
578459b
5ae782b
105bf4f
34276c7
3c9d40a
927e044
5d0c4ff
cfaa839
e6f7dbc
412d3dc
4072325
d8a5a9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| """Custom GPyTorch components.""" | ||
|
|
||
| import torch | ||
| from botorch.models.multitask import _compute_multitask_mean | ||
| from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL | ||
| from gpytorch.constraints import GreaterThan | ||
| from gpytorch.likelihoods.hadamard_gaussian_likelihood import HadamardGaussianLikelihood | ||
| from gpytorch.means import MultitaskMean | ||
| from gpytorch.means.multitask_mean import Mean | ||
| from gpytorch.priors import LogNormalPrior | ||
| from torch import Tensor | ||
| from torch.nn import Module | ||
|
|
||
|
|
||
| class HadamardConstantMean(Mean): | ||
| """A GPyTorch mean function implementing BoTorch's multitask mean logic. | ||
|
|
||
| While GPyTorch already provides a :class:`~gpytorch.means.MultitaskMean` class, it | ||
| computes mean values for all (input, task)-pairs (where input means all parameters | ||
| except the task parameter), i.e. it intrinsically applies a Cartesian expansion. | ||
| However, for the regular transfer learning setting, we only need the means for the | ||
| pairs that are actually observed/requested. BoTorch subselects the relevant means | ||
| from the GPyTorch output in `MultiTaskGP.forward`, i.e. it uses a class-based | ||
| approach to define its special logic for the multitask case. In contrast, BayBE uses | ||
| a composition approach, which is more flexible but requires that the logic is | ||
| injected via a self-contained `Mean` object, which is what this class provides. | ||
|
|
||
| Note: | ||
| Analogous to GPyTorch's | ||
| https://github.com/cornellius-gp/gpytorch/blob/main/gpytorch/likelihoods/hadamard_gaussian_likelihood.py | ||
| but where the logic is applied to the mean function, i.e. we learn a different | ||
| (constant) mean for each task. | ||
| """ | ||
|
|
||
| def __init__(self, mean_module: Module, num_tasks: int, task_feature: int): | ||
| super().__init__() | ||
| self.multitask_mean = MultitaskMean(mean_module, num_tasks=num_tasks) | ||
| self.task_feature = task_feature | ||
|
|
||
| def forward(self, x: Tensor) -> Tensor: | ||
| # Adapted from https://github.com/meta-pytorch/botorch/blob/e0f4f5b941b5949a4a1171bf8d4ee9f74f146f3a/botorch/models/multitask.py#L397 | ||
|
|
||
| # Convert task feature to positive index | ||
| task_feature = self.task_feature % x.shape[-1] | ||
|
|
||
| # Split input into task and non-task components | ||
| x_before = x[..., :task_feature] | ||
| task_idcs = x[..., task_feature : task_feature + 1] | ||
| x_after = x[..., task_feature + 1 :] | ||
|
|
||
| return _compute_multitask_mean( | ||
| self.multitask_mean, x_before, task_idcs, x_after | ||
| ) | ||
|
|
||
|
|
||
| def make_botorch_multitask_likelihood( | ||
| num_tasks: int, task_feature: int | ||
| ) -> HadamardGaussianLikelihood: | ||
| """Adapted from :class:`botorch.models.multitask.MultiTaskGP`.""" | ||
| noise_prior = LogNormalPrior(loc=-4.0, scale=1.0) | ||
| return HadamardGaussianLikelihood( | ||
| num_tasks=num_tasks, | ||
| batch_shape=torch.Size(), | ||
| noise_prior=noise_prior, | ||
| noise_constraint=GreaterThan( | ||
| MIN_INFERRED_NOISE_LEVEL, | ||
| transform=None, | ||
| initial_value=noise_prior.mode, | ||
| ), | ||
| task_feature_index=task_feature, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| """Fitting criteria for the Gaussian process surrogate.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from enum import Enum | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood | ||
| from gpytorch.mlls import MarginalLogLikelihood | ||
| from gpytorch.models import GP as GPyTorchModel | ||
|
|
||
|
|
||
| class FitCriterion(Enum): | ||
| """Available fitting criteria for GP hyperparameter optimization.""" | ||
|
|
||
| MARGINAL_LOG_LIKELIHOOD = "MARGINAL_LOG_LIKELIHOOD" | ||
| """Exact marginal log-likelihood.""" | ||
|
|
||
| LEAVE_ONE_OUT_PSEUDOLIKELIHOOD = "LEAVE_ONE_OUT_PSEUDOLIKELIHOOD" | ||
| """Leave-one-out cross-validation pseudo-likelihood.""" | ||
|
|
||
| def to_gpytorch( | ||
| self, likelihood: GPyTorchLikelihood, model: GPyTorchModel | ||
| ) -> MarginalLogLikelihood: | ||
| """Create the corresponding GPyTorch MLL object.""" | ||
| import gpytorch | ||
|
|
||
| mll_class = { | ||
| FitCriterion.MARGINAL_LOG_LIKELIHOOD: gpytorch.ExactMarginalLogLikelihood, | ||
| FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD: gpytorch.mlls.LeaveOneOutPseudoLikelihood, # noqa: E501 | ||
| }[self] | ||
| return mll_class(likelihood, model) | ||
|
|
||
|
|
||
| # Delayed import to avoid circular dependency | ||
| from baybe.surrogates.gaussian_process.components.generic import ( # noqa: E402 | ||
| GPComponentFactoryProtocol, | ||
| PlainGPComponentFactory, | ||
| ) | ||
|
|
||
| FitCriterionFactoryProtocol = GPComponentFactoryProtocol[FitCriterion] | ||
| """A protocol defining the interface for fit criterion factories.""" | ||
|
|
||
| PlainFitCriterionFactory = PlainGPComponentFactory[FitCriterion] | ||
| """A trivial factory that returns a fixed fit criterion.""" |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -102,7 +102,7 @@ def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None: | |||||||||||||||||||||||||||||||||||||||||
| @override | ||||||||||||||||||||||||||||||||||||||||||
| def __call__( | ||||||||||||||||||||||||||||||||||||||||||
| self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor | ||||||||||||||||||||||||||||||||||||||||||
| ) -> Kernel: | ||||||||||||||||||||||||||||||||||||||||||
| ) -> Kernel | GPyTorchKernel: | ||||||||||||||||||||||||||||||||||||||||||
| """Construct the kernel, validating parameter kinds before construction.""" | ||||||||||||||||||||||||||||||||||||||||||
| if self.parameter_selector is not None: | ||||||||||||||||||||||||||||||||||||||||||
| params = [p for p in searchspace.parameters if self.parameter_selector(p)] | ||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -115,7 +115,7 @@ def __call__( | |||||||||||||||||||||||||||||||||||||||||
| @abstractmethod | ||||||||||||||||||||||||||||||||||||||||||
| def _make( | ||||||||||||||||||||||||||||||||||||||||||
| self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor | ||||||||||||||||||||||||||||||||||||||||||
| ) -> Kernel: | ||||||||||||||||||||||||||||||||||||||||||
| ) -> Kernel | GPyTorchKernel: | ||||||||||||||||||||||||||||||||||||||||||
| """Construct the kernel.""" | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -171,10 +171,43 @@ def _default_task_kernel_factory(self) -> KernelFactoryProtocol: | |||||||||||||||||||||||||||||||||||||||||
| def __call__( | ||||||||||||||||||||||||||||||||||||||||||
| self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor | ||||||||||||||||||||||||||||||||||||||||||
| ) -> Kernel: | ||||||||||||||||||||||||||||||||||||||||||
| if searchspace.task_idx is None: | ||||||||||||||||||||||||||||||||||||||||||
| raise IncompatibleSearchSpaceError( | ||||||||||||||||||||||||||||||||||||||||||
| f"'{type(self).__name__}' can only be used with a searchspace that " | ||||||||||||||||||||||||||||||||||||||||||
| f"contains a '{TaskParameter.__name__}'." | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| base_kernel = self.base_kernel_factory(searchspace, train_x, train_y) | ||||||||||||||||||||||||||||||||||||||||||
| task_kernel = self.task_kernel_factory(searchspace, train_x, train_y) | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(base_kernel, Kernel): | ||||||||||||||||||||||||||||||||||||||||||
| base_kernel = base_kernel.to_gpytorch(searchspace) | ||||||||||||||||||||||||||||||||||||||||||
| if isinstance(task_kernel, Kernel): | ||||||||||||||||||||||||||||||||||||||||||
| task_kernel = task_kernel.to_gpytorch(searchspace) | ||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||
| # Ensure correct partitioning between base and task kernels active dimensions | ||||||||||||||||||||||||||||||||||||||||||
| all_idcs = set(range(len(searchspace.comp_rep_columns))) | ||||||||||||||||||||||||||||||||||||||||||
| allowed_task_idcs = {searchspace.task_idx} | ||||||||||||||||||||||||||||||||||||||||||
| allowed_base_idcs = all_idcs - allowed_task_idcs | ||||||||||||||||||||||||||||||||||||||||||
| base_idcs = ( | ||||||||||||||||||||||||||||||||||||||||||
| set(dims) | ||||||||||||||||||||||||||||||||||||||||||
| if (dims := base_kernel.active_dims.tolist()) is not None | ||||||||||||||||||||||||||||||||||||||||||
| else None | ||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||
| task_idcs = ( | ||||||||||||||||||||||||||||||||||||||||||
| set(dims) | ||||||||||||||||||||||||||||||||||||||||||
| if (dims := task_kernel.active_dims.tolist()) is not None | ||||||||||||||||||||||||||||||||||||||||||
| else None | ||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+191
to
+199
|
||||||||||||||||||||||||||||||||||||||||||
| base_idcs = ( | |
| set(dims) | |
| if (dims := base_kernel.active_dims.tolist()) is not None | |
| else None | |
| ) | |
| task_idcs = ( | |
| set(dims) | |
| if (dims := task_kernel.active_dims.tolist()) is not None | |
| else None | |
| base_active_dims = base_kernel.active_dims | |
| task_active_dims = task_kernel.active_dims | |
| base_idcs = ( | |
| all_idcs | |
| if base_active_dims is None | |
| else set(base_active_dims.tolist()) | |
| ) | |
| task_idcs = ( | |
| all_idcs | |
| if task_active_dims is None | |
| else set(task_active_dims.tolist()) |
Copilot
AI
Apr 17, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The subset check for base-kernel active_dims is incorrect: base_idcs > allowed_base_idcs checks for a strict superset, not “not a subset”. This will miss invalid cases (e.g. {0, task_idx}) and potentially flag none. Use a proper subset validation (e.g. not base_idcs <= allowed_base_idcs) and consider a clearer error if active_dims is None (meaning “all dims”).
| if base_idcs is not None and (base_idcs > allowed_base_idcs): | |
| raise ValueError( | |
| if base_idcs is None: | |
| raise ValueError( | |
| "The base kernel's 'active_dims' must be restricted to the non-task " | |
| f"indices {allowed_base_idcs}; got None, which means all dimensions." | |
| ) | |
| if not base_idcs <= allowed_base_idcs: | |
| raise ValueError( |
Uh oh!
There was an error while loading. Please reload this page.