Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
DefaultKernelFactory,
_default_noise_factory,
)
from baybe.surrogates.gaussian_process.prior_modules import PriorMean
from baybe.utils.conversion import to_string

if TYPE_CHECKING:
from botorch.models import SingleTaskGP
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
Expand Down Expand Up @@ -113,11 +115,57 @@ class GaussianProcessSurrogate(Surrogate):
_model = field(init=False, default=None, eq=False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that we already have this _model attribute, can you explain why we need to introduce yet another attribute like _prior_gp? Naively I would suspect the first contains the latter. Or at least we should strive to avoid putting alot of additional attributes in this class (because they will esentially be irrelevant for non TL cases)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point and I thought the same before, but unfortunately I couldn't find a good solution to this: The problem is that there is a gap between creation via from_prior and fitting the model via fit. The instance must somehow remember it should use transfer learning and its prior to be able to create the _model. I'd be happy to change this and will give it another thought. Maybe the logic could be moved to some KernelFactory or MeanFactory. Do you have any suggestions how to get rid of this attribute?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AVHopp , do you maybe have an idea?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about introducing a new mean factory similar to the kernel factories in BayBE?

  # New default
  class ConstantMeanFactory(MeanFactory):
      def __call__(self, ..)
          return gpytorch.means.ConstantMean()

  class PriorMeanFactory(MeanFactory):
      def __init__(self, prior_gp: GPSurrogate):
          self.prior_gp = deepcopy(prior_gp)

      def __call__(self, batch_shape: torch.Size) :
          return PriorMean()

Then in from_prior I'd just replace the mean factory by the new PriorMeanFactory and could remove the attribute from the surrogate class, but this would add an entirely new factory pattern to BayBE.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I would prefer that a bit

although the main reason for the factory was that search space info is needed when creating the kernels, which is not available yet when specifying the attribute here to the surrogate. that doesnt seem to bet he case here with the means, right?

So a factory is not strictly needed but I see till two advantages why I would prefer it:

  • it wouldnt be an emtpy unused content in no prior gp is used as it would hold the default factory
  • it would be more consistent to have all kinds of fatories rather than haveing a mixture of factories and other optional model-related attributes

About _model: This is supposed to hold the fitted botorch model right? So would it make any sense to only partially initialize it with the means? If no, then forget hat idea

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this thread now still relevant, given that we agreed to a Factory approach in our meeting (iirc)?

"""The actual model."""

# Transfer learning fields
_prior_gp = field(init=False, default=None, eq=False)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no type?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see its prob the same issue as with _model so ideally you can paste the same comment thats there also here

depending ont he design this attribbute might also be removed and the commen is obsolete

"""Prior GP to extract mean/covariance from for transfer learning."""
Comment thread
kalama-ai marked this conversation as resolved.
Outdated

@staticmethod
def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate:
"""Create a Gaussian process surrogate from one of the defined presets."""
return make_gp_from_preset(preset)

@classmethod
def from_prior(
Comment thread
AVHopp marked this conversation as resolved.
cls,
prior_gp: SingleTaskGP,
Comment thread
Scienfitz marked this conversation as resolved.
Outdated
kernel_factory: KernelFactory | None = None,
**kwargs,
) -> GaussianProcessSurrogate:
"""Create a GP surrogate with mean function transfer learning.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this docstring needs a bit more explanation on what exactly is done and transferred. Also, the description in the Returns: part could contain more information (but might not be needed if you add 2-3 sentences here describing what this does in more detail)


Args:
prior_gp: Fitted SingleTaskGP to use as prior
kernel_factory: Kernel factory for covariance components
**kwargs: Additional arguments for GaussianProcessSurrogate constructor

Returns:
New GaussianProcessSurrogate instance with transfer learning

Raises:
ValueError: If prior_gp is not fitted
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is also a raise of the variable is not a SingleTaskGP which is not mentioned here?

"""
from copy import deepcopy

from botorch.models import SingleTaskGP
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import statements are placed inside the method. Since deepcopy is already imported at the module level (line 5 in prior_modules.py) and SingleTaskGP is imported in the TYPE_CHECKING block (line 31), these local imports are redundant and should be removed in favor of the module-level imports.

Copilot uses AI. Check for mistakes.

# Validate prior GP is fitted
if not isinstance(prior_gp, SingleTaskGP):
raise ValueError("prior_gp must be a fitted SingleTaskGP instance")
if not hasattr(prior_gp, "train_inputs") or prior_gp.train_inputs is None:
raise ValueError("Prior GP must be fitted (have train_inputs) before use")

# Configure kernel factory (always needed since we only do mean transfer now)
if kernel_factory is None:
kernel_factory = DefaultKernelFactory()

# Create new surrogate instance
instance = cls(kernel_or_factory=kernel_factory, **kwargs)

# Configure for transfer learning
instance._prior_gp = deepcopy(prior_gp)

return instance

@override
def to_botorch(self) -> GPyTorchModel:
return self._model
Expand Down Expand Up @@ -152,22 +200,30 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
assert self._searchspace is not None

context = _ModelContext(self._searchspace)

numerical_idxs = context.get_numerical_indices(train_x.shape[-1])

# For GPs, we let botorch handle the scaling. See [Scaling Workaround] above.
input_transform = Normalize(
train_x.shape[-1],
bounds=context.parameter_bounds,
indices=list(numerical_idxs),
)
outcome_transform = Standardize(train_y.shape[-1])

# extract the batch shape of the training data
batch_shape = train_x.shape[:-2]

# Configure input/output transforms
if self._prior_gp is not None and hasattr(self._prior_gp, "input_transform"):
Comment thread
AVHopp marked this conversation as resolved.
# Use prior's transforms for consistency in transfer learning
input_transform = self._prior_gp.input_transform
outcome_transform = self._prior_gp.outcome_transform
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since there is an explicit check for inout_transform, is it always guaranteed to have output_transform?

Why is the heck for input_transform even needed?

else:
# For GPs, we let botorch handle scaling. See [Scaling Workaround] above.
input_transform = Normalize(
train_x.shape[-1],
bounds=context.parameter_bounds,
indices=numerical_idxs,
Copy link

Copilot AI Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indices parameter expects a list but numerical_idxs is a tuple. While this may work in practice, it's inconsistent with the previous implementation that used list(numerical_idxs) on line 217 in the original code. For consistency and to match the expected type, convert the tuple to a list.

Suggested change
indices=numerical_idxs,
indices=list(numerical_idxs),

Copilot uses AI. Check for mistakes.
)
outcome_transform = Standardize(train_y.shape[-1])

# create GP mean
mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)
if self._prior_gp is not None:
mean_module = PriorMean(self._prior_gp, batch_shape=batch_shape)
else:
mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)

