Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions bofire/surrogates/multi_task_gp.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import warnings
from typing import Dict, Optional

import botorch
import numpy as np
import pandas as pd
import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models.kernels.positive_index import PositiveIndexKernel
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from gpytorch.kernels import IndexKernel
from gpytorch.mlls import ExactMarginalLogLikelihood

import bofire.kernels.api as kernels
Expand Down Expand Up @@ -77,14 +78,16 @@ def _fit_botorch(
)

if isinstance(self.task_prior, LKJPrior):
warnings.warn(
"The LKJ prior has issues when sampling from the prior, prior has been defaulted to None.",
UserWarning,
task_covar_module = next(
kernel
for kernel in self.model.covar_module.kernels
if isinstance(kernel, (IndexKernel, PositiveIndexKernel))
)
task_covar_module.register_prior(
"IndexKernelPrior",
priors.map(self.task_prior),
_index_kernel_prior_closure,
)
# once the issue is fixed, the following line should be uncommented
# self.model.task_covar_module.register_prior(
# "IndexKernelPrior", priors.map(self.lkj_prior), _index_kernel_prior_closure
# )
self.model.likelihood.noise_covar.noise_prior = priors.map(self.noise_prior)
if self.noise_constraint is not None:
self.model.likelihood.noise_covar.raw_noise_constraint = priors.map(
Expand Down
26 changes: 20 additions & 6 deletions tests/bofire/surrogates/test_multitask_gps.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import importlib

import gpytorch.priors
import pandas as pd
import pytest
import torch
from botorch.models import MultiTaskGP
from botorch.models.kernels.positive_index import PositiveIndexKernel
from botorch.models.transforms.input import InputStandardize, Normalize
from botorch.models.transforms.outcome import ChainedOutcomeTransform, Log, Standardize
from gpytorch.kernels import IndexKernel
from pandas.testing import assert_frame_equal

import bofire.surrogates.api as surrogates
Expand Down Expand Up @@ -142,12 +145,7 @@ def test_MultiTaskGPModel(kernel, scaler, output_scaler, task_prior):
model = surrogates.map(model)
with pytest.raises(ValueError):
model.dumps()
# if task_prior is not None, a warning should be raised
if task_prior is not None:
with pytest.warns(UserWarning):
model.fit(experiments)
else:
model.fit(experiments)
model.fit(experiments)
# check that the active_dims are set correctly
assert torch.allclose(
model.model.covar_module.kernels[0].active_dims,
Expand All @@ -169,6 +167,22 @@ def test_MultiTaskGPModel(kernel, scaler, output_scaler, task_prior):
assert preds.shape == (5, 2)
# check that model is composed correctly
assert isinstance(model.model, MultiTaskGP)
task_covar_module = next(
kernel
for kernel in model.model.covar_module.kernels
if isinstance(kernel, (IndexKernel, PositiveIndexKernel))
)
prior_names = [name for name, *_ in task_covar_module.named_priors()]
if task_prior is not None:
assert "IndexKernelPrior" in prior_names
index_kernel_prior = next(
prior
for name, _, prior, *_ in task_covar_module.named_priors()
if name == "IndexKernelPrior"
)
assert isinstance(index_kernel_prior, gpytorch.priors.LKJCovariancePrior)
else:
assert "IndexKernelPrior" not in prior_names
if output_scaler == ScalerEnum.STANDARDIZE:
assert isinstance(model.model.outcome_transform, Standardize)
elif output_scaler == ScalerEnum.LOG:
Expand Down
Loading