Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8679dba
Add configurable GP optimization criterion component
AdrianSosic May 6, 2026
d76d452
Fix circular import
AdrianSosic May 6, 2026
15e9dbb
Update CHANGELOG.md
AdrianSosic May 6, 2026
920dd88
Fix typing
AdrianSosic May 6, 2026
887c6bc
Extend preset tests
AdrianSosic May 6, 2026
f213893
Make preset factory exports consistent instances with UPPER_CASE naming
AdrianSosic May 6, 2026
79f3604
Drop unnecessary isinstance check
AdrianSosic May 6, 2026
2f847b6
Rename Criterion to FitCriterion
AdrianSosic May 6, 2026
19d6b3b
Make FitCriterion a BayBEGPComponent
AdrianSosic May 6, 2026
0538b72
Rename LEAVE_ONE_OUT to LEAVE_ONE_OUT_PSEUDOLIKELIHOOD
AdrianSosic May 6, 2026
4d41104
Rename criterion.py to fit_criterion.py
AdrianSosic May 6, 2026
578459b
Enable multitask mode for surrogate streamlit
AdrianSosic Mar 2, 2026
5ae782b
Add BOTORCH preset
AdrianSosic Mar 2, 2026
105bf4f
Extend BoTorch preset test to multitask case
AdrianSosic Mar 2, 2026
34276c7
Add custom GPyTorch components to replicate BoTorch logic
AdrianSosic Mar 2, 2026
3c9d40a
Extend BoTorch factories to multitask case
AdrianSosic Mar 2, 2026
927e044
Add kernel active dimension validation to ICMKernelFactory
AdrianSosic Mar 2, 2026
5d0c4ff
Fix KernelFactory return types
AdrianSosic Apr 17, 2026
cfaa839
Make BotorchKernelFactory support parameter selection
AdrianSosic Apr 17, 2026
e6f7dbc
Fix active dimensions validation
AdrianSosic Apr 17, 2026
412d3dc
Bypass kernel warning for presets
AdrianSosic May 8, 2026
4072325
Update CHANGELOG.md
AdrianSosic Mar 2, 2026
d8a5a9b
Rename on-task/off-task to target/source in streamlit
AdrianSosic May 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Gaussian process component factories
- Support for GPyTorch objects (kernels, means, likelihood) as Gaussian process
components, enabling full low-level customization
- Configurable fitting criterion for Gaussian process hyperparameter optimization
- Factories for all Gaussian process components
- `CHEN`, `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- `BOTORCH`, `CHEN`, `EDBO` and `EDBO_SMOOTHED` presets for `GaussianProcessSurrogate`
- `TypeSelector` and `NameSelector` classes for parameter selection in kernel factories
- `parameter_names` attribute to basic kernels for controlling the considered parameters
- `ParameterKind` flag enum for classifying parameters by their role and automatic
Expand Down
9 changes: 9 additions & 0 deletions baybe/surrogates/gaussian_process/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""Gaussian process surrogate components."""

from baybe.surrogates.gaussian_process.components.fit_criterion import (
FitCriterion,
FitCriterionFactoryProtocol,
PlainFitCriterionFactory,
)
from baybe.surrogates.gaussian_process.components.kernel import (
KernelFactoryProtocol,
PlainKernelFactory,
Expand All @@ -15,6 +20,10 @@
)

__all__ = [
# Fit Criterion
"FitCriterion",
"FitCriterionFactoryProtocol",
"PlainFitCriterionFactory",
# Kernel
"KernelFactoryProtocol",
"PlainKernelFactory",
Expand Down
71 changes: 71 additions & 0 deletions baybe/surrogates/gaussian_process/components/_gpytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Custom GPyTorch components."""

import torch
from botorch.models.multitask import _compute_multitask_mean
from botorch.models.utils.gpytorch_modules import MIN_INFERRED_NOISE_LEVEL
from gpytorch.constraints import GreaterThan
from gpytorch.likelihoods.hadamard_gaussian_likelihood import HadamardGaussianLikelihood
from gpytorch.means import MultitaskMean
from gpytorch.means.multitask_mean import Mean
from gpytorch.priors import LogNormalPrior
from torch import Tensor
Comment thread
AdrianSosic marked this conversation as resolved.
from torch.nn import Module


class HadamardConstantMean(Mean):
"""A GPyTorch mean function implementing BoTorch's multitask mean logic.

While GPyTorch already provides a :class:`~gpytorch.means.MultitaskMean` class, it
computes mean values for all (input, task)-pairs (where input means all parameters
except the task parameter), i.e. it intrinsically applies a Cartesian expansion.
However, for the regular transfer learning setting, we only need the means for the
pairs that are actually observed/requested. BoTorch subselects the relevant means
from the GPyTorch output in `MultiTaskGP.forward`, i.e. it uses a class-based
approach to define its special logic for the multitask case. In contrast, BayBE uses
a composition approach, which is more flexible but requires that the logic is
injected via a self-contained `Mean` object, which is what this class provides.

Note:
Analogous to GPyTorch's
https://github.com/cornellius-gp/gpytorch/blob/main/gpytorch/likelihoods/hadamard_gaussian_likelihood.py
but where the logic is applied to the mean function, i.e. we learn a different
(constant) mean for each task.
"""

def __init__(self, mean_module: Module, num_tasks: int, task_feature: int):
super().__init__()
self.multitask_mean = MultitaskMean(mean_module, num_tasks=num_tasks)
self.task_feature = task_feature

def forward(self, x: Tensor) -> Tensor:
# Adapted from https://github.com/meta-pytorch/botorch/blob/e0f4f5b941b5949a4a1171bf8d4ee9f74f146f3a/botorch/models/multitask.py#L397

# Convert task feature to positive index
task_feature = self.task_feature % x.shape[-1]

# Split input into task and non-task components
x_before = x[..., :task_feature]
task_idcs = x[..., task_feature : task_feature + 1]
x_after = x[..., task_feature + 1 :]

return _compute_multitask_mean(
self.multitask_mean, x_before, task_idcs, x_after
)


def make_botorch_multitask_likelihood(
num_tasks: int, task_feature: int
) -> HadamardGaussianLikelihood:
"""Adapted from :class:`botorch.models.multitask.MultiTaskGP`."""
noise_prior = LogNormalPrior(loc=-4.0, scale=1.0)
return HadamardGaussianLikelihood(
num_tasks=num_tasks,
batch_shape=torch.Size(),
noise_prior=noise_prior,
noise_constraint=GreaterThan(
MIN_INFERRED_NOISE_LEVEL,
transform=None,
initial_value=noise_prior.mode,
),
task_feature_index=task_feature,
)
46 changes: 46 additions & 0 deletions baybe/surrogates/gaussian_process/components/fit_criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Fitting criteria for the Gaussian process surrogate."""

from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from gpytorch.likelihoods import Likelihood as GPyTorchLikelihood
from gpytorch.mlls import MarginalLogLikelihood
from gpytorch.models import GP as GPyTorchModel


class FitCriterion(Enum):
"""Available fitting criteria for GP hyperparameter optimization."""

MARGINAL_LOG_LIKELIHOOD = "MARGINAL_LOG_LIKELIHOOD"
"""Exact marginal log-likelihood."""

LEAVE_ONE_OUT_PSEUDOLIKELIHOOD = "LEAVE_ONE_OUT_PSEUDOLIKELIHOOD"
"""Leave-one-out cross-validation pseudo-likelihood."""

def to_gpytorch(
self, likelihood: GPyTorchLikelihood, model: GPyTorchModel
) -> MarginalLogLikelihood:
"""Create the corresponding GPyTorch MLL object."""
import gpytorch

mll_class = {
FitCriterion.MARGINAL_LOG_LIKELIHOOD: gpytorch.ExactMarginalLogLikelihood,
FitCriterion.LEAVE_ONE_OUT_PSEUDOLIKELIHOOD: gpytorch.mlls.LeaveOneOutPseudoLikelihood, # noqa: E501
}[self]
return mll_class(likelihood, model)


# Delayed import to avoid circular dependency
from baybe.surrogates.gaussian_process.components.generic import ( # noqa: E402
GPComponentFactoryProtocol,
PlainGPComponentFactory,
)

FitCriterionFactoryProtocol = GPComponentFactoryProtocol[FitCriterion]
"""A protocol defining the interface for fit criterion factories."""

PlainFitCriterionFactory = PlainGPComponentFactory[FitCriterion]
"""A trivial factory that returns a fixed fit criterion."""
16 changes: 13 additions & 3 deletions baybe/surrogates/gaussian_process/components/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from baybe.searchspace import SearchSpace
from baybe.serialization.core import block_serialization_hook, converter
from baybe.serialization.mixin import SerialMixin
from baybe.surrogates.gaussian_process.components.fit_criterion import FitCriterion

BayBEGPComponent: TypeAlias = Kernel
BayBEGPComponent: TypeAlias = Kernel | FitCriterion

if TYPE_CHECKING:
from gpytorch.kernels import Kernel as GPyTorchKernel
Expand Down Expand Up @@ -44,15 +45,24 @@ class GPComponentType(Enum):
LIKELIHOOD = "LIKELIHOOD"
"""Gaussian process likelihood."""

CRITERION = "CRITERION"
"""Gaussian process fitting criterion."""

def get_types(self) -> tuple[type, ...]:
"""Get the accepted BayBE and GPyTorch types for this component."""
types = []
types: list[type[GPComponent]] = []

# Add BayBE type if applicable
if self is GPComponentType.KERNEL:
from baybe.kernels.base import Kernel

types.append(Kernel)
elif self is GPComponentType.CRITERION:
from baybe.surrogates.gaussian_process.components.fit_criterion import (
FitCriterion,
)

types.append(FitCriterion)

# Add GPyTorch type if available
if sys.modules.get("gpytorch") is not None:
Expand Down Expand Up @@ -85,7 +95,7 @@ def _is_gpytorch_component_class(obj: Any, /) -> bool:

def _validate_component(instance: Any, attribute: Attribute, value: Any) -> None:
"""Validate that an object is a BayBE or a GPyTorch GP component."""
if isinstance(value, Kernel) or _is_gpytorch_component_class(type(value)):
if isinstance(value, BayBEGPComponent) or _is_gpytorch_component_class(type(value)):
return

raise TypeError(
Expand Down
37 changes: 35 additions & 2 deletions baybe/surrogates/gaussian_process/components/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _validate_parameter_kinds(self, parameters: Iterable[Parameter]) -> None:
@override
def __call__(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
) -> Kernel | GPyTorchKernel:
"""Construct the kernel, validating parameter kinds before construction."""
if self.parameter_selector is not None:
params = [p for p in searchspace.parameters if self.parameter_selector(p)]
Expand All @@ -115,7 +115,7 @@ def __call__(
@abstractmethod
def _make(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
) -> Kernel | GPyTorchKernel:
"""Construct the kernel."""


Expand Down Expand Up @@ -171,10 +171,43 @@ def _default_task_kernel_factory(self) -> KernelFactoryProtocol:
def __call__(
self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor
) -> Kernel:
if searchspace.task_idx is None:
raise IncompatibleSearchSpaceError(
f"'{type(self).__name__}' can only be used with a searchspace that "
f"contains a '{TaskParameter.__name__}'."
)

base_kernel = self.base_kernel_factory(searchspace, train_x, train_y)
task_kernel = self.task_kernel_factory(searchspace, train_x, train_y)
if isinstance(base_kernel, Kernel):
base_kernel = base_kernel.to_gpytorch(searchspace)
if isinstance(task_kernel, Kernel):
task_kernel = task_kernel.to_gpytorch(searchspace)

# Ensure correct partitioning between base and task kernels active dimensions
all_idcs = set(range(len(searchspace.comp_rep_columns)))
allowed_task_idcs = {searchspace.task_idx}
allowed_base_idcs = all_idcs - allowed_task_idcs
base_idcs = (
set(dims)
if (dims := base_kernel.active_dims.tolist()) is not None
else None
)
task_idcs = (
set(dims)
if (dims := task_kernel.active_dims.tolist()) is not None
else None
Comment on lines +191 to +199
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_kernel.active_dims / task_kernel.active_dims can be None in GPyTorch kernels. Calling .tolist() unconditionally will raise an AttributeError (NoneType has no attribute tolist). Consider guarding with an is not None check (or using getattr(..., None)) before converting to a set, and treat None explicitly in the validation logic.

Suggested change
base_idcs = (
set(dims)
if (dims := base_kernel.active_dims.tolist()) is not None
else None
)
task_idcs = (
set(dims)
if (dims := task_kernel.active_dims.tolist()) is not None
else None
base_active_dims = base_kernel.active_dims
task_active_dims = task_kernel.active_dims
base_idcs = (
all_idcs
if base_active_dims is None
else set(base_active_dims.tolist())
)
task_idcs = (
all_idcs
if task_active_dims is None
else set(task_active_dims.tolist())

Copilot uses AI. Check for mistakes.
)

if base_idcs is not None and (base_idcs > allowed_base_idcs):
raise ValueError(
Comment on lines +202 to +203
Copy link

Copilot AI Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The subset check for base-kernel active_dims is incorrect: base_idcs > allowed_base_idcs checks for a strict superset, not “not a subset”. This will miss invalid cases (e.g. {0, task_idx}) and potentially flag none. Use a proper subset validation (e.g. not base_idcs <= allowed_base_idcs) and consider a clearer error if active_dims is None (meaning “all dims”).

Suggested change
if base_idcs is not None and (base_idcs > allowed_base_idcs):
raise ValueError(
if base_idcs is None:
raise ValueError(
"The base kernel's 'active_dims' must be restricted to the non-task "
f"indices {allowed_base_idcs}; got None, which means all dimensions."
)
if not base_idcs <= allowed_base_idcs:
raise ValueError(

Copilot uses AI. Check for mistakes.
f"The base kernel's 'active_dims' {base_idcs} must be a subset of "
f"the non-task indices {allowed_base_idcs}."
)
if task_idcs != allowed_task_idcs:
raise ValueError(
f"The task kernel's 'active_dims' {task_idcs} does not match "
f"the task index {allowed_task_idcs}."
)

return base_kernel * task_kernel
56 changes: 38 additions & 18 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from baybe.parameters.categorical import TaskParameter
from baybe.searchspace.core import SearchSpace
from baybe.surrogates.base import Surrogate
from baybe.surrogates.gaussian_process.components.fit_criterion import (
FitCriterion,
FitCriterionFactoryProtocol,
)
from baybe.surrogates.gaussian_process.components.generic import (
GPComponentType,
to_component_factory,
Expand All @@ -35,6 +39,7 @@
GaussianProcessPreset,
)
from baybe.surrogates.gaussian_process.presets.baybe import (
BayBEFitCriterionFactory,
BayBEKernelFactory,
BayBELikelihoodFactory,
BayBEMeanFactory,
Expand Down Expand Up @@ -178,6 +183,21 @@ class GaussianProcessSurrogate(Surrogate):
* :class:`gpytorch.likelihoods.Likelihood`
"""

criterion_factory: FitCriterionFactoryProtocol = field(
alias="criterion_or_factory",
factory=BayBEFitCriterionFactory,
converter=partial( # type: ignore[misc]
to_component_factory, component_type=GPComponentType.CRITERION
),
validator=is_callable(),
)
"""The fitting criterion for Gaussian process hyperparameter optimization.

Accepts:
* :class:`.components.fit_criterion.FitCriterion`
* :class:`.components.fit_criterion.FitCriterionFactoryProtocol`
"""

# TODO: type should be Optional[botorch.models.SingleTaskGP] but is currently
# omitted due to: https://github.com/python-attrs/cattrs/issues/531
_model = field(init=False, default=None, eq=False)
Expand All @@ -195,6 +215,7 @@ def from_preset(
likelihood_or_factory: LikelihoodFactoryProtocol
| GPyTorchLikelihood
| None = None,
criterion_or_factory: FitCriterion | FitCriterionFactoryProtocol | None = None,
) -> Self:
"""Create a Gaussian process surrogate from one of the defined presets."""
preset = GaussianProcessPreset(preset)
Expand All @@ -204,13 +225,18 @@ def from_preset(
)
module = importlib.import_module(module_name)

kernel = kernel_or_factory or getattr(module, "PresetKernelFactory")()
mean = mean_or_factory or getattr(module, "PresetMeanFactory")()
likelihood = (
likelihood_or_factory or getattr(module, "PresetLikelihoodFactory")()
kernel = kernel_or_factory or getattr(module, "PRESET_KERNEL_FACTORY")
mean = mean_or_factory or getattr(module, "PRESET_MEAN_FACTORY")
likelihood = likelihood_or_factory or getattr(
module, "PRESET_LIKELIHOOD_FACTORY"
)
criterion = criterion_or_factory or getattr(
module, "PRESET_FIT_CRITERION_FACTORY"
)

return cls(kernel, mean, likelihood)
gp = cls(kernel, mean, likelihood, criterion)
gp._custom_kernel = False # preset are first-party features
return gp

@override
def to_botorch(self) -> GPyTorchModel:
Expand All @@ -237,7 +263,6 @@ def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior:
@override
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
import botorch
import gpytorch
from botorch.models.transforms import Normalize, Standardize

assert self._searchspace is not None # provided by base class
Expand Down Expand Up @@ -281,6 +306,9 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
### Likelihood
likelihood = self.likelihood_factory(context.searchspace, train_x, train_y)

### Criterion
criterion = self.criterion_factory(context.searchspace, train_x, train_y)

### Model construction and fitting
self._model = botorch.models.SingleTaskGP(
train_x,
Expand All @@ -291,18 +319,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
covar_module=kernel,
likelihood=likelihood,
)

# TODO: This is still a temporary workaround to avoid overfitting seen in
# low-dimensional TL cases. More robust settings are being researched.
if context.n_task_dimensions > 0:
mll = gpytorch.mlls.LeaveOneOutPseudoLikelihood(
self._model.likelihood, self._model
)
else:
mll = gpytorch.ExactMarginalLogLikelihood(
self._model.likelihood, self._model
)

mll = criterion.to_gpytorch(self._model.likelihood, self._model)
botorch.fit.fit_gpytorch_mll(mll)

@override
Expand All @@ -311,6 +328,9 @@ def __str__(self) -> str:
to_string("Kernel factory", self.kernel_factory, single_line=True),
to_string("Mean factory", self.mean_factory, single_line=True),
to_string("Likelihood factory", self.likelihood_factory, single_line=True),
to_string(
"Fit criterion factory", self.criterion_factory, single_line=True
),
]
return to_string(super().__str__(), *fields)

Expand Down
Loading
Loading