# define the covariance module for the numeric dimensions
base_covar_module = self.kernel_factory(
Expand Down
55 changes: 55 additions & 0 deletions baybe/surrogates/gaussian_process/prior_modules.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the name x_modules is a bit inconsistent compared to our other naming
just means.py?

Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Prior modules for Gaussian process transfer learning."""

from __future__ import annotations

from copy import deepcopy
from typing import Any

import gpytorch
import torch
from botorch.models import SingleTaskGP
from torch import Tensor


class PriorMean(gpytorch.means.Mean):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some question for understanding:

  1. What does this new class achieve what is not possible with the existing botorch constant mean class?
  2. When the incoming mean is a constant mean, this class would also effectively produce a contant mean?
  3. Afaik all our GP have constant mean, so everything woudl forever be costant mean. Is this class here then necessary? Couldnt we just use the botorch cosntant mean class for the new TL case as well, except that the number is fixed and predetermined, ie somehow "set"?

Copy link
Copy Markdown
Collaborator Author

@kalama-ai kalama-ai Jan 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be some misunderstanding here. The incoming mean is not constant since the prior GP is fitted on some data already and we are using its posterior here. Even if the prior GP originally had a ConstantMean, once trained, its posterior mean will not be constant anymore. Or am I misunderstanding your comment?

Copy link
Copy Markdown
Collaborator

@Scienfitz Scienfitz Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks for clarifying, I see now the need for the class
please lets just make sure this is optimized and does not impose any computaitonal bottleneck

Is it also right that this implementaiton is the variant of completely frozen prio mean? ie the mean is not just a prior but its forever the mean for our actual GP used int he campaign?

"""GPyTorch mean module using a trained GP as prior mean.

This mean module wraps a trained Gaussian Process and uses its predictions
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

process in Gaussian process should not be capitalized (unless its a headline or similar)

as the mean function for another GP.

Args:
gp: Trained Gaussian Process to use as mean function.
batch_shape: Batch shape for the mean module.
**kwargs: Additional keyword arguments.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to include those in this class/the __init__? Currently they seem to be silently ignored, so I would propose to either remove them completely if possible or at least mention that they are being ignored.

"""

def __init__(
self, gp: SingleTaskGP, batch_shape: torch.Size = torch.Size(), **kwargs: Any
) -> None:
super().__init__()

# Deep copy and freeze the GP
self.gp: SingleTaskGP = deepcopy(gp)
self.batch_shape: torch.Size = batch_shape

# Freeze parameters and set eval mode once
for param in self.gp.parameters():
param.requires_grad = False

def forward(self, x: Tensor) -> Tensor:
"""Compute the mean function using the wrapped GP.

Args:
x: Input tensor for which to compute the mean.

Returns:
Mean predictions from the wrapped GP.
"""
self.gp.eval()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldnt it make sense to move these eval statements into init because they are only needed once?

self.gp.likelihood.eval()
Comment thread
AVHopp marked this conversation as resolved.
with torch.no_grad(), gpytorch.settings.fast_pred_var():
mean = self.gp(x).mean.detach()

# Handle batch dimensions
target_shape = torch.broadcast_shapes(self.batch_shape, x.shape[:-1])
return mean.reshape(target_shape)
Loading