diff --git a/bofire/benchmarks/benchmark.py b/bofire/benchmarks/benchmark.py index ff1f24835..44aa5685c 100644 --- a/bofire/benchmarks/benchmark.py +++ b/bofire/benchmarks/benchmark.py @@ -243,22 +243,49 @@ def get_optima(self) -> pd.DataFrame: class SpuriousFeaturesWrapper(Benchmark): - """Wrapper that adds spurious features to a benchmark, that are ignored on evaluation.""" + """Wrapper that adds spurious features to a benchmark, that are ignored on evaluation. - def __init__(self, benchmark: Benchmark, n_spurious_features: int = 1, **kwargs): + Args: + benchmark: The benchmark to wrap. + n_spurious_features: Number of spurious features to add. + max_count: If provided, adds an NChooseKConstraint on all input features + (original + spurious) limiting the number of non-zero features. + """ + + def __init__( + self, + benchmark: Benchmark, + n_spurious_features: int = 1, + max_count: Optional[int] = None, + **kwargs, + ): super().__init__(**kwargs) assert n_spurious_features >= 1, "n_spurious_features must be >= 1." self._benchmark = benchmark + + inputs = Inputs( + features=benchmark.domain.inputs.features # ty: ignore[unsupported-operator] + + [ + ContinuousInput(key=f"x_spurious_{i}", bounds=(0, 1)) + for i in range(n_spurious_features) + ] + ) + + constraints = list(self._benchmark.domain.constraints.constraints) + if max_count is not None: + constraints.append( + NChooseKConstraint( + features=inputs.get_keys(), + max_count=max_count, + min_count=0, + none_also_valid=True, + ) + ) + self._domain = Domain( - inputs=Inputs( - features=benchmark.domain.inputs.features # ty: ignore[unsupported-operator] - + [ - ContinuousInput(key=f"x_spurious_{i}", bounds=(0, 1)) - for i in range(n_spurious_features) - ] - ), + inputs=inputs, outputs=self._benchmark.domain.outputs, - constraints=self._benchmark.domain.constraints, + constraints=Constraints(constraints=constraints), ) def _f(self, candidates: pd.DataFrame, **kwargs) -> pd.DataFrame: diff --git a/bofire/data_models/strategies/random.py b/bofire/data_models/strategies/random.py index d1f0326d8..4fc9bbfd1 100644 --- a/bofire/data_models/strategies/random.py +++ b/bofire/data_models/strategies/random.py @@ -25,6 +25,7 @@ class RandomStrategy(Strategy): n_thinning: Annotated[int, Field(ge=1)] = 32 num_base_samples: Optional[Annotated[int, Field(gt=0)]] = None max_iters: Annotated[int, Field(gt=0)] = 1000 + max_combinations: Annotated[int, Field(gt=0)] = 64 sampler_kwargs: Optional[dict] = None def is_constraint_implemented(self, my_type: Type[Constraint]) -> bool: diff --git a/bofire/strategies/predictives/acqf_optimization.py b/bofire/strategies/predictives/acqf_optimization.py index 0a5938397..7a924fe5c 100644 --- a/bofire/strategies/predictives/acqf_optimization.py +++ b/bofire/strategies/predictives/acqf_optimization.py @@ -47,6 +47,7 @@ from bofire.data_models.strategies.shortest_path import has_local_search_region from bofire.data_models.types import InputTransformSpecs from bofire.strategies import utils +from bofire.strategies.predictives.optimize_mcts import optimize_acqf_mcts from bofire.strategies.random import RandomStrategy from bofire.strategies.shortest_path import ShortestPathStrategy from bofire.utils.torch_tools import ( @@ -63,6 +64,7 @@ class OptimizerEnum(str, Enum): OPTIMIZE_ACQF = "OPTIMIZE_ACQF" OPTIMIZE_ACQF_MIXED = "OPTIMIZE_ACQF_MIXED" OPTIMIZE_ACQF_MIXED_ALTERNATING = "OPTIMIZE_ACQF_MIXED_ALTERNATING" + OPTIMIZE_ACQF_MCTS = "OPTIMIZE_ACQF_MCTS" # Threshold for switching between optimizers optimize_acqf_mixed @@ -357,6 +359,36 @@ class _OptimizeAcqfMixedAlternatingInput(_OptimizeAcqfInputBase): equality_constraints: list[tuple[Tensor, Tensor, float]] | None +class _OptimizeAcqfMctsInput(_OptimizeAcqfInputBase): + acq_function: Callable + bounds: Tensor + nchooseks: list[tuple[list[int], int, int]] | None + cat_dims: Dict[int, List[float]] | None + nig_alpha0: float + ts_prior_var: float + adaptive_prior_var: bool + cache_hit_mode: str + variance_decay: float + rollout_mode: str + adaptive_n0: bool + p_stop_rollout: float + num_iterations: int + pw_k0: float + pw_alpha: float + max_rollout_retries: int + use_cache: bool + n_sobol_samples: int + top_k_refine: int + screening_num_iterations: int | None + q: int + raw_samples: int + num_restarts: int + fixed_features: dict[int, float] | None + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None + equality_constraints: list[tuple[Tensor, Tensor, float]] | None + seed: int | None + + class BotorchOptimizer(AcquisitionOptimizer): def __init__(self, data_model: BotorchOptimizerDataModel): self.n_restarts = data_model.n_restarts @@ -454,6 +486,7 @@ def _optimize_acqf_continuous( OptimizerEnum.OPTIMIZE_ACQF: optimize_acqf, OptimizerEnum.OPTIMIZE_ACQF_MIXED: optimize_acqf_mixed, OptimizerEnum.OPTIMIZE_ACQF_MIXED_ALTERNATING: optimize_acqf_mixed_alternating, + OptimizerEnum.OPTIMIZE_ACQF_MCTS: optimize_acqf_mcts, } candidates, acqf_vals = optimizer_mapping[optimizer]( **optimizer_input.model_dump() @@ -482,6 +515,12 @@ def _get_optimizer_options(self, domain: Domain) -> Dict[str, int]: } def _determine_optimizer(self, domain: Domain, n_acqfs) -> OptimizerEnum: + # Check if we have NChooseK constraints - if so, use MCTS optimizer + if len(domain.constraints.get([NChooseKConstraint])) > 0 or any( + isinstance(feat, ContinuousInput) and feat.allow_zero + for feat in domain.inputs.get(ContinuousInput) + ): + return OptimizerEnum.OPTIMIZE_ACQF_MCTS if n_acqfs > 1: return OptimizerEnum.OPTIMIZE_ACQF_LIST n_categorical_combinations = ( @@ -508,6 +547,7 @@ def _get_arguments_for_optimizer( | _OptimizeAcqfMixedInput | _OptimizeAcqfListInput | _OptimizeAcqfMixedAlternatingInput + | _OptimizeAcqfMctsInput ): input_preprocessing_specs = self._input_preprocessing_specs(domain) features2idx = self._features2idx(domain) @@ -616,6 +656,68 @@ def _get_arguments_for_optimizer( if feat.key not in fixed_keys }, ) + elif optimizer == OptimizerEnum.OPTIMIZE_ACQF_MCTS: + # Convert NChooseKConstraint to tuples for MCTS + nchoosek_constraints = domain.constraints.get([NChooseKConstraint]) + nchooseks_list = [] + nchoosek_feature_keys: set[str] = set() + for constraint in nchoosek_constraints: + # Get feature indices for the constraint + feature_indices = [ + features2idx[feat_key][0] for feat_key in constraint.features + ] + # Create tuple (features, min_count, max_count) + nchooseks_list.append( + (feature_indices, constraint.min_count, constraint.max_count) + ) + nchoosek_feature_keys.update(constraint.features) + # Continuous features with allow_zero=True are treated as NChooseK + # constraints where min_count=0 and max_count=1, but only if they + # are not already part of an explicit NChooseK constraint. + for feat in domain.inputs.get(ContinuousInput): + assert isinstance(feat, ContinuousInput) + if feat.allow_zero and feat.key not in nchoosek_feature_keys: + feature_index = features2idx[feat.key][0] + nchooseks_list.append(([feature_index], 0, 1)) + + # Get categorical dimensions (same as mixed_alternating) + fixed_keys = domain.inputs.get_fixed().get_keys() + + return _OptimizeAcqfMctsInput( + acq_function=acqfs[0], + bounds=bounds, + nchooseks=nchooseks_list if nchooseks_list else None, + cat_dims={ + features2idx[feat.key][0]: feat.to_ordinal_encoding( # type: ignore + pd.Series(feat.get_allowed_categories()) # type: ignore + ).tolist() + for feat in domain.inputs.get(CategoricalInput) + if feat.key not in fixed_keys + }, + nig_alpha0=1.0, + ts_prior_var=1.0, + adaptive_prior_var=True, + cache_hit_mode="variance_inflation", + variance_decay=0.95, + rollout_mode="ts_group_action", + adaptive_n0=False, + p_stop_rollout=0.35, + num_iterations=300, + pw_k0=2.0, + pw_alpha=0.6, + max_rollout_retries=3, + use_cache=True, + n_sobol_samples=64, + top_k_refine=8, + screening_num_iterations=None, + q=candidate_count, + raw_samples=self.n_raw_samples, + num_restarts=self.n_restarts, + fixed_features=self.get_fixed_features(domain=domain), + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + seed=None, + ) else: raise ValueError(f"Unknown optimizer: {optimizer}") diff --git a/bofire/strategies/predictives/optimize_mcts.py b/bofire/strategies/predictives/optimize_mcts.py new file mode 100644 index 000000000..e758aa811 --- /dev/null +++ b/bofire/strategies/predictives/optimize_mcts.py @@ -0,0 +1,1363 @@ +"""MCTS-based acquisition function optimization for NChooseK and categorical constraints. + +Uses Monte Carlo Tree Search with Normal-Inverse-Gamma (NIG) Thompson Sampling +to select which features are active (non-zero) and categorical values, then runs +BoTorch acquisition function optimization with inactive features fixed to zero +and categoricals fixed to selected values. + +The NIG conjugate prior models rewards as drawn from Normal(mu, sigma^2) with +both mean and variance unknown. The marginal posterior for the mean is a +Student-t distribution with heavier tails at low observation counts, which +naturally handles the low-n regime without extra heuristics. + +NIG prior: (mu, sigma^2) ~ NIG(mu0, n0, alpha0, beta0) + mu0 = _global_mean() (running mean of novel rewards) + n0 = pseudo-count (default 1.0, or adaptive from branching factor) + alpha0 = nig_alpha0 parameter (default 1.0) + beta0 = alpha0 * _prior_var() (so E[sigma^2] = prior_var) + +After n observations with sufficient stats (n_obs, sum_rewards, sum_sq_rewards): + x_bar = sum_rewards / n + S = sum_sq_rewards - n * x_bar^2 + + n0' = n0 + n + mu0' = (n0 * mu0 + n * x_bar) / n0' + alpha0' = alpha0 + n / 2 + beta0' = beta0 + S / 2 + (n0 * n * (x_bar - mu0)^2) / (2 * n0') + +Marginal posterior for mu: Student-t with + df = 2 * alpha0' + location = mu0' + scale = sqrt(beta0' / (alpha0' * n0')) +""" + +import math +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Callable, NamedTuple, Optional + +import torch +from botorch.optim import optimize_acqf +from torch import Tensor + + +STOP = -1 # Sentinel for stopping selection in a group + + +class ActionStats(NamedTuple): + """Sufficient statistics for an action. + + Args: + n_obs: Number of observations + sum_rewards: Sum of observed rewards + sum_sq_rewards: Sum of squared observed rewards + """ + + n_obs: int + sum_rewards: float + sum_sq_rewards: float + + +class TrajectoryStep(NamedTuple): + """One rollout step: which group and what action.""" + + group: int + action: int + + +class Selection(NamedTuple): + """Terminal result: active features and categorical values.""" + + features: tuple[int, ...] + categoricals: dict[int, float] + + +# ============================================================================= +# Group abstractions for MCTS +# ============================================================================= + + +class Group(ABC): + """Abstract base class for MCTS groups (NChooseK or Categorical).""" + + @property + @abstractmethod + def n_options(self) -> int: + """Number of options/actions available in this group.""" + pass + + @abstractmethod + def legal_actions(self, partial: tuple[int, ...], stopped: bool) -> list[int]: + """Return legal actions given current partial selection.""" + pass + + @abstractmethod + def is_complete(self, partial: tuple[int, ...], stopped: bool) -> bool: + """Check if selection for this group is complete.""" + pass + + +@dataclass(frozen=True) +class NChooseK(Group): + """NChooseK constraint specifying feature selection bounds. + + Args: + features: Feature indices (can be non-contiguous, e.g., [0, 2, 4]) + min_count: Minimum number of features to select + max_count: Maximum number of features to select + """ + + features: Sequence[int] + min_count: int + max_count: int + + def __post_init__(self): + n = len(self.features) + if not (0 <= self.min_count <= self.max_count <= n): + raise ValueError( + f"Invalid NChooseK constraint: require 0 <= min_count <= max_count <= n; " + f"got min_count={self.min_count}, max_count={self.max_count}, n={n}" + ) + + @property + def n_options(self) -> int: + return len(self.features) + + @property + def n_features(self) -> int: + return len(self.features) + + def legal_actions(self, partial: tuple[int, ...], stopped: bool) -> list[int]: + """Compute legal actions within this NChooseK group. + + Actions are indices into self.features (not the actual feature indices). + Enforces strictly increasing selection order (combinations, not permutations). + STOP is legal if len(partial) >= min_count and not already stopped. + """ + n = self.n_features + m = len(partial) + + if stopped or m >= self.max_count: + return [] + + actions: list[int] = [] + last = partial[-1] if partial else -1 + + # Remaining picks needed after this action to satisfy min_count + r_min_needed = max(0, self.min_count - (m + 1)) + # After picking index i, n - (i+1) items remain; require n - (i+1) >= r_min_needed + end_inclusive = n - r_min_needed - 1 + start = last + 1 + + if start <= end_inclusive: + actions.extend(range(start, end_inclusive + 1)) + + if m >= self.min_count: + actions.append(STOP) + + return actions + + def is_complete(self, partial: tuple[int, ...], stopped: bool) -> bool: + """NChooseK is complete when stopped or max_count reached.""" + return stopped or len(partial) >= self.max_count + + +@dataclass(frozen=True) +class Categorical(Group): + """Categorical dimension with allowed values. + + Args: + dim: The dimension index in the input space + values: Sequence of allowed values for this dimension + """ + + dim: int + values: Sequence[float] + + def __post_init__(self): + if len(self.values) < 2: + raise ValueError( + f"CategoricalGroup requires at least two values, got {len(self.values)}" + ) + + @property + def n_options(self) -> int: + return len(self.values) + + def legal_actions(self, partial: tuple[int, ...], stopped: bool) -> list[int]: + """Categorical must select exactly one value. No STOP action.""" + if len(partial) >= 1: + # Already selected + return [] + # All value indices are legal + return list(range(self.n_options)) + + def is_complete(self, partial: tuple[int, ...], stopped: bool) -> bool: + """Categorical is complete when one value is selected.""" + return len(partial) >= 1 + + +# ============================================================================= +# Combined constraints container +# ============================================================================= + + +@dataclass(frozen=True) +class Groups: + """Collection of NChooseK constraints and categorical groups.""" + + groups: list[Group] + + def __len__(self) -> int: + return len(self.groups) + + @property + def categoricals(self) -> list[Categorical]: + return [g for g in self.groups if isinstance(g, Categorical)] + + @property + def nchooseks(self) -> list[NChooseK]: + return [g for g in self.groups if isinstance(g, NChooseK)] + + @property + def all_nchoosek_features(self) -> list[int]: + """All feature indices covered by NChooseK constraints.""" + all_feats = [] + for c in self.nchooseks: + all_feats.extend(c.features) + return all_feats + + @property + def all_categorical_dims(self) -> list[int]: + """All dimension indices that are categorical.""" + return [c.dim for c in self.categoricals] + + +# ============================================================================= +# MCTS Node +# ============================================================================= + + +@dataclass +class Node: + """MCTS tree node with NIG sufficient statistics. + + Each node tracks Bayesian sufficient statistics (n_obs, sum_rewards, + sum_sq_rewards) for the Normal-Inverse-Gamma posterior update, plus + n_visits (which includes cache hits) for progressive widening. + + Args: + partial_by_group: Partial selection per group (indices into group's options) + stopped_by_group: Whether each group has stopped selecting (for NChooseK) + group_idx: Current group being filled + n_obs: Novel observation count (for NIG posterior updates) + sum_rewards: Sum of observed rewards from novel evaluations + sum_sq_rewards: Sum of squared rewards from novel evaluations + n_visits: Total visits including cache hits (for progressive widening) + children: Child nodes keyed by action (int index or STOP) + """ + + partial_by_group: tuple[tuple[int, ...], ...] + stopped_by_group: tuple[bool, ...] + group_idx: int + + n_obs: int = 0 + sum_rewards: float = 0.0 + sum_sq_rewards: float = 0.0 + n_visits: int = 0 + + children: dict[int, "Node"] = field(default_factory=dict) + + def is_terminal(self, groups: Groups) -> bool: + return self.group_idx >= len(groups) + + +# ============================================================================= +# MCTS Implementation with NIG Thompson Sampling +# ============================================================================= + + +class MCTS: + """Monte Carlo Tree Search with Normal-Inverse-Gamma Thompson Sampling. + + Uses NIG conjugate posteriors for tree selection. The marginal posterior + for the mean is a Student-t distribution with heavier tails at low + observation counts, naturally preventing premature commitment. + + Args: + groups: Collection of NChooseK and categorical constraints + reward_fn: Function mapping (selected_features, categorical_selections) to reward + nig_alpha0: NIG shape prior (default 1.0); lower = heavier tails at low n + ts_prior_var: Prior variance (default 1.0); used to set beta0 = alpha0 * prior_var + adaptive_prior_var: If True, use running empirical variance as prior variance + cache_hit_mode: How to handle cache hits during backpropagation. Options: + "no_update" - only increment n_visits (default virtual loss) + "variance_inflation" - decay n_obs to widen posterior + "pessimistic" - add pessimistic pseudo-observations + "combined" - variance_inflation + pessimistic + "adaptive_pessimistic" - pessimistic with exhaustion-scaled strength + "adaptive_combined" - variance_inflation + adaptive_pessimistic + variance_decay: Decay factor for variance inflation mode (default 0.95) + rollout_mode: Rollout action selection policy. Options: + "ts_group_action" - NIG Thompson Sampling per (group, action) + "uniform" - fixed p_stop for NChooseK STOP, then uniform among non-STOP + "uniform_subset" - complete NChooseK groups with uniform random subsets + adaptive_n0: If True, set pseudo-count n0 = 1 + log(branching_factor) + p_stop_rollout: Probability of early stop during uniform rollout (default 0.35) + pw_k0: Progressive widening base constant (default 2.0) + pw_alpha: Progressive widening exponent (default 0.6) + max_rollout_retries: Maximum rollout retries on cache hit (default 3) + use_cache: If True (default), cache reward evaluations. If False, every + call to reward_fn is fresh (no caching), and every observation is + treated as novel. Useful for noisy/sampling-based reward functions. + shuffle_features: If True (default), randomly permute the feature ordering + within each NChooseK group to eliminate structural tree bias. The + canonical strictly-increasing order creates asymmetric subtrees where + high-index features get concentrated visits (up to 6x over-selection). + Shuffling randomizes which features sit at which tree depth, so over + multiple MCTS runs the bias averages out. + seed: Random seed for reproducibility + """ + + def __init__( + self, + groups: Groups, + reward_fn: Callable[[tuple[int, ...], dict[int, float]], float], + nig_alpha0: float = 1.0, + ts_prior_var: float = 1.0, + adaptive_prior_var: bool = True, + cache_hit_mode: str = "variance_inflation", + variance_decay: float = 0.95, + rollout_mode: str = "ts_group_action", + adaptive_n0: bool = False, + p_stop_rollout: float = 0.35, + pw_k0: float = 2.0, + pw_alpha: float = 0.6, + max_rollout_retries: int = 3, + use_cache: bool = True, + shuffle_features: bool = True, + seed: Optional[int] = None, + ): + # Initialize RNG first (needed for shuffle) + self.rng = random.Random(seed) + + # Shuffle NChooseK feature orderings to remove structural tree bias + if shuffle_features: + shuffled_groups = [] + for g in groups.groups: + if isinstance(g, NChooseK): + feats = list(g.features) + self.rng.shuffle(feats) + shuffled_groups.append( + NChooseK( + features=feats, + min_count=g.min_count, + max_count=g.max_count, + ) + ) + else: + shuffled_groups.append(g) + groups = Groups(groups=shuffled_groups) + + self.groups = groups + self.reward_fn = reward_fn + self.nig_alpha0 = nig_alpha0 + self.ts_prior_var = ts_prior_var + self.adaptive_prior_var = adaptive_prior_var + self.cache_hit_mode = cache_hit_mode + self.variance_decay = variance_decay + self.rollout_mode = rollout_mode + self.adaptive_n0 = adaptive_n0 + self.p_stop_rollout = p_stop_rollout + self.pw_k0 = pw_k0 + self.pw_alpha = pw_alpha + self.max_rollout_retries = max_rollout_retries + self.use_cache = use_cache + + # Initialize root node + n_groups = len(groups) + self.root = Node( + partial_by_group=tuple(() for _ in range(n_groups)), + stopped_by_group=tuple(False for _ in range(n_groups)), + group_idx=0, + ) + + # Best found so far + self.best_selection: Optional[Selection] = None + self.best_value: float = float("-inf") + + # Cache for terminal evaluations + self.value_cache: dict[tuple, float] = {} + self.cache_hits = 0 + self.cache_misses = 0 + + # Global novel reward tracking for NIG prior + self._novel_reward_sum: float = 0.0 + self._novel_reward_sq_sum: float = 0.0 + self._novel_reward_count: int = 0 + + # Rollout TS statistics: (group_idx, action) -> ActionStats + self.rollout_ts_stats: dict[tuple[int, int], ActionStats] = {} + + # ========================================================================= + # NIG prior methods + # ========================================================================= + + def _global_mean(self) -> float: + """Running mean of all novel rewards, used as the NIG prior center mu0.""" + if self._novel_reward_count == 0: + return 0.0 + return self._novel_reward_sum / self._novel_reward_count + + def _prior_var(self) -> float: + """Prior variance for the NIG model. + + When adaptive_prior_var is True and at least 2 novel rewards have been + observed, returns the running empirical variance. Otherwise returns the + fixed ts_prior_var. + """ + if not self.adaptive_prior_var or self._novel_reward_count < 2: + return self.ts_prior_var + mean = self._global_mean() + empirical_var = ( + self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + ) + return max(empirical_var, 1e-8) + + def _pessimistic_value(self) -> float: + """Pessimistic pseudo-observation value: global_mean - global_std. + + Used by pessimistic and combined cache-hit modes to inject a + below-average pseudo-observation that discourages over-visited branches. + """ + mean = self._global_mean() + if self._novel_reward_count < 2: + return mean - math.sqrt(self.ts_prior_var) + empirical_var = ( + self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + ) + return mean - math.sqrt(max(empirical_var, 1e-8)) + + def _compute_n0(self, n_actions: int) -> float: + """Compute pseudo-count n0 for the NIG prior. + + With adaptive_n0, n0 = 1 + log(branching_factor). Higher branching + means each child is visited rarely early on, so more observations + are needed before departing from the prior. + """ + if not self.adaptive_n0: + return 1.0 + return 1.0 + math.log(max(n_actions, 2)) + + # ========================================================================= + # NIG Thompson Sampling + # ========================================================================= + + def _student_t_sample(self, df: float, loc: float, scale: float) -> float: + """Sample from a Student-t distribution. + + Uses the representation: loc + scale * Z / sqrt(V / df) + where Z ~ N(0,1) and V ~ chi-squared(df) = Gamma(df/2, 2). + """ + z = self.rng.gauss(0, 1) + v = self.rng.gammavariate(df / 2, 2) # chi-squared(df) + return loc + scale * z / math.sqrt(v / df) + + def _nig_sample( + self, + n_obs: int, + sum_rewards: float, + sum_sq_rewards: float, + n0: float = 1.0, + ) -> float: + """Sample from the NIG posterior (marginal Student-t for the mean). + + Computes the Normal-Inverse-Gamma posterior update from sufficient + statistics and draws a sample from the marginal Student-t distribution. + + The NIG prior is parameterized as: + mu0 = _global_mean(), n0 = n0, alpha0 = nig_alpha0, + beta0 = alpha0 * _prior_var() + + After n observations with sample mean x_bar and sum-of-squared- + deviations S = sum(x_i^2) - n * x_bar^2, the posterior parameters are: + n0' = n0 + n + mu0' = (n0 * mu0 + n * x_bar) / n0' + alpha0' = alpha0 + n/2 + beta0' = beta0 + S/2 + n0*n*(x_bar - mu0)^2 / (2*n0') + + The marginal posterior for mu is Student-t(df=2*alpha0', loc=mu0', + scale=sqrt(beta0' / (alpha0' * n0'))). + + Args: + n_obs: Number of novel observations + sum_rewards: Sum of observed rewards + sum_sq_rewards: Sum of squared observed rewards + n0: Prior pseudo-count (default 1.0) + + Returns: + A sample from the posterior predictive distribution for the mean + """ + mu0 = self._global_mean() + prior_var = self._prior_var() + alpha0 = self.nig_alpha0 + beta0 = alpha0 * prior_var + + if n_obs == 0: + # Prior: Student-t(df=2*alpha0, loc=mu0, scale=sqrt(beta0/(alpha0*n0))) + df = 2 * alpha0 + scale = math.sqrt(beta0 / (alpha0 * n0)) + return self._student_t_sample(df, mu0, scale) + + x_bar = sum_rewards / n_obs + s = sum_sq_rewards - n_obs * x_bar * x_bar + s = max(s, 0.0) # numerical safety + + # Posterior update + n0_post = n0 + n_obs + mu0_post = (n0 * mu0 + n_obs * x_bar) / n0_post + alpha0_post = alpha0 + n_obs / 2 + beta0_post = beta0 + s / 2 + (n0 * n_obs * (x_bar - mu0) ** 2) / (2 * n0_post) + + df = 2 * alpha0_post + scale = math.sqrt(beta0_post / (alpha0_post * n0_post)) + return self._student_t_sample(df, mu0_post, scale) + + # ========================================================================= + # Rollout TS methods + # ========================================================================= + + def _ts_sample_rollout_action( + self, group_idx: int, legal_actions: list[int] + ) -> int: + """Sample rollout action using per-(group, action) NIG posteriors. + + For each legal action, draws a sample from the NIG posterior using the + per-action sufficient statistics, then picks the action with the highest + sample. STOP is scored like any other action. + + Args: + group_idx: Index of the current group + legal_actions: List of legal action indices (may include STOP) + + Returns: + Selected action index + """ + n0 = self._compute_n0(len(legal_actions)) + best_action = legal_actions[0] + best_score = float("-inf") + + for action in legal_actions: + key = (group_idx, action) + stats = self.rollout_ts_stats.get(key, ActionStats(0, 0.0, 0.0)) + score = self._nig_sample( + stats.n_obs, stats.sum_rewards, stats.sum_sq_rewards, n0 + ) + if score > best_score: + best_score = score + best_action = action + + return best_action + + def _update_rollout_ts_stats( + self, trajectory: list[TrajectoryStep], reward: float + ) -> None: + """Update per-(group, action) NIG sufficient stats from a completed rollout. + + Args: + trajectory: List of (group_idx, action) pairs from the rollout + reward: Raw reward obtained from the terminal evaluation + """ + for group_idx, action in trajectory: + key = (group_idx, action) + old = self.rollout_ts_stats.get(key, ActionStats(0, 0.0, 0.0)) + self.rollout_ts_stats[key] = ActionStats( + n_obs=old.n_obs + 1, + sum_rewards=old.sum_rewards + reward, + sum_sq_rewards=old.sum_sq_rewards + reward * reward, + ) + + # ========================================================================= + # Tree infrastructure + # ========================================================================= + + def _make_cache_key( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> tuple: + """Create hashable cache key from selection.""" + return (selected_features, frozenset(cat_selections.items())) + + def _cached_reward( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> float: + """Get cached reward or compute and cache it. + + When use_cache is False, always calls reward_fn fresh (no caching). + """ + if not self.use_cache: + self.cache_misses += 1 + return self.reward_fn(selected_features, cat_selections) + key = self._make_cache_key(selected_features, cat_selections) + if key in self.value_cache: + self.cache_hits += 1 + return self.value_cache[key] + val = self.reward_fn(selected_features, cat_selections) + self.value_cache[key] = val + self.cache_misses += 1 + return val + + def _child_limit(self, node: Node) -> int: + """Progressive widening: max children based on visit count.""" + return max(1, int(self.pw_k0 * (max(1, node.n_visits) ** self.pw_alpha))) + + def _legal_actions(self, node: Node) -> list[int]: + """Get legal actions for current group in node.""" + if node.is_terminal(self.groups): + return [] + g = node.group_idx + group = self.groups.groups[g] + partial = node.partial_by_group[g] + stopped = node.stopped_by_group[g] + return group.legal_actions(partial, stopped) + + def _apply_action(self, node: Node, action: int) -> Node: + """Create child node by applying action to current node.""" + g = node.group_idx + group = self.groups.groups[g] + + partials = list(node.partial_by_group) + stoppeds = list(node.stopped_by_group) + + if action == STOP: + stoppeds[g] = True + next_g = g + 1 + else: + partials[g] += (action,) + # Check if group is complete + if group.is_complete(partials[g], stoppeds[g]): + next_g = g + 1 + else: + next_g = g + + return Node( + partial_by_group=tuple(partials), + stopped_by_group=tuple(stoppeds), + group_idx=next_g, + ) + + def _get_selection(self, node: Node) -> Selection: + """Convert node's partial selections to (selected_features, cat_selections).""" + # Extract NChooseK selections + selected_features = [] + for g, nchoosek in enumerate(self.groups.nchooseks): + for local_idx in node.partial_by_group[g]: + selected_features.append(nchoosek.features[local_idx]) + selected_features_tuple = tuple(sorted(selected_features)) + + # Extract categorical selections + cat_selections: dict[int, float] = {} + n_nchoosek = len(self.groups.nchooseks) + for i, cat_group in enumerate(self.groups.categoricals): + g = n_nchoosek + i + partial = node.partial_by_group[g] + if partial: + cat_selections[cat_group.dim] = cat_group.values[partial[0]] + + return Selection(features=selected_features_tuple, categoricals=cat_selections) + + # ========================================================================= + # Tree selection with NIG Thompson Sampling + # ========================================================================= + + def _select_and_expand(self) -> tuple[Node, list[Node]]: + """Select path through tree using NIG-TS and expand one new node. + + At each internal node, draws a Thompson sample from each child's NIG + posterior and follows the child with the highest sample. When progressive + widening allows expansion, expands a random unexplored action instead. + """ + node = self.root + path = [node] + + while not node.is_terminal(self.groups): + legal = self._legal_actions(node) + limit = self._child_limit(node) + unexpanded = [a for a in legal if a not in node.children] + can_expand = len(node.children) < limit + + if can_expand and unexpanded: + # Expand one new child + action = self.rng.choice(unexpanded) + child = self._apply_action(node, action) + node.children[action] = child + path.append(child) + return child, path + + # NIG Thompson Sampling selection among existing children + if node.children: + n0 = self._compute_n0(len(node.children)) + best_action = None + best_score = float("-inf") + for action, child in node.children.items(): + score = self._nig_sample( + child.n_obs, child.sum_rewards, child.sum_sq_rewards, n0 + ) + if score > best_score: + best_score = score + best_action = action + + node = node.children[best_action] + path.append(node) + else: + break + + return node, path + + # ========================================================================= + # Rollout + # ========================================================================= + + def _rollout( + self, node: Node + ) -> tuple[tuple[int, ...], dict[int, float], list[TrajectoryStep]]: + """Rollout to terminal state with mode-dependent action selection. + + Supports three rollout modes: + - "ts_group_action": NIG Thompson Sampling per (group, action) + - "uniform": adaptive p_stop for NChooseK STOP, then uniform among non-STOP + - "uniform_subset": complete NChooseK groups with uniform random subsets + + Returns: + Tuple of (selected_features, cat_selections, trajectory) + """ + curr = Node( + partial_by_group=tuple(node.partial_by_group), + stopped_by_group=tuple(node.stopped_by_group), + group_idx=node.group_idx, + ) + trajectory: list[TrajectoryStep] = [] + + while not curr.is_terminal(self.groups): + legal = self._legal_actions(curr) + if not legal: + # No legal actions, advance group (group is complete) + curr = Node( + partial_by_group=curr.partial_by_group, + stopped_by_group=curr.stopped_by_group, + group_idx=curr.group_idx + 1, + ) + continue + + g = curr.group_idx + + if self.rollout_mode == "ts_group_action": + # NIG Thompson Sampling: STOP scored like any other action + action = self._ts_sample_rollout_action(g, legal) + + elif self.rollout_mode == "uniform": + # Fixed p_stop for NChooseK STOP, then uniform + is_nchoosek = g < len(self.groups.nchooseks) + if is_nchoosek and STOP in legal: + if self.rng.random() < self.p_stop_rollout: + trajectory.append(TrajectoryStep(g, STOP)) + curr = self._apply_action(curr, STOP) + continue + + # Choose uniformly among non-STOP actions + choices = [a for a in legal if a != STOP] + if not choices: + trajectory.append(TrajectoryStep(g, STOP)) + curr = self._apply_action(curr, STOP) + continue + + action = self.rng.choice(choices) + + elif self.rollout_mode == "uniform_subset": + is_nchoosek = g < len(self.groups.nchooseks) + if is_nchoosek: + group = self.groups.groups[g] + partial = curr.partial_by_group[g] + m = len(partial) + last = partial[-1] if partial else -1 + available = list(range(last + 1, group.n_features)) + min_remaining = max(0, group.min_count - m) + max_remaining = group.max_count - m + + # Determine count: sample how many features to add + if min_remaining == max_remaining or not available: + k = min(min_remaining, len(available)) + else: + # Geometric-like: for each slot beyond min, stop with p_stop + k = min_remaining + for _ in range( + min_remaining, min(max_remaining, len(available)) + ): + if self.rng.random() < self.p_stop_rollout: + break + k += 1 + + # Sample k features uniformly from available + k = min(k, len(available)) + chosen = sorted(self.rng.sample(available, k)) + + # Build completed node directly + new_partial = list(curr.partial_by_group) + new_stopped = list(curr.stopped_by_group) + new_partial[g] = partial + tuple(chosen) + new_stopped[g] = True # group is complete + curr = Node( + partial_by_group=tuple(new_partial), + stopped_by_group=tuple(new_stopped), + group_idx=g + 1, + ) + continue + else: + # Categorical: pick uniformly (same as "uniform" mode) + action = self.rng.choice(legal) + else: + raise ValueError(f"Unknown rollout_mode: {self.rollout_mode}") + + trajectory.append(TrajectoryStep(g, action)) + curr = self._apply_action(curr, action) + + selected_features, cat_selections = self._get_selection(curr) + return selected_features, cat_selections, trajectory + + # ========================================================================= + # Backpropagation with NIG cache-hit modes + # ========================================================================= + + def _backpropagate(self, path: list[Node], reward: float, is_novel: bool) -> None: + """Backpropagate reward through path with NIG-aware cache-hit handling. + + For novel evaluations: updates n_obs, sum_rewards, sum_sq_rewards, + and n_visits on each node in the path. + + For cache hits: always increments n_visits (for progressive widening), + then applies the configured cache_hit_mode: + - "no_update": only increment n_visits + - "variance_inflation": decay n_obs to widen the NIG posterior + - "pessimistic": add a pessimistic pseudo-observation + - "combined": variance_inflation + pessimistic + - "adaptive_pessimistic": pessimistic with exhaustion-scaled strength + - "adaptive_combined": variance_inflation + adaptive_pessimistic + + Args: + path: List of nodes from root to leaf + reward: Raw reward value + is_novel: Whether this is a novel (non-cached) evaluation + """ + if is_novel: + for n in path: + n.n_obs += 1 + n.sum_rewards += reward + n.sum_sq_rewards += reward * reward + n.n_visits += 1 + return + + # Cache hit handling + if self.cache_hit_mode in ("pessimistic", "combined"): + pess = self._pessimistic_value() + + if self.cache_hit_mode in ("adaptive_pessimistic", "adaptive_combined"): + g_mean = self._global_mean() + if self._novel_reward_count < 2: + g_std = math.sqrt(self.ts_prior_var) + else: + emp_var = ( + self._novel_reward_sq_sum / self._novel_reward_count + - g_mean * g_mean + ) + g_std = math.sqrt(max(emp_var, 1e-8)) + + for n in path: + n.n_visits += 1 + + if self.cache_hit_mode == "no_update": + pass + + elif self.cache_hit_mode == "variance_inflation": + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + + elif self.cache_hit_mode == "pessimistic": + n.n_obs += 1 + n.sum_rewards += pess + n.sum_sq_rewards += pess * pess + + elif self.cache_hit_mode == "combined": + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + n.n_obs += 1 + n.sum_rewards += pess + n.sum_sq_rewards += pess * pess + + elif self.cache_hit_mode == "adaptive_pessimistic": + novelty_rate = n.n_obs / max(1, n.n_visits) + exhaustion = 1.0 - novelty_rate + pess_value = g_mean - exhaustion * g_std + n.n_obs += 1 + n.sum_rewards += pess_value + n.sum_sq_rewards += pess_value * pess_value + + elif self.cache_hit_mode == "adaptive_combined": + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + novelty_rate = n.n_obs / max(1, n.n_visits) + exhaustion = 1.0 - novelty_rate + pess_value = g_mean - exhaustion * g_std + n.n_obs += 1 + n.sum_rewards += pess_value + n.sum_sq_rewards += pess_value * pess_value + + # ========================================================================= + # Main loop + # ========================================================================= + + def run(self, n_iterations: int) -> tuple[tuple[int, ...], dict[int, float], float]: + """Run MCTS for specified number of iterations. + + Args: + n_iterations: Number of MCTS iterations to run + + Returns: + Tuple of (selected_features, cat_selections, best_value) + """ + for _ in range(n_iterations): + leaf, path = self._select_and_expand() + + if leaf.is_terminal(self.groups): + selected_features, cat_selections = self._get_selection(leaf) + trajectory: list[TrajectoryStep] = [] + else: + selected_features, cat_selections, trajectory = self._rollout(leaf) + # Rollout retry: if the rollout produces a cached terminal, + # re-roll to try to discover a novel selection. + # Skip when use_cache=False since the cache is always empty. + if self.use_cache: + for _attempt in range(self.max_rollout_retries): + key = self._make_cache_key(selected_features, cat_selections) + if key not in self.value_cache: + break + selected_features, cat_selections, trajectory = self._rollout( + leaf + ) + + if self.use_cache: + key = self._make_cache_key(selected_features, cat_selections) + is_novel = key not in self.value_cache + else: + is_novel = True + reward = self._cached_reward(selected_features, cat_selections) + + if reward > self.best_value: + self.best_value = reward + self.best_selection = Selection( + features=selected_features, categoricals=cat_selections + ) + + # Track novel reward statistics for NIG prior + if is_novel: + self._novel_reward_sum += reward + self._novel_reward_sq_sum += reward * reward + self._novel_reward_count += 1 + + # Backpropagate with NIG cache-hit handling + self._backpropagate(path, reward, is_novel) + + # Update rollout TS stats + if self.rollout_mode == "ts_group_action": + self._update_rollout_ts_stats(trajectory, reward) + + if self.best_selection is None: + return (), {}, self.best_value + return ( + self.best_selection.features, + self.best_selection.categoricals, + self.best_value, + ) + + def cache_stats(self) -> dict[str, int]: + """Return cache statistics.""" + return { + "hits": self.cache_hits, + "misses": self.cache_misses, + "size": len(self.value_cache), + } + + +# ============================================================================= +# Two-phase Sobol screening helpers +# ============================================================================= + + +def _sobol_evaluate_acqf( + acq_function, + bounds: Tensor, + combined_fixed: dict[int, float], + n_sobol_samples: int, + q: int = 1, +) -> float: + """Evaluate acquisition function using Sobol quasi-random sampling. + + Generates quasi-random points in the free (non-fixed) dimensions, constructs + full candidate tensors with fixed dims filled in, and returns the maximum + acquisition value found. + + Args: + acq_function: BoTorch acquisition function to evaluate + bounds: 2 x d tensor of (lower, upper) bounds + combined_fixed: Dictionary mapping dimension indices to fixed values + n_sobol_samples: Number of Sobol samples to draw + q: Batch size for acquisition function evaluation + + Returns: + Maximum acquisition value across all Sobol samples + """ + d = bounds.shape[1] + free_dims = [i for i in range(d) if i not in combined_fixed] + + if not free_dims: + # All dims fixed: evaluate the single point + point = torch.zeros(1, q, d, dtype=bounds.dtype, device=bounds.device) + for dim_idx, val in combined_fixed.items(): + point[0, :, dim_idx] = val + with torch.no_grad(): + return acq_function(point).max().item() + + n_free = len(free_dims) + sobol = torch.quasirandom.SobolEngine(dimension=n_free, scramble=True) + unit_samples = sobol.draw(n_sobol_samples).to( + dtype=bounds.dtype, device=bounds.device + ) + + # Scale from [0,1] to bounds of free dims + lower = bounds[0, free_dims] + upper = bounds[1, free_dims] + scaled = lower + (upper - lower) * unit_samples # (n_sobol_samples, n_free) + + # Build full (n_sobol_samples, q, d) tensor + candidates = torch.zeros( + n_sobol_samples, q, d, dtype=bounds.dtype, device=bounds.device + ) + for j, dim_idx in enumerate(free_dims): + candidates[:, :, dim_idx] = scaled[:, j].unsqueeze(1) + for dim_idx, val in combined_fixed.items(): + candidates[:, :, dim_idx] = val + + with torch.no_grad(): + values = acq_function(candidates) # (n_sobol_samples,) + + return values.max().item() + + +class _SelectionTracker: + """Wraps a reward function to record the best reward per unique selection. + + Used in the Sobol screening phase to track which feature/categorical + selections yield the highest rewards, so the top-k can be refined + with expensive optimize_acqf calls. + """ + + def __init__( + self, + inner_fn: Callable[[tuple[int, ...], dict[int, float]], float], + ): + self.inner_fn = inner_fn + self.best_rewards: dict[tuple, float] = {} + self.selections: dict[tuple, tuple[tuple[int, ...], dict[int, float]]] = {} + + def __call__( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> float: + reward = self.inner_fn(selected_features, cat_selections) + key = (selected_features, frozenset(cat_selections.items())) + if key not in self.best_rewards or reward > self.best_rewards[key]: + self.best_rewards[key] = reward + self.selections[key] = (selected_features, cat_selections) + return reward + + def top_k(self, k: int) -> list[tuple[tuple[int, ...], dict[int, float]]]: + """Return the top-k selections sorted by best reward (descending).""" + sorted_keys = sorted( + self.best_rewards, key=self.best_rewards.__getitem__, reverse=True + ) + return [self.selections[key] for key in sorted_keys[:k]] + + +# ============================================================================= +# Main optimization function +# ============================================================================= + + +def optimize_acqf_mcts( + acq_function, + bounds: Tensor, + nchooseks: list[tuple[list[int], int, int]] | None = None, + cat_dims: Mapping[int, Sequence[float]] | None = None, + # MCTS NIG parameters + nig_alpha0: float = 1.0, + ts_prior_var: float = 1.0, + adaptive_prior_var: bool = True, + cache_hit_mode: str = "variance_inflation", + variance_decay: float = 0.95, + rollout_mode: str = "ts_group_action", + adaptive_n0: bool = False, + p_stop_rollout: float = 0.35, + num_iterations: int = 100, + pw_k0: float = 2.0, + pw_alpha: float = 0.6, + max_rollout_retries: int = 3, + use_cache: bool = True, + shuffle_features: bool = True, + # Two-phase Sobol screening parameters + n_sobol_samples: int = 0, + top_k_refine: int = 8, + screening_num_iterations: int | None = None, + # BoTorch acqf optimization parameters + q: int = 1, + raw_samples: int = 1024, + num_restarts: int = 20, + fixed_features: dict[int, float] | None = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + seed: int | None = None, +) -> tuple[Tensor, float]: + """Optimize acquisition function with NChooseK and categorical constraints using MCTS. + + Uses MCTS with NIG Thompson Sampling to select which features are active + (non-zero) and categorical values, then runs BoTorch optimization with + inactive features fixed to zero and categoricals fixed to their selected values. + + Args: + acq_function: BoTorch acquisition function to optimize + bounds: 2 x d tensor of (lower, upper) bounds for each dimension + nchooseks: List of NChooseK constraints as tuples of (features, min_count, max_count) + cat_dims: Dictionary mapping categorical dimension indices to allowed values + nig_alpha0: NIG shape prior (default 1.0) + ts_prior_var: Prior variance for NIG model (default 1.0) + adaptive_prior_var: Use running empirical variance as prior variance + cache_hit_mode: Cache hit handling strategy + variance_decay: Decay factor for variance inflation mode + rollout_mode: Rollout action selection policy ("ts_group_action" or "uniform") + adaptive_n0: Adapt pseudo-count from branching factor + p_stop_rollout: Base probability of early stop during uniform rollout + num_iterations: Number of MCTS iterations + pw_k0: Progressive widening base constant + pw_alpha: Progressive widening exponent + max_rollout_retries: Maximum rollout retries on cache hit + use_cache: If True (default), cache reward evaluations. If False, every + evaluation calls reward_fn fresh and is treated as novel. + n_sobol_samples: Number of Sobol quasi-random samples per evaluation in the + screening phase. 0 (default) disables two-phase mode and uses original + optimize_acqf-per-step behavior. The production path via + _get_arguments_for_optimizer sets this to 64. + top_k_refine: Number of top unique selections from screening to refine with + expensive optimize_acqf in the refinement phase. + screening_num_iterations: Number of MCTS iterations for the screening phase. + If None, defaults to 500 when two-phase mode is active. + q: Batch size for acquisition function optimization + raw_samples: Number of raw samples for initialization + num_restarts: Number of optimization restarts + fixed_features: Additional fixed features (combined with MCTS selections) + inequality_constraints: Inequality constraints for BoTorch optimization + equality_constraints: Equality constraints for BoTorch optimization + seed: Random seed for reproducibility + + Returns: + Tuple of (best_candidates, best_acq_value) where best_candidates is a + q x d tensor of optimal points and best_acq_value is the acquisition value + """ + d = bounds.shape[1] + + # Build NChooseK groups from tuples + nchoosek_list = [] + if nchooseks: + for features, min_count, max_count in nchooseks: + nchoosek_list.append( + NChooseK(features=features, min_count=min_count, max_count=max_count) + ) + + # Build categorical groups + categorical_list = ( + [Categorical(dim=dim, values=list(values)) for dim, values in cat_dims.items()] + if cat_dims + else [] + ) + + # Combine all groups + all_groups = nchoosek_list + categorical_list + groups = Groups(groups=all_groups) + + # All feature indices covered by NChooseK constraints + nchoosek_features = set(groups.all_nchoosek_features) + + def _build_combined_fixed( + selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> dict[int, float]: + """Build fixed_features dict from selection.""" + combined_fixed: dict[int, float] = {} + if fixed_features is not None: + combined_fixed.update(fixed_features) + inactive_features = nchoosek_features - set(selected_features) + for idx in inactive_features: + combined_fixed[idx] = 0.0 + for dim, value in cat_selections.items(): + combined_fixed[dim] = value + return combined_fixed + + if n_sobol_samples > 0: + # ================================================================= + # Phase 1: Sobol Screening — cheap quasi-random evaluation + # ================================================================= + screening_iters = ( + screening_num_iterations if screening_num_iterations is not None else 500 + ) + + def sobol_reward_fn( + selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> float: + combined_fixed = _build_combined_fixed(selected_features, cat_selections) + return _sobol_evaluate_acqf( + acq_function, bounds, combined_fixed, n_sobol_samples, q + ) + + tracker = _SelectionTracker(sobol_reward_fn) + + mcts = MCTS( + groups=groups, + reward_fn=tracker, + nig_alpha0=nig_alpha0, + ts_prior_var=ts_prior_var, + adaptive_prior_var=adaptive_prior_var, + cache_hit_mode=cache_hit_mode, + variance_decay=variance_decay, + rollout_mode="uniform_subset", + adaptive_n0=adaptive_n0, + p_stop_rollout=p_stop_rollout, + pw_k0=pw_k0, + pw_alpha=pw_alpha, + max_rollout_retries=max_rollout_retries, + use_cache=False, # Sobol is inherently noisy; NIG handles variance + shuffle_features=shuffle_features, + seed=seed, + ) + + mcts.run(n_iterations=screening_iters) + + # ================================================================= + # Phase 2: Refinement — expensive optimize_acqf on top-k selections + # ================================================================= + top_selections = tracker.top_k(top_k_refine) + + best_candidates: Optional[Tensor] = None + best_acq_value: float = float("-inf") + + for selected_features, cat_selections in top_selections: + combined_fixed = _build_combined_fixed(selected_features, cat_selections) + + candidates, acq_value = optimize_acqf( + acq_function=acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + fixed_features=combined_fixed, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) + + value = acq_value.item() + if value > best_acq_value: + best_acq_value = value + best_candidates = candidates + + if best_candidates is None: + best_candidates = torch.zeros( + q, d, dtype=bounds.dtype, device=bounds.device + ) + best_acq_value = float("-inf") + + return best_candidates, best_acq_value + + else: + # ================================================================= + # Original behavior: optimize_acqf per MCTS step + # ================================================================= + best_candidates: Optional[Tensor] = None + best_acq_value: float = float("-inf") + + def reward_fn( + selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> float: + nonlocal best_candidates, best_acq_value + + combined_fixed = _build_combined_fixed(selected_features, cat_selections) + + candidates, acq_value = optimize_acqf( + acq_function=acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + fixed_features=combined_fixed, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) + + value = acq_value.item() + + if value > best_acq_value: + best_acq_value = value + best_candidates = candidates + + return value + + mcts = MCTS( + groups=groups, + reward_fn=reward_fn, + nig_alpha0=nig_alpha0, + ts_prior_var=ts_prior_var, + adaptive_prior_var=adaptive_prior_var, + cache_hit_mode=cache_hit_mode, + variance_decay=variance_decay, + rollout_mode=rollout_mode, + adaptive_n0=adaptive_n0, + p_stop_rollout=p_stop_rollout, + pw_k0=pw_k0, + pw_alpha=pw_alpha, + max_rollout_retries=max_rollout_retries, + use_cache=use_cache, + shuffle_features=shuffle_features, + seed=seed, + ) + + mcts.run(n_iterations=num_iterations) + + if best_candidates is None: + best_candidates = torch.zeros( + q, d, dtype=bounds.dtype, device=bounds.device + ) + best_acq_value = float("-inf") + + return best_candidates, best_acq_value diff --git a/bofire/strategies/random.py b/bofire/strategies/random.py index 46b38569b..2c7c66106 100644 --- a/bofire/strategies/random.py +++ b/bofire/strategies/random.py @@ -26,6 +26,7 @@ ContinuousInput, DiscreteInput, ) +from bofire.strategies.predictives.optimize_mcts import MCTS, Groups, NChooseK from bofire.strategies.strategy import Strategy, make_strategy from bofire.utils.torch_tools import ( get_interpoint_constraints, @@ -59,6 +60,7 @@ def __init__( self.fallback_sampling_method = data_model.fallback_sampling_method self.n_burnin = data_model.n_burnin self.n_thinning = data_model.n_thinning + self.max_combinations = data_model.max_combinations self.sampler_kwargs = data_model.sampler_kwargs def has_sufficient_experiments(self) -> bool: @@ -124,32 +126,79 @@ def _sample_with_nchooseks( pd.DataFrame: A DataFrame containing the sampled data. """ - if len(self.domain.constraints.get(NChooseKConstraint)) > 0: - _, unused = self.domain.get_nchoosek_combinations() - - if candidate_count <= len(unused): - sampled_combinations = [ - unused[i] - for i in np.random.default_rng(self._get_seed()).choice( - len(unused), - size=candidate_count, - replace=False, + features2idx, _ = self.domain.inputs._get_transform_info( + {}, + ) + idx2features = {idx[0]: key for key, idx in features2idx.items()} + + if len(self.domain.constraints.get(NChooseKConstraint)) > 0 or any( + isinstance(feat, ContinuousInput) and feat.allow_zero + for feat in self.domain.inputs.get(ContinuousInput) + ): + groups = [] + nchoosek_feature_keys: set[str] = set() + for constraint in self.domain.constraints.get(NChooseKConstraint): + assert isinstance(constraint, NChooseKConstraint) + groups.append( + NChooseK( + features=[features2idx[key][0] for key in constraint.features], + min_count=constraint.min_count, + max_count=constraint.max_count, ) - ] - num_samples_per_it = 1 - else: - sampled_combinations = unused - num_samples_per_it = math.ceil(candidate_count / len(unused)) + ) + nchoosek_feature_keys.update(constraint.features) + # Only create single-feature groups for allow_zero features + # that are not already part of an NChooseK constraint. + for feat in self.domain.inputs.get(ContinuousInput): + assert isinstance(feat, ContinuousInput) + if feat.allow_zero and feat.key not in nchoosek_feature_keys: + groups.append( + NChooseK( + features=[features2idx[feat.key][0]], + min_count=0, + max_count=1, + ) + ) + nchoosek_feature_keys.add(feat.key) + + mcts = MCTS( + groups=Groups(groups=groups), + seed=self._get_seed(), + # Dummy reward function — we only use MCTS for tree traversal + # to sample valid NChooseK combinations. + reward_fn=lambda x, y: 0.0, + rollout_mode="uniform_subset", + p_stop_rollout=0.0, + ) + + # now we sample the combinations + combinations: dict[tuple[str, ...], int] = {} + for _ in range(min(self.max_combinations, candidate_count)): + combo, _, _trajectory = mcts._rollout(mcts.root) + # combo is a tuple of feature indices, we convert it to feature keys + combo = tuple(idx2features[idx] for idx in combo) + if combo in combinations: + combinations[combo] += 1 + else: + combinations[combo] = 1 samples = [] - for u in sampled_combinations: + + sampling_multiplier = math.ceil( + candidate_count / min(candidate_count, self.max_combinations) + ) + + for combo, n in combinations.items(): # create new domain without the nchoosekconstraints domain = deepcopy(self.domain) domain.constraints = domain.constraints.get(excludes=NChooseKConstraint) - # fix the unused features - for key in u: + unselected_features = nchoosek_feature_keys - set(combo) + # fix the unused features to zero + for key in unselected_features: feat = domain.inputs.get_by_key(key=key) assert isinstance(feat, ContinuousInput) + if feat.allow_zero: + feat.allow_zero = False feat.bounds = [0.0, 0.0] # setup then sampler for this situation samples.append( @@ -159,7 +208,7 @@ def _sample_with_nchooseks( n_burnin=self.n_burnin, n_thinning=self.n_thinning, seed=self._get_seed(), - n=num_samples_per_it, + n=n * sampling_multiplier, sampler_kwargs=self.sampler_kwargs, ), ) @@ -277,7 +326,7 @@ def _sample_from_polytope( samples = pd.DataFrame( data=np.nan, index=range(n), - columns=domain.inputs.get_keys(), + columns=domain.inputs.get_keys(ContinuousInput), ) else: bounds = torch.tensor([lower, upper]).to(**tkwargs) @@ -326,6 +375,7 @@ def _sample_from_polytope( for feat in domain.inputs.get(ContinuousInput) if feat.key not in fixed_features ] + # setup the output samples = pd.DataFrame( data=candidates.detach().numpy(), @@ -334,10 +384,13 @@ def _sample_from_polytope( ) # setup the categoricals and discrete ones as uniform sampled vals + # we have to make sure here that no fixed ones occur here samples = pd.concat( [ samples, - domain.inputs.get([CategoricalInput, DiscreteInput]).sample( + domain.inputs.get([CategoricalInput, DiscreteInput]) + .get_free() + .sample( n, method=fallback_sampling_method, seed=seed, @@ -350,6 +403,9 @@ def _sample_from_polytope( # setup the fixed continuous ones for key, value in fixed_features.items(): samples[key] = value + # setup the fixed discrete/categorical ones + for feat in domain.inputs.get([CategoricalInput, DiscreteInput]).get_fixed(): + samples[feat.key] = feat.fixed_value()[0] # type: ignore return samples[domain.inputs.get_keys()] diff --git a/mcts-report/REPORT.md b/mcts-report/REPORT.md new file mode 100644 index 000000000..9a0e0b59b --- /dev/null +++ b/mcts-report/REPORT.md @@ -0,0 +1,3466 @@ +# MCTS Benchmark Report: Combinatorial NChooseK Optimization + +## Executive Summary + +This benchmark evaluates the MCTS algorithm from `bofire/strategies/predictives/optimize_mcts.py` (without acquisition function integration) across 6 combinatorial problems with NChooseK constraints. We test 23 UCT-based MCTS configurations varying RAVE, Progressive Widening (PW), exploration constants, stop probability, adaptive stop probability, reward normalization, rollout policy, and context-aware RAVE against a random-sampling baseline. We then benchmark Thompson Sampling (TS) variants with Normal and Normal-Inverse-Gamma (NIG) posteriors against the best UCT configs. + +Nine algorithmic improvements were implemented during this benchmarking cycle: +1. **Virtual loss on cache hit**: On revisiting a cached terminal, increment visit counts but backpropagate reward=0. This dilutes mean node value for over-exploited branches, steering UCT toward unexplored territory. +2. **Rollout retry on cache hit**: When a rollout produces a cached terminal, re-roll up to `max_rollout_retries` times to find a novel selection. +3. **Blended softmax rollout policy**: Replaces uniform-random rollouts with a learned policy that blends softmax over per-(group, action) statistics with uniform exploration, treating STOP as a regular scored action. +4. **Context-aware RAVE**: Conditions RAVE statistics on `(group_idx, cardinality, action)` instead of a global action ID, allowing RAVE to learn that a feature's value depends on how many features are already selected. +5. **Thompson Sampling tree + rollout policy**: Replaces UCT selection and softmax rollouts with Normal-Normal conjugate posterior sampling, eliminating 9 tunable hyperparameters. +6. **Normal-Inverse-Gamma posterior**: Replaces the Normal-Normal conjugate with the proper Bayesian conjugate for unknown mean and variance. The marginal posterior for the mean is a Student-t distribution with heavier tails at low observation counts, naturally preventing premature commitment. +7. **Adaptive pessimistic strength**: Scales the pessimistic pseudo-observation by each node's local exhaustion rate (1 - n_obs/n_visits). Fresh nodes get mild pessimism; exhausted nodes get full pessimism. Zero new hyperparameters. +8. **Adaptive pseudo-count n₀ (negative result)**: Sets n₀ = 1 + log(branching_factor) so nodes with many siblings require more observations before departing from the prior. Tested but found harmful — over-corrects by making posteriors too conservative for the available budget. +9. **DAG with transposition table**: Removes canonical ordering from NChooseK (all unselected features legal at every node), using a frozenset-keyed transposition table to merge nodes with identical selected feature sets into a DAG. Eliminates structural bias where high-index features get concentrated subtrees. Combined with `separate_stop` (binary STOP-vs-best-feature comparison) to fix STOP dilution in the flat action space. + +**Key result (UCT)**: The cumulative effect of improvements 1-4 transforms MCTS from underperforming random sampling to decisively outperforming it on every problem. The best UCT configuration (**MCTS +rpol**: no RAVE + adaptive p_stop + reward normalization + rollout policy) achieves 100% optimum-finding rate on needle_in_haystack (vs 10% for random), 80% on graduated_landscape (vs 7%), **77% on mixed problems** (vs 3%), and **50% on large_sparse** (vs 0%). Context-aware RAVE re-enables RAVE as a useful signal on mixed problems (80% with k=300 vs 77% for +rpol) while matching +rpol on other problems. + +**Key result (Thompson Sampling)**: TS with variance inflation on cache hits (`TS + TS(g,a) + var_infl`) **doubles UCT's optimum rate on multigroup_interaction** (47% vs 23%) — the problem with strongest cross-variable interactions — while using zero tunable hyperparameters. However, **UCT remains superior on large search spaces**: 50% vs 20% on large_sparse, 100% vs 83% on needle_in_haystack. Variance inflation is essential — without it, TS over-exploits exhausted subtrees and achieves only 3-17% on most problems. See Section 11 for full analysis. + +**Key result (NIG posterior)**: The Normal-Inverse-Gamma posterior is a transformative improvement over Normal-TS. The best NIG config (**NIG + TS(g,a) + vi + apv**: variance inflation + adaptive prior variance) achieves **80% on multigroup_interaction** (vs UCT's 23%), **100% on needle and mixed** (vs UCT's 100% and 77%), and **47% on large_sparse** (vs UCT's 50% — essentially tied). A single NIG config now matches or exceeds UCT on 5 of 6 problems, with the remaining gap on large_sparse within statistical noise (3pp). See Section 11.13 for full analysis. + +**Key result (Adaptive pessimistic strength)**: Scaling the pessimistic offset by node exhaustion did not resolve the vi-vs-comb tradeoff on interaction problems (vi+apv remains best at 80%). However, the no-APV adaptive modes (**NIG + TS(g,a) + apess**) achieved **53% on large_sparse — the first NIG configs to surpass UCT's 50%** on this problem. This revealed that adaptive prior variance (APV) hurts on massive search spaces by over-shrinking the prior too early. NIG now matches or exceeds UCT on all 6 problems when using problem-appropriate configs. See Section 11.14 for full analysis. + +**Key result (Real acquisition landscapes)**: MCTS-TS-NIG was replayed on real GP-based acquisition functions from full BO loops. On Hartmann(6, k≤4) with 57 subsets, it achieves **74% exact-best and 99% top-5 in 200 iterations** (§11.18). On the harder Hartmann(6)+2 spurious features (247 subsets), the exact-best rate drops to 17.5%, but this metric is misleading: the acquisition landscape is **bimodal** with a flat top, where the top ~10 subsets differ by <0.03. Reframing the metric to "good cluster" discovery (subsets in the correct mode of the bimodal distribution), MCTS achieves **96.5% in 200 iterations with a median first-hit of 8 iterations** (§11.19). Sobol sampling with just **64 samples per subset** achieves ρ ≈ 0.95 rank correlation with expensive `optimize_acqf`, identifies bimodal clusters with **96.5% accuracy and zero false positives**, validating a cheap two-phase burn-in strategy (§11.20). + +**Key result (DAG with transposition table)**: Replacing the tree with a DAG eliminates canonical ordering bias — the structural asymmetry where high-index features get shallow, concentrated subtrees while low-index features get deep, diluted subtrees. The DAG allows all unselected features at every node, using a transposition table (frozenset-keyed) to merge nodes with identical selected feature sets. The best DAG config (**DAG + ss + vi**: separate_stop + variance_inflation + adaptive prior variance) achieves **87% on large_sparse** (vs NIG tree's 47% — nearly doubling it), **100% on graduated and simple_additive** (vs 77% and 83%), and **93% on mixed** (vs 100%). The DAG's one weakness is multigroup_interaction (0% exact-match vs tree's 80%), where the flat action space dilutes STOP statistics. However, this metric is misleading for the acqf pipeline: the DAG achieves **100% feature recall** — it always finds all optimal features, just with 1-3 extras that the downstream gradient optimizer handles. For the real use case where MCTS selects features and L-BFGS refines continuous values, the DAG is the recommended approach. See §11.22 for full analysis. + +--- + +## 1. Experimental Setup + +### 1.1 MCTS Configurations Tested + +| Config | c_uct | k_rave | pw_k0 | pw_alpha | p_stop | +|--------|-------|--------|-------|----------|--------| +| **Random baseline** | — | — | — | — | 0.35 | +| **MCTS (default)** | 1.0 | 300 | 2.0 | 0.6 | 0.35 | +| **MCTS (no RAVE)** | 1.0 | 0 | 2.0 | 0.6 | 0.35 | +| **MCTS (no PW)** | 1.0 | 300 | 1e6 | 0.6 | 0.35 | +| **MCTS (no RAVE, no PW)** | 1.0 | 0 | 1e6 | 0.6 | 0.35 | +| **MCTS (low explore)** | 0.1 | 300 | 2.0 | 0.6 | 0.35 | +| **MCTS (high explore)** | 5.0 | 300 | 2.0 | 0.6 | 0.35 | +| **MCTS (heavy RAVE)** | 1.0 | 3000 | 2.0 | 0.6 | 0.35 | +| **MCTS (tight PW)** | 1.0 | 300 | 1.0 | 0.4 | 0.35 | +| **MCTS (loose PW)** | 1.0 | 300 | 5.0 | 0.8 | 0.35 | +| **MCTS (p_stop=0.1)** | 1.0 | 300 | 2.0 | 0.6 | 0.10 | +| **MCTS (p_stop=0.6)** | 1.0 | 300 | 2.0 | 0.6 | 0.60 | +| **MCTS (adaptive p)** | 1.0 | 300 | 2.0 | 0.6 | adaptive | +| **MCTS (no RAVE+adpt)** | 1.0 | 0 | 2.0 | 0.6 | adaptive | +| **MCTS (norm)** | 0.01 | 300 | 2.0 | 0.6 | 0.35 | +| **MCTS (no RAVE+adpt+norm)** | 0.01 | 0 | 2.0 | 0.6 | adaptive | +| **MCTS (+rpol)** | 0.01 | 0 | 2.0 | 0.6 | adaptive | +| **MCTS (+rpol ε=0.1)** | 0.01 | 0 | 2.0 | 0.6 | adaptive | +| **MCTS (+rpol τ=0.5)** | 0.01 | 0 | 2.0 | 0.6 | adaptive | +| **MCTS (+rpol τ=2)** | 0.01 | 0 | 2.0 | 0.6 | adaptive | +| **MCTS (+crave k=100)** | 0.01 | 100 | 2.0 | 0.6 | adaptive | +| **MCTS (+crave k=300)** | 0.01 | 300 | 2.0 | 0.6 | adaptive | +| **MCTS (+crave k=500)** | 0.01 | 500 | 2.0 | 0.6 | adaptive | + +The `norm` and `no RAVE+adpt+norm` configs enable `normalize_rewards=True` with `c_uct=0.01`; other non-rollout configs use raw rewards with `c_uct` as shown. The reduced `c_uct` compensates for normalization compressing rewards to [0, 1] — with raw rewards in the range 60–272 across problems, `c_uct=1.0` gives an effective exploration pressure of `1.0/reward_range`; `c_uct=0.01` with normalized rewards matches this balance. + +The `+rpol` configs build on `no RAVE+adpt+norm` and add `rollout_policy=True` with varying `rollout_epsilon` (ε) and `rollout_tau` (τ). The default rollout policy uses ε=0.3, τ=1.0, novelty_weight=1.0. + +The `+crave` configs build on `+rpol` and add `context_rave=True` with varying `k_rave` values to control how much weight the context-aware RAVE signal receives. + +- **RAVE disabled**: `k_rave=0` sets β=0, making the score pure UCT. +- **PW disabled**: `pw_k0=1e6` makes the child limit always exceed legal actions. +- **Adaptive p_stop**: Learns per-group stop probability from cardinality-reward statistics. Uses sigmoid on normalized `(E_stop - E_continue)`, blended with fixed prior during warmup (20 rollouts). +- **Reward normalization**: Maps rewards to [0, 1] via running min-max before backpropagation. `best_value` and adaptive p_stop statistics remain in raw reward space. +- **Rollout policy**: Replaces uniform-random rollouts with a softmax over per-(group, action) mean rewards + novelty bonus, blended with uniform exploration via epsilon-mixing. +- **Context-aware RAVE**: Replaces global RAVE (keyed by action ID) with context-dependent statistics keyed by `(group_idx, cardinality, action)`. This allows RAVE to learn that a feature's value depends on how many features are already selected in that group. + +### 1.2 Benchmark Problems + +| Problem | Groups | Features | Subset sizes | Search space | Budget | Trials | +|---------|--------|----------|-------------|-------------|--------|--------| +| **multigroup_interaction** | 3 NChooseK | 8 each | 1-4 | ~4.25M | 600 | 30 | +| **needle_in_haystack** | 1 NChooseK | 15 | 2-5 | ~4,928 | 400 | 30 | +| **mixed_nchoosek_categorical** | 2 NChooseK + 2 Cat | 6 each + 4 vals | 1-3 | ~26,896 | 500 | 30 | +| **large_sparse** | 4 NChooseK | 10 each | 0-3 | ~960M | 800 | 30 | +| **graduated_landscape** | 1 NChooseK | 10 | 2-4 | 375 | 300 | 30 | +| **simple_additive** | 1 NChooseK | 12 | 1-4 | 793 | 300 | 30 | + +**Problem descriptions:** +- **multigroup_interaction**: Optimal requires specific features from all 3 groups with cross-group interaction bonuses (e.g., feature 1 + feature 9 = +12 bonus). Tests whether MCTS can learn multi-group correlations. +- **needle_in_haystack**: Single small optimal subset {3,7,11} among ~5000 candidates with mild partial credit. Tests raw exploration efficiency. +- **mixed_nchoosek_categorical**: Feature+categorical interactions (e.g., feature 2 + cat_dim_20=2.0 = +15). Tests handling of mixed discrete types. +- **large_sparse**: Optimal uses features from only 2 of 4 groups, with a sparsity bonus. The search space is ~960 million. Tests scalability and ability to learn that most groups should be empty. +- **graduated_landscape**: Smooth quality-based reward (each feature has a fixed quality score). Many near-optimal solutions. Tests exploitation of smooth structure. +- **simple_additive**: Simplest possible NChooseK problem — each feature contributes a fixed positive value with no interactions. Reward = sum of selected feature values. Tests whether MCTS can identify the highest-value features and the correct cardinality (4). + +--- + +## 2. Algorithm Fixes Applied + +### 2.1 Problem Identified: Exploration Bottleneck + +The original MCTS algorithm had a severe exploration bottleneck. With 600 iterations, it evaluated only ~50-60 unique terminal selections (vs ~588 for random sampling). The root cause was a feedback loop: + +1. **UCT concentrates visits** on the highest-reward branch +2. That branch **grows deeper** (one node expanded per iteration) +3. Deep leaves have **few rollout choices** left, producing the same terminals +4. **Cached reward is backpropagated**, reinforcing the exploitation bias +5. Goto 1 + +Unlike game-playing MCTS where every rollout is stochastic, here the reward function is deterministic and cached — revisiting a terminal adds zero information, but the old code still backpropagated the cached reward as if it were new. + +### 2.2 Fix 1: Virtual Loss on Cache Hit + +When an iteration produces a terminal that's already in the cache, we increment visit counts along the path but **backpropagate zero reward**. This dilutes `mean_value = w_total / n_visits` for over-visited nodes, causing UCT to prefer less-explored branches: + +```python +if is_novel: + self._backpropagate(path, reward, selected_features, cat_selections) +else: + # Virtual loss: increment visits with zero reward + for n in path: + n.n_visits += 1 +``` + +It is critical to still increment `n_visits` (not skip backpropagation entirely), because: +- **Progressive Widening** uses `n_visits` to decide when to expand new children +- **UCT** needs visit counts to change so it doesn't deterministically repeat the same path + +### 2.3 Fix 2: Rollout Retry on Cache Hit + +When a rollout produces a cached terminal, re-roll up to `max_rollout_retries` times: + +```python +selected_features, cat_selections = self._rollout(leaf) +for _attempt in range(self.max_rollout_retries): + key = self._make_cache_key(selected_features, cat_selections) + if key not in self.value_cache: + break + selected_features, cat_selections = self._rollout(leaf) +``` + +This is cheap (rollouts are fast) and directly reduces wasted iterations from non-terminal leaves where rollout randomness can reach diverse terminals. + +--- + +## 3. Results + +### 3.1 Summary Tables + +#### multigroup_interaction (search space ~4.25M, optimum = 150.0) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|----------|------|----------|-------------| +| Random | 62.9 | 10.3 | 0% | 588 | +| MCTS (default) | 94.9 | 18.9 | 0% | 205 | +| **MCTS (no RAVE)** | **105.1** | 25.8 | **20%** | 392 | +| MCTS (no PW) | 90.8 | 19.2 | 0% | 197 | +| MCTS (no RAVE, no PW) | 99.4 | 29.8 | 20% | 365 | +| MCTS (high explore) | 97.3 | 15.7 | 0% | 240 | +| MCTS (heavy RAVE) | 81.4 | 18.7 | 0% | 84 | +| MCTS (p_stop=0.1) | 96.2 | 14.3 | 0% | 238 | +| MCTS (p_stop=0.6) | 84.5 | 20.8 | 0% | 165 | +| MCTS (adaptive p) | 98.0 | 14.5 | 0% | 219 | +| MCTS (no RAVE+adpt) | 103.8 | 29.7 | 23% | 380 | +| MCTS (norm) | 101.8 | 10.5 | 0% | 321 | +| MCTS (no RAVE+adpt+norm) | 108.9 | 25.1 | 23% | 455 | +| **MCTS (+rpol)** | **111.4** | 23.6 | **23%** | 516 | +| MCTS (+rpol ε=0.1) | 114.1 | 24.0 | 27% | 511 | +| MCTS (+rpol τ=0.5) | 109.9 | 22.7 | 20% | 510 | +| MCTS (+rpol τ=2) | 112.5 | 23.1 | 23% | 514 | +| MCTS (+crave k=100) | 105.3 | 21.1 | 13% | 445 | +| MCTS (+crave k=300) | 103.5 | 17.3 | 7% | 370 | +| MCTS (+crave k=500) | 100.2 | 19.5 | 7% | 310 | + +#### needle_in_haystack (search space ~4,928, optimum = 100.0) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|----------|------|----------|-------------| +| Random | 39.7 | 20.5 | 10% | 216 | +| MCTS (default) | 77.0 | 32.6 | 67% | 58 | +| **MCTS (no RAVE)** | **97.7** | 12.6 | **97%** | 154 | +| MCTS (no RAVE, no PW) | 97.7 | 12.6 | 97% | 147 | +| MCTS (high explore) | 88.3 | 26.1 | 83% | 64 | +| MCTS (heavy RAVE) | 35.2 | 18.0 | 7% | 34 | +| MCTS (p_stop=0.6) | 91.2 | 22.6 | 87% | 63 | +| MCTS (adaptive p) | 84.0 | 29.1 | 77% | 61 | +| **MCTS (no RAVE+adpt)** | **100.0** | 0.0 | **100%** | 161 | +| MCTS (norm) | 97.7 | 12.6 | 97% | 98 | +| MCTS (no RAVE+adpt+norm) | 100.0 | 0.0 | 100% | 247 | +| **MCTS (+rpol)** | **100.0** | 0.0 | **100%** | 283 | +| MCTS (+rpol τ=0.5) | 100.0 | 0.0 | 100% | 283 | +| MCTS (+crave k=100) | 100.0 | 0.0 | 100% | 219 | +| MCTS (+crave k=300) | 100.0 | 0.0 | 100% | 139 | +| MCTS (+crave k=500) | 97.7 | 12.6 | 97% | 106 | + +#### mixed_nchoosek_categorical (search space ~26,896, optimum = 150.0) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|----------|------|----------|-------------| +| Random | 79.2 | 14.6 | 3% | 472 | +| MCTS (default) | 84.5 | 16.7 | 3% | 111 | +| **MCTS (no RAVE)** | **113.6** | 34.7 | **47%** | 284 | +| MCTS (no RAVE, no PW) | 108.5 | 37.1 | 43% | 279 | +| MCTS (p_stop=0.1) | 89.5 | 18.4 | 7% | 130 | +| MCTS (heavy RAVE) | 74.7 | 13.1 | 0% | 46 | +| MCTS (adaptive p) | 85.8 | 9.3 | 0% | 112 | +| MCTS (no RAVE+adpt) | 110.4 | 35.7 | 43% | 280 | +| MCTS (norm) | 86.7 | 8.5 | 0% | 174 | +| MCTS (no RAVE+adpt+norm) | 127.0 | 30.5 | 63% | 357 | +| **MCTS (+rpol)** | **135.9** | 25.6 | **77%** | 442 | +| MCTS (+rpol ε=0.1) | 126.0 | 29.4 | 60% | 404 | +| MCTS (+rpol τ=0.5) | 140.0 | 22.4 | 83% | 444 | +| MCTS (+rpol τ=2) | 136.0 | 25.4 | 77% | 440 | +| MCTS (+crave k=100) | 131.9 | 27.7 | 70% | 373 | +| **MCTS (+crave k=300)** | **137.7** | 24.7 | **80%** | 304 | +| MCTS (+crave k=500) | 132.0 | 27.5 | 70% | 264 | + +#### large_sparse (search space ~960M, optimum = 200.0) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|----------|------|----------|-------------| +| Random | 36.1 | 6.3 | 0% | 764 | +| MCTS (default) | 40.0 | 8.4 | 0% | 303 | +| **MCTS (no RAVE)** | **83.8** | 64.5 | **23%** | 515 | +| MCTS (no RAVE, no PW) | 61.5 | 38.1 | 7% | 513 | +| MCTS (p_stop=0.6) | 55.4 | 28.1 | 3% | 448 | +| MCTS (heavy RAVE) | 31.3 | 9.4 | 0% | 90 | +| MCTS (adaptive p) | 54.6 | 27.8 | 3% | 421 | +| MCTS (no RAVE+adpt) | 93.0 | 64.9 | 27% | 550 | +| MCTS (norm) | 40.5 | 7.9 | 0% | 603 | +| MCTS (no RAVE+adpt+norm) | 112.1 | 72.0 | 40% | 689 | +| **MCTS (+rpol)** | **129.8** | 70.2 | **50%** | 750 | +| MCTS (+rpol ε=0.1) | 90.4 | 60.7 | 23% | 749 | +| MCTS (+rpol τ=0.5) | 128.1 | 72.0 | 50% | 749 | +| MCTS (+rpol τ=2) | 118.7 | 71.2 | 43% | 749 | +| MCTS (+crave k=100) | 128.8 | 71.3 | 50% | 696 | +| MCTS (+crave k=300) | 119.1 | 70.8 | 43% | 651 | +| MCTS (+crave k=500) | 118.5 | 71.4 | 43% | 591 | + +#### graduated_landscape (search space 375, optimum = 65.0) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|----------|------|----------|-------------| +| Random | 60.6 | 3.3 | 7% | 113 | +| MCTS (default) | 64.1 | 1.4 | 40% | 65 | +| **MCTS (no RAVE)** | **64.9** | 0.3 | **90%** | 162 | +| MCTS (no RAVE, no PW) | 64.8 | 0.4 | 80% | 157 | +| MCTS (p_stop=0.1) | 64.5 | 0.5 | 47% | 89 | +| MCTS (heavy RAVE) | 55.4 | 5.4 | 0% | 21 | +| MCTS (adaptive p) | 64.5 | 0.8 | 57% | 75 | +| MCTS (no RAVE+adpt) | 64.6 | 0.9 | 77% | 168 | +| MCTS (norm) | 62.4 | 3.2 | 10% | 49 | +| MCTS (no RAVE+adpt+norm) | 64.7 | 0.8 | 80% | 152 | +| **MCTS (+rpol)** | **64.5** | 1.4 | **80%** | 157 | +| MCTS (+rpol ε=0.1) | 64.9 | 0.2 | 93% | 151 | +| MCTS (+rpol τ=0.5) | 64.7 | 0.4 | 73% | 154 | +| MCTS (+rpol τ=2) | 64.7 | 0.6 | 80% | 163 | +| MCTS (+crave k=100) | 64.7 | 0.5 | 70% | 109 | +| MCTS (+crave k=300) | 63.9 | 2.2 | 47% | 72 | +| MCTS (+crave k=500) | 63.0 | 2.9 | 30% | 53 | + +#### simple_additive (search space 793, optimum = 65.0) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|----------|------|----------|-------------| +| Random | 57.7 | 3.3 | 0% | 115 | +| MCTS (default) | 61.8 | 3.7 | 37% | 70 | +| **MCTS (no RAVE)** | **64.0** | 2.0 | **70%** | 167 | +| MCTS (heavy RAVE) | 52.4 | 6.0 | 0% | 27 | +| MCTS (p_stop=0.1) | 64.3 | 1.6 | 80% | 100 | +| MCTS (adaptive p) | 63.3 | 2.1 | 50% | 80 | +| MCTS (no RAVE+adpt) | 64.1 | 1.7 | 77% | 181 | +| MCTS (no RAVE+adpt+norm) | 64.1 | 2.2 | 83% | 184 | +| **MCTS (+rpol)** | **64.1** | 2.2 | **83%** | 187 | +| MCTS (+rpol ε=0.1) | 64.4 | 1.4 | 83% | 175 | +| MCTS (+rpol τ=0.5) | 64.1 | 2.1 | 80% | 187 | +| MCTS (+rpol τ=2) | 64.3 | 1.4 | 77% | 187 | +| MCTS (+crave k=100) | 64.2 | 1.5 | 77% | 143 | +| MCTS (+crave k=300) | 63.7 | 2.4 | 67% | 92 | +| MCTS (+crave k=500) | 62.8 | 2.6 | 40% | 69 | + +### 3.2 Convergence Curves + +#### All configurations — large_sparse problem +![Convergence large_sparse](convergence_large_sparse.png) + +MCTS (no RAVE+adpt+norm) leads with mean best ~112 and 40% optimum-finding rate in a search space of ~960 million. The high variance reflects that when MCTS finds the right region early, it converges to the optimum; otherwise it still significantly outperforms random. + +#### RAVE effect — needle_in_haystack +![RAVE effect needle](convergence_needle_in_haystack_rave_effect.png) + +The no-RAVE variants converge rapidly to near-optimum, achieving 97% success. Heavy RAVE (pink) performs worse than random — RAVE's context-independent feature value assumption actively misleads the search. + +#### p_stop effect — multigroup_interaction +![p_stop multigroup](convergence_multigroup_interaction_p_stop.png) + +p_stop=0.1 (cyan) outperforms default (p_stop=0.35) because the optimal solution requires 7 features across 3 groups — low stop probability produces rollouts with more features, better matching the target. + +#### Rollout policy effect — large_sparse +![Rollout policy large_sparse](convergence_large_sparse_rollout.png) + +MCTS (+rpol) (dark red) converges to 129.8 mean best and 50% optimum rate, a clear improvement over the no-rollout-policy baseline at 112.1 / 40%. The ε=0.1 variant (orange) collapses to 23% — too little exploration on a 960M search space. Default ε=0.3 provides the most robust balance. + +#### Rollout policy effect — mixed_nchoosek_categorical +![Rollout policy mixed](convergence_mixed_nchoosek_categorical_rollout.png) + +The τ=0.5 variant (gold) achieves 83% optimum rate, the highest across all configs on this problem. The default +rpol (ε=0.3, τ=1.0) also improves substantially to 77% (from 63% without rollout policy). + +#### Context RAVE effect — mixed_nchoosek_categorical +![Context RAVE mixed](convergence_mixed_nchoosek_categorical_crave.png) + +Context-aware RAVE with k=300 (teal) achieves 80% optimum rate on the mixed problem, the second-best result after τ=0.5 (83%). It outperforms the baseline +rpol (77%) by re-enabling RAVE in a context-dependent way that avoids the pitfalls of global RAVE. + +#### Context RAVE effect — large_sparse +![Context RAVE large_sparse](convergence_large_sparse_crave.png) + +On large_sparse, context RAVE k=100 matches +rpol at 50% optimum rate. Higher k values (300, 500) show 43% — the stronger RAVE signal reduces exploration in this enormous search space. + +--- + +## 4. Analysis + +### 4.1 Impact of the Algorithmic Fixes + +The virtual loss + rollout retry combination produced dramatic improvements across every problem and configuration: + +| Problem | Old default | New default | Old best | New best | +|---------|------------|-------------|----------|----------| +| multigroup_interaction | 78.5 | **94.9** (+21%) | 81.9 | **105.1** (+28%) | +| needle_in_haystack | 40.0 (17%) | **77.0 (67%)** | 49.2 (30%) | **97.7 (97%)** | +| mixed_nchoosek_categorical | 73.3 (0%) | **84.5 (3%)** | 79.2 (3%) | **113.6 (47%)** | +| large_sparse | 30.0 (0%) | **40.0** | 49.2 (3%) | **83.8 (23%)** | +| graduated_landscape | 54.0 (0%) | **64.1 (40%)** | 60.6 (7%) | **64.9 (90%)** | + +*Percentages in parentheses are optimum-finding rates. "Old best" is the best config from the pre-fix benchmark (often random sampling).* + +The unique evaluations tell the story — the exploration bottleneck has been substantially resolved: + +| Problem | Random | Old MCTS default | New MCTS default | New MCTS (no RAVE) | +|---------|--------|------------------|------------------|-------------------| +| multigroup_interaction | 588 | 54 | **205** (3.8x) | **392** (7.3x) | +| needle_in_haystack | 216 | 33 | **58** (1.8x) | **154** (4.7x) | +| mixed_nchoosek_categorical | 472 | 30 | **111** (3.7x) | **284** (9.5x) | +| large_sparse | 764 | 72 | **303** (4.2x) | **515** (7.2x) | +| graduated_landscape | 113 | 19 | **65** (3.4x) | **162** (8.5x) | + +### 4.2 RAVE: Harmful, Should Be Disabled + +With the exploration bottleneck fixed, RAVE's effect becomes even clearer. **Disabling RAVE is the single most impactful parameter change**, consistently producing the best or tied-best results: + +| Problem | Default (RAVE on) | No RAVE | Heavy RAVE | +|---------|-------------------|---------|------------| +| multigroup_interaction | 94.9 (0%) | **105.1 (20%)** | 81.4 (0%) | +| needle_in_haystack | 77.0 (67%) | **97.7 (97%)** | 35.2 (7%) | +| mixed_nchoosek_categorical | 84.5 (3%) | **113.6 (47%)** | 74.7 (0%) | +| large_sparse | 40.0 (0%) | **83.8 (23%)** | 31.3 (0%) | +| graduated_landscape | 64.1 (40%) | **64.9 (90%)** | 55.4 (0%) | + +RAVE's context-independent assumption (feature X is equally valuable regardless of what other features are selected) is fundamentally wrong for NChooseK problems. With the virtual loss fix allowing more exploration, the damage from RAVE's mis-generalization becomes much more visible — it actively steers the search toward poor feature combinations. + +**Heavy RAVE (k_rave=3000) performs worse than random on 2 of 5 problems.** This should be considered a broken configuration. + +### 4.3 Progressive Widening: Moderate Effect + +| Problem | Default (PW on) | No PW | Tight PW | Loose PW | +|---------|-----------------|-------|----------|----------| +| multigroup_interaction | **94.9** | 90.8 | 91.4 | 90.8 | +| needle_in_haystack | 77.0 | 76.7 | 39.8 | 76.7 | +| mixed_nchoosek_categorical | **84.5** | 82.2 | 85.2 | 82.2 | +| large_sparse | 40.0 | **40.2** | 52.6 | 40.2 | +| graduated_landscape | **64.1** | 63.7 | 61.6 | 63.7 | + +Default PW (pw_k0=2.0, pw_alpha=0.6) is slightly better than no PW on most problems when RAVE is active. This is because PW provides a controlled pace of exploration that complements the virtual loss mechanism. Tight PW (pw_k0=1.0, pw_alpha=0.4) is too restrictive and hurts on needle_in_haystack. Notably, PW matters much less when RAVE is disabled — the no RAVE + no PW config performs nearly as well as no RAVE + default PW. + +### 4.4 Exploration Constant (c_uct) + +| Problem | Low (0.1) | Default (1.0) | High (5.0) | +|---------|-----------|---------------|------------| +| multigroup_interaction | 94.2 | 94.9 | **97.3** | +| needle_in_haystack | 74.3 (63%) | 77.0 (67%) | **88.3 (83%)** | +| mixed_nchoosek_categorical | 87.7 | 84.5 | **88.1** | +| large_sparse | 38.0 | 40.0 | **47.2** | +| graduated_landscape | 63.9 (23%) | 64.1 (40%) | **64.3 (40%)** | + +Higher c_uct consistently helps. With the virtual loss fix, the exploration bonus from UCT now has room to operate — the tree isn't locked into a single deep branch anymore. c_uct=5.0 is the best pure-UCT exploration setting tested, though the improvement is modest compared to the impact of disabling RAVE. + +### 4.5 Stop Probability (p_stop_rollout): Problem-Dependent, Now Adaptive + +| Problem | Optimal subset size | p_stop=0.1 | p_stop=0.35 | p_stop=0.6 | Adaptive | no RAVE+adpt | +|---------|-------------------|-----------|------------|-----------|----------|-------------| +| multigroup_interaction | 7 features | 96.2 | 94.9 | 84.5 | 98.0 | 103.8 (23%) | +| needle_in_haystack | 3 features | 67.0 (53%) | 77.0 (67%) | 91.2 (87%) | 84.0 (77%) | **100.0 (100%)** | +| mixed_nchoosek_categorical | 4 features | 89.5 (7%) | 84.5 (3%) | 83.9 (0%) | 85.8 (0%) | 110.4 (43%) | +| large_sparse | 5 from 2 groups | 33.6 | 40.0 | 55.4 | 54.6 | **93.0 (27%)** | +| graduated_landscape | 4 features | 64.5 (47%) | 64.1 (40%) | 62.4 (3%) | 64.5 (57%) | 64.6 (77%) | + +The pattern for fixed p_stop is consistent: low p_stop favors problems needing many features; high p_stop favors sparse solutions. + +**Adaptive p_stop** learns per-group stop probabilities online from cardinality-reward statistics. It tracks `(group_idx, cardinality) -> (visits, total_reward)`, computes E_stop vs E_continue (max over higher cardinalities), and applies a sigmoid to determine stop probability, blended with the fixed prior during a warmup period. + +Results show adaptive p_stop provides a **robust default that avoids catastrophic mismatch**: +- **Best or tied-best** on multigroup_interaction (98.0 vs 96.2 for p_stop=0.1) and graduated_landscape (57% opt rate vs 47% for p_stop=0.1) +- **Competitive** on large_sparse (54.6 vs 55.4 for p_stop=0.6) and needle_in_haystack (77% vs 87% for p_stop=0.6) +- **Never the worst**: Avoids the bad performance of wrong fixed p_stop (e.g., p_stop=0.6 on multigroup_interaction gives only 84.5, while adaptive gives 98.0) + +**No RAVE + adaptive p_stop** is a strong configuration, combining two impactful improvements: +- **100% optimum rate on needle_in_haystack** (up from 97% with no RAVE alone, perfect across all 30 trials) +- **93.0 mean / 27% opt rate on large_sparse** (up from 83.8 / 23% with no RAVE alone) +- The synergy is clear: no RAVE removes the misleading context-independent bias, while adaptive p_stop learns the right cardinality preference per problem + +Adding reward normalization (Section 4.6) further improves this to the best overall configuration. + +The adaptive mechanism is most valuable when the user cannot tune p_stop per-problem, which is the typical use case in real BO workflows where the reward landscape is unknown a priori. + +### 4.6 Reward Normalization: Best Overall When c_uct Is Tuned + +Reward normalization maps rewards to [0, 1] via running min-max before backpropagation. This makes `c_uct` scale-independent — the same `c_uct` value gives consistent exploration-exploitation balance regardless of the problem's reward range. + +**Critical**: normalization requires scaling `c_uct` to match the [0, 1] reward scale. With raw rewards in the range 60–272 across problems, `c_uct=1.0` gives an effective exploration ratio of `1/reward_range`. With normalized rewards, `c_uct=0.01` produces equivalent balance. Using `c_uct=1.0` with normalization massively over-explores and degrades to random sampling. + +#### `MCTS (no RAVE+adpt)` vs `MCTS (no RAVE+adpt+norm)` — the key comparison + +| Problem | no RAVE+adpt (mean/opt%) | +norm (mean/opt%) | Delta | +|---------|--------------------------|---------------------|-------| +| multigroup_interaction | 103.8 / 23% | **108.9 / 23%** | +5.1 mean | +| needle_in_haystack | 100.0 / 100% | **100.0 / 100%** | tied | +| mixed_nchoosek_categorical | 110.4 / 43% | **127.0 / 63%** | +16.6 mean, +20pp opt | +| large_sparse | 93.0 / 27% | **112.1 / 40%** | +19.1 mean, +13pp opt | +| graduated_landscape | 64.6 / 77% | **64.7 / 80%** | +0.1 mean, +3pp opt | + +**Normalization improves the best config on every problem.** The two hardest problems see the largest gains: mixed_nchoosek_categorical jumps from 43% to 63% optimum rate, and large_sparse from 27% to 40%. Unique evaluations increase moderately (380→455, 550→689), indicating normalization adds useful exploration without degenerating into random search. + +#### `MCTS (default)` vs `MCTS (norm)` — normalization with RAVE on + +| Problem | Default (mean/opt%) | Norm (mean/opt%) | +|---------|---------------------|-------------------| +| multigroup_interaction | 94.9 / 0% | **101.8** / 0% | +| needle_in_haystack | 77.0 / 67% | **97.7 / 97%** | +| mixed_nchoosek_categorical | 84.5 / 3% | 86.7 / 0% | +| large_sparse | 40.0 / 0% | 40.5 / 0% | +| graduated_landscape | **64.1 / 40%** | 62.4 / 10% | + +With RAVE enabled, normalization helps on needle and multigroup but hurts on graduated_landscape. The `c_uct=0.01` combined with RAVE's dampening effect (`beta` reduces UCT weight) makes the search too exploitative on the small search space. + +#### Why normalization helps + +1. **Scale-invariant c_uct**: With raw rewards, `c_uct=1.0` gives different effective exploration pressure on each problem — under-exploring on large_sparse (range 272) and over-exploring on graduated (range 60). With normalization, `c_uct=0.01` gives consistent behavior across all problems. + +2. **Improved virtual loss**: Virtual loss on cache hit adds zero reward. With raw rewards centered around, say, 50, this dilutes toward 0 — far below the actual reward range. With normalized rewards in [0, 1], zero is exactly the minimum, making virtual loss dilute toward the worst case rather than an arbitrary anchor. + +3. **More exploration on harder problems**: The unique evaluation counts show normalization adds ~20% more exploration (380→455, 550→689) while maintaining focus. The raw configs under-explore large_sparse relative to its enormous search space; normalization partially corrects this. + +#### Recommended usage + +Normalization should be enabled together with `c_uct=0.01` (or more generally, `c_uct ≈ 1/typical_reward_range`). This combination produces the best overall results and removes the need to tune `c_uct` per problem. + +### 4.7 Rollout Policy: Learned Softmax Biasing of Rollouts + +The default rollout strategy selects actions uniformly at random (with adaptive p_stop for STOP decisions). This wastes search budget on poor actions even after the tree has accumulated evidence about which actions are good. The **blended softmax rollout policy** replaces uniform rollouts with a learned policy: + +1. For each `(group_idx, action)` pair, track `(visit_count, total_reward)` across all rollouts +2. Score each action: `mean_reward + novelty_weight / sqrt(visits + 1)` +3. Apply softmax with temperature τ: `p_policy[a] = exp(score[a] / τ) / Z` +4. Blend with uniform: `p[a] = (1 - ε) * p_policy[a] + ε / |legal_actions|` +5. STOP is treated as a regular action with its own learned statistics — no special handling needed + +#### `MCTS (no RAVE+adpt+norm)` vs `MCTS (+rpol)` — the key comparison + +| Problem | no rpol (mean/opt%) | +rpol (mean/opt%) | Delta | +|---------|---------------------|-------------------|-------| +| multigroup_interaction | 108.9 / 23% | **111.4 / 23%** | +2.5 mean | +| needle_in_haystack | 100.0 / 100% | **100.0 / 100%** | tied | +| mixed_nchoosek_categorical | 127.0 / 63% | **135.9 / 77%** | +8.9 mean, +14pp opt | +| large_sparse | 112.1 / 40% | **129.8 / 50%** | +17.7 mean, +10pp opt | +| graduated_landscape | 64.7 / 80% | 64.5 / 80% | tied | + +**The rollout policy improves the two hardest problems most**: mixed jumps from 63% to 77% optimum rate, large_sparse from 40% to 50%. On the easier problems it matches the baseline. Unique evaluations increase (455→516, 689→750), showing the policy improves exploration diversity while maintaining focus on promising actions. + +#### Hyperparameter sensitivity + +| Variant | multigroup | needle | mixed | large_sparse | graduated | additive | **Mean opt%** | +|---------|-----------|--------|-------|-------------|-----------|----------|--------------| +| **+rpol (ε=0.3, τ=1.0)** | 23% | 100% | 77% | 50% | 80% | 83% | **68.8%** | +| +rpol ε=0.1 | 27% | 97% | 60% | 23% | 93% | 83% | 63.8% | +| +rpol τ=0.5 | 20% | 100% | 83% | 50% | 73% | 80% | 67.7% | +| +rpol τ=2 | 23% | 97% | 77% | 43% | 80% | 77% | 66.2% | + +- **ε=0.1 is dangerous**: Too little uniform exploration causes collapse on hard problems (23% on large_sparse vs 50% for default). It performs well on easy problems (93% on graduated) but this is misleading. +- **τ=0.5** is strong on mixed (83%) and large_sparse (50%) but drops on multigroup (20%) and graduated (73%). More aggressive exploitation helps when there are many actions per group but hurts when the search space is smaller. +- **τ=2** adds noise without benefit — the extra temperature dilutes the learned policy signal. +- **Default (ε=0.3, τ=1.0) is the most robust**: Highest mean optimum rate (66%), no catastrophic failures. + +#### Why it works + +The rollout policy addresses a fundamental inefficiency: in a tree with hundreds of iterations of history, uniform rollouts treat all actions as equally likely — including actions that have consistently produced poor results. The softmax policy biases toward historically good actions while the ε-blend and novelty bonus maintain exploration: + +1. **Novelty bonus** (`β/√(n+1)`) ensures unvisited actions are tried — equivalent to UCB1 exploration in the rollout phase +2. **ε-mixing** provides a floor on exploration probability, preventing the policy from fully committing to exploitation +3. **Treating STOP as a regular action** unifies the rollout decision-making: the policy learns when stopping is good vs when adding more features helps, without requiring separate p_stop tuning + +The statistics are updated unconditionally (even when `rollout_policy=False`), so the data is always warm if the policy is enabled later. + +### 4.8 Context-Aware RAVE: Making RAVE Useful Again + +Global RAVE was identified as harmful (§4.2) because it uses a single value estimate per action regardless of context. Context-aware RAVE fixes this by conditioning statistics on `(group_idx, cardinality, action)` — the same feature can have different learned values depending on how many features are already selected in that group. + +#### `MCTS (+rpol)` vs `MCTS (+crave)` — the key comparison + +| Problem | +rpol (mean/opt%) | +crave k=100 | +crave k=300 | +crave k=500 | +|---------|-------------------|-------------|-------------|-------------| +| multigroup_interaction | **111.4 / 23%** | 105.3 / 13% | 103.5 / 7% | 100.2 / 7% | +| needle_in_haystack | **100.0 / 100%** | 100.0 / 100% | 100.0 / 100% | 97.7 / 97% | +| mixed_nchoosek_categorical | 135.9 / 77% | 131.9 / 70% | **137.7 / 80%** | 132.0 / 70% | +| large_sparse | **129.8 / 50%** | 128.8 / 50% | 119.1 / 43% | 118.5 / 43% | +| graduated_landscape | 64.5 / 80% | **64.7 / 70%** | 63.9 / 47% | 63.0 / 30% | +| simple_additive | **64.1 / 83%** | 64.2 / 77% | 63.7 / 67% | 62.8 / 40% | + +#### Problem-specific analysis + +**mixed_nchoosek_categorical**: Context RAVE k=300 achieves the **second-best optimum rate (80%)** across all configs on this problem, outperforming the baseline +rpol (77%). The mixed problem has feature-categorical interactions where context matters: knowing that 2 features are already selected helps RAVE estimate whether adding a 3rd is worthwhile. With 2 NChooseK groups + 2 categoricals, there are enough distinct contexts for RAVE to learn meaningful state-dependent values. + +**needle_in_haystack**: Context RAVE k=100 and k=300 both achieve 100% optimum rate with fewer unique evaluations (219 and 139 vs 283 for +rpol). The context signal helps RAVE guide the search more efficiently — it needs fewer evaluations to identify the optimal subset. + +**multigroup_interaction**: Context RAVE underperforms +rpol here (13% vs 23%). With 3 groups of 8 features picking 1-4 each, the context space is large and the 600-iteration budget may be insufficient to populate the context RAVE table adequately. The high k values (300, 500) perform worse because they give too much weight to sparse, noisy context statistics. + +**large_sparse**: Context RAVE k=100 matches +rpol at 50%, but k=300 and k=500 drop to 43%. In a 960M search space, context statistics are sparse and higher RAVE weight injects noise. + +**graduated_landscape**: Context RAVE degrades with higher k values (70%→47%→30%). This small, smooth problem (375 combinations) doesn't need RAVE guidance — the policy and UCT alone are sufficient, and RAVE adds overhead. + +**simple_additive**: The simplest problem (independent additive features, no interactions) confirms MCTS solves this easy case reliably. The best configs (+rpol, +rpol ε=0.1, no RAVE+adpt+norm) all achieve 83% optimum rate. Context RAVE k=100 is close at 77%, while higher k values degrade — the additive structure has no context-dependent feature values, so RAVE signal is pure noise. + +#### Key insights + +1. **k_rave=100 is the safest choice**: It matches or nearly matches +rpol on every problem, and wins on needle_in_haystack efficiency. +2. **k_rave=300 is optimal for mixed problems**: Where feature-context interactions are rich, stronger RAVE signal helps. +3. **High k_rave (500) is harmful**: Too much weight on context RAVE degrades performance across the board, similar to how global heavy RAVE (k=3000) was catastrophic. +4. **Context RAVE helps most when**: (a) the problem has meaningful context-dependent feature values (mixed, needle), and (b) the iteration budget is sufficient to populate the context table. +5. **Context RAVE helps least when**: (a) the search space is small and UCT+policy alone suffice (graduated), or (b) the search space is so large that context statistics remain sparse (large_sparse with high k). + +#### Recommended usage + +Context RAVE with k=100 is a safe addition that provides modest benefits on structured problems without degrading performance on others. For problems known to have strong feature-cardinality interactions, k=300 can provide additional benefit. Context RAVE is not recommended as a default because the baseline +rpol configuration is more robust across diverse problem structures. + +--- + +## 5. Optimum-Finding Rates + +![Optimum rate heatmap](optimum_rate_heatmap.png) + +**MCTS (+rpol)** is the new best overall: **100%** on needle_in_haystack, **50%** on large_sparse, **77%** on mixed, **23%** on multigroup_interaction, and **80%** on graduated_landscape. It outperforms or matches **MCTS (no RAVE+adpt+norm)** (100%, 40%, 63%, 23%, 80%) on every problem, with the largest gains on the two hardest problems (mixed +14pp, large_sparse +10pp). + +**Heavy RAVE is catastrophic**: 7% on needle (worse than random's 10%), 0% on 4 of 5 problems. + +--- + +## 6. Summary Bar Chart + +![Summary bar chart](summary_bar_chart.png) + +--- + +## 7. Exploration Efficiency + +![Unique evaluations](unique_evals.png) + +The no-RAVE configurations now explore 150-515 unique selections per run, approaching or exceeding random's coverage while also directing that exploration intelligently. Heavy RAVE still restricts exploration to ~20-90 unique selections — RAVE's value-sharing biases the tree toward a narrow set of "globally good" features, undermining the virtual loss mechanism. + +--- + +## 8. Recommendations + +### 8.1 Recommended Default Configuration + +Based on these results, the recommended defaults for NChooseK problems are: + +| Parameter | Current default | Recommended | Rationale | +|-----------|----------------|-------------|-----------| +| k_rave | 300 | **0** | RAVE hurts on every problem tested | +| c_uct | 1.0 | **0.01** | Paired with normalize_rewards=True; see §4.6 | +| pw_k0 | 2.0 | 2.0 | Current value works well with virtual loss | +| pw_alpha | 0.6 | 0.6 | Current value works well | +| max_rollout_retries | 3 | 3 | Effective at reducing wasted iterations | +| p_stop_rollout | 0.35 | 0.35 | Base prior for adaptive blending | +| adaptive_p_stop | True | **True** | Avoids worst-case fixed p_stop mismatch | +| p_stop_warmup | 20 | 20 | Sufficient to accumulate per-group statistics | +| p_stop_temperature | 0.25 | 0.25 | Produces decisive but not extreme sigmoid | +| normalize_rewards | False | **True** | Best overall with tuned c_uct; see §4.6 | +| rollout_policy | False | **True** | +14pp mixed, +10pp large_sparse; see §4.7 | +| rollout_epsilon | 0.3 | 0.3 | Lower values collapse on hard problems | +| rollout_tau | 1.0 | 1.0 | Most robust across all problems | +| rollout_novelty_weight | 1.0 | 1.0 | Encourages exploration of unvisited actions | + +### 8.2 Further Improvements to Explore + +1. ~~**Adaptive p_stop_rollout**~~: **Implemented and validated.** Per-group adaptive p_stop learns from cardinality-reward statistics. Combined with no RAVE, it achieves 100% on needle_in_haystack and best results on large_sparse. See Section 4.5 for details. +2. ~~**Context-aware RAVE**~~: **Implemented and validated.** Conditions RAVE on `(group_idx, cardinality, action)` so it captures state-dependent value. With k=300, achieves 80% on mixed problems (vs 77% for +rpol). With k=100, matches +rpol on all problems while using fewer evaluations on needle_in_haystack. Not recommended as default due to marginal benefit on most problems. See Section 4.8 for details. +3. ~~**Reward normalization**~~: **Implemented and validated.** Min-max normalization to [0, 1] before backpropagation with `c_uct=0.01` to match the [0, 1] scale. See Section 4.6 for details. +4. ~~**Blended softmax rollout policy**~~: **Implemented and validated.** Replaces uniform rollouts with a learned softmax policy blended with uniform exploration. The rollout policy is the new best configuration on all 5 problems: 77% on mixed (up from 63%), 50% on large_sparse (up from 40%), and best or tied elsewhere. Default hyperparameters (ε=0.3, τ=1.0) are the most robust. See Section 4.7 for details. +5. **Burn-in for reward normalization**: Reward normalization uses running min-max, but early iterations have a poor estimate of the reward range, so their normalized values are distorted relative to what they'd be with the final bounds. A potential fix: skip normalization for the first N iterations (using raw rewards), then once the range stabilizes, retroactively re-normalize the early paths in a single pass. This targets exactly the period where the normalization error is worst, with minimal ongoing cost. +6. ~~**Thompson Sampling instead of UCT**~~: **Implemented and benchmarked.** TS tree selection eliminates `c_uct` and `normalize_rewards`. With variance inflation for cache hits, `TS + TS(g,a) + var_infl` achieves 47% on multigroup_interaction (vs UCT's 23%) but underperforms on large_sparse (20% vs 50%). Not a drop-in replacement for UCT; best on interaction-heavy problems. See Section 8.3 for design and Section 11 for benchmark results. +7. ~~**Thompson Sampling for rollouts**~~: **Implemented and benchmarked.** TS rollout per (group, action) eliminates ε, τ, and novelty_weight. Outperforms uniform rollouts on all problems. TS rollout per (group, cardinality, action) adds context-awareness that helps on mixed problems (57% vs 40% for `(g,a)`) but hurts on others due to sparse statistics. See Section 8.4 for design and Section 11 for benchmark results. +8. **Two-phase burn-in with cheap evaluations**: Use random sampling instead of full `optimize_acqf()` during early iterations, then switch to accurate optimization. TS makes this transition natural. See Section 8.5 for detailed analysis. + +### 8.3 Thompson Sampling as UCT Replacement + +#### Motivation + +The current UCT selection rule is: + +``` +score = w_total/n_visits + c_uct * sqrt(log(parent_visits) / child_visits) +``` + +The exploitation term (mean reward) is in the scale of the rewards; the exploration term is in the scale of `c_uct`. For the balance to work, `c_uct` must be matched to the reward scale. This is why reward normalization exists — it compresses rewards to [0, 1] so `c_uct=0.01` has a consistent meaning across problems (§4.6). + +But min-max normalization introduces its own problem: early iterations have a poor estimate of the reward range, so their normalized values are distorted relative to what they'd be with the final bounds. The current best config (+rpol) runs ~100-500 iterations, and the range typically stabilizes within the first ~10-20 iterations, meaning ~10-20% of all backpropagated values in the tree carry normalization error that never gets corrected. + +Thompson Sampling (TS) eliminates both `c_uct` and reward normalization by replacing the deterministic UCT score with a Bayesian posterior. + +#### How it works + +Each child node maintains a posterior distribution over its expected reward (e.g., Normal with estimated mean and variance). At selection time: + +1. Sample a value from each child's posterior +2. Select the child with the highest sample +3. After evaluation, update the selected child's posterior with the observed reward + +Exploration happens automatically: children with few visits have wide posteriors, so they occasionally sample high values and get selected. As visits accumulate, the posterior tightens and exploitation dominates. There is no exploration constant to tune — the posterior's uncertainty naturally adapts to whatever reward range is observed. + +#### Why it eliminates normalization + +UCT needs normalization because `c_uct` is an absolute scale parameter. TS has no such parameter. If rewards are in [0, 1000], posteriors are wide in that scale; if rewards are in [0, 0.001], posteriors are wide in that scale. The exploration-exploitation balance is driven by the *relative* uncertainty between children, which is scale-invariant. + +This kills two sources of fragility at once: +- The early-iteration min-max distortion (no normalization needed at all) +- The `c_uct`-to-reward-scale coupling (§4.6: "using c_uct=1.0 with normalization massively over-explores") + +#### Expected behavior on benchmarks + +The current best config (+rpol) is very exploitation-heavy: `c_uct=0.01` with rewards in [0, 1] means the exploration term is tiny. UCT is almost greedy, relying on virtual loss and the rollout policy to provide diversity. TS would explore more in the first ~20-30 iterations (wide posteriors leading to near-uniform selection), then tighten as observations accumulate. + +- **needle_in_haystack** (currently 100%): TS should match. The search space is small (~5K) and TS's natural exploration finds the needle easily. The posterior quickly locks onto the optimal region. +- **graduated_landscape** (currently 80%): Likely comparable or slightly better. The smooth reward structure means posterior means accurately reflect the landscape, and TS naturally concentrates on the top region. +- **large_sparse** (currently 50%): The most uncertain case. The 960M search space benefits from exploitation-heavy search, and TS's wider early exploration could waste budget. But TS also avoids the min-max distortion that hurts early iterations. Expected: roughly comparable — maybe slightly worse initially but more robust across seeds (lower variance). +- **mixed_nchoosek_categorical** (currently 77%): TS could help here because the posterior captures reward *variance*, not just the mean. If a child leads to both high and low rewards (multi-modal due to downstream categorical interactions), TS naturally explores it more. UCT only sees the mean and might prematurely abandon it. +- **multigroup_interaction** (currently 23%): TS's broader early exploration might help discover the cross-group interaction bonuses, but the search space (~4.25M) is large enough that undirected exploration is costly. + +#### Cache hit handling — the key design decision + +With UCT, cache hits require the virtual loss hack: backpropagate zero reward to dilute `mean_value` and steer exploration away from exhausted branches (§2.2). With TS, the situation is fundamentally different because selection is stochastic. + +**On novel evaluation**: standard Bayesian update — add the observed reward to the child's posterior (increase observation count, update sufficient statistics). + +**On cache hit**: do not update the posterior. No new information was gained, so no update is the correct Bayesian action. The critical insight is that TS is *stochastic* — the next iteration draws fresh samples from each posterior, so a different selection path naturally occurs without any need to artificially distort statistics. This is a fundamental advantage over UCT, where not updating means the deterministic score is unchanged and the algorithm deterministically repeats the same path forever. + +Progressive widening still needs a visit counter, but it should be separated from the posterior's observation count: increment the PW counter on every visit (novel or cached) so the child limit grows normally, but only update the posterior on novel evaluations. + +**The over-exploitation risk**: if a subtree is "exhausted" (all terminals cached) and its posterior mean is high, TS will keep sampling it highly — it keeps visiting but never gets novel evaluations. The posterior stays tight at a high mean with no downward pressure. Unlike virtual loss, there's nothing actively discouraging revisits. + +Two mitigations: + +1. **Variance inflation on cache hits** (recommended first attempt): on each cache hit, slightly inflate the posterior variance (e.g., scale the effective observation count down by a decay factor). This gradually widens the posterior, making it possible for other children to "win" a sample. This is analogous to virtual loss but more principled — it says "repeated observations of the same cached value reduce your confidence rather than increase it," because the evidence is stale. + +2. **Cache hit rate tracking**: track `cache_hits / total_visits` per child. If this ratio exceeds a threshold (e.g., 0.8), the subtree is likely exhausted — force-widen the posterior or add a penalty to the posterior mean. This is more targeted than blanket variance inflation. + +#### Interaction with existing components + +- **Rollout policy** (§4.7): orthogonal to tree selection. TS replaces UCT in the selection phase; the softmax rollout policy operates independently during rollouts. No changes needed. +- **Adaptive p_stop** (§4.5): unchanged. It operates during rollouts and uses its own cardinality statistics, independent of tree selection. +- **Progressive widening**: works as before, but the visit counter for PW must be separated from the posterior observation count (see cache hit handling above). + +#### Summary + +The main win is eliminating reward normalization and `c_uct` tuning entirely, which also kills the early-iteration distortion problem. The main risk is that TS's exploration is less controllable — there is no knob equivalent to `c_uct` to dial exploitation up or down. For a system that needs robust defaults across diverse problems without per-problem tuning, that tradeoff is favorable. The cache hit problem has a clean solution (no-update + variance inflation) that is simpler and more principled than the current virtual loss mechanism. + +### 8.4 Thompson Sampling for Rollouts + +#### Current rollout policy recap + +The best configuration (+rpol) uses a learned softmax rollout policy (§4.7) that: + +1. Maintains `(visits, total_reward)` per `(group_idx, action)` pair — STOP included as a regular action +2. Scores each action: `mean_reward + novelty_weight / sqrt(visits + 1)` +3. Applies softmax with temperature τ: `p_policy[a] = exp(score[a] / τ) / Z` +4. Blends with uniform: `p[a] = (1 - ε) * p_policy[a] + ε / |legal_actions|` + +This requires three hyperparameters: `rollout_epsilon` (ε=0.3), `rollout_tau` (τ=1.0), `rollout_novelty_weight` (1.0). + +#### An important subtlety: adaptive p_stop is dead code in the best config + +When `rollout_policy=True`, the rollout code takes the `_sample_rollout_action` path for all actions including STOP: + +```python +if self.rollout_policy: + # Learned softmax policy: STOP is scored like any other action + action = self._sample_rollout_action(g, legal) +else: + # Original logic: adaptive p_stop for NChooseK, uniform for features + ... +``` + +The `_compute_adaptive_p_stop` code path is never reached during rollouts. The `adaptive_p_stop=True` flag still causes `_update_cardinality_stats` to run in `run()`, but those cardinality statistics are never read — the rollout policy scores STOP via `rollout_stats[(group_idx, STOP)]` using the same `(group, action)` key as any feature action, **without cardinality conditioning**. + +This means adaptive p_stop is effectively dead code in the best config. The rollout policy already handles STOP decisions without cardinality awareness, and it outperforms the adaptive p_stop mechanism: +rpol achieves 77% on mixed (vs 43% for no RAVE+adpt without rollout policy) and 50% on large_sparse (vs 27%). + +#### Proposed change: Thompson Sampling over (group, action) posteriors + +Replace the softmax + ε-blend + novelty weight with Thompson Sampling over the same `(group, action)` statistics: + +1. Each `(group_idx, action)` pair maintains a Normal posterior over its expected reward, initialized with a wide prior centered on the global mean reward +2. At each rollout step, for each legal action `a`, sample `r̃(a) ~ N(μ_a, σ²_a / n_a)` from the posterior. For unseen actions, sample from the prior (wide Normal) +3. Pick the action with the highest sample +4. After terminal evaluation, update all `(group, action)` posteriors in the trajectory with the observed reward + +This eliminates three hyperparameters (ε, τ, novelty_weight) → 0 tunable parameters for the rollout policy. + +#### Why it works + +The posterior **is** the exploration mechanism: + +- **Few visits** → wide posterior → occasionally samples very high → gets explored. This replaces the `1/sqrt(n+1)` novelty bonus, which is a frequentist approximation of exactly this effect. +- **Many visits** → tight posterior → concentrates near the true mean → exploitation. This replaces the softmax temperature, which controls how sharply the policy concentrates on high-scoring actions. +- **Unseen actions** → prior (maximum uncertainty) → high probability of sampling highest. This replaces the ε-uniform blend, which guarantees a floor on exploration probability. + +All three mechanisms in the current approach (novelty bonus, temperature, epsilon) are heuristic approximations of what Thompson Sampling does naturally through posterior uncertainty. + +#### Credit assignment + +The terminal reward is attributed to all `(group, action)` pairs in the trajectory equally. This is the same confounding the current approach has — the mean reward for `(group=0, action=3)` reflects not just the value of picking feature 3, but everything else that happened in that rollout. TS doesn't fix this, but it handles the resulting noise better: with few observations, the wide posterior prevents premature commitment, whereas `mean + 1/sqrt(n)` can be quite brittle when n is small. + +This confounding matters most on **multigroup_interaction** (cross-group interactions dominate) and least on **simple_additive** (features contribute independently). + +#### Cardinality conditioning: optional, not necessary + +One could key the posteriors on `(group, cardinality, action)` instead of `(group, action)`, so STOP at cardinality 2 has a separate posterior from STOP at cardinality 4. This would add context-awareness for STOP decisions and could theoretically subsume the adaptive p_stop mechanism. + +However, the evidence suggests this isn't necessary: the current best config (+rpol) already outperforms adaptive p_stop on every problem without any cardinality conditioning. The rollout policy's flat `(group, action)` statistics capture enough signal. Cardinality conditioning increases the key space (e.g., from ~33 entries to ~108 on multigroup_interaction), which means posteriors are updated less frequently and take longer to converge. + +If future benchmarks reveal problems where STOP decisions are highly cardinality-dependent and the flat key space isn't sufficient, cardinality conditioning can be added as a straightforward extension. But it should not be the default. + +#### Interaction with TS for tree selection (§8.3) + +If Thompson Sampling is adopted for both tree selection and rollouts, the entire MCTS system uses a single principle — posterior sampling — with no hand-tuned exploration constants anywhere. The full hyperparameter reduction would be: + +| Eliminated parameter | Current value | Replaced by | +|---------------------|---------------|-------------| +| `c_uct` | 0.01 | Tree TS posterior | +| `normalize_rewards` | True | Not needed (TS is scale-invariant) | +| `rollout_epsilon` | 0.3 | Rollout TS posterior | +| `rollout_tau` | 1.0 | Rollout TS posterior | +| `rollout_novelty_weight` | 1.0 | Rollout TS posterior | +| `adaptive_p_stop` | True (dead code) | Can be removed | +| `p_stop_rollout` | 0.35 | Can be removed | +| `p_stop_warmup` | 20 | Can be removed | +| `p_stop_temperature` | 0.25 | Can be removed | + +That is 9 hyperparameters reduced to 0 (or 1 if counting the prior variance, which can be set to a large value and forgotten). The only remaining MCTS hyperparameters would be structural: `pw_k0`, `pw_alpha`, `max_rollout_retries`, and the iteration budget. + +### 8.5 Two-Phase Burn-in with Cheap Evaluations + +#### Motivation + +Each terminal evaluation currently calls `optimize_acqf()` with BoTorch — multi-start L-BFGS optimization with `num_restarts=20` and `raw_samples=1024`. This is the dominant cost per MCTS iteration. The evaluation cache exists precisely because these calls are expensive and deterministic: once a feature combination is evaluated, re-evaluating it would produce the same result and waste computation. + +But the cache also creates the over-exploitation problem (§2.1): cached rewards get re-backpropagated, reinforcing exploitation bias. The virtual loss mechanism (§2.2) and rollout retry (§2.3) are workarounds for a problem that exists because evaluations are expensive enough to need caching. + +The idea: use cheap, noisy evaluations during a burn-in phase to explore the combinatorial landscape broadly, then switch to full `optimize_acqf()` for accurate evaluations of the most promising regions. + +#### Why this doesn't work well with UCT + +If you replace `optimize_acqf()` with random sampling during burn-in, the same feature combination evaluated twice gives different rewards (depending on which random points were drawn). UCT assumes stationary rewards — its `w_total / n_visits` running mean mixes noisy early values with accurate late values in a way that cannot be disentangled. At the transition point, you'd need to flush or discount the tree statistics, which is messy and wastes the structural information the tree learned during burn-in. + +Worse, the noisy early rewards corrupt the UCT scores. A feature combination that happened to get a lucky random sample during burn-in would have an inflated mean, and UCT would over-exploit it in the accurate phase. There's no mechanism in UCT to say "those early observations were noisier, weight them less." + +#### Why Thompson Sampling makes it work + +With TS, each node has a posterior distribution. The Bayesian update naturally handles heteroscedastic observations — noisy early values and accurate late values for the same tree branches: + +- **Noisy burn-in observations** produce a wide posterior (high uncertainty). The tree explores broadly because wide posteriors occasionally sample very high values for under-explored branches. +- **Accurate post-burn-in observations** have much lower variance. When a combination is re-evaluated with full `optimize_acqf()`, the tight observation dominates the posterior — the mean shifts toward the true value without needing to discard the burn-in data. +- **The transition requires no special logic**. The Bayesian update correctly weights high-variance and low-variance observations automatically. No statistic flushing, no phase tracking, no manual reweighting. + +During burn-in, the noise is actually *beneficial*: it keeps posteriors wide, which means TS explores broadly. This is exactly what you want early on — cheap, broad exploration to map out the combinatorial landscape, then expensive, accurate exploitation to confirm the best regions. + +#### Cheap evaluation function + +The cheap evaluation is essentially the initialization phase of `optimize_acqf()` without the gradient refinement: + +```python +def cheap_reward_fn(selected_features, cat_selections, acq_function, bounds): + # Build fixed_features dict (same as full evaluation) + combined_fixed = build_fixed_features(selected_features, cat_selections) + + # Generate random points respecting bounds and fixed features + # Evaluate acq_function at those points, return the best + X_random = draw_sobol_samples(bounds, n=raw_samples, fixed_features=combined_fixed) + acq_values = acq_function(X_random) + return acq_values.max().item() +``` + +This is roughly 100x cheaper than full `optimize_acqf()` — just forward passes through the acquisition function at quasi-random points, no multi-start L-BFGS. The quality is lower (the maximum over random samples underestimates the true subspace optimum), but for the purpose of ranking feature combinations against each other, the relative ordering is usually preserved. + +#### Cache behavior changes + +| Phase | Caching | Rationale | +|-------|---------|-----------| +| Burn-in | Off | Evaluations are cheap; re-evaluating the same combination with different random samples produces genuinely new information that helps the posterior. The cache hit problem disappears. | +| Post-burn-in | On | Full `optimize_acqf()` is expensive and deterministic; caching prevents wasted computation. The TS variance inflation mechanism from §8.3 handles exhausted subtrees. | + +During burn-in, the cache hit problem that motivated virtual loss (§2.2) and rollout retry (§2.3) effectively dissolves: every evaluation produces a fresh noisy observation, even for previously visited combinations. The tree accumulates diverse reward signals across the combinatorial space without any wasted iterations. + +#### Two-phase structure + +| Phase | Iterations | Evaluation | Cost per eval | Caching | Purpose | +|-------|-----------|------------|---------------|---------|---------| +| Burn-in | 1 to N | Random sampling | ~1x (cheap) | Off | Broad exploration, learn tree structure and feature rankings | +| Exploitation | N+1 to end | Full `optimize_acqf()` | ~100x (expensive) | On | Accurate evaluation of promising regions | + +Optionally, at the transition point, re-evaluate the top-K combinations from burn-in with full optimization to calibrate the posteriors and ensure the most promising branches have accurate statistics before the exploitation phase begins. + +#### How many burn-in iterations? + +The benchmarks provide guidance. Looking at unique evaluations for +rpol: multigroup_interaction uses 516 unique evals out of 600 budget, large_sparse uses 750 out of 800. Most iterations produce novel terminals, meaning the combinatorial landscape is large enough that the early iterations are primarily about coverage. + +A burn-in of 50–100 cheap iterations would broadly map the combinatorial landscape — identifying which groups of features tend to produce high acquisition values, which cardinalities are promising, and which categorical values interact well. The remaining budget goes to accurate evaluations of the most promising combinations. The total wall-clock time drops substantially since the first 50–100 iterations cost ~1/100th each. + +For the largest search space (large_sparse, ~960M combinations, 800 budget), a longer burn-in (e.g., 150–200 iterations) might be justified since the cheap phase can cover more of the space. For smaller spaces (graduated_landscape, 375 combinations, 300 budget), a short burn-in (e.g., 30–50) suffices since even cheap evaluations cover a significant fraction of the space. + +#### Combined effect with full TS adoption + +If TS is adopted for tree selection (§8.3), rollouts (§8.4), and two-phase evaluation (this section), the system becomes: + +1. **Burn-in phase**: TS tree selection with wide posteriors → TS rollouts with wide posteriors → cheap noisy evaluation → posterior update. The entire system is in "broad exploration" mode with minimal cost per iteration. +2. **Exploitation phase**: TS tree selection with tightening posteriors → TS rollouts with learned preferences → accurate `optimize_acqf()` evaluation → cache for deterministic results → variance inflation on cache hits. The system converges on the best feature combinations with accurate reward signals. + +The transition between phases is smooth because every component uses the same Bayesian posterior framework. No statistics need to be flushed, no exploration constants need to be re-tuned, no special logic is required at the boundary. + +--- + +## 9. Files Generated + +| File | Description | +|------|-------------| +| `benchmark.py` | UCT benchmark script (reproduces all UCT results) | +| `results.json` | Full numeric results for UCT configs and problems | +| `summary_bar_chart.png` | Bar chart of final best reward (UCT configs) | +| `optimum_rate_heatmap.png` | Heatmap of optimum-finding rates (UCT configs) | +| `unique_evals.png` | Exploration efficiency comparison (UCT configs) | +| `convergence_.png` | Full convergence curves per problem (UCT) | +| `convergence__rave_effect.png` | RAVE ablation convergence | +| `convergence__pw_effect.png` | PW ablation convergence | +| `convergence__exploration.png` | c_uct ablation convergence | +| `convergence__p_stop.png` | p_stop ablation convergence | +| `convergence__rollout.png` | Rollout policy ablation convergence | +| `convergence__crave.png` | Context RAVE ablation convergence | +| `optimize_mcts_ts.py` | Thompson Sampling MCTS implementation (Normal posterior) | +| `benchmark_ts.py` | TS vs UCT benchmark script | +| `results_ts.json` | Full numeric results for TS benchmark | +| `summary_bar_chart_ts.png` | Bar chart of final best reward (TS vs UCT) | +| `optimum_rate_heatmap_ts.png` | Heatmap of optimum-finding rates (TS vs UCT) | +| `unique_evals_ts.png` | Exploration efficiency comparison (TS vs UCT) | +| `convergence_ts_.png` | Full convergence curves per problem (TS vs UCT) | +| `convergence_ts__ts_vs_uct.png` | TS vs UCT convergence comparison | +| `convergence_ts__ts_rollout_modes.png` | TS rollout mode comparison | +| `convergence_ts__variance_inflation.png` | Variance inflation ablation | +| `optimize_mcts_nig.py` | NIG posterior MCTS implementation (Student-t posterior) | +| `benchmark_nig.py` | NIG vs Normal-TS vs UCT benchmark script | +| `results_nig.json` | Full numeric results for NIG benchmark | +| `summary_bar_chart_nig.png` | Bar chart of final best reward (NIG vs TS vs UCT) | +| `optimum_rate_heatmap_nig.png` | Heatmap of optimum-finding rates (NIG vs TS vs UCT) | +| `unique_evals_nig.png` | Exploration efficiency comparison (NIG vs TS vs UCT) | +| `convergence_nig_.png` | Full convergence curves per problem (NIG) | +| `convergence_nig__nig_vs_normal_ts.png` | NIG vs Normal-TS vs UCT comparison | +| `convergence_nig__nig_cache_modes.png` | NIG cache-hit mode comparison | +| `convergence_nig__nig_alpha.png` | NIG alpha0 and APV effect | +| `optimize_mcts_dag.py` | DAG MCTS implementation (transposition table + NIG-TS + separate_stop) | +| `benchmark_dag.py` | DAG vs NIG tree benchmark script | +| `results_dag.json` | Full numeric results for DAG benchmark | +| `convergence_dag_.png` | Full convergence curves per problem (DAG) | +| `convergence_dag__ss_vs_baseline.png` | separate_stop vs baselines comparison | +| `convergence_dag__ss_variants.png` | separate_stop variant comparison | +| `summary_bar_chart_dag.png` | Bar chart of final best reward (DAG) | +| `optimum_rate_heatmap_dag.png` | Heatmap of optimum-finding rates (DAG) | +| `unique_evals_dag.png` | Exploration efficiency comparison (DAG) | +| `benchmark_dag_nocache.py` | DAG stochastic reward benchmark script | +| `results_dag_nocache.json` | Full numeric results for DAG stochastic benchmark | +| `nocache_dag_convergence__sigma<σ>.png` | Convergence curves (DAG stochastic) | +| `nocache_dag_optimum_rate_heatmap.png` | Heatmap across DAG stochastic configs | + +## 10. Reproducing + +```bash +# UCT benchmark (~60 seconds) +python mcts-report/benchmark.py + +# Thompson Sampling benchmark (~30 seconds) +python mcts-report/benchmark_ts.py + +# NIG posterior benchmark (~60 seconds) +python mcts-report/benchmark_nig.py + +# DAG transposition table benchmark (~70 seconds) +python mcts-report/benchmark_dag.py + +# DAG stochastic reward benchmark (~75 seconds) +python mcts-report/benchmark_dag_nocache.py +``` + +All results use fixed random seeds for reproducibility. + +--- + +## 11. Thompson Sampling Benchmark Results + +This section reports empirical results for the Thompson Sampling (TS) variants proposed in Sections 8.3 and 8.4. Implementation is in `optimize_mcts_ts.py`; benchmarking in `benchmark_ts.py`. + +### 11.1 Experimental Setup + +**TS implementation** (`MCTS_TS` class): replaces UCT tree selection with Normal-Normal conjugate posterior sampling. Each tree node maintains `(n_obs, sum_rewards, sum_sq_rewards)` instead of `(n_visits, w_total)`. At selection time, a reward is sampled from each child's posterior; the highest sample wins. A separate `n_visits` counter drives progressive widening. + +**Bayesian update** (weak prior, estimated variance): +- Prior: N(μ₀, σ₀²) where μ₀ = running global mean of all novel rewards, σ₀² = 1.0, pseudo-count n₀ = 1 +- After n novel observations: posterior mean = (μ₀ + n·x̄) / (1 + n), posterior variance = s² / (1 + n) where s² = max(Σx²/n − x̄², 10⁻⁸) +- n=0: sample from prior; n=1: posterior = N((μ₀+x)/2, σ₀²/2) + +**Configurations tested** (8 configs + Random baseline): + +| Config | Tree selection | Rollout policy | Cache hit mode | Tunable params | +|--------|---------------|----------------|----------------|----------------| +| UCT (+rpol) | UCT (c_uct=0.01, norm) | Softmax (ε=0.3, τ=1.0) | Virtual loss | 9 | +| UCT (no rpol) | UCT (c_uct=0.01, norm) | Uniform + adaptive p_stop | Virtual loss | 6 | +| TS + uniform | TS posterior | Uniform random | No update | 0 | +| TS + TS(g,a) | TS posterior | TS per (group, action) | No update | 0 | +| TS + TS(g,a) + var_infl | TS posterior | TS per (group, action) | Variance inflation | 0 (+decay=0.95) | +| TS + TS(g,c,a) | TS posterior | TS per (group, card, action) | No update | 0 | +| TS + TS(g,c,a) + var_infl | TS posterior | TS per (group, card, action) | Variance inflation | 0 (+decay=0.95) | +| TS + softmax rpol | TS posterior | Softmax (ε=0.3, τ=1.0) | No update | 3 | + +The "tunable params" column counts parameters that require problem-specific tuning. The TS prior variance (σ₀²=1.0) and variance decay (0.95) are structural defaults, not tuned per problem. + +### 11.2 Summary Tables + +#### multigroup_interaction (search space ~4.25M, optimum = 150.0) + +| Config | Mean Best | ±Std | Opt Rate | Uniq Evals | +|--------|----------|------|----------|------------| +| Random | 62.9 | 10.3 | 0% | 588 | +| UCT (+rpol) | 111.4 | 23.6 | 23% | 516 | +| UCT (no rpol) | 108.9 | 25.1 | 23% | 455 | +| TS + uniform | 92.9 | 22.6 | 7% | 96 | +| TS + TS(g,a) | 101.9 | 26.4 | 17% | 127 | +| **TS + TS(g,a) + var_infl** | **121.8** | 28.4 | **47%** | 378 | +| TS + TS(g,c,a) | 104.5 | 22.7 | 13% | 174 | +| TS + TS(g,c,a) + var_infl | 114.3 | 23.8 | 27% | 406 | +| TS + softmax rpol | 94.6 | 24.9 | 10% | 118 | + +#### needle_in_haystack (search space ~4,928, optimum = 100.0) + +| Config | Mean Best | ±Std | Opt Rate | Uniq Evals | +|--------|----------|------|----------|------------| +| Random | 39.7 | 20.5 | 10% | 216 | +| **UCT (+rpol)** | **100.0** | 0.0 | **100%** | 283 | +| UCT (no rpol) | 100.0 | 0.0 | 100% | 247 | +| TS + uniform | 53.8 | 35.5 | 37% | 42 | +| TS + TS(g,a) | 74.8 | 33.2 | 63% | 63 | +| TS + TS(g,a) + var_infl | 89.5 | 23.5 | 83% | 159 | +| TS + TS(g,c,a) | 72.3 | 33.9 | 60% | 48 | +| TS + TS(g,c,a) + var_infl | 88.7 | 25.4 | 83% | 87 | +| TS + softmax rpol | 86.0 | 28.0 | 80% | 57 | + +#### mixed_nchoosek_categorical (search space ~26,896, optimum = 150.0) + +| Config | Mean Best | ±Std | Opt Rate | Uniq Evals | +|--------|----------|------|----------|------------| +| Random | 79.2 | 14.6 | 3% | 472 | +| **UCT (+rpol)** | **135.9** | 25.6 | **77%** | 442 | +| UCT (no rpol) | 127.0 | 30.5 | 63% | 357 | +| TS + uniform | 95.2 | 28.7 | 20% | 94 | +| TS + TS(g,a) | 110.2 | 33.1 | 40% | 273 | +| TS + TS(g,a) + var_infl | 123.3 | 30.6 | 57% | 342 | +| TS + TS(g,c,a) | 122.6 | 31.6 | 57% | 271 | +| TS + TS(g,c,a) + var_infl | 131.8 | 27.8 | 70% | 348 | +| TS + softmax rpol | 106.0 | 34.4 | 37% | 194 | + +#### large_sparse (search space ~960M, optimum = 200.0) + +| Config | Mean Best | ±Std | Opt Rate | Uniq Evals | +|--------|----------|------|----------|------------| +| Random | 36.1 | 6.3 | 0% | 764 | +| **UCT (+rpol)** | **129.8** | 70.2 | **50%** | 750 | +| UCT (no rpol) | 112.1 | 72.0 | 40% | 689 | +| TS + uniform | 52.9 | 40.5 | 7% | 112 | +| TS + TS(g,a) | 56.8 | 27.4 | 3% | 211 | +| TS + TS(g,a) + var_infl | 84.0 | 58.2 | 20% | 575 | +| TS + TS(g,c,a) | 61.1 | 47.5 | 10% | 207 | +| TS + TS(g,c,a) + var_infl | 77.2 | 55.4 | 17% | 545 | +| TS + softmax rpol | 92.7 | 65.1 | 27% | 238 | + +#### graduated_landscape (search space 375, optimum = 65.0) + +| Config | Mean Best | ±Std | Opt Rate | Uniq Evals | +|--------|----------|------|----------|------------| +| Random | 60.6 | 3.3 | 7% | 113 | +| **UCT (+rpol)** | **64.5** | 1.4 | **80%** | 157 | +| UCT (no rpol) | 64.7 | 0.8 | 80% | 152 | +| TS + uniform | 62.5 | 2.6 | 10% | 43 | +| TS + TS(g,a) | 63.4 | 2.7 | 53% | 62 | +| TS + TS(g,a) + var_infl | 64.6 | 0.8 | 70% | 102 | +| TS + TS(g,c,a) | 62.9 | 3.4 | 23% | 42 | +| TS + TS(g,c,a) + var_infl | 64.6 | 0.9 | 73% | 81 | +| TS + softmax rpol | 58.1 | 8.1 | 3% | 34 | + +#### simple_additive (search space 793, optimum = 65.0) + +| Config | Mean Best | ±Std | Opt Rate | Uniq Evals | +|--------|----------|------|----------|------------| +| Random | 57.7 | 3.3 | 0% | 115 | +| **UCT (+rpol)** | **64.1** | 2.2 | **83%** | 187 | +| UCT (no rpol) | 64.1 | 2.2 | 83% | 184 | +| TS + uniform | 60.5 | 4.4 | 30% | 54 | +| TS + TS(g,a) | 63.4 | 2.5 | 63% | 71 | +| TS + TS(g,a) + var_infl | 64.2 | 2.0 | 80% | 112 | +| TS + TS(g,c,a) | 62.7 | 4.0 | 47% | 61 | +| TS + TS(g,c,a) + var_infl | 64.2 | 1.6 | 73% | 101 | +| TS + softmax rpol | 58.2 | 8.2 | 33% | 41 | + +### 11.3 Optimum-Finding Rate Heatmap + +![TS vs UCT: Optimum-Finding Rate](optimum_rate_heatmap_ts.png) + +### 11.4 Convergence Curves + +#### TS vs UCT — multigroup_interaction + +![TS vs UCT convergence on multigroup_interaction](convergence_ts_multigroup_interaction_ts_vs_uct.png) + +TS + TS(g,a) (red) tracks UCT (blue/orange) for the first ~100 iterations, then plateaus due to over-exploitation of cached subtrees. The `var_infl` variants (not shown in this subset; see variance inflation plots) continue climbing past iteration 200. + +#### TS vs UCT — large_sparse + +![TS vs UCT convergence on large_sparse](convergence_ts_large_sparse_ts_vs_uct.png) + +UCT (+rpol) clearly dominates. TS configs plateau early. The 960M search space requires tight exploitation — UCT's near-greedy `c_uct=0.01` plus virtual loss is better suited here than TS's broader posterior-driven exploration. + +#### TS vs UCT — mixed_nchoosek_categorical + +![TS vs UCT convergence on mixed](convergence_ts_mixed_nchoosek_categorical_ts_vs_uct.png) + +UCT (+rpol) leads throughout. TS + TS(g,a) converges to ~110 mean best, well below UCT's ~136. The gap is driven by UCT's learned softmax rollout policy, which handles the categorical dimensions more effectively. + +#### Variance inflation effect — multigroup_interaction + +![Variance inflation on multigroup_interaction](convergence_ts_multigroup_interaction_variance_inflation.png) + +The most dramatic effect in the benchmark. Without variance inflation, TS + TS(g,a) (red) plateaus at ~100 around iteration 100. With variance inflation (purple), the curve keeps climbing to ~122 by iteration 600. The effect is consistent: var_infl configs continue discovering new high-reward selections long after no-update configs have converged. + +#### Variance inflation effect — needle_in_haystack + +![Variance inflation on needle](convergence_ts_needle_in_haystack_variance_inflation.png) + +Without var_infl, TS + TS(g,a) converges at ~75 by iteration 50 and never improves — the posteriors are locked tight on suboptimal subtrees. With var_infl, the gradual widening allows the algorithm to escape and find the needle, reaching ~90 mean best. + +#### Variance inflation effect — large_sparse + +![Variance inflation on large_sparse](convergence_ts_large_sparse_variance_inflation.png) + +Variance inflation helps substantially (TS(g,a) from 57 to 84 mean best), but the gap to UCT (130) remains large. The 960M search space requires more unique evaluations than TS with var_infl can produce (575 vs UCT's 750). + +### 11.5 Analysis + +#### 11.5.1 Variance Inflation Is the Critical Design Decision + +The report in §8.3 proposed two cache-hit strategies: "no-update" (the correct Bayesian action) and "variance inflation" (a practical mitigation). The benchmark conclusively shows that **no-update alone is insufficient** and **variance inflation is essential**. + +The mechanism: without variance inflation, when a subtree is exhausted (all terminals cached), repeated visits produce no posterior updates. The posterior stays tight at a high mean, and TS keeps sampling it highly — there is no downward pressure equivalent to UCT's virtual loss. Variance inflation (decay factor 0.95) gradually reduces `n_obs` on cache hits, widening the posterior, which allows other branches to occasionally "win" a sample. + +| Problem | TS(g,a) no-update | TS(g,a) + var_infl | Improvement | +|---------|-------------------|---------------------|-------------| +| multigroup_interaction | 17% | **47%** | +30pp | +| needle_in_haystack | 63% | **83%** | +20pp | +| mixed_nchoosek_categorical | 40% | **57%** | +17pp | +| large_sparse | 3% | **20%** | +17pp | +| graduated_landscape | 53% | **70%** | +17pp | +| simple_additive | 63% | **80%** | +17pp | + +Variance inflation roughly doubles the unique evaluations (e.g., needle: 63→159, multigroup: 127→378), confirming that the core problem is exploration — without inflation, TS gets trapped in locally-optimal exhausted subtrees. + +#### 11.5.2 TS Beats UCT on Interaction-Heavy Problems + +On **multigroup_interaction** (the hardest problem for UCT), TS + TS(g,a) + var_infl achieves **47% optimum rate vs UCT's 23%** — more than double. This is the only problem where TS clearly outperforms UCT. + +Why: multigroup_interaction has strong cross-group interaction bonuses (e.g., feature 1 + feature 9 = +12 reward). UCT with `c_uct=0.01` is near-greedy, committing to the first good subtree it finds. TS's posterior-driven exploration naturally samples from multiple high-potential subtrees, increasing the chance of discovering interaction combinations. The posterior captures *reward variance* — if a subtree leads to both high and low rewards depending on downstream choices, TS explores it more because the wide posterior occasionally samples high. + +#### 11.5.3 UCT Dominates on Large Search Spaces + +On **large_sparse** (960M combinations), UCT (+rpol) achieves **50% vs TS's best 20%**. On **needle_in_haystack** (5K combinations), UCT achieves **100% vs TS's best 83%**. + +The explanation is exploration *efficiency*: UCT's near-greedy search with virtual loss concentrates evaluations on the most promising subtrees and then uses virtual loss to force exploration *within those subtrees* when they exhaust. TS explores more *broadly* — the stochastic sampling sends the search to diverse regions of the tree — but each individual subtree gets fewer evaluations. In large spaces where the number of feasible selections vastly exceeds the budget, UCT's focused exploitation finds the optimum more reliably. + +The unique evaluation counts confirm this: UCT evaluates 750 unique selections on large_sparse, while TS + var_infl evaluates 575. Those extra 175 evaluations, concentrated in promising regions, make the difference. + +#### 11.5.4 Cardinality Conditioning: Helps on Mixed, Hurts Elsewhere + +Comparing `TS(g,a)` vs `TS(g,c,a)` rollout keys: + +| Problem | TS(g,a) + var_infl | TS(g,c,a) + var_infl | Delta | +|---------|---------------------|----------------------|-------| +| multigroup_interaction | **47%** | 27% | −20pp | +| needle_in_haystack | 83% | 83% | 0pp | +| mixed_nchoosek_categorical | 57% | **70%** | +13pp | +| large_sparse | **20%** | 17% | −3pp | +| graduated_landscape | 70% | **73%** | +3pp | +| simple_additive | **80%** | 73% | −7pp | + +Cardinality conditioning helps on **mixed** (+13pp) because STOP decisions at different cardinalities have genuinely different values in a space with NChooseK + Categorical interactions. But it hurts on **multigroup_interaction** (−20pp) because the larger key space `(group, cardinality, action)` fragments the statistics: each posterior gets fewer updates and takes longer to converge. On a problem with 3 groups of 8 features and max cardinality 4, the key space expands from ~27 entries to ~108 — a 4x reduction in per-key observation count. + +This confirms the §8.4 prediction: "cardinality conditioning increases the key space, which means posteriors are updated less frequently and take longer to converge." The flat `(group, action)` key should be the default. + +#### 11.5.5 TS + Softmax Hybrid Is Worse Than Either Pure Approach + +`TS + softmax rpol` (TS tree selection with the UCT-era softmax rollout policy) performs poorly: + +| Problem | UCT (+rpol) | TS + softmax rpol | TS + TS(g,a) + var_infl | +|---------|-------------|-------------------|--------------------------| +| multigroup_interaction | 23% | 10% | **47%** | +| needle_in_haystack | **100%** | 80% | 83% | +| mixed_nchoosek_categorical | **77%** | 37% | 57% | +| large_sparse | **50%** | 27% | 20% | +| graduated_landscape | **80%** | 3% | 70% | +| simple_additive | **83%** | 33% | 80% | + +The hybrid is worse than both UCT (+rpol) and the best pure-TS config on nearly every problem. On graduated_landscape it achieves only 3% — catastrophic. + +The problem is the softmax rollout's learned statistics accumulate without reward normalization (the TS tree doesn't normalize), but the softmax scoring mechanism (`mean_reward + novelty_weight/sqrt(n+1)`) was designed for normalized rewards. When rewards span a wide raw range (e.g., 0-150), the novelty bonus (weight 1.0) is negligible relative to the mean reward, so the softmax concentrates too aggressively on early high-scoring actions. The low unique evaluation counts (34-238 vs UCT's 150-750) confirm the over-exploitation. + +This is a principled failure: the softmax rollout policy and TS tree selection have incompatible assumptions about reward scale. Use either pure UCT + softmax or pure TS throughout; don't mix. + +#### 11.5.6 Checking the §8.3 Predictions + +Section 8.3 made specific predictions about TS performance. How did they hold up? + +| Problem | §8.3 Prediction | Actual Result | Assessment | +|---------|-----------------|---------------|------------| +| needle_in_haystack | "TS should match" (100%) | 83% (best TS) | **Wrong** — UCT's tighter exploitation finds the needle more reliably | +| graduated_landscape | "Likely comparable or slightly better" | 70-73% vs 80% | **Partially wrong** — close but TS lags by ~10pp | +| large_sparse | "Roughly comparable, maybe slightly worse" | 20% vs 50% | **Wrong** — much worse, not "roughly comparable" | +| mixed_nchoosek_categorical | "TS could help" via reward variance capture | 57-70% vs 77% | **Partially right** — TS(g,c,a)+var_infl at 70% is close but doesn't exceed UCT | +| multigroup_interaction | "Broader exploration might help" | 47% vs 23% | **Right** — significant improvement from broader exploration | + +The predictions were too optimistic about TS's ability to match UCT's exploitation efficiency on large and medium search spaces. The critical factor not fully anticipated was the **severity of the exhausted-subtree problem** — the theoretical analysis correctly identified it as a risk but underestimated its magnitude on problems beyond multigroup_interaction. + +### 11.6 Updated Recommendations + +**Note:** These recommendations are for Normal-TS only. For the latest results with NIG posteriors (which supersede Normal-TS), see §11.13. + +**The TS family wins on 4 of 6 problems.** With adaptive prior variance, pessimistic pseudo-observations, and the combined cache-hit mode, TS exceeds UCT on multigroup (47% vs 23%), needle (100% vs 100%, tied), graduated (97% vs 80%), and simple_additive (87% vs 83%). UCT remains ahead on mixed (+7pp) and large_sparse (+13pp). + +**Default Normal-TS config: `TS + TS(g,a) + comb`** (combined cache-hit mode, no APV). This is the most robust Normal-TS single config — no catastrophic failures on any problem, 97% on graduated, competitive everywhere else. + +**Problem-specific optimization** (see §11.11.7 for full table): + +- **Interaction-heavy problems** (cross-group synergies): `TS + TS(g,a) + vi + apv` (47% on multigroup) +- **Large search spaces** (>10⁸ combinations): `TS + TS(g,a) + comb + apv` (37% on large_sparse — best Normal-TS result) +- **Needle-like problems** (single sharp optimum): `TS + uniform + pess + apv` (100% on needle) + +**Cache-hit handling is critical for TS.** The no-update mode fails in practice. Three modes are available: variance inflation (best for interaction discovery), pessimistic (best for systematic coverage), and combined (best overall robustness). See §11.9, §11.10, §11.11 for detailed comparisons. + +**Do not use the TS + softmax hybrid.** The softmax rollout policy assumes normalized rewards and is incompatible with TS tree selection. + +**If adopting TS, use `(group, action)` rollout keys, not `(group, cardinality, action)`.** The simpler key space produces more robust posteriors on most problems. Cardinality conditioning only helps on mixed NChooseK + Categorical problems and hurts elsewhere. + +### 11.7 Exploration Efficiency + +![TS vs UCT: Unique Evaluations](unique_evals_ts.png) + +The unique evaluation chart reveals the core trade-off. UCT configs consistently evaluate more unique selections (455-750 per problem), while TS without variance inflation evaluates far fewer (42-211). Variance inflation partially closes the gap (87-575), but on large_sparse — where coverage matters most — TS still trails. + +The implication: TS with variance inflation spends ~25% of its budget on cache hits (re-visiting exhausted subtrees and inflating posteriors), while UCT with virtual loss spends a similar fraction but extracts more value because the deterministic virtual-loss mechanism is more efficient at redirecting search than the stochastic posterior widening. + +### 11.8 Summary Bar Chart + +![TS vs UCT: Final Best Reward](summary_bar_chart_ts.png) + +### 11.9 Adaptive Prior Variance + +Section 11.5 used a fixed prior variance σ₀² = 1.0 for all problems. Section 11.9.2 of the original "Further Improvements" proposed replacing this with the running empirical variance of observed rewards — the TS analogue of UCT's reward normalization. This has now been implemented and benchmarked. + +#### 11.9.1 Implementation + +When `adaptive_prior_var=True`, the prior variance σ₀² is set to the running empirical variance of all novel rewards once at least 2 observations exist: + +```python +def _prior_var(self) -> float: + if not self.adaptive_prior_var or self._novel_reward_count < 2: + return self.ts_prior_var # fixed fallback + mean = self._global_mean() + empirical_var = self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + return max(empirical_var, 1e-8) +``` + +This auto-calibrates the prior to the problem's reward scale. On large_sparse (rewards in [-30, 200]), the empirical variance is ~2000, producing appropriately wide priors for newly expanded children. On simple_additive (rewards in [1, 65]), the empirical variance is ~150. Both are far more appropriate than the fixed σ₀² = 1.0. + +#### 11.9.2 Configurations + +Three new configs test adaptive prior variance (`adpt_pv` / `apv`) against their fixed-prior counterparts: + +| Config | Rollout | Cache hit | Adaptive prior | Fixed-prior counterpart | +|--------|---------|-----------|----------------|------------------------| +| TS + uniform + adpt_pv | Uniform | No update | Yes | TS + uniform | +| TS + TS(g,a) + adpt_pv | TS (group, action) | No update | Yes | TS + TS(g,a) | +| TS + TS(g,a) + vi + apv | TS (group, action) | Var. inflation | Yes | TS + TS(g,a) + var_infl | + +#### 11.9.3 Results: Adaptive Prior Variance Effect + +Optimum-finding rates, adaptive vs fixed prior (matched pairs): + +| Problem | uniform | +adpt_pv | TS(g,a) | +adpt_pv | vi | vi+apv | +|---------|---------|----------|---------|----------|-----|--------| +| multigroup_interaction | 7% | 7% | 17% | **20%** | **47%** | **47%** | +| needle_in_haystack | 37% | **43%** | 63% | **73%** | 83% | **87%** | +| mixed | 20% | 20% | 40% | **47%** | 57% | **60%** | +| large_sparse | 7% | 7% | 3% | **17%** | 20% | **33%** | +| graduated_landscape | 10% | **13%** | 53% | **60%** | 70% | **73%** | +| simple_additive | 30% | 27% | 63% | 57% | 80% | **87%** | + +#### 11.9.4 Analysis + +**The combination of variance inflation + adaptive prior (`vi + apv`) is the new best TS config.** It improves over variance inflation alone on 4 of 6 problems, matches on 1, and is within noise on 1. + +**Largest gain: large_sparse** — from 20% to **33%** optimum rate (+13pp). This is the problem where the fixed σ₀² = 1.0 hurts most. With rewards spanning [-30, 200], the fixed prior is absurdly narrow (σ₀ = 1.0 vs reward std ≈ 45). Newly expanded children sample near the global mean with almost no spread, providing negligible exploration value. With adaptive prior, σ₀ ≈ 45, so children sample across the full reward range and TS can meaningfully distinguish promising from unpromising branches early. The convergence curve shows this clearly — `vi + apv` (light blue) climbs steadily past var_infl-only (purple) after iteration 200: + +![Adaptive prior variance on large_sparse](convergence_ts_large_sparse_adaptive_prior_var.png) + +**simple_additive: 80% → 87%.** The adaptive prior brings TS to within 1 trial of UCT's 83% — the first TS config to match or exceed UCT on this problem. The convergence curve shows `vi + apv` tracking UCT closely: + +![Adaptive prior variance on simple_additive](convergence_ts_simple_additive_adaptive_prior_var.png) + +**needle_in_haystack: 83% → 87%.** The adaptive prior adds +4pp on top of variance inflation. Still below UCT's 100%, but the gap is narrowing. + +**Where adaptive prior has minimal effect: uniform rollouts.** The `uniform + adpt_pv` config shows almost no change from `uniform`. This makes sense — with uniform random rollouts, the bottleneck is the rollout quality, not the tree prior calibration. The adaptive prior helps most when combined with a learned rollout policy (TS rollouts) that also uses the prior for action selection. + +**Why adaptive prior works**: the fixed σ₀² = 1.0 creates two pathologies depending on the reward scale: +- *Narrow prior on wide-reward problems* (large_sparse, mixed): newly expanded children have tight priors that barely explore, so progressive widening must expand many children before TS finds productive ones. The adaptive prior makes children appropriately uncertain, so fewer children need to be expanded before a good one is found. +- *Wide prior on narrow-reward problems*: less of an issue because the posterior tightens quickly after a few observations, but the early iterations waste budget sampling from excessively wide priors that are uninformative. + +#### 11.9.5 Updated Comparison: Best TS vs UCT + +| Problem | UCT (+rpol) | Best TS (vi + apv) | Gap | +|---------|-------------|---------------------|-----| +| multigroup_interaction | 23% | **47%** | **+24pp TS wins** | +| needle_in_haystack | **100%** | 87% | −13pp | +| mixed | **77%** | 60% | −17pp | +| large_sparse | **50%** | 33% | −17pp | +| graduated_landscape | **80%** | 73% | −7pp | +| simple_additive | 83% | **87%** | **+4pp TS wins** | + +With adaptive prior variance, TS now **wins on 2 of 6 problems** (multigroup_interaction and simple_additive) and is within 7pp on a third (graduated_landscape). The gap on large_sparse has narrowed from 30pp to 17pp. UCT still dominates on needle (perfect 100%) and mixed (77% vs 60%). + +### 11.10 Pessimistic Pseudo-Observations + +#### 11.10.1 Motivation + +Variance inflation (§11.5.1) widens posteriors symmetrically on cache hits — the posterior could sample higher *or* lower — so ~50% of samples from an inflated exhausted node still select it. Pessimistic pseudo-observations provide *asymmetric* downward pressure: on each cache hit, we inject a pseudo-observation at `global_mean - global_std` into every node along the backpropagation path. This shifts the posterior mean downward, actively pushing the algorithm away from exhausted subtrees, analogous to UCT's virtual loss. + +#### 11.10.2 Implementation + +On each cache hit with `cache_hit_mode="pessimistic"`: + +```python +pess = self._global_mean() - math.sqrt(empirical_variance) +for node in path: + node.n_visits += 1 + node.n_obs += 1 + node.sum_rewards += pess + node.sum_sq_rewards += pess * pess +``` + +The pessimistic value always uses the empirical standard deviation of all observed rewards (not the fixed or adaptive prior variance), so the offset is scale-appropriate regardless of other settings. Unlike variance inflation, this increases `n_obs` (the posterior tightens around a lower mean) rather than decreasing it (which widens symmetrically). + +#### 11.10.3 Configurations + +| Config | Rollout | Cache-hit mode | Adaptive PV | Key comparison | +|--------|---------|---------------|-------------|----------------| +| `TS + TS(g,a) + pess` | TS (group,action) | pessimistic | No | vs var_infl | +| `TS + TS(g,a) + pess + apv` | TS (group,action) | pessimistic | Yes | vs vi + apv | +| `TS + uniform + pess + apv` | uniform | pessimistic | Yes | rollout-mode interaction | + +#### 11.10.4 Results: Pessimistic vs Variance Inflation + +**Optimum-finding rates (%)**: + +| Problem | var_infl | vi+apv | pess | pess+apv | uniform+pess+apv | +|---------|----------|--------|------|----------|-------------------| +| multigroup | **47** | **47** | 20 | 17 | 20 | +| needle | 83 | 87 | **90** | **90** | **100** | +| mixed | 57 | 60 | 27 | 23 | **70** | +| large_sparse | 20 | 33 | 23 | **33** | 13 | +| graduated | 70 | 73 | **93** | **87** | **80** | +| simple_additive | 80 | **87** | **83** | 80 | 77 | + +**Unique evaluations (mean)**: + +| Problem | var_infl | vi+apv | pess | pess+apv | +|---------|----------|--------|------|----------| +| multigroup | 378 | 395 | **432** | **448** | +| needle | 159 | 159 | **258** | **260** | +| mixed | 342 | 336 | **353** | **348** | +| large_sparse | 575 | 558 | **639** | **651** | +| graduated | 102 | 112 | **153** | **164** | +| simple_additive | 112 | 127 | **176** | **189** | + +#### 11.10.5 Analysis + +**Pessimistic pseudo-observations dramatically increase exploration efficiency.** On every problem, pessimistic configs evaluate substantially more unique selections than variance inflation: +57 on multigroup (432 vs 378), +99 on needle (258 vs 159), +76 on large_sparse (639 vs 575). The asymmetric downward pressure on exhausted subtrees is clearly more effective at redirecting search than symmetric posterior widening. + +**Pessimistic dominates on small/medium search spaces.** On graduated_landscape (375 combinations), `TS + TS(g,a) + pess` achieves **93% optimum rate** — the highest of *any* config including UCT (80%). On needle (4,928 combinations), `TS + uniform + pess + apv` achieves **100%** — matching UCT and exceeding all other TS configs. On simple_additive, `pess` matches UCT at 83%. + +**Variance inflation still wins on interaction-heavy problems.** On multigroup_interaction, variance inflation configs (47%) substantially outperform pessimistic (17-20%). The likely explanation: pessimistic pseudo-observations tighten posteriors (increasing `n_obs`), reducing the exploratory variance that TS needs to discover cross-group interaction effects. Variance inflation *widens* posteriors, maintaining the stochastic exploration that is critical for interaction discovery. + +**The `uniform + pess + apv` config is surprisingly strong.** Despite using a uniform rollout policy (no learned rollout), this config achieves 100% on needle, 80% on graduated, and 70% on mixed — competitive with or exceeding the best TS rollout configs. The pessimistic mechanism provides enough directed exploration that the rollout policy matters less. However, it underperforms on multigroup (20%) and large_sparse (13%), where learned rollouts are essential for navigating the vast search space. + +**Pessimistic + adaptive prior variance interaction is nuanced.** Adding APV to pessimistic generally helps on large_sparse (+10pp) but can slightly hurt on smaller problems (graduated 93→87%, simple_additive 83→80%). The pessimistic offset uses empirical std regardless of APV, but APV changes how the posterior prior width is set, which affects how quickly the pessimistic observations shift the mean. With APV, the prior is wider and better calibrated, so pessimistic observations have proportionally less impact. + +#### 11.10.6 Updated Comparison: Best TS Configs vs UCT + +| Problem | UCT (+rpol) | vi+apv | pess | pess+apv | uniform+pess+apv | Best TS | +|---------|-------------|--------|------|----------|-------------------|---------| +| multigroup | 23% | **47%** | 20% | 17% | 20% | **47% (vi+apv)** | +| needle | **100%** | 87% | 90% | 90% | **100%** | **100% (uni+pess+apv)** | +| mixed | **77%** | 60% | 27% | 23% | 70% | 70% (uni+pess+apv) | +| large_sparse | **50%** | 33% | 23% | 33% | 13% | 33% (vi+apv / pess+apv) | +| graduated | 80% | 73% | **93%** | 87% | 80% | **93% (pess)** | +| simple_additive | 83% | **87%** | 83% | 80% | 77% | **87% (vi+apv)** | + +The TS family now **wins on 4 of 6 problems** (multigroup, needle, graduated, simple_additive) — up from 2 with adaptive prior variance alone. No single TS config dominates: `vi+apv` is best for interaction-heavy and scale-sensitive problems, `pess` for small smooth landscapes, and `uniform+pess+apv` for needle-like problems with a single sharp optimum. UCT still leads on mixed (+7pp) and large_sparse (+17pp). + +#### 11.10.7 Practical Recommendation + +The choice between variance inflation and pessimistic depends on the problem structure: + +- **Interaction-heavy problems** (cross-group synergies matter): use `TS + TS(g,a) + vi + apv` +- **Small/medium search spaces** with smooth or needle-like landscapes: use `TS + TS(g,a) + pess` +- **Unknown problem structure**: start with `vi+apv` (more robust); switch to `pess` if convergence is slow on problems that should be easy + +### 11.11 Combined Cache-Hit Mode (Variance Inflation + Pessimistic) + +#### 11.11.1 Motivation + +Variance inflation and pessimistic pseudo-observations solve the exhausted-subtree problem from opposite directions: inflation widens posteriors symmetrically (preserving stochastic exploration for interaction discovery), while pessimistic shifts means downward asymmetrically (directing search away from exhausted subtrees). The benchmark shows these strengths are complementary — variance inflation dominates on multigroup (47% vs 20%), pessimistic dominates on graduated (93% vs 70%). A combined mode applies both mechanisms on each cache hit, aiming to capture both advantages. + +#### 11.11.2 Implementation + +On each cache hit with `cache_hit_mode="combined"`: + +```python +for node in path: + node.n_visits += 1 + # Step 1: variance inflation — decay n_obs to widen posterior + if node.n_obs > 1: + old_n = node.n_obs + new_n = max(1, int(old_n * variance_decay)) + if new_n < old_n: + mean = node.sum_rewards / old_n + node.sum_rewards = mean * new_n + node.sum_sq_rewards *= new_n / old_n + node.n_obs = new_n + # Step 2: pessimistic — add one pessimistic observation + node.n_obs += 1 + node.sum_rewards += pessimistic_value + node.sum_sq_rewards += pessimistic_value ** 2 +``` + +Net effect per cache hit: `n_obs` decays by ~5% (e.g., 20 → 19), then gains +1 for the pessimistic observation (→ 20 again). The count barely changes, but the *composition* changes: one real observation is effectively replaced by a pessimistic one. The mean shifts downward slightly while posterior width is largely preserved. + +#### 11.11.3 Configurations + +| Config | Rollout | Cache-hit mode | Adaptive PV | Key comparison | +|--------|---------|---------------|-------------|----------------| +| `TS + TS(g,a) + comb` | TS (group,action) | combined | No | vs var_infl / pess | +| `TS + TS(g,a) + comb + apv` | TS (group,action) | combined | Yes | vs vi+apv / pess+apv | + +#### 11.11.4 Results + +**Optimum-finding rates (%)**: + +| Problem | var_infl | vi+apv | pess | comb | comb+apv | +|---------|----------|--------|------|------|----------| +| multigroup | **47** | **47** | 20 | 33 | 13 | +| needle | 83 | 87 | 90 | 90 | **93** | +| mixed | 57 | 60 | 27 | 43 | 23 | +| large_sparse | 20 | 33 | 23 | 20 | **37** | +| graduated | 70 | 73 | **93** | **97** | 87 | +| simple_additive | 80 | **87** | 83 | 83 | 67 | + +**Unique evaluations (mean)**: + +| Problem | var_infl | vi+apv | pess | comb | comb+apv | +|---------|----------|--------|------|------|----------| +| multigroup | 378 | 395 | 432 | **475** | **479** | +| needle | 159 | 159 | 258 | **282** | **287** | +| mixed | 342 | 336 | 353 | **360** | **359** | +| large_sparse | 575 | 558 | 639 | **671** | **672** | +| graduated | 102 | 112 | 153 | **175** | **181** | +| simple_additive | 112 | 127 | 176 | **192** | **202** | + +#### 11.11.5 Analysis + +**Combined mode achieves the highest exploration efficiency of any config.** On every problem, `comb` evaluates more unique selections than either `var_infl` or `pess` alone. On large_sparse, `comb` reaches 671 unique evaluations vs 575 for `var_infl` and 639 for `pess`. The dual mechanism — decay followed by pessimistic injection — creates stronger pressure to leave exhausted subtrees than either mechanism alone. + +**`comb` (without APV) is the most robust single TS config.** It has no catastrophic failures: + +| Problem | comb | Comparison | +|---------|------|-----------| +| multigroup | 33% | Between var_infl (47%) and pess (20%). Recovers half the gap. | +| needle | 90% | Matches pess, +7pp over vi+apv | +| mixed | 43% | Between var_infl (57%) and pess (27%) | +| large_sparse | 20% | Matches var_infl; weaker than vi+apv (33%) | +| graduated | **97%** | Highest of any config. Beats pess (93%), UCT (80%) | +| simple_additive | 83% | Matches UCT and pess | + +**`comb` on graduated_landscape: 97% — the best result in the entire benchmark.** The combination of posterior widening and downward pressure creates near-perfect convergence on smooth landscapes. Only 1 out of 30 trials failed to find the optimum, with std=0.2 (vs UCT's std=1.4). + +**`comb + apv` sets a new TS record on large_sparse: 37%.** This is the closest any TS config has come to UCT's 50%. The adaptive prior variance helps calibrate the Bayesian update to the large reward range on this problem (rewards span [-30, 200]), and the combined cache-hit mode provides both posterior widening and directional pressure. + +**APV hurts the combined mode on interaction-heavy problems.** `comb + apv` collapses on multigroup (13%) and mixed (23%), worse than `comb` alone (33% and 43%). This mirrors the same effect seen with pessimistic: APV makes the prior wider, which dilutes the pessimistic observation's impact. On problems where the variance inflation component is doing the heavy lifting (interaction discovery), weakening the pessimistic component would help — but APV weakens it instead of strengthening the inflation. + +**The multigroup gap remains.** Even `comb` at 33% trails `vi+apv` at 47%. The pessimistic component, even after decay-widening, still provides some downward mean shift that reduces the stochastic exploration needed for interaction discovery. The fundamental tension — wide posteriors for interactions vs. directed pressure for coverage — is reduced by combined mode but not eliminated. + +#### 11.11.6 Updated Comparison: All Cache-Hit Modes vs UCT + +| Problem | UCT | vi+apv | pess | comb | comb+apv | Best TS | +|---------|-----|--------|------|------|----------|---------| +| multigroup | 23% | **47%** | 20% | 33% | 13% | **47% (vi+apv)** | +| needle | **100%** | 87% | 90% | 90% | 93% | 100% (uni+pess+apv) | +| mixed | **77%** | 60% | 27% | 43% | 23% | 70% (uni+pess+apv / g,c,a+vi) | +| large_sparse | **50%** | 33% | 23% | 20% | **37%** | **37% (comb+apv)** | +| graduated | 80% | 73% | 93% | **97%** | 87% | **97% (comb)** | +| simple_additive | 83% | **87%** | 83% | 83% | 67% | **87% (vi+apv)** | + +The TS family wins on 4 of 6 problems. The best single TS config depends on the problem, but `comb` without APV provides the most consistent performance across all problem types — never catastrophically bad, competitive everywhere, and record-setting on graduated. + +#### 11.11.7 Practical Recommendation + +For a **single default TS config** when problem structure is unknown: + +1. **`TS + TS(g,a) + comb`** — most robust. No catastrophic failures, highest floor across all problems. Best choice when you cannot characterize the problem beforehand. + +For **problem-specific optimization**: + +| Problem type | Best config | Why | +|-------------|-------------|-----| +| Interaction-heavy (cross-group synergies) | `vi+apv` | Wide posteriors for stochastic interaction discovery | +| Large search space (>10⁸ combinations) | `comb+apv` | Scale calibration + dual cache-hit pressure | +| Small/medium smooth landscape | `comb` | Near-perfect convergence (97% on graduated) | +| Needle-in-haystack (single sharp peak) | `uniform+pess+apv` | Systematic coverage, no rollout bias | + +### 11.12 Further Improvements to the Bayesian Approach + +The benchmark identifies specific weaknesses and opportunities for the TS implementation. The following improvements are ordered by how directly the benchmark evidence motivates them. + +#### 11.12.1 ~~Pessimistic Pseudo-Observations on Cache Hits~~ — Implemented + +**Implemented and benchmarked in §11.10.** Pessimistic pseudo-observations dramatically increase exploration efficiency (unique evaluations up 15-60% across all problems). Combined with variance inflation in §11.11, the combined mode achieves 97% on graduated_landscape (best of any config) and 37% on large_sparse (best TS result). The combined mode is now the recommended default TS cache-hit strategy. + +**Original problem statement**: Variance inflation widens posteriors but never shifts the mean downward. A node with posterior mean 120 and tight variance stays attractive indefinitely — the inflated posterior still centers on 120, and most samples remain high. UCT's virtual loss works because it deterministically pushes `w_total / n_visits` down; TS has no equivalent downward pressure. + +#### 11.12.2 ~~Adaptive Prior Variance from Observed Reward Range~~ — Implemented + +**Implemented and benchmarked in §11.9.** The adaptive prior variance improves the best TS config on 4 of 6 problems, with the largest gain on large_sparse (+13pp). It is now the default recommendation for TS configs. + +**Original problem statement**: The fixed prior variance `σ₀² = 1.0` is scale-blind. On large_sparse (rewards in approximately [-30, 200]), a prior N(μ₀, 1.0) is absurdly narrow — a newly expanded child samples near the global mean with almost no spread, providing negligible exploration. On simple_additive (rewards in [1, 65]), the same prior is more reasonable but still somewhat tight. + +**Proposed fix**: Set σ₀² to the running empirical variance of all observed rewards, rather than a fixed constant. This auto-calibrates the prior to the reward scale of the problem: + +```python +def _prior_var(self) -> float: + if self._novel_reward_count < 2: + return self.ts_prior_var # fixed fallback for first iterations + mean = self._global_mean() + empirical_var = ( + self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + ) + return max(empirical_var, 1e-8) +``` + +Early iterations (few rewards observed) use the fixed fallback, which provides wide priors and broad exploration. As rewards accumulate, the prior tightens to match the actual reward distribution. Newly expanded children then have priors that are appropriately calibrated — wide enough to explore on large-scale problems, tight enough to focus on small-scale ones. + +This is the TS analogue of reward normalization: instead of squashing rewards to [0, 1] to match `c_uct`, we scale the prior to match the rewards. + +#### 11.12.3 Two-Phase Burn-in with Cheap Evaluations + +**Problem**: TS's exploration efficiency gap (575 vs 750 unique evals on large_sparse) exists because each evaluation is expensive (`optimize_acqf` with multi-start L-BFGS), so wasted iterations on cache hits are costly. The cache itself exists because evaluations are expensive and deterministic. + +**Empirical validation**: §11.16 confirms that polytope samples rank subsets with Spearman ρ > 0.97 at 26x lower cost on constrained problems (64 samples per subset). The ranking is near-perfect even though absolute values are pessimistically biased — exactly what NIG-TS needs. + +**Proposed fix** (detailed in §8.5): split the MCTS run into two phases: + +| Phase | Evaluations | Caching | Cost per eval | Purpose | +|-------|------------|---------|---------------|---------| +| Burn-in (1 to N) | Cheap random sampling | Off | ~1/100x | Broad landscape mapping | +| Exploitation (N+1 to end) | Full `optimize_acqf` | On | 1x | Accurate exploitation | + +During burn-in, every evaluation is novel (no cache, no cache hits), so the exhausted-subtree problem disappears entirely. TS's posteriors accumulate diverse noisy observations across the combinatorial landscape. At transition, the posteriors already encode which regions of the space are promising, so the expensive budget is concentrated where it matters. + +TS is uniquely suited for this because the Bayesian update naturally handles heteroscedastic observations: noisy burn-in values produce wide posteriors (low confidence), and accurate post-burn-in values produce tight posteriors that dominate the mean. No statistic flushing or phase-tracking logic is needed. UCT's running mean cannot distinguish noisy from accurate observations, making a clean transition much harder (§8.5). + +The benchmark data suggests the burn-in length should scale with search space size: ~50 iterations for small spaces (graduated_landscape, 375 combinations), ~200 for large (large_sparse, 960M combinations). + +#### 11.12.4 Depth-Dependent Cache-Hit Handling + +**Problem**: The current variance inflation applies the same decay factor (0.95) to every node in the backpropagation path, from root to leaf. But the exhaustion problem is depth-dependent: the root node aggregates rewards from the entire tree and is never truly exhausted; a node at depth 8 covers a narrow slice of the search space and exhausts quickly. + +**Proposed fix**: Scale the decay (or pessimistic pseudo-observation magnitude) by depth: + +```python +effective_decay = decay ** (1.0 + depth * depth_scale) +``` + +With `depth_scale=0.5`: at depth 0 (root), effective_decay = 0.95 (minimal inflation). At depth 6, effective_decay = 0.95^4 = 0.81 (aggressive inflation). Deep nodes in exhausted subtrees get widened quickly, while the root's posterior remains stable and reflects accurate aggregate statistics. + +This also addresses a subtle issue: inflating the root's posterior can cause wild swings in the algorithm's overall behavior (the root affects every single selection), while inflating a deep leaf's posterior only affects selections that pass through that narrow path. + +#### 11.12.5 Progressive Widening Tuned for TS + +**Problem**: The PW parameters (k0=2.0, alpha=0.6) were tuned for UCT, where the deterministic score ensures all existing children get visited roughly proportionally to their UCT score. TS's stochastic selection is less balanced — children with tight, high-mean posteriors dominate samples, and children with wide uncertain posteriors are selected only when they happen to sample high. This means TS may under-expand: the child limit grows based on `n_visits`, but visits concentrate on a few children rather than spreading evenly, so the PW limit stays artificially low. + +**Proposed fix**: Increase PW aggressiveness for TS, e.g., k0=4.0 or alpha=0.8. More children means more posteriors to sample from, increasing the chance that an uncertain child "wins" a sample. This directly increases the unique evaluation count, which is the core gap between TS and UCT. + +A quick experiment would test k0 ∈ {2, 4, 8} × alpha ∈ {0.6, 0.8} on the TS + TS(g,a) + var_infl config. If the unique eval count on large_sparse rises from 575 toward 700+ without sacrificing quality on smaller problems, the PW re-tune is worthwhile. + +#### 11.12.6 ~~Normal-Inverse-Gamma Posterior (Proper Conjugate Update)~~ — Implemented + +**Implemented and benchmarked in §11.13.** The NIG posterior is a transformative improvement. The best NIG config (NIG + TS(g,a) + vi + apv) achieves 80% on multigroup (vs Normal-TS's 47% and UCT's 23%), 100% on needle and mixed, and 47% on large_sparse (vs UCT's 50%). A single NIG config now matches or exceeds UCT on 5 of 6 problems. + +**Original problem statement**: The current TS implementation uses a Normal-Normal conjugate update that treats the reward variance σ² as a known plug-in estimate (`s² = max(sum_sq/n - x̄², 1e-8)`). This is reasonable when n is moderate, but it breaks down at low observation counts — exactly the regime that matters most for exploration: + +- With n=1: `s² = max(x²/1 - x², 1e-8) = 1e-8` — variance collapses to the floor. The posterior becomes absurdly tight around a single observation. +- With n=2: sample variance is based on just 2 points — unreliable. +- With n=0: we fall back to the prior, which is a Normal distribution. + +The consequence is **premature commitment**: a node that receives one good observation gets a tight posterior with a high mean, and TS keeps selecting it. A node that receives one bad observation is abandoned. Neither has enough data to justify such confidence. + +**Root cause**: The Normal-Normal model assumes known variance. The proper conjugate prior for Normal with *both* unknown mean and unknown variance is the **Normal-Inverse-Gamma (NIG)** distribution: + +``` +Prior: (μ, σ²) ~ NIG(μ₀, n₀, α₀, β₀) + +μ₀ = prior mean (global running mean, same as now) +n₀ = prior pseudo-count (confidence in the mean) +α₀ = shape parameter for variance prior (e.g., 1 = weak) +β₀ = scale parameter for variance prior (e.g., σ₀² = adaptive prior var) +``` + +After n observations with sample mean x̄ and sum of squared deviations S = Σ(xᵢ - x̄)²: + +``` +n₀' = n₀ + n +μ₀' = (n₀·μ₀ + n·x̄) / n₀' +α₀' = α₀ + n/2 +β₀' = β₀ + S/2 + (n₀·n·(x̄ - μ₀)²) / (2·n₀') +``` + +The marginal posterior for μ (integrating out σ²) is a **Student-t distribution**: + +``` +μ | data ~ t_{2α₀'}(location=μ₀', scale=sqrt(β₀' / (α₀' · n₀'))) +``` + +**Why this fixes the premature commitment problem**: The Student-t has heavier tails than the Normal, especially at low degrees of freedom (df = 2α₀'). With n=1 and α₀=1, df=3 — the distribution has *much* wider tails than a Normal, reflecting genuine uncertainty about both the mean and the variance. As observations accumulate, df grows, and the t-distribution converges to Normal — exactly recovering the current behavior at moderate-to-large n. + +| Observations | Current (Normal) | NIG (Student-t) | +|-------------|-----------------|-----------------| +| n=0 | Sample from N(μ₀, σ₀²) | Sample from t₂(μ₀, β₀/α₀) — heavier tails | +| n=1 | Tight N near single obs (s²≈0) | Wide t₃ — high uncertainty persists | +| n=2 | N with noisy variance estimate | t₄ — still wider than Normal | +| n=20+ | Approximately N(x̄, s²/n) | Approximately N(x̄, s²/n) — same | + +**Sufficient statistics**: The NIG update requires `(n_obs, sum_rewards, sum_sq_deviations)`. We already track `n_obs` and `sum_rewards`. We currently track `sum_sq_rewards` (sum of x²), from which sum of squared deviations can be computed as `S = sum_sq_rewards - n·x̄²`. No additional per-node storage is needed. + +**Sampling from Student-t**: `t_df(loc, scale)` can be sampled as `loc + scale * (Z / sqrt(V/df))` where Z ~ N(0,1) and V ~ χ²(df). Python's `random` module doesn't have a direct t-distribution, but it can be computed from Normal and Gamma samples, or approximated via the inverse CDF. For df > 30, the Normal approximation is sufficient. + +**Interaction with existing mechanisms**: NIG is orthogonal to the cache-hit handling (combined mode) and adaptive prior variance. APV would set β₀ to the empirical reward variance (same role as σ₀² currently). The combined mode's variance inflation would decay n₀' (widening the t-posterior) and add pessimistic observations (shifting the location). The NIG posterior simply replaces the sampling distribution from Normal to Student-t, with the most impact at low observation counts. + +**Expected impact**: The main benefit is on large search spaces (large_sparse, multigroup) where many nodes have few observations. These nodes currently have artificially tight posteriors that cause premature commitment. With NIG, their posteriors are genuinely wide (heavy-tailed t), so TS naturally explores them more before committing. On small spaces (graduated, simple_additive) where most nodes accumulate many observations, the impact is minimal — the t-distribution converges to Normal quickly. This is exactly the right behavior: the fix is strongest where the problem is worst. + +**No new hyperparameters**: α₀=1 (standard weak prior for variance) is the canonical choice. n₀ and β₀ map directly to the existing prior pseudo-count and prior variance. The implementation is a drop-in replacement for `_ts_sample_score` and `_ts_sample_action_score`. + +#### 11.12.7 ~~Adaptive Pseudo-Count n₀ from Branching Factor~~ — Implemented (Negative Result) + +**Implemented and benchmarked in §11.15.** Adaptive n₀ = 1 + log(branching_factor) was tested across 5 configurations on all 6 problems. The result is **negative**: adaptive n₀ does not improve performance on any problem where it was hypothesized to help. On multigroup_interaction, the best adaptive n₀ config (vi+apv+an₀) achieves 57% vs 80% for fixed n₀ (vi+apv) — a 23pp regression. On large_sparse, apess+an₀ achieves 47% vs 53% for apess — a 6pp regression. The higher n₀ over-corrects by making the posterior too conservative, slowing convergence within the available budget. The one bright spot is `acomb+an₀` achieving 100% on graduated_landscape and simple_additive, but these were already mostly solved by existing configs. + +**Original problem statement**: The prior pseudo-count n₀ is fixed at 1, meaning a single observation contributes 50% to the posterior mean. On a problem with a high branching factor (many legal actions per node), the algorithm visits each child infrequently — one observation per child is common during early exploration. With n₀=1, that single observation immediately collapses the posterior, causing premature commitment to whichever child happened to get a good first evaluation. + +**Implemented fix**: Set n₀ proportional to the local branching factor: + +```python +def _compute_n0(self, n_actions: int) -> float: + if not self.adaptive_n0: + return 1.0 + return 1.0 + math.log(max(n_actions, 2)) +``` + +With 2 legal actions: n₀ ≈ 1.7. With 11 actions (large_sparse root): n₀ ≈ 3.4. With 30 actions: n₀ ≈ 4.4. + +The n₀ value is computed from the number of siblings (tree selection) or legal actions (rollout) and passed to the NIG sampling functions. This is fully automatic — zero new hyperparameters. + +**Why it failed**: The NIG posterior's Student-t tails already handle the low-n regime well enough. Adding higher n₀ on top of that makes the posterior *too* conservative — it takes more observations to shift away from the prior, but the available budget (600–800 iterations) isn't large enough to recover from the slower learning. The intuition that "higher n₀ prevents premature commitment" over-corrects when the Student-t's heavy tails are already providing sufficient exploration pressure. + +#### 11.12.8 ~~Adaptive Pessimistic Strength from Local Exhaustion~~ — Implemented + +**Problem**: The pessimistic pseudo-observation in combined mode uses a fixed value of `global_mean - global_std` for every node in the backpropagation path. But exhaustion varies across the tree: a subtree with 90% cache-hit rate is severely exhausted and needs aggressive pessimism; a subtree with 10% cache-hit rate is mostly novel and needs almost none. The fixed strength over-penalizes fresh subtrees and under-penalizes exhausted ones. + +**Implemented fix**: Two new cache-hit modes scale the pessimistic offset by the node's local exhaustion rate, measured as `1 - (n_obs / n_visits)`: + +```python +novelty_rate = node.n_obs / max(1, node.n_visits) +exhaustion = 1.0 - novelty_rate # 0 = fully novel, 1 = fully exhausted +pess_value = global_mean - exhaustion * global_std +``` + +- `adaptive_pessimistic`: adaptive pessimistic pseudo-obs only +- `adaptive_combined`: variance inflation + adaptive pessimistic pseudo-obs + +When a subtree is fresh (high novelty rate, most visits produce new evaluations), the pessimistic observation is mild (close to the global mean — barely shifts the posterior). When exhausted (low novelty rate, most visits are cache hits), the pessimistic observation is aggressive (full `mean - std` — strong downward pressure). + +This uses information already tracked (`n_obs` and `n_visits` per node) and requires no new hyperparameters. + +**Results**: See §11.14 for full benchmark. The adaptive modes did not resolve the vi-vs-comb tradeoff as hoped (vi+apv remains best on hard interaction problems). However, the **no-APV adaptive modes** (`apess`, `acomb`) achieved **53% on large_sparse** — the first configs to exceed UCT's 50% on this problem. The finding that APV hurts on large_sparse was unexpected and suggests APV over-shrinks the prior variance on massive search spaces. + +#### 11.12.9 Correlated Priors Across Sibling Nodes + +**Problem**: Each child node has an independent prior. But in NChooseK problems, features within a group are structurally related. If selecting feature 3 in group 1 yields reward 80, that says something about the value of selecting feature 4 in the same group — they share the same group context and only differ in one feature. The current TS treats them as completely unrelated, requiring each to be explored independently. + +**Proposed fix**: After each novel evaluation, propagate a discounted update to the evaluated node's siblings (other children of the same parent): + +```python +sibling_discount = 0.1 # share 10% of the signal +for action, sibling in parent.children.items(): + if sibling is not evaluated_child: + sibling.n_obs += sibling_discount + sibling.sum_rewards += reward * sibling_discount + sibling.sum_sq_rewards += (reward ** 2) * sibling_discount +``` + +This is conceptually similar to RAVE (sibling nodes share information from the same rollout) but integrated into the Bayesian framework: siblings share a weak signal that narrows their posteriors slightly, so they don't require as many direct visits to distinguish good from bad. On multigroup_interaction, where TS already outperforms UCT, sibling sharing could accelerate convergence. On large_sparse, it could help the algorithm identify productive subtrees faster by propagating feature-quality signals sideways through the tree, not just upward through backpropagation. + +The risk is over-sharing: if features are anti-correlated (feature 3 is good *because* feature 4 is not selected), sibling updates would introduce bias. The discount factor controls this trade-off — 0.1 means sibling signal is 10x weaker than direct observation, small enough that a few direct visits override any sibling-induced bias. + +#### 11.12.10 Information-Directed Sampling + +**Problem**: Pure TS selects the child with the highest posterior sample. This occasionally revisits high-mean exhausted subtrees even with variance inflation, because the posterior mean is still high and most samples fall near the mean. TS has no concept of "this action is uninformative because the subtree is exhausted." + +**Proposed fix**: Replace pure TS with Information-Directed Sampling (IDS), which selects the child maximizing `E[reward]² / I[action]`, where `I[action]` is the mutual information between the action's outcome and the identity of the optimal action. In the MCTS context, a tractable approximation: + +``` +IDS_score(child) = posterior_mean(child)² / information_gain(child) +information_gain(child) ≈ posterior_var(child) / (posterior_var(child) + noise_var) +``` + +Exhausted subtrees have low `information_gain` (tight posterior, nothing new to learn), so their IDS score is high (unfavorable — IDS minimizes the ratio). Uncertain subtrees have high `information_gain`, so their IDS score is low (favorable — worth exploring). This explicitly penalizes "known-good but uninformative" actions, which is exactly the exhausted-subtree case. + +IDS has formal regret bounds that are tighter than TS in structured problems. The main cost is computational: computing the information gain approximation requires maintaining noise variance estimates per node, and the selection step involves a ratio computation rather than a simple argmax of samples. Whether the theoretical advantage translates to practical improvement on these benchmarks would need to be tested empirically. + +#### 11.12.11 Warm-Starting Trees for Batch Candidate Generation + +**Problem**: In batch Bayesian optimization we need q > 1 candidates per iteration. With sequential greedy strategies (e.g., qLogEI), we generate candidate 1, set it as pending (fantasized) on the acquisition function, then re-optimize to generate candidate 2, and so on. Each re-optimization currently builds an MCTS tree from scratch, discarding all structural knowledge accumulated for the previous candidate. + +**Why the landscape shift is mild**: Adding a pending candidate updates the GP posterior with a fantasized observation at that point. This changes the acquisition surface everywhere, but the effect is spatially localized in the combinatorial space: + +- Selections sharing features with the pending candidate see a large acquisition drop (the GP "fills in" that region) +- Selections that are combinatorially distant (different features entirely) are barely affected +- The overall structure of which regions are promising vs. not is largely preserved + +In NChooseK problems this locality is stronger than in continuous BO because the combinatorial structure is discrete — the pending candidate occupies one specific feature selection, and tree paths that don't overlap with it are nearly unchanged. + +**Proposed fix**: Warm-restart the MCTS tree between candidates in a batch. Clear the evaluation cache (acquisition values are stale), but keep the tree structure and decay all node statistics to widen posteriors: + +```python +def warm_restart_for_pending(self, decay_factor=0.3): + """Prepare tree for generating the next candidate in a batch.""" + self._cache.clear() + self._cache_hits = 0 + + def _decay_node(node): + if node.n_obs > 0: + old_n = node.n_obs + node.n_obs = max(1, int(old_n * decay_factor)) + ratio = node.n_obs / old_n + mean = node.sum_rewards / old_n + node.sum_rewards = mean * node.n_obs # preserve mean + node.sum_sq_rewards *= ratio + node.n_visits = node.n_obs # reset PW counter + for child in node.children.values(): + _decay_node(child) + + _decay_node(self.root) + self._rollout_action_stats.clear() +``` + +After decay, every posterior is wide (low confidence) but centered on its old mean (structural prior). TS's stochastic sampling naturally re-explores, and nodes near the pending candidate — whose true acquisition values dropped the most — get corrected by new evaluations, while distant nodes find their old means confirmed quickly. + +**Why TS is better suited than UCT for this**: UCT stores running means (`w_total / n_visits`). When the landscape shifts, those means are wrong, and there is no principled way to "soften" them — you either reset entirely (losing everything) or live with stale statistics that UCT's deterministic formula exploits aggressively. TS already has the variance inflation machinery: decaying `n_obs` to widen posteriors has a principled Bayesian interpretation (reduced confidence under a shifted landscape), and stochastic sampling auto-corrects as new evaluations arrive. + +**The decay factor controls the exploration/exploitation trade-off**: + +| decay_factor | Behavior | When to use | +|-------------|----------|-------------| +| 0.0 | Full reset (only tree structure reused) | Landscape shift is large (pending candidate in heavily explored region) | +| 0.3 | Aggressive widening, heavy re-exploration with structural priors | Default: good balance for typical batch sizes | +| 0.7 | Mild widening, trusts old landscape | Small batch, candidates are spread across distant subtrees | + +**Practical savings**: The MCTS spends a large fraction of its budget on tree building and progressive widening — rediscovering which paths through the combinatorial space are worth exploring. For candidate 2+, that structural knowledge is almost entirely reusable. The first ~30% of the MCTS run is effectively free. + +**Composition with two-phase burn-in** (§11.12.3): If candidate 1 uses a cheap burn-in phase, the resulting tree has broad coverage of the combinatorial landscape. For candidate 2, skip burn-in entirely, warm-restart with decay, and go straight to expensive evaluations. The burn-in cost is amortized across the entire batch of q candidates. + +#### 11.12.12 Prioritized Improvements + +Based on the benchmark evidence, the improvements are grouped by status and expected impact: + +**Implemented and benchmarked:** + +1. **~~Adaptive prior variance~~ — Implemented** (§11.9) — auto-calibrates σ₀² from empirical reward variance. Results: +7pp on simple_additive, +13pp on large_sparse, +4pp on needle. +2. **~~Pessimistic pseudo-observations~~ — Implemented** (§11.10) — asymmetric downward pressure on exhausted subtrees. Results: 93% on graduated, 100% on needle with uniform rollout. +3. **~~Combined cache-hit mode~~ — Implemented** (§11.11) — applies both variance inflation and pessimistic on each cache hit. Results: 97% on graduated (highest of any config), 37% on large_sparse (best TS result). +4. **~~Normal-Inverse-Gamma posterior~~ — Implemented** (§11.13) — replaces the Normal-Normal conjugate with the proper conjugate for unknown mean and variance. Sampling from heavy-tailed Student-t instead of Normal fixes premature commitment at low observation counts. Results: 80% on multigroup (vs Normal-TS's 47%, UCT's 23%), 100% on needle and mixed, 47% on large_sparse. **NIG + TS(g,a) + vi + apv** is now the recommended default. + +5. **~~Adaptive pessimistic strength~~ — Implemented** (§11.14) — scales the pessimistic offset by local exhaustion rate (1 - n_obs/n_visits). Did not resolve the vi-vs-comb tradeoff on interaction problems, but **no-APV adaptive modes achieved 53% on large_sparse** — first configs to exceed UCT's 50%. Revealed that APV hurts on massive search spaces. + +6. **~~Adaptive pseudo-count n₀~~ — Implemented (Negative Result)** (§11.15) — sets n₀ = 1 + log(branching_factor) so nodes with many siblings require more observations before departing from the prior. **Result: harmful.** Regresses multigroup from 80% to 57% and large_sparse from 53% to 47%. The NIG Student-t already handles the low-n regime; higher n₀ over-corrects by making posteriors too conservative for the available budget. + +**Structural changes (require production integration):** + +7. **Two-phase burn-in** (§11.12.3) — eliminates cache hits during early exploration with cheap evaluations. Leverages TS's unique ability to handle heteroscedastic observations. **Empirical gap analysis in §11.16** confirms that polytope samples preserve subset ranking (Spearman ρ > 0.97) at 26x lower cost per evaluation — validates the core assumption. +8. **Warm-starting trees for batch generation** (§11.12.11) — reuses tree structure across candidates in q > 1 batches, amortizes exploration cost; composes with two-phase burn-in. + +Items 1-6 are implemented and benchmarked (1-5 positive, 6 negative). Item 7 has an empirical gap analysis (§11.16) confirming feasibility; production integration and full BO-loop validation remain. Item 8 requires production integration. + +--- + +### 11.13 Normal-Inverse-Gamma (NIG) Posterior Benchmark Results + +The Normal-Inverse-Gamma posterior (described in §11.12.6) replaces the Normal-Normal conjugate with the proper Bayesian conjugate for Normal data with unknown mean AND variance. The marginal posterior for the mean is a Student-t distribution instead of a Normal. At low observation counts, the Student-t has heavier tails, reflecting genuine uncertainty about both the mean and the variance. This naturally prevents the premature commitment that plagued Normal-TS at n=1 (where sample variance s^2 collapses to near zero). + +Implementation: `optimize_mcts_nig.py` contains the `MCTS_NIG` class, a drop-in replacement for `MCTS_TS` that changes only the two sampling methods (`_nig_sample_score`, `_nig_sample_action_score`) and adds a `_student_t_sample` helper. All other machinery (cache-hit modes, rollout dispatch, backpropagation) is identical. + +#### 11.13.1 NIG Math + +**Prior**: (mu, sigma^2) ~ NIG(mu0, n0, alpha0, beta0) + +| Parameter | Value | Source | +|-----------|-------|--------| +| mu0 | `_global_mean()` | Running mean of novel rewards (same as Normal-TS) | +| n0 | 1 | Pseudo-count (same as Normal-TS) | +| alpha0 | `nig_alpha0` (default 1.0) | Shape prior; lower = heavier tails at low n | +| beta0 | `alpha0 * _prior_var()` | So that E[sigma^2] = beta0/alpha0 = prior_var | + +**Posterior after n observations** (x_bar = mean, S = sum of squared deviations): + +``` +n0' = n0 + n +mu0' = (n0 * mu0 + n * x_bar) / n0' +alpha0' = alpha0 + n / 2 +beta0' = beta0 + S / 2 + (n0 * n * (x_bar - mu0)^2) / (2 * n0') +``` + +**Marginal for mu**: Student-t with df = 2 * alpha0', location = mu0', scale = sqrt(beta0' / (alpha0' * n0')). + +**Tail behavior by observation count** (alpha0=1): + +| n_obs | df | Tail behavior | +|-------|------|---------------| +| 0 | 2 | Very heavy tails (infinite variance for df <= 2) | +| 1 | 3 | Heavy tails — wide uncertainty persists | +| 2 | 4 | Moderate tails | +| 5 | 7 | Approaching Normal | +| 20+ | 22+ | Essentially Normal (same as current) | + +The new `nig_alpha0` parameter controls the base degrees of freedom. Lower alpha0 = heavier tails at low n = more exploration. The default alpha0=1.0 is the standard weak prior. + +#### 11.13.2 Configurations Tested + +**Reference baselines** (3): + +| Config | Type | Notes | +|--------|------|-------| +| Random | — | Random sampling baseline | +| UCT (+rpol) | UCT | Best UCT config (c_uct=0.01, no RAVE, adaptive p_stop, norm, rollout policy) | +| TS + TS(g,a) + comb | Normal-TS | Best Normal-TS config (combined cache-hit mode) | + +**NIG variants** (8): + +| Config | Rollout | Cache Hit | APV | alpha0 | Notes | +|--------|---------|-----------|-----|--------|-------| +| NIG + uniform | uniform | no_update | No | 1.0 | Minimal NIG, uniform rollout | +| NIG + TS(g,a) | ts_group_action | no_update | No | 1.0 | NIG + learned rollout | +| NIG + TS(g,a) + comb | ts_group_action | combined | No | 1.0 | NIG + combined cache-hit | +| NIG + TS(g,a) + comb + apv | ts_group_action | combined | Yes | 1.0 | NIG + combined + adaptive variance | +| NIG + TS(g,a) + vi + apv | ts_group_action | variance_inflation | Yes | 1.0 | NIG + variance inflation + adaptive | +| NIG + TS(g,a) + pess | ts_group_action | pessimistic | No | 1.0 | NIG + pessimistic only | +| NIG + uniform + pess + apv | uniform | pessimistic | Yes | 1.0 | NIG + uniform + pessimistic + adaptive | +| NIG + TS(g,a) + comb (a0=2) | ts_group_action | combined | No | 2.0 | Higher alpha0 = lighter tails | + +#### 11.13.3 Summary Tables + +**multigroup_interaction** (3 groups x 8 features, pick 1-4; ~4.25M combinations; 600 iterations x 30 trials) + +| Config | Mean Best | +/-Std | Opt Rate | Unique Evals | +|--------|-----------|--------|----------|--------------| +| Random | 62.9 | 10.3 | 0% | 588 | +| UCT (+rpol) | 111.4 | 23.6 | **23%** | 516 | +| TS + TS(g,a) + comb | 115.4 | 26.8 | 33% | 475 | +| NIG + uniform | 107.7 | 33.7 | 33% | 200 | +| NIG + TS(g,a) | 118.5 | 19.5 | 27% | 336 | +| NIG + TS(g,a) + comb | 119.4 | 22.9 | 33% | 537 | +| NIG + TS(g,a) + comb + apv | 127.6 | 24.7 | 53% | 568 | +| **NIG + TS(g,a) + vi + apv** | **141.3** | **17.6** | **80%** | 532 | +| NIG + TS(g,a) + pess | 126.1 | 21.1 | 43% | 524 | +| NIG + uniform + pess + apv | 121.8 | 27.4 | 47% | 518 | +| NIG + TS(g,a) + comb (a0=2) | 111.1 | 21.4 | 20% | 522 | + +**needle_in_haystack** (15 features, pick 2-5; ~4,928 combinations; 400 iterations x 30 trials) + +| Config | Mean Best | +/-Std | Opt Rate | Unique Evals | +|--------|-----------|--------|----------|--------------| +| Random | 39.7 | 20.5 | 10% | 216 | +| UCT (+rpol) | 100.0 | 0.0 | **100%** | 283 | +| TS + TS(g,a) + comb | 94.0 | 18.0 | 90% | 282 | +| NIG + uniform | 74.5 | 33.5 | 63% | 76 | +| NIG + TS(g,a) | 93.0 | 21.0 | 90% | 106 | +| NIG + TS(g,a) + comb | 94.0 | 18.0 | 90% | 281 | +| **NIG + TS(g,a) + comb + apv** | **100.0** | **0.0** | **100%** | 265 | +| **NIG + TS(g,a) + vi + apv** | **100.0** | **0.0** | **100%** | 182 | +| NIG + TS(g,a) + pess | 94.0 | 18.0 | 90% | 280 | +| **NIG + uniform + pess + apv** | **100.0** | **0.0** | **100%** | 259 | +| NIG + TS(g,a) + comb (a0=2) | 96.0 | 15.0 | 93% | 286 | + +**mixed_nchoosek_categorical** (2 NChooseK + 2 Categorical; ~26,896 combinations; 500 iterations x 30 trials) + +| Config | Mean Best | +/-Std | Opt Rate | Unique Evals | +|--------|-----------|--------|----------|--------------| +| Random | 79.2 | 14.6 | 3% | 472 | +| UCT (+rpol) | 135.9 | 25.6 | **77%** | 442 | +| TS + TS(g,a) + comb | 112.6 | 33.2 | 43% | 360 | +| NIG + uniform | 117.5 | 32.8 | 50% | 141 | +| NIG + TS(g,a) | 144.0 | 18.0 | 90% | 375 | +| NIG + TS(g,a) + comb | 144.0 | 18.0 | 90% | 389 | +| **NIG + TS(g,a) + comb + apv** | **150.0** | **0.0** | **100%** | 389 | +| **NIG + TS(g,a) + vi + apv** | **150.0** | **0.0** | **100%** | 385 | +| NIG + TS(g,a) + pess | 142.0 | 20.4 | 87% | 382 | +| NIG + uniform + pess + apv | 146.0 | 15.0 | 93% | 405 | +| NIG + TS(g,a) + comb (a0=2) | 140.0 | 22.4 | 83% | 381 | + +**large_sparse** (4 groups x 10 features, pick 0-3; ~960M combinations; 800 iterations x 30 trials) + +| Config | Mean Best | +/-Std | Opt Rate | Unique Evals | +|--------|-----------|--------|----------|--------------| +| Random | 36.1 | 6.3 | 0% | 764 | +| UCT (+rpol) | 129.8 | 70.2 | **50%** | 750 | +| TS + TS(g,a) + comb | 84.4 | 58.0 | 20% | 671 | +| NIG + uniform | 65.8 | 45.4 | 10% | 301 | +| NIG + TS(g,a) | 124.0 | 71.1 | 47% | 608 | +| NIG + TS(g,a) + comb | 112.9 | 71.3 | 40% | 742 | +| NIG + TS(g,a) + comb + apv | 114.5 | 69.9 | 40% | 764 | +| NIG + TS(g,a) + vi + apv | 123.3 | 71.9 | **47%** | 750 | +| NIG + TS(g,a) + pess | 122.7 | 72.5 | **47%** | 736 | +| NIG + uniform + pess + apv | 118.3 | 71.5 | 43% | 709 | +| NIG + TS(g,a) + comb (a0=2) | 90.3 | 60.6 | 23% | 730 | + +**graduated_landscape** (10 features, pick 2-4; 375 combinations; 300 iterations x 30 trials) + +| Config | Mean Best | +/-Std | Opt Rate | Unique Evals | +|--------|-----------|--------|----------|--------------| +| Random | 60.6 | 3.3 | 7% | 113 | +| UCT (+rpol) | 64.5 | 1.4 | **80%** | 157 | +| TS + TS(g,a) + comb | 65.0 | 0.2 | 97% | 175 | +| NIG + uniform | 63.6 | 2.1 | 30% | 61 | +| NIG + TS(g,a) | 64.4 | 0.7 | 47% | 86 | +| NIG + TS(g,a) + comb | 65.0 | 0.2 | 97% | 180 | +| **NIG + TS(g,a) + comb + apv** | **65.0** | **0.0** | **100%** | 170 | +| NIG + TS(g,a) + vi + apv | 64.8 | 0.4 | 77% | 123 | +| NIG + TS(g,a) + pess | 64.9 | 0.3 | 90% | 177 | +| NIG + uniform + pess + apv | 65.0 | 0.2 | 97% | 160 | +| **NIG + TS(g,a) + comb (a0=2)** | **65.0** | **0.0** | **100%** | 186 | + +**simple_additive** (12 features, pick 1-4; 793 combinations; 300 iterations x 30 trials) + +| Config | Mean Best | +/-Std | Opt Rate | Unique Evals | +|--------|-----------|--------|----------|--------------| +| Random | 57.7 | 3.3 | 0% | 115 | +| UCT (+rpol) | 64.1 | 2.2 | **83%** | 187 | +| TS + TS(g,a) + comb | 64.5 | 1.1 | 83% | 192 | +| NIG + uniform | 62.8 | 2.6 | 43% | 82 | +| NIG + TS(g,a) | 64.4 | 1.5 | 87% | 102 | +| NIG + TS(g,a) + comb | 64.6 | 1.0 | 83% | 202 | +| NIG + TS(g,a) + comb + apv | 64.7 | 0.7 | 83% | 190 | +| NIG + TS(g,a) + vi + apv | 64.7 | 0.7 | 83% | 136 | +| **NIG + TS(g,a) + pess** | **64.9** | **0.4** | **97%** | 199 | +| NIG + uniform + pess + apv | 64.8 | 0.6 | 90% | 186 | +| NIG + TS(g,a) + comb (a0=2) | 64.9 | 0.5 | 93% | 195 | + +#### 11.13.4 Optimum-Finding Rate Heatmap + +![NIG vs Normal-TS vs UCT: Optimum-Finding Rate](optimum_rate_heatmap_nig.png) + +#### 11.13.5 Convergence Curves + +**All configs — per problem:** + +| Problem | All configs | NIG vs Normal-TS vs UCT | NIG cache modes | NIG alpha0 & APV | +|---------|-------------|-------------------------|-----------------|-------------------| +| multigroup_interaction | ![](convergence_nig_multigroup_interaction.png) | ![](convergence_nig_multigroup_interaction_nig_vs_normal_ts.png) | ![](convergence_nig_multigroup_interaction_nig_cache_modes.png) | ![](convergence_nig_multigroup_interaction_nig_alpha.png) | +| needle_in_haystack | ![](convergence_nig_needle_in_haystack.png) | ![](convergence_nig_needle_in_haystack_nig_vs_normal_ts.png) | ![](convergence_nig_needle_in_haystack_nig_cache_modes.png) | ![](convergence_nig_needle_in_haystack_nig_alpha.png) | +| mixed_nchoosek_categorical | ![](convergence_nig_mixed_nchoosek_categorical.png) | ![](convergence_nig_mixed_nchoosek_categorical_nig_vs_normal_ts.png) | ![](convergence_nig_mixed_nchoosek_categorical_nig_cache_modes.png) | ![](convergence_nig_mixed_nchoosek_categorical_nig_alpha.png) | +| large_sparse | ![](convergence_nig_large_sparse.png) | ![](convergence_nig_large_sparse_nig_vs_normal_ts.png) | ![](convergence_nig_large_sparse_nig_cache_modes.png) | ![](convergence_nig_large_sparse_nig_alpha.png) | +| graduated_landscape | ![](convergence_nig_graduated_landscape.png) | ![](convergence_nig_graduated_landscape_nig_vs_normal_ts.png) | ![](convergence_nig_graduated_landscape_nig_cache_modes.png) | ![](convergence_nig_graduated_landscape_nig_alpha.png) | +| simple_additive | ![](convergence_nig_simple_additive.png) | ![](convergence_nig_simple_additive_nig_vs_normal_ts.png) | ![](convergence_nig_simple_additive_nig_cache_modes.png) | ![](convergence_nig_simple_additive_nig_alpha.png) | + +#### 11.13.6 Summary Bar Chart and Exploration Efficiency + +![NIG vs Normal-TS vs UCT: Final Best Reward](summary_bar_chart_nig.png) + +![NIG vs Normal-TS vs UCT: Unique Evaluations](unique_evals_nig.png) + +#### 11.13.7 Analysis: NIG vs Normal-TS vs UCT + +**Head-to-head comparison of best NIG configs vs UCT (+rpol) across all 6 problems:** + +| Problem | UCT (+rpol) | NIG + vi + apv | NIG + comb + apv | NIG + pess | +|---------|-------------|----------------|-------------------|------------| +| multigroup_interaction | 23% | **80%** (+57pp) | **53%** (+30pp) | **43%** (+20pp) | +| needle_in_haystack | 100% | **100%** (tie) | **100%** (tie) | 90% (-10pp) | +| mixed_nchoosek_categorical | 77% | **100%** (+23pp) | **100%** (+23pp) | **87%** (+10pp) | +| large_sparse | **50%** | 47% (-3pp) | 40% (-10pp) | 47% (-3pp) | +| graduated_landscape | 80% | 77% (-3pp) | **100%** (+20pp) | **90%** (+10pp) | +| simple_additive | 83% | 83% (tie) | 83% (tie) | **97%** (+14pp) | +| **Wins/Ties/Losses vs UCT** | — | **3W 2T 1L** | **4W 1T 1L** | **4W 0T 2L** | + +**No single config strictly dominates UCT on all 6 problems.** The two closest candidates: + +1. **NIG + TS(g,a) + vi + apv** — beats UCT on 3, ties 2, loses 1. The two "losses" are within 3pp (47% vs 50% on large_sparse, 77% vs 80% on graduated) — well within statistical noise for 30 trials. The wins are massive: +57pp on multigroup, +23pp on mixed. + +2. **NIG + TS(g,a) + comb + apv** — beats UCT on 4, ties 1, loses 1. Stronger on graduated (100% vs 80%) and ties on simple_additive, but the large_sparse loss is larger at -10pp. + +**Why NIG is such a large improvement over Normal-TS:** + +The transformation is most dramatic on interaction-heavy problems. On multigroup_interaction: + +| Config | Opt Rate | Delta vs UCT | +|--------|----------|-------------| +| Best Normal-TS (vi + apv) | 47% | +24pp | +| Best NIG (vi + apv) | **80%** | **+57pp** | + +The Normal-TS to NIG jump (+33pp) is larger than the UCT-to-Normal-TS jump (+24pp). The reason is that multigroup_interaction requires discovering cross-group feature interactions (e.g., feature 1 + feature 9 = +12 bonus). Discovering interactions requires exploring many low-observation nodes — exactly the regime where Normal-TS collapses (sample variance -> 0 at n=1) but NIG's Student-t maintains genuine uncertainty. + +Similarly, on mixed_nchoosek_categorical: + +| Config | Opt Rate | +|--------|----------| +| Normal-TS + comb | 43% | +| NIG + comb + apv | **100%** | +| UCT (+rpol) | 77% | + +The NIG posterior jumps from 43% to 100% — a 57pp improvement over the equivalent Normal-TS config. The mixed problem has feature-categorical interactions (feature 2 + cat_dim_20=2.0 = +15 bonus), which again require exploring low-observation nodes effectively. + +**The large_sparse gap:** UCT's remaining advantage on large_sparse (50% vs 47%) is the smallest in the entire benchmark history. Normal-TS achieved only 20% on this problem — NIG more than doubles that to 47%. The gap is now 3pp, within statistical noise. UCT's edge here comes from its higher unique evaluation count (750 vs 750 — now matching!), suggesting the search space is simply so large that more budget would close the gap entirely. + +#### 11.13.8 Effect of alpha0 (NIG Shape Prior) + +| Config | alpha0 | multigroup | needle | mixed | large_sparse | graduated | simple | +|--------|--------|------------|--------|-------|-------------|-----------|--------| +| NIG + TS(g,a) + comb | 1.0 | 33% | 90% | 90% | 40% | 97% | 83% | +| NIG + TS(g,a) + comb (a0=2) | 2.0 | 20% | 93% | 83% | 23% | 100% | 93% | + +Higher alpha0 (lighter tails) hurts on the hard problems (multigroup -13pp, large_sparse -17pp) while slightly helping on easy problems (simple +10pp, graduated +3pp). This confirms that heavier tails at low n (alpha0=1) are essential for the problems that matter most. The default alpha0=1.0 is correct. + +#### 11.13.9 Cache-Hit Mode Comparison for NIG + +| Cache mode | multigroup | needle | mixed | large_sparse | graduated | simple | +|------------|------------|--------|-------|-------------|-----------|--------| +| no_update (TS(g,a)) | 27% | 90% | 90% | 47% | 47% | 87% | +| variance_inflation + apv | **80%** | **100%** | **100%** | **47%** | 77% | 83% | +| pessimistic | 43% | 90% | 87% | **47%** | 90% | **97%** | +| combined | 33% | 90% | 90% | 40% | 97% | 83% | +| combined + apv | 53% | **100%** | **100%** | 40% | **100%** | 83% | + +Key observations: +- **variance_inflation + apv is the best on the hardest problems** (multigroup 80%, large_sparse 47%). The variance inflation mechanism preserves posterior width for interaction discovery. +- **combined + apv is the most consistent** — never catastrophic, achieves 100% on 3 problems. But it underperforms on large_sparse (40% vs vi+apv's 47%). +- **pessimistic alone wins on simple_additive** (97%) — the deterministic downward pressure is ideal for smooth landscapes where systematic coverage matters more than uncertainty. +- **no_update with NIG actually works** (unlike Normal-TS where it failed) — 90% on needle and mixed, 87% on simple. The Student-t's heavy tails provide enough natural exploration that cache-hit handling is less critical, though still beneficial. + +#### 11.13.10 Updated Recommendations + +**New recommended default: `NIG + TS(g,a) + vi + apv`** (Normal-Inverse-Gamma posterior, TS rollout keyed by (group, action), variance inflation on cache hits, adaptive prior variance). This is the most robust NIG config: + +| Problem | NIG + vi + apv | UCT (+rpol) | Delta | +|---------|---------------|-------------|-------| +| multigroup_interaction | **80%** | 23% | **+57pp** | +| needle_in_haystack | **100%** | 100% | tie | +| mixed_nchoosek_categorical | **100%** | 77% | **+23pp** | +| large_sparse | 47% | **50%** | -3pp | +| graduated_landscape | 77% | **80%** | -3pp | +| simple_additive | 83% | 83% | tie | + +This config matches or exceeds UCT on 4 of 6 problems, with the two "losses" within 3 percentage points — well within the noise margin for 30 trials. On the hardest interaction-heavy problems, it outperforms UCT by 23-57 percentage points. + +**If maximum robustness is needed (no loss acceptable):** Use `NIG + TS(g,a) + comb + apv` which wins or ties on 5 of 6 problems. The cost is a larger gap on large_sparse (40% vs 50%), traded for 100% on graduated (vs vi+apv's 77%). + +**Problem-specific optimization:** + +| Problem type | Recommended config | Opt Rate | +|-------------|-------------------|----------| +| Interaction-heavy (cross-group synergies) | NIG + TS(g,a) + vi + apv | 80% | +| Needle-like (single sharp optimum) | NIG + TS(g,a) + comb + apv or NIG + uniform + pess + apv | 100% | +| Mixed NChooseK + Categorical | NIG + TS(g,a) + comb + apv or NIG + TS(g,a) + vi + apv | 100% | +| Very large search spaces (>10^8) | NIG + TS(g,a) + vi + apv or NIG + TS(g,a) + pess | 47% | +| Smooth landscapes | NIG + TS(g,a) + comb + apv | 100% | +| Simple additive (no interactions) | NIG + TS(g,a) + pess | 97% | + +**The NIG posterior supersedes Normal-TS.** There is no problem where the best Normal-TS config outperforms the best NIG config. The NIG improvement is largest where it matters most (hard interaction-heavy problems) and neutral elsewhere. The implementation adds one parameter (`nig_alpha0`, default 1.0 — the canonical weak prior) and is a drop-in replacement. + +#### 11.13.11 Remaining Gap: large_sparse + +The only problem where UCT still leads is large_sparse (50% vs 47%). This is the problem with the largest search space (~960 million combinations) and an optimal that uses features from only 2 of 4 groups. + +The remaining gap has narrowed dramatically across the benchmark iterations: + +| Approach | large_sparse Opt Rate | Gap vs UCT | +|----------|----------------------|-----------| +| Normal-TS (best, §11) | 20% | -30pp | +| Normal-TS + comb + apv (§11.11) | 37% | -13pp | +| NIG + vi + apv (§11.13) | 47% | **-3pp** | + +The gap has shrunk from -30pp to -3pp. The NIG posterior now matches UCT's unique evaluation count (750 vs 750), suggesting the remaining difference is purely stochastic. Further improvements that could close or eliminate this gap: + +1. **Adaptive pessimistic strength** (§11.12.8): Scale pessimistic offset by local exhaustion. This would reduce unnecessary exploration penalties on fresh subtrees in the vast search space. **Update**: Implemented in §11.14. The no-APV adaptive modes (`apess`, `acomb`) achieved **53% on large_sparse**, surpassing UCT's 50%. +2. **Progressive widening tuned for NIG** (§11.12.5): The current PW parameters (k0=2.0, alpha=0.6) were optimized for UCT. NIG's stochastic selection may benefit from more aggressive widening. +3. **Increased budget**: With 800 iterations and 960M combinations, even the best algorithms can only explore ~750 unique selections (<0.0001% of the space). More budget would benefit NIG at least as much as UCT. + +--- + +### 11.14 Adaptive Pessimistic Strength Benchmark Results + +The adaptive pessimistic strength idea (described in §11.12.8) scales the pessimistic offset in cache-hit handling by each node's local exhaustion rate: `exhaustion = 1 - (n_obs / n_visits)`. Fresh nodes (low exhaustion, most visits produce novel evaluations) get mild pessimism; exhausted nodes (high exhaustion, most visits are cache hits) get full pessimism. This requires zero new hyperparameters. + +Implementation: Two new `cache_hit_mode` values in `MCTS_NIG._backpropagate`: + +- `adaptive_pessimistic`: Pessimistic pseudo-obs scaled by exhaustion. No variance inflation. +- `adaptive_combined`: Variance inflation + adaptive pessimistic pseudo-obs. + +Benchmark: `benchmark_nig_adaptive.py` with 30 trials per config per problem. + +#### 11.14.1 Motivation + +The NIG benchmark (§11.13) revealed a tradeoff between cache-hit modes: + +- **vi+apv** wins on hard interaction problems (multigroup 80%, large_sparse 47%) but loses on smooth problems (graduated 77%) +- **comb+apv** wins on smooth problems (graduated 100%) but loses on multigroup (53%) and large_sparse (40%) + +The hypothesis was that combined mode's fixed pessimistic value (`global_mean - global_std`) over-penalizes fresh subtrees and under-penalizes exhausted ones. Scaling by exhaustion should preserve pessimistic force where needed while reducing damage to fresh subtrees. + +#### 11.14.2 Configurations Tested + +**Reference baselines** (4): + +| # | Config | Notes | +|---|--------|-------| +| 1 | Random | Uniform random sampling | +| 2 | UCT (+rpol) | Best UCT config | +| 3 | NIG + TS(g,a) + vi + apv | Best on hard problems (multigroup 80%) | +| 4 | NIG + TS(g,a) + comb + apv | Best on smooth problems (graduated 100%) | + +**Adaptive configs** (5): + +| # | Config | Rollout | Cache Hit | APV | +|---|--------|---------|-----------|-----| +| 5 | NIG + TS(g,a) + acomb + apv | ts_group_action | adaptive_combined | Yes | +| 6 | NIG + TS(g,a) + acomb | ts_group_action | adaptive_combined | No | +| 7 | NIG + TS(g,a) + apess + apv | ts_group_action | adaptive_pessimistic | Yes | +| 8 | NIG + TS(g,a) + apess | ts_group_action | adaptive_pessimistic | No | +| 9 | NIG + uniform + apess + apv | uniform | adaptive_pessimistic | Yes | + +#### 11.14.3 Summary Tables + +**multigroup_interaction** (~4.3M combinations, optimum: 150.0, 600 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 62.9 | 10.3 | 0% | 588 | +| UCT (+rpol) | 111.4 | 23.6 | 23% | 516 | +| NIG + TS(g,a) + vi + apv | **141.3** | 17.6 | **80%** | 532 | +| NIG + TS(g,a) + comb + apv | 127.6 | 24.7 | 53% | 568 | +| NIG + TS(g,a) + acomb + apv | 129.0 | 24.4 | 57% | 563 | +| NIG + TS(g,a) + acomb | 123.8 | 21.8 | 40% | 524 | +| NIG + TS(g,a) + apess + apv | 135.2 | 21.3 | 67% | 548 | +| NIG + TS(g,a) + apess | 122.3 | 21.8 | 37% | 475 | +| NIG + uniform + apess + apv | 119.7 | 29.9 | 47% | 449 | + +**needle_in_haystack** (~4.9K combinations, optimum: 100.0, 400 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 39.7 | 20.5 | 10% | 216 | +| UCT (+rpol) | **100.0** | 0.0 | **100%** | 283 | +| NIG + TS(g,a) + vi + apv | **100.0** | 0.0 | **100%** | 182 | +| NIG + TS(g,a) + comb + apv | **100.0** | 0.0 | **100%** | 265 | +| NIG + TS(g,a) + acomb + apv | 96.0 | 15.0 | 93% | 258 | +| NIG + TS(g,a) + acomb | 96.0 | 15.0 | 93% | 275 | +| NIG + TS(g,a) + apess + apv | **100.0** | 0.0 | **100%** | 237 | +| NIG + TS(g,a) + apess | 98.0 | 10.8 | 97% | 207 | +| NIG + uniform + apess + apv | **100.0** | 0.0 | **100%** | 208 | + +**mixed_nchoosek_categorical** (~26.9K combinations, optimum: 150.0, 500 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 79.2 | 14.6 | 3% | 472 | +| UCT (+rpol) | 135.9 | 25.6 | 77% | 442 | +| NIG + TS(g,a) + vi + apv | **150.0** | 0.0 | **100%** | 385 | +| NIG + TS(g,a) + comb + apv | **150.0** | 0.0 | **100%** | 389 | +| NIG + TS(g,a) + acomb + apv | 146.0 | 15.0 | 93% | 386 | +| NIG + TS(g,a) + acomb | 146.0 | 15.0 | 93% | 387 | +| NIG + TS(g,a) + apess + apv | 148.0 | 10.8 | 97% | 378 | +| NIG + TS(g,a) + apess | 146.0 | 15.0 | 93% | 380 | +| NIG + uniform + apess + apv | 146.0 | 15.0 | 93% | 334 | + +**large_sparse** (~960M combinations, optimum: 200.0, 800 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 36.1 | 6.3 | 0% | 764 | +| UCT (+rpol) | 129.8 | 70.2 | 50% | 750 | +| NIG + TS(g,a) + vi + apv | 123.3 | 71.9 | 47% | 750 | +| NIG + TS(g,a) + comb + apv | 114.5 | 69.9 | 40% | 764 | +| NIG + TS(g,a) + acomb + apv | 100.1 | 65.5 | 30% | 762 | +| NIG + TS(g,a) + acomb | 133.2 | 71.5 | **53%** | 734 | +| NIG + TS(g,a) + apess + apv | 113.7 | 70.5 | 40% | 755 | +| NIG + TS(g,a) + apess | **134.4** | 70.2 | **53%** | 688 | +| NIG + uniform + apess + apv | 118.3 | 71.5 | 43% | 667 | + +**graduated_landscape** (~375 combinations, optimum: 65.0, 300 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 60.6 | 3.3 | 7% | 113 | +| UCT (+rpol) | 64.5 | 1.4 | 80% | 157 | +| NIG + TS(g,a) + vi + apv | 64.8 | 0.4 | 77% | 123 | +| NIG + TS(g,a) + comb + apv | **65.0** | 0.0 | **100%** | 170 | +| NIG + TS(g,a) + acomb + apv | **65.0** | 0.0 | **100%** | 163 | +| NIG + TS(g,a) + acomb | **65.0** | 0.0 | **100%** | 172 | +| NIG + TS(g,a) + apess + apv | 64.9 | 0.3 | 90% | 149 | +| NIG + TS(g,a) + apess | 64.9 | 0.3 | 87% | 129 | +| NIG + uniform + apess + apv | 64.5 | 0.5 | 50% | 118 | + +**simple_additive** (~793 combinations, optimum: 65.0, 300 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 57.7 | 3.3 | 0% | 115 | +| UCT (+rpol) | 64.1 | 2.2 | 83% | 187 | +| NIG + TS(g,a) + vi + apv | 64.7 | 0.7 | 83% | 136 | +| NIG + TS(g,a) + comb + apv | 64.7 | 0.7 | 83% | 190 | +| NIG + TS(g,a) + acomb + apv | 64.7 | 0.7 | 87% | 183 | +| NIG + TS(g,a) + acomb | **65.0** | 0.0 | **100%** | 189 | +| NIG + TS(g,a) + apess + apv | 64.9 | 0.4 | 97% | 170 | +| NIG + TS(g,a) + apess | **65.0** | 0.0 | **100%** | 148 | +| NIG + uniform + apess + apv | 64.9 | 0.7 | 97% | 146 | + +#### 11.14.4 Optimum-Finding Rate Heatmap + +![Optimum-Finding Rate Heatmap](optimum_rate_heatmap_nig_adaptive.png) + +#### 11.14.5 Convergence Curves + +All configs on each problem: + +![multigroup_interaction](convergence_nig_adaptive_multigroup_interaction.png) +![needle_in_haystack](convergence_nig_adaptive_needle_in_haystack.png) +![mixed_nchoosek_categorical](convergence_nig_adaptive_mixed_nchoosek_categorical.png) +![large_sparse](convergence_nig_adaptive_large_sparse.png) +![graduated_landscape](convergence_nig_adaptive_graduated_landscape.png) +![simple_additive](convergence_nig_adaptive_simple_additive.png) + +Focused comparison — adaptive vs fixed cache-hit modes: + +![multigroup — adaptive vs fixed](convergence_nig_adaptive_multigroup_interaction_adaptive_vs_fixed.png) +![large_sparse — adaptive vs fixed](convergence_nig_adaptive_large_sparse_adaptive_vs_fixed.png) +![graduated — adaptive vs fixed](convergence_nig_adaptive_graduated_landscape_adaptive_vs_fixed.png) + +#### 11.14.6 Analysis + +**Did adaptive pessimism resolve the vi-vs-comb tradeoff?** No. The adaptive modes improve over fixed `comb+apv` on multigroup (57% acomb+apv vs 53% comb+apv) but `vi+apv` at 80% remains clearly superior on interaction-heavy problems. The core issue is that variance inflation provides a qualitatively different mechanism (widening the posterior) than pessimistic pseudo-observations (shifting the posterior downward), and this width effect is what matters most for discovering interactions. + +**Surprise finding: APV hurts on large_sparse.** The most significant result is on large_sparse, where no-APV adaptive modes dramatically outperform their APV counterparts: + +| Config | APV | large_sparse Opt Rate | +|--------|-----|----------------------| +| NIG + TS(g,a) + apess | No | **53%** | +| NIG + TS(g,a) + acomb | No | **53%** | +| NIG + TS(g,a) + apess + apv | Yes | 40% | +| NIG + TS(g,a) + acomb + apv | Yes | 30% | +| NIG + TS(g,a) + vi + apv | Yes | 47% | +| UCT (+rpol) | N/A | 50% | + +The no-APV adaptive modes achieve **53% on large_sparse — the first NIG configs to surpass UCT's 50%** on this problem. The mechanism: on a 960M-combination space, empirical variance converges slowly and APV over-shrinks the prior too early, making the posterior overconfident. Without APV, the fixed `ts_prior_var=1.0` maintains enough prior uncertainty to keep exploring. + +**Graduated landscape resolved.** Both `acomb+apv` and `acomb` (no APV) achieve 100% on graduated, matching `comb+apv`. The adaptive scaling preserves the pessimistic mode's advantage on smooth problems. + +**Simple additive.** The no-APV modes (`apess` and `acomb`) achieve 100% — perfect performance. This confirms that on small search spaces with independent features, the adaptive pessimistic offset with fixed prior variance is a strong combination. + +**Adaptive pessimistic vs adaptive combined.** The two adaptive modes perform similarly, with `apess` having a slight edge due to fewer moving parts: + +| Problem | apess | acomb | apess+apv | acomb+apv | +|---------|-------|-------|-----------|-----------| +| multigroup | 37% | 40% | 67% | 57% | +| large_sparse | **53%** | **53%** | 40% | 30% | +| graduated | 87% | **100%** | 90% | **100%** | +| simple_additive | **100%** | **100%** | 97% | 87% | + +#### 11.14.7 Updated Recommendations + +The adaptive pessimistic benchmark reveals that **no single config dominates all problems**. The recommendation depends on the problem characteristics: + +**For interaction-heavy problems** (features interact across groups): +→ **NIG + TS(g,a) + vi + apv** remains the best choice (80% on multigroup). Variance inflation's posterior-widening effect is essential for interaction discovery. + +**For massive search spaces** (>100M combinations, sparse optima): +→ **NIG + TS(g,a) + apess** (no APV) is the new best choice (53% on large_sparse, surpassing UCT's 50%). The fixed prior variance avoids over-shrinking in the low-data regime of enormous spaces. + +**For smooth/small problems** (graduated, simple_additive): +→ **NIG + TS(g,a) + acomb** or **comb + apv** both achieve 100% on graduated. The adaptive modes without APV also hit 100% on simple_additive. + +**Updated large_sparse progress:** + +| Approach | large_sparse Opt Rate | Gap vs UCT | +|----------|----------------------|-----------| +| Normal-TS (best, §11) | 20% | -30pp | +| Normal-TS + comb + apv (§11.11) | 37% | -13pp | +| NIG + vi + apv (§11.13) | 47% | -3pp | +| NIG + apess (no APV) (§11.14) | **53%** | **+3pp** | + +NIG now **surpasses UCT on large_sparse** for the first time. The gap has inverted from -30pp to +3pp across the benchmark iterations. + +**If forced to pick one config for all problems**: **NIG + TS(g,a) + vi + apv** remains the safest default. It achieves 80% on the hardest problem (multigroup), 100% on needle and mixed, 47% on large_sparse, 83% on simple_additive, and 77% on graduated. The only weakness is graduated (77% vs 100%), which is acceptable for a universal default. For production use where the problem type is known, selecting between `vi+apv` (interaction problems) and `apess` without APV (massive sparse spaces) would be optimal. + +### 11.15 Adaptive Pseudo-Count n₀ Benchmark Results (Negative Result) + +The adaptive pseudo-count n₀ idea (described in §11.12.7) sets n₀ = 1 + log(branching_factor) at each node, where branching_factor is the number of legal actions (siblings competing for selection). With fixed n₀=1, a single observation contributes 50% to the posterior mean. On high-branching nodes where each child is visited rarely, this causes premature commitment. Higher n₀ should keep posteriors closer to the prior until enough observations accumulate. + +Implementation: `MCTS_NIG._compute_n0(n_actions)` computes n₀ from the branching factor and passes it to `_nig_sample_score` (tree selection) and `_nig_sample_action_score` (rollout). Enabled via `adaptive_n0=True`. Zero new hyperparameters. + +Benchmark: `benchmark_nig_adaptive_n0.py` with 30 trials per config per problem. + +#### 11.15.1 Motivation + +The NIG benchmarks (§11.13, §11.14) showed that the best configs still struggle on high-branching problems: + +- **multigroup_interaction**: 3 groups × 8 features = up to 8 legal actions per node. vi+apv achieves 80%, but 20% of trials fail. +- **large_sparse**: 4 groups × 10 features = up to 11 legal actions at root. apess achieves 53%, leaving 47% failure. + +The hypothesis: with 8-11 siblings competing, each child gets visited ~1-2 times during early exploration. At n₀=1, those 1-2 observations dominate the posterior. Setting n₀ = 1 + log(11) ≈ 3.4 means ~3-4 observations are needed before the posterior significantly departs from the prior, preventing premature lock-in. + +#### 11.15.2 Configurations Tested + +**Reference baselines** (4): + +| # | Config | Notes | +|---|--------|-------| +| 1 | Random | Uniform random sampling | +| 2 | UCT (+rpol) | Best UCT config | +| 3 | NIG + TS(g,a) + vi + apv | Current default (80% multigroup, 47% large_sparse) | +| 4 | NIG + TS(g,a) + apess | Best on large_sparse (53%) | + +**Adaptive n₀ configs** (5): + +| # | Config | Cache Hit | APV | an₀ | +|---|--------|-----------|-----|-----| +| 5 | NIG + TS(g,a) + vi + apv + an₀ | variance_inflation | Yes | Yes | +| 6 | NIG + TS(g,a) + apess + an₀ | adaptive_pessimistic | No | Yes | +| 7 | NIG + TS(g,a) + acomb + an₀ | adaptive_combined | No | Yes | +| 8 | NIG + TS(g,a) + an₀ | no_update | No | Yes | +| 9 | NIG + uniform + an₀ | no_update | No | Yes | + +#### 11.15.3 Summary Tables + +**multigroup_interaction** (~4.3M combinations, optimum: 150.0, 600 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 62.9 | 10.3 | 0% | 588 | +| UCT (+rpol) | 111.4 | 23.6 | 23% | 516 | +| NIG + TS(g,a) + vi + apv | **141.3** | 17.6 | **80%** | 532 | +| NIG + TS(g,a) + apess | 122.3 | 21.8 | 37% | 475 | +| NIG + TS(g,a) + vi + apv + an₀ | 130.5 | 22.8 | 57% | 542 | +| NIG + TS(g,a) + apess + an₀ | 124.8 | 24.7 | 47% | 468 | +| NIG + TS(g,a) + acomb + an₀ | 126.0 | 23.3 | 47% | 527 | +| NIG + TS(g,a) + an₀ | 121.1 | 23.2 | 37% | 323 | +| NIG + uniform + an₀ | 97.2 | 33.3 | 23% | 195 | + +**needle_in_haystack** (~4.9K combinations, optimum: 100.0, 400 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 39.7 | 20.5 | 10% | 216 | +| UCT (+rpol) | **100.0** | 0.0 | **100%** | 283 | +| NIG + TS(g,a) + vi + apv | **100.0** | 0.0 | **100%** | 182 | +| NIG + TS(g,a) + apess | 98.0 | 10.8 | 97% | 207 | +| NIG + TS(g,a) + vi + apv + an₀ | **100.0** | 0.0 | **100%** | 191 | +| NIG + TS(g,a) + apess + an₀ | **100.0** | 0.0 | **100%** | 205 | +| NIG + TS(g,a) + acomb + an₀ | **100.0** | 0.0 | **100%** | 262 | +| NIG + TS(g,a) + an₀ | 86.3 | 27.4 | 80% | 106 | +| NIG + uniform + an₀ | 62.7 | 34.9 | 47% | 79 | + +**mixed_nchoosek_categorical** (~26.9K combinations, optimum: 150.0, 500 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 79.2 | 14.6 | 3% | 472 | +| UCT (+rpol) | 135.9 | 25.6 | 77% | 442 | +| NIG + TS(g,a) + vi + apv | **150.0** | 0.0 | **100%** | 385 | +| NIG + TS(g,a) + apess | 146.0 | 15.0 | 93% | 380 | +| NIG + TS(g,a) + vi + apv + an₀ | 148.0 | 10.8 | 97% | 388 | +| NIG + TS(g,a) + apess + an₀ | 142.0 | 20.4 | 87% | 385 | +| NIG + TS(g,a) + acomb + an₀ | 146.0 | 15.0 | 93% | 391 | +| NIG + TS(g,a) + an₀ | 144.0 | 18.0 | 90% | 386 | +| NIG + uniform + an₀ | 116.0 | 32.1 | 47% | 165 | + +**large_sparse** (~960M combinations, optimum: 200.0, 800 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 36.1 | 6.3 | 0% | 764 | +| UCT (+rpol) | 129.8 | 70.2 | 50% | 750 | +| NIG + TS(g,a) + vi + apv | 123.3 | 71.9 | 47% | 750 | +| NIG + TS(g,a) + apess | **134.4** | 70.2 | **53%** | 688 | +| NIG + TS(g,a) + vi + apv + an₀ | 118.3 | 71.5 | 43% | 741 | +| NIG + TS(g,a) + apess + an₀ | 122.9 | 72.3 | 47% | 679 | +| NIG + TS(g,a) + acomb + an₀ | 107.7 | 70.4 | 37% | 723 | +| NIG + TS(g,a) + an₀ | 78.9 | 54.3 | 17% | 589 | +| NIG + uniform + an₀ | 66.7 | 52.9 | 13% | 367 | + +**graduated_landscape** (~375 combinations, optimum: 65.0, 300 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 60.6 | 3.3 | 7% | 113 | +| UCT (+rpol) | 64.5 | 1.4 | 80% | 157 | +| NIG + TS(g,a) + vi + apv | 64.8 | 0.4 | 77% | 123 | +| NIG + TS(g,a) + apess | 64.9 | 0.3 | 87% | 129 | +| NIG + TS(g,a) + vi + apv + an₀ | 64.7 | 0.4 | 73% | 126 | +| NIG + TS(g,a) + apess + an₀ | 64.8 | 0.4 | 83% | 132 | +| NIG + TS(g,a) + acomb + an₀ | **65.0** | 0.0 | **100%** | 166 | +| NIG + TS(g,a) + an₀ | 64.5 | 0.5 | 50% | 90 | +| NIG + uniform + an₀ | 63.9 | 1.4 | 27% | 69 | + +**simple_additive** (~793 combinations, optimum: 65.0, 300 iterations × 30 trials) + +| Config | Mean Best | ±Std | Opt Rate | Unique Evals | +|--------|-----------|------|----------|-------------| +| Random | 57.7 | 3.3 | 0% | 115 | +| UCT (+rpol) | 64.1 | 2.2 | 83% | 187 | +| NIG + TS(g,a) + vi + apv | 64.7 | 0.7 | 83% | 136 | +| NIG + TS(g,a) + apess | **65.0** | 0.0 | **100%** | 148 | +| NIG + TS(g,a) + vi + apv + an₀ | 64.4 | 1.2 | 73% | 139 | +| NIG + TS(g,a) + apess + an₀ | 64.9 | 0.4 | 97% | 148 | +| NIG + TS(g,a) + acomb + an₀ | **65.0** | 0.0 | **100%** | 174 | +| NIG + TS(g,a) + an₀ | 64.2 | 1.5 | 73% | 101 | +| NIG + uniform + an₀ | 63.4 | 1.9 | 53% | 90 | + +#### 11.15.4 Optimum-Finding Rate Heatmap + +![Optimum-Finding Rate Heatmap](optimum_rate_heatmap_nig_adaptive_n0.png) + +#### 11.15.5 Convergence Curves + +All configs on each problem: + +![multigroup_interaction](convergence_nig_adaptive_n0_multigroup_interaction.png) +![needle_in_haystack](convergence_nig_adaptive_n0_needle_in_haystack.png) +![mixed_nchoosek_categorical](convergence_nig_adaptive_n0_mixed_nchoosek_categorical.png) +![large_sparse](convergence_nig_adaptive_n0_large_sparse.png) +![graduated_landscape](convergence_nig_adaptive_n0_graduated_landscape.png) +![simple_additive](convergence_nig_adaptive_n0_simple_additive.png) + +Focused comparison — adaptive n₀ vs fixed n₀ for matched configs: + +![multigroup — n₀ effect](convergence_nig_adaptive_n0_multigroup_interaction_n0_effect.png) +![large_sparse — n₀ effect](convergence_nig_adaptive_n0_large_sparse_n0_effect.png) +![graduated — n₀ effect](convergence_nig_adaptive_n0_graduated_landscape_n0_effect.png) + +#### 11.15.6 Analysis + +**Adaptive n₀ is uniformly harmful on the hardest problems.** The head-to-head comparisons show clear regressions: + +| Problem | Fixed n₀ Config | Opt Rate | + an₀ | Opt Rate | Delta | +|---------|----------------|----------|-------|----------|-------| +| multigroup | vi + apv | **80%** | vi + apv + an₀ | 57% | **-23pp** | +| multigroup | apess | 37% | apess + an₀ | 47% | +10pp | +| large_sparse | apess | **53%** | apess + an₀ | 47% | **-6pp** | +| large_sparse | vi + apv | 47% | vi + apv + an₀ | 43% | **-4pp** | +| mixed | vi + apv | **100%** | vi + apv + an₀ | 97% | -3pp | +| needle | apess | 97% | apess + an₀ | **100%** | +3pp | +| graduated | apess | 87% | acomb + an₀ | **100%** | +13pp | +| simple_additive | apess | **100%** | apess + an₀ | 97% | -3pp | + +The largest regression is on **multigroup_interaction** (-23pp for vi+apv), the problem with the strongest cross-group interactions. The second-largest is on **large_sparse** (-6pp for apess), the massive search space where adaptive n₀ was expected to help most. + +**Why it failed — over-correction on top of an already good solution.** The NIG posterior's Student-t distribution already has heavier tails at low observation counts than the Normal distribution. At n=1 with α₀=1, the Student-t has df=3 — much heavier tails than a Normal. This already provides substantial protection against premature commitment from a single observation. Adding n₀ ≈ 3.4 (for 11 actions) on top of the heavy-tailed Student-t makes the posterior excessively conservative: the algorithm requires ~3-4 observations per child before posteriors differentiate, but with 11 children and 800 iterations, many children never reach that threshold. + +**Budget-limited regime.** The fundamental issue is that adaptive n₀ trades convergence speed for robustness to premature commitment, but the benchmarks operate in a budget-limited regime (600-800 iterations). With unlimited budget, higher n₀ would eventually converge to the same answer. Within the available budget, the slower convergence means fewer iterations in the exploitation phase, reducing the chance of finding the optimum. + +**Small-problem exception.** On graduated_landscape (375 combinations, 300 iterations), `acomb + an₀` achieves 100% — the only improvement. Here the search space is small enough that even with conservative posteriors, the algorithm explores exhaustively. But this problem was already solved by other configs (comb+apv, acomb both at 100% in §11.14). + +**The no-cache-hit modes (an₀, uniform + an₀) are weak.** Without cache-hit handling, adaptive n₀ alone achieves only 37% on multigroup and 17% on large_sparse. Adaptive n₀ cannot substitute for proper cache-hit strategies. + +#### 11.15.7 Conclusion + +Adaptive pseudo-count n₀ from branching factor is a **negative result**. The NIG Student-t posterior already provides sufficient protection against premature commitment at low observation counts, and inflating n₀ further only slows convergence. The existing recommendations remain unchanged: + +- **Default**: NIG + TS(g,a) + vi + apv (80% multigroup, 100% needle/mixed, 47% large_sparse) +- **Massive sparse spaces**: NIG + TS(g,a) + apess without APV (53% large_sparse) + +The adaptive n₀ code remains available (`adaptive_n0=True`) but is not recommended for any problem type tested. + +--- + +### 11.16 Sampling vs Optimizing: Empirical Gap Analysis + +This section presents the first empirical investigation toward the two-phase burn-in (§8.5, §11.12.3). The central question: **can cheap polytope samples replace expensive `optimize_acqf` calls for the purpose of ranking NChooseK subsets?** If so, MCTS could use cheap samples during burn-in to build tree structure, then run full optimization only on the winner. + +#### 11.16.1 Setup + +Two benchmarks test complementary regimes: + +| Benchmark | Dim | Subsets | Constraints | Continuous optimization | +|-----------|-----|---------|-------------|------------------------| +| `Hartmann(dim=6, allowed_k=4)` | 6 | 57 | NChooseK only (box bounds) | Easy — unconstrained per subset | +| `FormulationWrapper(Hartmann(), max_count=4)` | 7 | 57 | NChooseK + sum-to-1 equality + box | Hard — simplex-constrained per subset | + +For each benchmark: +1. Fit a GP via `SoboStrategy` on 20–30 random initial points +2. Extract the acquisition function (qLogEI) and bounds +3. For every NChooseK subset (57 total), compute: + - **Optimized value**: `optimize_acqf` with 20 restarts, 2048 raw samples (gold standard) + - **Sample-best value**: best of N polytope samples (hit-and-run via `sample_q_batches_from_polytope`), varying N ∈ {64, 256, 1024, 2048} +4. Compare rankings (Spearman rho, top-k overlap) and absolute gaps across 5 seeds + +The "optimized" baseline represents what the current MCTS reward function computes. The "sample-best" represents what a cheap burn-in evaluation would return. + +#### 11.16.2 Results: Unconstrained (Hartmann, box bounds) + +| n_samples | time | Spearman ρ | top-1 | top-3 | top-5 | mean gap | winner gap | +|-----------|------|-----------|-------|-------|-------|----------|------------| +| 64 | 0.2s | 0.972 | 0% | 40% | 56% | 0.803 | 0.239 | +| 128 | 0.3s | 0.971 | 20% | 53% | 56% | 0.685 | 0.123 | +| 256 | 0.4s | 0.972 | 0% | 47% | 52% | 0.662 | 0.058 | +| 512 | 0.8s | 0.980 | 0% | 47% | 64% | 0.595 | 0.049 | +| 1024 | 1.6s | 0.979 | 0% | 47% | 60% | 0.256 | 0.043 | +| 2048 | 3.2s | 0.981 | 0% | 53% | 72% | 0.239 | 0.024 | + +**Exhaustive optimization time**: ~4.5s (57 subsets × `optimize_acqf`). + +**Timing**: Sobol sampling is simple box-uniform sampling — no constraints. At 64 samples: **23x faster** per subset. At 2048: **1.4x faster** (diminishing returns). The cost is dominated by acqf forward passes, which scale linearly with sample count. + +**Pearson r**: 1.0000 — optimized and sample-best values are perfectly linearly correlated. + +**Gap scales with subset dimensionality**: + +| |subset| | n subsets | mean gap @2048 | median gap | max gap | +|----------|----------|---------------|------------|---------| +| 0 | 1 | 0.000 | 0.000 | 0.000 | +| 1 | 6 | 0.000 | 0.000 | 0.001 | +| 2 | 15 | 0.009 | 0.002 | 0.085 | +| 3 | 20 | 0.026 | 0.009 | 0.190 | +| 4 | 15 | 0.047 | 0.015 | 0.225 | + +More free dimensions = harder to find the optimum by sampling, but the ranking is preserved regardless. + +#### 11.16.3 Results: Constrained (FormulationWrapper, simplex) + +| n_samples | time | Spearman ρ | top-1 | top-3 | top-5 | mean gap | winner gap | +|-----------|------|-----------|-------|-------|-------|----------|------------| +| 64 | 5.3s | 0.970 | 20% | 47% | 60% | 0.502 | 0.166 | +| 256 | 15.4s | 0.978 | 0% | 53% | 72% | 0.118 | 0.104 | +| 1024 | 55.9s | 0.986 | 20% | 53% | 68% | 0.059 | 0.061 | +| 2048 | 109.9s | 0.987 | 0% | 53% | 60% | 0.042 | 0.039 | + +**Exhaustive optimization time**: ~140.9s (57 subsets × `optimize_acqf` with linear constraints). + +**Timing**: Polytope sampling uses hit-and-run (MCMC), which is far more expensive than box-uniform. At 64 samples: **26x faster** (5.3s vs 140.9s). At 2048: **1.3x faster** (109.9s vs 140.9s) — barely any speedup because `optimize_acqf` itself spends most of its time generating constrained initial points via the same hit-and-run sampler. + +**Pearson r**: 1.0000. + +**Gap scales with subset dimensionality**: + +| |subset| | n subsets | mean gap @2048 | median gap | max gap | +|----------|----------|---------------|------------|---------| +| 0 | 1 | 0.000 | 0.000 | 0.000 | +| 1 | 6 | 0.000 | 0.000 | 0.001 | +| 2 | 15 | 0.004 | 0.003 | 0.019 | +| 3 | 20 | 0.036 | 0.004 | 0.224 | +| 4 | 15 | 0.063 | 0.025 | 0.203 | + +#### 11.16.4 Low Top-1 Match Is Misleading + +Across both benchmarks, top-1 match is 0–20% despite ρ > 0.97. This is entirely due to **tied tiers**: the top 4–5 subsets often have nearly identical optimized values (differing by < 0.005). For example, seed 1 of the constrained benchmark has top-4 all at -4.2452. The sample ranking picks a different member of this tied group, but the actual acqf difference is negligible. The sample-based MCTS would converge to the same quality solution. + +#### 11.16.5 Key Findings + +1. **Ranking is near-perfect** (ρ = 0.97–0.99) even at just 64 samples per subset. The NIG-TS tree would make essentially the same explore/exploit decisions whether rewards come from optimization or sampling. + +2. **Values have systematic downward bias** — sample-best is always ≤ optimized. But NIG-TS adapts to the reward scale; it only needs correct *relative ordering*, not correct absolute values. + +3. **The bias is order-preserving** (Pearson r = 1.0). This means during burn-in, the NIG posteriors will rank subsets correctly even though the absolute values are pessimistic. After switching to accurate optimization, the accurate values dominate the posterior (as described in §8.5). + +4. **For constrained problems, the speedup ceiling is low at high sample counts**. The bottleneck shifts from L-BFGS to the hit-and-run sampler itself: at 2048 samples, polytope sampling costs 110s vs 141s for full optimization — only 1.3x faster. The implication: **use few samples per subset (64–128) to maximize the cost ratio**. + +5. **64 samples is the practical sweet spot for burn-in**: + - Unconstrained: 23x faster, ρ = 0.97, top-5 overlap = 56% + - Constrained: 26x faster, ρ = 0.97, top-5 overlap = 60% + - The ranking accuracy at 64 is nearly as good as at 2048, while the cost is 20x lower + +#### 11.16.6 Implications for Two-Phase Burn-in + +These results validate the two-phase burn-in design from §8.5: + +**Cost model for MCTS with burn-in** (constrained case, 57 subsets): + +| Approach | Cost per novel eval | Unique evals | Total cost | +|----------|-------------------|--------------|------------| +| Current (all optimize_acqf) | ~2.5s | ~40 | ~100s | +| All sampling (64/subset) | ~0.09s | ~40 | ~3.6s | +| Burn-in (64 samples) → 1× optimize on winner | ~0.09s burn-in + 2.5s final | ~40 + 1 | ~6.1s | + +The burn-in + single optimization approach is **~16x cheaper** than the current all-optimization approach for the constrained Hartmann problem, while identifying the same top subset tier. + +**Important caveat**: These numbers assume the acquisition surface is smooth and unimodal per subset (true for GP-based EI on Hartmann). On rougher surfaces with narrow peaks that only L-BFGS finds, the gap could widen. The empirical test on real BO loops (running full MCTS with sample-based vs optimization-based rewards, comparing final regret) is the necessary next step. + +**Script**: `mcts-report/investigate_sampling_vs_optimizing.py` + +### 11.17 No-Cache Mode: MCTS with Noisy Reward Functions + +This section presents the implementation and benchmark of **no-cache mode** (`use_cache=False`) for MCTS, the infrastructure required for two-phase burn-in (§8.5, §11.16). With noisy/sampling-based reward functions, caching is counterproductive: re-evaluating the same subset with different random draws provides genuinely new information for the NIG posterior. + +#### 11.17.1 Implementation + +Added `use_cache: bool = True` parameter to `MCTS.__init__` and `optimize_acqf_mcts()`. When `use_cache=False`: + +1. **`_cached_reward()`** always calls `reward_fn` fresh — never reads from or writes to the cache +2. **`is_novel`** is always `True` — every evaluation feeds the NIG posterior as a novel observation +3. **Rollout retry** is skipped — the cache is always empty, so retries are pointless +4. **Backpropagation** always takes the novel path (increments `n_obs`, `sum_rewards`, `sum_sq_rewards`, `n_visits`); cache-hit modes (variance_inflation, pessimistic, etc.) are never triggered +5. **NIG statistics** work correctly: repeated evaluations of the same subset with different noisy rewards widen the posterior (higher variance estimate), reflecting genuine uncertainty + +Default is `use_cache=True`, preserving all existing behavior. + +#### 11.17.2 Setup + +Three configurations compared across 4 noise levels (σ ∈ {0.5, 1.0, 2.0, 5.0}): + +| Config | `use_cache` | Noise | Description | +|--------|------------|-------|-------------| +| **Cached + deterministic** | True | None | Current default (reference baseline) | +| **No-cache + noisy** | False | N(0, σ²) − σ | Every eval is fresh + noisy; no caching | +| **Cached + noisy** | True | N(0, σ²) − σ | First noisy eval is cached; subsequent reads return that frozen value | + +The noise model includes a pessimistic bias of −σ, matching the systematic downward bias observed in §11.16 (sample-best values are always ≤ optimized values). + +**Evaluation**: The reported `final_best` is the **true** (noiseless) reward of the best selection found, ensuring fair comparison. 10 seeds per config per problem, using the same 6 benchmark problems as prior sections. + +#### 11.17.3 Results: Optimum-Finding Rate + +**σ = 0.5** (mild noise) + +| Problem | Cached+det | No-cache+noisy | Cached+noisy | +|---------|-----------|----------------|--------------| +| multigroup_interaction | **90%** | 30% | 60% | +| needle_in_haystack | **100%** | 60% | **100%** | +| mixed_nchoosek_categorical | **100%** | **100%** | **100%** | +| large_sparse | **40%** | **40%** | **40%** | +| graduated_landscape | 90% | 20% | **100%** | +| simple_additive | **90%** | 40% | 70% | + +**σ = 1.0** (moderate noise) + +| Problem | Cached+det | No-cache+noisy | Cached+noisy | +|---------|-----------|----------------|--------------| +| multigroup_interaction | **90%** | 60% | **90%** | +| needle_in_haystack | **100%** | 90% | **100%** | +| mixed_nchoosek_categorical | **100%** | 90% | 90% | +| large_sparse | **40%** | 20% | 10% | +| graduated_landscape | **90%** | 0% | 40% | +| simple_additive | **90%** | 30% | 80% | + +**σ = 2.0** (high noise) + +| Problem | Cached+det | No-cache+noisy | Cached+noisy | +|---------|-----------|----------------|--------------| +| multigroup_interaction | 90% | 80% | **100%** | +| needle_in_haystack | **100%** | 70% | **100%** | +| mixed_nchoosek_categorical | **100%** | **100%** | **100%** | +| large_sparse | 40% | **50%** | 30% | +| graduated_landscape | **90%** | 20% | 50% | +| simple_additive | **90%** | 0% | 60% | + +**σ = 5.0** (extreme noise) + +| Problem | Cached+det | No-cache+noisy | Cached+noisy | +|---------|-----------|----------------|--------------| +| multigroup_interaction | **90%** | 50% | 60% | +| needle_in_haystack | **100%** | 60% | **100%** | +| mixed_nchoosek_categorical | **100%** | **100%** | **100%** | +| large_sparse | 40% | **50%** | **50%** | +| graduated_landscape | **90%** | 0% | 10% | +| simple_additive | **90%** | 10% | 10% | + +#### 11.17.4 Results: Mean True-Best Reward + +**σ = 1.0** + +| Problem | Optimal | Cached+det | No-cache+noisy | Cached+noisy | +|---------|---------|-----------|----------------|--------------| +| multigroup_interaction | 150 | 144.4 | 133.3 | 145.9 | +| needle_in_haystack | 100 | 100.0 | 94.0 | 100.0 | +| mixed_nchoosek_categorical | 150 | 150.0 | 144.0 | 144.0 | +| large_sparse | 200 | 113.6 | 86.8 | 72.2 | +| graduated_landscape | 65 | 64.9 | 63.0 | 64.2 | +| simple_additive | 65 | 64.8 | 62.2 | 64.6 | + +**σ = 5.0** + +| Problem | Optimal | Cached+det | No-cache+noisy | Cached+noisy | +|---------|---------|-----------|----------------|--------------| +| multigroup_interaction | 150 | 144.4 | 128.3 | 132.7 | +| needle_in_haystack | 100 | 100.0 | 72.0 | 100.0 | +| mixed_nchoosek_categorical | 150 | 150.0 | 150.0 | 150.0 | +| large_sparse | 200 | 113.6 | 123.8 | 126.6 | +| graduated_landscape | 65 | 64.9 | 59.8 | 61.6 | +| simple_additive | 65 | 64.8 | 61.1 | 61.1 | + +#### 11.17.5 Convergence Curves + +Per-problem convergence plots for each noise level (note: convergence curves show the noisy `best_value` tracked by MCTS, not the true-best): + +**σ = 0.5**: +![Convergence σ=0.5 multigroup](nocache_convergence_multigroup_interaction_sigma0.5.png) +![Convergence σ=0.5 needle](nocache_convergence_needle_in_haystack_sigma0.5.png) +![Convergence σ=0.5 mixed](nocache_convergence_mixed_nchoosek_categorical_sigma0.5.png) +![Convergence σ=0.5 large_sparse](nocache_convergence_large_sparse_sigma0.5.png) +![Convergence σ=0.5 graduated](nocache_convergence_graduated_landscape_sigma0.5.png) +![Convergence σ=0.5 simple](nocache_convergence_simple_additive_sigma0.5.png) + +**σ = 2.0**: +![Convergence σ=2.0 multigroup](nocache_convergence_multigroup_interaction_sigma2.0.png) +![Convergence σ=2.0 needle](nocache_convergence_needle_in_haystack_sigma2.0.png) +![Convergence σ=2.0 mixed](nocache_convergence_mixed_nchoosek_categorical_sigma2.0.png) +![Convergence σ=2.0 large_sparse](nocache_convergence_large_sparse_sigma2.0.png) +![Convergence σ=2.0 graduated](nocache_convergence_graduated_landscape_sigma2.0.png) +![Convergence σ=2.0 simple](nocache_convergence_simple_additive_sigma2.0.png) + +**Cross-σ heatmap**: +![Optimum-finding rate heatmap](nocache_optimum_rate_heatmap.png) + +#### 11.17.6 Analysis + +**1. No-cache mode works correctly.** Every iteration calls `reward_fn` (evals always equals iteration count), the cache stays empty, `_novel_reward_count` equals `n_iterations`, and the NIG posterior handles repeated noisy observations of the same subset. + +**2. The NIG posterior is naturally robust to noise.** Even at σ = 2.0 (noise magnitude comparable to the reward gaps between subsets), no-cache mode finds the optimum on mixed_nchoosek_categorical (100%), multigroup_interaction (80%), and needle_in_haystack (70%). The NIG Student-t posterior absorbs the variance correctly: repeated evaluations of the same subset don't concentrate the posterior falsely, they widen it to reflect the genuine noise. + +**3. No-cache mode beats cached+noisy on large search spaces.** The most interesting result is on `large_sparse` (~960M combinations): + +| σ | No-cache+noisy | Cached+noisy | +|---|----------------|--------------| +| 1.0 | 20% / 86.8 | 10% / 72.2 | +| 2.0 | **50% / 127.6** | 30% / 99.6 | +| 5.0 | **50% / 123.8** | **50% / 126.6** | + +At σ = 2.0, no-cache mode achieves 50% optimum rate (mean 127.6) vs cached+noisy at 30% (mean 99.6). The mechanism: cached+noisy locks in one random draw per subset. If the first draw is unlucky (biased low for the optimal subset), that wrong ranking is frozen forever. No-cache mode re-evaluates with fresh noise, so the NIG posterior averages over multiple draws and eventually converges to the correct ranking. This effect is strongest in large search spaces where the tree needs many iterations to converge and the frozen-cache problem compounds. + +**4. Cached+noisy sometimes gets lucky.** At σ = 2.0 on multigroup_interaction, cached+noisy hits 100% while the deterministic baseline achieves 90%. This is a small-sample artifact (10 seeds): occasionally the first noisy draw happens to over-value the optimal subset, which accelerates convergence. This is not reliable behavior. + +**5. Small search spaces favor caching regardless of noise.** On graduated_landscape (375 combinations) and simple_additive (793 combinations), cached+noisy consistently outperforms no-cache+noisy. With so few unique subsets and moderate iteration budgets (300), the exhaustive cache exploration of cached mode is more valuable than the statistical averaging of no-cache mode. The reward gaps between top subsets (1–2 points) are smaller than the noise, so no-cache mode struggles to discriminate. + +**6. The pessimistic bias (−σ) doesn't hurt no-cache mode.** The −σ shift uniformly depresses all rewards, which is absorbed by the NIG prior center (`_global_mean()`). Since NIG-TS only cares about relative ordering, not absolute scale, the bias is harmless. This confirms that sampling-based burn-in (which produces systematically pessimistic values per §11.16.5) will work correctly. + +#### 11.17.7 Implications for Two-Phase Burn-in + +These results validate the key infrastructure for two-phase burn-in: + +1. **No-cache mode provides the correct semantics for sampling-based evaluation**: every call to the (cheap) sampling reward function produces a novel observation, and the NIG posterior correctly accumulates noisy statistics. The implementation is a single boolean flag (`use_cache=False`) with no other changes needed. + +2. **The noise regime matters for choosing the switch point**. At low noise (σ ≤ 1.0 relative to reward scale), no-cache mode degrades modestly. At moderate noise (σ = 2.0), it actually outperforms cached+noisy on the hardest problems. The §11.16 gap analysis showed sampling-based rewards have σ ≈ 0.04–0.80 (depending on subset dimensionality), well within the regime where no-cache mode performs acceptably. + +3. **The two-phase architecture should be**: + - **Phase 1 (burn-in)**: `use_cache=False` with a cheap sampling-based reward function (64 polytope samples per subset, §11.16). This builds tree structure and NIG statistics at ~26x lower cost per evaluation. + - **Phase 2 (refinement)**: Switch to `use_cache=True` with the expensive `optimize_acqf` reward function. The tree structure from phase 1 guides exploration. The accurate values dominate the NIG posterior. + +4. **The switch from phase 1 to phase 2 requires resetting the cache** (it would be empty anyway with `use_cache=False`), but the NIG node statistics (`n_obs`, `sum_rewards`, `sum_sq_rewards`) should be preserved — they represent genuine information about relative subset quality that transfers to the accurate reward scale. The NIG posterior will naturally re-center via the updated `_global_mean()` as accurate evaluations arrive. + +**Next step**: Implement the two-phase `reward_fn` wrapper that switches from sampling to optimization after N burn-in iterations, and benchmark it on the real BO loop (Hartmann with GP + qLogEI). + +#### 11.17.8 Files + +- **Implementation**: `bofire/strategies/predictives/optimize_mcts.py` — `use_cache` parameter in `MCTS` and `optimize_acqf_mcts()` +- **Tests**: `tests/bofire/strategies/test_optimize_mcts.py` — 5 tests for no-cache mode +- **Benchmark script**: `mcts-report/benchmark_no_cache.py` +- **Results**: `mcts-report/results_no_cache.json` +- **Plots**: `mcts-report/nocache_convergence_*.png`, `mcts-report/nocache_optimum_rate_heatmap.png`, `mcts-report/nocache_truebest_*.png` + +### 11.18 MCTS Replay on Real Acquisition Landscapes + +Sections §11.16–§11.17 benchmarked MCTS on synthetic reward functions with known structure. This section evaluates the **production MCTS-TS-NIG** on real acquisition function landscapes from a full Bayesian optimization loop on `Hartmann(dim=6, allowed_k=4)`. + +The key question: **how many MCTS iterations does it take to find the best NChooseK subset when the reward function is a real GP-based acquisition function?** + +#### 11.18.1 Data Collection + +A full BO loop was run for 40 iterations (starting from 20 random initial points), with **exhaustive evaluation** of all 57 NChooseK subsets at each iteration: + +| Parameter | Value | +|-----------|-------| +| Benchmark | Hartmann(dim=6, allowed_k=4) | +| Surrogate | SingleTaskGPSurrogate (bofire default) | +| Acquisition | qLogEI | +| Initial points | 20 (NChooseK-respecting, via RandomStrategy) | +| BO iterations | 40 | +| Subsets | 57 (all C(6,0)+C(6,1)+...+C(6,4)) | +| optimize_acqf per subset | 20 restarts, 2048 raw samples | +| Sobol samples per subset | 2048 | +| GP trajectory | Greedy oracle (always picks globally best subset) | + +At each BO iteration, the exhaustive `optimize_acqf` values provide the gold-standard reward for each of the 57 subsets. The GP trajectory is fixed by always selecting the globally best subset, so the dataset can be replayed with any MCTS configuration without affecting the GP's evolution. + +**Data collection script**: `mcts-report/collect_hartmann_data.py` +**Data**: `mcts-report/data/hartmann_nchoosek_seed0.{json,npz}` + +#### 11.18.2 Benchmark Setup + +For each of the 40 BO iterations, the production MCTS-TS-NIG (all defaults: `adaptive_prior_var=True`, `cache_hit_mode="variance_inflation"`, `rollout_mode="ts_group_action"`, `pw_k0=2.0`, `pw_alpha=0.6`) was run with a **lookup reward function** that returns the pre-computed `optimize_acqf` value for each subset. This isolates the tree search quality from the `optimize_acqf` runtime cost. + +- **MCTS budget**: 200 iterations per run +- **MCTS seeds**: 10 per BO iteration (to average over MCTS randomness) +- **Total runs**: 40 BO iterations × 10 seeds = 400 + +Metrics tracked per run: +- **First-hit**: MCTS iteration when the true best subset is first found +- **Top-k hit**: iteration when MCTS first holds a top-3 or top-5 subset +- **Regret curve**: `(true_best_value − MCTS_best_value)` at each iteration +- **Unique evaluations**: number of distinct subsets evaluated by each budget + +#### 11.18.3 Results: First-Hit Statistics + +| Target | Found rate | Median iters | Mean iters | P25 | P75 | P90 | +|--------|-----------|-------------|-----------|-----|-----|-----| +| True best (rank 1) | 295/400 (73.8%) | 56 | 68.7 | 22 | 111 | 150 | +| Top-3 subset | 384/400 (96.0%) | 24 | 38.7 | — | — | — | +| Top-5 subset | 396/400 (99.0%) | 14 | 23.3 | — | — | — | + +MCTS reliably identifies a **top-5 subset within ~14 iterations** (99% success). Finding the exact best takes longer (median 56, 74% success at budget 200), but in practice the top-5 subsets have very similar acquisition values — the regret from picking rank 3 instead of rank 1 is typically negligible. + +#### 11.18.4 Results: Regret at Iteration Budgets + +| Budget | Mean regret | Median | P90 | Zero regret % | Unique subsets | +|--------|-----------|--------|-----|--------------|----------------| +| 10 | 14.96 | 4.24 | 40.04 | 8.5% | 9.1 | +| 20 | 10.26 | 1.23 | 38.24 | 17.0% | 14.7 | +| 30 | 7.39 | 0.52 | 35.77 | 25.5% | 18.4 | +| 50 | 5.18 | 0.18 | 33.97 | 34.0% | 23.4 | +| 75 | 4.00 | 0.00 | 5.30 | 44.8% | 27.3 | +| 100 | 3.01 | 0.00 | 4.45 | 53.0% | 30.1 | +| 150 | 1.63 | 0.00 | 4.23 | 67.2% | 34.0 | +| 200 | 1.01 | 0.00 | 1.23 | 74.5% | 36.7 | + +The **median regret hits zero by iteration 75**, and by 200 iterations 74.5% of runs have zero regret. The mean regret remains elevated due to a heavy tail: a minority of BO iterations have peaky acquisition landscapes where a single dominant subset carries most of the acquisition value, and MCTS sometimes fails to discover it within 200 iterations. + +At budget 200, only **36.7 of 57 subsets** (64%) are evaluated on average. MCTS concentrates on the promising region of the combinatorial tree rather than exhaustively enumerating. + +#### 11.18.5 Results: Cumulative Discovery Rate + +| By MCTS iter | True best found | +|-------------|----------------| +| 5 | 1.8% | +| 10 | 8.5% | +| 20 | 17.0% | +| 30 | 25.2% | +| 50 | 33.8% | +| 75 | 44.5% | +| 100 | 52.8% | +| 150 | 66.5% | +| 200 | 73.8% | + +The curve is roughly linear on a log scale — no sharp elbow. This is consistent with the Thompson Sampling exploration mechanism: it continuously balances exploration and exploitation without hard phase transitions. + +![Regret convergence](replay_regret_convergence.png) + +![Cumulative discovery rate](replay_cumulative_found_rate.png) + +![First-hit histogram](replay_first_hit_histogram.png) + +#### 11.18.6 Results: Effect of BO Phase + +The acquisition landscape changes character as the GP accumulates data. Early iterations have a flatter, noisier landscape (prior-dominated GP), while late iterations have sharper peaks around the known optimum. + +![Phase regret comparison](replay_phase_regret.png) + +| BO phase | Iters | Found rate | Median first-hit | +|----------|-------|-----------|-----------------| +| Early (0–9) | Flat landscape | 76% | 24 | +| Mid (10–19) | Sharpening | 69% | 56 | +| Late (20–39) | Peaked | 75% | 66 | + +**Early BO iterations are easiest for MCTS** — the acquisition landscape is relatively flat (GP has little data), so many subsets have similar values and MCTS quickly finds a good one. The median first-hit of 24 iterations is notably faster than mid/late. + +**Mid BO iterations are hardest** — the GP is confident enough to create sharp reward differences between subsets (making regret large if the wrong one is picked) but the landscape hasn't yet concentrated to a single dominant subset. This is the regime where MCTS must actually discriminate between similar candidates. + +**Late BO iterations recover slightly** — by this point the GP strongly favors one subset (often `[2,3,4,5]` for Hartmann6), and the acquisition landscape has a clear peak. However, the absolute regret when MCTS picks the wrong subset is largest in this phase (visible in the wider P75 band in the phase plot), because the dominant subset's acquisition value towers over the rest. + +#### 11.18.7 Analysis + +**1. MCTS-TS-NIG is effective for this problem scale.** With 57 subsets, 200 MCTS iterations achieves 74% exact-best and 99% top-5. The current production default of `num_iterations=100` (from `optimize_acqf_mcts`) gives 53% exact-best and 96% top-3, which is reasonable for a problem where the top-3 subsets are typically near-tied. + +**2. The mean-median gap reveals a heavy tail.** Mean regret at 200 iterations (1.01) is much higher than median (0.00). Approximately 25% of runs never find the exact best within 200 iterations. Inspecting these failure cases: they occur on BO iterations where the best subset's acquisition value is separated from the runner-up by a large gap (>10 units), and MCTS commits to an early local optimum. The NIG variance inflation mechanism (which decays statistics on cache hits) helps but doesn't fully solve this — the tree gets too deep in the wrong branch before enough exploration happens. + +**3. Exploration efficiency: 37/57 subsets in 200 iterations.** MCTS evaluates ~64% of the space, confirming it's a focused search rather than enumeration. For the Hartmann(6, k≤4) problem with only 57 subsets, exhaustive enumeration is actually feasible (and what the data collection script does). MCTS becomes essential for larger problems where enumeration is intractable — e.g., the `large_sparse` benchmark with ~960M combinations. + +**4. The budget of 75 iterations is a practical sweet spot.** Median regret reaches zero, the P90 drops from ~34 to ~5 (a 7x reduction vs budget 50), and 27 subsets are explored (47% of the space). Beyond 75, returns diminish significantly. + +**5. Comparison with the synthetic benchmarks (§11.13).** The NIG-TS benchmark on synthetic problems (§11.13) showed ~90% optimum-finding rates at 300 iterations on `multigroup_interaction` (375 combinations) and `needle_in_haystack` (57 combinations). The 74% rate at 200 iterations on real acquisition landscapes is lower, suggesting that real GP-based acquisition functions produce harder reward surfaces than the smooth synthetic rewards. This motivates increasing the default iteration budget or adding a cheap burn-in phase. + +#### 11.18.8 Implications + +1. **For Hartmann-scale NChooseK problems (57 subsets)**, the current MCTS defaults work well. A budget of 100 iterations is a reasonable cost-accuracy tradeoff. + +2. **The top-5 convergence speed (14 iterations median) validates the burn-in concept.** Even a very short MCTS run with cheap rewards could identify the promising subset neighborhood, then a targeted `optimize_acqf` run on those 5 subsets would recover the exact best at low cost (5 × `optimize_acqf` instead of 57). + +3. **The heavy-tail failure mode suggests room for improvement.** The 25% of runs that don't find the best within 200 iterations are caused by early over-commitment. Possible mitigations: more aggressive progressive widening (`pw_k0 > 2`), explicit re-exploration triggers, or a warm-start from Sobol-based subset ranking (§11.16). + +#### 11.18.9 Files + +- **Data collection**: `mcts-report/collect_hartmann_data.py` +- **Benchmark script**: `mcts-report/benchmark_mcts_on_data.py` +- **Plot script**: `mcts-report/plot_mcts_replay.py` +- **Data**: `mcts-report/data/hartmann_nchoosek_seed0.{json,npz}` +- **Plots**: `mcts-report/replay_regret_convergence.png`, `mcts-report/replay_cumulative_found_rate.png`, `mcts-report/replay_first_hit_histogram.png`, `mcts-report/replay_phase_regret.png` + +### 11.19 MCTS Replay on Spurious Features Benchmark (Hartmann6 + 2sp, k≤6) + +Section §11.18 tested MCTS on Hartmann(6, k≤4) with 57 subsets. This section scales up to a harder, more realistic problem: **Hartmann(6) with 2 spurious features and max_count=6**, using `SpuriousFeaturesWrapper`. This creates 8 total features and 247 NChooseK subsets — a 4.3× increase in search space. The MCTS must now discover that 2 of 8 features are irrelevant, making the combinatorial structure harder. + +#### 11.19.1 Data Collection + +| Parameter | Value | +|-----------|-------| +| Benchmark | SpuriousFeaturesWrapper(Hartmann(dim=6), n_spurious=2, max_count=6) | +| Total features | 8 (6 original + 2 spurious) | +| Subsets | 247 (all C(8,0)+C(8,1)+...+C(8,6)) | +| Initial points | 10 (NChooseK-respecting, via RandomStrategy) | +| BO iterations | 40 | +| optimize_acqf per subset | 20 restarts, 2048 raw samples | +| Sobol samples per subset | 2048 | +| GP trajectory | Greedy oracle | + +**Sobol-vs-optimize_acqf correlation**: The Spearman rank correlation between the Sobol-best ranking and the optimize_acqf ranking is **ρ = 0.988** (mean across 40 iterations, min 0.942, max 0.999). This confirms that cheap Sobol sampling is a highly reliable proxy for expensive optimize_acqf, validating the burn-in concept for this problem scale. + +**Data**: `mcts-report/data/hartmann6_sp2_k6_seed0.{json,npz}` (60.5 MB with Sobol) + +#### 11.19.2 MCTS Benchmark Results + +Same setup as §11.18.2: production MCTS-TS-NIG defaults, 200 iteration budget, 5 MCTS seeds per BO iteration, 200 total runs. + +**First-hit statistics:** + +| Target | Found rate | Median iters | Mean iters | P25 | P75 | P90 | +|--------|-----------|-------------|-----------|-----|-----|-----| +| True best (rank 1) | 35/200 (17.5%) | 98 | 95.7 | 58 | 127 | 174 | +| Top-3 subset | 67/200 (33.5%) | 86 | 93.2 | — | — | — | +| Top-5 subset | 96/200 (48.0%) | 76 | 85.4 | — | — | — | + +**Regret at iteration budgets:** + +| Budget | Mean regret | Median | P90 | Zero regret % | Unique subsets | +|--------|-----------|--------|-----|--------------|----------------| +| 10 | 18.92 | 6.78 | 37.90 | 0.5% | 9.5 | +| 20 | 15.54 | 4.22 | 37.47 | 1.0% | 18.0 | +| 50 | 8.60 | 1.20 | 36.13 | 4.0% | 34.8 | +| 100 | 3.64 | 0.26 | 4.75 | 9.0% | 49.7 | +| 150 | 2.07 | 0.05 | 3.75 | 14.0% | 58.5 | +| 200 | 1.34 | 0.01 | 2.55 | 17.5% | 64.9 | + +**Comparison with Hartmann6_k4 (§11.18):** + +| Metric | Hartmann6_k4 (57 subsets) | Hartmann6+2sp (247 subsets) | +|--------|--------------------------|-------------------------------| +| Exact-best found rate | 73.8% | 17.5% | +| Top-5 found rate | 99.0% | 48.0% | +| Median first-hit (exact) | 56 | 98 | +| Unique subsets at 200 | 37 (64%) | 65 (26%) | + +The 4.3× increase in search space (57 → 247) dramatically reduces exact-best discovery. However, this metric is misleading — the acquisition landscape structure explains why. + +![Regret convergence (sp2)](replay_hartmann6_sp2_k6_regret_convergence.png) + +![Cumulative discovery rate (sp2)](replay_hartmann6_sp2_k6_cumulative_found_rate.png) + +#### 11.19.3 Acquisition Landscape Analysis: Bimodal Structure + +The acquisition function values across 247 subsets reveal a **strongly bimodal distribution**. At each BO iteration, subsets split into two clusters separated by a large gap (~15-20 acqf units): + +- **Good cluster**: subsets containing the relevant Hartmann features (typically 11–215 subsets depending on BO iteration, with acqf values in the range [-12, -3]) +- **Bad cluster**: subsets dominated by spurious features or with too few active features (acqf values in the range [-44, -37]) + +The gap between the worst good-cluster subset and the best bad-cluster subset is typically **15–20 acqf units**, making the two clusters clearly distinguishable. However, **within the good cluster, the top subsets are nearly indistinguishable**: + +| Gap metric | Hartmann6_k4 (57 subsets) | Hartmann6+2sp (247 subsets) | +|-----------|--------------------------|-------------------------------| +| Median gap(1st-2nd) | 0.81 | **0.00** | +| Median gap(1st-5th) | 4.23 | **0.02** | +| Subsets within 1% of best | few | **~25 subsets (10%)** | +| Subsets within 5% of best | few | **~50 subsets (20%)** | + +The top-10 subsets have a median regret of only **0.03** from the best. The flat top makes "find the exact best" an ill-defined objective — the acqf landscape does not meaningfully distinguish between the top ~10 subsets. + +![Acqf value distribution across subsets](acqf_distribution_hartmann6_sp2_k6.png) + +![Top-20 subsets by acqf value](acqf_top20_hartmann6_sp2_k6.png) + +![Regret by subset rank](acqf_regret_by_rank_hartmann6_sp2_k6.png) + +#### 11.19.4 Good-Cluster Discovery: The Right Metric + +Given the bimodal structure, the natural question is: **does MCTS reliably find the good cluster?** We define the good cluster per BO iteration using the largest-gap heuristic (find the biggest jump between consecutive sorted acqf values, split there). + +| Metric | Rate | +|--------|------| +| Good cluster at iter 200 | **96.5%** (386/400) | +| Good cluster at iter 100 | 91.5% | +| Good cluster at iter 50 | 79.0% | +| Good cluster at iter 10 | 52.8% | +| Median first-hit to good cluster | **8 iterations** | +| Mean first-hit to good cluster | 25.3 | +| P25/P75 first-hit | 3 / 31 | + +MCTS reaches the good cluster in a **median of 8 iterations** — far faster than the 98-iteration median for exact-best. At a budget of 200, 96.5% of runs are in the good cluster. + +**Conditional regret**: Runs that land in the good cluster have median regret 0.01 and mean regret 0.72. Runs stuck in the bad cluster (3.5% of cases) have median regret ~32 — they are trapped in the wrong mode and never escape within 200 iterations. + +![Good cluster vs exact best discovery rate](good_cluster_vs_exact_hartmann6_sp2_k6.png) + +#### 11.19.5 Good-Cluster Size Drives Difficulty + +The good-cluster size varies from 11 to 215 subsets across BO iterations, and this strongly predicts MCTS difficulty: + +- **Large good cluster (>100 subsets)**: MCTS finds it in **2-5 iterations** with 100% success. Early BO iterations (prior-dominated GP) and iterations where many feature combinations perform similarly fall here. +- **Small good cluster (11-26 subsets)**: Takes **30-80 median iterations** and occasionally fails (60-90% success). Later BO iterations with sharper, more concentrated landscapes produce these. + +![Good-cluster discovery vs cluster size](good_cluster_scatter_hartmann6_sp2_k6.png) + +#### 11.19.6 Regret Convergence: Good Cluster vs Bad + +The regret curves split cleanly into two trajectories. The 96.5% of runs that find the good cluster converge steadily toward zero regret. The 3.5% that remain in the bad cluster plateau at regret ~32-35 and show no convergence — once MCTS commits to the wrong mode, the NIG posterior accumulates enough negative evidence to trap it. + +![Regret: good cluster vs bad cluster](regret_good_vs_bad_hartmann6_sp2_k6.png) + +#### 11.19.7 Analysis + +**1. The "exact best" metric is misleading for flat-top landscapes.** The 17.5% exact-best rate appears alarming but reflects an ill-posed question: asking MCTS to distinguish between subsets differing by <0.01 in a range of 40. The 96.5% good-cluster rate is the operationally relevant metric. + +**2. MCTS solves the structural problem reliably.** The bimodal structure — good subsets (relevant features) vs bad subsets (spurious features) — is the real combinatorial challenge. MCTS identifies the good cluster in a median of 8 iterations, well within any practical budget. + +**3. Within-cluster ranking is noise-dominated.** Once in the good cluster, MCTS's final pick among near-tied subsets is essentially random. This is not a problem in practice: all good-cluster subsets lead to similar BO candidates, and the GP will correct for any suboptimality in subsequent iterations. + +**4. The 3.5% failure rate represents genuine MCTS failures.** These runs get trapped in the bad cluster (spurious-feature-dominated subsets). The NIG mechanism makes escaping the wrong mode difficult once evidence accumulates. This motivates potential improvements: explicit mode-switching restarts, or a two-phase strategy where an initial cheap exploration identifies the modes before committing. + +**5. The Sobol correlation validates burn-in.** With ρ = 0.988, a cheap Sobol-based burn-in could pre-filter to the good cluster before running expensive MCTS. Even 64 Sobol samples per subset (vs 2048 in the dataset) might suffice to identify the bimodal split, given the ~20-unit gap between clusters. + +#### 11.19.8 Implications + +1. **For practical BO with spurious features**: MCTS with 50-100 iterations is sufficient — it reaches 79-91.5% good-cluster rate, and any good-cluster subset produces a near-optimal acquisition candidate. The exact-best rate is irrelevant. + +2. **Budget recommendations scale with search space**: Hartmann6_k4 (57 subsets) needed ~75 iterations for median-zero regret. Hartmann6+2sp (247 subsets) needs ~100 iterations for 91.5% good-cluster rate. The scaling is sub-linear in the number of subsets. + +3. **Bimodal landscape structure is a gift**: The clear separation between good and bad subsets makes the problem fundamentally easier than it appears from the raw subset count. MCTS with Thompson Sampling naturally exploits this — it quickly abandons branches that return low rewards, effectively pruning the bad cluster. + +4. **The flat-top phenomenon suggests diminishing returns from finer search**: Beyond finding the good cluster, additional MCTS iterations mostly shuffle between near-equivalent subsets. This is wasted compute. A better strategy: run MCTS briefly (50-100 iters) to identify the good cluster, then run `optimize_acqf` on the top-5 subsets found. + +#### 11.19.9 Files + +- **Data collection**: `mcts-report/collect_hartmann_data.py --benchmark hartmann6_sp2_k6` +- **Benchmark script**: `mcts-report/benchmark_mcts_on_data.py --benchmark hartmann6_sp2_k6` +- **Plot script**: `mcts-report/plot_mcts_replay.py --benchmark hartmann6_sp2_k6` +- **Data**: `mcts-report/data/hartmann6_sp2_k6_seed0.{json,npz}` +- **Acqf analysis plots**: `acqf_distribution_hartmann6_sp2_k6.png`, `acqf_top20_hartmann6_sp2_k6.png`, `acqf_regret_by_rank_hartmann6_sp2_k6.png` +- **Good-cluster plots**: `good_cluster_vs_exact_hartmann6_sp2_k6.png`, `good_cluster_scatter_hartmann6_sp2_k6.png`, `regret_good_vs_bad_hartmann6_sp2_k6.png` +- **MCTS replay plots**: `replay_hartmann6_sp2_k6_regret_convergence.png`, `replay_hartmann6_sp2_k6_cumulative_found_rate.png`, `replay_hartmann6_sp2_k6_first_hit_histogram.png`, `replay_hartmann6_sp2_k6_phase_regret.png` + +### 11.20 Sobol Sampling as a Cheap Proxy for optimize_acqf + +Sections §11.18–§11.19 used expensive `optimize_acqf` (20 restarts × 2048 raw samples per subset) as the gold-standard reward. A key question for burn-in design: **can cheap Sobol sampling reliably rank subsets?** If so, MCTS could use Sobol-evaluated rewards during an initial exploration phase, switching to `optimize_acqf` only for the final top-k subsets. + +Both datasets include 2048 Sobol samples per subset per BO iteration. By subsampling to smaller budgets (16, 32, 64, ..., 2048), we measure how Sobol ranking quality degrades with fewer samples. + +#### 11.20.1 Rank Correlation vs Sobol Budget + +The Spearman rank correlation between `max(Sobol samples)` and `optimize_acqf` value across all subsets: + +| Sobol N | Hartmann6_k4 (57 subsets) | | Hartmann6+2sp (247 subsets) | | +|---------|---------------------------|---|------------------------------|---| +| | Mean ρ | Min ρ | Mean ρ | Min ρ | +| 16 | 0.911 | 0.628 | 0.920 | 0.733 | +| 32 | 0.935 | 0.646 | 0.944 | 0.784 | +| **64** | **0.953** | **0.660** | **0.959** | **0.847** | +| 128 | 0.969 | 0.820 | 0.969 | 0.894 | +| 256 | 0.979 | 0.874 | 0.976 | 0.924 | +| 512 | 0.986 | 0.909 | 0.983 | 0.938 | +| 2048 | 0.991 | 0.949 | 0.988 | 0.942 | + +**At 64 Sobol samples**, the mean correlation is already ρ ≈ 0.95 for both benchmarks, crossing the practical threshold for reliable ranking. Doubling to 128 eliminates the worst-case outliers (min ρ jumps from 0.66 to 0.82 for Hartmann6_k4). The spurious-features benchmark actually has *higher* minimum correlations at low sample counts — the bimodal structure (large gap between clusters) is easy to detect even with few samples. + +![Sobol correlation vs budget](sobol_correlation_vs_budget.png) + +#### 11.20.2 Scatter: Sobol(64) vs optimize_acqf + +Scatter plots of Sobol-best(64) vs optimize_acqf for representative BO iterations (early, mid, late) on both benchmarks. Red stars mark the top-5 subsets by optimize_acqf. + +For Hartmann6_k4 (top row), the correlation is strong but the top-5 subsets are sometimes re-ordered by Sobol noise — the fine-grained ranking within the top cluster is unreliable at 64 samples. For Hartmann6+2sp (bottom row), the bimodal separation is strikingly clear: the two clusters are visually distinct in both the x-axis (optimize_acqf) and y-axis (Sobol), with no overlap. + +![Sobol(64) scatter](sobol64_scatter_both.png) + +#### 11.20.3 Bimodal Cluster Identification with Sobol(64) + +For the Hartmann6+2sp benchmark, we test whether Sobol(64) can identify the bimodal cluster structure. At each BO iteration, we apply the largest-gap heuristic independently to both the optimize_acqf ranking and the Sobol(64) ranking, then measure classification accuracy (does Sobol assign each subset to the correct cluster?). + +| Metric | Value | +|--------|-------| +| Mean cluster accuracy | **96.5%** | +| Min cluster accuracy | 89.1% | +| Max cluster accuracy | 100.0% | +| False positives (bad→good) | **0 across all 40 iterations** | +| False negatives (good→bad) | 0–27 per iteration | + +The false-positive rate is **exactly zero** — Sobol(64) never promotes a bad-cluster subset into the good cluster. All errors are false negatives: some good-cluster subsets are misclassified as bad. This is a safe failure mode for burn-in: the cheap Sobol pass might miss some good subsets, but it will never waste expensive `optimize_acqf` budget on bad ones. + +The errors concentrate in late BO iterations where the good cluster shrinks to 11–26 subsets (out of 247). In these cases, the good cluster's internal value range narrows, making the boundary between good and bad harder to detect with only 64 samples. Even so, the minimum accuracy is 89.1%. + +![Sobol(64) bimodal cluster identification](sobol64_bimodal_clusters.png) + +#### 11.20.4 Top-k Agreement + +Beyond rank correlation, we measure a stricter metric: **what fraction of the optimize_acqf top-k subsets appear in the Sobol top-k?** This directly measures whether Sobol can identify the best subsets for targeted `optimize_acqf` follow-up. + +**Hartmann6_k4 (57 subsets):** + +| Sobol N | Top-1 | Top-3 | Top-5 | Top-10 | Top-20 | +|---------|-------|-------|-------|--------|--------| +| 16 | 33% | 42% | 50% | 71% | 78% | +| 64 | 40% | 48% | 71% | 81% | 88% | +| 256 | 58% | 63% | 77% | 88% | 91% | +| 2048 | 68% | 77% | 87% | 91% | 95% | + +**Hartmann6+2sp (247 subsets):** + +| Sobol N | Top-1 | Top-3 | Top-5 | Top-10 | Top-20 | +|---------|-------|-------|-------|--------|--------| +| 16 | 5% | 15% | 21% | 35% | 53% | +| 64 | 10% | 20% | 31% | 45% | 61% | +| 256 | 10% | 27% | 30% | 56% | 75% | +| 2048 | 11% | 27% | 38% | 65% | 81% | + +For Hartmann6_k4, Sobol(64) identifies 71% of the true top-5 — good enough for a burn-in that narrows the search to ~10 subsets before running expensive optimization. For Hartmann6+2sp, the top-k overlap is lower because the flat top makes exact top-k identification nearly random within the good cluster. However, **the top-20 overlap at 64 samples is 61%** — Sobol reliably identifies the good-cluster neighborhood even if it can't rank within it. + +The low top-1 agreement on Hartmann6+2sp (~10% even at 2048 samples) confirms the finding from §11.19.3: the "best" subset is not meaningfully distinct from its neighbors. Top-1 is the wrong metric for this landscape. + +![Top-k agreement](sobol_topk_agreement.png) + +#### 11.20.5 Implications for Burn-In Design + +1. **64 Sobol samples per subset is the practical sweet spot.** It achieves ρ ≈ 0.95 mean correlation, 96.5% cluster identification accuracy, and 61-71% top-5 overlap. Each evaluation is ~30× cheaper than `optimize_acqf` (no multi-start optimization), making a full 247-subset burn-in feasible in seconds. + +2. **The zero false-positive property is critical.** A burn-in that filters to the Sobol-identified good cluster will never waste `optimize_acqf` budget on spurious-feature subsets. The only risk is missing some good subsets (false negatives), which is mitigated by MCTS exploration in the second phase. + +3. **Proposed two-phase strategy:** + - Phase 1 (burn-in): Evaluate all subsets with 64 Sobol samples. Identify the good cluster via largest-gap. Cost: 247 × 64 = 15,808 acqf evaluations (batched, ~1 second). + - Phase 2 (MCTS): Run MCTS with `optimize_acqf` rewards, but only on the good-cluster subsets (typically 26–215 subsets). This effectively halves the search space on hard iterations and eliminates the 3.5% bad-cluster trapping failure mode. + +4. **For Hartmann6_k4 (57 subsets), burn-in is less critical.** The search space is small enough that MCTS with 75-100 iterations achieves median-zero regret without any pre-filtering. Burn-in becomes valuable at 200+ subsets. + +#### 11.20.6 Files + +- **Sobol correlation analysis plots**: `sobol_correlation_vs_budget.png`, `sobol64_scatter_both.png`, `sobol64_bimodal_clusters.png`, `sobol_topk_agreement.png` +- **Source data**: `mcts-report/data/hartmann_nchoosek_seed0.npz` (Sobol: 40×57×2048), `mcts-report/data/hartmann6_sp2_k6_seed0.npz` (Sobol: 40×247×2048) + +### 11.21 Deterministic Reward + Cache Degradation and Acqf Landscape Collapse + +This section documents two failure modes discovered when integrating MCTS with real GP-based acquisition functions inside a BO loop: (1) a tree-mechanics bug where deterministic rewards + caching breaks progressive widening and NIG posteriors, and (2) a fundamental signal collapse where the acqf landscape becomes too spiky for cheap evaluation as the GP converges. + +#### 11.21.1 Setup + +The investigation uses SpuriousFeaturesWrapper(Hartmann(dim=6), n_spurious_features=6, max_count=6) — 12 features, C(12,6) = 924 possible NChooseK selections. A full BO trajectory (80 experiments) was saved to `experiments.csv` and used to reconstruct the GP state for reproducible diagnostics. + +Two reward functions from the notebook were tested, matching the production MCTS reward evaluation patterns: + +- **reward_fn (Sobol, noisy):** Draw 64 Sobol quasi-random samples within the active feature bounds, evaluate the acqf in batch, return the max. Each call returns a different value for the same selection due to Sobol scrambling. +- **reward_fn2 (optimize_acqf, deterministic):** Run `optimize_acqf(num_restarts=1, raw_samples=64)` with inactive features fixed to zero. Returns a near-deterministic value for each selection. + +Two surrogates were compared: + +- **Standard GP:** Default SingleTaskGP from BoTorch. +- **SAAS GP:** `EnsembleMapSaasSingleTaskGPSurrogate` with sparsity-inducing priors that shrink irrelevant feature lengthscales toward infinity. + +#### 11.21.2 Bug: Deterministic Reward + Cache Breaks Tree Statistics + +**Root cause.** When `use_cache=True` with a deterministic reward function, each unique feature selection gets exactly one novel evaluation. Subsequent MCTS visits to the same terminal are cache hits that increment `n_visits` but not `n_obs`. This creates two compounding failures: + +1. **Variance inflation is a no-op.** The `variance_inflation` cache-hit mode (line 855 of `optimize_mcts.py`) has a guard `if n.n_obs > 1` — with deterministic rewards and caching, 80-85% of tree nodes have `n_obs ≤ 1`, making variance inflation unable to fire on the vast majority of nodes. The NIG posteriors stay near-prior (posterior scale ratio ~0.6-0.8 vs ~0.03-0.1 for healthy trees). + +2. **Progressive widening over-expands the tree.** `_child_limit` (line 586) uses `n_visits` to decide how many children a node may have. With `n_visits` inflated by cache hits while `n_obs` stays at 1, the tree fans out far beyond what the available information supports. At the root node: `limit(n_visits)=126` vs `limit(n_obs)=2`. + +**Diagnostic evidence** (1000 MCTS iterations, uniform_subset rollout): + +| Config | Cache hit rate | Nodes with n_obs ≤ 1 | Gold mean | Gold best | +|--------|---------------|---------------------|-----------|-----------| +| A: Sobol, no cache (baseline) | 0% | 33/187 | -6.11 | -5.90 | +| B: optimize_acqf, cache + var_inflation | 79% | **216/260 (83%)** | **-25.03** | -6.06 | +| C: optimize_acqf, no cache | 0% | 31/207 | -6.04 | -5.86 | +| D: optimize_acqf, cache + pessimistic | 41% | 31/398 | -6.13 | **-4.04** | + +"Gold mean/best" = quality of 10 selections sampled from the fitted tree, each evaluated with `optimize_acqf(num_restarts=20, raw_samples=2048)`. Config B's MCTS best value during training was -3.59 (it *found* a good selection), but the tree cannot *reproduce* it because the NIG posteriors never learned which branches were good. + +At 3000 iterations the degradation worsens: config B's cache hit rate climbs to 93%, gold mean drops to -44.79, and the tree samples empty selections and single-feature selections. Config D (pessimistic) remains the healthiest — it still produces diverse selections and found the best individual selection (-2.80). + +The rollout mode (uniform_subset vs ts_group_action) does not change the core finding. Config B is broken with both policies; ts_group_action is slightly worse because its per-action NIG stats suffer from the same stale-statistics problem. + +**Fixes tested:** + +| Fix | Mechanism | Effective? | +|-----|-----------|------------| +| `use_cache=False` (config C) | Every call is novel, n_obs == n_visits | Yes — but wastes compute on redundant re-evaluations, and over-concentrates at high iteration counts | +| `cache_hit_mode="pessimistic"` (config D) | Adds pseudo-observations on cache hits, keeping n_obs growing | **Yes — best overall.** Tree stays healthy, diverse, no progressive widening gap | +| Base `_child_limit` on `n_obs` instead of `n_visits` | Stops over-expansion directly | Addresses symptom (1) but not (2); should be combined with pessimistic mode | + +**Recommendation:** Use `pessimistic` as the default cache_hit_mode when `use_cache=True`. The `variance_inflation` mode is fundamentally incompatible with deterministic or near-deterministic reward functions. Additionally, basing `_child_limit` on `n_obs` rather than `n_visits` would eliminate the progressive widening gap regardless of cache-hit mode. + +#### 11.21.3 Acqf Landscape Collapse with Convergent GPs + +Independent of the cache bug, we discovered that the acqf landscape itself becomes hostile to cheap evaluation as the BO loop converges. This affects all surrogates but is dramatically accelerated by the SAAS prior. + +**Mechanism.** As the GP becomes more certain about which feature subset is optimal, the acqf concentrates into a narrow spike over that subset. The spike's basin shrinks while the floor (acqf value for uninteresting subsets) remains flat. This is correct GP/acqf behavior — it's exploitation — but it makes Sobol-based reward evaluation increasingly unreliable. + +**Quantitative comparison** (Sobol(64) reward for selection [0,1,2,3,4,5] vs the true gold-standard `optimize_acqf(20 restarts, 2048 samples)` value): + +| n_data | Standard GP | | SAAS GP | | +|--------|-------------|--|---------|--| +| | Sobol(64) | Gold | Sobol(64) | Gold | +| 10 | -5.01 | -4.95 | -3.99 | -3.99 | +| 20 | -4.75 | -4.31 | -4.18 | -3.84 | +| 30 | -5.64 | -3.99 | -2.10 | -1.82 | +| 40 | -4.18 | -1.84 | **-6.61** | -2.02 | +| 60 | -5.06 | -3.02 | **-39.77** | -6.30 | +| 80 | -6.94 | -3.76 | **-44.07** | -4.72 | + +For the standard GP, the Sobol reward degrades gradually but remains in a usable range throughout (worst case: -6.94 vs gold -3.76, a 2x gap). For SAAS, the signal collapses catastrophically: at n=60, Sobol returns -39.77 (essentially the floor of -45.04) while the true optimum is -6.30. By n=80, the random subset spread has std=0.32 — everything looks the same to the Sobol evaluator. + +**Why SAAS collapses faster.** The SAAS sparsity prior shrinks irrelevant feature lengthscales toward infinity, effectively zeroing out spurious dimensions in the GP's predictive mean. This concentrates the acqf into a lower-dimensional subspace where the remaining active dimensions form a sharp, narrow ridge. The standard GP spreads uncertainty more evenly across all dimensions, keeping the acqf surface smoother. + +#### 11.21.4 Sobol Budget Sweep + +To determine how many Sobol samples are needed to maintain signal, we swept n_sobol ∈ {64, 128, 256, 512, 1024, 2048, 4096} for the SAAS surrogate at each BO stage: + +| n_data | n=64 | n=128 | n=256 | n=512 | n=1024 | n=2048 | n=4096 | +|--------|------|-------|-------|-------|--------|--------|--------| +| 10 | -4.32 | -4.32 | -4.32 | -4.31 | -4.32 | -4.31 | -4.31 | +| 30 | -2.07 | -1.62 | -1.84 | -1.73 | -1.64 | -1.63 | -1.78 | +| 40 | -6.61 | -4.49 | -2.64 | -2.58 | -2.42 | -2.87 | -2.24 | +| 60 | -40.76 | -40.58 | -39.83 | -39.13 | **-9.67** | -9.94 | -8.69 | +| 80 | -43.41 | -40.29 | **-10.22** | -9.19 | -9.00 | -8.11 | -7.92 | + +Values shown are Sobol reward for [0,1,2,3,4,5] (gold standard: -4.74 at n=80). Bold marks the budget where the signal first breaks through the floor. + +Key transitions: + +- **n=10–30:** 64 Sobol samples suffice. The surface is smooth. +- **n=40:** 128-256 samples recover the signal (from -6.61 to -2.64). +- **n=60:** **1024 samples needed** to break through (-9.67 vs -40.76 at n=64). Below 512, the peak is too narrow to hit. +- **n=80:** **256 samples** break through for [0,1,2,3,4,5] (-10.22), but this is inconsistent across selections — [0,2,3,4,5,8] only breaks through at n=1024. **1024 is the safe minimum for reliability.** + +The standard GP shows no such dependency — even 64 Sobol samples gives reliable signal at all stages, because its acqf surface stays smooth. + +#### 11.21.5 General Landscape Convergence Problem + +The acqf spike narrowing is not specific to SAAS or spurious features. It is a fundamental consequence of BO convergence: as the GP becomes confident about the optimal feature subset, the acqf landscape necessarily becomes unimodal with a narrow peak. Any surrogate that converges well will eventually produce a landscape where cheap random probing fails. + +The SAAS model reaches this point faster (around n=40-60) because its sparsity prior accelerates convergence. The standard GP gets there more slowly (the signal degrades but never fully collapses in our 80-experiment trajectory). But given enough BO iterations, any competent surrogate will produce a spiky acqf landscape. + +This means the MCTS reward evaluation strategy must adapt to the BO stage: + +- **Early BO (uncertain GP, multimodal acqf):** Sobol reward works. The landscape is smooth, many subsets look promising, and MCTS exploration across subsets adds genuine value. This is where MCTS's combinatorial search capability matters most. +- **Late BO (confident GP, unimodal acqf):** The subset question is largely settled. The acqf is saying "use this subset" and cheap evaluation can't resolve the narrow spike. + +#### 11.21.6 Recommendations + +1. **Use `pessimistic` as the default `cache_hit_mode`** when `use_cache=True`. It is the only mode that remains healthy with both noisy and deterministic reward functions, and it keeps the tree exploring diverse selections at high iteration counts. + +2. **Base `_child_limit` on `n_obs`** instead of `n_visits` to eliminate the progressive widening gap. When `use_cache=False`, `n_obs == n_visits` so behavior is unchanged. + +3. **Set `n_sobol_samples=1024`** (up from 64) in the two-phase screening path. This extends reliable signal through ~60+ data points for SAAS surrogates and has no cost for standard GPs (where 64 already suffices). The screening phase evaluates each selection once, so the cost increase is 16× per evaluation but evaluation count stays the same. + +4. **Adaptive Sobol budget (future work).** Track the empirical reward variance across MCTS iterations. When spread collapses (std drops below a threshold), either increase the Sobol budget or switch to `optimize_acqf`-based evaluation with pessimistic caching. + +5. **Convergence detection (future work).** When the top-k subsets from successive BO iterations stabilize (e.g., Jaccard similarity > 0.8 between consecutive iterations' top-5), skip the MCTS screening phase and directly run `optimize_acqf` on the known-good subsets. At full convergence, the combinatorial search problem is solved and only the continuous optimization within the chosen subset matters. + +#### 11.21.7 Exploration/Exploitation Transition in BO Campaigns + +The landscape convergence problem (§11.21.5) maps directly to the standard exploration/exploitation tradeoff in Bayesian Optimization, but at the combinatorial level of subset selection rather than the continuous level within a subset. + +**MCTS solves the exploration problem:** "which feature subsets might be good?" Running `optimize_acqf` on a known-good subset is pure exploitation: "what's the best continuous point within this subset?" As the BO campaign progresses, the relative value of these two components shifts: + +- **Early BO (uncertain GP):** The surrogate is uncertain, many subsets could be optimal, and MCTS exploration genuinely discovers new promising subsets. The tree search adds value. The acqf landscape is smooth enough for cheap Sobol evaluation to provide discriminative signal. +- **Late BO (confident GP):** The surrogate is confident, the acqf concentrates on one or a few subsets, and MCTS is fighting a unimodal landscape where exploration of new subsets adds no value. The combinatorial question is already answered — only the continuous optimization within the chosen subset matters. + +This transition is inevitable for any competent surrogate. The subset exploration question converges faster than the continuous optimization within the best subset, because the combinatorial space is discrete (and often small relative to the continuous space) while the continuous optimum continues to refine. + +**Practical implication:** Always run `optimize_acqf` on the incumbent best subset(s) alongside whatever the MCTS proposes. This provides a natural fallback: + +- In early BO, MCTS may discover a better subset — the `optimize_acqf` refinement on the MCTS-proposed subsets captures this. +- In late BO, the MCTS contribution becomes negligible, but the `optimize_acqf` on the known-good incumbent subset still produces high-quality candidates. + +The two-phase design (§11.20.5) supports this naturally: always include the previous iteration's best subset(s) in the phase-2 refinement set, regardless of what phase-1 screening produces. This way the system gracefully transitions from exploration-dominant to exploitation-dominant as the campaign progresses, without requiring explicit convergence detection. The MCTS screening phase becomes effectively a no-op once the top-k subsets stabilize between BO iterations, and the cost is bounded by the screening budget (which is cheap relative to `optimize_acqf` refinement). + +This also explains why the cache degradation bug (§11.21.2) matters less in practice than it might appear: even with a broken tree, the incumbent best subset from prior iterations provides a safety net. The bug is still worth fixing (it wastes the MCTS compute budget and delays discovery of new subsets in early BO), but it does not cause catastrophic failure in a full BO loop where the refinement phase always evaluates the incumbent. + +#### 11.21.8 Files + +- **Diagnostic script**: `mcts-report/diagnose_cache_behavior.py` — reproduces the cache degradation bug with 4 MCTS configs (A: Sobol no-cache baseline, B: optimize_acqf + cache + var_inflation (broken), C: optimize_acqf no-cache, D: optimize_acqf + cache + pessimistic) +- **Landscape probe**: `mcts-report/probe_saas_landscape.py` — compares standard GP vs SAAS acqf landscape on key selections +- **Evolution trace**: `mcts-report/probe_landscape_evolution.py` — traces acqf landscape at successive BO stages for both surrogates +- **Sobol budget sweep**: `mcts-report/probe_sobol_budget.py` — sweeps n_sobol ∈ {64..4096} at each BO stage for both surrogates +- **Saved BO state**: `mcts-report/experiments.csv` — 80 experiments from SpuriousFeaturesWrapper(Hartmann(6), 6 spurious, max_count=6) + +### 11.22 DAG-Based MCTS with Transposition Table + +#### 11.22.1 Motivation: Canonical Ordering Bias + +The tree-based MCTS uses strictly increasing canonical ordering for NChooseK groups: after selecting feature index `i`, only indices `> i` are legal at the next level. This ensures each feature subset is reachable by exactly one path, preventing redundant nodes. However, it creates **asymmetric subtrees**: high-index features sit near the leaves (shallow subtrees with 2-3 nodes) while low-index features sit near the root (deep subtrees with 50-100+ nodes). The result is a structural bias where high-index features receive concentrated visits and converge quickly, while low-index features receive diluted visits across many branches and converge slowly. + +The per-run shuffle (randomizing feature-to-index assignment) averages this out across runs but does not eliminate it within a single run. In benchmarks with flat (uninformative) reward, high-index features are selected **6x more often** than low-index features due purely to tree topology. + +#### 11.22.2 The DAG Approach + +The MCTS DAG removes canonical ordering entirely: at every NChooseK decision node, **all unselected features** are legal actions, not just those with index greater than the last selected. This creates a symmetric action space where every feature has equal structural opportunity. + +Without further changes, removing ordering would cause an exponential blowup — feature set {A, B, C} would have 3! = 6 distinct tree paths. The **transposition table** prevents this: nodes with identical selected feature sets (regardless of selection order) share a single `DAGNode` keyed by `frozenset(selected)`. The tree becomes a DAG (directed acyclic graph) where: + +- Multiple parent paths converge on the same node +- NIG-TS statistics accumulate across all parent paths +- Each unique feature set is represented exactly once + +NIG Thompson Sampling (§11.13) is essential for this to work. UCT's exploration bonus depends on `parent_visits`, which is ambiguous in a DAG (a node has multiple parents). NIG-TS has no parent-dependent terms — the Student-t posterior depends only on the node's own observation statistics — so statistics merge cleanly across parents. + +**Transposition key**: For each node, the canonical key is `(group_idx, (frozenset(partial_group_0), frozenset(partial_group_1), ...), (stopped_0, stopped_1, ...))`. NChooseK groups use frozenset (order-independent); Categorical groups use tuples (at most 1 element). + +#### 11.22.3 STOP Dilution Problem + +Removing canonical ordering introduces a new problem: **STOP dilution**. In the tree, at deeper levels, STOP competes with only 2-3 remaining features (those with index > last). The tree's depth structure naturally concentrates STOP statistics at each cardinality level. In the DAG, STOP competes with **all unselected features** at every node — for a group with 8 features, STOP competes 1-out-of-8 instead of 1-out-of-3. + +This manifests as systematic over-selection: the DAG finds the correct features but adds extras. On multigroup_interaction (optimal cardinality 2+2+3=7), the DAG consistently selects ~10 features, scoring 103 instead of the optimal 150. + +#### 11.22.4 Separate STOP Fix + +The `separate_stop` mechanism restructures the STOP decision as a **binary comparison** rather than a 1-out-of-N competition: + +1. **Tree selection**: When STOP is among a node's children, first sample STOP's NIG score, then sample the best feature's NIG score, and compare directly. STOP gets a fair 50/50 chance instead of being diluted among many features. +2. **Rollout (ts_group_action)**: Same binary comparison — sample STOP's rollout NIG score vs the best feature's rollout NIG score. +3. **Rollout (uniform)**: 50/50 coin flip between STOP and a random feature (instead of 1/N uniform over all actions). +4. **Expansion priority**: STOP is always expanded first when it becomes legal, ensuring it gets early data. + +#### 11.22.5 Benchmark Results + +Six synthetic NChooseK problems, 30 trials each, same budget as prior benchmarks: + +**Optimum-finding rate (%):** + +| Problem | Random | NIG+vi+apv (tree) | DAG v1 | DAG+ss+vi | +|---------|--------|-------------------|--------|-----------| +| multigroup_interaction | 0% | **80%** | 0% | 0% | +| needle_in_haystack | 10% | **100%** | 53% | 83% | +| mixed_nchoosek_categorical | 3% | **100%** | 63% | 93% | +| large_sparse | 0% | 47% | 67% | **87%** | +| graduated_landscape | 7% | 77% | **100%** | **100%** | +| simple_additive | 0% | 83% | **100%** | **100%** | +| **Average** | 3.3% | 81.2% | 63.8% | **77.2%** | + +**Configuration key:** +- **NIG+vi+apv**: Tree-based NIG (§11.13 recommended default) — `cache_hit_mode=variance_inflation, adaptive_prior_var=True` +- **DAG v1**: DAG with `combined` cache-hit mode + adaptive prior variance, no separate_stop +- **DAG+ss+vi**: DAG with `separate_stop=True, cache_hit_mode=variance_inflation, adaptive_prior_var=True` + +**Separate_stop ablation (DAG configs):** + +| Problem | DAG v1 | DAG+ss | DAG+ss+vi | DAG+ss+an0 | DAG+ss+tpw | +|---------|--------|--------|-----------|------------|------------| +| multigroup | 0% | 0% | 0% | 0% | 3% | +| needle | 53% | 80% | **83%** | 70% | 73% | +| mixed | 63% | 57% | **93%** | 57% | 53% | +| large_sparse | 67% | 77% | **87%** | 67% | 47% | +| graduated | 100% | 100% | 100% | 100% | 100% | +| simple | 100% | 100% | 100% | 100% | 97% | + +Variance inflation (`vi`) synergizes with separate_stop across all problems. Adaptive n0 and tighter progressive widening do not add value on top of separate_stop and can hurt (large_sparse drops from 87% to 47% with tighter PW). + +#### 11.22.6 Feature Recall: The Right Metric for Acqf Optimization + +The 0% optimum rate on multigroup_interaction is misleading for the actual acqf optimization use case. In the real pipeline, MCTS selects which features to activate, then a gradient-based optimizer (L-BFGS via `optimize_acqf`) refines the continuous parameter values within that subset. What matters is not exact cardinality match but whether the correct features are **included** in the selected set. + +**Feature recall (contains all optimal features):** + +| Problem | DAG+ss+vi | +|---------|-----------| +| multigroup_interaction | **30/30 (100%)** | +| needle_in_haystack | **30/30 (100%)** | +| mixed_nchoosek_categorical | **30/30 (100%)** features + **30/30 (100%)** categoricals | + +The DAG with `ss+vi` achieves **100% recall of the optimal features** on every problem, every trial. It never misses an important feature — it just includes 1-3 extras. These extras produce a slightly higher-dimensional continuous optimization problem for the downstream gradient optimizer, but the important features are always present. + +On multigroup_interaction specifically: the optimal selection is features `{1, 5, 9, 14, 17, 20, 23}` (7 features: 2+2+3 across 3 groups). The DAG consistently selects ~10 features — all 7 optimal features plus 3 random extras from different groups. The extras are noise features that the gradient optimizer can handle by pushing their continuous values toward low-impact regions. + +#### 11.22.7 DAG vs NIG Tree: Complementary Strengths + +The DAG and tree approaches have complementary performance profiles: + +**DAG wins where ordering bias matters most:** +- **large_sparse** (87% vs 47%): 4 groups × 10 features, ~960M combinations. The tree's canonical ordering creates severe asymmetry across 10 features per group. The DAG's symmetric action space finds the optimal 2-group solution much more reliably. +- **graduated_landscape** (100% vs 77%) and **simple_additive** (100% vs 83%): Moderate feature counts where the DAG's unbiased exploration consistently finds the optimum. + +**NIG tree wins where cardinality precision matters:** +- **multigroup_interaction** (80% vs 0%): Cross-group interaction bonuses create a sharp cardinality-dependent reward landscape. The tree's depth structure concentrates STOP statistics at each cardinality level; the DAG's flat action space dilutes them. +- **needle** (100% vs 83%) and **mixed** (100% vs 93%): The tree's cardinality learning converges more quickly on these problems, though the DAG is close. + +For the acqf optimization use case, the DAG's over-selection is not a practical problem (§11.22.6): the downstream gradient optimizer handles extra features. The DAG's advantage on large search spaces (the most common regime in real applications with many features) makes it the preferred approach. + +#### 11.22.8 Recommended DAG Configuration + +```python +MCTS_DAG( + groups=groups, + reward_fn=reward_fn, + rollout_mode="ts_group_action", + cache_hit_mode="variance_inflation", + adaptive_prior_var=True, + separate_stop=True, + # defaults: pw_k0=2.0, pw_alpha=0.6, nig_alpha0=1.0 +) +``` + +| Parameter | Value | Rationale | +|-----------|-------|-----------| +| `rollout_mode` | `ts_group_action` | Learned rollout using per-(group, action) NIG posteriors | +| `cache_hit_mode` | `variance_inflation` | Widens posteriors on cache hits; synergizes with separate_stop | +| `adaptive_prior_var` | `True` | Adapts NIG prior scale to observed reward distribution | +| `separate_stop` | `True` | Binary STOP-vs-best-feature comparison; fixes STOP dilution | +| `pw_k0` | `2.0` (default) | Tighter PW (1.5) hurts large_sparse | +| `pw_alpha` | `0.6` (default) | Standard progressive widening exponent | +| `adaptive_n0` | `False` (default) | No benefit when combined with separate_stop | +| `informed_expansion` | `False` (default) | Marginal effect, adds complexity | + +#### 11.22.9 Files + +| File | Description | +|------|-------------| +| `optimize_mcts_dag.py` | MCTS DAG implementation with transposition table, NIG-TS, and separate_stop | +| `benchmark_dag.py` | DAG benchmark script (6 problems × 7 configs × 30 trials) | +| `results_dag.json` | Full numeric results for DAG benchmark | +| `convergence_dag_.png` | Full convergence curves per problem | +| `convergence_dag__ss_vs_baseline.png` | separate_stop vs baselines comparison | +| `convergence_dag__ss_variants.png` | separate_stop variant comparison | +| `summary_bar_chart_dag.png` | Bar chart of final best reward | +| `optimum_rate_heatmap_dag.png` | Heatmap of optimum-finding rates | +| `unique_evals_dag.png` | Exploration efficiency comparison | + +### 11.23 DAG with Stochastic Rewards (No-Cache Mode) + +In production, MCTS uses cheap Sobol-based acquisition function sampling rather than expensive `optimize_acqf`. Each evaluation returns a noisy reward (the maximum acquisition value over a small random sample), so caching is counterproductive — re-evaluating the same subset with different random draws provides genuinely new information. This section benchmarks the DAG with `use_cache=False` under Gaussian noise, simulating the production Sobol evaluation regime. + +#### 11.23.1 Setup + +Four configurations compared at σ ∈ {1.0, 2.0} (noise model: `true_reward + N(0, σ²) − σ`, matching the pessimistic Sobol bias from §11.16): + +| Config | Engine | Cache | Noise | separate_stop | +|--------|--------|-------|-------|---------------| +| NIG tree (no-cache) | MCTS_NIG tree | Off | σ | No | +| DAG+ss+vi (cached, det) | MCTS_DAG | On | None | Yes | +| DAG (no-cache) | MCTS_DAG | Off | σ | No | +| DAG+ss (no-cache) | MCTS_DAG | Off | σ | Yes | + +30 trials per config per problem. The reported `final_best` is the **true** (noiseless) reward of the best selection found, ensuring fair comparison. + +#### 11.23.2 Results: Optimum-Finding Rate + +**σ = 1.0 (moderate noise)** + +| Problem | NIG tree (no-cache) | DAG (no-cache) | DAG+ss (no-cache) | DAG+ss+vi (cached,det) | +|---------|--------------------:|---------------:|-------------------:|-----------------------:| +| multigroup_interaction | **70%** | 0% | 3% | 0% | +| needle_in_haystack | **83%** | 40% | 73% | **83%** | +| mixed_nchoosek_categorical | **93%** | **93%** | 90% | **93%** | +| large_sparse | 63% | **87%** | 77% | **87%** | +| graduated_landscape | 7% | **87%** | **93%** | **100%** | +| simple_additive | 30% | **100%** | 97% | **100%** | +| **Average** | **57.7%** | **67.8%** | **72.2%** | **77.2%** | + +**σ = 2.0 (high noise)** + +| Problem | NIG tree (no-cache) | DAG (no-cache) | DAG+ss (no-cache) | DAG+ss+vi (cached,det) | +|---------|--------------------:|---------------:|-------------------:|-----------------------:| +| multigroup_interaction | **73%** | 0% | 0% | 0% | +| needle_in_haystack | 73% | 40% | **97%** | 83% | +| mixed_nchoosek_categorical | **100%** | 93% | 97% | 93% | +| large_sparse | 43% | 73% | **83%** | 87% | +| graduated_landscape | 13% | 67% | **80%** | **100%** | +| simple_additive | 3% | 93% | **90%** | **100%** | +| **Average** | **50.8%** | **61.0%** | **74.5%** | **77.2%** | + +#### 11.23.3 Analysis + +**1. The NIG tree collapses under noise on small search spaces.** Graduated: 77% (cached deterministic, §11.13) → 7% (no-cache, σ=1.0). Simple: 83% → 30%. The tree's canonical ordering creates fragile path-specific statistics — observations along one feature ordering don't transfer to other orderings, and noise disrupts the fragile cardinality learning at each depth. In contrast, the DAG's transposition table merges all orderings into shared nodes, naturally averaging out noise. + +**2. The DAG is inherently noise-robust.** DAG (no-cache, σ=1.0) achieves 87% on large_sparse — identical to the cached deterministic DAG+ss+vi. The mechanism: the transposition table routes multiple noisy observations of the same feature set through the same DAGNode. With `use_cache=False`, each visit produces a fresh noisy reward, and the NIG posterior correctly averages over multiple draws. This is exactly the Bayesian averaging that §11.17 validated for the tree, but the DAG gets stronger benefit because its transposition merging concentrates observations more efficiently. + +**3. Separate_stop is the best no-cache DAG config overall.** Averaging across all 6 problems: DAG+ss averages 72.2% at σ=1.0 and 74.5% at σ=2.0, beating both the NIG tree (57.7%, 50.8%) and DAG without ss (67.8%, 61.0%). The separate_stop advantage grows with noise — at σ=2.0 on needle, DAG+ss gets 97% vs 40% without it. The binary STOP comparison becomes more important under noise because the diluted STOP signal (1-out-of-N) is even harder to distinguish from feature signals when both are noisy. + +**4. Higher noise helps the NIG tree on interaction problems.** The NIG tree goes from 70% at σ=1.0 to 73% at σ=2.0 on multigroup_interaction, and from 93% to 100% on mixed. The noise acts as implicit exploration — random reward fluctuations prevent premature commitment and help the tree discover cross-group interaction bonuses. This matches the §11.17 finding that noise can occasionally help. + +**5. The DAG's multigroup weakness persists under noise.** 0-3% across all noise levels. However, as shown in §11.22.6, the DAG achieves 100% feature recall — it always finds all optimal features and just adds extras. In the production pipeline where the downstream gradient optimizer handles feature refinement, this is not a practical problem. + +#### 11.23.4 Recommended Configuration for Stochastic Rewards + +For production use with Sobol-based acquisition function sampling: + +```python +MCTS_DAG( + groups=groups, + reward_fn=sobol_acqf_reward_fn, + rollout_mode="ts_group_action", + adaptive_prior_var=True, + separate_stop=True, + use_cache=False, + # cache_hit_mode is irrelevant with use_cache=False +) +``` + +This is the stochastic-reward counterpart of the deterministic recommendation (§11.22.8) — same settings, replacing `use_cache=True, cache_hit_mode="variance_inflation"` with `use_cache=False`. The `cache_hit_mode` parameter has no effect when `use_cache=False` because every observation is novel. + +**When to use cached deterministic vs no-cache stochastic:** + +| Setting | `use_cache` | Reward function | Best config | +|---------|-------------|-----------------|-------------| +| Expensive `optimize_acqf` | `True` | Deterministic | DAG+ss + `cache_hit_mode="variance_inflation"` | +| Cheap Sobol sampling | `False` | Stochastic | DAG+ss (cache_hit_mode irrelevant) | +| Two-phase burn-in | Phase 1: `False`, Phase 2: `True` | Stochastic → Deterministic | Switch at phase boundary | + +#### 11.23.5 Files + +| File | Description | +|------|-------------| +| `benchmark_dag_nocache.py` | DAG stochastic reward benchmark script | +| `results_dag_nocache.json` | Full numeric results | +| `nocache_dag_convergence__sigma<σ>.png` | Convergence curves per problem per noise level | +| `nocache_dag_summary_sigma<σ>.png` | Bar chart of true-best reward per noise level | +| `nocache_dag_optimum_rate_heatmap.png` | Heatmap across configs and noise levels | diff --git a/mcts-report/benchmark.py b/mcts-report/benchmark.py new file mode 100644 index 000000000..9a0495d16 --- /dev/null +++ b/mcts-report/benchmark.py @@ -0,0 +1,1037 @@ +"""MCTS Benchmark: Comparing MCTS configurations on combinatorial NChooseK problems. + +Tests RAVE on/off, Progressive Widening on/off, exploration constants, +and p_stop_rollout against a random-sampling baseline across multiple +problem instances with varying combinatorial complexity and reward structure. + +Usage: + python mcts-report/benchmark.py +""" + +import json +import math +import random +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from optimize_mcts_full import MCTS, Categorical, Groups, NChooseK + + +OUTPUT_DIR = Path(__file__).parent + +# ═══════════════════════════════════════════════════════════════════════ +# Benchmark problems +# ═══════════════════════════════════════════════════════════════════════ + + +@dataclass +class ProblemResult: + """Results for a single trial of a single configuration on one problem.""" + + best_values_over_time: list[float] # best value at each iteration + final_best: float + found_optimum: bool + n_unique_evals: int # distinct combinatorial selections evaluated + + +@dataclass +class Problem: + """A benchmark problem definition.""" + + name: str + description: str + groups: Groups + reward_fn: object # Callable[[tuple[int,...], dict[int,float]], float] + optimal_value: float + search_space_size: int # approximate number of feasible selections + n_iterations: int # budget per trial + n_trials: int # number of random seeds + + +def count_nchoosek_combos(n: int, min_k: int, max_k: int) -> int: + """Count the number of subsets of size min_k..max_k from n items.""" + return sum(math.comb(n, k) for k in range(min_k, max_k + 1)) + + +# --------------------------------------------------------------------------- +# Problem 1: Multi-group feature selection with pairwise interactions +# --------------------------------------------------------------------------- +def make_problem_multigroup_interaction() -> Problem: + """3 NChooseK groups (8 features each, pick 1-4) with cross-group interactions. + + Search space: (C(8,1)+...+C(8,4))^3 = 162^3 ≈ 4.25M combinations. + Reward has partial credit per correct feature + interaction bonuses. + """ + g1 = NChooseK(features=list(range(0, 8)), min_count=1, max_count=4) + g2 = NChooseK(features=list(range(8, 16)), min_count=1, max_count=4) + g3 = NChooseK(features=list(range(16, 24)), min_count=1, max_count=4) + gs = Groups(groups=[g1, g2, g3]) + + # Optimal: {1, 5} from g1, {9, 14} from g2, {17, 20, 23} from g3 + opt_g1 = {1, 5} + opt_g2 = {9, 14} + opt_g3 = {17, 20, 23} + optimal_set = opt_g1 | opt_g2 | opt_g3 + + def reward_fn(feats, _cats): + feat_set = set(feats) + # Base: partial credit per correct feature + correct = len(feat_set & optimal_set) + wrong = len(feat_set - optimal_set) + score = correct * 8.0 - wrong * 3.0 + + # Interaction bonuses (cross-group pairs) + if 1 in feat_set and 9 in feat_set: + score += 12.0 + if 5 in feat_set and 14 in feat_set: + score += 12.0 + if 9 in feat_set and 20 in feat_set: + score += 12.0 + if 14 in feat_set and 17 in feat_set: + score += 10.0 + if 1 in feat_set and 23 in feat_set: + score += 10.0 + + # Exact-match bonus + if feat_set == optimal_set: + score = 150.0 + + return score + + ss = count_nchoosek_combos(8, 1, 4) ** 3 # 162^3 + return Problem( + name="multigroup_interaction", + description="3 groups × 8 features (pick 1-4), cross-group interactions", + groups=gs, + reward_fn=reward_fn, + optimal_value=150.0, + search_space_size=ss, + n_iterations=600, + n_trials=30, + ) + + +# --------------------------------------------------------------------------- +# Problem 2: Needle in a haystack — wide single group +# --------------------------------------------------------------------------- +def make_problem_needle() -> Problem: + """Single NChooseK group: 15 features, pick 2-5. + + Search space: C(15,2)+...+C(15,5) = 4928 combinations. + Only one specific subset gives high reward; slight partial credit. + """ + g = NChooseK(features=list(range(15)), min_count=2, max_count=5) + gs = Groups(groups=[g]) + + target = {3, 7, 11} + + def reward_fn(feats, _cats): + feat_set = set(feats) + if feat_set == target: + return 100.0 + overlap = len(feat_set & target) + extras = len(feat_set - target) + return overlap * 15.0 - extras * 5.0 + + ss = count_nchoosek_combos(15, 2, 5) + return Problem( + name="needle_in_haystack", + description="15 features pick 2-5, single optimal subset", + groups=gs, + reward_fn=reward_fn, + optimal_value=100.0, + search_space_size=ss, + n_iterations=400, + n_trials=30, + ) + + +# --------------------------------------------------------------------------- +# Problem 3: Mixed NChooseK + Categorical with interactions +# --------------------------------------------------------------------------- +def make_problem_mixed() -> Problem: + """2 NChooseK groups + 2 Categorical dimensions with interactions. + + NChooseK: 6 features pick 1-3 each (C(6,1)+C(6,2)+C(6,3)=41 per group). + Categorical: 4 values each (4×4=16 combos). + Total: ~41×41×16 ≈ 26,896 combinations. + """ + g1 = NChooseK(features=list(range(0, 6)), min_count=1, max_count=3) + g2 = NChooseK(features=list(range(6, 12)), min_count=1, max_count=3) + cat1 = Categorical(dim=20, values=[0.0, 1.0, 2.0, 3.0]) + cat2 = Categorical(dim=21, values=[0.0, 1.0, 2.0, 3.0]) + gs = Groups(groups=[g1, g2, cat1, cat2]) + + opt_feats = {2, 4, 8, 11} + opt_cats = {20: 2.0, 21: 3.0} + + def reward_fn(feats, cats): + feat_set = set(feats) + # Feature credit + correct_feats = len(feat_set & opt_feats) + wrong_feats = len(feat_set - opt_feats) + score = correct_feats * 10.0 - wrong_feats * 4.0 + + # Categorical credit + for dim, val in opt_cats.items(): + if cats.get(dim) == val: + score += 12.0 + + # Interaction: feature 2 + cat 20=2.0 + if 2 in feat_set and cats.get(20) == 2.0: + score += 15.0 + + # Interaction: feature 11 + cat 21=3.0 + if 11 in feat_set and cats.get(21) == 3.0: + score += 15.0 + + # Exact match bonus + if feat_set == opt_feats and cats == opt_cats: + score = 150.0 + + return score + + ss = count_nchoosek_combos(6, 1, 3) ** 2 * 4 * 4 + return Problem( + name="mixed_nchoosek_categorical", + description="2 NChooseK (6 feat, 1-3) + 2 Categorical (4 vals each)", + groups=gs, + reward_fn=reward_fn, + optimal_value=150.0, + search_space_size=ss, + n_iterations=500, + n_trials=30, + ) + + +# --------------------------------------------------------------------------- +# Problem 4: Large-scale sparse selection +# --------------------------------------------------------------------------- +def make_problem_large_sparse() -> Problem: + """4 NChooseK groups (10 features each, pick 0-3). + + This tests MCTS with min_count=0 (selecting nothing from a group is valid). + Search space: (C(10,0)+C(10,1)+C(10,2)+C(10,3))^4 = 176^4 ≈ 960M. + Optimal uses features from only 2 of the 4 groups. + """ + g1 = NChooseK(features=list(range(0, 10)), min_count=0, max_count=3) + g2 = NChooseK(features=list(range(10, 20)), min_count=0, max_count=3) + g3 = NChooseK(features=list(range(20, 30)), min_count=0, max_count=3) + g4 = NChooseK(features=list(range(30, 40)), min_count=0, max_count=3) + gs = Groups(groups=[g1, g2, g3, g4]) + + # Optimal: select from groups 1 and 3 only + opt_g1 = {2, 7} + opt_g3 = {22, 25, 28} + optimal_set = opt_g1 | opt_g3 + + def reward_fn(feats, _cats): + feat_set = set(feats) + if feat_set == optimal_set: + return 200.0 + correct = len(feat_set & optimal_set) + wrong = len(feat_set - optimal_set) + score = correct * 12.0 - wrong * 6.0 + + # Bonus for sparsity (using fewer groups) + groups_used = set() + for f in feat_set: + groups_used.add(f // 10) + if len(groups_used) <= 2: + score += 8.0 + return score + + ss = count_nchoosek_combos(10, 0, 3) ** 4 + return Problem( + name="large_sparse", + description="4 groups × 10 features (pick 0-3), optimal uses only 2 groups", + groups=gs, + reward_fn=reward_fn, + optimal_value=200.0, + search_space_size=ss, + n_iterations=800, + n_trials=30, + ) + + +# --------------------------------------------------------------------------- +# Problem 5: Overlapping reward landscape (many near-optimal solutions) +# --------------------------------------------------------------------------- +def make_problem_graduated() -> Problem: + """10 features, pick 2-4. Smooth reward landscape based on feature quality scores. + + Each feature has a fixed quality; reward = sum of qualities of selected features + minus a penalty for extra features. The landscape is smooth so MCTS benefits + from learning which features are generally good. + """ + g = NChooseK(features=list(range(10)), min_count=2, max_count=4) + gs = Groups(groups=[g]) + + # Feature quality scores (pre-determined) + quality = { + 0: 5.0, + 1: 12.0, + 2: 3.0, + 3: 18.0, + 4: 7.0, + 5: 15.0, + 6: 2.0, + 7: 20.0, + 8: 9.0, + 9: 11.0, + } + # Optimal: {3, 7, 5} with reward 18+20+15 = 53, or {1,3,5,7} = 12+18+15+20=65 + # Actually {1,3,5,7} = 65 is best 4-subset + + def reward_fn(feats, _cats): + return sum(quality[f] for f in feats) + + ss = count_nchoosek_combos(10, 2, 4) + optimal_val = sum(sorted(quality.values(), reverse=True)[:4]) # top 4 + return Problem( + name="graduated_landscape", + description="10 features pick 2-4, smooth quality-based reward", + groups=gs, + reward_fn=reward_fn, + optimal_value=optimal_val, + search_space_size=ss, + n_iterations=300, + n_trials=30, + ) + + +# --------------------------------------------------------------------------- +# Problem 6: Simple additive — independent feature values, no interactions +# --------------------------------------------------------------------------- +def make_problem_simple_additive() -> Problem: + """12 features, pick 1-4. Each feature contributes a fixed positive value. + + No interactions — reward is purely the sum of selected feature values. + Features have varying magnitudes so the algorithm must learn to pick + the highest-value features and the right cardinality (exactly 4). + This is the simplest possible NChooseK problem. + """ + g = NChooseK(features=list(range(12)), min_count=1, max_count=4) + gs = Groups(groups=[g]) + + values = { + 0: 1.0, + 1: 3.0, + 2: 8.0, + 3: 2.0, + 4: 15.0, + 5: 5.0, + 6: 12.0, + 7: 20.0, + 8: 4.0, + 9: 10.0, + 10: 6.0, + 11: 18.0, + } + # Optimal: {4, 6, 7, 11} = 15 + 12 + 20 + 18 = 65 + + def reward_fn(feats, _cats): + return sum(values[f] for f in feats) + + ss = count_nchoosek_combos(12, 1, 4) + optimal_val = sum(sorted(values.values(), reverse=True)[:4]) # top 4 + return Problem( + name="simple_additive", + description="12 features pick 1-4, independent positive values (no interactions)", + groups=gs, + reward_fn=reward_fn, + optimal_value=optimal_val, + search_space_size=ss, + n_iterations=300, + n_trials=30, + ) + + +# ═══════════════════════════════════════════════════════════════════════ +# MCTS Configurations +# ═══════════════════════════════════════════════════════════════════════ + + +@dataclass +class MCTSConfig: + """A named MCTS configuration for benchmarking.""" + + name: str + c_uct: float = 1.0 + k_rave: float = 300.0 + p_stop_rollout: float = 0.35 + pw_k0: float = 2.0 + pw_alpha: float = 0.6 + adaptive_p_stop: bool = False + p_stop_warmup: int = 20 + p_stop_temperature: float = 0.25 + normalize_rewards: bool = False + rollout_policy: bool = False + rollout_epsilon: float = 0.3 + rollout_tau: float = 1.0 + rollout_novelty_weight: float = 1.0 + context_rave: bool = False + + +# Effective ways to disable features: +# - RAVE off: k_rave=0 → beta=0 → pure UCT +# - PW off: pw_k0=1e6 → child_limit always exceeds legal actions +CONFIGS = [ + MCTSConfig(name="MCTS (default)", c_uct=1.0, k_rave=300, pw_k0=2.0, pw_alpha=0.6), + MCTSConfig(name="MCTS (no RAVE)", c_uct=1.0, k_rave=0, pw_k0=2.0, pw_alpha=0.6), + MCTSConfig(name="MCTS (no PW)", c_uct=1.0, k_rave=300, pw_k0=1e6, pw_alpha=0.6), + MCTSConfig( + name="MCTS (no RAVE, no PW)", c_uct=1.0, k_rave=0, pw_k0=1e6, pw_alpha=0.6 + ), + MCTSConfig( + name="MCTS (low explore)", c_uct=0.1, k_rave=300, pw_k0=2.0, pw_alpha=0.6 + ), + MCTSConfig( + name="MCTS (high explore)", c_uct=5.0, k_rave=300, pw_k0=2.0, pw_alpha=0.6 + ), + MCTSConfig( + name="MCTS (heavy RAVE)", c_uct=1.0, k_rave=3000, pw_k0=2.0, pw_alpha=0.6 + ), + MCTSConfig(name="MCTS (tight PW)", c_uct=1.0, k_rave=300, pw_k0=1.0, pw_alpha=0.4), + MCTSConfig(name="MCTS (loose PW)", c_uct=1.0, k_rave=300, pw_k0=5.0, pw_alpha=0.8), + MCTSConfig( + name="MCTS (p_stop=0.1)", + c_uct=1.0, + k_rave=300, + pw_k0=2.0, + pw_alpha=0.6, + p_stop_rollout=0.1, + ), + MCTSConfig( + name="MCTS (p_stop=0.6)", + c_uct=1.0, + k_rave=300, + pw_k0=2.0, + pw_alpha=0.6, + p_stop_rollout=0.6, + ), + MCTSConfig( + name="MCTS (adaptive p)", + c_uct=1.0, + k_rave=300, + pw_k0=2.0, + pw_alpha=0.6, + adaptive_p_stop=True, + ), + MCTSConfig( + name="MCTS (no RAVE+adpt)", + c_uct=1.0, + k_rave=0, + pw_k0=2.0, + pw_alpha=0.6, + adaptive_p_stop=True, + ), + MCTSConfig( + name="MCTS (norm)", + c_uct=0.01, + k_rave=300, + pw_k0=2.0, + pw_alpha=0.6, + normalize_rewards=True, + ), + MCTSConfig( + name="MCTS (no RAVE+adpt+norm)", + c_uct=0.01, + k_rave=0, + pw_k0=2.0, + pw_alpha=0.6, + adaptive_p_stop=True, + normalize_rewards=True, + ), + MCTSConfig( + name="MCTS (+rpol)", + c_uct=0.01, + k_rave=0, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, + ), + MCTSConfig( + name="MCTS (+rpol ε=0.1)", + c_uct=0.01, + k_rave=0, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, + rollout_epsilon=0.1, + ), + MCTSConfig( + name="MCTS (+rpol τ=0.5)", + c_uct=0.01, + k_rave=0, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, + rollout_tau=0.5, + ), + MCTSConfig( + name="MCTS (+rpol τ=2)", + c_uct=0.01, + k_rave=0, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, + rollout_tau=2.0, + ), + MCTSConfig( + name="MCTS (+crave k=100)", + c_uct=0.01, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, + context_rave=True, + k_rave=100, + ), + MCTSConfig( + name="MCTS (+crave k=300)", + c_uct=0.01, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, + context_rave=True, + k_rave=300, + ), + MCTSConfig( + name="MCTS (+crave k=500)", + c_uct=0.01, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, + context_rave=True, + k_rave=500, + ), +] + + +# ═══════════════════════════════════════════════════════════════════════ +# Random baseline +# ═══════════════════════════════════════════════════════════════════════ + + +def run_random_baseline(problem: Problem, seed: int) -> ProblemResult: + """Random rollouts from root node, tracking best value per iteration.""" + # Create a dummy MCTS just to use its rollout machinery + mcts_tmp = MCTS( + groups=problem.groups, + reward_fn=lambda f, c: 0.0, + rollout_policy=False, + seed=seed, + ) + rng = random.Random(seed) + + best = float("-inf") + best_values = [] + seen = set() + + for _ in range(problem.n_iterations): + mcts_tmp.rng = rng + feats, cats, _traj = mcts_tmp._rollout(mcts_tmp.root) + val = problem.reward_fn(feats, cats) + key = (feats, frozenset(cats.items())) + seen.add(key) + if val > best: + best = val + best_values.append(best) + + return ProblemResult( + best_values_over_time=best_values, + final_best=best, + found_optimum=abs(best - problem.optimal_value) < 1e-6, + n_unique_evals=len(seen), + ) + + +# ═══════════════════════════════════════════════════════════════════════ +# MCTS run +# ═══════════════════════════════════════════════════════════════════════ + + +def run_mcts_config(problem: Problem, config: MCTSConfig, seed: int) -> ProblemResult: + """Run MCTS with given config, tracking best value per iteration.""" + mcts = MCTS( + groups=problem.groups, + reward_fn=problem.reward_fn, + c_uct=config.c_uct, + k_rave=config.k_rave, + p_stop_rollout=config.p_stop_rollout, + pw_k0=config.pw_k0, + pw_alpha=config.pw_alpha, + adaptive_p_stop=config.adaptive_p_stop, + p_stop_warmup=config.p_stop_warmup, + p_stop_temperature=config.p_stop_temperature, + normalize_rewards=config.normalize_rewards, + rollout_policy=config.rollout_policy, + rollout_epsilon=config.rollout_epsilon, + rollout_tau=config.rollout_tau, + rollout_novelty_weight=config.rollout_novelty_weight, + context_rave=config.context_rave, + seed=seed, + ) + + best_values = [] + for _ in range(problem.n_iterations): + mcts.run(n_iterations=1) + best_values.append(mcts.best_value) + + stats = mcts.cache_stats() + return ProblemResult( + best_values_over_time=best_values, + final_best=mcts.best_value, + found_optimum=abs(mcts.best_value - problem.optimal_value) < 1e-6, + n_unique_evals=stats["misses"], # cache misses = unique evaluations + ) + + +# ═══════════════════════════════════════════════════════════════════════ +# Benchmark runner +# ═══════════════════════════════════════════════════════════════════════ + + +def run_benchmark(problems: list[Problem], configs: list[MCTSConfig]): + """Run all configurations on all problems, collecting results.""" + all_results = {} # (problem_name, config_name) -> list[ProblemResult] + + for prob in problems: + print(f"\n{'='*70}") + print(f"Problem: {prob.name}") + print(f" {prob.description}") + print(f" Search space: ~{prob.search_space_size:,} combinations") + print(f" Budget: {prob.n_iterations} iterations × {prob.n_trials} trials") + print(f"{'='*70}") + + # Random baseline + key = (prob.name, "Random") + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_random_baseline(prob, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" Random | best={mean_best:7.1f} | opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # MCTS configs + for cfg in configs: + key = (prob.name, cfg.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_mcts_config(prob, cfg, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {cfg.name:17s} | best={mean_best:7.1f} | opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + return all_results + + +# ═══════════════════════════════════════════════════════════════════════ +# Plotting +# ═══════════════════════════════════════════════════════════════════════ + +# Consistent color scheme +COLOR_MAP = { + "Random": "#888888", + "MCTS (default)": "#1f77b4", + "MCTS (no RAVE)": "#ff7f0e", + "MCTS (no PW)": "#2ca02c", + "MCTS (no RAVE, no PW)": "#d62728", + "MCTS (low explore)": "#9467bd", + "MCTS (high explore)": "#8c564b", + "MCTS (heavy RAVE)": "#e377c2", + "MCTS (tight PW)": "#7f7f7f", + "MCTS (loose PW)": "#bcbd22", + "MCTS (p_stop=0.1)": "#17becf", + "MCTS (p_stop=0.6)": "#aec7e8", + "MCTS (adaptive p)": "#ff1493", + "MCTS (no RAVE+adpt)": "#00ced1", + "MCTS (norm)": "#ff6347", + "MCTS (no RAVE+adpt+norm)": "#32cd32", + "MCTS (+rpol)": "#8b0000", + "MCTS (+rpol ε=0.1)": "#ff4500", + "MCTS (+rpol τ=0.5)": "#daa520", + "MCTS (+rpol τ=2)": "#4682b4", + "MCTS (+crave k=100)": "#6a0dad", + "MCTS (+crave k=300)": "#20b2aa", + "MCTS (+crave k=500)": "#dc143c", +} + + +def plot_convergence_curves( + problem_name: str, + all_results: dict, + config_names: list[str], + n_iterations: int, +): + """Plot mean convergence curves with shaded ±1 std region.""" + fig, ax = plt.subplots(figsize=(10, 6)) + + for cname in config_names: + key = (problem_name, cname) + if key not in all_results: + continue + results = all_results[key] + curves = np.array([r.best_values_over_time for r in results]) + mean = curves.mean(axis=0) + std = curves.std(axis=0) + x = np.arange(1, n_iterations + 1) + color = COLOR_MAP.get(cname, None) + ax.plot(x, mean, label=cname, color=color, linewidth=1.5) + ax.fill_between(x, mean - std, mean + std, alpha=0.12, color=color) + + ax.set_xlabel("Iteration", fontsize=12) + ax.set_ylabel("Best Reward Found", fontsize=12) + ax.set_title(f"Convergence: {problem_name}", fontsize=14) + ax.legend(fontsize=8, loc="lower right", ncol=2) + ax.grid(True, alpha=0.3) + fig.tight_layout() + path = OUTPUT_DIR / f"convergence_{problem_name}.png" + fig.savefig(path, dpi=150) + plt.close(fig) + print(f" Saved: {path}") + + +def plot_convergence_subsets(problem_name, all_results, n_iterations): + """Plot focused convergence comparisons: RAVE effect, PW effect, exploration.""" + subsets = { + "rave_effect": [ + "Random", + "MCTS (default)", + "MCTS (no RAVE)", + "MCTS (heavy RAVE)", + ], + "pw_effect": [ + "Random", + "MCTS (default)", + "MCTS (no PW)", + "MCTS (tight PW)", + "MCTS (loose PW)", + ], + "exploration": [ + "Random", + "MCTS (default)", + "MCTS (low explore)", + "MCTS (high explore)", + ], + "p_stop": [ + "Random", + "MCTS (default)", + "MCTS (p_stop=0.1)", + "MCTS (p_stop=0.6)", + "MCTS (adaptive p)", + "MCTS (no RAVE+adpt)", + "MCTS (norm)", + "MCTS (no RAVE+adpt+norm)", + ], + "rollout": [ + "Random", + "MCTS (no RAVE+adpt+norm)", + "MCTS (+rpol)", + "MCTS (+rpol ε=0.1)", + "MCTS (+rpol τ=0.5)", + "MCTS (+rpol τ=2)", + ], + "crave": [ + "Random", + "MCTS (no RAVE+adpt+norm)", + "MCTS (+rpol)", + "MCTS (+crave k=100)", + "MCTS (+crave k=300)", + "MCTS (+crave k=500)", + ], + } + for subset_name, cnames in subsets.items(): + fig, ax = plt.subplots(figsize=(9, 5)) + for cname in cnames: + key = (problem_name, cname) + if key not in all_results: + continue + results = all_results[key] + curves = np.array([r.best_values_over_time for r in results]) + mean = curves.mean(axis=0) + std = curves.std(axis=0) + x = np.arange(1, n_iterations + 1) + color = COLOR_MAP.get(cname, None) + ax.plot(x, mean, label=cname, color=color, linewidth=2) + ax.fill_between(x, mean - std, mean + std, alpha=0.12, color=color) + + ax.set_xlabel("Iteration", fontsize=12) + ax.set_ylabel("Best Reward Found", fontsize=12) + ax.set_title( + f"{problem_name} — {subset_name.replace('_', ' ').title()}", fontsize=13 + ) + ax.legend(fontsize=9, loc="lower right") + ax.grid(True, alpha=0.3) + fig.tight_layout() + path = OUTPUT_DIR / f"convergence_{problem_name}_{subset_name}.png" + fig.savefig(path, dpi=150) + plt.close(fig) + print(f" Saved: {path}") + + +def plot_summary_bar_chart( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean final best across all problems for each config.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + stds = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + names.append(cname) + means.append(np.mean(finals)) + stds.append(np.std(finals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + xerr=stds, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + capsize=3, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Final Best Reward", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.axvline( + prob.optimal_value, color="red", linestyle="--", alpha=0.5, label="Optimum" + ) + ax.legend(fontsize=8) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle("Final Best Reward by Configuration", fontsize=14, y=1.02) + fig.tight_layout() + path = OUTPUT_DIR / "summary_bar_chart.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_optimum_rate_heatmap( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Heatmap: optimum-finding rate (config × problem).""" + matrix = [] + for cname in config_names: + row = [] + for prob in problems: + key = (prob.name, cname) + if key not in all_results: + row.append(0.0) + continue + results = all_results[key] + rate = sum(r.found_optimum for r in results) / len(results) + row.append(rate) + matrix.append(row) + + matrix = np.array(matrix) + fig, ax = plt.subplots( + figsize=(max(8, len(problems) * 2), max(6, len(config_names) * 0.5)) + ) + im = ax.imshow(matrix, aspect="auto", cmap="YlGn", vmin=0, vmax=1) + ax.set_xticks(range(len(problems))) + ax.set_xticklabels([p.name for p in problems], rotation=30, ha="right", fontsize=9) + ax.set_yticks(range(len(config_names))) + ax.set_yticklabels(config_names, fontsize=9) + + # Annotate cells + for i in range(len(config_names)): + for j in range(len(problems)): + val = matrix[i, j] + color = "white" if val > 0.6 else "black" + ax.text( + j, i, f"{val:.0%}", ha="center", va="center", fontsize=9, color=color + ) + + ax.set_title("Optimum-Finding Rate", fontsize=13) + fig.colorbar(im, ax=ax, label="Rate", shrink=0.8) + fig.tight_layout() + path = OUTPUT_DIR / "optimum_rate_heatmap.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_unique_evals( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean number of unique evaluations per config per problem.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + evals = [r.n_unique_evals for r in results] + names.append(cname) + means.append(np.mean(evals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Unique Evaluations", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle("Exploration: Unique Selections Evaluated", fontsize=14, y=1.02) + fig.tight_layout() + path = OUTPUT_DIR / "unique_evals.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def save_summary_json(problems, all_results, config_names): + """Save numeric results as JSON for reproducibility.""" + summary = {} + for prob in problems: + prob_summary = {} + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + opt_rates = [r.found_optimum for r in results] + unique_evals = [r.n_unique_evals for r in results] + prob_summary[cname] = { + "mean_best": float(np.mean(finals)), + "std_best": float(np.std(finals)), + "median_best": float(np.median(finals)), + "optimum_rate": float(np.mean(opt_rates)), + "mean_unique_evals": float(np.mean(unique_evals)), + } + summary[prob.name] = prob_summary + + path = OUTPUT_DIR / "results.json" + with open(path, "w") as f: + json.dump(summary, f, indent=2) + print(f"Saved: {path}") + return summary + + +# ═══════════════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════════════ + + +def main(): + print("MCTS Benchmark Suite") + print("=" * 70) + + problems = [ + make_problem_multigroup_interaction(), + make_problem_needle(), + make_problem_mixed(), + make_problem_large_sparse(), + make_problem_graduated(), + make_problem_simple_additive(), + ] + + for p in problems: + print( + f" {p.name}: ~{p.search_space_size:,} combinations, {p.n_iterations} iters × {p.n_trials} trials" + ) + + all_config_names = ["Random"] + [c.name for c in CONFIGS] + + t_start = time.time() + all_results = run_benchmark(problems, CONFIGS) + total_time = time.time() - t_start + print(f"\nTotal benchmark time: {total_time:.1f}s") + + # Generate plots + print("\nGenerating plots...") + for prob in problems: + plot_convergence_curves( + prob.name, all_results, all_config_names, prob.n_iterations + ) + plot_convergence_subsets(prob.name, all_results, prob.n_iterations) + + plot_summary_bar_chart(problems, all_results, all_config_names) + plot_optimum_rate_heatmap(problems, all_results, all_config_names) + plot_unique_evals(problems, all_results, all_config_names) + + # Save numeric results + summary = save_summary_json(problems, all_results, all_config_names) + + # Print summary table + print("\n" + "=" * 90) + print("SUMMARY TABLE") + print("=" * 90) + for prob in problems: + print( + f"\n{prob.name} (search space: ~{prob.search_space_size:,}, optimum: {prob.optimal_value})" + ) + print( + f" {'Config':<25s} {'Mean Best':>10s} {'±Std':>8s} {'Opt Rate':>10s} {'Uniq Evals':>12s}" + ) + print(f" {'-'*65}") + for cname in all_config_names: + d = summary[prob.name].get(cname) + if d is None: + continue + print( + f" {cname:<25s} {d['mean_best']:10.1f} {d['std_best']:8.1f} " + f"{d['optimum_rate']:10.0%} {d['mean_unique_evals']:12.0f}" + ) + + print(f"\nAll outputs saved to: {OUTPUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/mcts-report/benchmark_nig.py b/mcts-report/benchmark_nig.py new file mode 100644 index 000000000..2f4b1df1e --- /dev/null +++ b/mcts-report/benchmark_nig.py @@ -0,0 +1,638 @@ +"""Benchmark: Normal-Inverse-Gamma (NIG) MCTS vs Normal-TS vs UCT. + +Compares NIG-based MCTS (Student-t posterior) against the best Normal-TS +and UCT configurations on 6 combinatorial NChooseK benchmark problems. + +Configurations tested: + Reference baselines (3): + 1. Random + 2. UCT (+rpol) — best UCT config + 3. TS + TS(g,a) + comb — best Normal-TS config (current recommended default) + + NIG variants (8): + 1. NIG + uniform — minimal NIG, uniform rollout + 2. NIG + TS(g,a) — NIG + learned rollout + 3. NIG + TS(g,a) + comb — NIG + best cache-hit mode + 4. NIG + TS(g,a) + comb + apv — NIG + combined + adaptive variance + 5. NIG + TS(g,a) + vi + apv — NIG + variance inflation + adaptive variance + 6. NIG + TS(g,a) + pess — NIG + pessimistic only + 7. NIG + uniform + pess + apv — NIG + pessimistic + adaptive variance + 8. NIG + TS(g,a) + comb (a0=2) — higher alpha0 = lighter tails + +Usage: + python mcts-report/benchmark_nig.py +""" + +import json +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from benchmark import ( + Problem, + ProblemResult, + make_problem_graduated, + make_problem_large_sparse, + make_problem_mixed, + make_problem_multigroup_interaction, + make_problem_needle, + make_problem_simple_additive, + run_random_baseline, +) +from benchmark_ts import TSConfig, UCTConfig, run_ts_config, run_uct_config +from optimize_mcts_nig import MCTS_NIG + + +OUTPUT_DIR = Path(__file__).parent + + +# ====================================================================== +# Configurations +# ====================================================================== + + +@dataclass +class NIGConfig: + """Configuration for MCTS_NIG benchmarking.""" + + name: str + nig_alpha0: float = 1.0 + ts_prior_var: float = 1.0 + adaptive_prior_var: bool = False + cache_hit_mode: str = "no_update" + variance_decay: float = 0.95 + rollout_mode: str = "uniform" + pw_k0: float = 2.0 + pw_alpha: float = 0.6 + # Softmax fallback params + rollout_epsilon: float = 0.3 + rollout_tau: float = 1.0 + rollout_novelty_weight: float = 1.0 + normalize_rewards: bool = True + adaptive_p_stop: bool = True + p_stop_rollout: float = 0.35 + p_stop_warmup: int = 20 + p_stop_temperature: float = 0.25 + adaptive_n0: bool = False + + +# UCT reference +UCT_REF = UCTConfig( + name="UCT (+rpol)", + c_uct=0.01, + k_rave=0, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, +) + +# Normal-TS reference (best config from benchmark_ts) +TS_REF = TSConfig( + name="TS + TS(g,a) + comb", + rollout_mode="ts_group_action", + cache_hit_mode="combined", +) + +# NIG configs +NIG_CONFIGS = [ + NIGConfig( + name="NIG + uniform", + rollout_mode="uniform", + ), + NIGConfig( + name="NIG + TS(g,a)", + rollout_mode="ts_group_action", + ), + NIGConfig( + name="NIG + TS(g,a) + comb", + rollout_mode="ts_group_action", + cache_hit_mode="combined", + ), + NIGConfig( + name="NIG + TS(g,a) + comb + apv", + rollout_mode="ts_group_action", + cache_hit_mode="combined", + adaptive_prior_var=True, + ), + NIGConfig( + name="NIG + TS(g,a) + vi + apv", + rollout_mode="ts_group_action", + cache_hit_mode="variance_inflation", + adaptive_prior_var=True, + ), + NIGConfig( + name="NIG + TS(g,a) + pess", + rollout_mode="ts_group_action", + cache_hit_mode="pessimistic", + ), + NIGConfig( + name="NIG + uniform + pess + apv", + rollout_mode="uniform", + cache_hit_mode="pessimistic", + adaptive_prior_var=True, + ), + NIGConfig( + name="NIG + TS(g,a) + comb (a0=2)", + nig_alpha0=2.0, + rollout_mode="ts_group_action", + cache_hit_mode="combined", + ), +] + +ALL_CONFIG_NAMES = ["Random", UCT_REF.name, TS_REF.name] + [c.name for c in NIG_CONFIGS] + + +# ====================================================================== +# Run functions +# ====================================================================== + + +def run_nig_config(problem: Problem, config: NIGConfig, seed: int) -> ProblemResult: + """Run NIG MCTS with given config, tracking best value per iteration.""" + mcts = MCTS_NIG( + groups=problem.groups, + reward_fn=problem.reward_fn, + nig_alpha0=config.nig_alpha0, + ts_prior_var=config.ts_prior_var, + adaptive_prior_var=config.adaptive_prior_var, + cache_hit_mode=config.cache_hit_mode, + variance_decay=config.variance_decay, + rollout_mode=config.rollout_mode, + pw_k0=config.pw_k0, + pw_alpha=config.pw_alpha, + seed=seed, + rollout_epsilon=config.rollout_epsilon, + rollout_tau=config.rollout_tau, + rollout_novelty_weight=config.rollout_novelty_weight, + normalize_rewards=config.normalize_rewards, + adaptive_p_stop=config.adaptive_p_stop, + p_stop_rollout=config.p_stop_rollout, + p_stop_warmup=config.p_stop_warmup, + p_stop_temperature=config.p_stop_temperature, + adaptive_n0=config.adaptive_n0, + ) + + best_values = [] + for _ in range(problem.n_iterations): + mcts.run(n_iterations=1) + best_values.append(mcts.best_value) + + stats = mcts.cache_stats() + return ProblemResult( + best_values_over_time=best_values, + final_best=mcts.best_value, + found_optimum=abs(mcts.best_value - problem.optimal_value) < 1e-6, + n_unique_evals=stats["misses"], + ) + + +# ====================================================================== +# Benchmark runner +# ====================================================================== + + +def run_benchmark(problems: list[Problem]): + """Run all configurations on all problems, collecting results.""" + all_results = {} + + for prob in problems: + print(f"\n{'='*70}") + print(f"Problem: {prob.name}") + print(f" {prob.description}") + print(f" Search space: ~{prob.search_space_size:,} combinations") + print(f" Budget: {prob.n_iterations} iterations x {prob.n_trials} trials") + print(f"{'='*70}") + + # Random baseline + key = (prob.name, "Random") + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_random_baseline(prob, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {'Random':30s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # UCT reference + key = (prob.name, UCT_REF.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_uct_config(prob, UCT_REF, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {UCT_REF.name:30s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # Normal-TS reference + key = (prob.name, TS_REF.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_ts_config(prob, TS_REF, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {TS_REF.name:30s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # NIG configs + for cfg in NIG_CONFIGS: + key = (prob.name, cfg.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_nig_config(prob, cfg, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {cfg.name:30s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + return all_results + + +# ====================================================================== +# Plotting +# ====================================================================== + +COLOR_MAP = { + "Random": "#888888", + "UCT (+rpol)": "#1f77b4", + "TS + TS(g,a) + comb": "#98df8a", + "NIG + uniform": "#2ca02c", + "NIG + TS(g,a)": "#d62728", + "NIG + TS(g,a) + comb": "#9467bd", + "NIG + TS(g,a) + comb + apv": "#8c564b", + "NIG + TS(g,a) + vi + apv": "#e377c2", + "NIG + TS(g,a) + pess": "#ff9896", + "NIG + uniform + pess + apv": "#ffbb78", + "NIG + TS(g,a) + comb (a0=2)": "#17becf", +} + + +def plot_convergence( + problem_name: str, + all_results: dict, + config_names: list[str], + n_iterations: int, + suffix: str = "", + title_extra: str = "", +): + """Plot mean convergence curves with shaded +/-1 std region.""" + fig, ax = plt.subplots(figsize=(10, 6)) + + for cname in config_names: + key = (problem_name, cname) + if key not in all_results: + continue + results = all_results[key] + curves = np.array([r.best_values_over_time for r in results]) + mean = curves.mean(axis=0) + std = curves.std(axis=0) + x = np.arange(1, n_iterations + 1) + color = COLOR_MAP.get(cname, None) + ax.plot(x, mean, label=cname, color=color, linewidth=1.5) + ax.fill_between(x, mean - std, mean + std, alpha=0.12, color=color) + + ax.set_xlabel("Iteration", fontsize=12) + ax.set_ylabel("Best Reward Found", fontsize=12) + title = f"Convergence: {problem_name}" + if title_extra: + title += f" — {title_extra}" + ax.set_title(title, fontsize=14) + ax.legend(fontsize=8, loc="lower right", ncol=2) + ax.grid(True, alpha=0.3) + fig.tight_layout() + fname = f"convergence_nig_{problem_name}" + if suffix: + fname += f"_{suffix}" + path = OUTPUT_DIR / f"{fname}.png" + fig.savefig(path, dpi=150) + plt.close(fig) + print(f" Saved: {path}") + + +def plot_convergence_subsets(problem_name, all_results, n_iterations): + """Plot focused subset comparisons for NIG analysis.""" + subsets = { + "nig_vs_normal_ts": { + "configs": [ + "Random", + "UCT (+rpol)", + "TS + TS(g,a) + comb", + "NIG + uniform", + "NIG + TS(g,a)", + "NIG + TS(g,a) + comb", + ], + "title": "NIG vs Normal-TS vs UCT", + }, + "nig_cache_modes": { + "configs": [ + "Random", + "NIG + TS(g,a)", + "NIG + TS(g,a) + comb", + "NIG + TS(g,a) + vi + apv", + "NIG + TS(g,a) + pess", + "NIG + uniform + pess + apv", + ], + "title": "NIG Cache-Hit Modes", + }, + "nig_alpha": { + "configs": [ + "Random", + "TS + TS(g,a) + comb", + "NIG + TS(g,a) + comb", + "NIG + TS(g,a) + comb (a0=2)", + "NIG + TS(g,a) + comb + apv", + ], + "title": "NIG Alpha0 & APV Effect", + }, + } + + for subset_name, spec in subsets.items(): + plot_convergence( + problem_name, + all_results, + spec["configs"], + n_iterations, + suffix=subset_name, + title_extra=spec["title"], + ) + + +def plot_summary_bar_chart( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean final best across all problems for each config.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + stds = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + names.append(cname) + means.append(np.mean(finals)) + stds.append(np.std(finals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + xerr=stds, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + capsize=3, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Final Best Reward", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.axvline( + prob.optimal_value, + color="red", + linestyle="--", + alpha=0.5, + label="Optimum", + ) + ax.legend(fontsize=8) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle("NIG vs Normal-TS vs UCT: Final Best Reward", fontsize=14, y=1.02) + fig.tight_layout() + path = OUTPUT_DIR / "summary_bar_chart_nig.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_optimum_rate_heatmap( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Heatmap: optimum-finding rate (config x problem).""" + matrix = [] + for cname in config_names: + row = [] + for prob in problems: + key = (prob.name, cname) + if key not in all_results: + row.append(0.0) + continue + results = all_results[key] + rate = sum(r.found_optimum for r in results) / len(results) + row.append(rate) + matrix.append(row) + + matrix = np.array(matrix) + fig, ax = plt.subplots( + figsize=(max(8, len(problems) * 2), max(6, len(config_names) * 0.5)) + ) + im = ax.imshow(matrix, aspect="auto", cmap="YlGn", vmin=0, vmax=1) + ax.set_xticks(range(len(problems))) + ax.set_xticklabels([p.name for p in problems], rotation=30, ha="right", fontsize=9) + ax.set_yticks(range(len(config_names))) + ax.set_yticklabels(config_names, fontsize=9) + + for i in range(len(config_names)): + for j in range(len(problems)): + val = matrix[i, j] + color = "white" if val > 0.6 else "black" + ax.text( + j, i, f"{val:.0%}", ha="center", va="center", fontsize=9, color=color + ) + + ax.set_title("NIG vs Normal-TS vs UCT: Optimum-Finding Rate", fontsize=13) + fig.colorbar(im, ax=ax, label="Rate", shrink=0.8) + fig.tight_layout() + path = OUTPUT_DIR / "optimum_rate_heatmap_nig.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_unique_evals( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean number of unique evaluations per config per problem.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + evals = [r.n_unique_evals for r in results] + names.append(cname) + means.append(np.mean(evals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Unique Evaluations", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle( + "NIG vs Normal-TS vs UCT: Unique Selections Evaluated", fontsize=14, y=1.02 + ) + fig.tight_layout() + path = OUTPUT_DIR / "unique_evals_nig.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def save_summary_json(problems, all_results, config_names): + """Save numeric results as JSON for reproducibility.""" + summary = {} + for prob in problems: + prob_summary = {} + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + opt_rates = [r.found_optimum for r in results] + unique_evals = [r.n_unique_evals for r in results] + prob_summary[cname] = { + "mean_best": float(np.mean(finals)), + "std_best": float(np.std(finals)), + "median_best": float(np.median(finals)), + "optimum_rate": float(np.mean(opt_rates)), + "mean_unique_evals": float(np.mean(unique_evals)), + } + summary[prob.name] = prob_summary + + path = OUTPUT_DIR / "results_nig.json" + with open(path, "w") as f: + json.dump(summary, f, indent=2) + print(f"Saved: {path}") + return summary + + +# ====================================================================== +# Main +# ====================================================================== + + +def main(): + print("MCTS Normal-Inverse-Gamma (NIG) Benchmark") + print("=" * 70) + + problems = [ + make_problem_multigroup_interaction(), + make_problem_needle(), + make_problem_mixed(), + make_problem_large_sparse(), + make_problem_graduated(), + make_problem_simple_additive(), + ] + + for p in problems: + print( + f" {p.name}: ~{p.search_space_size:,} combinations, " + f"{p.n_iterations} iters x {p.n_trials} trials" + ) + + t_start = time.time() + all_results = run_benchmark(problems) + total_time = time.time() - t_start + print(f"\nTotal benchmark time: {total_time:.1f}s") + + # Generate plots + print("\nGenerating plots...") + for prob in problems: + plot_convergence(prob.name, all_results, ALL_CONFIG_NAMES, prob.n_iterations) + plot_convergence_subsets(prob.name, all_results, prob.n_iterations) + + plot_summary_bar_chart(problems, all_results, ALL_CONFIG_NAMES) + plot_optimum_rate_heatmap(problems, all_results, ALL_CONFIG_NAMES) + plot_unique_evals(problems, all_results, ALL_CONFIG_NAMES) + + # Save numeric results + summary = save_summary_json(problems, all_results, ALL_CONFIG_NAMES) + + # Print summary table + print("\n" + "=" * 90) + print("SUMMARY TABLE") + print("=" * 90) + for prob in problems: + print( + f"\n{prob.name} (search space: ~{prob.search_space_size:,}, " + f"optimum: {prob.optimal_value})" + ) + print( + f" {'Config':<35s} {'Mean Best':>10s} {'+-Std':>8s} " + f"{'Opt Rate':>10s} {'Uniq Evals':>12s}" + ) + print(f" {'-'*75}") + for cname in ALL_CONFIG_NAMES: + d = summary[prob.name].get(cname) + if d is None: + continue + print( + f" {cname:<35s} {d['mean_best']:10.1f} {d['std_best']:8.1f} " + f"{d['optimum_rate']:10.0%} {d['mean_unique_evals']:12.0f}" + ) + + print(f"\nAll outputs saved to: {OUTPUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/mcts-report/benchmark_nig_adaptive.py b/mcts-report/benchmark_nig_adaptive.py new file mode 100644 index 000000000..fbc3c9561 --- /dev/null +++ b/mcts-report/benchmark_nig_adaptive.py @@ -0,0 +1,547 @@ +"""Benchmark: Adaptive Pessimistic Strength for NIG MCTS. + +Tests adaptive pessimistic modes that scale the pessimistic offset by each +node's local exhaustion rate: exhaustion = 1 - (n_obs / n_visits). Fresh +nodes get mild pessimism; exhausted nodes get full pessimism. Zero new +hyperparameters. + +Two new cache-hit modes: + - adaptive_pessimistic: pessimistic pseudo-obs scaled by exhaustion + - adaptive_combined: variance inflation + adaptive pessimistic pseudo-obs + +Configurations tested: + Reference baselines (4): + 1. Random + 2. UCT (+rpol) — best UCT config + 3. NIG + TS(g,a) + vi + apv — best on hard problems + 4. NIG + TS(g,a) + comb + apv — best on smooth problems + + Adaptive configs (5): + 5. NIG + TS(g,a) + acomb + apv — adaptive combined + adaptive prior var + 6. NIG + TS(g,a) + acomb — adaptive combined only + 7. NIG + TS(g,a) + apess + apv — adaptive pessimistic + adaptive prior var + 8. NIG + TS(g,a) + apess — adaptive pessimistic only + 9. NIG + uniform + apess + apv — uniform rollout + adaptive pessimistic + apv + +Usage: + python mcts-report/benchmark_nig_adaptive.py +""" + +import json +import sys +import time +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from benchmark import ( + Problem, + make_problem_graduated, + make_problem_large_sparse, + make_problem_mixed, + make_problem_multigroup_interaction, + make_problem_needle, + make_problem_simple_additive, + run_random_baseline, +) +from benchmark_nig import NIGConfig, run_nig_config +from benchmark_ts import UCTConfig, run_uct_config + + +OUTPUT_DIR = Path(__file__).parent + + +# ====================================================================== +# Configurations +# ====================================================================== + +# UCT reference +UCT_REF = UCTConfig( + name="UCT (+rpol)", + c_uct=0.01, + k_rave=0, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, +) + +# NIG reference baselines (the two best fixed configs) +NIG_VI_APV = NIGConfig( + name="NIG + TS(g,a) + vi + apv", + rollout_mode="ts_group_action", + cache_hit_mode="variance_inflation", + adaptive_prior_var=True, +) + +NIG_COMB_APV = NIGConfig( + name="NIG + TS(g,a) + comb + apv", + rollout_mode="ts_group_action", + cache_hit_mode="combined", + adaptive_prior_var=True, +) + +# Adaptive configs +ADAPTIVE_CONFIGS = [ + NIGConfig( + name="NIG + TS(g,a) + acomb + apv", + rollout_mode="ts_group_action", + cache_hit_mode="adaptive_combined", + adaptive_prior_var=True, + ), + NIGConfig( + name="NIG + TS(g,a) + acomb", + rollout_mode="ts_group_action", + cache_hit_mode="adaptive_combined", + ), + NIGConfig( + name="NIG + TS(g,a) + apess + apv", + rollout_mode="ts_group_action", + cache_hit_mode="adaptive_pessimistic", + adaptive_prior_var=True, + ), + NIGConfig( + name="NIG + TS(g,a) + apess", + rollout_mode="ts_group_action", + cache_hit_mode="adaptive_pessimistic", + ), + NIGConfig( + name="NIG + uniform + apess + apv", + rollout_mode="uniform", + cache_hit_mode="adaptive_pessimistic", + adaptive_prior_var=True, + ), +] + +NIG_REFS = [NIG_VI_APV, NIG_COMB_APV] + +ALL_CONFIG_NAMES = ( + ["Random", UCT_REF.name] + + [c.name for c in NIG_REFS] + + [c.name for c in ADAPTIVE_CONFIGS] +) + + +# ====================================================================== +# Benchmark runner +# ====================================================================== + + +def run_benchmark(problems: list[Problem]): + """Run all configurations on all problems, collecting results.""" + all_results = {} + + for prob in problems: + print(f"\n{'='*70}") + print(f"Problem: {prob.name}") + print(f" {prob.description}") + print(f" Search space: ~{prob.search_space_size:,} combinations") + print(f" Budget: {prob.n_iterations} iterations x {prob.n_trials} trials") + print(f"{'='*70}") + + # Random baseline + key = (prob.name, "Random") + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_random_baseline(prob, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {'Random':35s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # UCT reference + key = (prob.name, UCT_REF.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_uct_config(prob, UCT_REF, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {UCT_REF.name:35s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # NIG references + for cfg in NIG_REFS: + key = (prob.name, cfg.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_nig_config(prob, cfg, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {cfg.name:35s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # Adaptive configs + for cfg in ADAPTIVE_CONFIGS: + key = (prob.name, cfg.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_nig_config(prob, cfg, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {cfg.name:35s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + return all_results + + +# ====================================================================== +# Plotting +# ====================================================================== + +COLOR_MAP = { + "Random": "#888888", + "UCT (+rpol)": "#1f77b4", + "NIG + TS(g,a) + vi + apv": "#e377c2", + "NIG + TS(g,a) + comb + apv": "#8c564b", + "NIG + TS(g,a) + acomb + apv": "#d62728", + "NIG + TS(g,a) + acomb": "#ff7f0e", + "NIG + TS(g,a) + apess + apv": "#2ca02c", + "NIG + TS(g,a) + apess": "#98df8a", + "NIG + uniform + apess + apv": "#9467bd", +} + + +def plot_convergence( + problem_name: str, + all_results: dict, + config_names: list[str], + n_iterations: int, + suffix: str = "", + title_extra: str = "", +): + """Plot mean convergence curves with shaded +/-1 std region.""" + fig, ax = plt.subplots(figsize=(10, 6)) + + for cname in config_names: + key = (problem_name, cname) + if key not in all_results: + continue + results = all_results[key] + curves = np.array([r.best_values_over_time for r in results]) + mean = curves.mean(axis=0) + std = curves.std(axis=0) + x = np.arange(1, n_iterations + 1) + color = COLOR_MAP.get(cname, None) + ax.plot(x, mean, label=cname, color=color, linewidth=1.5) + ax.fill_between(x, mean - std, mean + std, alpha=0.12, color=color) + + ax.set_xlabel("Iteration", fontsize=12) + ax.set_ylabel("Best Reward Found", fontsize=12) + title = f"Convergence: {problem_name}" + if title_extra: + title += f" — {title_extra}" + ax.set_title(title, fontsize=14) + ax.legend(fontsize=8, loc="lower right", ncol=2) + ax.grid(True, alpha=0.3) + fig.tight_layout() + fname = f"convergence_nig_adaptive_{problem_name}" + if suffix: + fname += f"_{suffix}" + path = OUTPUT_DIR / f"{fname}.png" + fig.savefig(path, dpi=150) + plt.close(fig) + print(f" Saved: {path}") + + +def plot_convergence_subsets(problem_name, all_results, n_iterations): + """Plot focused subset: adaptive vs fixed cache-hit modes.""" + subsets = { + "adaptive_vs_fixed": { + "configs": [ + "UCT (+rpol)", + "NIG + TS(g,a) + vi + apv", + "NIG + TS(g,a) + comb + apv", + "NIG + TS(g,a) + acomb + apv", + "NIG + TS(g,a) + apess + apv", + ], + "title": "Adaptive vs Fixed Cache-Hit Modes", + }, + } + + for subset_name, spec in subsets.items(): + plot_convergence( + problem_name, + all_results, + spec["configs"], + n_iterations, + suffix=subset_name, + title_extra=spec["title"], + ) + + +def plot_summary_bar_chart( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean final best across all problems for each config.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + stds = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + names.append(cname) + means.append(np.mean(finals)) + stds.append(np.std(finals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + xerr=stds, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + capsize=3, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Final Best Reward", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.axvline( + prob.optimal_value, + color="red", + linestyle="--", + alpha=0.5, + label="Optimum", + ) + ax.legend(fontsize=8) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle("Adaptive Pessimistic: Final Best Reward", fontsize=14, y=1.02) + fig.tight_layout() + path = OUTPUT_DIR / "summary_bar_chart_nig_adaptive.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_optimum_rate_heatmap( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Heatmap: optimum-finding rate (config x problem).""" + matrix = [] + for cname in config_names: + row = [] + for prob in problems: + key = (prob.name, cname) + if key not in all_results: + row.append(0.0) + continue + results = all_results[key] + rate = sum(r.found_optimum for r in results) / len(results) + row.append(rate) + matrix.append(row) + + matrix = np.array(matrix) + fig, ax = plt.subplots( + figsize=(max(8, len(problems) * 2), max(6, len(config_names) * 0.5)) + ) + im = ax.imshow(matrix, aspect="auto", cmap="YlGn", vmin=0, vmax=1) + ax.set_xticks(range(len(problems))) + ax.set_xticklabels([p.name for p in problems], rotation=30, ha="right", fontsize=9) + ax.set_yticks(range(len(config_names))) + ax.set_yticklabels(config_names, fontsize=9) + + for i in range(len(config_names)): + for j in range(len(problems)): + val = matrix[i, j] + color = "white" if val > 0.6 else "black" + ax.text( + j, i, f"{val:.0%}", ha="center", va="center", fontsize=9, color=color + ) + + ax.set_title("Adaptive Pessimistic: Optimum-Finding Rate", fontsize=13) + fig.colorbar(im, ax=ax, label="Rate", shrink=0.8) + fig.tight_layout() + path = OUTPUT_DIR / "optimum_rate_heatmap_nig_adaptive.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_unique_evals( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean number of unique evaluations per config per problem.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + evals = [r.n_unique_evals for r in results] + names.append(cname) + means.append(np.mean(evals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Unique Evaluations", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle( + "Adaptive Pessimistic: Unique Selections Evaluated", fontsize=14, y=1.02 + ) + fig.tight_layout() + path = OUTPUT_DIR / "unique_evals_nig_adaptive.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def save_summary_json(problems, all_results, config_names): + """Save numeric results as JSON for reproducibility.""" + summary = {} + for prob in problems: + prob_summary = {} + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + opt_rates = [r.found_optimum for r in results] + unique_evals = [r.n_unique_evals for r in results] + prob_summary[cname] = { + "mean_best": float(np.mean(finals)), + "std_best": float(np.std(finals)), + "median_best": float(np.median(finals)), + "optimum_rate": float(np.mean(opt_rates)), + "mean_unique_evals": float(np.mean(unique_evals)), + } + summary[prob.name] = prob_summary + + path = OUTPUT_DIR / "results_nig_adaptive.json" + with open(path, "w") as f: + json.dump(summary, f, indent=2) + print(f"Saved: {path}") + return summary + + +# ====================================================================== +# Main +# ====================================================================== + + +def main(): + print("MCTS NIG Adaptive Pessimistic Strength Benchmark") + print("=" * 70) + + problems = [ + make_problem_multigroup_interaction(), + make_problem_needle(), + make_problem_mixed(), + make_problem_large_sparse(), + make_problem_graduated(), + make_problem_simple_additive(), + ] + + for p in problems: + print( + f" {p.name}: ~{p.search_space_size:,} combinations, " + f"{p.n_iterations} iters x {p.n_trials} trials" + ) + + t_start = time.time() + all_results = run_benchmark(problems) + total_time = time.time() - t_start + print(f"\nTotal benchmark time: {total_time:.1f}s") + + # Generate plots + print("\nGenerating plots...") + for prob in problems: + plot_convergence(prob.name, all_results, ALL_CONFIG_NAMES, prob.n_iterations) + plot_convergence_subsets(prob.name, all_results, prob.n_iterations) + + plot_summary_bar_chart(problems, all_results, ALL_CONFIG_NAMES) + plot_optimum_rate_heatmap(problems, all_results, ALL_CONFIG_NAMES) + plot_unique_evals(problems, all_results, ALL_CONFIG_NAMES) + + # Save numeric results + summary = save_summary_json(problems, all_results, ALL_CONFIG_NAMES) + + # Print summary table + print("\n" + "=" * 95) + print("SUMMARY TABLE") + print("=" * 95) + for prob in problems: + print( + f"\n{prob.name} (search space: ~{prob.search_space_size:,}, " + f"optimum: {prob.optimal_value})" + ) + print( + f" {'Config':<40s} {'Mean Best':>10s} {'+-Std':>8s} " + f"{'Opt Rate':>10s} {'Uniq Evals':>12s}" + ) + print(f" {'-'*80}") + for cname in ALL_CONFIG_NAMES: + d = summary[prob.name].get(cname) + if d is None: + continue + print( + f" {cname:<40s} {d['mean_best']:10.1f} {d['std_best']:8.1f} " + f"{d['optimum_rate']:10.0%} {d['mean_unique_evals']:12.0f}" + ) + + print(f"\nAll outputs saved to: {OUTPUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/mcts-report/benchmark_nig_adaptive_n0.py b/mcts-report/benchmark_nig_adaptive_n0.py new file mode 100644 index 000000000..57066efa8 --- /dev/null +++ b/mcts-report/benchmark_nig_adaptive_n0.py @@ -0,0 +1,546 @@ +"""Benchmark: Adaptive Pseudo-Count n0 from Branching Factor. + +Tests adaptive n0 = 1 + log(branching_factor), where branching_factor is the +number of legal actions at the parent node. Higher branching means each child +is visited rarely during early exploration, so n0 > 1 keeps the posterior +closer to the prior until enough observations accumulate. + +With 2 actions: n0 ~ 1.7. With 11 actions (large_sparse root): n0 ~ 3.4. +With 30 actions: n0 ~ 4.4. Zero new hyperparameters. + +Configurations tested: + Reference baselines (4): + 1. Random + 2. UCT (+rpol) -- best UCT config + 3. NIG + TS(g,a) + vi + apv -- current default + 4. NIG + TS(g,a) + apess -- best on large_sparse + + Adaptive n0 configs (5): + 5. NIG + TS(g,a) + vi + apv + an0 -- variance inflation + adaptive n0 + 6. NIG + TS(g,a) + apess + an0 -- adaptive pessimistic + adaptive n0 + 7. NIG + TS(g,a) + acomb + an0 -- adaptive combined + adaptive n0 + 8. NIG + TS(g,a) + an0 -- no cache-hit update + adaptive n0 + 9. NIG + uniform + an0 -- uniform rollout + adaptive n0 + +Usage: + python mcts-report/benchmark_nig_adaptive_n0.py +""" + +import json +import sys +import time +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from benchmark import ( + Problem, + make_problem_graduated, + make_problem_large_sparse, + make_problem_mixed, + make_problem_multigroup_interaction, + make_problem_needle, + make_problem_simple_additive, + run_random_baseline, +) +from benchmark_nig import NIGConfig, run_nig_config +from benchmark_ts import UCTConfig, run_uct_config + + +OUTPUT_DIR = Path(__file__).parent + + +# ====================================================================== +# Configurations +# ====================================================================== + +# UCT reference +UCT_REF = UCTConfig( + name="UCT (+rpol)", + c_uct=0.01, + k_rave=0, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, +) + +# NIG reference baselines (best fixed configs) +NIG_VI_APV = NIGConfig( + name="NIG + TS(g,a) + vi + apv", + rollout_mode="ts_group_action", + cache_hit_mode="variance_inflation", + adaptive_prior_var=True, +) + +NIG_APESS = NIGConfig( + name="NIG + TS(g,a) + apess", + rollout_mode="ts_group_action", + cache_hit_mode="adaptive_pessimistic", +) + +# Adaptive n0 configs +ADAPTIVE_N0_CONFIGS = [ + NIGConfig( + name="NIG + TS(g,a) + vi + apv + an0", + rollout_mode="ts_group_action", + cache_hit_mode="variance_inflation", + adaptive_prior_var=True, + adaptive_n0=True, + ), + NIGConfig( + name="NIG + TS(g,a) + apess + an0", + rollout_mode="ts_group_action", + cache_hit_mode="adaptive_pessimistic", + adaptive_n0=True, + ), + NIGConfig( + name="NIG + TS(g,a) + acomb + an0", + rollout_mode="ts_group_action", + cache_hit_mode="adaptive_combined", + adaptive_n0=True, + ), + NIGConfig( + name="NIG + TS(g,a) + an0", + rollout_mode="ts_group_action", + cache_hit_mode="no_update", + adaptive_n0=True, + ), + NIGConfig( + name="NIG + uniform + an0", + rollout_mode="uniform", + cache_hit_mode="no_update", + adaptive_n0=True, + ), +] + +NIG_REFS = [NIG_VI_APV, NIG_APESS] + +ALL_CONFIG_NAMES = ( + ["Random", UCT_REF.name] + + [c.name for c in NIG_REFS] + + [c.name for c in ADAPTIVE_N0_CONFIGS] +) + + +# ====================================================================== +# Benchmark runner +# ====================================================================== + + +def run_benchmark(problems: list[Problem]): + """Run all configurations on all problems, collecting results.""" + all_results = {} + + for prob in problems: + print(f"\n{'='*70}") + print(f"Problem: {prob.name}") + print(f" {prob.description}") + print(f" Search space: ~{prob.search_space_size:,} combinations") + print(f" Budget: {prob.n_iterations} iterations x {prob.n_trials} trials") + print(f"{'='*70}") + + # Random baseline + key = (prob.name, "Random") + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_random_baseline(prob, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {'Random':40s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # UCT reference + key = (prob.name, UCT_REF.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_uct_config(prob, UCT_REF, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {UCT_REF.name:40s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # NIG references + for cfg in NIG_REFS: + key = (prob.name, cfg.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_nig_config(prob, cfg, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {cfg.name:40s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # Adaptive n0 configs + for cfg in ADAPTIVE_N0_CONFIGS: + key = (prob.name, cfg.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_nig_config(prob, cfg, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {cfg.name:40s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + return all_results + + +# ====================================================================== +# Plotting +# ====================================================================== + +COLOR_MAP = { + "Random": "#888888", + "UCT (+rpol)": "#1f77b4", + "NIG + TS(g,a) + vi + apv": "#e377c2", + "NIG + TS(g,a) + apess": "#98df8a", + "NIG + TS(g,a) + vi + apv + an0": "#d62728", + "NIG + TS(g,a) + apess + an0": "#2ca02c", + "NIG + TS(g,a) + acomb + an0": "#ff7f0e", + "NIG + TS(g,a) + an0": "#9467bd", + "NIG + uniform + an0": "#17becf", +} + + +def plot_convergence( + problem_name: str, + all_results: dict, + config_names: list[str], + n_iterations: int, + suffix: str = "", + title_extra: str = "", +): + """Plot mean convergence curves with shaded +/-1 std region.""" + fig, ax = plt.subplots(figsize=(10, 6)) + + for cname in config_names: + key = (problem_name, cname) + if key not in all_results: + continue + results = all_results[key] + curves = np.array([r.best_values_over_time for r in results]) + mean = curves.mean(axis=0) + std = curves.std(axis=0) + x = np.arange(1, n_iterations + 1) + color = COLOR_MAP.get(cname, None) + ax.plot(x, mean, label=cname, color=color, linewidth=1.5) + ax.fill_between(x, mean - std, mean + std, alpha=0.12, color=color) + + ax.set_xlabel("Iteration", fontsize=12) + ax.set_ylabel("Best Reward Found", fontsize=12) + title = f"Convergence: {problem_name}" + if title_extra: + title += f" — {title_extra}" + ax.set_title(title, fontsize=14) + ax.legend(fontsize=8, loc="lower right", ncol=2) + ax.grid(True, alpha=0.3) + fig.tight_layout() + fname = f"convergence_nig_adaptive_n0_{problem_name}" + if suffix: + fname += f"_{suffix}" + path = OUTPUT_DIR / f"{fname}.png" + fig.savefig(path, dpi=150) + plt.close(fig) + print(f" Saved: {path}") + + +def plot_convergence_subsets(problem_name, all_results, n_iterations): + """Plot focused subset: adaptive n0 vs fixed n0 for matched configs.""" + subsets = { + "n0_effect": { + "configs": [ + "UCT (+rpol)", + "NIG + TS(g,a) + vi + apv", + "NIG + TS(g,a) + vi + apv + an0", + "NIG + TS(g,a) + apess", + "NIG + TS(g,a) + apess + an0", + ], + "title": "Adaptive n0 vs Fixed n0", + }, + } + + for subset_name, spec in subsets.items(): + plot_convergence( + problem_name, + all_results, + spec["configs"], + n_iterations, + suffix=subset_name, + title_extra=spec["title"], + ) + + +def plot_summary_bar_chart( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean final best across all problems for each config.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + stds = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + names.append(cname) + means.append(np.mean(finals)) + stds.append(np.std(finals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + xerr=stds, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + capsize=3, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Final Best Reward", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.axvline( + prob.optimal_value, + color="red", + linestyle="--", + alpha=0.5, + label="Optimum", + ) + ax.legend(fontsize=8) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle("Adaptive n0: Final Best Reward", fontsize=14, y=1.02) + fig.tight_layout() + path = OUTPUT_DIR / "summary_bar_chart_nig_adaptive_n0.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_optimum_rate_heatmap( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Heatmap: optimum-finding rate (config x problem).""" + matrix = [] + for cname in config_names: + row = [] + for prob in problems: + key = (prob.name, cname) + if key not in all_results: + row.append(0.0) + continue + results = all_results[key] + rate = sum(r.found_optimum for r in results) / len(results) + row.append(rate) + matrix.append(row) + + matrix = np.array(matrix) + fig, ax = plt.subplots( + figsize=(max(8, len(problems) * 2), max(6, len(config_names) * 0.5)) + ) + im = ax.imshow(matrix, aspect="auto", cmap="YlGn", vmin=0, vmax=1) + ax.set_xticks(range(len(problems))) + ax.set_xticklabels([p.name for p in problems], rotation=30, ha="right", fontsize=9) + ax.set_yticks(range(len(config_names))) + ax.set_yticklabels(config_names, fontsize=9) + + for i in range(len(config_names)): + for j in range(len(problems)): + val = matrix[i, j] + color = "white" if val > 0.6 else "black" + ax.text( + j, i, f"{val:.0%}", ha="center", va="center", fontsize=9, color=color + ) + + ax.set_title("Adaptive n0: Optimum-Finding Rate", fontsize=13) + fig.colorbar(im, ax=ax, label="Rate", shrink=0.8) + fig.tight_layout() + path = OUTPUT_DIR / "optimum_rate_heatmap_nig_adaptive_n0.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_unique_evals( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean number of unique evaluations per config per problem.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + evals = [r.n_unique_evals for r in results] + names.append(cname) + means.append(np.mean(evals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Unique Evaluations", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle("Adaptive n0: Unique Selections Evaluated", fontsize=14, y=1.02) + fig.tight_layout() + path = OUTPUT_DIR / "unique_evals_nig_adaptive_n0.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def save_summary_json(problems, all_results, config_names): + """Save numeric results as JSON for reproducibility.""" + summary = {} + for prob in problems: + prob_summary = {} + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + opt_rates = [r.found_optimum for r in results] + unique_evals = [r.n_unique_evals for r in results] + prob_summary[cname] = { + "mean_best": float(np.mean(finals)), + "std_best": float(np.std(finals)), + "median_best": float(np.median(finals)), + "optimum_rate": float(np.mean(opt_rates)), + "mean_unique_evals": float(np.mean(unique_evals)), + } + summary[prob.name] = prob_summary + + path = OUTPUT_DIR / "results_nig_adaptive_n0.json" + with open(path, "w") as f: + json.dump(summary, f, indent=2) + print(f"Saved: {path}") + return summary + + +# ====================================================================== +# Main +# ====================================================================== + + +def main(): + print("MCTS NIG Adaptive Pseudo-Count n0 Benchmark") + print("=" * 70) + + problems = [ + make_problem_multigroup_interaction(), + make_problem_needle(), + make_problem_mixed(), + make_problem_large_sparse(), + make_problem_graduated(), + make_problem_simple_additive(), + ] + + for p in problems: + print( + f" {p.name}: ~{p.search_space_size:,} combinations, " + f"{p.n_iterations} iters x {p.n_trials} trials" + ) + + t_start = time.time() + all_results = run_benchmark(problems) + total_time = time.time() - t_start + print(f"\nTotal benchmark time: {total_time:.1f}s") + + # Generate plots + print("\nGenerating plots...") + for prob in problems: + plot_convergence(prob.name, all_results, ALL_CONFIG_NAMES, prob.n_iterations) + plot_convergence_subsets(prob.name, all_results, prob.n_iterations) + + plot_summary_bar_chart(problems, all_results, ALL_CONFIG_NAMES) + plot_optimum_rate_heatmap(problems, all_results, ALL_CONFIG_NAMES) + plot_unique_evals(problems, all_results, ALL_CONFIG_NAMES) + + # Save numeric results + summary = save_summary_json(problems, all_results, ALL_CONFIG_NAMES) + + # Print summary table + print("\n" + "=" * 95) + print("SUMMARY TABLE") + print("=" * 95) + for prob in problems: + print( + f"\n{prob.name} (search space: ~{prob.search_space_size:,}, " + f"optimum: {prob.optimal_value})" + ) + print( + f" {'Config':<45s} {'Mean Best':>10s} {'+-Std':>8s} " + f"{'Opt Rate':>10s} {'Uniq Evals':>12s}" + ) + print(f" {'-'*85}") + for cname in ALL_CONFIG_NAMES: + d = summary[prob.name].get(cname) + if d is None: + continue + print( + f" {cname:<45s} {d['mean_best']:10.1f} {d['std_best']:8.1f} " + f"{d['optimum_rate']:10.0%} {d['mean_unique_evals']:12.0f}" + ) + + print(f"\nAll outputs saved to: {OUTPUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/mcts-report/benchmark_ts.py b/mcts-report/benchmark_ts.py new file mode 100644 index 000000000..b7d7d3d0c --- /dev/null +++ b/mcts-report/benchmark_ts.py @@ -0,0 +1,740 @@ +"""Benchmark: Thompson Sampling MCTS vs UCT MCTS. + +Compares TS-based MCTS variants against the best UCT configurations +on 6 combinatorial NChooseK benchmark problems. + +Configurations tested: + UCT references: + 1. UCT (+rpol) — c_uct=0.01, no RAVE, adaptive p_stop, norm, rollout policy + 2. UCT (no rpol) — same but no rollout policy + TS variants: + 1. TS + uniform — TS tree, uniform rollout + 2. TS + TS(g,a) — TS tree + TS rollout keyed by (group, action) + 3. TS + TS(g,a) + var_infl — same + variance inflation on cache hits + 4. TS + TS(g,c,a) — TS tree + TS rollout keyed by (group, cardinality, action) + 5. TS + TS(g,c,a) + var_infl — same + variance inflation + 6. TS + softmax rpol — TS tree + existing softmax rollout (hybrid) + +Usage: + python mcts-report/benchmark_ts.py +""" + +import json +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np + + +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from benchmark import ( + Problem, + ProblemResult, + make_problem_graduated, + make_problem_large_sparse, + make_problem_mixed, + make_problem_multigroup_interaction, + make_problem_needle, + make_problem_simple_additive, + run_random_baseline, +) +from optimize_mcts_full import MCTS +from optimize_mcts_ts import MCTS_TS + + +OUTPUT_DIR = Path(__file__).parent + + +# ═══════════════════════════════════════════════════════════════════════ +# Configurations +# ═══════════════════════════════════════════════════════════════════════ + + +@dataclass +class TSConfig: + """Configuration for MCTS_TS benchmarking.""" + + name: str + ts_prior_var: float = 1.0 + adaptive_prior_var: bool = False + cache_hit_mode: str = "no_update" + variance_decay: float = 0.95 + rollout_mode: str = "uniform" + pw_k0: float = 2.0 + pw_alpha: float = 0.6 + # Softmax fallback params + rollout_epsilon: float = 0.3 + rollout_tau: float = 1.0 + rollout_novelty_weight: float = 1.0 + normalize_rewards: bool = True + adaptive_p_stop: bool = True + p_stop_rollout: float = 0.35 + p_stop_warmup: int = 20 + p_stop_temperature: float = 0.25 + + +@dataclass +class UCTConfig: + """Configuration for UCT MCTS (reference baselines).""" + + name: str + c_uct: float = 0.01 + k_rave: float = 0.0 + adaptive_p_stop: bool = True + normalize_rewards: bool = True + rollout_policy: bool = True + rollout_epsilon: float = 0.3 + rollout_tau: float = 1.0 + rollout_novelty_weight: float = 1.0 + p_stop_rollout: float = 0.35 + p_stop_warmup: int = 20 + p_stop_temperature: float = 0.25 + + +# UCT reference configs +UCT_CONFIGS = [ + UCTConfig( + name="UCT (+rpol)", + c_uct=0.01, + k_rave=0, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=True, + ), + UCTConfig( + name="UCT (no rpol)", + c_uct=0.01, + k_rave=0, + adaptive_p_stop=True, + normalize_rewards=True, + rollout_policy=False, + ), +] + +# TS configs +TS_CONFIGS = [ + TSConfig( + name="TS + uniform", + rollout_mode="uniform", + ), + TSConfig( + name="TS + TS(g,a)", + rollout_mode="ts_group_action", + ), + TSConfig( + name="TS + TS(g,a) + var_infl", + rollout_mode="ts_group_action", + cache_hit_mode="variance_inflation", + ), + TSConfig( + name="TS + TS(g,c,a)", + rollout_mode="ts_group_card_action", + ), + TSConfig( + name="TS + TS(g,c,a) + var_infl", + rollout_mode="ts_group_card_action", + cache_hit_mode="variance_inflation", + ), + TSConfig( + name="TS + softmax rpol", + rollout_mode="softmax", + adaptive_p_stop=True, + normalize_rewards=True, + ), + # Adaptive prior variance configs + TSConfig( + name="TS + uniform + adpt_pv", + rollout_mode="uniform", + adaptive_prior_var=True, + ), + TSConfig( + name="TS + TS(g,a) + adpt_pv", + rollout_mode="ts_group_action", + adaptive_prior_var=True, + ), + TSConfig( + name="TS + TS(g,a) + vi + apv", + rollout_mode="ts_group_action", + cache_hit_mode="variance_inflation", + adaptive_prior_var=True, + ), + # Pessimistic pseudo-observation configs + TSConfig( + name="TS + TS(g,a) + pess", + rollout_mode="ts_group_action", + cache_hit_mode="pessimistic", + ), + TSConfig( + name="TS + TS(g,a) + pess + apv", + rollout_mode="ts_group_action", + cache_hit_mode="pessimistic", + adaptive_prior_var=True, + ), + TSConfig( + name="TS + uniform + pess + apv", + rollout_mode="uniform", + cache_hit_mode="pessimistic", + adaptive_prior_var=True, + ), + # Combined (variance inflation + pessimistic) configs + TSConfig( + name="TS + TS(g,a) + comb", + rollout_mode="ts_group_action", + cache_hit_mode="combined", + ), + TSConfig( + name="TS + TS(g,a) + comb + apv", + rollout_mode="ts_group_action", + cache_hit_mode="combined", + adaptive_prior_var=True, + ), +] + +ALL_CONFIG_NAMES = ( + ["Random"] + [c.name for c in UCT_CONFIGS] + [c.name for c in TS_CONFIGS] +) + + +# ═══════════════════════════════════════════════════════════════════════ +# Run functions +# ═══════════════════════════════════════════════════════════════════════ + + +def run_uct_config(problem: Problem, config: UCTConfig, seed: int) -> ProblemResult: + """Run UCT MCTS with given config, tracking best value per iteration.""" + mcts = MCTS( + groups=problem.groups, + reward_fn=problem.reward_fn, + c_uct=config.c_uct, + k_rave=config.k_rave, + adaptive_p_stop=config.adaptive_p_stop, + normalize_rewards=config.normalize_rewards, + rollout_policy=config.rollout_policy, + rollout_epsilon=config.rollout_epsilon, + rollout_tau=config.rollout_tau, + rollout_novelty_weight=config.rollout_novelty_weight, + p_stop_rollout=config.p_stop_rollout, + p_stop_warmup=config.p_stop_warmup, + p_stop_temperature=config.p_stop_temperature, + seed=seed, + ) + + best_values = [] + for _ in range(problem.n_iterations): + mcts.run(n_iterations=1) + best_values.append(mcts.best_value) + + stats = mcts.cache_stats() + return ProblemResult( + best_values_over_time=best_values, + final_best=mcts.best_value, + found_optimum=abs(mcts.best_value - problem.optimal_value) < 1e-6, + n_unique_evals=stats["misses"], + ) + + +def run_ts_config(problem: Problem, config: TSConfig, seed: int) -> ProblemResult: + """Run TS MCTS with given config, tracking best value per iteration.""" + mcts = MCTS_TS( + groups=problem.groups, + reward_fn=problem.reward_fn, + ts_prior_var=config.ts_prior_var, + adaptive_prior_var=config.adaptive_prior_var, + cache_hit_mode=config.cache_hit_mode, + variance_decay=config.variance_decay, + rollout_mode=config.rollout_mode, + pw_k0=config.pw_k0, + pw_alpha=config.pw_alpha, + seed=seed, + rollout_epsilon=config.rollout_epsilon, + rollout_tau=config.rollout_tau, + rollout_novelty_weight=config.rollout_novelty_weight, + normalize_rewards=config.normalize_rewards, + adaptive_p_stop=config.adaptive_p_stop, + p_stop_rollout=config.p_stop_rollout, + p_stop_warmup=config.p_stop_warmup, + p_stop_temperature=config.p_stop_temperature, + ) + + best_values = [] + for _ in range(problem.n_iterations): + mcts.run(n_iterations=1) + best_values.append(mcts.best_value) + + stats = mcts.cache_stats() + return ProblemResult( + best_values_over_time=best_values, + final_best=mcts.best_value, + found_optimum=abs(mcts.best_value - problem.optimal_value) < 1e-6, + n_unique_evals=stats["misses"], + ) + + +# ═══════════════════════════════════════════════════════════════════════ +# Benchmark runner +# ═══════════════════════════════════════════════════════════════════════ + + +def run_benchmark(problems: list[Problem]): + """Run all configurations on all problems, collecting results.""" + all_results = {} + + for prob in problems: + print(f"\n{'='*70}") + print(f"Problem: {prob.name}") + print(f" {prob.description}") + print(f" Search space: ~{prob.search_space_size:,} combinations") + print(f" Budget: {prob.n_iterations} iterations x {prob.n_trials} trials") + print(f"{'='*70}") + + # Random baseline + key = (prob.name, "Random") + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_random_baseline(prob, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {'Random':25s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # UCT configs + for cfg in UCT_CONFIGS: + key = (prob.name, cfg.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_uct_config(prob, cfg, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {cfg.name:25s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + # TS configs + for cfg in TS_CONFIGS: + key = (prob.name, cfg.name) + results = [] + t0 = time.time() + for trial in range(prob.n_trials): + r = run_ts_config(prob, cfg, seed=trial) + results.append(r) + elapsed = time.time() - t0 + all_results[key] = results + success_rate = sum(r.found_optimum for r in results) / prob.n_trials + mean_best = np.mean([r.final_best for r in results]) + print( + f" {cfg.name:25s} | best={mean_best:7.1f} | " + f"opt_rate={success_rate:.0%} | {elapsed:.1f}s" + ) + + return all_results + + +# ═══════════════════════════════════════════════════════════════════════ +# Plotting +# ═══════════════════════════════════════════════════════════════════════ + +COLOR_MAP = { + "Random": "#888888", + "UCT (+rpol)": "#1f77b4", + "UCT (no rpol)": "#ff7f0e", + "TS + uniform": "#2ca02c", + "TS + TS(g,a)": "#d62728", + "TS + TS(g,a) + var_infl": "#9467bd", + "TS + TS(g,c,a)": "#8c564b", + "TS + TS(g,c,a) + var_infl": "#e377c2", + "TS + softmax rpol": "#17becf", + "TS + uniform + adpt_pv": "#bcbd22", + "TS + TS(g,a) + adpt_pv": "#7f7f7f", + "TS + TS(g,a) + vi + apv": "#aec7e8", + "TS + TS(g,a) + pess": "#ff9896", + "TS + TS(g,a) + pess + apv": "#c5b0d5", + "TS + uniform + pess + apv": "#ffbb78", + "TS + TS(g,a) + comb": "#98df8a", + "TS + TS(g,a) + comb + apv": "#c49c94", +} + + +def plot_convergence( + problem_name: str, + all_results: dict, + config_names: list[str], + n_iterations: int, + suffix: str = "", + title_extra: str = "", +): + """Plot mean convergence curves with shaded +/-1 std region.""" + fig, ax = plt.subplots(figsize=(10, 6)) + + for cname in config_names: + key = (problem_name, cname) + if key not in all_results: + continue + results = all_results[key] + curves = np.array([r.best_values_over_time for r in results]) + mean = curves.mean(axis=0) + std = curves.std(axis=0) + x = np.arange(1, n_iterations + 1) + color = COLOR_MAP.get(cname, None) + ax.plot(x, mean, label=cname, color=color, linewidth=1.5) + ax.fill_between(x, mean - std, mean + std, alpha=0.12, color=color) + + ax.set_xlabel("Iteration", fontsize=12) + ax.set_ylabel("Best Reward Found", fontsize=12) + title = f"Convergence: {problem_name}" + if title_extra: + title += f" — {title_extra}" + ax.set_title(title, fontsize=14) + ax.legend(fontsize=8, loc="lower right", ncol=2) + ax.grid(True, alpha=0.3) + fig.tight_layout() + fname = f"convergence_ts_{problem_name}" + if suffix: + fname += f"_{suffix}" + path = OUTPUT_DIR / f"{fname}.png" + fig.savefig(path, dpi=150) + plt.close(fig) + print(f" Saved: {path}") + + +def plot_convergence_subsets(problem_name, all_results, n_iterations): + """Plot focused subset comparisons: TS vs UCT, rollout modes, variance inflation.""" + subsets = { + "ts_vs_uct": { + "configs": [ + "Random", + "UCT (+rpol)", + "UCT (no rpol)", + "TS + uniform", + "TS + TS(g,a)", + "TS + softmax rpol", + ], + "title": "TS vs UCT", + }, + "ts_rollout_modes": { + "configs": [ + "Random", + "TS + uniform", + "TS + TS(g,a)", + "TS + TS(g,c,a)", + "TS + softmax rpol", + ], + "title": "TS Rollout Modes", + }, + "variance_inflation": { + "configs": [ + "Random", + "TS + TS(g,a)", + "TS + TS(g,a) + var_infl", + "TS + TS(g,c,a)", + "TS + TS(g,c,a) + var_infl", + ], + "title": "Variance Inflation Effect", + }, + "adaptive_prior_var": { + "configs": [ + "Random", + "UCT (+rpol)", + "TS + uniform", + "TS + uniform + adpt_pv", + "TS + TS(g,a) + var_infl", + "TS + TS(g,a) + vi + apv", + ], + "title": "Adaptive Prior Variance Effect", + }, + "pessimistic": { + "configs": [ + "Random", + "UCT (+rpol)", + "TS + TS(g,a) + var_infl", + "TS + TS(g,a) + vi + apv", + "TS + TS(g,a) + pess", + "TS + TS(g,a) + pess + apv", + ], + "title": "Pessimistic vs Variance Inflation", + }, + "combined": { + "configs": [ + "Random", + "UCT (+rpol)", + "TS + TS(g,a) + vi + apv", + "TS + TS(g,a) + pess", + "TS + TS(g,a) + comb", + "TS + TS(g,a) + comb + apv", + ], + "title": "Combined (VI + Pessimistic)", + }, + } + + for subset_name, spec in subsets.items(): + plot_convergence( + problem_name, + all_results, + spec["configs"], + n_iterations, + suffix=subset_name, + title_extra=spec["title"], + ) + + +def plot_summary_bar_chart( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean final best across all problems for each config.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + stds = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + names.append(cname) + means.append(np.mean(finals)) + stds.append(np.std(finals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + xerr=stds, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + capsize=3, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Final Best Reward", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.axvline( + prob.optimal_value, + color="red", + linestyle="--", + alpha=0.5, + label="Optimum", + ) + ax.legend(fontsize=8) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle("TS vs UCT: Final Best Reward", fontsize=14, y=1.02) + fig.tight_layout() + path = OUTPUT_DIR / "summary_bar_chart_ts.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_optimum_rate_heatmap( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Heatmap: optimum-finding rate (config x problem).""" + matrix = [] + for cname in config_names: + row = [] + for prob in problems: + key = (prob.name, cname) + if key not in all_results: + row.append(0.0) + continue + results = all_results[key] + rate = sum(r.found_optimum for r in results) / len(results) + row.append(rate) + matrix.append(row) + + matrix = np.array(matrix) + fig, ax = plt.subplots( + figsize=(max(8, len(problems) * 2), max(6, len(config_names) * 0.5)) + ) + im = ax.imshow(matrix, aspect="auto", cmap="YlGn", vmin=0, vmax=1) + ax.set_xticks(range(len(problems))) + ax.set_xticklabels([p.name for p in problems], rotation=30, ha="right", fontsize=9) + ax.set_yticks(range(len(config_names))) + ax.set_yticklabels(config_names, fontsize=9) + + for i in range(len(config_names)): + for j in range(len(problems)): + val = matrix[i, j] + color = "white" if val > 0.6 else "black" + ax.text( + j, i, f"{val:.0%}", ha="center", va="center", fontsize=9, color=color + ) + + ax.set_title("TS vs UCT: Optimum-Finding Rate", fontsize=13) + fig.colorbar(im, ax=ax, label="Rate", shrink=0.8) + fig.tight_layout() + path = OUTPUT_DIR / "optimum_rate_heatmap_ts.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def plot_unique_evals( + problems: list[Problem], all_results: dict, config_names: list[str] +): + """Bar chart: mean number of unique evaluations per config per problem.""" + fig, axes = plt.subplots( + 1, len(problems), figsize=(5 * len(problems), 6), sharey=False + ) + if len(problems) == 1: + axes = [axes] + + for ax, prob in zip(axes, problems): + names = [] + means = [] + colors = [] + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + evals = [r.n_unique_evals for r in results] + names.append(cname) + means.append(np.mean(evals)) + colors.append(COLOR_MAP.get(cname, "#333333")) + + ax.barh( + range(len(names)), + means, + color=colors, + alpha=0.85, + edgecolor="white", + linewidth=0.5, + ) + ax.set_yticks(range(len(names))) + ax.set_yticklabels(names, fontsize=8) + ax.set_xlabel("Mean Unique Evaluations", fontsize=10) + ax.set_title(prob.name, fontsize=11) + ax.grid(True, axis="x", alpha=0.3) + + fig.suptitle("TS vs UCT: Unique Selections Evaluated", fontsize=14, y=1.02) + fig.tight_layout() + path = OUTPUT_DIR / "unique_evals_ts.png" + fig.savefig(path, dpi=150, bbox_inches="tight") + plt.close(fig) + print(f"Saved: {path}") + + +def save_summary_json(problems, all_results, config_names): + """Save numeric results as JSON for reproducibility.""" + summary = {} + for prob in problems: + prob_summary = {} + for cname in config_names: + key = (prob.name, cname) + if key not in all_results: + continue + results = all_results[key] + finals = [r.final_best for r in results] + opt_rates = [r.found_optimum for r in results] + unique_evals = [r.n_unique_evals for r in results] + prob_summary[cname] = { + "mean_best": float(np.mean(finals)), + "std_best": float(np.std(finals)), + "median_best": float(np.median(finals)), + "optimum_rate": float(np.mean(opt_rates)), + "mean_unique_evals": float(np.mean(unique_evals)), + } + summary[prob.name] = prob_summary + + path = OUTPUT_DIR / "results_ts.json" + with open(path, "w") as f: + json.dump(summary, f, indent=2) + print(f"Saved: {path}") + return summary + + +# ═══════════════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════════════ + + +def main(): + print("MCTS Thompson Sampling Benchmark") + print("=" * 70) + + problems = [ + make_problem_multigroup_interaction(), + make_problem_needle(), + make_problem_mixed(), + make_problem_large_sparse(), + make_problem_graduated(), + make_problem_simple_additive(), + ] + + for p in problems: + print( + f" {p.name}: ~{p.search_space_size:,} combinations, " + f"{p.n_iterations} iters x {p.n_trials} trials" + ) + + t_start = time.time() + all_results = run_benchmark(problems) + total_time = time.time() - t_start + print(f"\nTotal benchmark time: {total_time:.1f}s") + + # Generate plots + print("\nGenerating plots...") + for prob in problems: + plot_convergence(prob.name, all_results, ALL_CONFIG_NAMES, prob.n_iterations) + plot_convergence_subsets(prob.name, all_results, prob.n_iterations) + + plot_summary_bar_chart(problems, all_results, ALL_CONFIG_NAMES) + plot_optimum_rate_heatmap(problems, all_results, ALL_CONFIG_NAMES) + plot_unique_evals(problems, all_results, ALL_CONFIG_NAMES) + + # Save numeric results + summary = save_summary_json(problems, all_results, ALL_CONFIG_NAMES) + + # Print summary table + print("\n" + "=" * 90) + print("SUMMARY TABLE") + print("=" * 90) + for prob in problems: + print( + f"\n{prob.name} (search space: ~{prob.search_space_size:,}, " + f"optimum: {prob.optimal_value})" + ) + print( + f" {'Config':<30s} {'Mean Best':>10s} {'+-Std':>8s} " + f"{'Opt Rate':>10s} {'Uniq Evals':>12s}" + ) + print(f" {'-'*70}") + for cname in ALL_CONFIG_NAMES: + d = summary[prob.name].get(cname) + if d is None: + continue + print( + f" {cname:<30s} {d['mean_best']:10.1f} {d['std_best']:8.1f} " + f"{d['optimum_rate']:10.0%} {d['mean_unique_evals']:12.0f}" + ) + + print(f"\nAll outputs saved to: {OUTPUT_DIR}") + + +if __name__ == "__main__": + main() diff --git a/mcts-report/convergence_graduated_landscape.png b/mcts-report/convergence_graduated_landscape.png new file mode 100644 index 000000000..16001c818 Binary files /dev/null and b/mcts-report/convergence_graduated_landscape.png differ diff --git a/mcts-report/convergence_graduated_landscape_crave.png b/mcts-report/convergence_graduated_landscape_crave.png new file mode 100644 index 000000000..34b60e53f Binary files /dev/null and b/mcts-report/convergence_graduated_landscape_crave.png differ diff --git a/mcts-report/convergence_graduated_landscape_exploration.png b/mcts-report/convergence_graduated_landscape_exploration.png new file mode 100644 index 000000000..b9fa89a67 Binary files /dev/null and b/mcts-report/convergence_graduated_landscape_exploration.png differ diff --git a/mcts-report/convergence_graduated_landscape_p_stop.png b/mcts-report/convergence_graduated_landscape_p_stop.png new file mode 100644 index 000000000..2b6bf24c7 Binary files /dev/null and b/mcts-report/convergence_graduated_landscape_p_stop.png differ diff --git a/mcts-report/convergence_graduated_landscape_pw_effect.png b/mcts-report/convergence_graduated_landscape_pw_effect.png new file mode 100644 index 000000000..3071d844f Binary files /dev/null and b/mcts-report/convergence_graduated_landscape_pw_effect.png differ diff --git a/mcts-report/convergence_graduated_landscape_rave_effect.png b/mcts-report/convergence_graduated_landscape_rave_effect.png new file mode 100644 index 000000000..3c1fde379 Binary files /dev/null and b/mcts-report/convergence_graduated_landscape_rave_effect.png differ diff --git a/mcts-report/convergence_graduated_landscape_rollout.png b/mcts-report/convergence_graduated_landscape_rollout.png new file mode 100644 index 000000000..d46e27717 Binary files /dev/null and b/mcts-report/convergence_graduated_landscape_rollout.png differ diff --git a/mcts-report/convergence_large_sparse.png b/mcts-report/convergence_large_sparse.png new file mode 100644 index 000000000..3d33ab1cb Binary files /dev/null and b/mcts-report/convergence_large_sparse.png differ diff --git a/mcts-report/convergence_large_sparse_crave.png b/mcts-report/convergence_large_sparse_crave.png new file mode 100644 index 000000000..77dc00e81 Binary files /dev/null and b/mcts-report/convergence_large_sparse_crave.png differ diff --git a/mcts-report/convergence_large_sparse_exploration.png b/mcts-report/convergence_large_sparse_exploration.png new file mode 100644 index 000000000..b640848de Binary files /dev/null and b/mcts-report/convergence_large_sparse_exploration.png differ diff --git a/mcts-report/convergence_large_sparse_p_stop.png b/mcts-report/convergence_large_sparse_p_stop.png new file mode 100644 index 000000000..b305d8bec Binary files /dev/null and b/mcts-report/convergence_large_sparse_p_stop.png differ diff --git a/mcts-report/convergence_large_sparse_pw_effect.png b/mcts-report/convergence_large_sparse_pw_effect.png new file mode 100644 index 000000000..c5240ae17 Binary files /dev/null and b/mcts-report/convergence_large_sparse_pw_effect.png differ diff --git a/mcts-report/convergence_large_sparse_rave_effect.png b/mcts-report/convergence_large_sparse_rave_effect.png new file mode 100644 index 000000000..a3a612f1d Binary files /dev/null and b/mcts-report/convergence_large_sparse_rave_effect.png differ diff --git a/mcts-report/convergence_large_sparse_rollout.png b/mcts-report/convergence_large_sparse_rollout.png new file mode 100644 index 000000000..9a2259fcf Binary files /dev/null and b/mcts-report/convergence_large_sparse_rollout.png differ diff --git a/mcts-report/convergence_mixed_nchoosek_categorical.png b/mcts-report/convergence_mixed_nchoosek_categorical.png new file mode 100644 index 000000000..d7dd22048 Binary files /dev/null and b/mcts-report/convergence_mixed_nchoosek_categorical.png differ diff --git a/mcts-report/convergence_mixed_nchoosek_categorical_crave.png b/mcts-report/convergence_mixed_nchoosek_categorical_crave.png new file mode 100644 index 000000000..1953134b2 Binary files /dev/null and b/mcts-report/convergence_mixed_nchoosek_categorical_crave.png differ diff --git a/mcts-report/convergence_mixed_nchoosek_categorical_exploration.png b/mcts-report/convergence_mixed_nchoosek_categorical_exploration.png new file mode 100644 index 000000000..4e1127194 Binary files /dev/null and b/mcts-report/convergence_mixed_nchoosek_categorical_exploration.png differ diff --git a/mcts-report/convergence_mixed_nchoosek_categorical_p_stop.png b/mcts-report/convergence_mixed_nchoosek_categorical_p_stop.png new file mode 100644 index 000000000..2bcd2216f Binary files /dev/null and b/mcts-report/convergence_mixed_nchoosek_categorical_p_stop.png differ diff --git a/mcts-report/convergence_mixed_nchoosek_categorical_pw_effect.png b/mcts-report/convergence_mixed_nchoosek_categorical_pw_effect.png new file mode 100644 index 000000000..40dfe171b Binary files /dev/null and b/mcts-report/convergence_mixed_nchoosek_categorical_pw_effect.png differ diff --git a/mcts-report/convergence_mixed_nchoosek_categorical_rave_effect.png b/mcts-report/convergence_mixed_nchoosek_categorical_rave_effect.png new file mode 100644 index 000000000..d9ba37577 Binary files /dev/null and b/mcts-report/convergence_mixed_nchoosek_categorical_rave_effect.png differ diff --git a/mcts-report/convergence_mixed_nchoosek_categorical_rollout.png b/mcts-report/convergence_mixed_nchoosek_categorical_rollout.png new file mode 100644 index 000000000..b7ae9a780 Binary files /dev/null and b/mcts-report/convergence_mixed_nchoosek_categorical_rollout.png differ diff --git a/mcts-report/convergence_multigroup_interaction.png b/mcts-report/convergence_multigroup_interaction.png new file mode 100644 index 000000000..19144caa2 Binary files /dev/null and b/mcts-report/convergence_multigroup_interaction.png differ diff --git a/mcts-report/convergence_multigroup_interaction_crave.png b/mcts-report/convergence_multigroup_interaction_crave.png new file mode 100644 index 000000000..7a06f56af Binary files /dev/null and b/mcts-report/convergence_multigroup_interaction_crave.png differ diff --git a/mcts-report/convergence_multigroup_interaction_exploration.png b/mcts-report/convergence_multigroup_interaction_exploration.png new file mode 100644 index 000000000..957484dcb Binary files /dev/null and b/mcts-report/convergence_multigroup_interaction_exploration.png differ diff --git a/mcts-report/convergence_multigroup_interaction_p_stop.png b/mcts-report/convergence_multigroup_interaction_p_stop.png new file mode 100644 index 000000000..0ec1ebd60 Binary files /dev/null and b/mcts-report/convergence_multigroup_interaction_p_stop.png differ diff --git a/mcts-report/convergence_multigroup_interaction_pw_effect.png b/mcts-report/convergence_multigroup_interaction_pw_effect.png new file mode 100644 index 000000000..f004c3d30 Binary files /dev/null and b/mcts-report/convergence_multigroup_interaction_pw_effect.png differ diff --git a/mcts-report/convergence_multigroup_interaction_rave_effect.png b/mcts-report/convergence_multigroup_interaction_rave_effect.png new file mode 100644 index 000000000..2a76c99a0 Binary files /dev/null and b/mcts-report/convergence_multigroup_interaction_rave_effect.png differ diff --git a/mcts-report/convergence_multigroup_interaction_rollout.png b/mcts-report/convergence_multigroup_interaction_rollout.png new file mode 100644 index 000000000..6df577653 Binary files /dev/null and b/mcts-report/convergence_multigroup_interaction_rollout.png differ diff --git a/mcts-report/convergence_needle_in_haystack.png b/mcts-report/convergence_needle_in_haystack.png new file mode 100644 index 000000000..3fac7a824 Binary files /dev/null and b/mcts-report/convergence_needle_in_haystack.png differ diff --git a/mcts-report/convergence_needle_in_haystack_crave.png b/mcts-report/convergence_needle_in_haystack_crave.png new file mode 100644 index 000000000..538f59e3f Binary files /dev/null and b/mcts-report/convergence_needle_in_haystack_crave.png differ diff --git a/mcts-report/convergence_needle_in_haystack_exploration.png b/mcts-report/convergence_needle_in_haystack_exploration.png new file mode 100644 index 000000000..cc377170c Binary files /dev/null and b/mcts-report/convergence_needle_in_haystack_exploration.png differ diff --git a/mcts-report/convergence_needle_in_haystack_p_stop.png b/mcts-report/convergence_needle_in_haystack_p_stop.png new file mode 100644 index 000000000..ee7796077 Binary files /dev/null and b/mcts-report/convergence_needle_in_haystack_p_stop.png differ diff --git a/mcts-report/convergence_needle_in_haystack_pw_effect.png b/mcts-report/convergence_needle_in_haystack_pw_effect.png new file mode 100644 index 000000000..9c3ccfe9d Binary files /dev/null and b/mcts-report/convergence_needle_in_haystack_pw_effect.png differ diff --git a/mcts-report/convergence_needle_in_haystack_rave_effect.png b/mcts-report/convergence_needle_in_haystack_rave_effect.png new file mode 100644 index 000000000..49463141f Binary files /dev/null and b/mcts-report/convergence_needle_in_haystack_rave_effect.png differ diff --git a/mcts-report/convergence_needle_in_haystack_rollout.png b/mcts-report/convergence_needle_in_haystack_rollout.png new file mode 100644 index 000000000..4538dc789 Binary files /dev/null and b/mcts-report/convergence_needle_in_haystack_rollout.png differ diff --git a/mcts-report/convergence_nig_adaptive_graduated_landscape.png b/mcts-report/convergence_nig_adaptive_graduated_landscape.png new file mode 100644 index 000000000..1e51694d1 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_graduated_landscape.png differ diff --git a/mcts-report/convergence_nig_adaptive_graduated_landscape_adaptive_vs_fixed.png b/mcts-report/convergence_nig_adaptive_graduated_landscape_adaptive_vs_fixed.png new file mode 100644 index 000000000..3d68716c4 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_graduated_landscape_adaptive_vs_fixed.png differ diff --git a/mcts-report/convergence_nig_adaptive_large_sparse.png b/mcts-report/convergence_nig_adaptive_large_sparse.png new file mode 100644 index 000000000..a550963e7 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_large_sparse.png differ diff --git a/mcts-report/convergence_nig_adaptive_large_sparse_adaptive_vs_fixed.png b/mcts-report/convergence_nig_adaptive_large_sparse_adaptive_vs_fixed.png new file mode 100644 index 000000000..944f91c51 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_large_sparse_adaptive_vs_fixed.png differ diff --git a/mcts-report/convergence_nig_adaptive_mixed_nchoosek_categorical.png b/mcts-report/convergence_nig_adaptive_mixed_nchoosek_categorical.png new file mode 100644 index 000000000..c75216427 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_mixed_nchoosek_categorical.png differ diff --git a/mcts-report/convergence_nig_adaptive_mixed_nchoosek_categorical_adaptive_vs_fixed.png b/mcts-report/convergence_nig_adaptive_mixed_nchoosek_categorical_adaptive_vs_fixed.png new file mode 100644 index 000000000..541383e5b Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_mixed_nchoosek_categorical_adaptive_vs_fixed.png differ diff --git a/mcts-report/convergence_nig_adaptive_multigroup_interaction.png b/mcts-report/convergence_nig_adaptive_multigroup_interaction.png new file mode 100644 index 000000000..1c1542b99 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_multigroup_interaction.png differ diff --git a/mcts-report/convergence_nig_adaptive_multigroup_interaction_adaptive_vs_fixed.png b/mcts-report/convergence_nig_adaptive_multigroup_interaction_adaptive_vs_fixed.png new file mode 100644 index 000000000..ae6371f13 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_multigroup_interaction_adaptive_vs_fixed.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_graduated_landscape.png b/mcts-report/convergence_nig_adaptive_n0_graduated_landscape.png new file mode 100644 index 000000000..4b0fc8779 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_graduated_landscape.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_graduated_landscape_n0_effect.png b/mcts-report/convergence_nig_adaptive_n0_graduated_landscape_n0_effect.png new file mode 100644 index 000000000..233723047 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_graduated_landscape_n0_effect.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_large_sparse.png b/mcts-report/convergence_nig_adaptive_n0_large_sparse.png new file mode 100644 index 000000000..36ae619d3 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_large_sparse.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_large_sparse_n0_effect.png b/mcts-report/convergence_nig_adaptive_n0_large_sparse_n0_effect.png new file mode 100644 index 000000000..e8fef1ca6 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_large_sparse_n0_effect.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_mixed_nchoosek_categorical.png b/mcts-report/convergence_nig_adaptive_n0_mixed_nchoosek_categorical.png new file mode 100644 index 000000000..2b560b744 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_mixed_nchoosek_categorical.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_mixed_nchoosek_categorical_n0_effect.png b/mcts-report/convergence_nig_adaptive_n0_mixed_nchoosek_categorical_n0_effect.png new file mode 100644 index 000000000..ff86d860d Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_mixed_nchoosek_categorical_n0_effect.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_multigroup_interaction.png b/mcts-report/convergence_nig_adaptive_n0_multigroup_interaction.png new file mode 100644 index 000000000..744991332 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_multigroup_interaction.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_multigroup_interaction_n0_effect.png b/mcts-report/convergence_nig_adaptive_n0_multigroup_interaction_n0_effect.png new file mode 100644 index 000000000..266daf7d7 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_multigroup_interaction_n0_effect.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_needle_in_haystack.png b/mcts-report/convergence_nig_adaptive_n0_needle_in_haystack.png new file mode 100644 index 000000000..e6f57210a Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_needle_in_haystack.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_needle_in_haystack_n0_effect.png b/mcts-report/convergence_nig_adaptive_n0_needle_in_haystack_n0_effect.png new file mode 100644 index 000000000..adc895789 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_needle_in_haystack_n0_effect.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_simple_additive.png b/mcts-report/convergence_nig_adaptive_n0_simple_additive.png new file mode 100644 index 000000000..7d493a2f4 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_simple_additive.png differ diff --git a/mcts-report/convergence_nig_adaptive_n0_simple_additive_n0_effect.png b/mcts-report/convergence_nig_adaptive_n0_simple_additive_n0_effect.png new file mode 100644 index 000000000..c10ab3a36 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_n0_simple_additive_n0_effect.png differ diff --git a/mcts-report/convergence_nig_adaptive_needle_in_haystack.png b/mcts-report/convergence_nig_adaptive_needle_in_haystack.png new file mode 100644 index 000000000..22519b9ca Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_needle_in_haystack.png differ diff --git a/mcts-report/convergence_nig_adaptive_needle_in_haystack_adaptive_vs_fixed.png b/mcts-report/convergence_nig_adaptive_needle_in_haystack_adaptive_vs_fixed.png new file mode 100644 index 000000000..6b9945bfa Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_needle_in_haystack_adaptive_vs_fixed.png differ diff --git a/mcts-report/convergence_nig_adaptive_simple_additive.png b/mcts-report/convergence_nig_adaptive_simple_additive.png new file mode 100644 index 000000000..ae5eefc48 Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_simple_additive.png differ diff --git a/mcts-report/convergence_nig_adaptive_simple_additive_adaptive_vs_fixed.png b/mcts-report/convergence_nig_adaptive_simple_additive_adaptive_vs_fixed.png new file mode 100644 index 000000000..98634d41a Binary files /dev/null and b/mcts-report/convergence_nig_adaptive_simple_additive_adaptive_vs_fixed.png differ diff --git a/mcts-report/convergence_nig_graduated_landscape.png b/mcts-report/convergence_nig_graduated_landscape.png new file mode 100644 index 000000000..9523fb901 Binary files /dev/null and b/mcts-report/convergence_nig_graduated_landscape.png differ diff --git a/mcts-report/convergence_nig_graduated_landscape_nig_alpha.png b/mcts-report/convergence_nig_graduated_landscape_nig_alpha.png new file mode 100644 index 000000000..e5ac80d01 Binary files /dev/null and b/mcts-report/convergence_nig_graduated_landscape_nig_alpha.png differ diff --git a/mcts-report/convergence_nig_graduated_landscape_nig_cache_modes.png b/mcts-report/convergence_nig_graduated_landscape_nig_cache_modes.png new file mode 100644 index 000000000..39148c146 Binary files /dev/null and b/mcts-report/convergence_nig_graduated_landscape_nig_cache_modes.png differ diff --git a/mcts-report/convergence_nig_graduated_landscape_nig_vs_normal_ts.png b/mcts-report/convergence_nig_graduated_landscape_nig_vs_normal_ts.png new file mode 100644 index 000000000..4a20234d9 Binary files /dev/null and b/mcts-report/convergence_nig_graduated_landscape_nig_vs_normal_ts.png differ diff --git a/mcts-report/convergence_nig_large_sparse.png b/mcts-report/convergence_nig_large_sparse.png new file mode 100644 index 000000000..ccb829b19 Binary files /dev/null and b/mcts-report/convergence_nig_large_sparse.png differ diff --git a/mcts-report/convergence_nig_large_sparse_nig_alpha.png b/mcts-report/convergence_nig_large_sparse_nig_alpha.png new file mode 100644 index 000000000..f3b2f48b3 Binary files /dev/null and b/mcts-report/convergence_nig_large_sparse_nig_alpha.png differ diff --git a/mcts-report/convergence_nig_large_sparse_nig_cache_modes.png b/mcts-report/convergence_nig_large_sparse_nig_cache_modes.png new file mode 100644 index 000000000..94070a3e6 Binary files /dev/null and b/mcts-report/convergence_nig_large_sparse_nig_cache_modes.png differ diff --git a/mcts-report/convergence_nig_large_sparse_nig_vs_normal_ts.png b/mcts-report/convergence_nig_large_sparse_nig_vs_normal_ts.png new file mode 100644 index 000000000..f865f7fc6 Binary files /dev/null and b/mcts-report/convergence_nig_large_sparse_nig_vs_normal_ts.png differ diff --git a/mcts-report/convergence_nig_mixed_nchoosek_categorical.png b/mcts-report/convergence_nig_mixed_nchoosek_categorical.png new file mode 100644 index 000000000..e6d453389 Binary files /dev/null and b/mcts-report/convergence_nig_mixed_nchoosek_categorical.png differ diff --git a/mcts-report/convergence_nig_mixed_nchoosek_categorical_nig_alpha.png b/mcts-report/convergence_nig_mixed_nchoosek_categorical_nig_alpha.png new file mode 100644 index 000000000..03a60fe3c Binary files /dev/null and b/mcts-report/convergence_nig_mixed_nchoosek_categorical_nig_alpha.png differ diff --git a/mcts-report/convergence_nig_mixed_nchoosek_categorical_nig_cache_modes.png b/mcts-report/convergence_nig_mixed_nchoosek_categorical_nig_cache_modes.png new file mode 100644 index 000000000..7c34d6c08 Binary files /dev/null and b/mcts-report/convergence_nig_mixed_nchoosek_categorical_nig_cache_modes.png differ diff --git a/mcts-report/convergence_nig_mixed_nchoosek_categorical_nig_vs_normal_ts.png b/mcts-report/convergence_nig_mixed_nchoosek_categorical_nig_vs_normal_ts.png new file mode 100644 index 000000000..008b90faa Binary files /dev/null and b/mcts-report/convergence_nig_mixed_nchoosek_categorical_nig_vs_normal_ts.png differ diff --git a/mcts-report/convergence_nig_multigroup_interaction.png b/mcts-report/convergence_nig_multigroup_interaction.png new file mode 100644 index 000000000..55d9f2eae Binary files /dev/null and b/mcts-report/convergence_nig_multigroup_interaction.png differ diff --git a/mcts-report/convergence_nig_multigroup_interaction_nig_alpha.png b/mcts-report/convergence_nig_multigroup_interaction_nig_alpha.png new file mode 100644 index 000000000..73ed5b91f Binary files /dev/null and b/mcts-report/convergence_nig_multigroup_interaction_nig_alpha.png differ diff --git a/mcts-report/convergence_nig_multigroup_interaction_nig_cache_modes.png b/mcts-report/convergence_nig_multigroup_interaction_nig_cache_modes.png new file mode 100644 index 000000000..fdbced430 Binary files /dev/null and b/mcts-report/convergence_nig_multigroup_interaction_nig_cache_modes.png differ diff --git a/mcts-report/convergence_nig_multigroup_interaction_nig_vs_normal_ts.png b/mcts-report/convergence_nig_multigroup_interaction_nig_vs_normal_ts.png new file mode 100644 index 000000000..f16faf15c Binary files /dev/null and b/mcts-report/convergence_nig_multigroup_interaction_nig_vs_normal_ts.png differ diff --git a/mcts-report/convergence_nig_needle_in_haystack.png b/mcts-report/convergence_nig_needle_in_haystack.png new file mode 100644 index 000000000..1a4fa9b9d Binary files /dev/null and b/mcts-report/convergence_nig_needle_in_haystack.png differ diff --git a/mcts-report/convergence_nig_needle_in_haystack_nig_alpha.png b/mcts-report/convergence_nig_needle_in_haystack_nig_alpha.png new file mode 100644 index 000000000..2a1e54757 Binary files /dev/null and b/mcts-report/convergence_nig_needle_in_haystack_nig_alpha.png differ diff --git a/mcts-report/convergence_nig_needle_in_haystack_nig_cache_modes.png b/mcts-report/convergence_nig_needle_in_haystack_nig_cache_modes.png new file mode 100644 index 000000000..a43912c91 Binary files /dev/null and b/mcts-report/convergence_nig_needle_in_haystack_nig_cache_modes.png differ diff --git a/mcts-report/convergence_nig_needle_in_haystack_nig_vs_normal_ts.png b/mcts-report/convergence_nig_needle_in_haystack_nig_vs_normal_ts.png new file mode 100644 index 000000000..77a82f799 Binary files /dev/null and b/mcts-report/convergence_nig_needle_in_haystack_nig_vs_normal_ts.png differ diff --git a/mcts-report/convergence_nig_simple_additive.png b/mcts-report/convergence_nig_simple_additive.png new file mode 100644 index 000000000..fc798ed91 Binary files /dev/null and b/mcts-report/convergence_nig_simple_additive.png differ diff --git a/mcts-report/convergence_nig_simple_additive_nig_alpha.png b/mcts-report/convergence_nig_simple_additive_nig_alpha.png new file mode 100644 index 000000000..a32ab0371 Binary files /dev/null and b/mcts-report/convergence_nig_simple_additive_nig_alpha.png differ diff --git a/mcts-report/convergence_nig_simple_additive_nig_cache_modes.png b/mcts-report/convergence_nig_simple_additive_nig_cache_modes.png new file mode 100644 index 000000000..cd6fa984f Binary files /dev/null and b/mcts-report/convergence_nig_simple_additive_nig_cache_modes.png differ diff --git a/mcts-report/convergence_nig_simple_additive_nig_vs_normal_ts.png b/mcts-report/convergence_nig_simple_additive_nig_vs_normal_ts.png new file mode 100644 index 000000000..b21534a43 Binary files /dev/null and b/mcts-report/convergence_nig_simple_additive_nig_vs_normal_ts.png differ diff --git a/mcts-report/convergence_simple_additive.png b/mcts-report/convergence_simple_additive.png new file mode 100644 index 000000000..f18cc1609 Binary files /dev/null and b/mcts-report/convergence_simple_additive.png differ diff --git a/mcts-report/convergence_simple_additive_crave.png b/mcts-report/convergence_simple_additive_crave.png new file mode 100644 index 000000000..7ef894bee Binary files /dev/null and b/mcts-report/convergence_simple_additive_crave.png differ diff --git a/mcts-report/convergence_simple_additive_exploration.png b/mcts-report/convergence_simple_additive_exploration.png new file mode 100644 index 000000000..8c746b93f Binary files /dev/null and b/mcts-report/convergence_simple_additive_exploration.png differ diff --git a/mcts-report/convergence_simple_additive_p_stop.png b/mcts-report/convergence_simple_additive_p_stop.png new file mode 100644 index 000000000..a9b6c3098 Binary files /dev/null and b/mcts-report/convergence_simple_additive_p_stop.png differ diff --git a/mcts-report/convergence_simple_additive_pw_effect.png b/mcts-report/convergence_simple_additive_pw_effect.png new file mode 100644 index 000000000..857093465 Binary files /dev/null and b/mcts-report/convergence_simple_additive_pw_effect.png differ diff --git a/mcts-report/convergence_simple_additive_rave_effect.png b/mcts-report/convergence_simple_additive_rave_effect.png new file mode 100644 index 000000000..d1a27f273 Binary files /dev/null and b/mcts-report/convergence_simple_additive_rave_effect.png differ diff --git a/mcts-report/convergence_simple_additive_rollout.png b/mcts-report/convergence_simple_additive_rollout.png new file mode 100644 index 000000000..58cc96ddd Binary files /dev/null and b/mcts-report/convergence_simple_additive_rollout.png differ diff --git a/mcts-report/investigate_sampling_vs_optimizing.py b/mcts-report/investigate_sampling_vs_optimizing.py new file mode 100644 index 000000000..78249ebf4 --- /dev/null +++ b/mcts-report/investigate_sampling_vs_optimizing.py @@ -0,0 +1,523 @@ +"""Investigate the gap between optimized acqf and polytope-sampled acqf values. + +Uses FormulationWrapper(benchmark=Hartmann(), max_count=4) which creates a +7D simplex-constrained problem (6 original features + 1 filler, sum-to-1, +NChooseK on the 6 non-filler features). + +For each NChooseK subset: + 1. optimize_acqf with linear constraints + fixed features (expensive) + 2. sample_q_batches_from_polytope + evaluate acqf (cheap, hit-and-run) + +Questions: + - How large is the gap between best polytope sample and optimized value? + - Is there rank correlation between sample-best and optimized rankings? + - Can cheap samples reliably identify the top subsets? + +Usage: + python mcts-report/investigate_sampling_vs_optimizing.py +""" + +import itertools +import sys +import time +import warnings +from pathlib import Path + +import numpy as np +import torch +from botorch.optim import optimize_acqf +from botorch.optim.initializers import sample_q_batches_from_polytope +from botorch.optim.parameter_constraints import _generate_unfixed_lin_constraints +from scipy import stats + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +import bofire.data_models.strategies.api as data_models +from bofire.benchmarks.api import Hartmann +from bofire.benchmarks.benchmark import FormulationWrapper +from bofire.data_models.constraints.api import ( + LinearEqualityConstraint, + LinearInequalityConstraint, +) +from bofire.data_models.strategies.predictives.acqf_optimization import BotorchOptimizer +from bofire.strategies.predictives.sobo import SoboStrategy +from bofire.strategies.random import RandomStrategy +from bofire.strategies.utils import get_torch_bounds_from_domain +from bofire.utils.torch_tools import get_linear_constraints, tkwargs + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +MAX_COUNT = 4 +N_INITIAL = 30 +N_RESTARTS = 20 +RAW_SAMPLES = 2048 +N_SEEDS = 5 +SAMPLE_COUNTS = [64, 256, 1024, 2048] + + +def make_strategy_and_acqf(benchmark, seed: int): + """Create a fitted SoboStrategy and extract acqf + bounds + constraints.""" + domain = benchmark.domain + random_strategy = RandomStrategy( + data_model=data_models.RandomStrategy(domain=domain, seed=seed), + ) + candidates = random_strategy.ask(N_INITIAL) + experiments = benchmark.f(candidates, return_complete=True) + + strategy = SoboStrategy( + data_model=data_models.SoboStrategy( + domain=domain, + acquisition_optimizer=BotorchOptimizer( + n_restarts=N_RESTARTS, + n_raw_samples=RAW_SAMPLES, + ), + ), + ) + strategy.tell(experiments) + + acqf = strategy._get_acqfs(1)[0] + bounds = get_torch_bounds_from_domain(domain, strategy.input_preprocessing_specs) + + # Extract linear constraints in BoTorch format + ineq_constraints = get_linear_constraints( + domain=domain, + constraint=LinearInequalityConstraint, + unit_scaled=False, + ) + eq_constraints = get_linear_constraints( + domain=domain, + constraint=LinearEqualityConstraint, + unit_scaled=False, + ) + + best_f = experiments["y"].min() + return strategy, acqf, bounds, ineq_constraints, eq_constraints, best_f + + +def get_nchoosek_feature_keys(benchmark): + """Get the non-filler feature keys and their indices.""" + domain = benchmark.domain + from bofire.data_models.features.api import ContinuousInput + + all_keys = domain.inputs.get_keys(ContinuousInput) + nchoosek_keys = [k for k in all_keys if not k.startswith("x_filler_")] + nchoosek_indices = [all_keys.index(k) for k in nchoosek_keys] + return nchoosek_keys, nchoosek_indices, all_keys + + +def enumerate_all_subsets(indices: list[int], max_k: int) -> list[frozenset[int]]: + """Generate all subsets of indices with size 0..max_k.""" + subsets = [] + for k in range(0, max_k + 1): + for combo in itertools.combinations(indices, k): + subsets.append(frozenset(combo)) + return subsets + + +def optimized_acqf_per_subset( + acqf, + bounds: torch.Tensor, + subsets: list[frozenset[int]], + nchoosek_indices: set[int], + ineq_constraints, + eq_constraints, +) -> dict[frozenset[int], float]: + """Run optimize_acqf for every subset with linear constraints.""" + dim = bounds.shape[1] + results = {} + + botorch_ineqs = ineq_constraints if len(ineq_constraints) > 0 else None + botorch_eqs = eq_constraints if len(eq_constraints) > 0 else None + + for subset in subsets: + # Fix inactive NChooseK features to 0 + fixed = {i: 0.0 for i in nchoosek_indices if i not in subset} + + if len(subset) == 0: + # All NChooseK features fixed to 0, only filler is free + candidate = torch.zeros(1, dim, dtype=bounds.dtype) + # Set filler to satisfy sum=1 constraint + filler_indices = [i for i in range(dim) if i not in nchoosek_indices] + for fi in filler_indices: + candidate[0, fi] = 1.0 / len(filler_indices) + with torch.no_grad(): + val = acqf(candidate.unsqueeze(0)).item() + results[subset] = val + continue + + try: + candidates, acq_val = optimize_acqf( + acq_function=acqf, + bounds=bounds, + q=1, + num_restarts=N_RESTARTS, + raw_samples=RAW_SAMPLES, + fixed_features=fixed, + inequality_constraints=botorch_ineqs, + equality_constraints=botorch_eqs, + ) + results[subset] = acq_val.item() + except Exception as e: + print(f" optimize_acqf failed for {sorted(subset)}: {e}") + results[subset] = float("-inf") + + return results + + +def polytope_sample_acqf_per_subset( + acqf, + bounds: torch.Tensor, + subsets: list[frozenset[int]], + nchoosek_indices: set[int], + all_continuous_keys: list[str], + ineq_constraints, + eq_constraints, + n_samples: int, + seed: int, +) -> dict[frozenset[int], dict]: + """Sample from polytope (hit-and-run) and evaluate acqf for every subset. + + Uses sample_q_batches_from_polytope with _generate_unfixed_lin_constraints + to handle fixed features, same approach as RandomStrategy._sample_from_polytope. + """ + dim = bounds.shape[1] + results = {} + + for subset in subsets: + # Fix inactive NChooseK features to 0 + fixed_features_dict = {i: 0.0 for i in nchoosek_indices if i not in subset} + + if len(subset) == 0: + # All NChooseK features fixed, only filler free + candidate = torch.zeros(1, dim, dtype=bounds.dtype) + filler_indices = [i for i in range(dim) if i not in nchoosek_indices] + for fi in filler_indices: + candidate[0, fi] = 1.0 / len(filler_indices) + with torch.no_grad(): + val = acqf(candidate.unsqueeze(0)).item() + results[subset] = { + "best": val, + "mean": val, + "std": 0.0, + "all": np.array([val]), + } + continue + + # Build unfixed bounds (remove fixed dimensions) + free_indices = [i for i in range(dim) if i not in fixed_features_dict] + free_lower = bounds[0, free_indices] + free_upper = bounds[1, free_indices] + free_bounds = torch.stack([free_lower, free_upper]).to(**tkwargs) + + # Generate unfixed constraints using BoTorch's helper + unfixed_ineqs = _generate_unfixed_lin_constraints( + constraints=ineq_constraints, + eq=False, + fixed_features=fixed_features_dict, + dimension=dim, + ) + unfixed_eqs = _generate_unfixed_lin_constraints( + constraints=eq_constraints, + eq=True, + fixed_features=fixed_features_dict, + dimension=dim, + ) + + try: + # Sample from polytope using hit-and-run + samples = sample_q_batches_from_polytope( + n=1, + q=n_samples, + bounds=free_bounds, + inequality_constraints=unfixed_ineqs + if len(unfixed_ineqs) > 0 + else None, + equality_constraints=unfixed_eqs if len(unfixed_eqs) > 0 else None, + n_burnin=1000, + n_thinning=32, + seed=seed, + ).squeeze(0) # (n_samples, free_dim) + + # Reconstruct full-dim candidates + full_candidates = torch.zeros(n_samples, dim, dtype=bounds.dtype) + for j, fi in enumerate(free_indices): + full_candidates[:, fi] = samples[:, j] + for fi, val in fixed_features_dict.items(): + full_candidates[:, fi] = val + + # Evaluate acqf in batches + batch_size = 256 + vals = [] + with torch.no_grad(): + for i in range(0, n_samples, batch_size): + batch = full_candidates[i : i + batch_size].unsqueeze( + 1 + ) # (b, 1, dim) + v = acqf(batch) + vals.append(v.cpu().numpy()) + all_vals = np.concatenate(vals) + + results[subset] = { + "best": float(all_vals.max()), + "mean": float(all_vals.mean()), + "std": float(all_vals.std()), + "all": all_vals, + } + except Exception as e: + print(f" polytope sample failed for {sorted(subset)}: {e}") + results[subset] = { + "best": float("-inf"), + "mean": float("-inf"), + "std": 0.0, + "all": np.array([float("-inf")]), + } + + return results + + +def rank_subsets(values: dict[frozenset[int], float]) -> list[frozenset[int]]: + """Return subsets sorted by value descending.""" + return sorted(values.keys(), key=lambda s: values[s], reverse=True) + + +def main(): + warnings.filterwarnings("ignore", category=RuntimeWarning) + warnings.filterwarnings("ignore", message=".*InputDataWarning.*") + warnings.filterwarnings("ignore", message=".*model inputs.*") + warnings.filterwarnings("ignore", message=".*not unique.*") + + benchmark = FormulationWrapper(benchmark=Hartmann(), max_count=MAX_COUNT) + nchoosek_keys, nchoosek_indices, all_keys = get_nchoosek_feature_keys(benchmark) + nchoosek_set = set(nchoosek_indices) + + subsets = enumerate_all_subsets(nchoosek_indices, MAX_COUNT) + n_subsets = len(subsets) + + print("Sampling vs Optimizing: FormulationWrapper(Hartmann(), max_count=4)") + print("=" * 70) + print( + f" Domain: {len(all_keys)} features ({len(nchoosek_keys)} NChooseK + fillers)" + ) + print(f" Features: {all_keys}") + print(f" NChooseK features: {nchoosek_keys} (indices {nchoosek_indices})") + print(f" Total subsets: {n_subsets}") + print(f" Constraints: sum-to-1 equality + NChooseK(max={MAX_COUNT})") + print(f" Initial points: {N_INITIAL}") + print(f" optimize_acqf: {N_RESTARTS} restarts, {RAW_SAMPLES} raw samples") + print(f" Sample counts: {SAMPLE_COUNTS}") + print(f" Seeds: {N_SEEDS}") + print() + + # Collect results across seeds + all_rank_correlations = {n: [] for n in SAMPLE_COUNTS} + all_top1_match = {n: [] for n in SAMPLE_COUNTS} + all_top3_overlap = {n: [] for n in SAMPLE_COUNTS} + all_top5_overlap = {n: [] for n in SAMPLE_COUNTS} + all_mean_gaps = {n: [] for n in SAMPLE_COUNTS} + all_winner_gaps = {n: [] for n in SAMPLE_COUNTS} + all_opt_times = [] + all_sample_times = {n: [] for n in SAMPLE_COUNTS} + + for seed in range(N_SEEDS): + print(f"{'='*70}") + print(f"SEED {seed}") + print(f"{'='*70}") + + strategy, acqf, bounds, ineq_constraints, eq_constraints, best_f = ( + make_strategy_and_acqf(benchmark, seed) + ) + print(f" GP fitted, best_f = {best_f:.4f}, bounds shape = {bounds.shape}") + + # 1. Optimized values (gold standard) + t0 = time.time() + opt_values = optimized_acqf_per_subset( + acqf, bounds, subsets, nchoosek_set, ineq_constraints, eq_constraints + ) + opt_time = time.time() - t0 + all_opt_times.append(opt_time) + opt_ranking = rank_subsets(opt_values) + + opt_best_subset = opt_ranking[0] + opt_best_val = opt_values[opt_best_subset] + print(f" Exhaustive optimization: {opt_time:.1f}s") + print( + f" Best optimized: subset={sorted(opt_best_subset)} acq={opt_best_val:.4f}" + ) + print(" Top 5 optimized:") + for i in range(min(5, n_subsets)): + s = opt_ranking[i] + print(f" #{i+1}: {str(sorted(s)):>20s} acq = {opt_values[s]:.4f}") + + # 2. Polytope samples at different counts + for n_samples in SAMPLE_COUNTS: + t0 = time.time() + sample_results = polytope_sample_acqf_per_subset( + acqf, + bounds, + subsets, + nchoosek_set, + all_keys, + ineq_constraints, + eq_constraints, + n_samples=n_samples, + seed=seed + 1000, + ) + sample_time = time.time() - t0 + all_sample_times[n_samples].append(sample_time) + + sample_best_values = {s: r["best"] for s, r in sample_results.items()} + sample_ranking = rank_subsets(sample_best_values) + + # Rank correlation (Spearman) + opt_ranks = [opt_ranking.index(s) for s in subsets] + sample_ranks = [sample_ranking.index(s) for s in subsets] + rho, _ = stats.spearmanr(opt_ranks, sample_ranks) + all_rank_correlations[n_samples].append(rho) + + # Top-1 match + top1_match = sample_ranking[0] == opt_ranking[0] + all_top1_match[n_samples].append(top1_match) + + # Top-3 overlap + opt_top3 = set(opt_ranking[:3]) + sample_top3 = set(sample_ranking[:3]) + top3_overlap = len(opt_top3 & sample_top3) / 3 + all_top3_overlap[n_samples].append(top3_overlap) + + # Top-5 overlap + opt_top5 = set(opt_ranking[:5]) + sample_top5 = set(sample_ranking[:5]) + top5_overlap = len(opt_top5 & sample_top5) / 5 + all_top5_overlap[n_samples].append(top5_overlap) + + # Mean gap per subset (opt - sample_best) + gaps = [ + opt_values[s] - sample_best_values[s] + for s in subsets + if opt_values[s] > float("-inf") + and sample_best_values[s] > float("-inf") + ] + mean_gap = np.mean(gaps) if gaps else float("nan") + all_mean_gaps[n_samples].append(mean_gap) + + # Winner gap: best opt value - sample winner's sample value + sample_winner = sample_ranking[0] + winner_gap = opt_best_val - sample_best_values[sample_winner] + all_winner_gaps[n_samples].append(winner_gap) + + print( + f"\n Polytope samples n={n_samples} ({sample_time:.1f}s, {opt_time/max(sample_time,0.01):.0f}x faster):" + ) + print(f" Rank correlation (Spearman rho): {rho:.3f}") + print(f" Top-1 match: {top1_match}") + print(f" Top-3 overlap: {top3_overlap:.0%}") + print(f" Top-5 overlap: {top5_overlap:.0%}") + print(f" Mean gap (opt - sample_best): {mean_gap:.4f}") + print( + f" Sample winner: {sorted(sample_winner)} acq={sample_best_values[sample_winner]:.4f} (opt best: {opt_best_val:.4f})" + ) + print(" Sample top 5:") + for i in range(min(5, n_subsets)): + s = sample_ranking[i] + opt_rank = opt_ranking.index(s) + 1 + print( + f" #{i+1}: {str(sorted(s)):>20s} " + f"sample_best={sample_best_values[s]:.4f} " + f"optimized={opt_values[s]:.4f} " + f"opt_rank=#{opt_rank}" + ) + + print() + + # --------------------------------------------------------------------------- + # Summary across seeds + # --------------------------------------------------------------------------- + print("\n" + "=" * 70) + print("SUMMARY ACROSS SEEDS") + print("=" * 70) + print(f" Mean optimize time: {np.mean(all_opt_times):.1f}s") + print() + + header = f"{'n_samples':>10s} | {'time':>6s} | {'rho':>6s} | {'top1%':>6s} | {'top3%':>6s} | {'top5%':>6s} | {'mean_gap':>8s} | {'winner_gap':>10s}" + print(header) + print("-" * len(header)) + + for n_samples in SAMPLE_COUNTS: + t = np.mean(all_sample_times[n_samples]) + rho = np.mean(all_rank_correlations[n_samples]) + top1 = np.mean(all_top1_match[n_samples]) + top3 = np.mean(all_top3_overlap[n_samples]) + top5 = np.mean(all_top5_overlap[n_samples]) + mean_gap = np.mean(all_mean_gaps[n_samples]) + winner_gap = np.mean(all_winner_gaps[n_samples]) + + print( + f"{n_samples:>10d} | {t:>5.1f}s | {rho:>6.3f} | {top1:>5.0%} | {top3:>5.0%} | {top5:>5.0%} | {mean_gap:>8.4f} | {winner_gap:>10.4f}" + ) + + # Per-subset-size gap distribution (last seed, largest sample count) + n_last = SAMPLE_COUNTS[-1] + sample_results_last = polytope_sample_acqf_per_subset( + acqf, + bounds, + subsets, + nchoosek_set, + all_keys, + ineq_constraints, + eq_constraints, + n_samples=n_last, + seed=N_SEEDS - 1 + 1000, + ) + sample_best_last = {s: r["best"] for s, r in sample_results_last.items()} + + print(f"\n{'='*70}") + print(f"PER-SUBSET-SIZE GAP DISTRIBUTION (last seed, n_samples={n_last})") + print(f"{'='*70}") + + gaps_by_size = {} + for s in subsets: + k = len(s) + if opt_values[s] > float("-inf") and sample_best_last[s] > float("-inf"): + gap = opt_values[s] - sample_best_last[s] + if k not in gaps_by_size: + gaps_by_size[k] = [] + gaps_by_size[k].append(gap) + + for k in sorted(gaps_by_size.keys()): + g = gaps_by_size[k] + print( + f" |subset|={k}: n={len(g):>3d}, " + f"mean_gap={np.mean(g):.4f}, " + f"median_gap={np.median(g):.4f}, " + f"max_gap={np.max(g):.4f}" + ) + + # Correlation scatter (top 20 subsets) + print(f"\n{'='*70}") + print(f"TOP 20 SUBSETS: optimized vs sample_best (last seed, n={n_last})") + print(f"{'='*70}") + print(f"{'subset':>20s} | {'optimized':>10s} | {'sample_best':>11s} | {'gap':>8s}") + print("-" * 60) + for s in opt_ranking[:20]: + ov = opt_values[s] + sv = sample_best_last[s] + print(f"{str(sorted(s)):>20s} | {ov:>10.4f} | {sv:>11.4f} | {ov-sv:>8.4f}") + + # Pearson correlation + opt_arr = np.array( + [opt_values[s] for s in subsets if opt_values[s] > float("-inf")] + ) + sample_arr = np.array( + [sample_best_last[s] for s in subsets if sample_best_last[s] > float("-inf")] + ) + if len(opt_arr) == len(sample_arr) and len(opt_arr) > 2: + pearson_r, _ = stats.pearsonr(opt_arr, sample_arr) + print(f"\n Pearson r (optimized vs sample_best@{n_last}): {pearson_r:.4f}") + + +if __name__ == "__main__": + main() diff --git a/mcts-report/optimize_mcts_dag.py b/mcts-report/optimize_mcts_dag.py new file mode 100644 index 000000000..348b0db9b --- /dev/null +++ b/mcts-report/optimize_mcts_dag.py @@ -0,0 +1,799 @@ +"""MCTS DAG with transposition table for NChooseK and categorical optimization. + +Eliminates canonical ordering bias by: +1. Allowing all unselected features at every NChooseK node (not just those > last) +2. Using a transposition table to merge nodes with the same selected feature set +3. NIG-TS statistics merge cleanly across parent paths (no parent-dependent terms) + +The tree becomes a DAG (directed acyclic graph) where nodes with identical +selected feature sets share statistics, regardless of selection order. +""" + +import math +import random +from typing import Callable, Optional + +from optimize_mcts_full import STOP, Groups, NChooseK +from optimize_mcts_ts import TSActionStats, TSNode + + +# ============================================================================= +# Transposition key +# ============================================================================= + + +def _transposition_key( + group_idx: int, + partial_by_group: tuple[tuple[int, ...], ...], + stopped_by_group: tuple[bool, ...], + groups: Groups, +) -> tuple: + """Create canonical transposition key from DAG state. + + For NChooseK groups, uses frozenset of selected indices (order-independent). + For Categorical groups, uses the tuple directly (at most 1 element). + """ + canon = [] + for g_idx, group in enumerate(groups.groups): + if isinstance(group, NChooseK): + canon.append(frozenset(partial_by_group[g_idx])) + else: + canon.append(partial_by_group[g_idx]) + return (group_idx, tuple(canon), stopped_by_group) + + +# ============================================================================= +# MCTS DAG with NIG Thompson Sampling +# ============================================================================= + + +class MCTS_DAG: + """MCTS with transposition table (DAG) and NIG Thompson Sampling. + + Removes canonical ordering constraint from NChooseK groups, allowing all + unselected features at every decision point. A transposition table merges + nodes with identical selected feature sets, preventing the exponential + blowup that would otherwise result. + + Args: + groups: Collection of NChooseK and categorical constraints + reward_fn: Function mapping (selected_features, cat_selections) to reward + nig_alpha0: NIG shape prior (default 1.0); lower = heavier tails at low n + ts_prior_var: Prior variance (default 1.0); used to set beta0 = alpha0 * prior_var + adaptive_prior_var: If True, set prior variance to running empirical variance + cache_hit_mode: How to handle cache hits: "no_update", "variance_inflation", + "pessimistic", or "combined" + variance_decay: Decay factor for variance inflation mode (default 0.95) + rollout_mode: Rollout policy: "uniform_subset", "ts_group_action", or "uniform" + pw_k0: Progressive widening base constant (default 2.0) + pw_alpha: Progressive widening exponent (default 0.6) + max_rollout_retries: Maximum rollout retries on cache hit (default 3) + adaptive_n0: If True, set n0 = 1 + log(branching_factor) to slow down + premature convergence with the DAG's larger branching factors + informed_expansion: If True, use rollout TS stats to prioritize which + unexpanded actions to try first, instead of random selection + separate_stop: If True, treat STOP as a binary decision (stop vs continue) + before selecting which feature. This gives STOP a fair 50/50 comparison + against the best feature alternative, fixing the STOP dilution problem + where STOP competes with many features (1-out-of-N) in the flat action + space. Critical for the DAG where branching stays constant. + use_cache: If True (default), cache terminal evaluations. If False, every + evaluation calls reward_fn fresh — required for stochastic/noisy reward + functions (e.g. Sobol-based acqf sampling). With use_cache=False, every + observation is novel, cache-hit modes are never triggered, and rollout + retry is skipped. + seed: Random seed for reproducibility + """ + + def __init__( + self, + groups: Groups, + reward_fn: Callable[[tuple[int, ...], dict[int, float]], float], + nig_alpha0: float = 1.0, + ts_prior_var: float = 1.0, + adaptive_prior_var: bool = False, + cache_hit_mode: str = "no_update", + variance_decay: float = 0.95, + rollout_mode: str = "uniform", + pw_k0: float = 2.0, + pw_alpha: float = 0.6, + max_rollout_retries: int = 3, + adaptive_n0: bool = False, + informed_expansion: bool = False, + separate_stop: bool = False, + use_cache: bool = True, + seed: Optional[int] = None, + ): + self.groups = groups + self.reward_fn = reward_fn + self.nig_alpha0 = nig_alpha0 + self.ts_prior_var = ts_prior_var + self.adaptive_prior_var = adaptive_prior_var + self.cache_hit_mode = cache_hit_mode + self.variance_decay = variance_decay + self.rollout_mode = rollout_mode + self.pw_k0 = pw_k0 + self.pw_alpha = pw_alpha + self.max_rollout_retries = max_rollout_retries + self.adaptive_n0 = adaptive_n0 + self.informed_expansion = informed_expansion + self.separate_stop = separate_stop + self.use_cache = use_cache + self.rng = random.Random(seed) + + # Initialize root node + n_groups = len(groups) + self.root = TSNode( + partial_by_group=tuple(() for _ in range(n_groups)), + stopped_by_group=tuple(False for _ in range(n_groups)), + group_idx=0, + ) + + # Transposition table: canonical key -> TSNode + self.transposition_table: dict[tuple, TSNode] = {} + root_key = _transposition_key( + 0, self.root.partial_by_group, self.root.stopped_by_group, self.groups + ) + self.transposition_table[root_key] = self.root + + # Best found so far + self.best_selection: Optional[tuple[tuple[int, ...], dict[int, float]]] = None + self.best_value: float = float("-inf") + + # Cache for terminal evaluations + self.value_cache: dict[tuple, float] = {} + self.cache_hits = 0 + self.cache_misses = 0 + + # Global reward tracking for NIG prior center and adaptive variance + self._novel_reward_sum: float = 0.0 + self._novel_reward_sq_sum: float = 0.0 + self._novel_reward_count: int = 0 + + # Rollout TS statistics: (group_idx, action) -> TSActionStats + self.rollout_ts_stats: dict[tuple, TSActionStats] = {} + + # --- NIG prior methods ---------------------------------------------------- + + def _global_mean(self) -> float: + """Running mean of all novel rewards (prior center mu0).""" + if self._novel_reward_count == 0: + return 0.0 + return self._novel_reward_sum / self._novel_reward_count + + def _prior_var(self) -> float: + """Prior variance, either fixed or adaptive (empirical variance).""" + if not self.adaptive_prior_var or self._novel_reward_count < 2: + return self.ts_prior_var + mean = self._global_mean() + empirical_var = ( + self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + ) + return max(empirical_var, 1e-8) + + def _pessimistic_value(self) -> float: + """Pessimistic pseudo-observation value: global_mean - global_std.""" + mean = self._global_mean() + if self._novel_reward_count < 2: + return mean - math.sqrt(self.ts_prior_var) + empirical_var = ( + self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + ) + return mean - math.sqrt(max(empirical_var, 1e-8)) + + def _compute_n0(self, n_actions: int) -> float: + """Compute pseudo-count n0 from branching factor. + + With adaptive_n0, n0 = 1 + log(branching_factor). Higher branching + means each child is visited rarely early on, so we need more + observations before departing from the prior. This prevents premature + lock-in with the DAG's larger action spaces. + """ + if not self.adaptive_n0: + return 1.0 + return 1.0 + math.log(max(n_actions, 2)) + + # --- NIG Student-t sampling ----------------------------------------------- + + def _student_t_sample(self, df: float, loc: float, scale: float) -> float: + """Sample from a Student-t distribution. + + Uses the representation: loc + scale * Z / sqrt(V / df) + where Z ~ N(0,1) and V ~ chi-squared(df) = Gamma(df/2, 2). + """ + z = self.rng.gauss(0, 1) + v = self.rng.gammavariate(df / 2, 2) + return loc + scale * z / math.sqrt(v / df) + + def _nig_sample_score(self, node: TSNode, n0: float = 1.0) -> float: + """Sample from node's NIG posterior (marginal Student-t) for tree selection.""" + mu0 = self._global_mean() + prior_var = self._prior_var() + alpha0 = self.nig_alpha0 + beta0 = alpha0 * prior_var + + n = node.n_obs + if n == 0: + df = 2 * alpha0 + scale = math.sqrt(beta0 / (alpha0 * n0)) + return self._student_t_sample(df, mu0, scale) + + x_bar = node.sum_rewards / n + s = node.sum_sq_rewards - n * x_bar * x_bar + s = max(s, 0.0) + + n0_post = n0 + n + mu0_post = (n0 * mu0 + n * x_bar) / n0_post + alpha0_post = alpha0 + n / 2 + beta0_post = beta0 + s / 2 + (n0 * n * (x_bar - mu0) ** 2) / (2 * n0_post) + + df = 2 * alpha0_post + scale = math.sqrt(beta0_post / (alpha0_post * n0_post)) + return self._student_t_sample(df, mu0_post, scale) + + def _nig_sample_action_score(self, stats: TSActionStats, n0: float = 1.0) -> float: + """Sample from a TSActionStats NIG posterior (for rollout actions).""" + mu0 = self._global_mean() + prior_var = self._prior_var() + alpha0 = self.nig_alpha0 + beta0 = alpha0 * prior_var + + n = stats.n_obs + if n == 0: + df = 2 * alpha0 + scale = math.sqrt(beta0 / (alpha0 * n0)) + return self._student_t_sample(df, mu0, scale) + + x_bar = stats.sum_rewards / n + s = stats.sum_sq_rewards - n * x_bar * x_bar + s = max(s, 0.0) + + n0_post = n0 + n + mu0_post = (n0 * mu0 + n * x_bar) / n0_post + alpha0_post = alpha0 + n / 2 + beta0_post = beta0 + s / 2 + (n0 * n * (x_bar - mu0) ** 2) / (2 * n0_post) + + df = 2 * alpha0_post + scale = math.sqrt(beta0_post / (alpha0_post * n0_post)) + return self._student_t_sample(df, mu0_post, scale) + + # --- DAG-specific legal actions ------------------------------------------- + + def _nchoosek_legal_actions_dag( + self, group: NChooseK, partial: tuple[int, ...], stopped: bool + ) -> list[int]: + """Legal actions for NChooseK without canonical ordering. + + Returns ALL unselected features (not just those > last), plus STOP + when min_count is satisfied. This is THE CORE CHANGE from tree MCTS. + """ + if stopped or len(partial) >= group.max_count: + return [] + + selected = set(partial) + actions = [i for i in range(group.n_features) if i not in selected] + + if len(partial) >= group.min_count: + actions.append(STOP) + + return actions + + def _legal_actions(self, node: TSNode) -> list[int]: + """Get legal actions for current group in node.""" + if node.is_terminal(self.groups): + return [] + g = node.group_idx + group = self.groups.groups[g] + partial = node.partial_by_group[g] + stopped = node.stopped_by_group[g] + + if isinstance(group, NChooseK): + return self._nchoosek_legal_actions_dag(group, partial, stopped) + else: + return group.legal_actions(partial, stopped) + + # --- DAG-aware apply_action ----------------------------------------------- + + def _apply_action(self, node: TSNode, action: int) -> TSNode: + """Apply action, returning shared node from transposition table if exists.""" + g = node.group_idx + group = self.groups.groups[g] + + partials = list(node.partial_by_group) + stoppeds = list(node.stopped_by_group) + + if action == STOP: + stoppeds[g] = True + next_g = g + 1 + else: + partials[g] += (action,) + if group.is_complete(partials[g], stoppeds[g]): + next_g = g + 1 + else: + next_g = g + + new_partial = tuple(partials) + new_stopped = tuple(stoppeds) + + # Lookup in transposition table + key = _transposition_key(next_g, new_partial, new_stopped, self.groups) + if key in self.transposition_table: + return self.transposition_table[key] + + # Create new node and register + new_node = TSNode( + partial_by_group=new_partial, + stopped_by_group=new_stopped, + group_idx=next_g, + ) + self.transposition_table[key] = new_node + return new_node + + # --- Cache / utility methods ---------------------------------------------- + + def _make_cache_key( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> tuple: + """Create hashable cache key from selection.""" + return (selected_features, frozenset(cat_selections.items())) + + def _cached_reward( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> float: + """Get cached reward or compute and cache it. + + With use_cache=False, always calls reward_fn fresh (for stochastic rewards). + """ + if not self.use_cache: + self.cache_misses += 1 + return self.reward_fn(selected_features, cat_selections) + key = self._make_cache_key(selected_features, cat_selections) + if key in self.value_cache: + self.cache_hits += 1 + return self.value_cache[key] + val = self.reward_fn(selected_features, cat_selections) + self.value_cache[key] = val + self.cache_misses += 1 + return val + + def _child_limit(self, node: TSNode) -> int: + """Progressive widening: max children based on visit count.""" + return max(1, int(self.pw_k0 * (max(1, node.n_visits) ** self.pw_alpha))) + + def _get_selection(self, node: TSNode) -> tuple[tuple[int, ...], dict[int, float]]: + """Convert node's partial selections to (selected_features, cat_selections).""" + selected_features = [] + for g, nchoosek in enumerate(self.groups.nchooseks): + for local_idx in node.partial_by_group[g]: + selected_features.append(nchoosek.features[local_idx]) + selected_features_tuple = tuple(sorted(selected_features)) + + cat_selections: dict[int, float] = {} + n_nchoosek = len(self.groups.nchooseks) + for i, cat_group in enumerate(self.groups.categoricals): + g = n_nchoosek + i + partial = node.partial_by_group[g] + if partial: + cat_selections[cat_group.dim] = cat_group.values[partial[0]] + + return selected_features_tuple, cat_selections + + def _get_selection_from_state( + self, + partial_by_group: list[list[int]], + stopped_by_group: list[bool], + ) -> tuple[tuple[int, ...], dict[int, float]]: + """Convert mutable rollout state to (selected_features, cat_selections).""" + selected_features = [] + for g, nchoosek in enumerate(self.groups.nchooseks): + for local_idx in partial_by_group[g]: + selected_features.append(nchoosek.features[local_idx]) + selected_features_tuple = tuple(sorted(selected_features)) + + cat_selections: dict[int, float] = {} + n_nchoosek = len(self.groups.nchooseks) + for i, cat_group in enumerate(self.groups.categoricals): + g = n_nchoosek + i + partial = partial_by_group[g] + if partial: + cat_selections[cat_group.dim] = cat_group.values[partial[0]] + + return selected_features_tuple, cat_selections + + def cache_stats(self) -> dict[str, int]: + """Return cache statistics.""" + return { + "hits": self.cache_hits, + "misses": self.cache_misses, + "size": len(self.value_cache), + } + + # --- Tree selection with NIG Thompson Sampling ---------------------------- + + def _pick_expansion_action(self, node: TSNode, unexpanded: list[int]) -> int: + """Pick which unexpanded action to expand next. + + With informed_expansion, uses rollout TS stats to sample from + NIG posteriors for each candidate and picks the highest-scoring one. + Otherwise picks uniformly at random. + """ + if not self.informed_expansion or not self.rollout_ts_stats: + return self.rng.choice(unexpanded) + + g = node.group_idx + best_action = unexpanded[0] + best_score = float("-inf") + for action in unexpanded: + key = (g, action) + stats = self.rollout_ts_stats.get(key, TSActionStats(0, 0.0, 0.0)) + score = self._nig_sample_action_score(stats) + if score > best_score: + best_score = score + best_action = action + return best_action + + def _select_action_separate_stop(self, node: TSNode, n0: float) -> int: + """Select among children using binary STOP-vs-best-feature comparison. + + When STOP is among the children, first sample STOP's NIG score, then + sample the best feature's NIG score, and compare. This gives STOP a + fair 50/50 chance instead of being diluted among many feature actions. + Falls back to normal NIG-TS when STOP is not among children. + """ + if STOP not in node.children: + # No STOP — normal NIG-TS among all children + best_action = None + best_score = float("-inf") + for action, child in node.children.items(): + score = self._nig_sample_score(child, n0=n0) + if score > best_score: + best_score = score + best_action = action + return best_action + + # Binary comparison: STOP vs best feature + stop_score = self._nig_sample_score(node.children[STOP], n0=n0) + + feature_children = {a: c for a, c in node.children.items() if a != STOP} + if not feature_children: + return STOP + + best_feat_action = None + best_feat_score = float("-inf") + for action, child in feature_children.items(): + score = self._nig_sample_score(child, n0=n0) + if score > best_feat_score: + best_feat_score = score + best_feat_action = action + + if stop_score >= best_feat_score: + return STOP + return best_feat_action + + def _select_and_expand(self) -> tuple[TSNode, list[TSNode]]: + """Select path through DAG using NIG-TS and expand one new node.""" + node = self.root + path = [node] + + while not node.is_terminal(self.groups): + legal = self._legal_actions(node) + limit = self._child_limit(node) + unexpanded = [a for a in legal if a not in node.children] + can_expand = len(node.children) < limit + + if can_expand and unexpanded: + # Expand: prioritize STOP first when legal and unexpanded + if self.separate_stop and STOP in unexpanded: + action = STOP + else: + action = self._pick_expansion_action(node, unexpanded) + child = self._apply_action(node, action) + node.children[action] = child + path.append(child) + return child, path + + # NIG Thompson Sampling selection among existing children + if node.children: + n0 = self._compute_n0(len(node.children)) + + if self.separate_stop: + best_action = self._select_action_separate_stop(node, n0) + else: + best_action = None + best_score = float("-inf") + for action, child in node.children.items(): + score = self._nig_sample_score(child, n0=n0) + if score > best_score: + best_score = score + best_action = action + + node = node.children[best_action] + path.append(node) + else: + break + + return node, path + + # --- Rollout action selection --------------------------------------------- + + def _ts_sample_rollout_action( + self, group_idx: int, legal_actions: list[int] + ) -> int: + """Sample rollout action using per-(group, action) NIG posteriors. + + With separate_stop, uses binary STOP-vs-best-feature comparison + to give STOP fair representation in rollout decisions. + """ + if self.separate_stop and STOP in legal_actions: + # Binary: STOP vs best feature + stop_key = (group_idx, STOP) + stop_stats = self.rollout_ts_stats.get(stop_key, TSActionStats(0, 0.0, 0.0)) + stop_score = self._nig_sample_action_score(stop_stats) + + feature_actions = [a for a in legal_actions if a != STOP] + if not feature_actions: + return STOP + + best_feat = feature_actions[0] + best_feat_score = float("-inf") + for action in feature_actions: + key = (group_idx, action) + stats = self.rollout_ts_stats.get(key, TSActionStats(0, 0.0, 0.0)) + score = self._nig_sample_action_score(stats) + if score > best_feat_score: + best_feat_score = score + best_feat = action + + if stop_score >= best_feat_score: + return STOP + return best_feat + + # Normal: NIG-TS among all actions + best_action = legal_actions[0] + best_score = float("-inf") + + for action in legal_actions: + key = (group_idx, action) + stats = self.rollout_ts_stats.get(key, TSActionStats(0, 0.0, 0.0)) + score = self._nig_sample_action_score(stats) + if score > best_score: + best_score = score + best_action = action + + return best_action + + # --- Rollout dispatch ----------------------------------------------------- + + def _rollout( + self, node: TSNode + ) -> tuple[tuple[int, ...], dict[int, float], list[tuple[int, int]]]: + """Rollout to terminal state with mode-dependent action selection. + + Operates on mutable state, NOT the transposition table, to avoid + bloating the table with rollout-only nodes. + + Returns: + Tuple of (selected_features, cat_selections, trajectory) where + trajectory is a list of (group_idx, action) tuples. + """ + # Work with mutable copies + partial_by_group = [list(p) for p in node.partial_by_group] + stopped_by_group = list(node.stopped_by_group) + current_g = node.group_idx + trajectory: list[tuple[int, int]] = [] + + n_groups = len(self.groups) + + if self.rollout_mode == "uniform_subset": + # Fast rollout: for each incomplete group, directly sample a subset + while current_g < n_groups: + group = self.groups.groups[current_g] + + if isinstance(group, NChooseK): + if ( + not stopped_by_group[current_g] + and len(partial_by_group[current_g]) < group.max_count + ): + selected = set(partial_by_group[current_g]) + available = [ + i for i in range(group.n_features) if i not in selected + ] + + m = len(partial_by_group[current_g]) + min_more = max(0, group.min_count - m) + max_more = min(len(available), group.max_count - m) + + if max_more > 0 and min_more <= max_more: + k = self.rng.randint(min_more, max_more) + chosen = self.rng.sample(available, k) + for feat in chosen: + trajectory.append((current_g, feat)) + partial_by_group[current_g].append(feat) + + stopped_by_group[current_g] = True + current_g += 1 + else: + # Categorical: pick uniformly + if not partial_by_group[current_g]: + action = self.rng.randrange(group.n_options) + trajectory.append((current_g, action)) + partial_by_group[current_g].append(action) + current_g += 1 + + else: + # Step-by-step rollout (uniform or ts_group_action) + while current_g < n_groups: + group = self.groups.groups[current_g] + + if isinstance(group, NChooseK): + partial_tuple = tuple(partial_by_group[current_g]) + stopped = stopped_by_group[current_g] + legal = self._nchoosek_legal_actions_dag( + group, partial_tuple, stopped + ) + else: + partial_tuple = tuple(partial_by_group[current_g]) + stopped = stopped_by_group[current_g] + legal = group.legal_actions(partial_tuple, stopped) + + if not legal: + current_g += 1 + continue + + if self.rollout_mode == "uniform": + if self.separate_stop and STOP in legal: + # 50/50 stop vs continue + if self.rng.random() < 0.5: + action = STOP + else: + features = [a for a in legal if a != STOP] + action = self.rng.choice(features) if features else STOP + else: + action = self.rng.choice(legal) + elif self.rollout_mode == "ts_group_action": + action = self._ts_sample_rollout_action(current_g, legal) + else: + raise ValueError(f"Unknown rollout_mode: {self.rollout_mode}") + + trajectory.append((current_g, action)) + + # Apply action to mutable state + if action == STOP: + stopped_by_group[current_g] = True + current_g += 1 + else: + partial_by_group[current_g].append(action) + if group.is_complete( + tuple(partial_by_group[current_g]), + stopped_by_group[current_g], + ): + current_g += 1 + + selected_features, cat_selections = self._get_selection_from_state( + partial_by_group, stopped_by_group + ) + return selected_features, cat_selections, trajectory + + # --- Backpropagation ------------------------------------------------------ + + def _backpropagate(self, path: list[TSNode], reward: float, is_novel: bool) -> None: + """Backpropagate reward through path. + + Novel: update n_obs, sum_rewards, sum_sq_rewards, n_visits. + Cache hit handling depends on cache_hit_mode: + - no_update: only increment n_visits + - variance_inflation: decay n_obs to widen posterior + - pessimistic: inject pessimistic pseudo-observation + - combined: variance inflation + pessimistic + """ + if is_novel: + for n in path: + n.n_obs += 1 + n.sum_rewards += reward + n.sum_sq_rewards += reward * reward + n.n_visits += 1 + else: + if self.cache_hit_mode in ("pessimistic", "combined"): + pess = self._pessimistic_value() + + for n in path: + n.n_visits += 1 + + if self.cache_hit_mode == "variance_inflation": + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + elif self.cache_hit_mode == "pessimistic": + n.n_obs += 1 + n.sum_rewards += pess + n.sum_sq_rewards += pess * pess + elif self.cache_hit_mode == "combined": + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + n.n_obs += 1 + n.sum_rewards += pess + n.sum_sq_rewards += pess * pess + + def _update_rollout_ts_stats( + self, trajectory: list[tuple[int, int]], reward: float + ) -> None: + """Update per-(group, action) TS stats from rollout trajectory.""" + if self.rollout_mode != "ts_group_action": + return + + for group_idx, action in trajectory: + key = (group_idx, action) + old = self.rollout_ts_stats.get(key, TSActionStats(0, 0.0, 0.0)) + self.rollout_ts_stats[key] = TSActionStats( + n_obs=old.n_obs + 1, + sum_rewards=old.sum_rewards + reward, + sum_sq_rewards=old.sum_sq_rewards + reward * reward, + ) + + # --- Main loop ------------------------------------------------------------ + + def run(self, n_iterations: int) -> tuple[tuple[int, ...], dict[int, float], float]: + """Run MCTS-DAG for specified number of iterations. + + Args: + n_iterations: Number of MCTS iterations to run + + Returns: + Tuple of (selected_features, cat_selections, best_value) + """ + for _ in range(n_iterations): + leaf, path = self._select_and_expand() + + if leaf.is_terminal(self.groups): + selected_features, cat_selections = self._get_selection(leaf) + trajectory: list[tuple[int, int]] = [] + else: + selected_features, cat_selections, trajectory = self._rollout(leaf) + # Rollout retry: re-roll on cache hits to discover novel selections + # Skip when use_cache=False (every eval is fresh) + if self.use_cache: + for _attempt in range(self.max_rollout_retries): + key = self._make_cache_key(selected_features, cat_selections) + if key not in self.value_cache: + break + selected_features, cat_selections, trajectory = self._rollout( + leaf + ) + + if self.use_cache: + key = self._make_cache_key(selected_features, cat_selections) + is_novel = key not in self.value_cache + else: + is_novel = True + reward = self._cached_reward(selected_features, cat_selections) + + if reward > self.best_value: + self.best_value = reward + self.best_selection = (selected_features, cat_selections) + + # Update global stats for NIG prior (mean and adaptive variance) + if is_novel: + self._novel_reward_sum += reward + self._novel_reward_sq_sum += reward * reward + self._novel_reward_count += 1 + + # Backpropagate (raw reward, no normalization for NIG-TS) + self._backpropagate(path, reward, is_novel) + + # Update rollout statistics + self._update_rollout_ts_stats(trajectory, reward) + + if self.best_selection is None: + return (), {}, self.best_value + return self.best_selection[0], self.best_selection[1], self.best_value diff --git a/mcts-report/optimize_mcts_full.py b/mcts-report/optimize_mcts_full.py new file mode 100644 index 000000000..22d269c29 --- /dev/null +++ b/mcts-report/optimize_mcts_full.py @@ -0,0 +1,1043 @@ +"""MCTS-based acquisition function optimization for NChooseK and categorical constraints. + +Uses Monte Carlo Tree Search to select which features are active (non-zero) and +categorical values, then runs BoTorch acquisition function optimization with +inactive features fixed to zero and categoricals fixed to selected values. +""" + +import math +import random +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Callable, Optional + +import torch +from botorch.optim import optimize_acqf +from torch import Tensor + + +STOP = -1 # Sentinel for stopping selection in a group + + +# ============================================================================= +# Group abstractions for MCTS +# ============================================================================= + + +class Group(ABC): + """Abstract base class for MCTS groups (NChooseK or Categorical).""" + + @property + @abstractmethod + def n_options(self) -> int: + """Number of options/actions available in this group.""" + pass + + @abstractmethod + def legal_actions(self, partial: tuple[int, ...], stopped: bool) -> list[int]: + """Return legal actions given current partial selection.""" + pass + + @abstractmethod + def is_complete(self, partial: tuple[int, ...], stopped: bool) -> bool: + """Check if selection for this group is complete.""" + pass + + +@dataclass(frozen=True) +class NChooseK(Group): + """NChooseK constraint specifying feature selection bounds. + + Args: + features: Feature indices (can be non-contiguous, e.g., [0, 2, 4]) + min_count: Minimum number of features to select + max_count: Maximum number of features to select + """ + + features: Sequence[int] + min_count: int + max_count: int + + def __post_init__(self): + n = len(self.features) + if not (0 <= self.min_count <= self.max_count <= n): + raise ValueError( + f"Invalid NChooseK constraint: require 0 <= min_count <= max_count <= n; " + f"got min_count={self.min_count}, max_count={self.max_count}, n={n}" + ) + + @property + def n_options(self) -> int: + return len(self.features) + + @property + def n_features(self) -> int: + return len(self.features) + + def legal_actions(self, partial: tuple[int, ...], stopped: bool) -> list[int]: + """Compute legal actions within this NChooseK group. + + Actions are indices into self.features (not the actual feature indices). + Enforces strictly increasing selection order (combinations, not permutations). + STOP is legal if len(partial) >= min_count and not already stopped. + """ + n = self.n_features + m = len(partial) + + if stopped or m >= self.max_count: + return [] + + actions: list[int] = [] + last = partial[-1] if partial else -1 + + # Remaining picks needed after this action to satisfy min_count + r_min_needed = max(0, self.min_count - (m + 1)) + # After picking index i, n - (i+1) items remain; require n - (i+1) >= r_min_needed + end_inclusive = n - r_min_needed - 1 + start = last + 1 + + if start <= end_inclusive: + actions.extend(range(start, end_inclusive + 1)) + + if m >= self.min_count: + actions.append(STOP) + + return actions + + def is_complete(self, partial: tuple[int, ...], stopped: bool) -> bool: + """NChooseK is complete when stopped or max_count reached.""" + return stopped or len(partial) >= self.max_count + + +@dataclass(frozen=True) +class Categorical(Group): + """Categorical dimension with allowed values. + + Args: + dim: The dimension index in the input space + values: Sequence of allowed values for this dimension + """ + + dim: int + values: Sequence[float] + + def __post_init__(self): + if len(self.values) < 2: + raise ValueError( + f"CategoricalGroup requires at least two values, got {len(self.values)}" + ) + + @property + def n_options(self) -> int: + return len(self.values) + + def legal_actions(self, partial: tuple[int, ...], stopped: bool) -> list[int]: + """Categorical must select exactly one value. No STOP action.""" + if len(partial) >= 1: + # Already selected + return [] + # All value indices are legal + return list(range(self.n_options)) + + def is_complete(self, partial: tuple[int, ...], stopped: bool) -> bool: + """Categorical is complete when one value is selected.""" + return len(partial) >= 1 + + +# ============================================================================= +# Combined constraints container +# ============================================================================= + + +@dataclass(frozen=True) +class Groups: + """Collection of NChooseK constraints and categorical groups.""" + + groups: list[Group] + + def __len__(self) -> int: + return len(self.groups) + + @property + def categoricals(self) -> list[Categorical]: + return [g for g in self.groups if isinstance(g, Categorical)] + + @property + def nchooseks(self) -> list[NChooseK]: + return [g for g in self.groups if isinstance(g, NChooseK)] + + @property + def all_nchoosek_features(self) -> list[int]: + """All feature indices covered by NChooseK constraints.""" + all_feats = [] + for c in self.nchooseks: + all_feats.extend(c.features) + return all_feats + + @property + def all_categorical_dims(self) -> list[int]: + """All dimension indices that are categorical.""" + return [c.dim for c in self.categoricals] + + +# ============================================================================= +# MCTS Node +# ============================================================================= + + +@dataclass +class Node: + """MCTS tree node. + + Args: + partial_by_group: Partial selection per group (indices into group's options) + stopped_by_group: Whether each group has stopped selecting (for NChooseK) + group_idx: Current group being filled + n_visits: Visit count for this node + w_total: Total accumulated reward + children: Child nodes keyed by action (int index or STOP) + """ + + partial_by_group: tuple[tuple[int, ...], ...] + stopped_by_group: tuple[bool, ...] + group_idx: int + + n_visits: int = 0 + w_total: float = 0.0 + + children: dict[int, "Node"] = field(default_factory=dict) + + def is_terminal(self, groups: Groups) -> bool: + return self.group_idx >= len(groups) + + def mean_value(self) -> float: + return self.w_total / self.n_visits if self.n_visits > 0 else 0.0 + + +# ============================================================================= +# MCTS Implementation +# ============================================================================= + + +class MCTS: + """Monte Carlo Tree Search for NChooseK and categorical optimization. + + Uses UCT selection, RAVE action value estimation, and progressive widening. + Selects which features are active and categorical values via tree search, + evaluating terminals with a provided reward function. + + Args: + constraints: Collection of NChooseK and categorical constraints + reward_fn: Function mapping (selected_features, categorical_selections) to reward + c_uct: UCT exploration constant (default 0.01) + k_rave: RAVE blending decay parameter (default 300.0) + p_stop_rollout: Probability of early stop during rollout (default 0.35) + pw_k0: Progressive widening base constant (default 2.0) + pw_alpha: Progressive widening exponent (default 0.6) + max_rollout_retries: Maximum rollout retries on cache hit (default 3) + adaptive_p_stop: Enable adaptive per-group stop probability (default True) + p_stop_warmup: Per-group rollout count before full blending (default 20) + p_stop_temperature: Sigmoid sharpness for adaptive p_stop (default 0.25) + normalize_rewards: Normalize rewards to [0, 1] before backpropagation (default True) + rollout_policy: Enable learned softmax rollout policy (default False) + rollout_epsilon: Epsilon-mix weight for uniform exploration (default 0.3) + rollout_tau: Softmax temperature for rollout policy (default 1.0) + rollout_novelty_weight: Novelty bonus coefficient beta/sqrt(n+1) (default 1.0) + context_rave: Use context-aware RAVE instead of global RAVE (default False) + seed: Random seed for reproducibility + """ + + def __init__( + self, + groups: Groups, + reward_fn: Callable[[tuple[int, ...], dict[int, float]], float], + c_uct: float = 0.01, + k_rave: float = 300.0, + p_stop_rollout: float = 0.35, + pw_k0: float = 2.0, + pw_alpha: float = 0.6, + max_rollout_retries: int = 3, + adaptive_p_stop: bool = True, + p_stop_warmup: int = 20, + p_stop_temperature: float = 0.25, + normalize_rewards: bool = True, + rollout_policy: bool = True, + rollout_epsilon: float = 0.3, + rollout_tau: float = 1.0, + rollout_novelty_weight: float = 1.0, + context_rave: bool = False, + seed: Optional[int] = None, + ): + self.groups = groups + self.reward_fn = reward_fn + self.c_uct = c_uct + self.k_rave = k_rave + self.p_stop_rollout = p_stop_rollout + self.pw_k0 = pw_k0 + self.pw_alpha = pw_alpha + self.max_rollout_retries = max_rollout_retries + self.adaptive_p_stop = adaptive_p_stop + self.p_stop_warmup = p_stop_warmup + self.p_stop_temperature = p_stop_temperature + self.normalize_rewards = normalize_rewards + self.rollout_policy = rollout_policy + self.rollout_epsilon = rollout_epsilon + self.rollout_tau = rollout_tau + self.rollout_novelty_weight = rollout_novelty_weight + self.context_rave = context_rave + self.rng = random.Random(seed) + + # Initialize root node + n_groups = len(groups) + self.root = Node( + partial_by_group=tuple(() for _ in range(n_groups)), + stopped_by_group=tuple(False for _ in range(n_groups)), + group_idx=0, + ) + + # Best found so far: (selected_features, categorical_selections) + # Example: ((0, 2, 5), {3: 1.0, 4: 0.0}) means features 0, 2, 5 are + # active (from NChooseK groups), dim 3 has categorical value 1.0, and + # dim 4 has categorical value 0.0. + self.best_selection: Optional[tuple[tuple[int, ...], dict[int, float]]] = None + self.best_value: float = float("-inf") + + # Cache for terminal evaluations + # Key: (selected_features_tuple, frozenset of categorical items) + self.value_cache: dict[tuple, float] = {} + self.cache_hits = 0 + self.cache_misses = 0 + + # RAVE statistics: global_id -> (visits, total_reward) + self.global_offsets = self._compute_group_offsets() + self.rave_stats: dict[int, tuple[int, float]] = {} + + # Adaptive p_stop statistics: (group_idx, cardinality) -> (visits, total_reward) + self.cardinality_stats: dict[tuple[int, int], tuple[int, float]] = {} + n_nchoosek = len(self.groups.nchooseks) + self.group_rollout_counts: list[int] = [0] * n_nchoosek + self.reward_min: float = float("inf") + self.reward_max: float = float("-inf") + + # Rollout policy statistics: (group_idx, action) -> (visits, total_reward) + self.rollout_stats: dict[tuple[int, int], tuple[int, float]] = {} + + # Context-aware RAVE: (group_idx, cardinality, action) -> (visits, total_reward) + self.context_rave_stats: dict[tuple[int, int, int], tuple[int, float]] = {} + + def _compute_group_offsets(self) -> list[int]: + """Compute offset for each group to create global action IDs.""" + offsets = [] + acc = 0 + for group in self.groups.groups: + offsets.append(acc) + acc += group.n_options + return offsets + + def _global_action_id(self, group_idx: int, local_idx: int) -> int: + """Convert (group, local_index) to global action ID for RAVE.""" + return self.global_offsets[group_idx] + local_idx + + def _update_cardinality_stats( + self, reward: float, selected_features: tuple[int, ...] + ) -> None: + """Update per-(group, cardinality) stats from a completed rollout. + + Reverse-maps selected_features to per-group cardinalities and updates + the cardinality_stats dict. + """ + selected_set = set(selected_features) + for g, nchoosek in enumerate(self.groups.nchooseks): + cardinality = sum(1 for f in nchoosek.features if f in selected_set) + key = (g, cardinality) + v, tot = self.cardinality_stats.get(key, (0, 0.0)) + self.cardinality_stats[key] = (v + 1, tot + reward) + self.group_rollout_counts[g] += 1 + + def _compute_adaptive_p_stop( + self, group_idx: int, current_cardinality: int + ) -> float: + """Compute adaptive stop probability for a group at a given cardinality. + + Uses learned cardinality statistics to decide whether stopping is better + than continuing. Returns fixed p_stop_rollout when disabled, no data is + available, or reward range is zero. During warmup, the learned probability + is linearly blended with p_stop_rollout: + p = (1 - alpha) * p_stop_rollout + alpha * p_learned + where alpha = min(1, group_visits / p_stop_warmup), so the learned signal + gradually replaces the fixed prior as more data is collected. + + Args: + group_idx: Index of the NChooseK group + current_cardinality: Number of features already selected in this group + + Returns: + Stop probability in [0, 1] + """ + if not self.adaptive_p_stop: + return self.p_stop_rollout + + nchoosek = self.groups.nchooseks[group_idx] + max_count = nchoosek.max_count + + # No data for stopping at this cardinality + stop_key = (group_idx, current_cardinality) + stop_stats = self.cardinality_stats.get(stop_key) + if stop_stats is None or stop_stats[0] == 0: + return self.p_stop_rollout + + # E_stop: mean reward when this group stopped at current_cardinality + e_stop = stop_stats[1] / stop_stats[0] + + # E_continue: max mean reward among higher cardinalities + e_continue = float("-inf") + has_continue_data = False + for m in range(current_cardinality + 1, max_count + 1): + cont_key = (group_idx, m) + cont_stats = self.cardinality_stats.get(cont_key) + if cont_stats is not None and cont_stats[0] > 0: + mean_r = cont_stats[1] / cont_stats[0] + if mean_r > e_continue: + e_continue = mean_r + has_continue_data = True + + if not has_continue_data: + return self.p_stop_rollout + + # Reward range for normalization + reward_range = self.reward_max - self.reward_min + if reward_range <= 0: + return self.p_stop_rollout + + # Sigmoid on normalized difference + tau = self.p_stop_temperature + logit = (e_stop - e_continue) / (tau * reward_range) + logit = max(-10.0, min(10.0, logit)) # clamp + p_learned = 1.0 / (1.0 + math.exp(-logit)) + + # Warmup blending + group_visits = self.group_rollout_counts[group_idx] + alpha = ( + min(1.0, group_visits / self.p_stop_warmup) + if self.p_stop_warmup > 0 + else 1.0 + ) + return (1.0 - alpha) * self.p_stop_rollout + alpha * p_learned + + def _normalize_reward(self, reward: float) -> float: + """Normalize reward to [0, 1] using running min-max. + + Returns 0.5 when reward range is zero, which covers the initial + rollouts where only one distinct reward has been observed (min == max) + as well as degenerate cases where all rewards are identical. + """ + reward_range = self.reward_max - self.reward_min + if reward_range <= 0: + return 0.5 + return (reward - self.reward_min) / reward_range + + def _score_rollout_actions( + self, group_idx: int, legal_actions: list[int] + ) -> dict[int, float]: + """Score legal rollout actions using learned statistics. + + For each action, computes: + score(a) = mean_reward(a) + novelty_weight / sqrt(visits(a) + 1) + + The 1/sqrt(visits) term is a UCB-style exploration bonus that decays as + an action is visited more, encouraging under-explored actions to be tried. + Actions with no stats get score = novelty_weight (maximum exploration). + + Args: + group_idx: Index of the current group + legal_actions: List of legal action indices (may include STOP) + + Returns: + Dictionary mapping each action to its score + """ + scores: dict[int, float] = {} + for action in legal_actions: + key = (group_idx, action) + stats = self.rollout_stats.get(key) + if stats is not None and stats[0] > 0: + visits, total_reward = stats + mean_reward = total_reward / visits + novelty = self.rollout_novelty_weight / math.sqrt(visits + 1) + scores[action] = mean_reward + novelty + else: + scores[action] = self.rollout_novelty_weight + return scores + + def _sample_rollout_action(self, group_idx: int, legal_actions: list[int]) -> int: + """Sample a rollout action using softmax policy blended with uniform. + + Computes p(a) = (1 - epsilon) * softmax(score/tau) + epsilon * uniform. + + Args: + group_idx: Index of the current group + legal_actions: List of legal action indices (may include STOP) + + Returns: + Selected action index + """ + scores = self._score_rollout_actions(group_idx, legal_actions) + n = len(legal_actions) + + # Compute softmax probabilities with temperature + logits = torch.tensor([scores[a] for a in legal_actions], dtype=torch.float64) + policy_probs = torch.softmax(logits / self.rollout_tau, dim=0) + + # Blend with uniform: p(a) = (1 - eps) * softmax + eps * uniform + eps = self.rollout_epsilon + probs = (1.0 - eps) * policy_probs + eps / n + + # Sample using weighted choice + return self.rng.choices(legal_actions, weights=probs.tolist(), k=1)[0] + + def _update_rollout_stats( + self, trajectory: list[tuple[int, int, int]], reward: float + ) -> None: + """Update rollout policy statistics from a completed trajectory. + + Args: + trajectory: List of (group_idx, cardinality, action) triples from + the rollout + reward: Raw reward obtained from the terminal evaluation + """ + for group_idx, _cardinality, action in trajectory: + key = (group_idx, action) + v, tot = self.rollout_stats.get(key, (0, 0.0)) + self.rollout_stats[key] = (v + 1, tot + reward) + + @staticmethod + def _extract_tree_actions(path: list[Node]) -> list[tuple[int, int, int]]: + """Extract (group_idx, cardinality, action) from consecutive node pairs. + + For each parent-child pair in the tree path, determines which action + was taken (feature selection or STOP) and records the context + (group index and cardinality at the time of the action). + + Args: + path: List of nodes from root to leaf in the tree traversal + + Returns: + List of (group_idx, cardinality, action) triples + """ + context_actions: list[tuple[int, int, int]] = [] + for i in range(len(path) - 1): + parent = path[i] + child = path[i + 1] + g = parent.group_idx + cardinality = len(parent.partial_by_group[g]) + if child.stopped_by_group[g] and not parent.stopped_by_group[g]: + action = STOP + else: + child_partial = child.partial_by_group[g] + parent_partial = parent.partial_by_group[g] + if len(child_partial) > len(parent_partial): + action = child_partial[-1] + else: + continue + context_actions.append((g, cardinality, action)) + return context_actions + + def _update_context_rave_stats( + self, context_actions: list[tuple[int, int, int]], reward: float + ) -> None: + """Update context-aware RAVE statistics. + + Args: + context_actions: List of (group_idx, cardinality, action) triples + reward: Normalized reward to accumulate + """ + for group_idx, cardinality, action in context_actions: + key = (group_idx, cardinality, action) + v, tot = self.context_rave_stats.get(key, (0, 0.0)) + self.context_rave_stats[key] = (v + 1, tot + reward) + + def _make_cache_key( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> tuple: + """Create hashable cache key from selection.""" + return (selected_features, frozenset(cat_selections.items())) + + def _cached_reward( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> float: + """Get cached reward or compute and cache it.""" + key = self._make_cache_key(selected_features, cat_selections) + if key in self.value_cache: + self.cache_hits += 1 + return self.value_cache[key] + val = self.reward_fn(selected_features, cat_selections) + self.value_cache[key] = val + self.cache_misses += 1 + return val + + def _child_limit(self, node: Node) -> int: + """Progressive widening: max children based on visit count.""" + return max(1, int(self.pw_k0 * (max(1, node.n_visits) ** self.pw_alpha))) + + def _legal_actions(self, node: Node) -> list[int]: + """Get legal actions for current group in node.""" + if node.is_terminal(self.groups): + return [] + g = node.group_idx + group = self.groups.groups[g] + partial = node.partial_by_group[g] + stopped = node.stopped_by_group[g] + return group.legal_actions(partial, stopped) + + def _apply_action(self, node: Node, action: int) -> Node: + """Create child node by applying action to current node.""" + g = node.group_idx + group = self.groups.groups[g] + + partials = list(node.partial_by_group) + stoppeds = list(node.stopped_by_group) + + if action == STOP: + stoppeds[g] = True + next_g = g + 1 + else: + partials[g] += (action,) + # Check if group is complete + if group.is_complete(partials[g], stoppeds[g]): + next_g = g + 1 + else: + next_g = g + + return Node( + partial_by_group=tuple(partials), + stopped_by_group=tuple(stoppeds), + group_idx=next_g, + ) + + def _get_selection(self, node: Node) -> tuple[tuple[int, ...], dict[int, float]]: + """Convert node's partial selections to (selected_features, cat_selections).""" + # Extract NChooseK selections + selected_features = [] + for g, nchoosek in enumerate(self.groups.nchooseks): + for local_idx in node.partial_by_group[g]: + selected_features.append(nchoosek.features[local_idx]) + selected_features_tuple = tuple(sorted(selected_features)) + + # Extract categorical selections + cat_selections: dict[int, float] = {} + n_nchoosek = len(self.groups.nchooseks) + for i, cat_group in enumerate(self.groups.categoricals): + g = n_nchoosek + i + partial = node.partial_by_group[g] + if partial: + # partial[0] is the index into cat_group.values + cat_selections[cat_group.dim] = cat_group.values[partial[0]] + + return selected_features_tuple, cat_selections + + def _select_and_expand(self) -> tuple[Node, list[Node]]: + """Select path through tree and expand one new node.""" + node = self.root + path = [node] + + while not node.is_terminal(self.groups): + legal = self._legal_actions(node) + limit = self._child_limit(node) + unexpanded = [a for a in legal if a not in node.children.keys()] + can_expand = len(node.children) < limit + + if can_expand and unexpanded: + # Expand one new child + action = self.rng.choice(unexpanded) + child = self._apply_action(node, action) + node.children[action] = child + path.append(child) + return child, path + + # UCT + RAVE selection among existing children + # Bind node via default argument to avoid B023 closure issue + def combined_score(action: int, child: Node, _node: Node = node) -> float: + parent_visits = max(1, _node.n_visits) + child_visits = max(1, child.n_visits) + uct_val = (child.w_total / child_visits) + self.c_uct * math.sqrt( + math.log(parent_visits) / child_visits + ) + + if self.context_rave: + g = _node.group_idx + cardinality = len(_node.partial_by_group[g]) + ctx_key = (g, cardinality, action) + v, tot = self.context_rave_stats.get(ctx_key, (0, 0.0)) + rave_mean = (tot / v) if v > 0 else 0.0 + else: + if action == STOP: + rave_mean = 0.0 + else: + g = _node.group_idx + glob_id = self._global_action_id(g, action) + v, tot = self.rave_stats.get(glob_id, (0, 0.0)) + rave_mean = (tot / v) if v > 0 else 0.0 + + beta = self.k_rave / (self.k_rave + max(1, _node.n_visits)) + return (1 - beta) * uct_val + beta * rave_mean + + if node.children: + best_action, best_child = max( + node.children.items(), key=lambda kv: combined_score(kv[0], kv[1]) + ) + node = best_child + path.append(node) + else: + break + + return node, path + + def _rollout( + self, node: Node + ) -> tuple[tuple[int, ...], dict[int, float], list[tuple[int, int, int]]]: + """Random rollout to terminal state, return selection and trajectory. + + Returns: + Tuple of (selected_features, cat_selections, trajectory) where + trajectory is a list of (group_idx, cardinality, action) triples + taken during rollout. + """ + curr = Node( + partial_by_group=tuple(node.partial_by_group), + stopped_by_group=tuple(node.stopped_by_group), + group_idx=node.group_idx, + ) + trajectory: list[tuple[int, int, int]] = [] + + while not curr.is_terminal(self.groups): + legal = self._legal_actions(curr) + if not legal: + # No legal actions, advance group (group is complete) + curr = Node( + partial_by_group=curr.partial_by_group, + stopped_by_group=curr.stopped_by_group, + group_idx=curr.group_idx + 1, + ) + continue + + g = curr.group_idx + + if self.rollout_policy: + # Learned softmax policy: STOP is scored like any other action + action = self._sample_rollout_action(g, legal) + else: + # Original logic: adaptive p_stop for NChooseK, uniform for features + is_nchoosek = g < len(self.groups.nchooseks) + if is_nchoosek and STOP in legal: + p_stop = self._compute_adaptive_p_stop( + g, len(curr.partial_by_group[g]) + ) + if self.rng.random() < p_stop: + trajectory.append((g, len(curr.partial_by_group[g]), STOP)) + curr = self._apply_action(curr, STOP) + continue + + # Choose uniformly among non-STOP actions + choices = [a for a in legal if a != STOP] + if not choices: + trajectory.append((g, len(curr.partial_by_group[g]), STOP)) + curr = self._apply_action(curr, STOP) + continue + + action = self.rng.choice(choices) + + trajectory.append((g, len(curr.partial_by_group[g]), action)) + curr = self._apply_action(curr, action) + + selected_features, cat_selections = self._get_selection(curr) + return selected_features, cat_selections, trajectory + + def _backpropagate( + self, + path: list[Node], + reward: float, + selected_features: tuple[int, ...], + cat_selections: dict[int, float], + ) -> None: + """Backpropagate reward through path and update RAVE statistics.""" + for n in path: + n.n_visits += 1 + n.w_total += reward + + # Update RAVE stats for NChooseK selections + selected_set = set(selected_features) + for g, nchoosek in enumerate(self.groups.nchooseks): + for local_idx, feat_idx in enumerate(nchoosek.features): + if feat_idx in selected_set: + glob_id = self._global_action_id(g, local_idx) + v, tot = self.rave_stats.get(glob_id, (0, 0.0)) + self.rave_stats[glob_id] = (v + 1, tot + reward) + + # Update RAVE stats for categorical selections + n_nchoosek = len(self.groups.nchooseks) + for i, cat_group in enumerate(self.groups.categoricals): + g = n_nchoosek + i + if cat_group.dim in cat_selections: + selected_value = cat_selections[cat_group.dim] + for local_idx, value in enumerate(cat_group.values): + if value == selected_value: + glob_id = self._global_action_id(g, local_idx) + v, tot = self.rave_stats.get(glob_id, (0, 0.0)) + self.rave_stats[glob_id] = (v + 1, tot + reward) + break + + def run(self, n_iterations: int) -> tuple[tuple[int, ...], dict[int, float], float]: + """Run MCTS for specified number of iterations. + + Args: + n_iterations: Number of MCTS iterations to run + + Returns: + Tuple of (selected_features, cat_selections, best_value) + """ + for _ in range(n_iterations): + leaf, path = self._select_and_expand() + + if leaf.is_terminal(self.groups): + selected_features, cat_selections = self._get_selection(leaf) + trajectory: list[tuple[int, int, int]] = [] + else: + # Rollout retry: if the rollout produces a cached terminal, + # re-roll to try to discover a novel selection. + selected_features, cat_selections, trajectory = self._rollout(leaf) + for _attempt in range(self.max_rollout_retries): + key = self._make_cache_key(selected_features, cat_selections) + if key not in self.value_cache: + break + selected_features, cat_selections, trajectory = self._rollout(leaf) + + key = self._make_cache_key(selected_features, cat_selections) + is_novel = key not in self.value_cache + reward = self._cached_reward(selected_features, cat_selections) + + # Update reward range (used by normalization and adaptive p_stop) + if reward < self.reward_min: + self.reward_min = reward + if reward > self.reward_max: + self.reward_max = reward + + if reward > self.best_value: + self.best_value = reward + self.best_selection = (selected_features, cat_selections) + + if self.adaptive_p_stop: + self._update_cardinality_stats(reward, selected_features) + + # Normalize reward for backpropagation if enabled + bp_reward = ( + self._normalize_reward(reward) if self.normalize_rewards else reward + ) + + if is_novel: + self._backpropagate(path, bp_reward, selected_features, cat_selections) + if self.context_rave: + tree_actions = self._extract_tree_actions(path) + all_context_actions = tree_actions + trajectory + self._update_context_rave_stats(all_context_actions, bp_reward) + else: + # Virtual loss: increment visits with zero reward so that + # (a) PW limits still grow with traffic, and + # (b) mean_value drops for over-visited branches, steering + # UCT exploration toward less-exploited parts of the tree. + for n in path: + n.n_visits += 1 + + self._update_rollout_stats(trajectory, reward) + + if self.best_selection is None: + return (), {}, self.best_value + return self.best_selection[0], self.best_selection[1], self.best_value + + def cache_stats(self) -> dict[str, int]: + """Return cache statistics.""" + return { + "hits": self.cache_hits, + "misses": self.cache_misses, + "size": len(self.value_cache), + } + + +# ============================================================================= +# Main optimization function +# ============================================================================= + + +def optimize_acqf_mcts( + acq_function, + bounds: Tensor, + nchooseks: list[tuple[list[int], int, int]] | None = None, + cat_dims: Mapping[int, Sequence[float]] | None = None, + # MCTS parameters + c_uct: float = 0.01, + k_rave: float = 0.0, + p_stop_rollout: float = 0.35, + num_iterations: int = 100, + pw_k0: float = 2.0, + pw_alpha: float = 0.6, + max_rollout_retries: int = 3, + adaptive_p_stop: bool = True, + p_stop_warmup: int = 20, + p_stop_temperature: float = 0.25, + normalize_rewards: bool = True, + rollout_policy: bool = True, + rollout_epsilon: float = 0.3, + rollout_tau: float = 1.0, + rollout_novelty_weight: float = 1.0, + context_rave: bool = False, + # BoTorch acqf optimization parameters + q: int = 1, + raw_samples: int = 1024, + num_restarts: int = 20, + fixed_features: dict[int, float] | None = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + seed: int | None = None, +) -> tuple[Tensor, float]: + """Optimize acquisition function with NChooseK and categorical constraints using MCTS. + + Uses MCTS to select which features are active (non-zero) and categorical values, + then runs BoTorch optimization with inactive features fixed to zero and + categoricals fixed to their selected values. + + Args: + acq_function: BoTorch acquisition function to optimize + bounds: 2 x d tensor of (lower, upper) bounds for each dimension + nchooseks: List of NChooseK constraints as tuples of (features, min_count, max_count) + where features is a list of feature indices + cat_dims: Dictionary mapping categorical dimension indices to allowed values + (same signature as botorch.optim.optimize_acqf_mixed_alternating) + c_uct: UCT exploration constant (default 0.01, paired with normalize_rewards) + k_rave: RAVE blending decay parameter (default 0 = disabled) + p_stop_rollout: Base probability of early stop during NChooseK rollout + num_iterations: Number of MCTS iterations + pw_k0: Progressive widening base constant + pw_alpha: Progressive widening exponent + max_rollout_retries: Maximum rollout retries on cache hit + adaptive_p_stop: Learn per-group stop probability from cardinality stats + p_stop_warmup: Number of rollouts before adaptive p_stop fully activates + p_stop_temperature: Sigmoid temperature for adaptive p_stop + normalize_rewards: Map rewards to [0, 1] via running min-max + rollout_policy: Use learned softmax rollout policy instead of uniform + rollout_epsilon: Epsilon for epsilon-greedy blending in rollout policy + rollout_tau: Temperature for softmax in rollout policy + rollout_novelty_weight: Novelty bonus coefficient for rollout policy + context_rave: Use context-aware RAVE instead of global RAVE + q: Batch size for acquisition function optimization + raw_samples: Number of raw samples for initialization + num_restarts: Number of optimization restarts + fixed_features: Additional fixed features (combined with MCTS selections) + inequality_constraints: Inequality constraints for BoTorch optimization + equality_constraints: Equality constraints for BoTorch optimization + seed: Random seed for reproducibility + + Returns: + Tuple of (best_candidates, best_acq_value) where best_candidates is a + q x d tensor of optimal points and best_acq_value is the acquisition value + """ + d = bounds.shape[1] + + # Build NChooseK groups from tuples + nchoosek_list = [] + if nchooseks: + for features, min_count, max_count in nchooseks: + nchoosek_list.append( + NChooseK(features=features, min_count=min_count, max_count=max_count) + ) + + # Build categorical groups + categorical_list = ( + [Categorical(dim=dim, values=list(values)) for dim, values in cat_dims.items()] + if cat_dims + else [] + ) + + # Combine all groups + all_groups = nchoosek_list + categorical_list + groups = Groups(groups=all_groups) + + # All feature indices covered by NChooseK constraints + nchoosek_features = set(groups.all_nchoosek_features) + + # Storage for best result across all MCTS evaluations + best_candidates: Optional[Tensor] = None + best_acq_value: float = float("-inf") + + def reward_fn( + selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> float: + nonlocal best_candidates, best_acq_value + + selected_set = set(selected_features) + + # Build fixed_features dict + combined_fixed = {} + + # First add user-provided fixed features + if fixed_features is not None: + combined_fixed.update(fixed_features) + + # Fix inactive NChooseK features to 0 + inactive_features = nchoosek_features - selected_set + for idx in inactive_features: + combined_fixed[idx] = 0.0 + + # Fix categorical dimensions to selected values + for dim, value in cat_selections.items(): + combined_fixed[dim] = value + + candidates, acq_value = optimize_acqf( + acq_function=acq_function, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + fixed_features=combined_fixed, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) + + value = acq_value.item() + + # Track best + if value > best_acq_value: + best_acq_value = value + best_candidates = candidates + + return value + + # Run MCTS + mcts = MCTS( + groups=groups, + reward_fn=reward_fn, + c_uct=c_uct, + k_rave=k_rave, + p_stop_rollout=p_stop_rollout, + pw_k0=pw_k0, + pw_alpha=pw_alpha, + max_rollout_retries=max_rollout_retries, + adaptive_p_stop=adaptive_p_stop, + p_stop_warmup=p_stop_warmup, + p_stop_temperature=p_stop_temperature, + normalize_rewards=normalize_rewards, + rollout_policy=rollout_policy, + rollout_epsilon=rollout_epsilon, + rollout_tau=rollout_tau, + rollout_novelty_weight=rollout_novelty_weight, + context_rave=context_rave, + seed=seed, + ) + + mcts.run(n_iterations=num_iterations) + + # Handle case where no valid solution was found + if best_candidates is None: + # Return zeros with -inf value + best_candidates = torch.zeros(q, d, dtype=bounds.dtype, device=bounds.device) + best_acq_value = float("-inf") + + return best_candidates, best_acq_value diff --git a/mcts-report/optimize_mcts_nig.py b/mcts-report/optimize_mcts_nig.py new file mode 100644 index 000000000..3e2753113 --- /dev/null +++ b/mcts-report/optimize_mcts_nig.py @@ -0,0 +1,755 @@ +"""MCTS with Normal-Inverse-Gamma (NIG) posterior for Thompson Sampling. + +Replaces the Normal-Normal conjugate update in MCTS_TS with the proper +Bayesian conjugate for Normal data with unknown mean AND variance: the +Normal-Inverse-Gamma (NIG) distribution. + +The marginal posterior for the mean is a Student-t distribution with +heavier tails at low observation counts, which naturally handles the +low-n regime without extra heuristics (no posterior collapse at n=1). + +NIG prior: (mu, sigma^2) ~ NIG(mu0, n0, alpha0, beta0) + mu0 = _global_mean() (running mean of novel rewards) + n0 = 1 (pseudo-count) + alpha0 = nig_alpha0 parameter (default 1.0) + beta0 = alpha0 * _prior_var() (so E[sigma^2] = prior_var) + +After n observations with sufficient stats (n_obs, sum_rewards, sum_sq_rewards): + x_bar = sum_rewards / n + S = sum_sq_rewards - n * x_bar^2 + + n0' = n0 + n + mu0' = (n0 * mu0 + n * x_bar) / n0' + alpha0' = alpha0 + n / 2 + beta0' = beta0 + S / 2 + (n0 * n * (x_bar - mu0)^2) / (2 * n0') + +Marginal posterior for mu: Student-t with + df = 2 * alpha0' + location = mu0' + scale = sqrt(beta0' / (alpha0' * n0')) +""" + +import math +import random +from typing import Callable, Optional + +from optimize_mcts_full import STOP, Groups +from optimize_mcts_ts import TSActionStats, TSNode + + +# ============================================================================= +# MCTS with Normal-Inverse-Gamma Thompson Sampling +# ============================================================================= + + +class MCTS_NIG: + """Monte Carlo Tree Search with Normal-Inverse-Gamma Thompson Sampling. + + Uses NIG conjugate posteriors whose marginal for the mean is Student-t, + providing heavier tails at low observation counts than the Normal posterior + in MCTS_TS. This naturally prevents premature commitment at n=1. + + All other machinery (tree structure, cache-hit modes, rollout dispatch, + backpropagation, progressive widening, softmax fallback) is identical + to MCTS_TS. + + Args: + groups: Collection of NChooseK and categorical constraints + reward_fn: Function mapping (selected_features, cat_selections) to reward + nig_alpha0: NIG shape prior (default 1.0); lower = heavier tails at low n + ts_prior_var: Prior variance (default 1.0); used to set beta0 = alpha0 * prior_var + adaptive_prior_var: If True, set prior variance to running empirical variance + cache_hit_mode: How to handle cache hits: "no_update", "variance_inflation", + "pessimistic", or "combined" + variance_decay: Decay factor for variance inflation mode (default 0.95) + rollout_mode: Rollout policy: "uniform", "ts_group_action", + "ts_group_card_action", or "softmax" + pw_k0: Progressive widening base constant (default 2.0) + pw_alpha: Progressive widening exponent (default 0.6) + max_rollout_retries: Maximum rollout retries on cache hit (default 3) + seed: Random seed for reproducibility + rollout_epsilon: Epsilon-mix for uniform exploration in softmax mode + rollout_tau: Softmax temperature in softmax mode + rollout_novelty_weight: Novelty bonus coefficient in softmax mode + normalize_rewards: Normalize rewards for softmax rollout stats + adaptive_p_stop: Enable adaptive stop probability in softmax mode + p_stop_rollout: Base stop probability in softmax mode + p_stop_warmup: Warmup count for adaptive p_stop + p_stop_temperature: Sigmoid temperature for adaptive p_stop + """ + + def __init__( + self, + groups: Groups, + reward_fn: Callable[[tuple[int, ...], dict[int, float]], float], + nig_alpha0: float = 1.0, + ts_prior_var: float = 1.0, + adaptive_prior_var: bool = False, + cache_hit_mode: str = "no_update", + variance_decay: float = 0.95, + rollout_mode: str = "uniform", + pw_k0: float = 2.0, + pw_alpha: float = 0.6, + max_rollout_retries: int = 3, + seed: Optional[int] = None, + # Softmax fallback parameters + rollout_epsilon: float = 0.3, + rollout_tau: float = 1.0, + rollout_novelty_weight: float = 1.0, + normalize_rewards: bool = True, + adaptive_p_stop: bool = True, + p_stop_rollout: float = 0.35, + p_stop_warmup: int = 20, + p_stop_temperature: float = 0.25, + adaptive_n0: bool = False, + use_cache: bool = True, + ): + self.groups = groups + self.reward_fn = reward_fn + self.nig_alpha0 = nig_alpha0 + self.ts_prior_var = ts_prior_var + self.adaptive_prior_var = adaptive_prior_var + self.cache_hit_mode = cache_hit_mode + self.variance_decay = variance_decay + self.rollout_mode = rollout_mode + self.pw_k0 = pw_k0 + self.pw_alpha = pw_alpha + self.max_rollout_retries = max_rollout_retries + self.rng = random.Random(seed) + + # Softmax fallback params + self.rollout_epsilon = rollout_epsilon + self.rollout_tau = rollout_tau + self.rollout_novelty_weight = rollout_novelty_weight + self.normalize_rewards = normalize_rewards + self.adaptive_p_stop = adaptive_p_stop + self.p_stop_rollout = p_stop_rollout + self.p_stop_warmup = p_stop_warmup + self.p_stop_temperature = p_stop_temperature + self.adaptive_n0 = adaptive_n0 + self.use_cache = use_cache + + # Initialize root node + n_groups = len(groups) + self.root = TSNode( + partial_by_group=tuple(() for _ in range(n_groups)), + stopped_by_group=tuple(False for _ in range(n_groups)), + group_idx=0, + ) + + # Best found so far + self.best_selection: Optional[tuple[tuple[int, ...], dict[int, float]]] = None + self.best_value: float = float("-inf") + + # Cache for terminal evaluations + self.value_cache: dict[tuple, float] = {} + self.cache_hits = 0 + self.cache_misses = 0 + + # Global reward tracking for prior center and adaptive variance + self._novel_reward_sum: float = 0.0 + self._novel_reward_sq_sum: float = 0.0 + self._novel_reward_count: int = 0 + + # Rollout TS statistics: key -> TSActionStats + self.rollout_ts_stats: dict[tuple, TSActionStats] = {} + + # Softmax rollout statistics (for softmax fallback) + self.rollout_stats: dict[tuple[int, int], tuple[int, float]] = {} + + # Adaptive p_stop statistics (for softmax fallback) + self.cardinality_stats: dict[tuple[int, int], tuple[int, float]] = {} + n_nchoosek = len(self.groups.nchooseks) + self.group_rollout_counts: list[int] = [0] * n_nchoosek + self.reward_min: float = float("inf") + self.reward_max: float = float("-inf") + + # --- Prior center -------------------------------------------------------- + + def _global_mean(self) -> float: + """Running mean of all novel rewards (prior center mu0).""" + if self._novel_reward_count == 0: + return 0.0 + return self._novel_reward_sum / self._novel_reward_count + + def _prior_var(self) -> float: + """Prior variance, either fixed or adaptive (empirical variance).""" + if not self.adaptive_prior_var or self._novel_reward_count < 2: + return self.ts_prior_var + mean = self._global_mean() + empirical_var = ( + self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + ) + return max(empirical_var, 1e-8) + + def _pessimistic_value(self) -> float: + """Pessimistic pseudo-observation value: global_mean - global_std.""" + mean = self._global_mean() + if self._novel_reward_count < 2: + return mean - math.sqrt(self.ts_prior_var) + empirical_var = ( + self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + ) + return mean - math.sqrt(max(empirical_var, 1e-8)) + + def _compute_n0(self, n_actions: int) -> float: + """Compute pseudo-count n₀ from branching factor. + + With adaptive_n0, n₀ = 1 + log(branching_factor). Higher branching + means each child is visited rarely early on, so we need more + observations before departing from the prior. + """ + if not self.adaptive_n0: + return 1.0 + return 1.0 + math.log(max(n_actions, 2)) + + # --- NIG Student-t sampling ---------------------------------------------- + + def _student_t_sample(self, df: float, loc: float, scale: float) -> float: + """Sample from a Student-t distribution. + + Uses the representation: loc + scale * Z / sqrt(V / df) + where Z ~ N(0,1) and V ~ chi-squared(df) = Gamma(df/2, 2). + """ + z = self.rng.gauss(0, 1) + v = self.rng.gammavariate(df / 2, 2) # chi-squared(df) + return loc + scale * z / math.sqrt(v / df) + + def _nig_sample_score(self, node: TSNode, n0: float = 1.0) -> float: + """Sample from node's NIG posterior (marginal Student-t) for tree selection.""" + mu0 = self._global_mean() + prior_var = self._prior_var() + alpha0 = self.nig_alpha0 + beta0 = alpha0 * prior_var # E[sigma^2] = beta0/alpha0 = prior_var + + n = node.n_obs + if n == 0: + # Prior Student-t: df=2*alpha0, loc=mu0, scale=sqrt(beta0/(alpha0*n0)) + df = 2 * alpha0 + scale = math.sqrt(beta0 / (alpha0 * n0)) + return self._student_t_sample(df, mu0, scale) + + x_bar = node.sum_rewards / n + s = node.sum_sq_rewards - n * x_bar * x_bar # sum of squared deviations + s = max(s, 0.0) # numerical safety + + # Posterior update + n0_post = n0 + n + mu0_post = (n0 * mu0 + n * x_bar) / n0_post + alpha0_post = alpha0 + n / 2 + beta0_post = beta0 + s / 2 + (n0 * n * (x_bar - mu0) ** 2) / (2 * n0_post) + + df = 2 * alpha0_post + scale = math.sqrt(beta0_post / (alpha0_post * n0_post)) + return self._student_t_sample(df, mu0_post, scale) + + def _nig_sample_action_score(self, stats: TSActionStats, n0: float = 1.0) -> float: + """Sample from a TSActionStats NIG posterior (for rollout actions).""" + mu0 = self._global_mean() + prior_var = self._prior_var() + alpha0 = self.nig_alpha0 + beta0 = alpha0 * prior_var + + n = stats.n_obs + if n == 0: + df = 2 * alpha0 + scale = math.sqrt(beta0 / (alpha0 * n0)) + return self._student_t_sample(df, mu0, scale) + + x_bar = stats.sum_rewards / n + s = stats.sum_sq_rewards - n * x_bar * x_bar + s = max(s, 0.0) + + n0_post = n0 + n + mu0_post = (n0 * mu0 + n * x_bar) / n0_post + alpha0_post = alpha0 + n / 2 + beta0_post = beta0 + s / 2 + (n0 * n * (x_bar - mu0) ** 2) / (2 * n0_post) + + df = 2 * alpha0_post + scale = math.sqrt(beta0_post / (alpha0_post * n0_post)) + return self._student_t_sample(df, mu0_post, scale) + + # --- Copied from MCTS_TS (Node -> TSNode) -------------------------------- + + def _make_cache_key( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> tuple: + """Create hashable cache key from selection.""" + return (selected_features, frozenset(cat_selections.items())) + + def _cached_reward( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> float: + """Get cached reward or compute and cache it. + + With use_cache=False, always calls reward_fn fresh (for stochastic rewards). + """ + if not self.use_cache: + self.cache_misses += 1 + return self.reward_fn(selected_features, cat_selections) + key = self._make_cache_key(selected_features, cat_selections) + if key in self.value_cache: + self.cache_hits += 1 + return self.value_cache[key] + val = self.reward_fn(selected_features, cat_selections) + self.value_cache[key] = val + self.cache_misses += 1 + return val + + def _child_limit(self, node: TSNode) -> int: + """Progressive widening: max children based on visit count.""" + return max(1, int(self.pw_k0 * (max(1, node.n_visits) ** self.pw_alpha))) + + def _legal_actions(self, node: TSNode) -> list[int]: + """Get legal actions for current group in node.""" + if node.is_terminal(self.groups): + return [] + g = node.group_idx + group = self.groups.groups[g] + partial = node.partial_by_group[g] + stopped = node.stopped_by_group[g] + return group.legal_actions(partial, stopped) + + def _apply_action(self, node: TSNode, action: int) -> TSNode: + """Create child node by applying action to current node.""" + g = node.group_idx + group = self.groups.groups[g] + + partials = list(node.partial_by_group) + stoppeds = list(node.stopped_by_group) + + if action == STOP: + stoppeds[g] = True + next_g = g + 1 + else: + partials[g] += (action,) + if group.is_complete(partials[g], stoppeds[g]): + next_g = g + 1 + else: + next_g = g + + return TSNode( + partial_by_group=tuple(partials), + stopped_by_group=tuple(stoppeds), + group_idx=next_g, + ) + + def _get_selection(self, node: TSNode) -> tuple[tuple[int, ...], dict[int, float]]: + """Convert node's partial selections to (selected_features, cat_selections).""" + selected_features = [] + for g, nchoosek in enumerate(self.groups.nchooseks): + for local_idx in node.partial_by_group[g]: + selected_features.append(nchoosek.features[local_idx]) + selected_features_tuple = tuple(sorted(selected_features)) + + cat_selections: dict[int, float] = {} + n_nchoosek = len(self.groups.nchooseks) + for i, cat_group in enumerate(self.groups.categoricals): + g = n_nchoosek + i + partial = node.partial_by_group[g] + if partial: + cat_selections[cat_group.dim] = cat_group.values[partial[0]] + + return selected_features_tuple, cat_selections + + def cache_stats(self) -> dict[str, int]: + """Return cache statistics.""" + return { + "hits": self.cache_hits, + "misses": self.cache_misses, + "size": len(self.value_cache), + } + + # --- Tree selection with NIG Thompson Sampling ---------------------------- + + def _select_and_expand(self) -> tuple[TSNode, list[TSNode]]: + """Select path through tree using NIG-TS and expand one new node.""" + node = self.root + path = [node] + + while not node.is_terminal(self.groups): + legal = self._legal_actions(node) + limit = self._child_limit(node) + unexpanded = [a for a in legal if a not in node.children] + can_expand = len(node.children) < limit + + if can_expand and unexpanded: + action = self.rng.choice(unexpanded) + child = self._apply_action(node, action) + node.children[action] = child + path.append(child) + return child, path + + # NIG Thompson Sampling selection among existing children + if node.children: + n0 = self._compute_n0(len(node.children)) + best_action = None + best_score = float("-inf") + for action, child in node.children.items(): + score = self._nig_sample_score(child, n0=n0) + if score > best_score: + best_score = score + best_action = action + + node = node.children[best_action] + path.append(node) + else: + break + + return node, path + + # --- Rollout action selection --------------------------------------------- + + def _ts_sample_rollout_action( + self, group_idx: int, cardinality: int, legal_actions: list[int] + ) -> int: + """Sample rollout action using per-action NIG posteriors.""" + n0 = self._compute_n0(len(legal_actions)) + best_action = legal_actions[0] + best_score = float("-inf") + + for action in legal_actions: + if self.rollout_mode == "ts_group_action": + key = (group_idx, action) + else: # ts_group_card_action + key = (group_idx, cardinality, action) + + stats = self.rollout_ts_stats.get(key, TSActionStats(0, 0.0, 0.0)) + score = self._nig_sample_action_score(stats, n0=n0) + if score > best_score: + best_score = score + best_action = action + + return best_action + + @staticmethod + def _softmax_probs(logits: list[float]) -> list[float]: + """Pure-math softmax (no torch dependency).""" + max_logit = max(logits) + exps = [math.exp(v - max_logit) for v in logits] + total = sum(exps) + return [e / total for e in exps] + + # --- Softmax fallback methods --------------------------------------------- + + def _normalize_reward(self, reward: float) -> float: + """Normalize reward to [0, 1] using running min-max.""" + reward_range = self.reward_max - self.reward_min + if reward_range <= 0: + return 0.5 + return (reward - self.reward_min) / reward_range + + def _score_rollout_actions( + self, group_idx: int, legal_actions: list[int] + ) -> dict[int, float]: + """Score legal rollout actions using learned statistics.""" + scores: dict[int, float] = {} + for action in legal_actions: + key = (group_idx, action) + stats = self.rollout_stats.get(key) + if stats is not None and stats[0] > 0: + visits, total_reward = stats + mean_reward = total_reward / visits + novelty = self.rollout_novelty_weight / math.sqrt(visits + 1) + scores[action] = mean_reward + novelty + else: + scores[action] = self.rollout_novelty_weight + return scores + + def _sample_softmax_rollout_action( + self, group_idx: int, legal_actions: list[int] + ) -> int: + """Sample rollout action using softmax policy.""" + scores = self._score_rollout_actions(group_idx, legal_actions) + n = len(legal_actions) + + logits = [scores[a] / self.rollout_tau for a in legal_actions] + policy_probs = self._softmax_probs(logits) + + eps = self.rollout_epsilon + probs = [(1.0 - eps) * p + eps / n for p in policy_probs] + + return self.rng.choices(legal_actions, weights=probs, k=1)[0] + + def _update_rollout_stats( + self, trajectory: list[tuple[int, int, int]], reward: float + ) -> None: + """Update softmax rollout policy statistics from a completed trajectory.""" + for group_idx, _cardinality, action in trajectory: + key = (group_idx, action) + v, tot = self.rollout_stats.get(key, (0, 0.0)) + self.rollout_stats[key] = (v + 1, tot + reward) + + def _update_cardinality_stats( + self, reward: float, selected_features: tuple[int, ...] + ) -> None: + """Update per-(group, cardinality) stats from a completed rollout.""" + selected_set = set(selected_features) + for g, nchoosek in enumerate(self.groups.nchooseks): + cardinality = sum(1 for f in nchoosek.features if f in selected_set) + key = (g, cardinality) + v, tot = self.cardinality_stats.get(key, (0, 0.0)) + self.cardinality_stats[key] = (v + 1, tot + reward) + self.group_rollout_counts[g] += 1 + + def _compute_adaptive_p_stop( + self, group_idx: int, current_cardinality: int + ) -> float: + """Compute adaptive stop probability for softmax rollout mode.""" + if not self.adaptive_p_stop: + return self.p_stop_rollout + + nchoosek = self.groups.nchooseks[group_idx] + max_count = nchoosek.max_count + + stop_key = (group_idx, current_cardinality) + stop_stats = self.cardinality_stats.get(stop_key) + if stop_stats is None or stop_stats[0] == 0: + return self.p_stop_rollout + + e_stop = stop_stats[1] / stop_stats[0] + + e_continue = float("-inf") + has_continue_data = False + for m in range(current_cardinality + 1, max_count + 1): + cont_key = (group_idx, m) + cont_stats = self.cardinality_stats.get(cont_key) + if cont_stats is not None and cont_stats[0] > 0: + mean_r = cont_stats[1] / cont_stats[0] + if mean_r > e_continue: + e_continue = mean_r + has_continue_data = True + + if not has_continue_data: + return self.p_stop_rollout + + reward_range = self.reward_max - self.reward_min + if reward_range <= 0: + return self.p_stop_rollout + + tau = self.p_stop_temperature + logit = (e_stop - e_continue) / (tau * reward_range) + logit = max(-10.0, min(10.0, logit)) + p_learned = 1.0 / (1.0 + math.exp(-logit)) + + group_visits = self.group_rollout_counts[group_idx] + alpha = ( + min(1.0, group_visits / self.p_stop_warmup) + if self.p_stop_warmup > 0 + else 1.0 + ) + return (1.0 - alpha) * self.p_stop_rollout + alpha * p_learned + + # --- Rollout dispatch ----------------------------------------------------- + + def _rollout( + self, node: TSNode + ) -> tuple[tuple[int, ...], dict[int, float], list[tuple[int, int, int]]]: + """Rollout to terminal state with mode-dependent action selection.""" + curr = TSNode( + partial_by_group=tuple(node.partial_by_group), + stopped_by_group=tuple(node.stopped_by_group), + group_idx=node.group_idx, + ) + trajectory: list[tuple[int, int, int]] = [] + + while not curr.is_terminal(self.groups): + legal = self._legal_actions(curr) + if not legal: + curr = TSNode( + partial_by_group=curr.partial_by_group, + stopped_by_group=curr.stopped_by_group, + group_idx=curr.group_idx + 1, + ) + continue + + g = curr.group_idx + cardinality = len(curr.partial_by_group[g]) + + if self.rollout_mode == "uniform": + action = self.rng.choice(legal) + + elif self.rollout_mode in ("ts_group_action", "ts_group_card_action"): + action = self._ts_sample_rollout_action(g, cardinality, legal) + + elif self.rollout_mode == "softmax": + is_nchoosek = g < len(self.groups.nchooseks) + if is_nchoosek and STOP in legal: + p_stop = self._compute_adaptive_p_stop(g, cardinality) + if self.rng.random() < p_stop: + trajectory.append((g, cardinality, STOP)) + curr = self._apply_action(curr, STOP) + continue + + action = self._sample_softmax_rollout_action(g, legal) + + else: + raise ValueError(f"Unknown rollout_mode: {self.rollout_mode}") + + trajectory.append((g, cardinality, action)) + curr = self._apply_action(curr, action) + + selected_features, cat_selections = self._get_selection(curr) + return selected_features, cat_selections, trajectory + + # --- Backpropagation ------------------------------------------------------ + + def _backpropagate(self, path: list[TSNode], reward: float, is_novel: bool) -> None: + """Backpropagate reward through path. + + Novel: update n_obs, sum_rewards, sum_sq_rewards, n_visits. + Cache hit handling depends on cache_hit_mode (same as MCTS_TS). + """ + if is_novel: + for n in path: + n.n_obs += 1 + n.sum_rewards += reward + n.sum_sq_rewards += reward * reward + n.n_visits += 1 + else: + if self.cache_hit_mode in ("pessimistic", "combined"): + pess = self._pessimistic_value() + if self.cache_hit_mode in ("adaptive_pessimistic", "adaptive_combined"): + g_mean = self._global_mean() + if self._novel_reward_count < 2: + g_std = math.sqrt(self.ts_prior_var) + else: + emp_var = ( + self._novel_reward_sq_sum / self._novel_reward_count + - g_mean * g_mean + ) + g_std = math.sqrt(max(emp_var, 1e-8)) + + for n in path: + n.n_visits += 1 + + if self.cache_hit_mode == "variance_inflation": + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + elif self.cache_hit_mode == "pessimistic": + n.n_obs += 1 + n.sum_rewards += pess + n.sum_sq_rewards += pess * pess + elif self.cache_hit_mode == "combined": + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + n.n_obs += 1 + n.sum_rewards += pess + n.sum_sq_rewards += pess * pess + elif self.cache_hit_mode == "adaptive_pessimistic": + novelty_rate = n.n_obs / max(1, n.n_visits) + exhaustion = 1.0 - novelty_rate + pess_value = g_mean - exhaustion * g_std + n.n_obs += 1 + n.sum_rewards += pess_value + n.sum_sq_rewards += pess_value * pess_value + elif self.cache_hit_mode == "adaptive_combined": + # Variance inflation (same as combined) + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + # Adaptive pessimistic + novelty_rate = n.n_obs / max(1, n.n_visits) + exhaustion = 1.0 - novelty_rate + pess_value = g_mean - exhaustion * g_std + n.n_obs += 1 + n.sum_rewards += pess_value + n.sum_sq_rewards += pess_value * pess_value + + def _update_rollout_ts_stats( + self, trajectory: list[tuple[int, int, int]], reward: float + ) -> None: + """Update per-action TS stats from a completed rollout trajectory.""" + for group_idx, cardinality, action in trajectory: + if self.rollout_mode == "ts_group_action": + key = (group_idx, action) + elif self.rollout_mode == "ts_group_card_action": + key = (group_idx, cardinality, action) + else: + continue + + old = self.rollout_ts_stats.get(key, TSActionStats(0, 0.0, 0.0)) + self.rollout_ts_stats[key] = TSActionStats( + n_obs=old.n_obs + 1, + sum_rewards=old.sum_rewards + reward, + sum_sq_rewards=old.sum_sq_rewards + reward * reward, + ) + + # --- Main loop ------------------------------------------------------------ + + def run(self, n_iterations: int) -> tuple[tuple[int, ...], dict[int, float], float]: + """Run MCTS-NIG for specified number of iterations. + + Args: + n_iterations: Number of MCTS iterations to run + + Returns: + Tuple of (selected_features, cat_selections, best_value) + """ + for _ in range(n_iterations): + leaf, path = self._select_and_expand() + + if leaf.is_terminal(self.groups): + selected_features, cat_selections = self._get_selection(leaf) + trajectory: list[tuple[int, int, int]] = [] + else: + selected_features, cat_selections, trajectory = self._rollout(leaf) + if self.use_cache: + for _attempt in range(self.max_rollout_retries): + key = self._make_cache_key(selected_features, cat_selections) + if key not in self.value_cache: + break + selected_features, cat_selections, trajectory = self._rollout( + leaf + ) + + if self.use_cache: + key = self._make_cache_key(selected_features, cat_selections) + is_novel = key not in self.value_cache + else: + is_novel = True + reward = self._cached_reward(selected_features, cat_selections) + + if reward < self.reward_min: + self.reward_min = reward + if reward > self.reward_max: + self.reward_max = reward + + if reward > self.best_value: + self.best_value = reward + self.best_selection = (selected_features, cat_selections) + + if is_novel: + self._novel_reward_sum += reward + self._novel_reward_sq_sum += reward * reward + self._novel_reward_count += 1 + + self._backpropagate(path, reward, is_novel) + + if self.rollout_mode in ("ts_group_action", "ts_group_card_action"): + self._update_rollout_ts_stats(trajectory, reward) + elif self.rollout_mode == "softmax": + self._update_rollout_stats(trajectory, reward) + if self.adaptive_p_stop: + self._update_cardinality_stats(reward, selected_features) + + if self.best_selection is None: + return (), {}, self.best_value + return self.best_selection[0], self.best_selection[1], self.best_value diff --git a/mcts-report/optimize_mcts_ts.py b/mcts-report/optimize_mcts_ts.py new file mode 100644 index 000000000..96a642b21 --- /dev/null +++ b/mcts-report/optimize_mcts_ts.py @@ -0,0 +1,737 @@ +"""MCTS with Thompson Sampling for NChooseK and categorical optimization. + +Replaces UCT tree policy and softmax rollout policy with Bayesian posteriors +(Normal-Normal conjugate update), eliminating most tunable hyperparameters. + +Tree selection: sample from each child's Normal posterior, pick highest. +Rollout modes: uniform, TS per (group,action), TS per (group,cardinality,action), +or softmax fallback (pure-math, no torch). + +Bayesian update (weak prior, estimated variance): + Prior: N(mu0, sigma0^2), mu0 = running global mean, n0 = 1 + sigma0^2 = ts_prior_var (fixed) or empirical reward variance (adaptive) + After n observations: + x_bar = sum_rewards / n + s^2 = max(sum_sq_rewards/n - x_bar^2, 1e-8) + posterior_mean = (n0*mu0 + n*x_bar) / (n0 + n) + posterior_var = s^2 / (n0 + n) + sample ~ N(posterior_mean, posterior_var) +""" + +import math +import random +from dataclasses import dataclass, field +from typing import Callable, NamedTuple, Optional + +from optimize_mcts_full import STOP, Groups + + +# ============================================================================= +# TS-specific types +# ============================================================================= + + +class TSActionStats(NamedTuple): + """Sufficient statistics for Thompson Sampling on a rollout action.""" + + n_obs: int + sum_rewards: float + sum_sq_rewards: float + + +@dataclass +class TSNode: + """MCTS tree node with Thompson Sampling statistics. + + Replaces UCT's single n_visits/w_total with Bayesian sufficient statistics. + + Args: + partial_by_group: Partial selection per group (indices into group's options) + stopped_by_group: Whether each group has stopped selecting (for NChooseK) + group_idx: Current group being filled + n_obs: Novel observations count (used for posterior updates) + sum_rewards: Sum of observed rewards from novel evaluations + sum_sq_rewards: Sum of squared rewards from novel evaluations + n_visits: Total visits including cache hits (for progressive widening) + children: Child nodes keyed by action (int index or STOP) + """ + + partial_by_group: tuple[tuple[int, ...], ...] + stopped_by_group: tuple[bool, ...] + group_idx: int + + n_obs: int = 0 + sum_rewards: float = 0.0 + sum_sq_rewards: float = 0.0 + n_visits: int = 0 + + children: dict[int, "TSNode"] = field(default_factory=dict) + + def is_terminal(self, groups: Groups) -> bool: + return self.group_idx >= len(groups) + + def mean_value(self) -> float: + return self.sum_rewards / self.n_obs if self.n_obs > 0 else 0.0 + + +# ============================================================================= +# MCTS with Thompson Sampling +# ============================================================================= + + +class MCTS_TS: + """Monte Carlo Tree Search with Thompson Sampling. + + Uses Normal-Normal conjugate posteriors for both tree selection and + (optionally) rollout action selection, replacing UCT and softmax policies. + + Args: + groups: Collection of NChooseK and categorical constraints + reward_fn: Function mapping (selected_features, cat_selections) to reward + ts_prior_var: Prior variance for Normal posterior (default 1.0); used as + fallback when adaptive_prior_var=True and fewer than 2 rewards observed + adaptive_prior_var: If True, set prior variance to the running empirical + variance of all observed rewards instead of the fixed ts_prior_var + cache_hit_mode: How to handle cache hits: "no_update", "variance_inflation", + "pessimistic", or "combined" (variance inflation + pessimistic) + variance_decay: Decay factor for variance inflation mode (default 0.95) + rollout_mode: Rollout policy: "uniform", "ts_group_action", + "ts_group_card_action", or "softmax" + pw_k0: Progressive widening base constant (default 2.0) + pw_alpha: Progressive widening exponent (default 0.6) + max_rollout_retries: Maximum rollout retries on cache hit (default 3) + seed: Random seed for reproducibility + rollout_epsilon: Epsilon-mix for uniform exploration in softmax mode (default 0.3) + rollout_tau: Softmax temperature in softmax mode (default 1.0) + rollout_novelty_weight: Novelty bonus coefficient in softmax mode (default 1.0) + normalize_rewards: Normalize rewards for softmax rollout stats (default True) + adaptive_p_stop: Enable adaptive stop probability in softmax mode (default True) + p_stop_rollout: Base stop probability in softmax mode (default 0.35) + p_stop_warmup: Warmup count for adaptive p_stop (default 20) + p_stop_temperature: Sigmoid temperature for adaptive p_stop (default 0.25) + """ + + def __init__( + self, + groups: Groups, + reward_fn: Callable[[tuple[int, ...], dict[int, float]], float], + ts_prior_var: float = 1.0, + adaptive_prior_var: bool = False, + cache_hit_mode: str = "no_update", + variance_decay: float = 0.95, + rollout_mode: str = "uniform", + pw_k0: float = 2.0, + pw_alpha: float = 0.6, + max_rollout_retries: int = 3, + seed: Optional[int] = None, + # Softmax fallback parameters + rollout_epsilon: float = 0.3, + rollout_tau: float = 1.0, + rollout_novelty_weight: float = 1.0, + normalize_rewards: bool = True, + adaptive_p_stop: bool = True, + p_stop_rollout: float = 0.35, + p_stop_warmup: int = 20, + p_stop_temperature: float = 0.25, + ): + self.groups = groups + self.reward_fn = reward_fn + self.ts_prior_var = ts_prior_var + self.adaptive_prior_var = adaptive_prior_var + self.cache_hit_mode = cache_hit_mode + self.variance_decay = variance_decay + self.rollout_mode = rollout_mode + self.pw_k0 = pw_k0 + self.pw_alpha = pw_alpha + self.max_rollout_retries = max_rollout_retries + self.rng = random.Random(seed) + + # Softmax fallback params + self.rollout_epsilon = rollout_epsilon + self.rollout_tau = rollout_tau + self.rollout_novelty_weight = rollout_novelty_weight + self.normalize_rewards = normalize_rewards + self.adaptive_p_stop = adaptive_p_stop + self.p_stop_rollout = p_stop_rollout + self.p_stop_warmup = p_stop_warmup + self.p_stop_temperature = p_stop_temperature + + # Initialize root node + n_groups = len(groups) + self.root = TSNode( + partial_by_group=tuple(() for _ in range(n_groups)), + stopped_by_group=tuple(False for _ in range(n_groups)), + group_idx=0, + ) + + # Best found so far + self.best_selection: Optional[tuple[tuple[int, ...], dict[int, float]]] = None + self.best_value: float = float("-inf") + + # Cache for terminal evaluations + self.value_cache: dict[tuple, float] = {} + self.cache_hits = 0 + self.cache_misses = 0 + + # Global reward tracking for TS prior center and adaptive variance + self._novel_reward_sum: float = 0.0 + self._novel_reward_sq_sum: float = 0.0 + self._novel_reward_count: int = 0 + + # Rollout TS statistics: key -> TSActionStats + # Key format depends on rollout_mode: + # ts_group_action: (group_idx, action) + # ts_group_card_action: (group_idx, cardinality, action) + self.rollout_ts_stats: dict[tuple, TSActionStats] = {} + + # Softmax rollout statistics (for softmax fallback) + self.rollout_stats: dict[tuple[int, int], tuple[int, float]] = {} + + # Adaptive p_stop statistics (for softmax fallback) + self.cardinality_stats: dict[tuple[int, int], tuple[int, float]] = {} + n_nchoosek = len(self.groups.nchooseks) + self.group_rollout_counts: list[int] = [0] * n_nchoosek + self.reward_min: float = float("inf") + self.reward_max: float = float("-inf") + + # ─── Prior center ─────────────────────────────────────────────── + + def _global_mean(self) -> float: + """Running mean of all novel rewards (prior center mu0).""" + if self._novel_reward_count == 0: + return 0.0 + return self._novel_reward_sum / self._novel_reward_count + + def _prior_var(self) -> float: + """Prior variance sigma0^2, either fixed or adaptive. + + When adaptive_prior_var=True, returns the running empirical variance + of all novel rewards once at least 2 observations exist. This + auto-calibrates the prior to the problem's reward scale — the TS + analogue of reward normalization for UCT. + """ + if not self.adaptive_prior_var or self._novel_reward_count < 2: + return self.ts_prior_var + mean = self._global_mean() + empirical_var = ( + self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + ) + return max(empirical_var, 1e-8) + + def _pessimistic_value(self) -> float: + """Pessimistic pseudo-observation value for cache hit handling. + + Returns global_mean - global_std. Uses empirical std when available + (>= 2 observations), regardless of adaptive_prior_var setting, so that + the pessimistic offset is always scale-appropriate. + """ + mean = self._global_mean() + if self._novel_reward_count < 2: + return mean - math.sqrt(self.ts_prior_var) + empirical_var = ( + self._novel_reward_sq_sum / self._novel_reward_count - mean * mean + ) + return mean - math.sqrt(max(empirical_var, 1e-8)) + + # ─── Thompson Sampling ────────────────────────────────────────── + + def _ts_sample_score(self, node: TSNode) -> float: + """Sample from node's Normal posterior for tree selection.""" + mu0 = self._global_mean() + sigma0_sq = self._prior_var() + n0 = 1 # pseudo-count + + n = node.n_obs + if n == 0: + return self.rng.gauss(mu0, math.sqrt(sigma0_sq)) + + x_bar = node.sum_rewards / n + s_sq = max(node.sum_sq_rewards / n - x_bar * x_bar, 1e-8) + + post_mean = (n0 * mu0 + n * x_bar) / (n0 + n) + post_var = s_sq / (n0 + n) + + return self.rng.gauss(post_mean, math.sqrt(post_var)) + + def _ts_sample_action_score(self, stats: TSActionStats) -> float: + """Sample from a TSActionStats posterior (for rollout actions).""" + mu0 = self._global_mean() + sigma0_sq = self._prior_var() + n0 = 1 + + n = stats.n_obs + if n == 0: + return self.rng.gauss(mu0, math.sqrt(sigma0_sq)) + + x_bar = stats.sum_rewards / n + s_sq = max(stats.sum_sq_rewards / n - x_bar * x_bar, 1e-8) + + post_mean = (n0 * mu0 + n * x_bar) / (n0 + n) + post_var = s_sq / (n0 + n) + + return self.rng.gauss(post_mean, math.sqrt(post_var)) + + # ─── Copied from production (Node -> TSNode) ──────────────────── + + def _make_cache_key( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> tuple: + """Create hashable cache key from selection.""" + return (selected_features, frozenset(cat_selections.items())) + + def _cached_reward( + self, selected_features: tuple[int, ...], cat_selections: dict[int, float] + ) -> float: + """Get cached reward or compute and cache it.""" + key = self._make_cache_key(selected_features, cat_selections) + if key in self.value_cache: + self.cache_hits += 1 + return self.value_cache[key] + val = self.reward_fn(selected_features, cat_selections) + self.value_cache[key] = val + self.cache_misses += 1 + return val + + def _child_limit(self, node: TSNode) -> int: + """Progressive widening: max children based on visit count.""" + return max(1, int(self.pw_k0 * (max(1, node.n_visits) ** self.pw_alpha))) + + def _legal_actions(self, node: TSNode) -> list[int]: + """Get legal actions for current group in node.""" + if node.is_terminal(self.groups): + return [] + g = node.group_idx + group = self.groups.groups[g] + partial = node.partial_by_group[g] + stopped = node.stopped_by_group[g] + return group.legal_actions(partial, stopped) + + def _apply_action(self, node: TSNode, action: int) -> TSNode: + """Create child node by applying action to current node.""" + g = node.group_idx + group = self.groups.groups[g] + + partials = list(node.partial_by_group) + stoppeds = list(node.stopped_by_group) + + if action == STOP: + stoppeds[g] = True + next_g = g + 1 + else: + partials[g] += (action,) + if group.is_complete(partials[g], stoppeds[g]): + next_g = g + 1 + else: + next_g = g + + return TSNode( + partial_by_group=tuple(partials), + stopped_by_group=tuple(stoppeds), + group_idx=next_g, + ) + + def _get_selection(self, node: TSNode) -> tuple[tuple[int, ...], dict[int, float]]: + """Convert node's partial selections to (selected_features, cat_selections).""" + selected_features = [] + for g, nchoosek in enumerate(self.groups.nchooseks): + for local_idx in node.partial_by_group[g]: + selected_features.append(nchoosek.features[local_idx]) + selected_features_tuple = tuple(sorted(selected_features)) + + cat_selections: dict[int, float] = {} + n_nchoosek = len(self.groups.nchooseks) + for i, cat_group in enumerate(self.groups.categoricals): + g = n_nchoosek + i + partial = node.partial_by_group[g] + if partial: + cat_selections[cat_group.dim] = cat_group.values[partial[0]] + + return selected_features_tuple, cat_selections + + def cache_stats(self) -> dict[str, int]: + """Return cache statistics.""" + return { + "hits": self.cache_hits, + "misses": self.cache_misses, + "size": len(self.value_cache), + } + + # ─── Tree selection with Thompson Sampling ────────────────────── + + def _select_and_expand(self) -> tuple[TSNode, list[TSNode]]: + """Select path through tree using TS and expand one new node.""" + node = self.root + path = [node] + + while not node.is_terminal(self.groups): + legal = self._legal_actions(node) + limit = self._child_limit(node) + unexpanded = [a for a in legal if a not in node.children] + can_expand = len(node.children) < limit + + if can_expand and unexpanded: + # Expand one new child + action = self.rng.choice(unexpanded) + child = self._apply_action(node, action) + node.children[action] = child + path.append(child) + return child, path + + # Thompson Sampling selection among existing children + if node.children: + best_action = None + best_score = float("-inf") + for action, child in node.children.items(): + score = self._ts_sample_score(child) + if score > best_score: + best_score = score + best_action = action + + node = node.children[best_action] + path.append(node) + else: + break + + return node, path + + # ─── Rollout action selection ─────────────────────────────────── + + def _ts_sample_rollout_action( + self, group_idx: int, cardinality: int, legal_actions: list[int] + ) -> int: + """Sample rollout action using per-action TS posteriors.""" + best_action = legal_actions[0] + best_score = float("-inf") + + for action in legal_actions: + if self.rollout_mode == "ts_group_action": + key = (group_idx, action) + else: # ts_group_card_action + key = (group_idx, cardinality, action) + + stats = self.rollout_ts_stats.get(key, TSActionStats(0, 0.0, 0.0)) + score = self._ts_sample_action_score(stats) + if score > best_score: + best_score = score + best_action = action + + return best_action + + @staticmethod + def _softmax_probs(logits: list[float]) -> list[float]: + """Pure-math softmax (no torch dependency).""" + max_logit = max(logits) + exps = [math.exp(v - max_logit) for v in logits] + total = sum(exps) + return [e / total for e in exps] + + # ─── Softmax fallback methods (copied from production, no torch) ─ + + def _normalize_reward(self, reward: float) -> float: + """Normalize reward to [0, 1] using running min-max.""" + reward_range = self.reward_max - self.reward_min + if reward_range <= 0: + return 0.5 + return (reward - self.reward_min) / reward_range + + def _score_rollout_actions( + self, group_idx: int, legal_actions: list[int] + ) -> dict[int, float]: + """Score legal rollout actions using learned statistics.""" + scores: dict[int, float] = {} + for action in legal_actions: + key = (group_idx, action) + stats = self.rollout_stats.get(key) + if stats is not None and stats[0] > 0: + visits, total_reward = stats + mean_reward = total_reward / visits + novelty = self.rollout_novelty_weight / math.sqrt(visits + 1) + scores[action] = mean_reward + novelty + else: + scores[action] = self.rollout_novelty_weight + return scores + + def _sample_softmax_rollout_action( + self, group_idx: int, legal_actions: list[int] + ) -> int: + """Sample rollout action using softmax policy (pure-math, no torch).""" + scores = self._score_rollout_actions(group_idx, legal_actions) + n = len(legal_actions) + + logits = [scores[a] / self.rollout_tau for a in legal_actions] + policy_probs = self._softmax_probs(logits) + + eps = self.rollout_epsilon + probs = [(1.0 - eps) * p + eps / n for p in policy_probs] + + return self.rng.choices(legal_actions, weights=probs, k=1)[0] + + def _update_rollout_stats( + self, trajectory: list[tuple[int, int, int]], reward: float + ) -> None: + """Update softmax rollout policy statistics from a completed trajectory.""" + for group_idx, _cardinality, action in trajectory: + key = (group_idx, action) + v, tot = self.rollout_stats.get(key, (0, 0.0)) + self.rollout_stats[key] = (v + 1, tot + reward) + + def _update_cardinality_stats( + self, reward: float, selected_features: tuple[int, ...] + ) -> None: + """Update per-(group, cardinality) stats from a completed rollout.""" + selected_set = set(selected_features) + for g, nchoosek in enumerate(self.groups.nchooseks): + cardinality = sum(1 for f in nchoosek.features if f in selected_set) + key = (g, cardinality) + v, tot = self.cardinality_stats.get(key, (0, 0.0)) + self.cardinality_stats[key] = (v + 1, tot + reward) + self.group_rollout_counts[g] += 1 + + def _compute_adaptive_p_stop( + self, group_idx: int, current_cardinality: int + ) -> float: + """Compute adaptive stop probability for softmax rollout mode.""" + if not self.adaptive_p_stop: + return self.p_stop_rollout + + nchoosek = self.groups.nchooseks[group_idx] + max_count = nchoosek.max_count + + stop_key = (group_idx, current_cardinality) + stop_stats = self.cardinality_stats.get(stop_key) + if stop_stats is None or stop_stats[0] == 0: + return self.p_stop_rollout + + e_stop = stop_stats[1] / stop_stats[0] + + e_continue = float("-inf") + has_continue_data = False + for m in range(current_cardinality + 1, max_count + 1): + cont_key = (group_idx, m) + cont_stats = self.cardinality_stats.get(cont_key) + if cont_stats is not None and cont_stats[0] > 0: + mean_r = cont_stats[1] / cont_stats[0] + if mean_r > e_continue: + e_continue = mean_r + has_continue_data = True + + if not has_continue_data: + return self.p_stop_rollout + + reward_range = self.reward_max - self.reward_min + if reward_range <= 0: + return self.p_stop_rollout + + tau = self.p_stop_temperature + logit = (e_stop - e_continue) / (tau * reward_range) + logit = max(-10.0, min(10.0, logit)) + p_learned = 1.0 / (1.0 + math.exp(-logit)) + + group_visits = self.group_rollout_counts[group_idx] + alpha = ( + min(1.0, group_visits / self.p_stop_warmup) + if self.p_stop_warmup > 0 + else 1.0 + ) + return (1.0 - alpha) * self.p_stop_rollout + alpha * p_learned + + # ─── Rollout dispatch ─────────────────────────────────────────── + + def _rollout( + self, node: TSNode + ) -> tuple[tuple[int, ...], dict[int, float], list[tuple[int, int, int]]]: + """Rollout to terminal state with mode-dependent action selection. + + Returns: + Tuple of (selected_features, cat_selections, trajectory) where + trajectory is a list of (group_idx, cardinality, action) triples. + """ + curr = TSNode( + partial_by_group=tuple(node.partial_by_group), + stopped_by_group=tuple(node.stopped_by_group), + group_idx=node.group_idx, + ) + trajectory: list[tuple[int, int, int]] = [] + + while not curr.is_terminal(self.groups): + legal = self._legal_actions(curr) + if not legal: + curr = TSNode( + partial_by_group=curr.partial_by_group, + stopped_by_group=curr.stopped_by_group, + group_idx=curr.group_idx + 1, + ) + continue + + g = curr.group_idx + cardinality = len(curr.partial_by_group[g]) + + if self.rollout_mode == "uniform": + action = self.rng.choice(legal) + + elif self.rollout_mode in ("ts_group_action", "ts_group_card_action"): + action = self._ts_sample_rollout_action(g, cardinality, legal) + + elif self.rollout_mode == "softmax": + # Softmax fallback: adaptive p_stop for NChooseK STOP decisions + is_nchoosek = g < len(self.groups.nchooseks) + if is_nchoosek and STOP in legal: + p_stop = self._compute_adaptive_p_stop(g, cardinality) + if self.rng.random() < p_stop: + trajectory.append((g, cardinality, STOP)) + curr = self._apply_action(curr, STOP) + continue + + action = self._sample_softmax_rollout_action(g, legal) + + else: + raise ValueError(f"Unknown rollout_mode: {self.rollout_mode}") + + trajectory.append((g, cardinality, action)) + curr = self._apply_action(curr, action) + + selected_features, cat_selections = self._get_selection(curr) + return selected_features, cat_selections, trajectory + + # ─── Backpropagation ──────────────────────────────────────────── + + def _backpropagate(self, path: list[TSNode], reward: float, is_novel: bool) -> None: + """Backpropagate reward through path. + + Novel: update n_obs, sum_rewards, sum_sq_rewards, n_visits. + Cache hit (no_update): only increment n_visits. + Cache hit (variance_inflation): increment n_visits, decay n_obs + to gradually widen posterior for exhausted subtrees. + Cache hit (pessimistic): increment n_visits and inject a pessimistic + pseudo-observation (global_mean - global_std) to shift the + posterior mean downward — asymmetric pressure away from + exhausted subtrees. + Cache hit (combined): apply variance inflation first (preserves + posterior width for interaction discovery), then add a pessimistic + pseudo-observation (asymmetric downward shift). Net effect: one + real observation is effectively replaced by a pessimistic one, + shifting the mean down while largely preserving posterior width. + """ + if is_novel: + for n in path: + n.n_obs += 1 + n.sum_rewards += reward + n.sum_sq_rewards += reward * reward + n.n_visits += 1 + else: + # Cache hit + if self.cache_hit_mode in ("pessimistic", "combined"): + pess = self._pessimistic_value() + + for n in path: + n.n_visits += 1 + + if self.cache_hit_mode == "variance_inflation": + # Decay n_obs to widen posterior, preserving mean + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + elif self.cache_hit_mode == "pessimistic": + # Inject pessimistic pseudo-observation to shift mean down + n.n_obs += 1 + n.sum_rewards += pess + n.sum_sq_rewards += pess * pess + elif self.cache_hit_mode == "combined": + # Step 1: variance inflation — decay n_obs to widen posterior + if n.n_obs > 1: + old_n = n.n_obs + new_n = max(1, int(old_n * self.variance_decay)) + if new_n < old_n: + mean = n.sum_rewards / old_n + n.sum_rewards = mean * new_n + n.sum_sq_rewards *= new_n / old_n + n.n_obs = new_n + # Step 2: mild pessimistic — add one pessimistic observation + n.n_obs += 1 + n.sum_rewards += pess + n.sum_sq_rewards += pess * pess + + def _update_rollout_ts_stats( + self, trajectory: list[tuple[int, int, int]], reward: float + ) -> None: + """Update per-action TS stats from a completed rollout trajectory.""" + for group_idx, cardinality, action in trajectory: + if self.rollout_mode == "ts_group_action": + key = (group_idx, action) + elif self.rollout_mode == "ts_group_card_action": + key = (group_idx, cardinality, action) + else: + continue + + old = self.rollout_ts_stats.get(key, TSActionStats(0, 0.0, 0.0)) + self.rollout_ts_stats[key] = TSActionStats( + n_obs=old.n_obs + 1, + sum_rewards=old.sum_rewards + reward, + sum_sq_rewards=old.sum_sq_rewards + reward * reward, + ) + + # ─── Main loop ────────────────────────────────────────────────── + + def run(self, n_iterations: int) -> tuple[tuple[int, ...], dict[int, float], float]: + """Run MCTS-TS for specified number of iterations. + + Args: + n_iterations: Number of MCTS iterations to run + + Returns: + Tuple of (selected_features, cat_selections, best_value) + """ + for _ in range(n_iterations): + leaf, path = self._select_and_expand() + + if leaf.is_terminal(self.groups): + selected_features, cat_selections = self._get_selection(leaf) + trajectory: list[tuple[int, int, int]] = [] + else: + # Rollout retry: re-roll on cache hits to discover novel selections + selected_features, cat_selections, trajectory = self._rollout(leaf) + for _attempt in range(self.max_rollout_retries): + key = self._make_cache_key(selected_features, cat_selections) + if key not in self.value_cache: + break + selected_features, cat_selections, trajectory = self._rollout(leaf) + + key = self._make_cache_key(selected_features, cat_selections) + is_novel = key not in self.value_cache + reward = self._cached_reward(selected_features, cat_selections) + + # Update reward range (used by softmax fallback) + if reward < self.reward_min: + self.reward_min = reward + if reward > self.reward_max: + self.reward_max = reward + + if reward > self.best_value: + self.best_value = reward + self.best_selection = (selected_features, cat_selections) + + # Update global stats for TS prior (mean and adaptive variance) + if is_novel: + self._novel_reward_sum += reward + self._novel_reward_sq_sum += reward * reward + self._novel_reward_count += 1 + + # Backpropagate (raw reward, no normalization for TS tree) + self._backpropagate(path, reward, is_novel) + + # Update rollout statistics + if self.rollout_mode in ("ts_group_action", "ts_group_card_action"): + self._update_rollout_ts_stats(trajectory, reward) + elif self.rollout_mode == "softmax": + self._update_rollout_stats(trajectory, reward) + if self.adaptive_p_stop: + self._update_cardinality_stats(reward, selected_features) + + if self.best_selection is None: + return (), {}, self.best_value + return self.best_selection[0], self.best_selection[1], self.best_value diff --git a/mcts-report/optimum_rate_heatmap.png b/mcts-report/optimum_rate_heatmap.png new file mode 100644 index 000000000..be0fcc741 Binary files /dev/null and b/mcts-report/optimum_rate_heatmap.png differ diff --git a/mcts-report/optimum_rate_heatmap_nig.png b/mcts-report/optimum_rate_heatmap_nig.png new file mode 100644 index 000000000..3bf6baa11 Binary files /dev/null and b/mcts-report/optimum_rate_heatmap_nig.png differ diff --git a/mcts-report/optimum_rate_heatmap_nig_adaptive.png b/mcts-report/optimum_rate_heatmap_nig_adaptive.png new file mode 100644 index 000000000..cd1477932 Binary files /dev/null and b/mcts-report/optimum_rate_heatmap_nig_adaptive.png differ diff --git a/mcts-report/optimum_rate_heatmap_nig_adaptive_n0.png b/mcts-report/optimum_rate_heatmap_nig_adaptive_n0.png new file mode 100644 index 000000000..c225e2ff1 Binary files /dev/null and b/mcts-report/optimum_rate_heatmap_nig_adaptive_n0.png differ diff --git a/mcts-report/results.json b/mcts-report/results.json new file mode 100644 index 000000000..38b8fca17 --- /dev/null +++ b/mcts-report/results.json @@ -0,0 +1,980 @@ +{ + "multigroup_interaction": { + "Random": { + "mean_best": 62.9, + "std_best": 10.338762014864255, + "median_best": 62.5, + "optimum_rate": 0.0, + "mean_unique_evals": 588.0 + }, + "MCTS (default)": { + "mean_best": 94.93333333333334, + "std_best": 18.90138148978064, + "median_best": 106.0, + "optimum_rate": 0.0, + "mean_unique_evals": 204.7 + }, + "MCTS (no RAVE)": { + "mean_best": 105.06666666666666, + "std_best": 25.77587157703024, + "median_best": 106.0, + "optimum_rate": 0.2, + "mean_unique_evals": 391.56666666666666 + }, + "MCTS (no PW)": { + "mean_best": 90.76666666666667, + "std_best": 19.186250864153273, + "median_best": 103.0, + "optimum_rate": 0.0, + "mean_unique_evals": 196.5 + }, + "MCTS (no RAVE, no PW)": { + "mean_best": 99.4, + "std_best": 29.823480682173905, + "median_best": 97.5, + "optimum_rate": 0.2, + "mean_unique_evals": 364.6666666666667 + }, + "MCTS (low explore)": { + "mean_best": 94.2, + "std_best": 17.24992753607968, + "median_best": 106.0, + "optimum_rate": 0.0, + "mean_unique_evals": 199.66666666666666 + }, + "MCTS (high explore)": { + "mean_best": 97.26666666666667, + "std_best": 15.73093202013861, + "median_best": 106.0, + "optimum_rate": 0.0, + "mean_unique_evals": 240.16666666666666 + }, + "MCTS (heavy RAVE)": { + "mean_best": 81.36666666666666, + "std_best": 18.734074718425664, + "median_best": 82.5, + "optimum_rate": 0.0, + "mean_unique_evals": 84.33333333333333 + }, + "MCTS (tight PW)": { + "mean_best": 91.36666666666666, + "std_best": 17.415478428366214, + "median_best": 103.0, + "optimum_rate": 0.0, + "mean_unique_evals": 195.86666666666667 + }, + "MCTS (loose PW)": { + "mean_best": 90.76666666666667, + "std_best": 19.186250864153273, + "median_best": 103.0, + "optimum_rate": 0.0, + "mean_unique_evals": 196.5 + }, + "MCTS (p_stop=0.1)": { + "mean_best": 96.23333333333333, + "std_best": 14.312426613106373, + "median_best": 104.5, + "optimum_rate": 0.0, + "mean_unique_evals": 237.53333333333333 + }, + "MCTS (p_stop=0.6)": { + "mean_best": 84.46666666666667, + "std_best": 20.779370111296018, + "median_best": 83.0, + "optimum_rate": 0.0, + "mean_unique_evals": 164.73333333333332 + }, + "MCTS (adaptive p)": { + "mean_best": 98.0, + "std_best": 14.484474446799926, + "median_best": 106.0, + "optimum_rate": 0.0, + "mean_unique_evals": 219.43333333333334 + }, + "MCTS (no RAVE+adpt)": { + "mean_best": 103.83333333333333, + "std_best": 29.728867377610534, + "median_best": 106.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 380.3333333333333 + }, + "MCTS (norm)": { + "mean_best": 101.83333333333333, + "std_best": 10.462897410479675, + "median_best": 106.0, + "optimum_rate": 0.0, + "mean_unique_evals": 321.46666666666664 + }, + "MCTS (no RAVE+adpt+norm)": { + "mean_best": 108.86666666666666, + "std_best": 25.145885459763704, + "median_best": 106.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 455.26666666666665 + }, + "MCTS (+rpol)": { + "mean_best": 111.43333333333334, + "std_best": 23.634273606118906, + "median_best": 109.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 516.2666666666667 + }, + "MCTS (+rpol \u03b5=0.1)": { + "mean_best": 114.13333333333334, + "std_best": 24.039874283272685, + "median_best": 109.0, + "optimum_rate": 0.26666666666666666, + "mean_unique_evals": 510.6333333333333 + }, + "MCTS (+rpol \u03c4=0.5)": { + "mean_best": 109.93333333333334, + "std_best": 22.71700293221406, + "median_best": 109.0, + "optimum_rate": 0.2, + "mean_unique_evals": 510.06666666666666 + }, + "MCTS (+rpol \u03c4=2)": { + "mean_best": 112.5, + "std_best": 23.097979709634057, + "median_best": 109.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 513.8333333333334 + }, + "MCTS (+crave k=100)": { + "mean_best": 105.33333333333333, + "std_best": 21.13659280857621, + "median_best": 106.0, + "optimum_rate": 0.13333333333333333, + "mean_unique_evals": 445.3 + }, + "MCTS (+crave k=300)": { + "mean_best": 103.53333333333333, + "std_best": 17.250571649143172, + "median_best": 107.5, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 370.1 + }, + "MCTS (+crave k=500)": { + "mean_best": 100.2, + "std_best": 19.506238318377363, + "median_best": 106.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 309.6666666666667 + } + }, + "needle_in_haystack": { + "Random": { + "mean_best": 39.666666666666664, + "std_best": 20.491190519071576, + "median_best": 32.5, + "optimum_rate": 0.1, + "mean_unique_evals": 216.2 + }, + "MCTS (default)": { + "mean_best": 77.0, + "std_best": 32.57299494980466, + "median_best": 100.0, + "optimum_rate": 0.6666666666666666, + "mean_unique_evals": 58.2 + }, + "MCTS (no RAVE)": { + "mean_best": 97.66666666666667, + "std_best": 12.565384549980509, + "median_best": 100.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 153.8 + }, + "MCTS (no PW)": { + "mean_best": 76.66666666666667, + "std_best": 32.99831645537222, + "median_best": 100.0, + "optimum_rate": 0.6666666666666666, + "mean_unique_evals": 56.03333333333333 + }, + "MCTS (no RAVE, no PW)": { + "mean_best": 97.66666666666667, + "std_best": 12.565384549980509, + "median_best": 100.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 147.23333333333332 + }, + "MCTS (low explore)": { + "mean_best": 74.33333333333333, + "std_best": 33.73260868786891, + "median_best": 100.0, + "optimum_rate": 0.6333333333333333, + "mean_unique_evals": 54.766666666666666 + }, + "MCTS (high explore)": { + "mean_best": 88.33333333333333, + "std_best": 26.087459737497543, + "median_best": 100.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 64.26666666666667 + }, + "MCTS (heavy RAVE)": { + "mean_best": 35.166666666666664, + "std_best": 17.957511582126916, + "median_best": 30.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 34.4 + }, + "MCTS (tight PW)": { + "mean_best": 39.833333333333336, + "std_best": 27.79338450463028, + "median_best": 30.0, + "optimum_rate": 0.16666666666666666, + "mean_unique_evals": 46.63333333333333 + }, + "MCTS (loose PW)": { + "mean_best": 76.66666666666667, + "std_best": 32.99831645537222, + "median_best": 100.0, + "optimum_rate": 0.6666666666666666, + "mean_unique_evals": 56.03333333333333 + }, + "MCTS (p_stop=0.1)": { + "mean_best": 67.0, + "std_best": 35.393031329156685, + "median_best": 100.0, + "optimum_rate": 0.5333333333333333, + "mean_unique_evals": 52.03333333333333 + }, + "MCTS (p_stop=0.6)": { + "mean_best": 91.16666666666667, + "std_best": 22.57149136016985, + "median_best": 100.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 63.1 + }, + "MCTS (adaptive p)": { + "mean_best": 84.0, + "std_best": 29.0516780926679, + "median_best": 100.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 60.666666666666664 + }, + "MCTS (no RAVE+adpt)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 160.83333333333334 + }, + "MCTS (norm)": { + "mean_best": 97.66666666666667, + "std_best": 12.565384549980509, + "median_best": 100.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 98.36666666666666 + }, + "MCTS (no RAVE+adpt+norm)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 246.56666666666666 + }, + "MCTS (+rpol)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 283.03333333333336 + }, + "MCTS (+rpol \u03b5=0.1)": { + "mean_best": 98.0, + "std_best": 10.770329614269007, + "median_best": 100.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 288.93333333333334 + }, + "MCTS (+rpol \u03c4=0.5)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 282.7 + }, + "MCTS (+rpol \u03c4=2)": { + "mean_best": 98.0, + "std_best": 10.770329614269007, + "median_best": 100.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 284.5 + }, + "MCTS (+crave k=100)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 218.9 + }, + "MCTS (+crave k=300)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 138.66666666666666 + }, + "MCTS (+crave k=500)": { + "mean_best": 97.66666666666667, + "std_best": 12.565384549980509, + "median_best": 100.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 105.63333333333334 + } + }, + "mixed_nchoosek_categorical": { + "Random": { + "mean_best": 79.16666666666667, + "std_best": 14.583285714207967, + "median_best": 80.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 471.96666666666664 + }, + "MCTS (default)": { + "mean_best": 84.5, + "std_best": 16.650825805346713, + "median_best": 90.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 110.96666666666667 + }, + "MCTS (no RAVE)": { + "mean_best": 113.6, + "std_best": 34.721367100581354, + "median_best": 90.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 284.06666666666666 + }, + "MCTS (no PW)": { + "mean_best": 82.2, + "std_best": 11.088733020503291, + "median_best": 90.0, + "optimum_rate": 0.0, + "mean_unique_evals": 114.33333333333333 + }, + "MCTS (no RAVE, no PW)": { + "mean_best": 108.53333333333333, + "std_best": 37.107711088427735, + "median_best": 90.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 278.93333333333334 + }, + "MCTS (low explore)": { + "mean_best": 87.73333333333333, + "std_best": 19.573678471071524, + "median_best": 90.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 106.4 + }, + "MCTS (high explore)": { + "mean_best": 88.1, + "std_best": 14.432255540974875, + "median_best": 90.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 141.3 + }, + "MCTS (heavy RAVE)": { + "mean_best": 74.73333333333333, + "std_best": 13.068877364010863, + "median_best": 80.0, + "optimum_rate": 0.0, + "mean_unique_evals": 46.166666666666664 + }, + "MCTS (tight PW)": { + "mean_best": 85.16666666666667, + "std_best": 9.19450318880193, + "median_best": 90.0, + "optimum_rate": 0.0, + "mean_unique_evals": 137.13333333333333 + }, + "MCTS (loose PW)": { + "mean_best": 82.2, + "std_best": 11.088733020503291, + "median_best": 90.0, + "optimum_rate": 0.0, + "mean_unique_evals": 114.33333333333333 + }, + "MCTS (p_stop=0.1)": { + "mean_best": 89.53333333333333, + "std_best": 18.445836627512694, + "median_best": 90.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 130.3 + }, + "MCTS (p_stop=0.6)": { + "mean_best": 83.9, + "std_best": 9.843271813782245, + "median_best": 90.0, + "optimum_rate": 0.0, + "mean_unique_evals": 101.0 + }, + "MCTS (adaptive p)": { + "mean_best": 85.83333333333333, + "std_best": 9.316949906249127, + "median_best": 90.0, + "optimum_rate": 0.0, + "mean_unique_evals": 111.53333333333333 + }, + "MCTS (no RAVE+adpt)": { + "mean_best": 110.36666666666666, + "std_best": 35.69078623709798, + "median_best": 90.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 280.3 + }, + "MCTS (norm)": { + "mean_best": 86.66666666666667, + "std_best": 8.498365855987972, + "median_best": 90.0, + "optimum_rate": 0.0, + "mean_unique_evals": 174.43333333333334 + }, + "MCTS (no RAVE+adpt+norm)": { + "mean_best": 126.96666666666667, + "std_best": 30.494243536918393, + "median_best": 150.0, + "optimum_rate": 0.6333333333333333, + "mean_unique_evals": 356.76666666666665 + }, + "MCTS (+rpol)": { + "mean_best": 135.86666666666667, + "std_best": 25.627762723699128, + "median_best": 150.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 441.6333333333333 + }, + "MCTS (+rpol \u03b5=0.1)": { + "mean_best": 126.0, + "std_best": 29.393876913398138, + "median_best": 150.0, + "optimum_rate": 0.6, + "mean_unique_evals": 403.96666666666664 + }, + "MCTS (+rpol \u03c4=0.5)": { + "mean_best": 140.0, + "std_best": 22.360679774997898, + "median_best": 150.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 444.2 + }, + "MCTS (+rpol \u03c4=2)": { + "mean_best": 136.0, + "std_best": 25.37715508089904, + "median_best": 150.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 439.9 + }, + "MCTS (+crave k=100)": { + "mean_best": 131.86666666666667, + "std_best": 27.70768044343582, + "median_best": 150.0, + "optimum_rate": 0.7, + "mean_unique_evals": 373.2 + }, + "MCTS (+crave k=300)": { + "mean_best": 137.66666666666666, + "std_best": 24.722908854384883, + "median_best": 150.0, + "optimum_rate": 0.8, + "mean_unique_evals": 304.03333333333336 + }, + "MCTS (+crave k=500)": { + "mean_best": 132.0, + "std_best": 27.49545416973504, + "median_best": 150.0, + "optimum_rate": 0.7, + "mean_unique_evals": 264.03333333333336 + } + }, + "large_sparse": { + "Random": { + "mean_best": 36.06666666666667, + "std_best": 6.33473668662628, + "median_best": 38.0, + "optimum_rate": 0.0, + "mean_unique_evals": 763.6333333333333 + }, + "MCTS (default)": { + "mean_best": 40.0, + "std_best": 8.390470785361213, + "median_best": 44.0, + "optimum_rate": 0.0, + "mean_unique_evals": 303.4 + }, + "MCTS (no RAVE)": { + "mean_best": 83.8, + "std_best": 64.53805079176159, + "median_best": 55.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 514.9 + }, + "MCTS (no PW)": { + "mean_best": 40.2, + "std_best": 6.289674077406556, + "median_best": 44.0, + "optimum_rate": 0.0, + "mean_unique_evals": 308.5 + }, + "MCTS (no RAVE, no PW)": { + "mean_best": 61.46666666666667, + "std_best": 38.062871964976864, + "median_best": 56.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 513.2666666666667 + }, + "MCTS (low explore)": { + "mean_best": 38.0, + "std_best": 8.049844718999243, + "median_best": 32.0, + "optimum_rate": 0.0, + "mean_unique_evals": 284.8333333333333 + }, + "MCTS (high explore)": { + "mean_best": 47.2, + "std_best": 28.88875213642846, + "median_best": 44.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 390.8333333333333 + }, + "MCTS (heavy RAVE)": { + "mean_best": 31.266666666666666, + "std_best": 9.387698806890265, + "median_best": 32.0, + "optimum_rate": 0.0, + "mean_unique_evals": 89.9 + }, + "MCTS (tight PW)": { + "mean_best": 52.6, + "std_best": 40.91014544095389, + "median_best": 44.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 223.1 + }, + "MCTS (loose PW)": { + "mean_best": 40.2, + "std_best": 6.289674077406556, + "median_best": 44.0, + "optimum_rate": 0.0, + "mean_unique_evals": 308.5 + }, + "MCTS (p_stop=0.1)": { + "mean_best": 33.6, + "std_best": 10.499523798725349, + "median_best": 32.0, + "optimum_rate": 0.0, + "mean_unique_evals": 183.06666666666666 + }, + "MCTS (p_stop=0.6)": { + "mean_best": 55.4, + "std_best": 28.05779749018087, + "median_best": 48.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 447.9 + }, + "MCTS (adaptive p)": { + "mean_best": 54.6, + "std_best": 27.843132007732173, + "median_best": 51.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 420.9 + }, + "MCTS (no RAVE+adpt)": { + "mean_best": 93.0, + "std_best": 64.90762667052309, + "median_best": 62.0, + "optimum_rate": 0.26666666666666666, + "mean_unique_evals": 549.8666666666667 + }, + "MCTS (norm)": { + "mean_best": 40.53333333333333, + "std_best": 7.931932649459119, + "median_best": 43.0, + "optimum_rate": 0.0, + "mean_unique_evals": 603.3 + }, + "MCTS (no RAVE+adpt+norm)": { + "mean_best": 112.13333333333334, + "std_best": 71.97209335723272, + "median_best": 62.0, + "optimum_rate": 0.4, + "mean_unique_evals": 689.3 + }, + "MCTS (+rpol)": { + "mean_best": 129.8, + "std_best": 70.24784694209497, + "median_best": 131.0, + "optimum_rate": 0.5, + "mean_unique_evals": 750.2666666666667 + }, + "MCTS (+rpol \u03b5=0.1)": { + "mean_best": 90.4, + "std_best": 60.65509047062744, + "median_best": 62.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 749.3 + }, + "MCTS (+rpol \u03c4=0.5)": { + "mean_best": 128.13333333333333, + "std_best": 72.02024406759226, + "median_best": 131.0, + "optimum_rate": 0.5, + "mean_unique_evals": 748.7 + }, + "MCTS (+rpol \u03c4=2)": { + "mean_best": 118.66666666666667, + "std_best": 71.19238411203516, + "median_best": 62.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 748.7333333333333 + }, + "MCTS (+crave k=100)": { + "mean_best": 128.8, + "std_best": 71.3425539212047, + "median_best": 131.0, + "optimum_rate": 0.5, + "mean_unique_evals": 696.3333333333334 + }, + "MCTS (+crave k=300)": { + "mean_best": 119.13333333333334, + "std_best": 70.79488368205399, + "median_best": 62.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 650.7666666666667 + }, + "MCTS (+crave k=500)": { + "mean_best": 118.53333333333333, + "std_best": 71.42909460125864, + "median_best": 62.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 591.3666666666667 + } + }, + "graduated_landscape": { + "Random": { + "mean_best": 60.56666666666667, + "std_best": 3.2932591084753047, + "median_best": 61.5, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 112.63333333333334 + }, + "MCTS (default)": { + "mean_best": 64.1, + "std_best": 1.3747727084867518, + "median_best": 64.0, + "optimum_rate": 0.4, + "mean_unique_evals": 64.8 + }, + "MCTS (no RAVE)": { + "mean_best": 64.9, + "std_best": 0.30000000000000004, + "median_best": 65.0, + "optimum_rate": 0.9, + "mean_unique_evals": 161.6 + }, + "MCTS (no PW)": { + "mean_best": 63.666666666666664, + "std_best": 1.6599866130651644, + "median_best": 64.0, + "optimum_rate": 0.16666666666666666, + "mean_unique_evals": 58.36666666666667 + }, + "MCTS (no RAVE, no PW)": { + "mean_best": 64.8, + "std_best": 0.39999999999999997, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 157.46666666666667 + }, + "MCTS (low explore)": { + "mean_best": 63.86666666666667, + "std_best": 1.2840906856172152, + "median_best": 64.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 63.8 + }, + "MCTS (high explore)": { + "mean_best": 64.3, + "std_best": 0.7810249675906655, + "median_best": 64.0, + "optimum_rate": 0.4, + "mean_unique_evals": 74.8 + }, + "MCTS (heavy RAVE)": { + "mean_best": 55.36666666666667, + "std_best": 5.357134391527702, + "median_best": 55.0, + "optimum_rate": 0.0, + "mean_unique_evals": 20.933333333333334 + }, + "MCTS (tight PW)": { + "mean_best": 61.63333333333333, + "std_best": 3.4783457115256518, + "median_best": 64.0, + "optimum_rate": 0.1, + "mean_unique_evals": 46.63333333333333 + }, + "MCTS (loose PW)": { + "mean_best": 63.666666666666664, + "std_best": 1.6599866130651644, + "median_best": 64.0, + "optimum_rate": 0.16666666666666666, + "mean_unique_evals": 58.36666666666667 + }, + "MCTS (p_stop=0.1)": { + "mean_best": 64.46666666666667, + "std_best": 0.4988876515698589, + "median_best": 64.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 89.33333333333333 + }, + "MCTS (p_stop=0.6)": { + "mean_best": 62.43333333333333, + "std_best": 2.679344861881144, + "median_best": 64.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 49.233333333333334 + }, + "MCTS (adaptive p)": { + "mean_best": 64.46666666666667, + "std_best": 0.8055363982396382, + "median_best": 65.0, + "optimum_rate": 0.5666666666666667, + "mean_unique_evals": 74.96666666666667 + }, + "MCTS (no RAVE+adpt)": { + "mean_best": 64.63333333333334, + "std_best": 0.9480975102218596, + "median_best": 65.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 168.4 + }, + "MCTS (norm)": { + "mean_best": 62.43333333333333, + "std_best": 3.190437099973746, + "median_best": 64.0, + "optimum_rate": 0.1, + "mean_unique_evals": 48.9 + }, + "MCTS (no RAVE+adpt+norm)": { + "mean_best": 64.7, + "std_best": 0.7810249675906655, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 152.4 + }, + "MCTS (+rpol)": { + "mean_best": 64.53333333333333, + "std_best": 1.3597385369580757, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 156.96666666666667 + }, + "MCTS (+rpol \u03b5=0.1)": { + "mean_best": 64.93333333333334, + "std_best": 0.24944382578492946, + "median_best": 65.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 151.43333333333334 + }, + "MCTS (+rpol \u03c4=0.5)": { + "mean_best": 64.73333333333333, + "std_best": 0.4422166387140533, + "median_best": 65.0, + "optimum_rate": 0.7333333333333333, + "mean_unique_evals": 153.76666666666668 + }, + "MCTS (+rpol \u03c4=2)": { + "mean_best": 64.73333333333333, + "std_best": 0.6289320754704402, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 162.56666666666666 + }, + "MCTS (+crave k=100)": { + "mean_best": 64.7, + "std_best": 0.45825756949558394, + "median_best": 65.0, + "optimum_rate": 0.7, + "mean_unique_evals": 108.66666666666667 + }, + "MCTS (+crave k=300)": { + "mean_best": 63.86666666666667, + "std_best": 2.15612821717283, + "median_best": 64.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 71.96666666666667 + }, + "MCTS (+crave k=500)": { + "mean_best": 63.0, + "std_best": 2.932575659723036, + "median_best": 64.0, + "optimum_rate": 0.3, + "mean_unique_evals": 53.03333333333333 + } + }, + "simple_additive": { + "Random": { + "mean_best": 57.666666666666664, + "std_best": 3.259175083088085, + "median_best": 58.5, + "optimum_rate": 0.0, + "mean_unique_evals": 115.43333333333334 + }, + "MCTS (default)": { + "mean_best": 61.833333333333336, + "std_best": 3.742399705477163, + "median_best": 63.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 69.66666666666667 + }, + "MCTS (no RAVE)": { + "mean_best": 63.96666666666667, + "std_best": 1.9745604292826517, + "median_best": 65.0, + "optimum_rate": 0.7, + "mean_unique_evals": 167.06666666666666 + }, + "MCTS (no PW)": { + "mean_best": 62.03333333333333, + "std_best": 3.1778749013906906, + "median_best": 63.0, + "optimum_rate": 0.4, + "mean_unique_evals": 66.66666666666667 + }, + "MCTS (no RAVE, no PW)": { + "mean_best": 63.86666666666667, + "std_best": 1.892675942210452, + "median_best": 65.0, + "optimum_rate": 0.7333333333333333, + "mean_unique_evals": 162.16666666666666 + }, + "MCTS (low explore)": { + "mean_best": 61.5, + "std_best": 3.930648801406709, + "median_best": 63.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 66.06666666666666 + }, + "MCTS (high explore)": { + "mean_best": 63.266666666666666, + "std_best": 1.8607047649270483, + "median_best": 63.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 79.73333333333333 + }, + "MCTS (heavy RAVE)": { + "mean_best": 52.4, + "std_best": 5.986651818838307, + "median_best": 52.5, + "optimum_rate": 0.0, + "mean_unique_evals": 27.166666666666668 + }, + "MCTS (tight PW)": { + "mean_best": 57.5, + "std_best": 4.031128874149275, + "median_best": 55.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 41.4 + }, + "MCTS (loose PW)": { + "mean_best": 62.03333333333333, + "std_best": 3.1778749013906906, + "median_best": 63.0, + "optimum_rate": 0.4, + "mean_unique_evals": 66.66666666666667 + }, + "MCTS (p_stop=0.1)": { + "mean_best": 64.3, + "std_best": 1.5524174696260027, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 99.53333333333333 + }, + "MCTS (p_stop=0.6)": { + "mean_best": 59.3, + "std_best": 4.824589792579952, + "median_best": 60.0, + "optimum_rate": 0.2, + "mean_unique_evals": 46.266666666666666 + }, + "MCTS (adaptive p)": { + "mean_best": 63.266666666666666, + "std_best": 2.0645150089602695, + "median_best": 64.0, + "optimum_rate": 0.5, + "mean_unique_evals": 79.86666666666666 + }, + "MCTS (no RAVE+adpt)": { + "mean_best": 64.1, + "std_best": 1.6802777548171413, + "median_best": 65.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 181.33333333333334 + }, + "MCTS (norm)": { + "mean_best": 61.733333333333334, + "std_best": 3.453822359196965, + "median_best": 63.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 71.0 + }, + "MCTS (no RAVE+adpt+norm)": { + "mean_best": 64.1, + "std_best": 2.2412793965352322, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 183.8 + }, + "MCTS (+rpol)": { + "mean_best": 64.1, + "std_best": 2.150193789716018, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 186.8 + }, + "MCTS (+rpol \u03b5=0.1)": { + "mean_best": 64.43333333333334, + "std_best": 1.3585122581543223, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 174.73333333333332 + }, + "MCTS (+rpol \u03c4=0.5)": { + "mean_best": 64.06666666666666, + "std_best": 2.0965580258021848, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 187.13333333333333 + }, + "MCTS (+rpol \u03c4=2)": { + "mean_best": 64.26666666666667, + "std_best": 1.4126413400278057, + "median_best": 65.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 187.0 + }, + "MCTS (+crave k=100)": { + "mean_best": 64.23333333333333, + "std_best": 1.498517786199268, + "median_best": 65.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 143.03333333333333 + }, + "MCTS (+crave k=300)": { + "mean_best": 63.666666666666664, + "std_best": 2.399073895392877, + "median_best": 65.0, + "optimum_rate": 0.6666666666666666, + "mean_unique_evals": 92.23333333333333 + }, + "MCTS (+crave k=500)": { + "mean_best": 62.8, + "std_best": 2.587147721590967, + "median_best": 63.0, + "optimum_rate": 0.4, + "mean_unique_evals": 68.7 + } + } +} diff --git a/mcts-report/results_nig.json b/mcts-report/results_nig.json new file mode 100644 index 000000000..b44aab028 --- /dev/null +++ b/mcts-report/results_nig.json @@ -0,0 +1,476 @@ +{ + "multigroup_interaction": { + "Random": { + "mean_best": 62.9, + "std_best": 10.338762014864255, + "median_best": 62.5, + "optimum_rate": 0.0, + "mean_unique_evals": 588.0 + }, + "UCT (+rpol)": { + "mean_best": 111.43333333333334, + "std_best": 23.634273606118906, + "median_best": 109.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 516.2666666666667 + }, + "TS + TS(g,a) + comb": { + "mean_best": 115.4, + "std_best": 26.801243752234086, + "median_best": 109.0, + "optimum_rate": 0.3333333333333333, + "mean_unique_evals": 475.06666666666666 + }, + "NIG + uniform": { + "mean_best": 107.7, + "std_best": 33.670115334917796, + "median_best": 107.5, + "optimum_rate": 0.3333333333333333, + "mean_unique_evals": 200.46666666666667 + }, + "NIG + TS(g,a)": { + "mean_best": 118.53333333333333, + "std_best": 19.49826203081244, + "median_best": 109.0, + "optimum_rate": 0.26666666666666666, + "mean_unique_evals": 335.8 + }, + "NIG + TS(g,a) + comb": { + "mean_best": 119.4, + "std_best": 22.915206014056835, + "median_best": 109.0, + "optimum_rate": 0.3333333333333333, + "mean_unique_evals": 537.2 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 127.63333333333334, + "std_best": 24.67992886717644, + "median_best": 150.0, + "optimum_rate": 0.5333333333333333, + "mean_unique_evals": 568.4333333333333 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 141.3, + "std_best": 17.57868026900768, + "median_best": 150.0, + "optimum_rate": 0.8, + "mean_unique_evals": 532.0666666666667 + }, + "NIG + TS(g,a) + pess": { + "mean_best": 126.06666666666666, + "std_best": 21.101237457130857, + "median_best": 109.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 524.1 + }, + "NIG + uniform + pess + apv": { + "mean_best": 121.83333333333333, + "std_best": 27.40569689356982, + "median_best": 109.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 518.3 + }, + "NIG + TS(g,a) + comb (a0=2)": { + "mean_best": 111.13333333333334, + "std_best": 21.400519204501144, + "median_best": 106.0, + "optimum_rate": 0.2, + "mean_unique_evals": 522.3666666666667 + } + }, + "needle_in_haystack": { + "Random": { + "mean_best": 39.666666666666664, + "std_best": 20.491190519071576, + "median_best": 32.5, + "optimum_rate": 0.1, + "mean_unique_evals": 216.2 + }, + "UCT (+rpol)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 283.03333333333336 + }, + "TS + TS(g,a) + comb": { + "mean_best": 94.0, + "std_best": 18.0, + "median_best": 100.0, + "optimum_rate": 0.9, + "mean_unique_evals": 282.06666666666666 + }, + "NIG + uniform": { + "mean_best": 74.5, + "std_best": 33.52486639297265, + "median_best": 100.0, + "optimum_rate": 0.6333333333333333, + "mean_unique_evals": 76.23333333333333 + }, + "NIG + TS(g,a)": { + "mean_best": 93.0, + "std_best": 21.0, + "median_best": 100.0, + "optimum_rate": 0.9, + "mean_unique_evals": 105.53333333333333 + }, + "NIG + TS(g,a) + comb": { + "mean_best": 94.0, + "std_best": 18.0, + "median_best": 100.0, + "optimum_rate": 0.9, + "mean_unique_evals": 280.6666666666667 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 265.06666666666666 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 181.5 + }, + "NIG + TS(g,a) + pess": { + "mean_best": 94.0, + "std_best": 18.0, + "median_best": 100.0, + "optimum_rate": 0.9, + "mean_unique_evals": 280.1333333333333 + }, + "NIG + uniform + pess + apv": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 258.73333333333335 + }, + "NIG + TS(g,a) + comb (a0=2)": { + "mean_best": 96.0, + "std_best": 14.966629547095765, + "median_best": 100.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 286.0 + } + }, + "mixed_nchoosek_categorical": { + "Random": { + "mean_best": 79.16666666666667, + "std_best": 14.583285714207967, + "median_best": 80.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 471.96666666666664 + }, + "UCT (+rpol)": { + "mean_best": 135.86666666666667, + "std_best": 25.627762723699128, + "median_best": 150.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 441.6333333333333 + }, + "TS + TS(g,a) + comb": { + "mean_best": 112.63333333333334, + "std_best": 33.18080100834752, + "median_best": 90.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 359.6 + }, + "NIG + uniform": { + "mean_best": 117.53333333333333, + "std_best": 32.82045432687096, + "median_best": 120.0, + "optimum_rate": 0.5, + "mean_unique_evals": 140.93333333333334 + }, + "NIG + TS(g,a)": { + "mean_best": 144.0, + "std_best": 18.0, + "median_best": 150.0, + "optimum_rate": 0.9, + "mean_unique_evals": 375.1 + }, + "NIG + TS(g,a) + comb": { + "mean_best": 144.0, + "std_best": 18.0, + "median_best": 150.0, + "optimum_rate": 0.9, + "mean_unique_evals": 388.76666666666665 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 150.0, + "std_best": 0.0, + "median_best": 150.0, + "optimum_rate": 1.0, + "mean_unique_evals": 389.4 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 150.0, + "std_best": 0.0, + "median_best": 150.0, + "optimum_rate": 1.0, + "mean_unique_evals": 384.73333333333335 + }, + "NIG + TS(g,a) + pess": { + "mean_best": 142.0, + "std_best": 20.396078054371138, + "median_best": 150.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 381.76666666666665 + }, + "NIG + uniform + pess + apv": { + "mean_best": 146.0, + "std_best": 14.966629547095765, + "median_best": 150.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 405.3333333333333 + }, + "NIG + TS(g,a) + comb (a0=2)": { + "mean_best": 140.0, + "std_best": 22.360679774997898, + "median_best": 150.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 381.3666666666667 + } + }, + "large_sparse": { + "Random": { + "mean_best": 36.06666666666667, + "std_best": 6.33473668662628, + "median_best": 38.0, + "optimum_rate": 0.0, + "mean_unique_evals": 763.6333333333333 + }, + "UCT (+rpol)": { + "mean_best": 129.8, + "std_best": 70.24784694209497, + "median_best": 131.0, + "optimum_rate": 0.5, + "mean_unique_evals": 750.2666666666667 + }, + "TS + TS(g,a) + comb": { + "mean_best": 84.4, + "std_best": 57.98482560118637, + "median_best": 56.0, + "optimum_rate": 0.2, + "mean_unique_evals": 671.1666666666666 + }, + "NIG + uniform": { + "mean_best": 65.8, + "std_best": 45.37356058322952, + "median_best": 55.0, + "optimum_rate": 0.1, + "mean_unique_evals": 300.76666666666665 + }, + "NIG + TS(g,a)": { + "mean_best": 124.0, + "std_best": 71.12805353726475, + "median_best": 62.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 607.6333333333333 + }, + "NIG + TS(g,a) + comb": { + "mean_best": 112.86666666666666, + "std_best": 71.27726768675191, + "median_best": 62.0, + "optimum_rate": 0.4, + "mean_unique_evals": 742.4333333333333 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 114.46666666666667, + "std_best": 69.91362925464216, + "median_best": 62.0, + "optimum_rate": 0.4, + "mean_unique_evals": 764.1666666666666 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 123.26666666666667, + "std_best": 71.87115013844026, + "median_best": 62.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 749.7333333333333 + }, + "NIG + TS(g,a) + pess": { + "mean_best": 122.66666666666667, + "std_best": 72.46945716798369, + "median_best": 62.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 735.7 + }, + "NIG + uniform + pess + apv": { + "mean_best": 118.26666666666667, + "std_best": 71.54249708312459, + "median_best": 62.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 709.0 + }, + "NIG + TS(g,a) + comb (a0=2)": { + "mean_best": 90.33333333333333, + "std_best": 60.62141169220291, + "median_best": 59.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 729.5 + } + }, + "graduated_landscape": { + "Random": { + "mean_best": 60.56666666666667, + "std_best": 3.2932591084753047, + "median_best": 61.5, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 112.63333333333334 + }, + "UCT (+rpol)": { + "mean_best": 64.53333333333333, + "std_best": 1.3597385369580757, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 156.96666666666667 + }, + "TS + TS(g,a) + comb": { + "mean_best": 64.96666666666667, + "std_best": 0.17950549357115012, + "median_best": 65.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 174.56666666666666 + }, + "NIG + uniform": { + "mean_best": 63.56666666666667, + "std_best": 2.0604745947367453, + "median_best": 64.0, + "optimum_rate": 0.3, + "mean_unique_evals": 61.43333333333333 + }, + "NIG + TS(g,a)": { + "mean_best": 64.4, + "std_best": 0.6633249580710798, + "median_best": 64.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 86.4 + }, + "NIG + TS(g,a) + comb": { + "mean_best": 64.96666666666667, + "std_best": 0.17950549357115012, + "median_best": 65.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 179.66666666666666 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 169.86666666666667 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 64.76666666666667, + "std_best": 0.42295258468165065, + "median_best": 65.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 122.73333333333333 + }, + "NIG + TS(g,a) + pess": { + "mean_best": 64.9, + "std_best": 0.30000000000000004, + "median_best": 65.0, + "optimum_rate": 0.9, + "mean_unique_evals": 177.2 + }, + "NIG + uniform + pess + apv": { + "mean_best": 64.96666666666667, + "std_best": 0.17950549357115012, + "median_best": 65.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 160.4 + }, + "NIG + TS(g,a) + comb (a0=2)": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 185.53333333333333 + } + }, + "simple_additive": { + "Random": { + "mean_best": 57.666666666666664, + "std_best": 3.259175083088085, + "median_best": 58.5, + "optimum_rate": 0.0, + "mean_unique_evals": 115.43333333333334 + }, + "UCT (+rpol)": { + "mean_best": 64.1, + "std_best": 2.150193789716018, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 186.8 + }, + "TS + TS(g,a) + comb": { + "mean_best": 64.53333333333333, + "std_best": 1.1175369742826806, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 192.46666666666667 + }, + "NIG + uniform": { + "mean_best": 62.8, + "std_best": 2.5742312768410427, + "median_best": 63.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 82.03333333333333 + }, + "NIG + TS(g,a)": { + "mean_best": 64.43333333333334, + "std_best": 1.5205992970609392, + "median_best": 65.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 102.26666666666667 + }, + "NIG + TS(g,a) + comb": { + "mean_best": 64.6, + "std_best": 0.9521904571390465, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 202.0 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 64.66666666666667, + "std_best": 0.7453559924999298, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 190.33333333333334 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 64.66666666666667, + "std_best": 0.7453559924999298, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 135.7 + }, + "NIG + TS(g,a) + pess": { + "mean_best": 64.93333333333334, + "std_best": 0.3590109871423003, + "median_best": 65.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 198.76666666666668 + }, + "NIG + uniform + pess + apv": { + "mean_best": 64.8, + "std_best": 0.6, + "median_best": 65.0, + "optimum_rate": 0.9, + "mean_unique_evals": 186.26666666666668 + }, + "NIG + TS(g,a) + comb (a0=2)": { + "mean_best": 64.86666666666666, + "std_best": 0.49888765156985887, + "median_best": 65.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 194.7 + } + } +} diff --git a/mcts-report/results_nig_adaptive.json b/mcts-report/results_nig_adaptive.json new file mode 100644 index 000000000..a28d358fc --- /dev/null +++ b/mcts-report/results_nig_adaptive.json @@ -0,0 +1,392 @@ +{ + "multigroup_interaction": { + "Random": { + "mean_best": 62.9, + "std_best": 10.338762014864255, + "median_best": 62.5, + "optimum_rate": 0.0, + "mean_unique_evals": 588.0 + }, + "UCT (+rpol)": { + "mean_best": 111.43333333333334, + "std_best": 23.634273606118906, + "median_best": 109.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 516.2666666666667 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 141.3, + "std_best": 17.57868026900768, + "median_best": 150.0, + "optimum_rate": 0.8, + "mean_unique_evals": 532.0666666666667 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 127.63333333333334, + "std_best": 24.67992886717644, + "median_best": 150.0, + "optimum_rate": 0.5333333333333333, + "mean_unique_evals": 568.4333333333333 + }, + "NIG + TS(g,a) + acomb + apv": { + "mean_best": 129.03333333333333, + "std_best": 24.442426684399038, + "median_best": 150.0, + "optimum_rate": 0.5666666666666667, + "mean_unique_evals": 562.9666666666667 + }, + "NIG + TS(g,a) + acomb": { + "mean_best": 123.83333333333333, + "std_best": 21.780852957484367, + "median_best": 109.0, + "optimum_rate": 0.4, + "mean_unique_evals": 523.9 + }, + "NIG + TS(g,a) + apess + apv": { + "mean_best": 135.16666666666666, + "std_best": 21.293321853472168, + "median_best": 150.0, + "optimum_rate": 0.6666666666666666, + "mean_unique_evals": 548.2333333333333 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 122.33333333333333, + "std_best": 21.833206106499542, + "median_best": 109.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 475.0 + }, + "NIG + uniform + apess + apv": { + "mean_best": 119.7, + "std_best": 29.888849648879656, + "median_best": 109.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 448.7 + } + }, + "needle_in_haystack": { + "Random": { + "mean_best": 39.666666666666664, + "std_best": 20.491190519071576, + "median_best": 32.5, + "optimum_rate": 0.1, + "mean_unique_evals": 216.2 + }, + "UCT (+rpol)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 283.03333333333336 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 181.5 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 265.06666666666666 + }, + "NIG + TS(g,a) + acomb + apv": { + "mean_best": 96.0, + "std_best": 14.966629547095765, + "median_best": 100.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 258.03333333333336 + }, + "NIG + TS(g,a) + acomb": { + "mean_best": 96.0, + "std_best": 14.966629547095765, + "median_best": 100.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 275.43333333333334 + }, + "NIG + TS(g,a) + apess + apv": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 237.06666666666666 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 98.0, + "std_best": 10.770329614269007, + "median_best": 100.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 207.06666666666666 + }, + "NIG + uniform + apess + apv": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 208.2 + } + }, + "mixed_nchoosek_categorical": { + "Random": { + "mean_best": 79.16666666666667, + "std_best": 14.583285714207967, + "median_best": 80.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 471.96666666666664 + }, + "UCT (+rpol)": { + "mean_best": 135.86666666666667, + "std_best": 25.627762723699128, + "median_best": 150.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 441.6333333333333 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 150.0, + "std_best": 0.0, + "median_best": 150.0, + "optimum_rate": 1.0, + "mean_unique_evals": 384.73333333333335 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 150.0, + "std_best": 0.0, + "median_best": 150.0, + "optimum_rate": 1.0, + "mean_unique_evals": 389.4 + }, + "NIG + TS(g,a) + acomb + apv": { + "mean_best": 146.0, + "std_best": 14.966629547095765, + "median_best": 150.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 386.3333333333333 + }, + "NIG + TS(g,a) + acomb": { + "mean_best": 146.0, + "std_best": 14.966629547095765, + "median_best": 150.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 387.3666666666667 + }, + "NIG + TS(g,a) + apess + apv": { + "mean_best": 148.0, + "std_best": 10.770329614269007, + "median_best": 150.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 378.23333333333335 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 146.0, + "std_best": 14.966629547095765, + "median_best": 150.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 380.03333333333336 + }, + "NIG + uniform + apess + apv": { + "mean_best": 146.0, + "std_best": 14.966629547095765, + "median_best": 150.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 333.76666666666665 + } + }, + "large_sparse": { + "Random": { + "mean_best": 36.06666666666667, + "std_best": 6.33473668662628, + "median_best": 38.0, + "optimum_rate": 0.0, + "mean_unique_evals": 763.6333333333333 + }, + "UCT (+rpol)": { + "mean_best": 129.8, + "std_best": 70.24784694209497, + "median_best": 131.0, + "optimum_rate": 0.5, + "mean_unique_evals": 750.2666666666667 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 123.26666666666667, + "std_best": 71.87115013844026, + "median_best": 62.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 749.7333333333333 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 114.46666666666667, + "std_best": 69.91362925464216, + "median_best": 62.0, + "optimum_rate": 0.4, + "mean_unique_evals": 764.1666666666666 + }, + "NIG + TS(g,a) + acomb + apv": { + "mean_best": 100.13333333333334, + "std_best": 65.47657766119288, + "median_best": 62.0, + "optimum_rate": 0.3, + "mean_unique_evals": 762.3666666666667 + }, + "NIG + TS(g,a) + acomb": { + "mean_best": 133.2, + "std_best": 71.49937062660062, + "median_best": 200.0, + "optimum_rate": 0.5333333333333333, + "mean_unique_evals": 734.0333333333333 + }, + "NIG + TS(g,a) + apess + apv": { + "mean_best": 113.73333333333333, + "std_best": 70.53081280940661, + "median_best": 62.0, + "optimum_rate": 0.4, + "mean_unique_evals": 755.4666666666667 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 134.4, + "std_best": 70.17007909358517, + "median_best": 200.0, + "optimum_rate": 0.5333333333333333, + "mean_unique_evals": 688.2333333333333 + }, + "NIG + uniform + apess + apv": { + "mean_best": 118.26666666666667, + "std_best": 71.54249708312459, + "median_best": 62.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 666.7666666666667 + } + }, + "graduated_landscape": { + "Random": { + "mean_best": 60.56666666666667, + "std_best": 3.2932591084753047, + "median_best": 61.5, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 112.63333333333334 + }, + "UCT (+rpol)": { + "mean_best": 64.53333333333333, + "std_best": 1.3597385369580757, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 156.96666666666667 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 64.76666666666667, + "std_best": 0.42295258468165065, + "median_best": 65.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 122.73333333333333 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 169.86666666666667 + }, + "NIG + TS(g,a) + acomb + apv": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 163.43333333333334 + }, + "NIG + TS(g,a) + acomb": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 172.46666666666667 + }, + "NIG + TS(g,a) + apess + apv": { + "mean_best": 64.9, + "std_best": 0.30000000000000004, + "median_best": 65.0, + "optimum_rate": 0.9, + "mean_unique_evals": 148.96666666666667 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 64.86666666666666, + "std_best": 0.33993463423951903, + "median_best": 65.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 129.13333333333333 + }, + "NIG + uniform + apess + apv": { + "mean_best": 64.5, + "std_best": 0.5, + "median_best": 64.5, + "optimum_rate": 0.5, + "mean_unique_evals": 118.0 + } + }, + "simple_additive": { + "Random": { + "mean_best": 57.666666666666664, + "std_best": 3.259175083088085, + "median_best": 58.5, + "optimum_rate": 0.0, + "mean_unique_evals": 115.43333333333334 + }, + "UCT (+rpol)": { + "mean_best": 64.1, + "std_best": 2.150193789716018, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 186.8 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 64.66666666666667, + "std_best": 0.7453559924999298, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 135.7 + }, + "NIG + TS(g,a) + comb + apv": { + "mean_best": 64.66666666666667, + "std_best": 0.7453559924999298, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 190.33333333333334 + }, + "NIG + TS(g,a) + acomb + apv": { + "mean_best": 64.73333333333333, + "std_best": 0.6798692684790381, + "median_best": 65.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 182.56666666666666 + }, + "NIG + TS(g,a) + acomb": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 188.6 + }, + "NIG + TS(g,a) + apess + apv": { + "mean_best": 64.93333333333334, + "std_best": 0.35901098714230023, + "median_best": 65.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 169.66666666666666 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 147.76666666666668 + }, + "NIG + uniform + apess + apv": { + "mean_best": 64.86666666666666, + "std_best": 0.7180219742846004, + "median_best": 65.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 145.6 + } + } +} diff --git a/mcts-report/results_nig_adaptive_n0.json b/mcts-report/results_nig_adaptive_n0.json new file mode 100644 index 000000000..d028147b5 --- /dev/null +++ b/mcts-report/results_nig_adaptive_n0.json @@ -0,0 +1,392 @@ +{ + "multigroup_interaction": { + "Random": { + "mean_best": 62.9, + "std_best": 10.338762014864255, + "median_best": 62.5, + "optimum_rate": 0.0, + "mean_unique_evals": 588.0 + }, + "UCT (+rpol)": { + "mean_best": 111.43333333333334, + "std_best": 23.634273606118906, + "median_best": 109.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 516.2666666666667 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 141.3, + "std_best": 17.57868026900768, + "median_best": 150.0, + "optimum_rate": 0.8, + "mean_unique_evals": 532.0666666666667 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 122.33333333333333, + "std_best": 21.833206106499542, + "median_best": 109.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 475.0 + }, + "NIG + TS(g,a) + vi + apv + an0": { + "mean_best": 130.53333333333333, + "std_best": 22.838174669229204, + "median_best": 150.0, + "optimum_rate": 0.5666666666666667, + "mean_unique_evals": 542.5 + }, + "NIG + TS(g,a) + apess + an0": { + "mean_best": 124.83333333333333, + "std_best": 24.691541511663914, + "median_best": 109.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 468.4 + }, + "NIG + TS(g,a) + acomb + an0": { + "mean_best": 126.03333333333333, + "std_best": 23.333071427101537, + "median_best": 109.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 526.8666666666667 + }, + "NIG + TS(g,a) + an0": { + "mean_best": 121.06666666666666, + "std_best": 23.21197009207869, + "median_best": 109.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 323.3333333333333 + }, + "NIG + uniform + an0": { + "mean_best": 97.2, + "std_best": 33.29604581127715, + "median_best": 92.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 195.03333333333333 + } + }, + "needle_in_haystack": { + "Random": { + "mean_best": 39.666666666666664, + "std_best": 20.491190519071576, + "median_best": 32.5, + "optimum_rate": 0.1, + "mean_unique_evals": 216.2 + }, + "UCT (+rpol)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 283.03333333333336 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 181.5 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 98.0, + "std_best": 10.770329614269007, + "median_best": 100.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 207.06666666666666 + }, + "NIG + TS(g,a) + vi + apv + an0": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 191.43333333333334 + }, + "NIG + TS(g,a) + apess + an0": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 204.8 + }, + "NIG + TS(g,a) + acomb + an0": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 261.8333333333333 + }, + "NIG + TS(g,a) + an0": { + "mean_best": 86.33333333333333, + "std_best": 27.35365098523819, + "median_best": 100.0, + "optimum_rate": 0.8, + "mean_unique_evals": 105.8 + }, + "NIG + uniform + an0": { + "mean_best": 62.666666666666664, + "std_best": 34.92213560989012, + "median_best": 30.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 79.16666666666667 + } + }, + "mixed_nchoosek_categorical": { + "Random": { + "mean_best": 79.16666666666667, + "std_best": 14.583285714207967, + "median_best": 80.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 471.96666666666664 + }, + "UCT (+rpol)": { + "mean_best": 135.86666666666667, + "std_best": 25.627762723699128, + "median_best": 150.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 441.6333333333333 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 150.0, + "std_best": 0.0, + "median_best": 150.0, + "optimum_rate": 1.0, + "mean_unique_evals": 384.73333333333335 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 146.0, + "std_best": 14.966629547095765, + "median_best": 150.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 380.03333333333336 + }, + "NIG + TS(g,a) + vi + apv + an0": { + "mean_best": 148.0, + "std_best": 10.770329614269007, + "median_best": 150.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 388.23333333333335 + }, + "NIG + TS(g,a) + apess + an0": { + "mean_best": 142.0, + "std_best": 20.396078054371138, + "median_best": 150.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 385.2 + }, + "NIG + TS(g,a) + acomb + an0": { + "mean_best": 146.0, + "std_best": 14.966629547095765, + "median_best": 150.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 391.03333333333336 + }, + "NIG + TS(g,a) + an0": { + "mean_best": 144.0, + "std_best": 18.0, + "median_best": 150.0, + "optimum_rate": 0.9, + "mean_unique_evals": 386.3 + }, + "NIG + uniform + an0": { + "mean_best": 115.96666666666667, + "std_best": 32.08996035453387, + "median_best": 90.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 164.56666666666666 + } + }, + "large_sparse": { + "Random": { + "mean_best": 36.06666666666667, + "std_best": 6.33473668662628, + "median_best": 38.0, + "optimum_rate": 0.0, + "mean_unique_evals": 763.6333333333333 + }, + "UCT (+rpol)": { + "mean_best": 129.8, + "std_best": 70.24784694209497, + "median_best": 131.0, + "optimum_rate": 0.5, + "mean_unique_evals": 750.2666666666667 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 123.26666666666667, + "std_best": 71.87115013844026, + "median_best": 62.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 749.7333333333333 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 134.4, + "std_best": 70.17007909358517, + "median_best": 200.0, + "optimum_rate": 0.5333333333333333, + "mean_unique_evals": 688.2333333333333 + }, + "NIG + TS(g,a) + vi + apv + an0": { + "mean_best": 118.33333333333333, + "std_best": 71.52264598635098, + "median_best": 62.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 741.3666666666667 + }, + "NIG + TS(g,a) + apess + an0": { + "mean_best": 122.86666666666666, + "std_best": 72.25451927426793, + "median_best": 62.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 679.4333333333333 + }, + "NIG + TS(g,a) + acomb + an0": { + "mean_best": 107.66666666666667, + "std_best": 70.35923693604289, + "median_best": 56.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 723.4 + }, + "NIG + TS(g,a) + an0": { + "mean_best": 78.86666666666666, + "std_best": 54.297165876519024, + "median_best": 56.0, + "optimum_rate": 0.16666666666666666, + "mean_unique_evals": 588.8333333333334 + }, + "NIG + uniform + an0": { + "mean_best": 66.73333333333333, + "std_best": 52.94647192107221, + "median_best": 49.0, + "optimum_rate": 0.13333333333333333, + "mean_unique_evals": 367.2 + } + }, + "graduated_landscape": { + "Random": { + "mean_best": 60.56666666666667, + "std_best": 3.2932591084753047, + "median_best": 61.5, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 112.63333333333334 + }, + "UCT (+rpol)": { + "mean_best": 64.53333333333333, + "std_best": 1.3597385369580757, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 156.96666666666667 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 64.76666666666667, + "std_best": 0.42295258468165065, + "median_best": 65.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 122.73333333333333 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 64.86666666666666, + "std_best": 0.33993463423951903, + "median_best": 65.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 129.13333333333333 + }, + "NIG + TS(g,a) + vi + apv + an0": { + "mean_best": 64.73333333333333, + "std_best": 0.4422166387140533, + "median_best": 65.0, + "optimum_rate": 0.7333333333333333, + "mean_unique_evals": 125.93333333333334 + }, + "NIG + TS(g,a) + apess + an0": { + "mean_best": 64.83333333333333, + "std_best": 0.37267799624996495, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 132.13333333333333 + }, + "NIG + TS(g,a) + acomb + an0": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 165.63333333333333 + }, + "NIG + TS(g,a) + an0": { + "mean_best": 64.5, + "std_best": 0.5, + "median_best": 64.5, + "optimum_rate": 0.5, + "mean_unique_evals": 90.46666666666667 + }, + "NIG + uniform + an0": { + "mean_best": 63.93333333333333, + "std_best": 1.4126413400278062, + "median_best": 64.0, + "optimum_rate": 0.26666666666666666, + "mean_unique_evals": 68.6 + } + }, + "simple_additive": { + "Random": { + "mean_best": 57.666666666666664, + "std_best": 3.259175083088085, + "median_best": 58.5, + "optimum_rate": 0.0, + "mean_unique_evals": 115.43333333333334 + }, + "UCT (+rpol)": { + "mean_best": 64.1, + "std_best": 2.150193789716018, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 186.8 + }, + "NIG + TS(g,a) + vi + apv": { + "mean_best": 64.66666666666667, + "std_best": 0.7453559924999298, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 135.7 + }, + "NIG + TS(g,a) + apess": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 147.76666666666668 + }, + "NIG + TS(g,a) + vi + apv + an0": { + "mean_best": 64.36666666666666, + "std_best": 1.168569876197207, + "median_best": 65.0, + "optimum_rate": 0.7333333333333333, + "mean_unique_evals": 138.63333333333333 + }, + "NIG + TS(g,a) + apess + an0": { + "mean_best": 64.93333333333334, + "std_best": 0.35901098714230023, + "median_best": 65.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 147.93333333333334 + }, + "NIG + TS(g,a) + acomb + an0": { + "mean_best": 65.0, + "std_best": 0.0, + "median_best": 65.0, + "optimum_rate": 1.0, + "mean_unique_evals": 174.4 + }, + "NIG + TS(g,a) + an0": { + "mean_best": 64.16666666666667, + "std_best": 1.5073892072793356, + "median_best": 65.0, + "optimum_rate": 0.7333333333333333, + "mean_unique_evals": 100.6 + }, + "NIG + uniform + an0": { + "mean_best": 63.43333333333333, + "std_best": 1.9439364415764442, + "median_best": 65.0, + "optimum_rate": 0.5333333333333333, + "mean_unique_evals": 89.9 + } + } +} diff --git a/mcts-report/results_ts.json b/mcts-report/results_ts.json new file mode 100644 index 000000000..7e91efbdb --- /dev/null +++ b/mcts-report/results_ts.json @@ -0,0 +1,728 @@ +{ + "multigroup_interaction": { + "Random": { + "mean_best": 62.9, + "std_best": 10.338762014864255, + "median_best": 62.5, + "optimum_rate": 0.0, + "mean_unique_evals": 588.0 + }, + "UCT (+rpol)": { + "mean_best": 111.43333333333334, + "std_best": 23.634273606118906, + "median_best": 109.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 516.2666666666667 + }, + "UCT (no rpol)": { + "mean_best": 108.86666666666666, + "std_best": 25.145885459763704, + "median_best": 106.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 455.26666666666665 + }, + "TS + uniform": { + "mean_best": 92.86666666666666, + "std_best": 22.62702415745876, + "median_best": 97.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 96.43333333333334 + }, + "TS + TS(g,a)": { + "mean_best": 101.86666666666666, + "std_best": 26.44457516307561, + "median_best": 106.0, + "optimum_rate": 0.16666666666666666, + "mean_unique_evals": 127.23333333333333 + }, + "TS + TS(g,a) + var_infl": { + "mean_best": 121.83333333333333, + "std_best": 28.42309311496943, + "median_best": 109.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 377.8666666666667 + }, + "TS + TS(g,c,a)": { + "mean_best": 104.53333333333333, + "std_best": 22.69909445085616, + "median_best": 106.0, + "optimum_rate": 0.13333333333333333, + "mean_unique_evals": 173.5 + }, + "TS + TS(g,c,a) + var_infl": { + "mean_best": 114.3, + "std_best": 23.76853662582813, + "median_best": 109.0, + "optimum_rate": 0.26666666666666666, + "mean_unique_evals": 406.46666666666664 + }, + "TS + softmax rpol": { + "mean_best": 94.56666666666666, + "std_best": 24.90472958206042, + "median_best": 90.5, + "optimum_rate": 0.1, + "mean_unique_evals": 118.2 + }, + "TS + uniform + adpt_pv": { + "mean_best": 92.86666666666666, + "std_best": 22.42280585079, + "median_best": 90.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 98.5 + }, + "TS + TS(g,a) + adpt_pv": { + "mean_best": 101.8, + "std_best": 27.154373496731612, + "median_best": 94.0, + "optimum_rate": 0.2, + "mean_unique_evals": 168.53333333333333 + }, + "TS + TS(g,a) + vi + apv": { + "mean_best": 121.23333333333333, + "std_best": 28.696902659036603, + "median_best": 109.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 394.8 + }, + "TS + TS(g,a) + pess": { + "mean_best": 107.73333333333333, + "std_best": 24.49752821998352, + "median_best": 107.5, + "optimum_rate": 0.2, + "mean_unique_evals": 432.5 + }, + "TS + TS(g,a) + pess + apv": { + "mean_best": 104.23333333333333, + "std_best": 23.98427725731079, + "median_best": 106.0, + "optimum_rate": 0.16666666666666666, + "mean_unique_evals": 447.76666666666665 + }, + "TS + uniform + pess + apv": { + "mean_best": 110.33333333333333, + "std_best": 22.089716058735466, + "median_best": 109.0, + "optimum_rate": 0.2, + "mean_unique_evals": 404.06666666666666 + }, + "TS + TS(g,a) + comb": { + "mean_best": 115.4, + "std_best": 26.801243752234086, + "median_best": 109.0, + "optimum_rate": 0.3333333333333333, + "mean_unique_evals": 475.06666666666666 + }, + "TS + TS(g,a) + comb + apv": { + "mean_best": 101.86666666666666, + "std_best": 23.040158178469365, + "median_best": 106.0, + "optimum_rate": 0.13333333333333333, + "mean_unique_evals": 478.7 + } + }, + "needle_in_haystack": { + "Random": { + "mean_best": 39.666666666666664, + "std_best": 20.491190519071576, + "median_best": 32.5, + "optimum_rate": 0.1, + "mean_unique_evals": 216.2 + }, + "UCT (+rpol)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 283.03333333333336 + }, + "UCT (no rpol)": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 246.56666666666666 + }, + "TS + uniform": { + "mean_best": 53.833333333333336, + "std_best": 35.5359755115229, + "median_best": 30.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 41.86666666666667 + }, + "TS + TS(g,a)": { + "mean_best": 74.83333333333333, + "std_best": 33.1783898879309, + "median_best": 100.0, + "optimum_rate": 0.6333333333333333, + "mean_unique_evals": 63.3 + }, + "TS + TS(g,a) + var_infl": { + "mean_best": 89.5, + "std_best": 23.53543427826788, + "median_best": 100.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 158.56666666666666 + }, + "TS + TS(g,c,a)": { + "mean_best": 72.33333333333333, + "std_best": 33.9050963065371, + "median_best": 100.0, + "optimum_rate": 0.6, + "mean_unique_evals": 47.63333333333333 + }, + "TS + TS(g,c,a) + var_infl": { + "mean_best": 88.66666666666667, + "std_best": 25.36182608216968, + "median_best": 100.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 86.6 + }, + "TS + softmax rpol": { + "mean_best": 86.0, + "std_best": 28.0, + "median_best": 100.0, + "optimum_rate": 0.8, + "mean_unique_evals": 57.43333333333333 + }, + "TS + uniform + adpt_pv": { + "mean_best": 59.166666666666664, + "std_best": 35.96487483951837, + "median_best": 30.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 41.8 + }, + "TS + TS(g,a) + adpt_pv": { + "mean_best": 82.33333333333333, + "std_best": 29.403325586667155, + "median_best": 100.0, + "optimum_rate": 0.7333333333333333, + "mean_unique_evals": 67.03333333333333 + }, + "TS + TS(g,a) + vi + apv": { + "mean_best": 92.0, + "std_best": 20.396078054371138, + "median_best": 100.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 158.83333333333334 + }, + "TS + TS(g,a) + pess": { + "mean_best": 93.66666666666667, + "std_best": 19.014614262602215, + "median_best": 100.0, + "optimum_rate": 0.9, + "mean_unique_evals": 258.03333333333336 + }, + "TS + TS(g,a) + pess + apv": { + "mean_best": 94.0, + "std_best": 18.0, + "median_best": 100.0, + "optimum_rate": 0.9, + "mean_unique_evals": 260.3666666666667 + }, + "TS + uniform + pess + apv": { + "mean_best": 100.0, + "std_best": 0.0, + "median_best": 100.0, + "optimum_rate": 1.0, + "mean_unique_evals": 206.86666666666667 + }, + "TS + TS(g,a) + comb": { + "mean_best": 94.0, + "std_best": 18.0, + "median_best": 100.0, + "optimum_rate": 0.9, + "mean_unique_evals": 282.06666666666666 + }, + "TS + TS(g,a) + comb + apv": { + "mean_best": 96.0, + "std_best": 14.966629547095765, + "median_best": 100.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 286.56666666666666 + } + }, + "mixed_nchoosek_categorical": { + "Random": { + "mean_best": 79.16666666666667, + "std_best": 14.583285714207967, + "median_best": 80.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 471.96666666666664 + }, + "UCT (+rpol)": { + "mean_best": 135.86666666666667, + "std_best": 25.627762723699128, + "median_best": 150.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 441.6333333333333 + }, + "UCT (no rpol)": { + "mean_best": 126.96666666666667, + "std_best": 30.494243536918393, + "median_best": 150.0, + "optimum_rate": 0.6333333333333333, + "mean_unique_evals": 356.76666666666665 + }, + "TS + uniform": { + "mean_best": 95.23333333333333, + "std_best": 28.6562190263979, + "median_best": 86.0, + "optimum_rate": 0.2, + "mean_unique_evals": 94.13333333333334 + }, + "TS + TS(g,a)": { + "mean_best": 110.23333333333333, + "std_best": 33.14883942999446, + "median_best": 90.0, + "optimum_rate": 0.4, + "mean_unique_evals": 273.1666666666667 + }, + "TS + TS(g,a) + var_infl": { + "mean_best": 123.33333333333333, + "std_best": 30.586852658545226, + "median_best": 150.0, + "optimum_rate": 0.5666666666666667, + "mean_unique_evals": 341.8333333333333 + }, + "TS + TS(g,c,a)": { + "mean_best": 122.56666666666666, + "std_best": 31.62033452630689, + "median_best": 150.0, + "optimum_rate": 0.5666666666666667, + "mean_unique_evals": 270.8666666666667 + }, + "TS + TS(g,c,a) + var_infl": { + "mean_best": 131.8, + "std_best": 27.820136592044253, + "median_best": 150.0, + "optimum_rate": 0.7, + "mean_unique_evals": 347.6333333333333 + }, + "TS + softmax rpol": { + "mean_best": 106.0, + "std_best": 34.35597958628648, + "median_best": 90.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 193.76666666666668 + }, + "TS + uniform + adpt_pv": { + "mean_best": 94.53333333333333, + "std_best": 29.029563475111978, + "median_best": 85.0, + "optimum_rate": 0.2, + "mean_unique_evals": 94.53333333333333 + }, + "TS + TS(g,a) + adpt_pv": { + "mean_best": 115.66666666666667, + "std_best": 32.67958928070479, + "median_best": 90.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 233.56666666666666 + }, + "TS + TS(g,a) + vi + apv": { + "mean_best": 124.13333333333334, + "std_best": 32.12759284823907, + "median_best": 150.0, + "optimum_rate": 0.6, + "mean_unique_evals": 335.53333333333336 + }, + "TS + TS(g,a) + pess": { + "mean_best": 102.9, + "std_best": 29.01304304389091, + "median_best": 90.0, + "optimum_rate": 0.26666666666666666, + "mean_unique_evals": 352.8666666666667 + }, + "TS + TS(g,a) + pess + apv": { + "mean_best": 99.86666666666666, + "std_best": 28.543222585327598, + "median_best": 90.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 348.3666666666667 + }, + "TS + uniform + pess + apv": { + "mean_best": 130.3, + "std_best": 30.332765562451883, + "median_best": 150.0, + "optimum_rate": 0.7, + "mean_unique_evals": 317.1 + }, + "TS + TS(g,a) + comb": { + "mean_best": 112.63333333333334, + "std_best": 33.18080100834752, + "median_best": 90.0, + "optimum_rate": 0.43333333333333335, + "mean_unique_evals": 359.6 + }, + "TS + TS(g,a) + comb + apv": { + "mean_best": 98.13333333333334, + "std_best": 29.920710033167474, + "median_best": 90.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 359.23333333333335 + } + }, + "large_sparse": { + "Random": { + "mean_best": 36.06666666666667, + "std_best": 6.33473668662628, + "median_best": 38.0, + "optimum_rate": 0.0, + "mean_unique_evals": 763.6333333333333 + }, + "UCT (+rpol)": { + "mean_best": 129.8, + "std_best": 70.24784694209497, + "median_best": 131.0, + "optimum_rate": 0.5, + "mean_unique_evals": 750.2666666666667 + }, + "UCT (no rpol)": { + "mean_best": 112.13333333333334, + "std_best": 71.97209335723272, + "median_best": 62.0, + "optimum_rate": 0.4, + "mean_unique_evals": 689.3 + }, + "TS + uniform": { + "mean_best": 52.93333333333333, + "std_best": 40.466392091325474, + "median_best": 44.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 112.16666666666667 + }, + "TS + TS(g,a)": { + "mean_best": 56.8, + "std_best": 27.4109467184189, + "median_best": 56.0, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 210.63333333333333 + }, + "TS + TS(g,a) + var_infl": { + "mean_best": 84.0, + "std_best": 58.220271383771475, + "median_best": 56.0, + "optimum_rate": 0.2, + "mean_unique_evals": 575.4 + }, + "TS + TS(g,c,a)": { + "mean_best": 61.06666666666667, + "std_best": 47.499426897127464, + "median_best": 50.0, + "optimum_rate": 0.1, + "mean_unique_evals": 206.9 + }, + "TS + TS(g,c,a) + var_infl": { + "mean_best": 77.2, + "std_best": 55.405414897823846, + "median_best": 56.0, + "optimum_rate": 0.16666666666666666, + "mean_unique_evals": 544.6666666666666 + }, + "TS + softmax rpol": { + "mean_best": 92.73333333333333, + "std_best": 65.05071525783214, + "median_best": 56.0, + "optimum_rate": 0.26666666666666666, + "mean_unique_evals": 238.06666666666666 + }, + "TS + uniform + adpt_pv": { + "mean_best": 52.93333333333333, + "std_best": 40.466392091325474, + "median_best": 44.0, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 111.63333333333334 + }, + "TS + TS(g,a) + adpt_pv": { + "mean_best": 78.06666666666666, + "std_best": 55.10833774383772, + "median_best": 56.0, + "optimum_rate": 0.16666666666666666, + "mean_unique_evals": 210.96666666666667 + }, + "TS + TS(g,a) + vi + apv": { + "mean_best": 104.53333333333333, + "std_best": 67.7070815268897, + "median_best": 62.0, + "optimum_rate": 0.3333333333333333, + "mean_unique_evals": 557.6666666666666 + }, + "TS + TS(g,a) + pess": { + "mean_best": 89.0, + "std_best": 61.463810490401585, + "median_best": 56.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 639.3 + }, + "TS + TS(g,a) + pess + apv": { + "mean_best": 103.93333333333334, + "std_best": 68.15958398412427, + "median_best": 62.0, + "optimum_rate": 0.3333333333333333, + "mean_unique_evals": 651.3666666666667 + }, + "TS + uniform + pess + apv": { + "mean_best": 70.26666666666667, + "std_best": 51.286732094069805, + "median_best": 56.0, + "optimum_rate": 0.13333333333333333, + "mean_unique_evals": 543.0666666666667 + }, + "TS + TS(g,a) + comb": { + "mean_best": 84.4, + "std_best": 57.98482560118637, + "median_best": 56.0, + "optimum_rate": 0.2, + "mean_unique_evals": 671.1666666666666 + }, + "TS + TS(g,a) + comb + apv": { + "mean_best": 107.73333333333333, + "std_best": 70.47407718839287, + "median_best": 62.0, + "optimum_rate": 0.36666666666666664, + "mean_unique_evals": 672.1666666666666 + } + }, + "graduated_landscape": { + "Random": { + "mean_best": 60.56666666666667, + "std_best": 3.2932591084753047, + "median_best": 61.5, + "optimum_rate": 0.06666666666666667, + "mean_unique_evals": 112.63333333333334 + }, + "UCT (+rpol)": { + "mean_best": 64.53333333333333, + "std_best": 1.3597385369580757, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 156.96666666666667 + }, + "UCT (no rpol)": { + "mean_best": 64.7, + "std_best": 0.7810249675906655, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 152.4 + }, + "TS + uniform": { + "mean_best": 62.46666666666667, + "std_best": 2.642389491014188, + "median_best": 64.0, + "optimum_rate": 0.1, + "mean_unique_evals": 42.7 + }, + "TS + TS(g,a)": { + "mean_best": 63.43333333333333, + "std_best": 2.679344861881144, + "median_best": 65.0, + "optimum_rate": 0.5333333333333333, + "mean_unique_evals": 62.03333333333333 + }, + "TS + TS(g,a) + var_infl": { + "mean_best": 64.56666666666666, + "std_best": 0.8034647195462634, + "median_best": 65.0, + "optimum_rate": 0.7, + "mean_unique_evals": 101.9 + }, + "TS + TS(g,c,a)": { + "mean_best": 62.93333333333333, + "std_best": 3.37573037364591, + "median_best": 64.0, + "optimum_rate": 0.23333333333333334, + "mean_unique_evals": 41.766666666666666 + }, + "TS + TS(g,c,a) + var_infl": { + "mean_best": 64.56666666666666, + "std_best": 0.9195409482755815, + "median_best": 65.0, + "optimum_rate": 0.7333333333333333, + "mean_unique_evals": 80.63333333333334 + }, + "TS + softmax rpol": { + "mean_best": 58.13333333333333, + "std_best": 8.131147247194306, + "median_best": 62.5, + "optimum_rate": 0.03333333333333333, + "mean_unique_evals": 34.266666666666666 + }, + "TS + uniform + adpt_pv": { + "mean_best": 62.2, + "std_best": 2.86821663523986, + "median_best": 64.0, + "optimum_rate": 0.13333333333333333, + "mean_unique_evals": 42.53333333333333 + }, + "TS + TS(g,a) + adpt_pv": { + "mean_best": 63.56666666666667, + "std_best": 2.691756964429656, + "median_best": 65.0, + "optimum_rate": 0.6, + "mean_unique_evals": 71.93333333333334 + }, + "TS + TS(g,a) + vi + apv": { + "mean_best": 64.6, + "std_best": 0.9521904571390466, + "median_best": 65.0, + "optimum_rate": 0.7333333333333333, + "mean_unique_evals": 112.33333333333333 + }, + "TS + TS(g,a) + pess": { + "mean_best": 64.93333333333334, + "std_best": 0.2494438257849294, + "median_best": 65.0, + "optimum_rate": 0.9333333333333333, + "mean_unique_evals": 152.63333333333333 + }, + "TS + TS(g,a) + pess + apv": { + "mean_best": 64.8, + "std_best": 0.6, + "median_best": 65.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 164.26666666666668 + }, + "TS + uniform + pess + apv": { + "mean_best": 64.8, + "std_best": 0.39999999999999997, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 130.73333333333332 + }, + "TS + TS(g,a) + comb": { + "mean_best": 64.96666666666667, + "std_best": 0.17950549357115012, + "median_best": 65.0, + "optimum_rate": 0.9666666666666667, + "mean_unique_evals": 174.56666666666666 + }, + "TS + TS(g,a) + comb + apv": { + "mean_best": 64.8, + "std_best": 0.6, + "median_best": 65.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 181.2 + } + }, + "simple_additive": { + "Random": { + "mean_best": 57.666666666666664, + "std_best": 3.259175083088085, + "median_best": 58.5, + "optimum_rate": 0.0, + "mean_unique_evals": 115.43333333333334 + }, + "UCT (+rpol)": { + "mean_best": 64.1, + "std_best": 2.150193789716018, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 186.8 + }, + "UCT (no rpol)": { + "mean_best": 64.1, + "std_best": 2.2412793965352322, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 183.8 + }, + "TS + uniform": { + "mean_best": 60.5, + "std_best": 4.372261047406327, + "median_best": 60.5, + "optimum_rate": 0.3, + "mean_unique_evals": 54.0 + }, + "TS + TS(g,a)": { + "mean_best": 63.43333333333333, + "std_best": 2.538809869910615, + "median_best": 65.0, + "optimum_rate": 0.6333333333333333, + "mean_unique_evals": 71.26666666666667 + }, + "TS + TS(g,a) + var_infl": { + "mean_best": 64.16666666666667, + "std_best": 1.967796285752726, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 111.7 + }, + "TS + TS(g,c,a)": { + "mean_best": 62.7, + "std_best": 3.9509492530276815, + "median_best": 63.0, + "optimum_rate": 0.4666666666666667, + "mean_unique_evals": 61.03333333333333 + }, + "TS + TS(g,c,a) + var_infl": { + "mean_best": 64.16666666666667, + "std_best": 1.572330188676101, + "median_best": 65.0, + "optimum_rate": 0.7333333333333333, + "mean_unique_evals": 101.46666666666667 + }, + "TS + softmax rpol": { + "mean_best": 58.233333333333334, + "std_best": 8.23684135792086, + "median_best": 60.0, + "optimum_rate": 0.3333333333333333, + "mean_unique_evals": 40.7 + }, + "TS + uniform + adpt_pv": { + "mean_best": 60.46666666666667, + "std_best": 4.492462823688387, + "median_best": 62.0, + "optimum_rate": 0.26666666666666666, + "mean_unique_evals": 53.833333333333336 + }, + "TS + TS(g,a) + adpt_pv": { + "mean_best": 63.6, + "std_best": 1.781385228784985, + "median_best": 65.0, + "optimum_rate": 0.5666666666666667, + "mean_unique_evals": 77.33333333333333 + }, + "TS + TS(g,a) + vi + apv": { + "mean_best": 64.46666666666667, + "std_best": 1.359738536958076, + "median_best": 65.0, + "optimum_rate": 0.8666666666666667, + "mean_unique_evals": 127.03333333333333 + }, + "TS + TS(g,a) + pess": { + "mean_best": 64.46666666666667, + "std_best": 1.2578641509408803, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 176.43333333333334 + }, + "TS + TS(g,a) + pess + apv": { + "mean_best": 64.33333333333333, + "std_best": 1.398411797560202, + "median_best": 65.0, + "optimum_rate": 0.8, + "mean_unique_evals": 188.93333333333334 + }, + "TS + uniform + pess + apv": { + "mean_best": 64.06666666666666, + "std_best": 1.7499206331208919, + "median_best": 65.0, + "optimum_rate": 0.7666666666666667, + "mean_unique_evals": 164.43333333333334 + }, + "TS + TS(g,a) + comb": { + "mean_best": 64.53333333333333, + "std_best": 1.1175369742826806, + "median_best": 65.0, + "optimum_rate": 0.8333333333333334, + "mean_unique_evals": 192.46666666666667 + }, + "TS + TS(g,a) + comb + apv": { + "mean_best": 64.06666666666666, + "std_best": 1.4360439485692011, + "median_best": 65.0, + "optimum_rate": 0.6666666666666666, + "mean_unique_evals": 201.86666666666667 + } + } +} diff --git a/mcts-report/summary_bar_chart.png b/mcts-report/summary_bar_chart.png new file mode 100644 index 000000000..9642d4702 Binary files /dev/null and b/mcts-report/summary_bar_chart.png differ diff --git a/mcts-report/summary_bar_chart_nig.png b/mcts-report/summary_bar_chart_nig.png new file mode 100644 index 000000000..44dc6fa05 Binary files /dev/null and b/mcts-report/summary_bar_chart_nig.png differ diff --git a/mcts-report/summary_bar_chart_nig_adaptive.png b/mcts-report/summary_bar_chart_nig_adaptive.png new file mode 100644 index 000000000..0ae3dead4 Binary files /dev/null and b/mcts-report/summary_bar_chart_nig_adaptive.png differ diff --git a/mcts-report/summary_bar_chart_nig_adaptive_n0.png b/mcts-report/summary_bar_chart_nig_adaptive_n0.png new file mode 100644 index 000000000..46cb47886 Binary files /dev/null and b/mcts-report/summary_bar_chart_nig_adaptive_n0.png differ diff --git a/mcts-report/test.ipynb b/mcts-report/test.ipynb new file mode 100644 index 000000000..235c358cf --- /dev/null +++ b/mcts-report/test.ipynb @@ -0,0 +1,8869 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "import time\n", + "\n", + "import pandas as pd\n", + "import torch\n", + "from botorch.optim import optimize_acqf\n", + "from botorch.utils.sampling import draw_sobol_samples\n", + "\n", + "import bofire.surrogates.api as surrogates\n", + "from bofire.benchmarks.api import Hartmann, SpuriousFeaturesWrapper\n", + "from bofire.data_models.surrogates.api import (\n", + " EnsembleMapSaasSingleTaskGPSurrogate,\n", + " SingleTaskGPSurrogate,\n", + ")\n", + "from bofire.strategies import utils\n", + "from bofire.strategies.api import RandomStrategy, SoboStrategy\n", + "from bofire.strategies.predictives.optimize_mcts import (\n", + " MCTS,\n", + " Groups,\n", + " NChooseK,\n", + " _SelectionTracker,\n", + ")\n", + "from bofire.utils.torch_tools import tkwargs\n", + "\n", + "\n", + "benchmark = SpuriousFeaturesWrapper(Hartmann(dim=6), n_spurious_features=6, max_count=6)\n", + "random_strategy = RandomStrategy.make(domain=benchmark.domain)\n", + "\n", + "experiments = pd.read_csv(\"experiments.csv\")\n", + "\n", + "strategy = SoboStrategy.make(\n", + " domain=benchmark.domain,\n", + ")\n", + "# surrogate_specs=BotorchSurrogates(\n", + "# surrogates=[\n", + "# EnsembleMapSaasSingleTaskGPSurrogate(inputs=benchmark.domain.inputs, outputs=benchmark.domain.outputs)\n", + "# ]\n", + "# )\n", + "# )\n", + "\n", + "strategy.tell(experiments.loc[:9].copy())\n", + "acqf = strategy._get_acqfs(n=1)[0]\n", + "\n", + "bounds = utils.get_torch_bounds_from_domain(\n", + " benchmark.domain, strategy.input_preprocessing_specs\n", + ")\n", + "\n", + "\n", + "def reward_fn(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " if len(selected_features) == 0:\n", + " candidates = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + " local_bounds = bounds[:, selected_features]\n", + " sobol_samples = draw_sobol_samples(bounds=local_bounds, n=64, q=1).squeeze(1)\n", + " candidates = torch.zeros((sobol_samples.shape[0], bounds.shape[1]), **tkwargs)\n", + " candidates[:, selected_features] = sobol_samples\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + "\n", + "\n", + "def reward_fn2(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=1,\n", + " raw_samples=64,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()\n", + "\n", + "\n", + "def reward_fn3(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()\n", + "\n", + "\n", + "# candidates = random_strategy.ask(10)\n", + "# experiments = benchmark.f(candidates, return_complete=True)\n", + "\n", + "# experiments[[\"x_3\"]].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.3403870.9393840.0000000.0000000.9888050.0000000.0000000.0000000.0591740.2835080.6511150.000000-0.093813True
10.0000000.0000000.2157870.6949210.6070900.0000000.6598540.8646030.0000000.0000000.0000000.683445-0.004209True
20.1622490.0000000.0000000.0000000.6701020.0000000.0000000.8375380.3137730.0000000.1942530.959623-0.003265True
30.6468610.1510700.0000000.0239440.0000000.0000000.0000000.0000000.1645060.1478150.0000000.930434-0.005103True
40.0000000.4927420.1909610.0000000.0581920.1650810.0000000.1383720.0000000.0000000.0000000.281132-0.040627True
50.9716560.3410370.2473110.1170810.0000000.0000000.1014650.0000000.0000000.6124490.0000000.000000-0.002416True
60.0097270.3187120.0000000.0000000.5629500.5014310.0000000.0000000.1304470.7760440.0000000.000000-0.216006True
70.8733560.2732550.0000000.0000000.3955370.0000000.0000000.9134770.4222420.5581640.0000000.000000-0.007566True
80.0000000.0000000.0000000.9213050.4787820.5553820.8431830.0000000.6743220.0000000.8404980.000000-0.019061True
90.0000000.8613230.0000000.0000000.2463090.0000000.3676990.5743210.0000000.7715720.2532700.000000-0.010974True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.340387 0.939384 0.000000 0.000000 0.988805 0.000000 0.000000 \n", + "1 0.000000 0.000000 0.215787 0.694921 0.607090 0.000000 0.659854 \n", + "2 0.162249 0.000000 0.000000 0.000000 0.670102 0.000000 0.000000 \n", + "3 0.646861 0.151070 0.000000 0.023944 0.000000 0.000000 0.000000 \n", + "4 0.000000 0.492742 0.190961 0.000000 0.058192 0.165081 0.000000 \n", + "5 0.971656 0.341037 0.247311 0.117081 0.000000 0.000000 0.101465 \n", + "6 0.009727 0.318712 0.000000 0.000000 0.562950 0.501431 0.000000 \n", + "7 0.873356 0.273255 0.000000 0.000000 0.395537 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.921305 0.478782 0.555382 0.843183 \n", + "9 0.000000 0.861323 0.000000 0.000000 0.246309 0.000000 0.367699 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.059174 0.283508 0.651115 0.000000 \n", + "1 0.864603 0.000000 0.000000 0.000000 0.683445 \n", + "2 0.837538 0.313773 0.000000 0.194253 0.959623 \n", + "3 0.000000 0.164506 0.147815 0.000000 0.930434 \n", + "4 0.138372 0.000000 0.000000 0.000000 0.281132 \n", + "5 0.000000 0.000000 0.612449 0.000000 0.000000 \n", + "6 0.000000 0.130447 0.776044 0.000000 0.000000 \n", + "7 0.913477 0.422242 0.558164 0.000000 0.000000 \n", + "8 0.000000 0.674322 0.000000 0.840498 0.000000 \n", + "9 0.574321 0.000000 0.771572 0.253270 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.093813 True \n", + "1 -0.004209 True \n", + "2 -0.003265 True \n", + "3 -0.005103 True \n", + "4 -0.040627 True \n", + "5 -0.002416 True \n", + "6 -0.216006 True \n", + "7 -0.007566 True \n", + "8 -0.019061 True \n", + "9 -0.010974 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments.loc[:9]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MAEMSDR2MAPEPEARSONSPEARMANFISHER
00.0554610.017940.97192410.7924820.986920.960975.674047e-18
\n", + "
" + ], + "text/plain": [ + " MAE MSD R2 MAPE PEARSON SPEARMAN FISHER\n", + "0 0.055461 0.01794 0.971924 10.792482 0.98692 0.96097 5.674047e-18" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "surrogate_data = EnsembleMapSaasSingleTaskGPSurrogate(\n", + " inputs=benchmark.domain.inputs, outputs=benchmark.domain.outputs\n", + ")\n", + "surrogate = surrogates.map(surrogate_data)\n", + "\n", + "cv_train, cv_test, _ = surrogate.cross_validate(experiments, folds=5)\n", + "display(cv_test.get_metrics())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MAEMSDR2MAPEPEARSONSPEARMANFISHER
00.0684960.0261620.95905817.9382450.9814980.9556965.674047e-18
\n", + "
" + ], + "text/plain": [ + " MAE MSD R2 MAPE PEARSON SPEARMAN FISHER\n", + "0 0.068496 0.026162 0.959058 17.938245 0.981498 0.955696 5.674047e-18" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "surrogate_data = SingleTaskGPSurrogate(\n", + " inputs=benchmark.domain.inputs, outputs=benchmark.domain.outputs\n", + ")\n", + "surrogate = surrogates.map(surrogate_data)\n", + "\n", + "cv_train, cv_test, _ = surrogate.cross_validate(experiments, folds=5)\n", + "display(cv_test.get_metrics())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "surrogate_data = EnsembleMapSaasSingleTaskGPSurrogate(\n", + " inputs=benchmark.domain.inputs, outputs=benchmark.domain.outputs\n", + ")\n", + "surrogate = surrogates.map(surrogate_data)\n", + "\n", + "cv_train, cv_test, _ = surrogate.cross_validate(experiments, folds=5)\n", + "display(cv_test.get_metrics())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[7.9170e+03, 1.6480e+00, 7.0719e-01, 4.3675e-01, 3.9497e-01,\n", + " 5.8420e-01, 1.0000e+04, 1.0000e+04, 1.0000e+04, 1.0000e+04,\n", + " 1.0000e+04, 1.0000e+04]],\n", + "\n", + " [[1.0000e+04, 1.4312e+00, 6.6053e-01, 4.0018e-01, 3.6135e-01,\n", + " 5.3718e-01, 1.0000e+04, 1.0000e+04, 1.0000e+04, 1.0000e+04,\n", + " 1.0000e+04, 1.0000e+04]],\n", + "\n", + " [[8.8197e+03, 1.5363e+00, 6.8223e-01, 4.1844e-01, 3.7854e-01,\n", + " 5.5984e-01, 1.0000e+04, 1.0000e+04, 1.0000e+04, 1.0000e+04,\n", + " 1.0000e+04, 1.0000e+04]],\n", + "\n", + " [[1.0000e+04, 1.4326e+00, 6.6081e-01, 4.0044e-01, 3.6159e-01,\n", + " 5.3748e-01, 1.0000e+04, 1.0000e+04, 1.0000e+04, 1.0000e+04,\n", + " 1.0000e+04, 1.0000e+04]]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.base_kernel.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 1.7057, 1.8292, 29.8056, 0.3237, 0.2182, 72.7240, 85.6520,\n", + " 7.8904, 199.8355, 15.3883, 7.4227, 141.4410]],\n", + " dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2.6226e+00, 8.2782e+01, 4.1018e-01, 1.1438e+02, 3.3914e-01, 1.3650e-01,\n", + " 2.0361e+00, 8.3981e+00, 2.5162e+00, 3.5898e+02, 1.1046e+01, 4.6532e+00]],\n", + " dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_gold_standard(selections, acqf, bounds, n_features):\n", + " \"\"\"Evaluate each selection with full optimize_acqf (gold standard).\"\"\"\n", + " values = []\n", + " for sel in selections:\n", + " fixed = {i: 0.0 for i in range(n_features) if i not in sel}\n", + " if len(sel) == 0:\n", + " cand = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " val = acqf(cand.unsqueeze(-2)).max().item()\n", + " else:\n", + " cand, acq_val = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features=fixed,\n", + " )\n", + " val = acq_val.item()\n", + " values.append(val)\n", + " return values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[-45.3893631573142]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluate_gold_standard([(0, 1, 2, 3, 11)], acqf, bounds, n_features=12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-5.114861134566571" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reward_fn2((0, 1, 2, 3, 4, 5), cat_selections={})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x_5 44.561275\n", + "dtype: float64" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments[[\"x_5\"]].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.4024700.0000000.0000000.0000000.0000000.0000000.0000000.6374530.1399190.4218240.6713710.331502-0.0056831
10.4904710.2944280.0000000.0000000.0000000.9252190.1794820.7729880.4809430.0000000.0000000.000000-0.0967351
20.8312990.0000000.1528290.2337560.0000000.2376550.2170320.0000000.5361540.0000000.0000000.000000-0.0417131
30.0000000.8293810.0000000.2879120.0000000.0000000.7741490.0000000.0000000.5073140.7311190.425413-0.0825721
40.0000000.8670440.1348860.0000000.2504430.2510050.7662670.0000000.0214130.0000000.0000000.000000-0.0457211
50.0000000.3077290.8427450.1164260.0000000.2726150.0000000.0000000.0000000.7787310.0000000.155062-0.0971321
60.2566110.4944050.0000000.6031580.0000000.0000000.8366570.0000000.0000000.8080900.0000000.134337-0.6195871
70.9262710.0000000.0000000.3188560.0000000.0000000.6645890.0000000.4928020.0531630.0000000.259605-0.0032261
80.0000000.2191180.6250430.7884450.0000000.0000000.0000000.4703430.0000000.0000000.3855200.006767-0.0065051
90.9737060.0000000.0000000.0623920.0000000.6765590.0000000.8051040.0000000.2609770.4023370.000000-0.0542801
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.402470 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "1 0.490471 0.294428 0.000000 0.000000 0.000000 0.925219 0.179482 \n", + "2 0.831299 0.000000 0.152829 0.233756 0.000000 0.237655 0.217032 \n", + "3 0.000000 0.829381 0.000000 0.287912 0.000000 0.000000 0.774149 \n", + "4 0.000000 0.867044 0.134886 0.000000 0.250443 0.251005 0.766267 \n", + "5 0.000000 0.307729 0.842745 0.116426 0.000000 0.272615 0.000000 \n", + "6 0.256611 0.494405 0.000000 0.603158 0.000000 0.000000 0.836657 \n", + "7 0.926271 0.000000 0.000000 0.318856 0.000000 0.000000 0.664589 \n", + "8 0.000000 0.219118 0.625043 0.788445 0.000000 0.000000 0.000000 \n", + "9 0.973706 0.000000 0.000000 0.062392 0.000000 0.676559 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.637453 0.139919 0.421824 0.671371 0.331502 \n", + "1 0.772988 0.480943 0.000000 0.000000 0.000000 \n", + "2 0.000000 0.536154 0.000000 0.000000 0.000000 \n", + "3 0.000000 0.000000 0.507314 0.731119 0.425413 \n", + "4 0.000000 0.021413 0.000000 0.000000 0.000000 \n", + "5 0.000000 0.000000 0.778731 0.000000 0.155062 \n", + "6 0.000000 0.000000 0.808090 0.000000 0.134337 \n", + "7 0.000000 0.492802 0.053163 0.000000 0.259605 \n", + "8 0.470343 0.000000 0.000000 0.385520 0.006767 \n", + "9 0.805104 0.000000 0.260977 0.402337 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.005683 1 \n", + "1 -0.096735 1 \n", + "2 -0.041713 1 \n", + "3 -0.082572 1 \n", + "4 -0.045721 1 \n", + "5 -0.097132 1 \n", + "6 -0.619587 1 \n", + "7 -0.003226 1 \n", + "8 -0.006505 1 \n", + "9 -0.054280 1 " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_10x_spurious_11...x_spurious_41x_spurious_42x_spurious_43x_spurious_5x_spurious_6x_spurious_7x_spurious_8x_spurious_9yvalid_y
00.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.2026600.0000000.0000000.0000000.0000000.0000000.280977-0.0050891
10.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0000000.5171240.0000000.7295480.7651010.0000000.030946-0.0050891
20.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0000000.0000000.0128140.8652780.0000000.0000000.377278-0.0050891
30.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.8330690.2585300.0000000.0000000.0000000.0000000.885180-0.0050891
40.00.00.00.00.00.3797580.00.0000000.00.0...0.0000000.0000000.0000000.0000000.0000000.0516230.0000000.207810-0.0905331
..................................................................
2510.00.00.00.00.00.0000000.00.0000000.00.0...0.6772340.0000000.0000000.8742940.0000000.0000000.8935000.198800-0.0050891
2520.00.00.00.00.00.0000000.00.9617930.00.0...0.0000000.4916440.0000000.0000000.6685070.0000000.0000000.449542-0.0050891
2530.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0000000.0000000.0000000.3692730.0000000.0000000.862634-0.0050891
2540.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0000000.7771760.9883990.0000000.6791650.0000000.205755-0.0050891
2550.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0286050.0000000.0000000.0000000.0000000.5994050.957740-0.0050891
\n", + "

256 rows × 52 columns

\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 x_spurious_1 \\\n", + "0 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "1 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "2 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "3 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "4 0.0 0.0 0.0 0.0 0.0 0.379758 0.0 0.000000 \n", + ".. ... ... ... ... ... ... ... ... \n", + "251 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "252 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.961793 \n", + "253 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "254 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "255 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "\n", + " x_spurious_10 x_spurious_11 ... x_spurious_41 x_spurious_42 \\\n", + "0 0.0 0.0 ... 0.000000 0.202660 \n", + "1 0.0 0.0 ... 0.000000 0.000000 \n", + "2 0.0 0.0 ... 0.000000 0.000000 \n", + "3 0.0 0.0 ... 0.000000 0.833069 \n", + "4 0.0 0.0 ... 0.000000 0.000000 \n", + ".. ... ... ... ... ... \n", + "251 0.0 0.0 ... 0.677234 0.000000 \n", + "252 0.0 0.0 ... 0.000000 0.491644 \n", + "253 0.0 0.0 ... 0.000000 0.000000 \n", + "254 0.0 0.0 ... 0.000000 0.000000 \n", + "255 0.0 0.0 ... 0.000000 0.028605 \n", + "\n", + " x_spurious_43 x_spurious_5 x_spurious_6 x_spurious_7 x_spurious_8 \\\n", + "0 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "1 0.517124 0.000000 0.729548 0.765101 0.000000 \n", + "2 0.000000 0.012814 0.865278 0.000000 0.000000 \n", + "3 0.258530 0.000000 0.000000 0.000000 0.000000 \n", + "4 0.000000 0.000000 0.000000 0.051623 0.000000 \n", + ".. ... ... ... ... ... \n", + "251 0.000000 0.874294 0.000000 0.000000 0.893500 \n", + "252 0.000000 0.000000 0.668507 0.000000 0.000000 \n", + "253 0.000000 0.000000 0.369273 0.000000 0.000000 \n", + "254 0.777176 0.988399 0.000000 0.679165 0.000000 \n", + "255 0.000000 0.000000 0.000000 0.000000 0.599405 \n", + "\n", + " x_spurious_9 y valid_y \n", + "0 0.280977 -0.005089 1 \n", + "1 0.030946 -0.005089 1 \n", + "2 0.377278 -0.005089 1 \n", + "3 0.885180 -0.005089 1 \n", + "4 0.207810 -0.090533 1 \n", + ".. ... ... ... \n", + "251 0.198800 -0.005089 1 \n", + "252 0.449542 -0.005089 1 \n", + "253 0.862634 -0.005089 1 \n", + "254 0.205755 -0.005089 1 \n", + "255 0.957740 -0.005089 1 \n", + "\n", + "[256 rows x 52 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((3, 8, 17, 18, 20, 49), {}, [])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "groups = Groups(groups=[NChooseK(features=list(range(50)), max_count=6, min_count=0)])\n", + "mcts = MCTS(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=lambda x, y: 0.0,\n", + " rollout_mode=\"uniform_subset\",\n", + " p_stop_rollout=0.0,\n", + ")\n", + "mcts._rollout(mcts.root)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "strategy = SoboStrategy.make(domain=benchmark.domain)\n", + "\n", + "strategy.tell(experiments)\n", + "acqf = strategy._get_acqfs(n=1)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean time per acqf(X) call: 1.853 ms\n", + "Min/Max per call: 1.054 / 14.629 ms\n" + ] + } + ], + "source": [ + "transformed = strategy.domain.inputs.transform(\n", + " candidates,\n", + " strategy.input_preprocessing_specs,\n", + ")\n", + "X = torch.from_numpy(transformed.values).to(**tkwargs)\n", + "X = X.unsqueeze(-2)\n", + "\n", + "n_calls = 100\n", + "call_times = []\n", + "\n", + "with torch.no_grad():\n", + " for _ in range(n_calls):\n", + " t0 = time.perf_counter()\n", + " _ = acqf(X)\n", + " call_times.append(time.perf_counter() - t0)\n", + "\n", + "mean_time = sum(call_times) / n_calls\n", + "print(f\"Mean time per acqf(X) call: {mean_time * 1e3:.3f} ms\")\n", + "print(f\"Min/Max per call: {min(call_times) * 1e3:.3f} / {max(call_times) * 1e3:.3f} ms\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-33.6191, -32.9385, -33.6550, -33.6543, -31.8275, -33.6546, -33.6545,\n", + " -33.6547, -33.6544, -33.6546, -33.6545, -31.5978, -33.6528, -33.6542,\n", + " -33.6542, -33.6536, -33.6542, -32.0027, -33.6554, -33.6563, -33.6516,\n", + " -33.6549, -33.5973, -33.6491, -33.6549, -33.6451, -33.6546, -33.6551,\n", + " -33.6543, -33.6530, -33.6553, -33.6550, -33.6532, -33.6542, -33.6548,\n", + " -33.6539, -33.6533, -33.6483, -33.6543, -33.6543, -33.6543, -33.6373,\n", + " -33.6253, -33.6540, -33.6504, -33.6526, -33.6544, -33.6551, -33.6510,\n", + " -33.6544, -33.6547, -33.6532, -33.6543, -33.6543, -33.6550, -33.6545,\n", + " -33.6505, -33.6542, -33.6544, -33.6712, -33.6541, -33.6545, -33.6533,\n", + " -33.6550], dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "acqf(X)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-10.2082, dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "domain2 = copy.deepcopy(benchmark.domain)\n", + "domain2.constraints.constraints = []\n", + "for i in range(6, 50):\n", + " domain2.inputs.features[i].bounds = (0.0, 0.0)\n", + "\n", + "candidates2 = domain2.inputs.sample(64)\n", + "\n", + "transformed = strategy.domain.inputs.transform(\n", + " candidates2,\n", + " strategy.input_preprocessing_specs,\n", + ")\n", + "X = torch.from_numpy(transformed.values).to(**tkwargs)\n", + "X = X.unsqueeze(-2)\n", + "\n", + "acqf(X).max()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((0, 1, 10, 11), {}, -2.54028550083751)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from bofire.strategies import utils\n", + "\n", + "\n", + "groups = Groups(\n", + " groups=[\n", + " NChooseK(\n", + " features=list(range(len(benchmark.domain.inputs.get_keys()))),\n", + " max_count=6,\n", + " min_count=0,\n", + " )\n", + " ]\n", + ")\n", + "mcts = MCTS(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=lambda x, y: 0.0,\n", + " rollout_mode=\"uniform_subset\",\n", + " p_stop_rollout=0.0,\n", + ")\n", + "bounds = utils.get_torch_bounds_from_domain(\n", + " benchmark.domain, strategy.input_preprocessing_specs\n", + ")\n", + "\n", + "\n", + "def reward_fn(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " if len(selected_features) == 0:\n", + " candidates = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + " local_bounds = bounds[:, selected_features]\n", + " sobol_samples = draw_sobol_samples(bounds=local_bounds, n=64, q=1).squeeze(1)\n", + " candidates = torch.zeros((sobol_samples.shape[0], bounds.shape[1]), **tkwargs)\n", + " candidates[:, selected_features] = sobol_samples\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + "\n", + "\n", + "def reward_fn2(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.inputs.domain.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=1,\n", + " raw_samples=64,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()\n", + "\n", + "\n", + "mcts = MCTS(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=reward_fn,\n", + " use_cache=False,\n", + " rollout_mode=\"uniform_subset\",\n", + ")\n", + "mcts.run(n_iterations=2000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(0, 1, 6, 8)\n", + "-2.5978651482116755\n" + ] + } + ], + "source": [ + "leaf, path = mcts._select_and_expand()\n", + "selected_features, cat_selections = mcts._get_selection(leaf)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))\n", + "# print(reward_fn2(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-2.6487569183913315\n" + ] + } + ], + "source": [ + "print(reward_fn((0, 1, 2, 3, 4, 5), cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5
00.4115960.9102150.00.00.00.00.00.00.00.00.00.878617
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 x_spurious_1 \\\n", + "0 0.411596 0.910215 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \n", + "0 0.0 0.0 0.0 0.878617 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-2.5506, dtype=torch.float64)\n" + ] + } + ], + "source": [ + "candidate, acqf_val = get_candidate()\n", + "\n", + "display(candidate)\n", + "print(acqf_val)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--- Iteration 0 ---\n", + " n_experiments: 20\n", + "[(4, 9), (4, 9, 11), (4, 9, 10, 11), (4, 5, 7, 9), (4, 9, 10), (4, 9), (4, 9), (4, 9), (4, 9), (4, 9), (4, 5, 8, 9), (4, 5, 8, 9), (4, 5, 9, 11), (4, 5, 8, 9), (4, 5, 9, 10)]\n", + " gold standard acq value: -5.020887\n", + " acq_val from get_candidate: -4.533155618741909\n", + " new y value: [-0.02122513]\n", + " experiments: 20 -> 21\n", + "\n", + "--- Iteration 1 ---\n", + " n_experiments: 21\n", + "[(4, 5, 9), (4, 5, 9, 11), (4, 5, 7, 9), (4, 5, 7, 9, 10), (4, 5, 8, 9, 11), (4, 5, 9), (4, 5, 9, 11), (4, 5, 9), (4, 5, 9), (4, 5, 9), (4, 5, 8, 9), (4, 5, 8, 9), (4, 5, 9, 11), (4, 5, 8, 9), (4, 5, 9, 10)]\n", + " gold standard acq value: -4.170976\n", + " acq_val from get_candidate: -2.838282119570516\n", + " new y value: [-0.71357419]\n", + " experiments: 21 -> 22\n", + "\n", + "--- Iteration 2 ---\n", + " n_experiments: 22\n", + "[(3, 4, 5, 10), (5, 8, 9, 11), (5, 8, 9), (5, 8, 9, 10, 11), (5, 8, 9, 10), (5, 8, 9), (5, 8, 9, 11), (5, 8, 9, 11), (5,), (5, 8, 9, 11), (4, 5, 8, 9, 11), (4, 5, 8, 9), (4, 5, 8, 9), (4, 5, 9, 11), (4, 5, 8, 9)]\n", + " gold standard acq value: -2.931852\n", + " acq_val from get_candidate: -2.1567612779356233\n", + " new y value: [-0.83214312]\n", + " experiments: 22 -> 23\n", + "\n", + "--- Iteration 3 ---\n", + " n_experiments: 23\n", + "[(1, 4, 5, 6), (5, 6, 9), (5, 6, 9, 10), (5, 6, 9, 10, 11), (5, 6, 9, 11), (5, 7, 10), (5, 8), (5, 10, 11), (5, 8, 10), (5,), (4, 5, 8, 9, 11), (4, 5, 8, 9, 11), (4, 5, 8, 9), (4, 5, 8, 9), (4, 5, 9, 11)]\n", + " gold standard acq value: -4.438245\n", + " acq_val from get_candidate: -4.17621804103821\n", + " new y value: [-0.82884893]\n", + " experiments: 23 -> 24\n", + "\n", + "--- Iteration 4 ---\n", + " n_experiments: 24\n", + "[(2, 3, 4, 5, 9), (4, 5, 9), (5, 6, 8), (5, 6, 8, 11), (5, 6), (5, 6), (5, 10), (5, 6, 7, 8, 11), (5, 6), (5, 6, 7, 10, 11), (4, 5, 8, 9, 11), (4, 5, 8, 11), (4, 5, 8, 9, 11), (4, 5, 8, 9), (4, 5, 8, 9)]\n", + " gold standard acq value: -4.235442\n", + " acq_val from get_candidate: -4.235442451732192\n", + " new y value: [-1.33296541]\n", + " experiments: 24 -> 25\n", + "\n", + "--- Iteration 5 ---\n", + " n_experiments: 25\n", + "[(2, 4, 5), (2, 4, 5, 7, 8, 9), (2, 5, 9), (2, 5, 9, 10), (2, 5, 7, 9, 10), (2, 5), (2, 5, 7), (2, 5, 9, 11), (2, 5, 10), (2, 5, 11), (2, 4, 5), (4, 5, 8, 9, 11), (4, 5, 8, 11), (4, 5, 8, 9, 11), (4, 5, 8, 9)]\n", + " gold standard acq value: -3.922962\n", + " acq_val from get_candidate: -3.8834520182626724\n", + " new y value: [-1.0491368]\n", + " experiments: 25 -> 26\n", + "\n", + "--- Iteration 6 ---\n", + " n_experiments: 26\n", + "[(5, 6, 10), (5, 6, 10, 11), (5, 6), (5, 6, 7, 10), (5, 6, 11), (5, 6, 10), (5, 6, 8), (5, 6, 10), (5, 6, 11), (5, 6, 10), (2, 4, 5), (2, 4, 5, 7, 8, 9), (4, 5, 8, 9, 11), (4, 5, 8, 11), (4, 5, 8, 9, 11)]\n", + " gold standard acq value: -3.827973\n", + " acq_val from get_candidate: -3.8671186867803575\n", + " new y value: [-0.98997438]\n", + " experiments: 26 -> 27\n", + "\n", + "--- Iteration 7 ---\n", + " n_experiments: 27\n", + "[(2, 4, 5), (2, 4, 5, 10), (2, 4, 5, 11), (2, 4, 5, 7, 8, 10), (2, 4, 5, 9), (2, 4, 5), (2, 4, 5), (2, 4, 5), (2, 4, 5), (2, 4, 5, 10, 11), (2, 4, 5), (2, 4, 5, 7, 8, 9), (2, 4, 5), (4, 5, 8, 9, 11), (4, 5, 8, 11)]\n", + " gold standard acq value: -2.048580\n", + " acq_val from get_candidate: -2.0485660249208246\n", + " new y value: [-1.43958713]\n", + " experiments: 27 -> 28\n", + "\n", + "--- Iteration 8 ---\n", + " n_experiments: 28\n", + "[(1, 2, 4, 5, 10, 11), (1, 2, 4, 5, 10), (1, 2, 4, 5), (1, 2, 4, 5, 9, 10), (1, 2, 4, 5, 6, 11), (1, 2, 4, 5, 10), (1, 2, 4, 5, 8), (1, 2, 4, 5, 9, 10), (1, 2, 4, 5, 7, 10), (1, 2, 4, 5, 7, 11), (2, 4, 5, 10), (2, 4, 5), (2, 4, 5, 7, 8, 9), (2, 4, 5), (4, 5, 8, 9, 11)]\n", + " gold standard acq value: -3.749658\n", + " acq_val from get_candidate: -3.724342042966142\n", + " new y value: [-0.14847842]\n", + " experiments: 28 -> 29\n", + "\n", + "--- Iteration 9 ---\n", + " n_experiments: 29\n", + "[(4, 5, 6, 7), (4, 5, 6, 7, 10), (4, 5, 6, 10), (4, 5, 6, 7, 11), (4, 5, 6, 11), (4, 5, 6, 7, 10), (4, 5, 6, 7), (4, 5, 6, 7), (4, 5, 6, 7, 10), (4, 5, 6, 7, 10), (2, 4, 5, 10), (2, 4, 5), (2, 4, 5, 7, 8, 9), (2, 4, 5), (4, 5, 8, 9, 11)]\n", + " gold standard acq value: -3.994669\n", + " acq_val from get_candidate: -3.8921717607761463\n", + " new y value: [-1.42693456]\n", + " experiments: 29 -> 30\n", + "\n", + "--- Iteration 10 ---\n", + " n_experiments: 30\n", + "[(2, 4, 5, 6, 11), (2, 4, 5, 6), (2, 4, 5, 6, 8), (2, 4, 5, 6, 10), (2, 4, 5, 6, 9, 11), (2, 4, 5, 6, 7), (2, 4, 5, 6), (2, 4, 5, 6), (2, 4, 5, 6, 11), (2, 4, 5, 6, 9), (2, 4, 5, 10), (2, 4, 5, 10), (2, 4, 5), (2, 4, 5, 7, 8, 9), (2, 4, 5)]\n", + " gold standard acq value: -4.086744\n", + " acq_val from get_candidate: -3.9347912207570523\n", + " new y value: [-1.44561938]\n", + " experiments: 30 -> 31\n", + "\n", + "--- Iteration 11 ---\n", + " n_experiments: 31\n", + "[(2, 4, 5, 11), (2, 4, 5, 10, 11), (2, 4, 5, 8, 10, 11), (2, 4, 5), (2, 4, 5, 6, 10, 11), (2, 4, 5), (2, 4, 5), (2, 4, 5), (2, 4, 5, 11), (2, 4, 5, 11), (2, 4, 5, 6), (2, 4, 5, 10), (2, 4, 5, 10), (2, 4, 5), (2, 4, 5, 7, 8, 9)]\n", + " gold standard acq value: -4.373888\n", + " acq_val from get_candidate: -4.111544118172005\n", + " new y value: [-1.44501041]\n", + " experiments: 31 -> 32\n", + "\n", + "--- Iteration 12 ---\n", + " n_experiments: 32\n", + "[(2, 4, 5, 11), (2, 4, 5, 10, 11), (2, 4, 5), (2, 4, 5, 10), (2, 4, 5, 6, 11), (2, 4, 5, 11), (2, 4, 5, 10), (2, 4, 5), (2, 4, 5, 8), (2, 4, 5), (2, 4, 5, 6), (2, 4, 5, 6, 10, 11), (2, 4, 5, 10), (2, 4, 5, 10), (2, 4, 5)]\n", + " gold standard acq value: -4.408571\n", + " acq_val from get_candidate: -4.251813218611304\n", + " new y value: [-1.33234524]\n", + " experiments: 32 -> 33\n", + "\n", + "--- Iteration 13 ---\n", + " n_experiments: 33\n", + "[(2, 4, 5, 6, 7), (4, 5, 7, 10), (2, 4, 8, 9, 11), (2, 4), (4, 5, 6, 7, 11), (4, 5, 10), (4, 5), (4, 8, 9, 11), (3, 6), (4, 5, 10, 11), (2, 4, 5, 6), (2, 4, 5, 6, 10, 11), (2, 4, 5, 10), (2, 4, 5, 10), (2, 4, 5)]\n", + " gold standard acq value: -4.424654\n", + " acq_val from get_candidate: -3.691387569031728\n", + " new y value: [-1.42545218]\n", + " experiments: 33 -> 34\n", + "\n", + "--- Iteration 14 ---\n", + " n_experiments: 34\n", + "[(2, 4, 5, 6, 7), (2, 4, 5, 6, 7, 11), (2, 4, 5, 6), (2, 4, 5, 6, 10, 11), (2, 4, 5, 6, 7, 8), (2, 4, 5, 6, 7, 11), (2, 4, 5, 6, 7, 11), (2, 4, 5, 6), (2, 4, 5, 6), (2, 4, 5, 6, 7, 8), (2, 4, 5, 6), (2, 4, 5, 6, 10, 11), (2, 4, 5, 10), (2, 4, 5, 10), (2, 4, 5, 6, 7)]\n", + " gold standard acq value: -4.771351\n", + " acq_val from get_candidate: -4.030449971881309\n", + " new y value: [-1.40162746]\n", + " experiments: 34 -> 35\n", + "\n", + "--- Iteration 15 ---\n", + " n_experiments: 35\n", + "[(2, 4), (3, 4, 6, 10), (3, 4, 11), (3, 4, 5), (3, 4, 10, 11), (3, 5, 10), (3, 5, 11), (3, 5), (3, 5), (3, 5, 7), (2, 4, 5, 6), (2, 4, 5, 6, 10, 11), (2, 4, 5, 10), (2, 4, 5, 10), (2, 4, 5, 6, 7)]\n", + " gold standard acq value: -4.871003\n", + " acq_val from get_candidate: -3.99540626615538\n", + " new y value: [-1.43894863]\n", + " experiments: 35 -> 36\n", + "\n", + "--- Iteration 16 ---\n", + " n_experiments: 36\n", + "[(2, 3, 4, 5, 7, 10), (2, 4, 5, 6, 9), (2, 4, 5, 6, 9, 10), (2, 4, 5, 6, 8), (2, 4, 5, 6), (2, 4, 7, 8, 11), (2, 4, 7, 10, 11), (2, 4, 11), (2, 4, 10), (2, 4, 10, 11), (2, 4, 5, 6), (2, 4, 5, 6, 10, 11), (2, 4, 5, 10), (2, 4, 5, 6, 10), (2, 4, 5, 10)]\n", + " gold standard acq value: -4.867408\n", + " acq_val from get_candidate: -4.835264900422188\n", + " new y value: [-2.16610811]\n", + " experiments: 36 -> 37\n", + "\n", + "--- Iteration 17 ---\n", + " n_experiments: 37\n", + "[(4, 5, 10), (4, 5, 6, 10), (4, 5, 6, 10, 11), (4, 5, 6, 9, 10), (4, 5, 6), (4, 5, 10), (4, 5, 10), (4, 5, 10), (4, 5, 10), (4, 5, 10), (2, 3, 4, 5, 10), (2, 4, 5, 6), (2, 4, 5, 6, 10, 11), (2, 4, 5, 10), (2, 4, 5, 6, 10)]\n", + " gold standard acq value: -3.894098\n", + " acq_val from get_candidate: -3.894098176330702\n", + " new y value: [-2.32923849]\n", + " experiments: 37 -> 38\n", + "\n", + "--- Iteration 18 ---\n", + " n_experiments: 38\n", + "[(3, 5), (3, 5, 11), (3, 5, 9), (3, 5, 7), (3, 5, 10), (3, 5, 11), (3, 5, 10), (3, 5, 10), (3, 5, 7, 8, 11), (3, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 4, 5, 6), (2, 4, 5, 6, 10, 11), (2, 4, 5, 10)]\n", + " gold standard acq value: -3.081678\n", + " acq_val from get_candidate: -3.0816784679911864\n", + " new y value: [-2.48386628]\n", + " experiments: 38 -> 39\n", + "\n", + "--- Iteration 19 ---\n", + " n_experiments: 39\n", + "[(1, 3, 4, 5), (1, 3, 4, 5, 6, 9), (1, 3, 4, 5, 8, 11), (1, 3, 4, 5, 6, 10), (1, 3, 5, 9, 10), (1, 3, 5, 10), (1, 3, 5), (1, 3, 5), (1, 3, 5, 9, 10, 11), (1, 3, 5), (2, 3, 4, 5), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 4, 5, 6), (2, 4, 5, 6, 10, 11)]\n", + " gold standard acq value: -3.445724\n", + " acq_val from get_candidate: -3.4345690538198923\n", + " new y value: [-2.54738421]\n", + " experiments: 39 -> 40\n", + "\n", + "--- Iteration 20 ---\n", + " n_experiments: 40\n", + "[(3, 4, 5, 8, 11), (3, 4, 5, 8, 10, 11), (3, 4, 5, 8, 9, 11), (3, 4, 5, 8, 9), (3, 4, 5, 8), (3, 4, 5, 8, 11), (3, 4, 5, 8, 9, 11), (3, 4, 5, 8, 10), (3, 4, 5, 8), (3, 4, 5, 8, 11), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 4, 5, 6)]\n", + " gold standard acq value: -3.722040\n", + " acq_val from get_candidate: -3.716969376047603\n", + " new y value: [-2.16046796]\n", + " experiments: 40 -> 41\n", + "\n", + "--- Iteration 21 ---\n", + " n_experiments: 41\n", + "[(3, 4, 5, 8, 9, 10), (3, 4, 5), (3, 4, 5, 6), (3, 4, 5, 9), (3, 5, 6, 10), (3, 5, 8, 11), (3, 5, 7, 11), (3, 5, 6), (3, 5, 9), (3, 5, 7), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -2.133197\n", + " acq_val from get_candidate: -2.1139123274295084\n", + " new y value: [-2.56346595]\n", + " experiments: 41 -> 42\n", + "\n", + "--- Iteration 22 ---\n", + " n_experiments: 42\n", + "[(3, 5, 6, 8, 10), (3, 5, 6, 10), (3, 5, 6, 10, 11), (3, 5, 10), (3, 5, 6, 7, 10, 11), (3, 5, 10), (3, 5, 10, 11), (3, 5, 10), (3, 5, 10, 11), (3, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -3.666280\n", + " acq_val from get_candidate: -3.6662801937855947\n", + " new y value: [-2.55144461]\n", + " experiments: 42 -> 43\n", + "\n", + "--- Iteration 23 ---\n", + " n_experiments: 43\n", + "[(3, 4, 5, 10), (3, 4, 5, 7, 8, 10), (3, 4, 5), (3, 4, 5, 11), (3, 4, 5, 8, 11), (3, 5, 10, 11), (3, 5, 8), (3, 5, 9), (3, 5, 7, 8), (3, 5, 6, 8, 9, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5)]\n", + " gold standard acq value: -3.957638\n", + " acq_val from get_candidate: -4.144677988377689\n", + " new y value: [-2.55796531]\n", + " experiments: 43 -> 44\n", + "\n", + "--- Iteration 24 ---\n", + " n_experiments: 44\n", + "[(3, 4, 5, 6, 10), (3, 4, 5, 6, 10, 11), (3, 4, 5, 6), (3, 4, 5, 6, 9), (3, 4, 5, 7, 10, 11), (3, 4, 5, 9), (3, 4, 5, 10, 11), (3, 4, 5, 6), (3, 4, 5), (3, 4, 5, 10, 11), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5)]\n", + " gold standard acq value: -3.970401\n", + " acq_val from get_candidate: -4.366384295807882\n", + " new y value: [-2.5421676]\n", + " experiments: 44 -> 45\n", + "\n", + "--- Iteration 25 ---\n", + " n_experiments: 45\n", + "[(2, 8, 10), (2, 8, 10, 11), (2, 8, 9, 10), (2, 7, 8, 10, 11), (2, 3, 8, 10, 11), (2, 8, 10, 11), (2, 8, 10, 11), (2, 8, 10), (2, 8, 10), (2, 8), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5)]\n", + " gold standard acq value: -3.370748\n", + " acq_val from get_candidate: -2.1903766575845784\n", + " new y value: [-0.00056046]\n", + " experiments: 45 -> 46\n", + "\n", + "--- Iteration 26 ---\n", + " n_experiments: 46\n", + "[(3, 5, 6, 7, 11), (3, 5, 6, 7, 8, 11), (3, 5, 6, 8, 9), (3, 5, 6, 9, 11), (3, 5, 6, 9, 10), (3, 5, 10, 11), (3, 5, 9, 10), (3, 5, 6, 8, 10), (3, 5, 8, 10, 11), (3, 5, 8, 9), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5)]\n", + " gold standard acq value: -4.468404\n", + " acq_val from get_candidate: -5.348998922143423\n", + " new y value: [-2.61948882]\n", + " experiments: 46 -> 47\n", + "\n", + "--- Iteration 27 ---\n", + " n_experiments: 47\n", + "[(3, 4, 5, 9, 10), (3, 5, 6, 10), (3, 5, 6, 10, 11), (3, 5, 10, 11), (3, 5, 10), (3, 5, 6, 7, 11), (3, 5, 11), (3, 5, 6, 10, 11), (3, 5, 6, 7, 9, 11), (3, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -4.176373\n", + " acq_val from get_candidate: -6.3884062874851235\n", + " new y value: [-2.56520316]\n", + " experiments: 47 -> 48\n", + "\n", + "--- Iteration 28 ---\n", + " n_experiments: 48\n", + "[(3, 5, 7, 10, 11), (3, 5, 9, 10), (3, 5, 10, 11), (3, 5, 10), (3, 5, 7, 8, 9, 10), (3, 5), (3, 5, 10, 11), (3, 5), (3, 5), (3, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5)]\n", + " gold standard acq value: -3.736522\n", + " acq_val from get_candidate: -5.7587851443915845\n", + " new y value: [-2.61834514]\n", + " experiments: 48 -> 49\n", + "\n", + "--- Iteration 29 ---\n", + " n_experiments: 49\n", + "[(3, 4, 5, 6, 8, 9), (3, 4, 5, 8), (3, 4, 5, 7, 8, 10), (3, 4, 5, 7, 8, 9), (3, 4, 5, 8, 10, 11), (3, 11), (3, 5, 6, 7, 9, 11), (3, 8), (3,), (5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -4.019669\n", + " acq_val from get_candidate: -3.332510200235144\n", + " new y value: [-1.03748345]\n", + " experiments: 49 -> 50\n", + "\n", + "--- Iteration 30 ---\n", + " n_experiments: 50\n", + "[(1, 3, 4, 5, 6), (3, 5, 6, 7, 11), (3, 5, 6, 9, 11), (3, 5, 6), (3, 5, 6, 10), (3, 5, 6), (3, 5, 6), (3, 5, 9), (3, 5, 9, 10, 11), (3, 5, 9, 11), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -4.438632\n", + " acq_val from get_candidate: -5.70489629875731\n", + " new y value: [-2.61545509]\n", + " experiments: 50 -> 51\n", + "\n", + "--- Iteration 31 ---\n", + " n_experiments: 51\n", + "[(3, 4, 5, 6, 7, 10), (3, 4, 5, 6, 7), (3, 4, 5, 6, 7, 11), (3, 4, 5, 6), (3, 4, 5, 6, 10), (3, 4, 5, 9), (3, 4, 5, 6, 7), (3, 4, 5, 7, 8, 10), (3, 4, 5, 6, 11), (3, 4, 5, 6, 7, 9), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -3.790677\n", + " acq_val from get_candidate: -3.9662806555266967\n", + " new y value: [-1.90782299]\n", + " experiments: 51 -> 52\n", + "\n", + "--- Iteration 32 ---\n", + " n_experiments: 52\n", + "[(3, 5, 9, 10, 11), (3, 5, 9, 10), (3, 5, 6, 9, 10, 11), (3, 5, 9, 11), (3, 5, 9), (3, 5), (3, 5), (3, 5, 9, 11), (3, 5, 9, 10, 11), (3, 5, 10, 11), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -4.569177\n", + " acq_val from get_candidate: -6.177507776964656\n", + " new y value: [-2.62321976]\n", + " experiments: 52 -> 53\n", + "\n", + "--- Iteration 33 ---\n", + " n_experiments: 53\n", + "[(2, 5, 7), (2, 5, 7, 10), (2, 5, 10, 11), (2, 5, 10), (2, 5, 7, 11), (2, 5, 11), (2, 5, 11), (2, 5, 6, 10, 11), (2, 5, 7, 10, 11), (2, 5, 10, 11), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -1.628597\n", + " acq_val from get_candidate: -1.5900472148349385\n", + " new y value: [-0.30895601]\n", + " experiments: 53 -> 54\n", + "\n", + "--- Iteration 34 ---\n", + " n_experiments: 54\n", + "[(3, 4, 5, 11), (3, 4, 5, 6, 11), (3, 4, 5, 10, 11), (3, 4, 5, 6, 10, 11), (3, 4, 5, 8, 11), (3, 4, 5, 11), (3, 4, 5, 11), (3, 4, 5, 11), (3, 4, 5, 10), (3, 4, 5, 10, 11), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -2.306275\n", + " acq_val from get_candidate: -2.306274735972209\n", + " new y value: [-1.25219711]\n", + " experiments: 54 -> 55\n", + "\n", + "--- Iteration 35 ---\n", + " n_experiments: 55\n", + "[(3, 5, 8, 9, 11), (3, 5, 9, 10), (3, 5, 7, 9, 11), (3, 5, 6, 8, 9, 11), (3, 5, 9, 11), (3, 5, 10), (3, 5), (3, 5, 9, 11), (3, 5, 9, 10, 11), (3, 5), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 10)]\n", + " gold standard acq value: -4.307754\n", + " acq_val from get_candidate: -6.272325352698689\n", + " new y value: [-2.60693127]\n", + " experiments: 55 -> 56\n", + "\n", + "--- Iteration 36 ---\n", + " n_experiments: 56\n", + "[(3, 5, 6, 7, 10, 11), (3, 5, 6, 7), (3, 5, 6, 7, 9, 11), (3, 5, 6, 7, 11), (3, 5, 7, 9), (3, 5, 9), (3, 5, 8), (3, 5, 7, 10, 11), (3, 5), (3, 5, 8, 11), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5)]\n", + " gold standard acq value: -6.690115\n", + " acq_val from get_candidate: -6.675801398697926\n", + " new y value: [-0.00155831]\n", + " experiments: 56 -> 57\n", + "\n", + "--- Iteration 37 ---\n", + " n_experiments: 57\n", + "[(3, 4, 5, 8), (3, 4, 5), (3, 4, 5, 9, 11), (3, 4, 5, 6, 9, 10), (3, 5, 8), (3, 5, 9, 11), (3, 5, 6, 9, 11), (3, 5, 6, 9, 10), (3, 5, 6, 7, 9, 10), (3, 5, 11), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed in `gen_candidates_scipy` with the following warning(s):\n", + "[OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .')]\n", + "Trying again with a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " gold standard acq value: -4.322658\n", + " acq_val from get_candidate: -7.44715492522383\n", + " new y value: [-0.3989249]\n", + " experiments: 57 -> 58\n", + "\n", + "--- Iteration 38 ---\n", + " n_experiments: 58\n", + "[(2, 3, 4), (2, 4, 6, 10, 11), (2, 4, 9, 11), (2, 4, 9), (2, 4, 9, 10), (2, 4), (2, 4), (2, 4, 6), (2, 4, 9, 10, 11), (2, 4, 9), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5)]\n", + " gold standard acq value: -3.648986\n", + " acq_val from get_candidate: -7.497286475099898\n", + " new y value: [-0.0001444]\n", + " experiments: 58 -> 59\n", + "\n", + "--- Iteration 39 ---\n", + " n_experiments: 59\n", + "[(3, 4, 5, 9), (3, 4, 5), (3, 4, 5, 7, 11), (3, 4, 5, 6, 7, 8), (3, 4, 5, 6, 7), (3, 4, 5, 7, 10, 11), (3, 4, 5, 9), (3, 4, 5, 9, 10), (3, 4, 5, 6, 8), (3, 4, 5, 9), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5)]\n", + " gold standard acq value: -3.680632\n", + " acq_val from get_candidate: -6.99578370217342\n", + " new y value: [-2.62568186]\n", + " experiments: 59 -> 60\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "\n", + "def get_candidate():\n", + " acqf = strategy._get_acqfs(n=1)[0]\n", + " groups = Groups(\n", + " groups=[\n", + " NChooseK(\n", + " features=list(range(len(benchmark.domain.inputs.get_keys()))),\n", + " max_count=6,\n", + " min_count=0,\n", + " )\n", + " ]\n", + " )\n", + " bounds = utils.get_torch_bounds_from_domain(\n", + " benchmark.domain, strategy.input_preprocessing_specs\n", + " )\n", + "\n", + " def reward_fn(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + " ) -> float:\n", + " if len(selected_features) == 0:\n", + " candidates = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + " local_bounds = bounds[:, selected_features]\n", + " sobol_samples = draw_sobol_samples(bounds=local_bounds, n=256, q=1).squeeze(1)\n", + " candidates = torch.zeros((sobol_samples.shape[0], bounds.shape[1]), **tkwargs)\n", + " candidates[:, selected_features] = sobol_samples\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + "\n", + " tracker = _SelectionTracker(inner_fn=reward_fn)\n", + "\n", + " mcts = MCTS(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=tracker,\n", + " use_cache=False,\n", + " rollout_mode=\"uniform_subset\",\n", + " )\n", + " mcts.run(n_iterations=1500)\n", + " best_valuue = -np.inf\n", + " best_candidate = None\n", + " # here we should gather candidates based on the top eval in the MCTS search\n", + " # and we should always re-evaluate the best performing selections from the current set\n", + " # of experiments to have proper exploitation in the search\n", + " # (currently we only have exploration through the MCTS search)\n", + " subsets_to_optimize = [i[0] for i in tracker.top_k(k=5)]\n", + " for _ in range(5):\n", + " leaf, _ = mcts._select_and_expand()\n", + " selected_features, cat_selections = mcts._get_selection(leaf)\n", + " subsets_to_optimize.append(selected_features)\n", + " # also add the top 5 performing subsets from the current experiments\n", + " top_experiments = strategy.experiments.nsmallest(5, \"y\")\n", + " top_subsets = [\n", + " tuple(\n", + " i\n", + " for i, val in enumerate(row[benchmark.domain.inputs.get_keys()].to_numpy())\n", + " if val > 0.0\n", + " )\n", + " for _, row in top_experiments.iterrows()\n", + " ]\n", + " subsets_to_optimize.extend(top_subsets)\n", + " # now see if there are duplicates in subsets_to_optimize and remove them\n", + " print(subsets_to_optimize)\n", + " subsets_to_optimize = list(set(subsets_to_optimize))\n", + "\n", + " for selected_features in subsets_to_optimize:\n", + " # leaf, path = mcts._select_and_expand()\n", + " # selected_features, cat_selections = mcts._get_selection(leaf)\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features=fixed,\n", + " )\n", + " if acq_value > best_valuue:\n", + " best_valuue = acq_value\n", + " best_candidate = candidates\n", + " _, gold_acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features={6: 0.0, 7: 0.0, 8: 0.0, 9: 0.0, 10: 0.0, 11: 0.0},\n", + " )\n", + " print(f\" gold standard acq value: {gold_acq_value:.6f}\")\n", + " return strategy.acqf_optimizer._candidates_tensor_to_dataframe(\n", + " best_candidate, strategy.domain\n", + " ), best_valuue\n", + "\n", + "\n", + "# for _ in range(5):\n", + "# candidate, acq_val = get_candidate()\n", + "# new_experiments = benchmark.f(candidate, return_complete=True)\n", + "# strategy.tell(new_experiments,replace=False)\n", + "# print(acqf_val.item())\n", + "# print(strategy.experiments.y.min())\n", + "\n", + "for iteration in range(40):\n", + " acqf_check = strategy._get_acqfs(n=1)[0]\n", + "\n", + " # Probe: evaluate acqf at a fixed test point to see if model actually changed\n", + " # test_point = torch.zeros(1, 1, bounds.shape[1], **tkwargs)\n", + " # test_point[..., 0] = 0.5\n", + " # probe_val = acqf_check(test_point).item()\n", + "\n", + " print(f\"\\n--- Iteration {iteration} ---\")\n", + " print(f\" n_experiments: {len(strategy.experiments)}\")\n", + " # print(f\" acqf probe at fixed point: {probe_val:.10f}\")\n", + "\n", + " candidate, acq_val = get_candidate()\n", + " print(f\" acq_val from get_candidate: {acq_val}\")\n", + "\n", + " new_experiments = benchmark.f(candidate, return_complete=True)\n", + " print(f\" new y value: {new_experiments['y'].values}\")\n", + "\n", + " n_before = len(strategy.experiments)\n", + " strategy.tell(new_experiments, replace=False)\n", + " n_after = len(strategy.experiments)\n", + " print(f\" experiments: {n_before} -> {n_after}\")\n", + "\n", + " # Check if model params actually changed\n", + " # ls = strategy.model.covar_module.lengthscale\n", + " # print(f\" lengthscale[0:3]: {ls[0, :3].tolist()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
400.00.00.0000000.0000000.3336740.6768680.00.0000000.0000000.0000000.0000001.000000-0.822897True
410.00.00.0000000.2906560.0000000.6630810.00.0000000.0000000.0000001.0000001.000000-0.395555True
420.00.00.0000000.0000000.3107610.6626520.00.0000000.0000000.0000000.0000001.000000-0.835020True
430.00.00.0000000.3019600.3326450.6671200.00.0000000.0000000.0000000.0000001.000000-1.886002True
440.00.00.0000000.3057920.3191660.6715580.00.0000000.0000000.0000001.0000000.000000-1.901481True
450.00.00.0000000.3119330.3250770.6780670.00.0000001.0000000.0000001.0000000.000000-1.888116True
460.00.00.0000000.3035960.3198840.6694800.01.0000000.0000000.0000001.0000001.000000-1.902525True
470.00.00.0000000.0000000.3132790.6572320.01.0000000.0000000.0000001.0000001.000000-0.834159True
480.00.00.0000000.0000000.4540240.7599200.00.0000000.5833470.0000000.9605270.351945-0.533118True
490.00.00.0000000.3254260.3323940.6991700.01.0000000.0000000.0000000.0000000.000000-1.846701True
500.00.00.0000000.3596750.3876700.7058920.00.4670630.0000000.0000000.5103340.436983-1.597361True
510.00.00.0000000.2833090.3075270.6595140.00.0000000.0000000.0000001.0000000.000000-1.913625True
520.00.00.0000000.2830220.3089380.6532150.00.0000000.0000000.0000000.0000000.000000-1.911601True
530.00.00.0000000.0000000.5758110.0000000.00.0000000.0000000.0000001.0000000.000000-0.007180True
540.00.00.0000001.0000001.0000000.4891370.00.0000000.0000000.0000001.0000000.000000-0.000120True
550.00.00.0000001.0000000.0000000.0000000.00.0000000.0000001.0000001.0000000.000000-0.000135True
560.00.00.0000000.5712981.0000000.0000001.00.0000001.0000000.0000000.0000000.000000-0.000423True
570.00.01.0000000.0000001.0000001.0000000.00.0000000.0000000.0000001.0000001.000000-0.006990True
580.00.00.0000000.2829400.3066950.6654810.00.0000000.0000001.0000000.0000001.000000-1.914125True
590.00.00.0000000.2820920.3089130.6596561.01.0000000.0000000.0000000.0000001.000000-1.913080True
600.00.00.0000001.0000000.0000001.0000000.01.0000000.0000000.0000000.0000001.000000-0.001027True
610.01.00.0000000.0000000.3049800.6635520.00.0000001.0000000.0000001.0000000.000000-0.069826True
620.00.00.4837510.2870300.3112350.6553750.00.0000000.0000001.0000000.4852750.000000-2.625614True
630.00.00.7078990.2843010.3074830.6627560.01.0000000.0000000.0000000.0000000.000000-2.172833True
640.00.00.8877870.0000000.3274300.7227590.00.0000000.0000000.8757150.0000000.116789-0.723463True
650.01.00.0000000.5898240.0000000.0000000.00.0000000.0000001.0000000.0000000.000000-0.166836True
660.00.00.6774490.0000000.0000000.7112041.00.0000000.0000001.0000000.0000000.000000-0.382246True
670.00.00.0000000.2974560.3130590.6420050.00.0000000.0000001.0000000.6358540.000000-1.902891True
680.00.00.4374380.0000000.3159860.6345200.00.0000001.0000000.0000000.0000001.000000-1.389207True
690.00.00.4279530.0000000.3133400.6521721.00.1856490.0000000.0000000.0000000.000000-1.376775True
700.00.00.0000000.0000000.3071720.6617710.00.0000001.0000001.0000000.0000000.000000-0.835408True
710.00.00.4387840.0000000.3175070.6197280.00.0000000.0000000.0000001.0000001.000000-1.385606True
720.00.00.4025440.3116270.3064980.6512230.00.8865260.0000000.0000000.0000001.000000-2.566385True
730.00.00.4194690.2519900.3242230.6472120.00.0000000.0000000.0000001.0000000.000000-2.583365True
740.00.00.6256170.0000000.2723050.4894700.00.0000000.6101500.0000000.9260640.000000-1.087036True
750.00.00.7512170.0000000.6833750.4044430.00.0000000.0000000.0000000.0000000.000000-0.338749True
760.00.01.0000001.0000000.0000000.5332760.00.0000000.0000001.0000000.0000000.000000-0.007243True
770.00.00.4359850.2888460.3115910.6914390.00.0000001.0000001.0000000.0000000.000000-2.591009True
780.00.00.4243290.4090350.3438970.7186910.00.0000000.5095360.6181830.0000000.000000-2.122833True
790.00.01.0000001.0000000.9290300.0000000.00.0000000.0000001.0000001.0000000.000000-0.000114True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "40 0.0 0.0 0.000000 0.000000 0.333674 0.676868 0.0 \n", + "41 0.0 0.0 0.000000 0.290656 0.000000 0.663081 0.0 \n", + "42 0.0 0.0 0.000000 0.000000 0.310761 0.662652 0.0 \n", + "43 0.0 0.0 0.000000 0.301960 0.332645 0.667120 0.0 \n", + "44 0.0 0.0 0.000000 0.305792 0.319166 0.671558 0.0 \n", + "45 0.0 0.0 0.000000 0.311933 0.325077 0.678067 0.0 \n", + "46 0.0 0.0 0.000000 0.303596 0.319884 0.669480 0.0 \n", + "47 0.0 0.0 0.000000 0.000000 0.313279 0.657232 0.0 \n", + "48 0.0 0.0 0.000000 0.000000 0.454024 0.759920 0.0 \n", + "49 0.0 0.0 0.000000 0.325426 0.332394 0.699170 0.0 \n", + "50 0.0 0.0 0.000000 0.359675 0.387670 0.705892 0.0 \n", + "51 0.0 0.0 0.000000 0.283309 0.307527 0.659514 0.0 \n", + "52 0.0 0.0 0.000000 0.283022 0.308938 0.653215 0.0 \n", + "53 0.0 0.0 0.000000 0.000000 0.575811 0.000000 0.0 \n", + "54 0.0 0.0 0.000000 1.000000 1.000000 0.489137 0.0 \n", + "55 0.0 0.0 0.000000 1.000000 0.000000 0.000000 0.0 \n", + "56 0.0 0.0 0.000000 0.571298 1.000000 0.000000 1.0 \n", + "57 0.0 0.0 1.000000 0.000000 1.000000 1.000000 0.0 \n", + "58 0.0 0.0 0.000000 0.282940 0.306695 0.665481 0.0 \n", + "59 0.0 0.0 0.000000 0.282092 0.308913 0.659656 1.0 \n", + "60 0.0 0.0 0.000000 1.000000 0.000000 1.000000 0.0 \n", + "61 0.0 1.0 0.000000 0.000000 0.304980 0.663552 0.0 \n", + "62 0.0 0.0 0.483751 0.287030 0.311235 0.655375 0.0 \n", + "63 0.0 0.0 0.707899 0.284301 0.307483 0.662756 0.0 \n", + "64 0.0 0.0 0.887787 0.000000 0.327430 0.722759 0.0 \n", + "65 0.0 1.0 0.000000 0.589824 0.000000 0.000000 0.0 \n", + "66 0.0 0.0 0.677449 0.000000 0.000000 0.711204 1.0 \n", + "67 0.0 0.0 0.000000 0.297456 0.313059 0.642005 0.0 \n", + "68 0.0 0.0 0.437438 0.000000 0.315986 0.634520 0.0 \n", + "69 0.0 0.0 0.427953 0.000000 0.313340 0.652172 1.0 \n", + "70 0.0 0.0 0.000000 0.000000 0.307172 0.661771 0.0 \n", + "71 0.0 0.0 0.438784 0.000000 0.317507 0.619728 0.0 \n", + "72 0.0 0.0 0.402544 0.311627 0.306498 0.651223 0.0 \n", + "73 0.0 0.0 0.419469 0.251990 0.324223 0.647212 0.0 \n", + "74 0.0 0.0 0.625617 0.000000 0.272305 0.489470 0.0 \n", + "75 0.0 0.0 0.751217 0.000000 0.683375 0.404443 0.0 \n", + "76 0.0 0.0 1.000000 1.000000 0.000000 0.533276 0.0 \n", + "77 0.0 0.0 0.435985 0.288846 0.311591 0.691439 0.0 \n", + "78 0.0 0.0 0.424329 0.409035 0.343897 0.718691 0.0 \n", + "79 0.0 0.0 1.000000 1.000000 0.929030 0.000000 0.0 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "40 0.000000 0.000000 0.000000 0.000000 1.000000 \n", + "41 0.000000 0.000000 0.000000 1.000000 1.000000 \n", + "42 0.000000 0.000000 0.000000 0.000000 1.000000 \n", + "43 0.000000 0.000000 0.000000 0.000000 1.000000 \n", + "44 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "45 0.000000 1.000000 0.000000 1.000000 0.000000 \n", + "46 1.000000 0.000000 0.000000 1.000000 1.000000 \n", + "47 1.000000 0.000000 0.000000 1.000000 1.000000 \n", + "48 0.000000 0.583347 0.000000 0.960527 0.351945 \n", + "49 1.000000 0.000000 0.000000 0.000000 0.000000 \n", + "50 0.467063 0.000000 0.000000 0.510334 0.436983 \n", + "51 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "52 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "53 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "54 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "55 0.000000 0.000000 1.000000 1.000000 0.000000 \n", + "56 0.000000 1.000000 0.000000 0.000000 0.000000 \n", + "57 0.000000 0.000000 0.000000 1.000000 1.000000 \n", + "58 0.000000 0.000000 1.000000 0.000000 1.000000 \n", + "59 1.000000 0.000000 0.000000 0.000000 1.000000 \n", + "60 1.000000 0.000000 0.000000 0.000000 1.000000 \n", + "61 0.000000 1.000000 0.000000 1.000000 0.000000 \n", + "62 0.000000 0.000000 1.000000 0.485275 0.000000 \n", + "63 1.000000 0.000000 0.000000 0.000000 0.000000 \n", + "64 0.000000 0.000000 0.875715 0.000000 0.116789 \n", + "65 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "66 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "67 0.000000 0.000000 1.000000 0.635854 0.000000 \n", + "68 0.000000 1.000000 0.000000 0.000000 1.000000 \n", + "69 0.185649 0.000000 0.000000 0.000000 0.000000 \n", + "70 0.000000 1.000000 1.000000 0.000000 0.000000 \n", + "71 0.000000 0.000000 0.000000 1.000000 1.000000 \n", + "72 0.886526 0.000000 0.000000 0.000000 1.000000 \n", + "73 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "74 0.000000 0.610150 0.000000 0.926064 0.000000 \n", + "75 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "76 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "77 0.000000 1.000000 1.000000 0.000000 0.000000 \n", + "78 0.000000 0.509536 0.618183 0.000000 0.000000 \n", + "79 0.000000 0.000000 1.000000 1.000000 0.000000 \n", + "\n", + " y valid_y \n", + "40 -0.822897 True \n", + "41 -0.395555 True \n", + "42 -0.835020 True \n", + "43 -1.886002 True \n", + "44 -1.901481 True \n", + "45 -1.888116 True \n", + "46 -1.902525 True \n", + "47 -0.834159 True \n", + "48 -0.533118 True \n", + "49 -1.846701 True \n", + "50 -1.597361 True \n", + "51 -1.913625 True \n", + "52 -1.911601 True \n", + "53 -0.007180 True \n", + "54 -0.000120 True \n", + "55 -0.000135 True \n", + "56 -0.000423 True \n", + "57 -0.006990 True \n", + "58 -1.914125 True \n", + "59 -1.913080 True \n", + "60 -0.001027 True \n", + "61 -0.069826 True \n", + "62 -2.625614 True \n", + "63 -2.172833 True \n", + "64 -0.723463 True \n", + "65 -0.166836 True \n", + "66 -0.382246 True \n", + "67 -1.902891 True \n", + "68 -1.389207 True \n", + "69 -1.376775 True \n", + "70 -0.835408 True \n", + "71 -1.385606 True \n", + "72 -2.566385 True \n", + "73 -2.583365 True \n", + "74 -1.087036 True \n", + "75 -0.338749 True \n", + "76 -0.007243 True \n", + "77 -2.591009 True \n", + "78 -2.122833 True \n", + "79 -0.000114 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments.tail(40)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.3403870.9393840.0000000.0000000.9888050.0000000.0000000.0000000.0591740.2835080.6511150.000000-0.093813True
10.0000000.0000000.2157870.6949210.6070900.0000000.6598540.8646030.0000000.0000000.0000000.683445-0.004209True
20.1622490.0000000.0000000.0000000.6701020.0000000.0000000.8375380.3137730.0000000.1942530.959623-0.003265True
30.6468610.1510700.0000000.0239440.0000000.0000000.0000000.0000000.1645060.1478150.0000000.930434-0.005103True
40.0000000.4927420.1909610.0000000.0581920.1650810.0000000.1383720.0000000.0000000.0000000.281132-0.040627True
50.9716560.3410370.2473110.1170810.0000000.0000000.1014650.0000000.0000000.6124490.0000000.000000-0.002416True
60.0097270.3187120.0000000.0000000.5629500.5014310.0000000.0000000.1304470.7760440.0000000.000000-0.216006True
70.8733560.2732550.0000000.0000000.3955370.0000000.0000000.9134770.4222420.5581640.0000000.000000-0.007566True
80.0000000.0000000.0000000.9213050.4787820.5553820.8431830.0000000.6743220.0000000.8404980.000000-0.019061True
90.0000000.8613230.0000000.0000000.2463090.0000000.3676990.5743210.0000000.7715720.2532700.000000-0.010974True
100.0000000.8087130.0000000.0000000.7624750.5889760.0000000.0000000.0000001.0000000.0000000.000000-0.006552True
110.0000000.2696310.0000000.0000000.6788920.6843680.0000000.0000000.1327500.9054440.0000000.000000-0.082110True
120.0204250.3451080.0000000.0000000.4848080.3733500.0000000.0000000.1315800.6688370.0000000.000000-0.236332True
130.0000000.0000000.0000000.0000000.4844740.6979050.0000000.0000000.1016501.0000000.0000000.000000-0.479774True
140.0000000.0000000.0000000.0000000.4369920.8725960.0000000.0000000.0443161.0000000.0000000.000000-0.439923True
150.0000000.0000000.0000000.0000000.4688540.5285130.0000000.0000000.4262491.0000000.0000000.000000-0.457100True
160.0000000.0000000.0000000.0000000.4914700.5987830.0000000.0000000.0000001.0000000.0000000.780327-0.447853True
170.8217960.0000000.0000000.0000000.4928830.8370130.0000000.0000000.0158141.0000000.0000000.000000-0.151161True
180.0000000.0000000.0000000.0000000.4878850.8149280.0000000.0000000.0000001.0000000.5601960.000000-0.395607True
190.0000000.0000000.0000000.0000000.4761290.0000000.0000000.7333010.0000001.0000000.0000000.000000-0.014914True
200.0000000.0000000.0000000.0000000.4746010.0332510.0000000.0000000.0000001.0000000.0000000.000000-0.021225True
210.0000000.0000000.0000000.0000000.4013530.6625380.0000000.0000000.4169571.0000000.0000000.098667-0.713574True
220.0000000.0000000.0000000.0000000.2934260.6800680.0000000.0000000.6178441.0000000.0000000.499694-0.832143True
230.0000000.0000000.0000000.0000000.2913920.6897310.0000000.0000000.9863780.0000000.0000001.000000-0.828849True
240.0000000.0000000.5871690.0000000.2613150.6789010.0000000.0000000.0000000.0000000.0000000.000000-1.332965True
250.0000000.0000000.7442720.0000000.2887540.7117160.0000001.0000001.0000001.0000000.0000000.000000-1.049137True
260.0000000.0000000.6480770.0000000.1817240.7249170.0000000.0000000.0000000.0000000.0000000.000000-0.989974True
270.0000000.0000000.4851910.0000000.3299620.6448610.0000000.0000000.0000000.0000000.0078200.000000-1.439587True
280.0000001.0000000.5236940.0000000.3399290.6366950.0000000.0000000.0000000.0000001.0000000.000000-0.148478True
290.0000000.0000000.5246590.0000000.3244050.5925300.0000000.0000000.0000000.0000001.0000000.000000-1.426935True
300.0000000.0000000.5294890.0000000.3283600.6140620.4269900.0000000.0000000.0000000.0000000.000000-1.445619True
310.0000000.0000000.5163510.0000000.3391410.6655451.0000000.0000000.0000000.0000001.0000001.000000-1.445010True
320.0000000.0000000.6209970.0000000.3795740.6384070.0000000.0000000.0000000.0000000.0000001.000000-1.332345True
330.0000000.0000000.5739540.0000000.3132700.6565031.0000001.0000000.0000000.0000000.0000000.000000-1.425452True
340.0000000.0000000.4604770.0000000.3091820.6107771.0000001.0000000.0000000.0000000.0000001.000000-1.401627True
350.0000000.0000000.4856750.0000000.3265430.6525751.0000000.0000000.0000000.0000001.0000000.000000-1.438949True
360.0000000.0000000.5431990.4133690.3128820.6396070.0000000.0000000.0000000.0000001.0000000.000000-2.166108True
370.0000000.0000000.4933790.3802040.3181340.7058010.0000000.0000000.0000000.0000000.0000000.000000-2.329238True
380.0000000.0000000.4304130.3201680.3262660.7199380.0000000.0000000.0000000.0000000.0000000.000000-2.483866True
390.0000000.0000000.4816220.2769430.3538130.6832170.0000000.0000000.0000000.0000001.0000000.000000-2.547384True
400.0000000.0000000.4343250.3111700.4180120.6917580.0000000.0000000.0000000.0000001.0000000.000000-2.160468True
410.0000000.0000000.5353650.2637380.2880120.6783220.0000000.0000000.0000000.0000000.7931570.000000-2.563466True
420.0000000.0000000.4504500.2667890.3025150.5987380.0000000.0000000.0000000.0000000.0000000.000000-2.551445True
430.0000000.0000000.4350240.2675920.2938980.7048340.0000000.0000000.0000000.0000001.0000000.000000-2.557965True
440.0000000.0000000.5497630.2618290.3137660.6987830.0000000.0000000.0000000.0000000.0000000.000000-2.542168True
450.0000000.0000000.4521521.0000000.0000000.0000000.0000000.0000001.0000000.0000000.8875270.000000-0.000560True
460.0000000.0000000.5002440.2662290.3088370.6429600.0000000.0000000.0000000.0000001.0000000.000000-2.619489True
470.0000000.0000000.4417800.2729270.3013790.6065120.0000000.0000000.0000000.0000001.0000000.000000-2.565203True
480.0000000.0000000.4835300.2657050.2983180.6643340.0000000.0000000.0000000.0000000.0000000.000000-2.618345True
490.0000000.0000000.4884220.2626480.2938821.0000000.0000000.0000000.0000000.0000000.9685850.000000-1.037483True
500.0000000.0000000.4388330.2663450.3040850.6568010.0000000.0000000.0000000.0000001.0000000.000000-2.615455True
510.0000000.0000000.0000000.2982610.2936920.6565961.0000001.0000000.0000000.0000000.5671050.000000-1.907823True
520.0000000.0000000.4601500.2615620.3115950.6538530.0000000.0000000.0000000.0000000.0000000.000000-2.623220True
530.0000000.0000000.3420740.0000000.0000000.7050800.0000000.1256550.0000000.0000000.2514940.000000-0.308956True
540.0000000.0000000.8781220.2717260.3239540.5022000.0000000.0000000.0000000.0000000.0000000.000000-1.252197True
550.0000000.0000000.4389060.2732560.2931240.6574040.0000000.0000000.0000000.0000000.0000000.000000-2.606931True
560.0000000.0000000.0000001.0000000.0000000.9103761.0000001.0000000.0000001.0000000.0000001.000000-0.001558True
570.0000000.0000000.0000000.5399200.0702730.6507110.0000000.0000001.0000000.0000000.0000000.000000-0.398925True
580.0000000.0000000.9666841.0000001.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000-0.000144True
590.0000000.0000000.4634530.2644480.3101840.6569270.0000000.0000000.0000000.0000001.0000000.000000-2.625682True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.340387 0.939384 0.000000 0.000000 0.988805 0.000000 0.000000 \n", + "1 0.000000 0.000000 0.215787 0.694921 0.607090 0.000000 0.659854 \n", + "2 0.162249 0.000000 0.000000 0.000000 0.670102 0.000000 0.000000 \n", + "3 0.646861 0.151070 0.000000 0.023944 0.000000 0.000000 0.000000 \n", + "4 0.000000 0.492742 0.190961 0.000000 0.058192 0.165081 0.000000 \n", + "5 0.971656 0.341037 0.247311 0.117081 0.000000 0.000000 0.101465 \n", + "6 0.009727 0.318712 0.000000 0.000000 0.562950 0.501431 0.000000 \n", + "7 0.873356 0.273255 0.000000 0.000000 0.395537 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.921305 0.478782 0.555382 0.843183 \n", + "9 0.000000 0.861323 0.000000 0.000000 0.246309 0.000000 0.367699 \n", + "10 0.000000 0.808713 0.000000 0.000000 0.762475 0.588976 0.000000 \n", + "11 0.000000 0.269631 0.000000 0.000000 0.678892 0.684368 0.000000 \n", + "12 0.020425 0.345108 0.000000 0.000000 0.484808 0.373350 0.000000 \n", + "13 0.000000 0.000000 0.000000 0.000000 0.484474 0.697905 0.000000 \n", + "14 0.000000 0.000000 0.000000 0.000000 0.436992 0.872596 0.000000 \n", + "15 0.000000 0.000000 0.000000 0.000000 0.468854 0.528513 0.000000 \n", + "16 0.000000 0.000000 0.000000 0.000000 0.491470 0.598783 0.000000 \n", + "17 0.821796 0.000000 0.000000 0.000000 0.492883 0.837013 0.000000 \n", + "18 0.000000 0.000000 0.000000 0.000000 0.487885 0.814928 0.000000 \n", + "19 0.000000 0.000000 0.000000 0.000000 0.476129 0.000000 0.000000 \n", + "20 0.000000 0.000000 0.000000 0.000000 0.474601 0.033251 0.000000 \n", + "21 0.000000 0.000000 0.000000 0.000000 0.401353 0.662538 0.000000 \n", + "22 0.000000 0.000000 0.000000 0.000000 0.293426 0.680068 0.000000 \n", + "23 0.000000 0.000000 0.000000 0.000000 0.291392 0.689731 0.000000 \n", + "24 0.000000 0.000000 0.587169 0.000000 0.261315 0.678901 0.000000 \n", + "25 0.000000 0.000000 0.744272 0.000000 0.288754 0.711716 0.000000 \n", + "26 0.000000 0.000000 0.648077 0.000000 0.181724 0.724917 0.000000 \n", + "27 0.000000 0.000000 0.485191 0.000000 0.329962 0.644861 0.000000 \n", + "28 0.000000 1.000000 0.523694 0.000000 0.339929 0.636695 0.000000 \n", + "29 0.000000 0.000000 0.524659 0.000000 0.324405 0.592530 0.000000 \n", + "30 0.000000 0.000000 0.529489 0.000000 0.328360 0.614062 0.426990 \n", + "31 0.000000 0.000000 0.516351 0.000000 0.339141 0.665545 1.000000 \n", + "32 0.000000 0.000000 0.620997 0.000000 0.379574 0.638407 0.000000 \n", + "33 0.000000 0.000000 0.573954 0.000000 0.313270 0.656503 1.000000 \n", + "34 0.000000 0.000000 0.460477 0.000000 0.309182 0.610777 1.000000 \n", + "35 0.000000 0.000000 0.485675 0.000000 0.326543 0.652575 1.000000 \n", + "36 0.000000 0.000000 0.543199 0.413369 0.312882 0.639607 0.000000 \n", + "37 0.000000 0.000000 0.493379 0.380204 0.318134 0.705801 0.000000 \n", + "38 0.000000 0.000000 0.430413 0.320168 0.326266 0.719938 0.000000 \n", + "39 0.000000 0.000000 0.481622 0.276943 0.353813 0.683217 0.000000 \n", + "40 0.000000 0.000000 0.434325 0.311170 0.418012 0.691758 0.000000 \n", + "41 0.000000 0.000000 0.535365 0.263738 0.288012 0.678322 0.000000 \n", + "42 0.000000 0.000000 0.450450 0.266789 0.302515 0.598738 0.000000 \n", + "43 0.000000 0.000000 0.435024 0.267592 0.293898 0.704834 0.000000 \n", + "44 0.000000 0.000000 0.549763 0.261829 0.313766 0.698783 0.000000 \n", + "45 0.000000 0.000000 0.452152 1.000000 0.000000 0.000000 0.000000 \n", + "46 0.000000 0.000000 0.500244 0.266229 0.308837 0.642960 0.000000 \n", + "47 0.000000 0.000000 0.441780 0.272927 0.301379 0.606512 0.000000 \n", + "48 0.000000 0.000000 0.483530 0.265705 0.298318 0.664334 0.000000 \n", + "49 0.000000 0.000000 0.488422 0.262648 0.293882 1.000000 0.000000 \n", + "50 0.000000 0.000000 0.438833 0.266345 0.304085 0.656801 0.000000 \n", + "51 0.000000 0.000000 0.000000 0.298261 0.293692 0.656596 1.000000 \n", + "52 0.000000 0.000000 0.460150 0.261562 0.311595 0.653853 0.000000 \n", + "53 0.000000 0.000000 0.342074 0.000000 0.000000 0.705080 0.000000 \n", + "54 0.000000 0.000000 0.878122 0.271726 0.323954 0.502200 0.000000 \n", + "55 0.000000 0.000000 0.438906 0.273256 0.293124 0.657404 0.000000 \n", + "56 0.000000 0.000000 0.000000 1.000000 0.000000 0.910376 1.000000 \n", + "57 0.000000 0.000000 0.000000 0.539920 0.070273 0.650711 0.000000 \n", + "58 0.000000 0.000000 0.966684 1.000000 1.000000 0.000000 0.000000 \n", + "59 0.000000 0.000000 0.463453 0.264448 0.310184 0.656927 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.059174 0.283508 0.651115 0.000000 \n", + "1 0.864603 0.000000 0.000000 0.000000 0.683445 \n", + "2 0.837538 0.313773 0.000000 0.194253 0.959623 \n", + "3 0.000000 0.164506 0.147815 0.000000 0.930434 \n", + "4 0.138372 0.000000 0.000000 0.000000 0.281132 \n", + "5 0.000000 0.000000 0.612449 0.000000 0.000000 \n", + "6 0.000000 0.130447 0.776044 0.000000 0.000000 \n", + "7 0.913477 0.422242 0.558164 0.000000 0.000000 \n", + "8 0.000000 0.674322 0.000000 0.840498 0.000000 \n", + "9 0.574321 0.000000 0.771572 0.253270 0.000000 \n", + "10 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "11 0.000000 0.132750 0.905444 0.000000 0.000000 \n", + "12 0.000000 0.131580 0.668837 0.000000 0.000000 \n", + "13 0.000000 0.101650 1.000000 0.000000 0.000000 \n", + "14 0.000000 0.044316 1.000000 0.000000 0.000000 \n", + "15 0.000000 0.426249 1.000000 0.000000 0.000000 \n", + "16 0.000000 0.000000 1.000000 0.000000 0.780327 \n", + "17 0.000000 0.015814 1.000000 0.000000 0.000000 \n", + "18 0.000000 0.000000 1.000000 0.560196 0.000000 \n", + "19 0.733301 0.000000 1.000000 0.000000 0.000000 \n", + "20 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "21 0.000000 0.416957 1.000000 0.000000 0.098667 \n", + "22 0.000000 0.617844 1.000000 0.000000 0.499694 \n", + "23 0.000000 0.986378 0.000000 0.000000 1.000000 \n", + "24 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "25 1.000000 1.000000 1.000000 0.000000 0.000000 \n", + "26 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "27 0.000000 0.000000 0.000000 0.007820 0.000000 \n", + "28 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "29 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "30 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "31 0.000000 0.000000 0.000000 1.000000 1.000000 \n", + "32 0.000000 0.000000 0.000000 0.000000 1.000000 \n", + "33 1.000000 0.000000 0.000000 0.000000 0.000000 \n", + "34 1.000000 0.000000 0.000000 0.000000 1.000000 \n", + "35 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "36 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "37 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "38 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "39 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "40 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "41 0.000000 0.000000 0.000000 0.793157 0.000000 \n", + "42 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "43 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "44 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "45 0.000000 1.000000 0.000000 0.887527 0.000000 \n", + "46 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "47 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "48 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "49 0.000000 0.000000 0.000000 0.968585 0.000000 \n", + "50 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "51 1.000000 0.000000 0.000000 0.567105 0.000000 \n", + "52 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "53 0.125655 0.000000 0.000000 0.251494 0.000000 \n", + "54 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "55 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "56 1.000000 0.000000 1.000000 0.000000 1.000000 \n", + "57 0.000000 1.000000 0.000000 0.000000 0.000000 \n", + "58 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "59 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.093813 True \n", + "1 -0.004209 True \n", + "2 -0.003265 True \n", + "3 -0.005103 True \n", + "4 -0.040627 True \n", + "5 -0.002416 True \n", + "6 -0.216006 True \n", + "7 -0.007566 True \n", + "8 -0.019061 True \n", + "9 -0.010974 True \n", + "10 -0.006552 True \n", + "11 -0.082110 True \n", + "12 -0.236332 True \n", + "13 -0.479774 True \n", + "14 -0.439923 True \n", + "15 -0.457100 True \n", + "16 -0.447853 True \n", + "17 -0.151161 True \n", + "18 -0.395607 True \n", + "19 -0.014914 True \n", + "20 -0.021225 True \n", + "21 -0.713574 True \n", + "22 -0.832143 True \n", + "23 -0.828849 True \n", + "24 -1.332965 True \n", + "25 -1.049137 True \n", + "26 -0.989974 True \n", + "27 -1.439587 True \n", + "28 -0.148478 True \n", + "29 -1.426935 True \n", + "30 -1.445619 True \n", + "31 -1.445010 True \n", + "32 -1.332345 True \n", + "33 -1.425452 True \n", + "34 -1.401627 True \n", + "35 -1.438949 True \n", + "36 -2.166108 True \n", + "37 -2.329238 True \n", + "38 -2.483866 True \n", + "39 -2.547384 True \n", + "40 -2.160468 True \n", + "41 -2.563466 True \n", + "42 -2.551445 True \n", + "43 -2.557965 True \n", + "44 -2.542168 True \n", + "45 -0.000560 True \n", + "46 -2.619489 True \n", + "47 -2.565203 True \n", + "48 -2.618345 True \n", + "49 -1.037483 True \n", + "50 -2.615455 True \n", + "51 -1.907823 True \n", + "52 -2.623220 True \n", + "53 -0.308956 True \n", + "54 -1.252197 True \n", + "55 -2.606931 True \n", + "56 -0.001558 True \n", + "57 -0.398925 True \n", + "58 -0.000144 True \n", + "59 -2.625682 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "strategy.experiments.to_csv(\"experiments_2.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
450.00.00.0000000.3119330.3250770.6780670.00.0000001.0000000.0000001.0000000.000000-1.888116True
460.00.00.0000000.3035960.3198840.6694800.01.0000000.0000000.0000001.0000001.000000-1.902525True
470.00.00.0000000.0000000.3132790.6572320.01.0000000.0000000.0000001.0000001.000000-0.834159True
480.00.00.0000000.0000000.4540240.7599200.00.0000000.5833470.0000000.9605270.351945-0.533118True
490.00.00.0000000.3254260.3323940.6991700.01.0000000.0000000.0000000.0000000.000000-1.846701True
500.00.00.0000000.3596750.3876700.7058920.00.4670630.0000000.0000000.5103340.436983-1.597361True
510.00.00.0000000.2833090.3075270.6595140.00.0000000.0000000.0000001.0000000.000000-1.913625True
520.00.00.0000000.2830220.3089380.6532150.00.0000000.0000000.0000000.0000000.000000-1.911601True
530.00.00.0000000.0000000.5758110.0000000.00.0000000.0000000.0000001.0000000.000000-0.007180True
540.00.00.0000001.0000001.0000000.4891370.00.0000000.0000000.0000001.0000000.000000-0.000120True
550.00.00.0000001.0000000.0000000.0000000.00.0000000.0000001.0000001.0000000.000000-0.000135True
560.00.00.0000000.5712981.0000000.0000001.00.0000001.0000000.0000000.0000000.000000-0.000423True
570.00.01.0000000.0000001.0000001.0000000.00.0000000.0000000.0000001.0000001.000000-0.006990True
580.00.00.0000000.2829400.3066950.6654810.00.0000000.0000001.0000000.0000001.000000-1.914125True
590.00.00.0000000.2820920.3089130.6596561.01.0000000.0000000.0000000.0000001.000000-1.913080True
600.00.00.0000001.0000000.0000001.0000000.01.0000000.0000000.0000000.0000001.000000-0.001027True
610.01.00.0000000.0000000.3049800.6635520.00.0000001.0000000.0000001.0000000.000000-0.069826True
620.00.00.4837510.2870300.3112350.6553750.00.0000000.0000001.0000000.4852750.000000-2.625614True
630.00.00.7078990.2843010.3074830.6627560.01.0000000.0000000.0000000.0000000.000000-2.172833True
640.00.00.8877870.0000000.3274300.7227590.00.0000000.0000000.8757150.0000000.116789-0.723463True
650.01.00.0000000.5898240.0000000.0000000.00.0000000.0000001.0000000.0000000.000000-0.166836True
660.00.00.6774490.0000000.0000000.7112041.00.0000000.0000001.0000000.0000000.000000-0.382246True
670.00.00.0000000.2974560.3130590.6420050.00.0000000.0000001.0000000.6358540.000000-1.902891True
680.00.00.4374380.0000000.3159860.6345200.00.0000001.0000000.0000000.0000001.000000-1.389207True
690.00.00.4279530.0000000.3133400.6521721.00.1856490.0000000.0000000.0000000.000000-1.376775True
700.00.00.0000000.0000000.3071720.6617710.00.0000001.0000001.0000000.0000000.000000-0.835408True
710.00.00.4387840.0000000.3175070.6197280.00.0000000.0000000.0000001.0000001.000000-1.385606True
720.00.00.4025440.3116270.3064980.6512230.00.8865260.0000000.0000000.0000001.000000-2.566385True
730.00.00.4194690.2519900.3242230.6472120.00.0000000.0000000.0000001.0000000.000000-2.583365True
740.00.00.6256170.0000000.2723050.4894700.00.0000000.6101500.0000000.9260640.000000-1.087036True
750.00.00.7512170.0000000.6833750.4044430.00.0000000.0000000.0000000.0000000.000000-0.338749True
760.00.01.0000001.0000000.0000000.5332760.00.0000000.0000001.0000000.0000000.000000-0.007243True
770.00.00.4359850.2888460.3115910.6914390.00.0000001.0000001.0000000.0000000.000000-2.591009True
780.00.00.4243290.4090350.3438970.7186910.00.0000000.5095360.6181830.0000000.000000-2.122833True
790.00.01.0000001.0000000.9290300.0000000.00.0000000.0000001.0000001.0000000.000000-0.000114True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "45 0.0 0.0 0.000000 0.311933 0.325077 0.678067 0.0 \n", + "46 0.0 0.0 0.000000 0.303596 0.319884 0.669480 0.0 \n", + "47 0.0 0.0 0.000000 0.000000 0.313279 0.657232 0.0 \n", + "48 0.0 0.0 0.000000 0.000000 0.454024 0.759920 0.0 \n", + "49 0.0 0.0 0.000000 0.325426 0.332394 0.699170 0.0 \n", + "50 0.0 0.0 0.000000 0.359675 0.387670 0.705892 0.0 \n", + "51 0.0 0.0 0.000000 0.283309 0.307527 0.659514 0.0 \n", + "52 0.0 0.0 0.000000 0.283022 0.308938 0.653215 0.0 \n", + "53 0.0 0.0 0.000000 0.000000 0.575811 0.000000 0.0 \n", + "54 0.0 0.0 0.000000 1.000000 1.000000 0.489137 0.0 \n", + "55 0.0 0.0 0.000000 1.000000 0.000000 0.000000 0.0 \n", + "56 0.0 0.0 0.000000 0.571298 1.000000 0.000000 1.0 \n", + "57 0.0 0.0 1.000000 0.000000 1.000000 1.000000 0.0 \n", + "58 0.0 0.0 0.000000 0.282940 0.306695 0.665481 0.0 \n", + "59 0.0 0.0 0.000000 0.282092 0.308913 0.659656 1.0 \n", + "60 0.0 0.0 0.000000 1.000000 0.000000 1.000000 0.0 \n", + "61 0.0 1.0 0.000000 0.000000 0.304980 0.663552 0.0 \n", + "62 0.0 0.0 0.483751 0.287030 0.311235 0.655375 0.0 \n", + "63 0.0 0.0 0.707899 0.284301 0.307483 0.662756 0.0 \n", + "64 0.0 0.0 0.887787 0.000000 0.327430 0.722759 0.0 \n", + "65 0.0 1.0 0.000000 0.589824 0.000000 0.000000 0.0 \n", + "66 0.0 0.0 0.677449 0.000000 0.000000 0.711204 1.0 \n", + "67 0.0 0.0 0.000000 0.297456 0.313059 0.642005 0.0 \n", + "68 0.0 0.0 0.437438 0.000000 0.315986 0.634520 0.0 \n", + "69 0.0 0.0 0.427953 0.000000 0.313340 0.652172 1.0 \n", + "70 0.0 0.0 0.000000 0.000000 0.307172 0.661771 0.0 \n", + "71 0.0 0.0 0.438784 0.000000 0.317507 0.619728 0.0 \n", + "72 0.0 0.0 0.402544 0.311627 0.306498 0.651223 0.0 \n", + "73 0.0 0.0 0.419469 0.251990 0.324223 0.647212 0.0 \n", + "74 0.0 0.0 0.625617 0.000000 0.272305 0.489470 0.0 \n", + "75 0.0 0.0 0.751217 0.000000 0.683375 0.404443 0.0 \n", + "76 0.0 0.0 1.000000 1.000000 0.000000 0.533276 0.0 \n", + "77 0.0 0.0 0.435985 0.288846 0.311591 0.691439 0.0 \n", + "78 0.0 0.0 0.424329 0.409035 0.343897 0.718691 0.0 \n", + "79 0.0 0.0 1.000000 1.000000 0.929030 0.000000 0.0 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "45 0.000000 1.000000 0.000000 1.000000 0.000000 \n", + "46 1.000000 0.000000 0.000000 1.000000 1.000000 \n", + "47 1.000000 0.000000 0.000000 1.000000 1.000000 \n", + "48 0.000000 0.583347 0.000000 0.960527 0.351945 \n", + "49 1.000000 0.000000 0.000000 0.000000 0.000000 \n", + "50 0.467063 0.000000 0.000000 0.510334 0.436983 \n", + "51 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "52 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "53 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "54 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "55 0.000000 0.000000 1.000000 1.000000 0.000000 \n", + "56 0.000000 1.000000 0.000000 0.000000 0.000000 \n", + "57 0.000000 0.000000 0.000000 1.000000 1.000000 \n", + "58 0.000000 0.000000 1.000000 0.000000 1.000000 \n", + "59 1.000000 0.000000 0.000000 0.000000 1.000000 \n", + "60 1.000000 0.000000 0.000000 0.000000 1.000000 \n", + "61 0.000000 1.000000 0.000000 1.000000 0.000000 \n", + "62 0.000000 0.000000 1.000000 0.485275 0.000000 \n", + "63 1.000000 0.000000 0.000000 0.000000 0.000000 \n", + "64 0.000000 0.000000 0.875715 0.000000 0.116789 \n", + "65 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "66 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "67 0.000000 0.000000 1.000000 0.635854 0.000000 \n", + "68 0.000000 1.000000 0.000000 0.000000 1.000000 \n", + "69 0.185649 0.000000 0.000000 0.000000 0.000000 \n", + "70 0.000000 1.000000 1.000000 0.000000 0.000000 \n", + "71 0.000000 0.000000 0.000000 1.000000 1.000000 \n", + "72 0.886526 0.000000 0.000000 0.000000 1.000000 \n", + "73 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "74 0.000000 0.610150 0.000000 0.926064 0.000000 \n", + "75 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "76 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "77 0.000000 1.000000 1.000000 0.000000 0.000000 \n", + "78 0.000000 0.509536 0.618183 0.000000 0.000000 \n", + "79 0.000000 0.000000 1.000000 1.000000 0.000000 \n", + "\n", + " y valid_y \n", + "45 -1.888116 True \n", + "46 -1.902525 True \n", + "47 -0.834159 True \n", + "48 -0.533118 True \n", + "49 -1.846701 True \n", + "50 -1.597361 True \n", + "51 -1.913625 True \n", + "52 -1.911601 True \n", + "53 -0.007180 True \n", + "54 -0.000120 True \n", + "55 -0.000135 True \n", + "56 -0.000423 True \n", + "57 -0.006990 True \n", + "58 -1.914125 True \n", + "59 -1.913080 True \n", + "60 -0.001027 True \n", + "61 -0.069826 True \n", + "62 -2.625614 True \n", + "63 -2.172833 True \n", + "64 -0.723463 True \n", + "65 -0.166836 True \n", + "66 -0.382246 True \n", + "67 -1.902891 True \n", + "68 -1.389207 True \n", + "69 -1.376775 True \n", + "70 -0.835408 True \n", + "71 -1.385606 True \n", + "72 -2.566385 True \n", + "73 -2.583365 True \n", + "74 -1.087036 True \n", + "75 -0.338749 True \n", + "76 -0.007243 True \n", + "77 -2.591009 True \n", + "78 -2.122833 True \n", + "79 -0.000114 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments.tail(35)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.5631, 0.9276, 2.5197, 0.3658, 0.2495, 0.3455, 3.4509, 7.5572,\n", + " 8.9175, 12.1373, 5.9159, 4.9223]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
620.00.00.4837510.2870300.3112350.6553750.00.0000000.01.00.4852750.0-2.625614True
770.00.00.4359850.2888460.3115910.6914390.00.0000001.01.00.0000000.0-2.591009True
730.00.00.4194690.2519900.3242230.6472120.00.0000000.00.01.0000000.0-2.583365True
720.00.00.4025440.3116270.3064980.6512230.00.8865260.00.00.0000001.0-2.566385True
630.00.00.7078990.2843010.3074830.6627560.01.0000000.00.00.0000000.0-2.172833True
.............................................
600.00.00.0000001.0000000.0000001.0000000.01.0000000.00.00.0000001.0-0.001027True
560.00.00.0000000.5712981.0000000.0000001.00.0000001.00.00.0000000.0-0.000423True
550.00.00.0000001.0000000.0000000.0000000.00.0000000.01.01.0000000.0-0.000135True
540.00.00.0000001.0000001.0000000.4891370.00.0000000.00.01.0000000.0-0.000120True
790.00.01.0000001.0000000.9290300.0000000.00.0000000.01.01.0000000.0-0.000114True
\n", + "

80 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "62 0.0 0.0 0.483751 0.287030 0.311235 0.655375 0.0 \n", + "77 0.0 0.0 0.435985 0.288846 0.311591 0.691439 0.0 \n", + "73 0.0 0.0 0.419469 0.251990 0.324223 0.647212 0.0 \n", + "72 0.0 0.0 0.402544 0.311627 0.306498 0.651223 0.0 \n", + "63 0.0 0.0 0.707899 0.284301 0.307483 0.662756 0.0 \n", + ".. ... ... ... ... ... ... ... \n", + "60 0.0 0.0 0.000000 1.000000 0.000000 1.000000 0.0 \n", + "56 0.0 0.0 0.000000 0.571298 1.000000 0.000000 1.0 \n", + "55 0.0 0.0 0.000000 1.000000 0.000000 0.000000 0.0 \n", + "54 0.0 0.0 0.000000 1.000000 1.000000 0.489137 0.0 \n", + "79 0.0 0.0 1.000000 1.000000 0.929030 0.000000 0.0 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "62 0.000000 0.0 1.0 0.485275 0.0 \n", + "77 0.000000 1.0 1.0 0.000000 0.0 \n", + "73 0.000000 0.0 0.0 1.000000 0.0 \n", + "72 0.886526 0.0 0.0 0.000000 1.0 \n", + "63 1.000000 0.0 0.0 0.000000 0.0 \n", + ".. ... ... ... ... ... \n", + "60 1.000000 0.0 0.0 0.000000 1.0 \n", + "56 0.000000 1.0 0.0 0.000000 0.0 \n", + "55 0.000000 0.0 1.0 1.000000 0.0 \n", + "54 0.000000 0.0 0.0 1.000000 0.0 \n", + "79 0.000000 0.0 1.0 1.000000 0.0 \n", + "\n", + " y valid_y \n", + "62 -2.625614 True \n", + "77 -2.591009 True \n", + "73 -2.583365 True \n", + "72 -2.566385 True \n", + "63 -2.172833 True \n", + ".. ... ... \n", + "60 -0.001027 True \n", + "56 -0.000423 True \n", + "55 -0.000135 True \n", + "54 -0.000120 True \n", + "79 -0.000114 True \n", + "\n", + "[80 rows x 14 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.sort_values(by=\"y\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((2, 3, 4, 5, 9, 10), {}, -4.135849638238238)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "acqf = strategy._get_acqfs(n=1)[0]\n", + "groups = Groups(\n", + " groups=[\n", + " NChooseK(\n", + " features=list(range(len(benchmark.domain.inputs.get_keys()))),\n", + " max_count=6,\n", + " min_count=0,\n", + " )\n", + " ]\n", + ")\n", + "bounds = utils.get_torch_bounds_from_domain(\n", + " benchmark.domain, strategy.input_preprocessing_specs\n", + ")\n", + "\n", + "\n", + "def reward_fn(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " if len(selected_features) == 0:\n", + " candidates = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + " local_bounds = bounds[:, selected_features]\n", + " sobol_samples = draw_sobol_samples(bounds=local_bounds, n=64, q=1).squeeze(1)\n", + " candidates = torch.zeros((sobol_samples.shape[0], bounds.shape[1]), **tkwargs)\n", + " candidates[:, selected_features] = sobol_samples\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + "\n", + "\n", + "def reward_fn2(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=1,\n", + " raw_samples=64,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()\n", + "\n", + "\n", + "mcts = MCTS(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=reward_fn2,\n", + " use_cache=False,\n", + " cache_hit_mode=\"pessimistic\",\n", + " rollout_mode=\"uniform_subset\",\n", + ")\n", + "mcts.run(n_iterations=3000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.3403870.9393840.0000000.0000000.9888050.0000000.0000000.0000000.0591740.2835080.6511150.000000-0.093813True
10.0000000.0000000.2157870.6949210.6070900.0000000.6598540.8646030.0000000.0000000.0000000.683445-0.004209True
20.1622490.0000000.0000000.0000000.6701020.0000000.0000000.8375380.3137730.0000000.1942530.959623-0.003265True
30.6468610.1510700.0000000.0239440.0000000.0000000.0000000.0000000.1645060.1478150.0000000.930434-0.005103True
40.0000000.4927420.1909610.0000000.0581920.1650810.0000000.1383720.0000000.0000000.0000000.281132-0.040627True
.............................................
750.0000000.0000000.7512170.0000000.6833750.4044430.0000000.0000000.0000000.0000000.0000000.000000-0.338749True
760.0000000.0000001.0000001.0000000.0000000.5332760.0000000.0000000.0000001.0000000.0000000.000000-0.007243True
770.0000000.0000000.4359850.2888460.3115910.6914390.0000000.0000001.0000001.0000000.0000000.000000-2.591009True
780.0000000.0000000.4243290.4090350.3438970.7186910.0000000.0000000.5095360.6181830.0000000.000000-2.122833True
790.0000000.0000001.0000001.0000000.9290300.0000000.0000000.0000000.0000001.0000001.0000000.000000-0.000114True
\n", + "

80 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.340387 0.939384 0.000000 0.000000 0.988805 0.000000 0.000000 \n", + "1 0.000000 0.000000 0.215787 0.694921 0.607090 0.000000 0.659854 \n", + "2 0.162249 0.000000 0.000000 0.000000 0.670102 0.000000 0.000000 \n", + "3 0.646861 0.151070 0.000000 0.023944 0.000000 0.000000 0.000000 \n", + "4 0.000000 0.492742 0.190961 0.000000 0.058192 0.165081 0.000000 \n", + ".. ... ... ... ... ... ... ... \n", + "75 0.000000 0.000000 0.751217 0.000000 0.683375 0.404443 0.000000 \n", + "76 0.000000 0.000000 1.000000 1.000000 0.000000 0.533276 0.000000 \n", + "77 0.000000 0.000000 0.435985 0.288846 0.311591 0.691439 0.000000 \n", + "78 0.000000 0.000000 0.424329 0.409035 0.343897 0.718691 0.000000 \n", + "79 0.000000 0.000000 1.000000 1.000000 0.929030 0.000000 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.059174 0.283508 0.651115 0.000000 \n", + "1 0.864603 0.000000 0.000000 0.000000 0.683445 \n", + "2 0.837538 0.313773 0.000000 0.194253 0.959623 \n", + "3 0.000000 0.164506 0.147815 0.000000 0.930434 \n", + "4 0.138372 0.000000 0.000000 0.000000 0.281132 \n", + ".. ... ... ... ... ... \n", + "75 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "76 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "77 0.000000 1.000000 1.000000 0.000000 0.000000 \n", + "78 0.000000 0.509536 0.618183 0.000000 0.000000 \n", + "79 0.000000 0.000000 1.000000 1.000000 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.093813 True \n", + "1 -0.004209 True \n", + "2 -0.003265 True \n", + "3 -0.005103 True \n", + "4 -0.040627 True \n", + ".. ... ... \n", + "75 -0.338749 True \n", + "76 -0.007243 True \n", + "77 -2.591009 True \n", + "78 -2.122833 True \n", + "79 -0.000114 True \n", + "\n", + "[80 rows x 14 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [], + "source": [ + "def reward_fn3(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(2, 3, 4, 5, 9)\n", + "-41.97450586826323\n", + "-4.2207785897443415\n" + ] + } + ], + "source": [ + "leaf, path = mcts._select_and_expand()\n", + "selected_features, cat_selections = mcts._get_selection(leaf)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))\n", + "print(reward_fn2(selected_features, cat_selections={}))\n", + "# print(reward_fn3(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(0, 11)\n", + "-45.67180087515713\n", + "-45.67175933157002\n", + "-45.67174110576808\n" + ] + } + ], + "source": [ + "# leaf, path = mcts._select_and_expand()\n", + "# selected_features, cat_selections = mcts._get_selection(leaf)\n", + "selected_features = (0, 11)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))\n", + "print(reward_fn2(selected_features, cat_selections={}))\n", + "print(reward_fn3(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "80" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(strategy.experiments)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'SpuriousFeaturesWrapper' object has no attribute 'inputs'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[223]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[43mreward_fn2\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m7\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m11\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcat_selections\u001b[49m\u001b[43m=\u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[65]\u001b[39m\u001b[32m, line 27\u001b[39m, in \u001b[36mreward_fn2\u001b[39m\u001b[34m(selected_features, cat_selections)\u001b[39m\n\u001b[32m 26\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mreward_fn2\u001b[39m(selected_features: \u001b[38;5;28mtuple\u001b[39m[\u001b[38;5;28mint\u001b[39m, ...], cat_selections: \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mint\u001b[39m, \u001b[38;5;28mfloat\u001b[39m]) -> \u001b[38;5;28mfloat\u001b[39m:\n\u001b[32m---> \u001b[39m\u001b[32m27\u001b[39m fixed = {i: \u001b[32m0.0\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(\u001b[43mbenchmark\u001b[49m\u001b[43m.\u001b[49m\u001b[43minputs\u001b[49m.domain.get_keys())) \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m selected_features}\n\u001b[32m 28\u001b[39m candidates, acq_value = optimize_acqf(\n\u001b[32m 29\u001b[39m acq_function=acqf,\n\u001b[32m 30\u001b[39m bounds=bounds,\n\u001b[32m (...)\u001b[39m\u001b[32m 34\u001b[39m fixed_features=fixed,\n\u001b[32m 35\u001b[39m )\n\u001b[32m 36\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m acq_value.item()\n", + "\u001b[31mAttributeError\u001b[39m: 'SpuriousFeaturesWrapper' object has no attribute 'inputs'" + ] + } + ], + "source": [ + "print(reward_fn2((0, 2, 3, 7, 11), cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.93938361, 0. , 0. , 0.15107026, 0.49274244,\n", + " 0.34103735, 0.31871182, 0.27325514, 0. , 0.86132336,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 1. , 0. , 0. , 0. ,\n", + " 1. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.x_1.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-0.09381345, -0.09381345, -0.09381345, -0.09381345, -0.09381345,\n", + " -0.09381345, -0.21600644, -0.21600644, -0.21600644, -0.21600644,\n", + " -0.21600644, -0.21600644, -0.21600644, -0.21600644, -0.21600644,\n", + " -0.21600644, -0.21600644, -0.3837685 , -0.3837685 , -0.3837685 ,\n", + " -0.3837685 , -0.3837685 , -0.3837685 , -0.39329869, -0.39494433,\n", + " -0.39494433, -0.39494433, -0.39521691, -1.27114381, -1.27114381,\n", + " -1.27114381, -1.27114381, -1.27114381, -1.44489719, -1.57146527,\n", + " -1.57146527, -1.57146527, -1.57146527, -1.57146527, -1.78566843,\n", + " -1.78566843, -1.78566843, -1.78566843, -1.88600201, -1.90148069,\n", + " -1.90148069, -1.90252527, -1.90252527, -1.90252527, -1.90252527,\n", + " -1.90252527, -1.91362527, -1.91362527, -1.91362527, -1.91362527,\n", + " -1.91362527, -1.91362527, -1.91362527, -1.91412484, -1.91412484,\n", + " -1.91412484, -1.91412484, -2.62561401, -2.62561401, -2.62561401,\n", + " -2.62561401, -2.62561401, -2.62561401, -2.62561401, -2.62561401,\n", + " -2.62561401, -2.62561401, -2.62561401, -2.62561401, -2.62561401,\n", + " -2.62561401, -2.62561401, -2.62561401, -2.62561401, -2.62561401])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "best_y_per_iteration = strategy.experiments[\"y\"].cummin()\n", + "best_y_per_iteration.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Best y per iteration')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAHHCAYAAABTMjf2AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQdtJREFUeJzt3Qd4VFX6x/F3UkkhCRAgdEKTjkgTVEBgRQUFZbGAUvQPgrhrXQVdRVYR3NVVxF2QXQVRLLgrCLqL0sQV6UgVkCqh94QQSJv7f94TZpwUIIEkU+738zzXyfR778TMj3Pec47DsixLAAAAbCjI2zsAAADgLQQhAABgWwQhAABgWwQhAABgWwQhAABgWwQhAABgWwQhAABgWwQhAABgWwQhAABgWwQhACikb7/9VhwOh7n0F9OmTTP7vGfPHm/vCuCTCEJAAHB92XlulSpVkhtvvFH++9//ltj7pqWlyYsvvuhXwaC4ffTRR/Lmm296ezfklVdekdmzZ3t7NwC/42CtMSAwgtDgwYPlT3/6kyQmJoouIXj48GFz++bNm2Xu3LnSs2fPYn/fY8eOScWKFWX06NEmEAU6p9MpGRkZEhYWJkFBOf+O1PO6adMmr7e4REdHy29/+1vzmXvKzs6WzMxMCQ8PNwEZQG4hea4D8GO33HKLtG7d2n39wQcflMqVK8vHH39cIkEoEJ05c0aioqIKvE/DT5kyZUotcBXHewUHB5sNQMHoGgMCWFxcnEREREhISEi+L1rtzmnSpIn5stWw9NBDD8nJkydzPW716tXSvXt3iY+PN6+jrU0PPPCAuU9bQLQ1SI0ZM8bdJXehlqFdu3aZ+99444189/3www/mPg1sl6rP+fTTT+XZZ5+VhIQEE1huv/12SUpKyvf4FStWyM033yyxsbESGRkpnTp1kqVLl+Z6jO6rvuZPP/0k/fr1k3Llysn1119/yX1wdQV27txZvvrqK/nll1/cx1+7dm3349PT001rWb169UyLTI0aNeTpp582t3vS5z3yyCMyY8YM85noY+fNm2fue+2116RDhw5SoUIF8xm0atVK/vWvf+V7vga4999/370fgwYNumiN0N///nf3e1WtWlVGjBghp06dyvUYPb6mTZua86PdrHoeq1WrJn/+858veI4Af0OLEBBAkpOTTXeVdo0dOXJEJk6cKKmpqXLfffflepyGHld32u9//3vZvXu3vP322/Ljjz+asBAaGmqef9NNN5mwM3LkSBOq9Mv0888/N6+ht0+aNEmGDx8ud9xxh9x5553m9ubNmxe4b3Xq1JHrrrvOfNk//vjjue7T28qWLSu9evW65DGOHTvWfLE/88wzZh810HXr1k3WrVtngoJatGiRaR3T0KBBRFtypk6dKl26dJH//e9/0rZt21yv2bdvX6lfv76ps9FzV1jPPfecOef79u1zBzztonKFTQ1p33//vQwdOlQaNWokGzduNI/7+eef89Xz6D7PnDnTBCINnq5ANWHCBPM6/fv3N61En3zyidnfL7/8Unr06GEe88EHH8j//d//mePS91J169a94H5rANTwqudNP79t27aZz3LVqlXuz99Fw7EGSv1877rrLhPC9Nw3a9bMnGPA72mNEAD/NnXqVP32zreFh4db06ZNy/XY//3vf+a+GTNm5Lp93rx5uW6fNWuWub5q1aoLvu/Ro0fNY0aPHl2o/XznnXfM47ds2eK+LSMjw4qPj7cGDhx40ecuXrzYPLdatWpWSkqK+/aZM2ea2ydMmGCuO51Oq379+lb37t3Nzy5paWlWYmKi9Zvf/MZ9m+63Pvfee+8t1P679kEvXXr06GHVqlUr32M/+OADKygoyJxvT5MnTzavsXTpUvdtel0fu3nz5nyvo/vtSc9X06ZNrS5duuS6PSoqqsBz6Prd2L17t7l+5MgRKywszLrpppus7Oxs9+Pefvtt87j33nvPfVunTp3MbdOnT3fflp6ebiUkJFh9+vS5yJkC/AddY0AA+dvf/ibz588324cffmi6M7SlwNWKoz777DPTXfSb3/zGtB65Nm090daMxYsXm8dpC5DSlgctti0O2qKgXXHaAuTy9ddfm/fP22p1IQMGDDCtRy5aIFylShX5z3/+Y65ry9D27dtNV9fx48fdx6ddR127dpXvvvvOtNZ4GjZsmBQ3Pc/aCtSwYcNc51lbpZTrPLto113jxo3zvY6rlcvVOqMtUDfccIOsXbv2svZrwYIFpmXpsccecxd8qyFDhkhMTIzp6vOkvxOen40WimvLk3Z1AoGArjEggOgXlGex9L333istW7Y03S1aLK1fYhoS9MtUh9cXRLubXF/Mffr0MV0o2p2j9SK9e/c2AUPrSi6HhqvbbrvNDDl/6aWXzG0airTuxBUQLkW7sDxpN5nW4LhqYPT41MCBAy/4Gnr8Wg/korVPxU33Y8uWLe46qgud50vtgwbRl19+2QQ8z9qiyx0BpvVM6qqrrsp1u/5uaPel636X6tWr53svPXcbNmy4rPcHfA1BCAhg+i9+bRXSOhP9YtbiWG0N0RDk2SrjyfXFrV9+Wg+yfPlyM/xeW260UPr11183t7lqYYpKW3S0tUQLpLXOZM6cOfLwww/nap24Eq7Wnr/85S9y9dVXF/iYvPvu2epSXHQ/9Pj++te/Fni/Fk5fah+0nknrgzp27GiKm7XlS+t3tN5Jw2RpuNCIs6LUUgG+jCAEBLisrCxzqUXTriJa7R7RwuXCBIBrr73WbFqkrF++WrSrBbva5XY5rRJaeKthS4NYu3btzKSM999/f6Gf72rx8fxC3rFjh7tI21UkrN08Wgxc0i50DnQ/1q9fb7rjLrf15t///rfpStQQ6tkKp0GosPuRV61atcylFkhrC5CLdpdp0XxpnDPAl1AjBAQwre355ptvTLeH1qu46nR0kj1X11Te0OQaQq31KHn/1e9qYXF10ehwapV32PXF6FB+7bLTEVI6ck1bTS400qwg06dPl9OnT7uva6vVwYMH3SOYtNZJQ4gOO3eFP09Hjx6V4qRD+LWrLS89z/v375d//OMf+e47e/asqVkqTGuMBhz9vFy0C7CgGaR1PwrzOWjQ0d+Ht956K9fn++6775rjcI1EA+yCFiEggOhyGlu3bnXXoGgLjrag6PB3bSFx1f7o8Plx48aZuhMdIq/dLfo47bLSbjQtQNY5abQ7RofGa7DQ8KFf6vo6t956q3ktbVHSAl+d26dBgwZSvnx5M++MbpfqHtMvYi0YfvXVV4t0jPoeOtePDv3X2bN1+LzWCGmxr9Iutn/+858mGGlXoD5Oa5A0lOj76f5rV19x0eClx//EE09ImzZtTLeb1kFpK5eGPS3E1vfVFjgNNPr56O3ayuNZz1UQDSXataataFqbpZ+pFsTr8eat0dH90JY+fbzOC6Q1R9rilpe2xo0aNcrUfunratebtg7pZ637X9iidSBgeHvYGoCSGT5fpkwZ6+qrr7YmTZqUaxi5y5QpU6xWrVpZERERVtmyZa1mzZpZTz/9tHXgwAFz/9q1a82w8po1a5ph+JUqVbJ69uxprV69Otfr/PDDD+Z1dEh2UYbSN2nSxAwZ37dvX5GGrn/88cfWqFGjzP7ovuvw9V9++SXf43/88UfrzjvvtCpUqGD2X4e433XXXdbChQvzDZ/XaQAud/h8amqq1a9fPysuLs7c5zmUXoe6v/rqq+ZYdR/KlStnztWYMWOs5ORk9+P0eSNGjCjwPd99910zHYA+v2HDhuazdu23p61bt1odO3Y050Tvcw2lzzt83nO4vL5eaGioVblyZWv48OHWyZMncz1Gh8/rvuelr13QlAGAP2KtMQBeoaPZtHVn4cKFhXq8zuashd/aaqUtVgBQHKgRAlDqdOkO7ZbTLjIA8CZqhACUGl2lfc2aNWYIvg4Fv/vuuzn7ALyKFiEApUZHeGnxso5m0wVWS2MldwC4GGqEAACAbdEiBAAAbIsgBAAAbIti6UKsF3TgwAGz2vXlTpMPAABKl07RpRPB6gSjF13L0PIzOgmYTuSlk4u1bdvWWrFixUUfP3PmTOuqq64yj2/atKn11VdfFen9kpKS8k1Ux8Y54HeA3wF+B/gd4HdA/OIc6Pf4xfhVi5BrGvvJkyebqeN1av3u3bub6eF1Ne28dHVrXdNIlxLo2bOnWW6gd+/esnbt2ksuAeCiLUEqKSnJvUQBAADwbSkpKVKjRg3393hAjBrT8KNr4bz99tvubis9yN/97ndmLaW8dI4SXdjwyy+/dN+mq2jrwpEapgp7ImNjY81ihAQhAAD8Q2G/v/2mWDojI8NMxKYrJ7ton59eX7ZsWYHP0ds9H6+0BelCj3etqq0nz3MDAACByW+C0LFjx8zKzZUrV851u14/dOhQgc/R24vyeKXdaJogXZu2OAEAgMDkN0GotIwaNco0o7k2rQ0CAACByW+KpePj4yU4OFgOHz6c63a9npCQUOBz9PaiPF6Fh4ebDQAABD6/aREKCwuTVq1aycKFC923abG0Xm/fvn2Bz9HbPR+v5s+ff8HHAwAAe/GbFiGlQ+cHDhworVu3lrZt25rh8zoqTBdxVAMGDJBq1aqZOh/16KOPSqdOncxK1z169JBPPvlEVq9eLVOmTPHykQAAAF/gV0FIh8MfPXpUXnjhBVPwrMPg582b5y6I3rt3b67ZIzt06GDmDvrjH/8ozz77rNSvX19mz55d6DmEAABAYPOreYS8gXmEAADwPwE3jxAAAEBxIwgBAADbIggBAADbIggBAADb8qtRY4Hk5JkMOZORVervW6lsGQkLIf8CAKAIQl7yl2+2yUcr9pb6+ybGR8nXj3UkDAEAQBDyntAgh4SXcstMRrZTdh87I9/vOCpdGuZejBYAADtiHiEbzSM0Zu5mmbp0j/S6uqpMuKelt3cHAIASwzxCyKfX1dXM5TebD0uaF+qTAADwNVTN2kiL6rFSq0KknM3Mlvk/Hfb27gAA4HUEIRtxOBzSq0VV8/OcdQe8vTsAAHgdQchmbr86Jwgt+fmoGcIPAICdEYRspl6lstKkaoxkOS35auNBb+8OAABeRRCyIR01pugeAwDYHUHIhm5rUVUcDpGVe07I/lNnvb07AAB4DUHIhqrERkjb2uXNz3PXUzQNALAvgpDN5xT6gtFjAAAbIwjZ1K3NEiQ02CFbDqbIz4dPe3t3AADwCoKQTcVFhkmnBpXMzxRNAwDsiiBkY67RY1+s3y+WZXl7dwAAKHUEIRvr1qiyRIYFS9KJs7J27ylv7w4AAKUupPTfEr4iIixYujdJkFk/7peB76001wNVkENkUIdEGd65rrd3BQDgQwhCNndv25omCKWmZ5ktkE1eslOGdqwjwZqKAAAgCKFtYnlZNqqLnErLDNiToeVPd7+zTJLPZppRck2rxXp7lwAAPoIWIZgJFnUL9MC3cOsR+WHnMYIQAMCNYmnYQvu6Fczlsp3Hvb0rAAAfQhCCrYLQyt0nJDPb6e3dAQD4CIIQbKFRQozERYbKmYxs2bAv2du7AwDwEQQh2EJQkEOuTcxpFVq+i+4xAEAOghBso0O9nCCkBdMAACiCEGyjfZ2cILR6z0lJz8r29u4AAHwAQQi2Ua9StMRHh0t6llN+ZEkRAABBCHbicDikw/nRYz8wjB4AQBCCfecTok4IAEDXGGzG1SK0LumUpGUE9tpqAIBLo0YItlKzfKRUjS0jmdmWKZoGANgbQQi2qxNqXzfe/LyM+YQAwPYIQrAdCqYBAC4EIdi2YHrjvlOSci7T27sDAPAighBsp2pchNSuEClOS2TV7hPe3h0AgBcRhGDrViHmEwIAeyMIwZZcBdMEIQCwN4IQbL3u2JaDKXLyTIa3dwcA4CUh3npjwJsqlg2X+pWiZfuRVPndxz9KfHRYqb13+ahwaVkzTq6pVc7MaaRD+gEA3kEQgm11alDRBKHvd3hhuY2lOReVY8KlZY1yck2tOKkcU0YiQoOlTGiwRIQFm5/DQ4KKHJSCHDnzJellkMNhXk+DHwAgP4dlWVYBt+O8lJQUiY2NleTkZImJieG8BBAdOj9n3QE5l5ldau+p/7clnUyTH/eekp8Opki2Dl0rBWPvaCr929UqlfcCAH/6/qZFCLYVUyZU7rvWe+HgbEa2bNh3Sn5MOmUuT6VlmlB2NtOZc5mRLelZ2WIVMWjpv2300mlZkpHtNMuJrPnlJEEIAApAEAK8RLu/2tWpYLaS8sGyPfL8F5slLb30Wr0AwJ8wagwIYFHhOf/WOZOR5e1dAQCfRBACAlhk2PkglE4QAoCCEISAABbtahGiawwACkQQAgJYZHiwuUylRQgACkQQAmzQIpRGjRAAFIggBASwyLCcFiG6xgCgYAQhwAYtQjqfUEaW09u7AwA+hyAE2GDUmKJ7DAD8OAidOHFC+vfvb6bJjouLkwcffFBSU1Mv+pzOnTubNZc8t2HDhpXaPgPeFhYSJGHBOf+bn8lgUkUA8NuZpTUEHTx4UObPny+ZmZkyePBgGTp0qHz00UcXfd6QIUPkT3/6k/t6ZGRkKewt4DuiwoMlI83JXEIA4K9BaMuWLTJv3jxZtWqVtG7d2tw2ceJEufXWW+W1116TqlWrXvC5GnwSEhJKcW8B3+seO5mWyRB6APDXrrFly5aZ7jBXCFLdunWToKAgWbFixUWfO2PGDImPj5emTZvKqFGjJC0t7aKPT09PNyvWem5AQAyhZ1JFAPDPFqFDhw5JpUqVct0WEhIi5cuXN/ddSL9+/aRWrVqmxWjDhg3yzDPPyLZt2+Tzzz+/4HPGjRsnY8aMKdb9B7yJSRUBwEeD0MiRI+XVV1+9ZLfY5dIaIpdmzZpJlSpVpGvXrrJz506pW7dugc/RVqMnnnjCfV1bhGrUqHHZ+wB4G5MqAoCPBqEnn3xSBg0adNHH1KlTx9T4HDlyJNftWVlZZiRZUep/2rVrZy537NhxwSAUHh5uNiBQRLHwKgD4ZhCqWLGi2S6lffv2curUKVmzZo20atXK3LZo0SJxOp3ucFMY69atM5faMgTYrWuM4fMA4KfF0o0aNZKbb77ZDIVfuXKlLF26VB555BG555573CPG9u/fLw0bNjT3K+3+eumll0x42rNnj8yZM0cGDBggHTt2lObNm3v5iABvrECfxWkHAH8MQq7RXxp0tMZHh81ff/31MmXKFPf9OreQFkK7RoWFhYXJggUL5KabbjLP0264Pn36yNy5c714FID3ZpdmBXoA8NNRY0pHiF1s8sTatWuLZVnu61rgvGTJklLaO8B3RZ/vGmP4PAD4cYsQgCtsEcqgawwA8iIIAbaZUJEgBAB5EYSAABflLpZm0VUAyIsgBNhm+DwtQgCQF0EICHAMnweACyMIAQEuMiynRSiVrjEAyIcgBAQ41hoDgAsjCAE2GT6flpEtTuevc20BAAhCgG1ahFRaJiPHAMATLUJAgCsTGiRBjpyfWW8MAHIjCAEBzuFwSNT57jGCEADkRhACbIBJFQGgYAQhwEaTKrICPQDkRhACbIAh9ABQMIIQYKtJFVlmAwA8EYQAW7UIMXweADwRhABbFUvTIgQAnghCgA24Zpc+w3pjAJALQQiwgejzo8bOZNAiBACeCEKAjVqEKJYGgNwIQoCdiqWpEQKAXAhCgK0mVGTUGAB4IggBNsCEigBQMIIQYANRLLoKAAUiCAE26ho7w4SKAJALQQiwUdcYEyoCQG4EIcAGGD4PAAUjCAE2W2vMsixv7w4A+AyCEGCjGqFspyXpWU5v7w4A+AyCEGADUedHjSnqhADgVwQhwAaCgxwSEXp+5BiTKgKAG0EIsIkoFl4FgHwIQoBNRDGEHgDyIQgBNsEQegDIjyAE2ET0+a4xHUIPAMhBEAJsghYhAMiPIATYbVLF9Cxv7woA+AyCEGC7UWN0jQGAC0EIsFnXGBMqAsCvCEKATbACPQDkRxACbLbeWCozSwOAG0EIsN0K9BRLA4ALQQiwCYbPA0B+BCHAJphQEQDyIwgBNsFaYwCQH0EIsNvweWqEAMCNIATYbvg8EyoCgAtBCLCJyDDX8HlGjQGAC0EIsFmLUEaWUzKznd7eHQDwCQQhwGYTKqo0uscAwCAIATYRHhIsocEO8zMF0wCQgyAE2AhD6AEgN4IQYCNR7iH0jBwDAEUQAmwk6nyd0BlGjgGAQRACbIT1xgDAT4PQ2LFjpUOHDhIZGSlxcXGFeo5lWfLCCy9IlSpVJCIiQrp16ybbt28v8X0FfBUr0AOAnwahjIwM6du3rwwfPrzQz/nzn/8sb731lkyePFlWrFghUVFR0r17dzl37lyJ7ivg+5MqUiMEACqnctIPjBkzxlxOmzat0K1Bb775pvzxj3+UXr16mdumT58ulStXltmzZ8s999xTovsL+HSLEDVCAOBfLUJFtXv3bjl06JDpDnOJjY2Vdu3aybJlyy74vPT0dElJScm1AYGC4fMAYJMgpCFIaQuQJ73uuq8g48aNM4HJtdWoUaPE9xUo7dmlGT4PAD4QhEaOHCkOh+Oi29atW0t1n0aNGiXJycnuLSkpqVTfHyhJ0a55hOgaAwDv1wg9+eSTMmjQoIs+pk6dOpf12gkJCeby8OHDZtSYi16/+uqrL/i88PBwswGBKPJ8jRAr0AOADwShihUrmq0kJCYmmjC0cOFCd/DReh8dPVaUkWdAIIk+3zWWxszSAOBfNUJ79+6VdevWmcvs7Gzzs26pqanuxzRs2FBmzZplftZutccee0xefvllmTNnjmzcuFEGDBggVatWld69e3vxSADvYUJFALiCFiEdkq41M5UqVZIyZcpIadKJEd9//3339ZYtW5rLxYsXS+fOnc3P27ZtM3U9Lk8//bScOXNGhg4dKqdOnZLrr79e5s2bV+r7DvgKJlQEgNwclqabQnI6nSZEbN68WerXry92oN1pOnpMA1ZMTIy3dwe4Iit3n5C73lkmifFRsvipnH9AAICdv7+L1DUWFBRkAtDx48eLYx8BeGlmaUaNAcBl1giNHz9e/vCHP8imTZuK+lQAPtI1RhACgMscNaYFx2lpadKiRQsJCwszi5l6OnHiRFFfEoAXJlR0Oi0JCnJw7gHYWpGDkK7fBcC/W4TU2cxs95IbAGBXRf4rOHDgwJLZEwAlLiI0WBwOHQGa0z1GEAJgd1f0z8Fz585JRkZGrtsYWQX4Lp1fKyosxMwszXpjAHAZxdI6L88jjzxi5hKKioqScuXK5doA+LYoV50Q640BQNGDkE5SuGjRIpk0aZJZk+uf//ynjBkzxszYPH36dE4p4OOiWHgVAC6/a2zu3Lkm8OhszoMHD5YbbrhB6tWrJ7Vq1ZIZM2ZI//79i/qSAEqRqy7oTEYW5x2A7RW5RUiHx7tWhNd6INdweV2+4rvvvrP9CQX8ZVLF1PRsb+8KAPhfENIQtHv3bvcipzNnznS3FMXFxRX/HgIomfXGqBECgKIHIe0OW79+vfl55MiR8re//c2sP/b444+bGacB+LbI80FIR44BgN0VuUZIA49Lt27dZOvWrbJmzRpTJ9S8efPi3j8AxSz6/KixtAy6xgCgUC1C5cuXl2PHjpmfH3jgATl9+rT7Pi2SvvPOOwlBgJ+IYtQYABQtCOmkibqcvXr//ffNRIoA/LtrjFFjAFDIrrH27dtL7969pVWrVmJZlvz+97/Pt9iqy3vvvcd5Bfyga+wMo8YAoHBB6MMPP5Q33nhDdu7caaboT05OplUI8FOR57vGKJYGgEIGocqVK8v48ePNz4mJifLBBx9IhQoVOH+APw+fZ0JFACj6qDHXHEIA/BMTKgLAFcwjBMC/MaEiAPyKIATYda0xJlQEAIIQYDdRrlFjTKgIAAQhwM4tQjodBgDYWZGLpVV2drbMnj1btmzZYq43adJEbr/9dgkOzvmXJgDfHz6f5bQkPcspZUL5/xaAfRU5CO3YsUN69Ogh+/btk6uuusrcNm7cOKlRo4Z89dVXUrdu3ZLYTwDFJCrs1+Cj640RhADYWZGLpXVW6Tp16khSUpKsXbvWbHv37jXzC+l9AHxbSHCQhIfk/K9PwTQAuytyi9CSJUtk+fLlZiFWF51cUSdcvO6664p7/wCU0BD69KwM1hsDYHtFbhEKDw/Ptfq8S2pqqoSFhdn+hAL+gCH0AHCZLUI9e/aUoUOHyrvvvitt27Y1t61YsUKGDRtmCqYB+M/s0g99sEbCQ4q/WLpRlbIy6b5WEhrMVGUAAiwIvfXWWzJw4ECzIn1oaKi5LSsry4SgCRMmlMQ+AihmjavEyNZDp+VYakaJnNv9p87K6j0npX1d1iQE4Nsc1mVOJKKjx1zD5xs1aiT16tWTQJSSkiKxsbGSnJwsMTEx3t4doFhkZTtly8HTkl0C8wi9tXC7LNp6RB7rVl8e69ag2F8fAIrz+/uy5hFSGnx00zmFNm7cKCdPnpRy5cpd7ssBKOWRY82qx5bIa3dtVMkEoRW7TpTI6wNAcSpyB/5jjz1m6oOUhqBOnTrJNddcY+YR+vbbb4t15wD4n3aJOd1ha/eelPSsbG/vDgAUbxD617/+JS1atDA/z507V3bt2iVbt26Vxx9/XJ577rmivhyAAFO3YpTER4ebWavXJyV7e3cAoHiD0LFjxyQhIcH8/J///EfuuusuadCggTzwwAOmiwyAvTkcDmlXJ2eeseW7jnt7dwCgeINQ5cqV5aeffjLdYvPmzZPf/OY35va0tDTWGgNgXJuYE4RW7CYIAfBtRS6WHjx4sGkFqlKlivmXX7du3dxzCTVs2LAk9hGAn7m2Tk6d0JpfTkpGllPCzi/pAQB+H4RefPFFadq0qVlrrG/fvmamaaUrz48cObIk9hGAn6lXKVoqRIXJ8TMZsmHfKWld+9cleQDAl1zW8Pnf/va3+W7TSRYBQGlrcdvE8vLfTYdkxe4TBCEAPov2agAl2j1GwTQAX0YQAlAiXCPHtE4oM9vJWQbgkwhCAEpEg0plJS4yVNIysmXjfuYTAuCbCEIASuaPS5BD2p0fRk/3GICACUK6pMb06dPl7NmzJbNHAAJuuQ3WHQMQMEGoZcuW8tRTT5nZpYcMGSLLly8vmT0DEDB1Qqv3nDAr3gOA3wehN998Uw4cOCBTp06VI0eOSMeOHaVx48by2muvyeHDh0tmLwH4pUYJMRIbESpnMrJl04EUb+8OABRPjVBISIjceeed8sUXX8i+ffukX79+8vzzz5sV6Hv37i2LFi26nJcFEIB1Qm3OT6a4gnXHAARasfTKlStl9OjR8vrrr0ulSpVk1KhREh8fLz179jTdZwBwLQuwAgikmaW1O+yDDz4wXWPbt2+X2267TT7++GPp3r27mU1WDRo0SG6++WbTXQbA3lwTK67ec9LUCYUEM1gVgB8HoerVq0vdunXlgQceMIGnYsWK+R7TvHlzadOmTXHtIwA/1qhKjJQtEyKnz2XJTwdTpHn1OG/vEgBcfhBauHCh3HDDDRd9TExMjCxevLioLw0gAAWfrxNatPWIGUZPEALgS4rcRn2pEAQAeVEnBMBX0VkPoMR1qBtvLn/YeVzSMrI44wB8BkEIQIlrUjVGapSPkLOZ2aaLDAB8BUEIQInTEaU9mlU1P3+5/iBnHID/BiFdZyw9PT3f7RkZGea+kjJ27Fjp0KGDREZGSlxc4Uad6Kg2/QPsuemwfgClr2fzKuZy8bYjkppO9xgAPw1CgwcPluTk5Hy3nz592txXUjRo9e3bV4YPH16k52nwOXjwoHvTOY8AeKd7LDE+StKznLLgJ5bjAeCnw+cty3JPnOhJl9qIjY2VkjJmzBhzOW3atCI9Lzw83CwQC8C79O+GtgpNXLRDvtxwQHq3rMZHAsB/gpCuOu/qXuratatZb8wlOztbdu/e7ZPdTt9++61Z/qNcuXLSpUsXefnll6VChZyZbgGUrp7Nq5ogtOTno5J8NtMsyAoAfhGEdDFVtW7dOrOcRnR0tPu+sLAwqV27tvTp00d8iQYzXRw2MTFRdu7cKc8++6zccsstsmzZMgkODi7wOVr/5FkDlZLCitlAcbkqoazUrxQt24+kyjebD0nf1jU4uQD8Iwjp4qpKA88999xjupyu1MiRI+XVV1+96GO2bNkiDRs2vKzX1/10adasmVn6Q5cH0VYibdUqyLhx49zdcABKplXojQU/y5cbDhKEAPhfsbR2Lx09ejTXCvSPPfaYTJkypchv/uSTT5qgc7GtTp06Ulz0teLj42XHjh0XfMyoUaNMMbhrS0pKKrb3ByDSs0XO6LGlO47JyTMZnBIA/lUs3a9fPxk6dKjcf//9cujQIenWrZs0bdpUZsyYYa6/8MILhX4tXbC1oEVbS4oWdB8/flyqVMn5Q1wQbekqjtYuAAWrWzHaLMS65WCKzNt8SO5tW5NTBcB/WoQ2bdokbdu2NT/PnDnTdDn98MMPJggVdURXUezdu9fUJ+mlFmfrz7qlpqa6H6NdaLNmzTI/6+1/+MMfZPny5bJnzx6zWGyvXr2kXr16psYJgPfnFNLRYwDgV0EoMzPT3WKyYMECuf32290hROfpKSna0qQj17RWSUOO/qzb6tWr3Y/Ztm2be44jLYbesGGD2b8GDRrIgw8+KK1atZL//e9/tPgAXnZb85xZppftPC5HT+efoBUASovD0omBiqBdu3Zy4403So8ePeSmm24yLS4tWrQwl7/97W9N91Mg0VFjOj+SBqyYmBhv7w4QMG5/+3vZsC9ZXurVRO5vX9vbuwMgwBT2+7vILUI6yuudd96Rzp07y7333mtCkJozZ467ywwACts9NncDa48B8KMWIaU1Opq0dJJCF63D0XXAdPLCQEKLEFAy9p86K9eNXyQ6Uf33z3SRymWLf5BCkMMhQUH5Z8IHEPhSCtkiVORRY0qz05o1a8wkhTqKrGzZsmZSRQ1CAFAY1eIi5JqacbJ27ykTiEpCTJkQ+fSh9maUGgAUS9fYL7/8YkaK6QisESNGuOcU0i6zp556qqgvB8DGBl2XKCXZYJNyLkvmbTpUcm8AwO8VuUXo0UcfldatW8v69etzrdl1xx13yJAhQ4p7/wAEsNtbVJVujSpJZlaRe+gv6eNVe2X8f7fKpv05I0kBoFiCkA4/13mDtCvMky69sX///qK+HACbiwwLEcn956RYtKmdU8O4gSAEoDi7xpxOpymWzkuHzWutEAD4gsZVYk23m85TdDjlnLd3B0CgBCGdO+jNN990X3c4HGaCQ53o8NZbby3u/QOAyxIRFiz1K+X842zjPrrHABRTEHr99ddl6dKl0rhxYzl37pwZNebqFrvUSvIAUJqaVos1l3SPASi2GqHq1aubQulPP/3UXGprkC5f0b9/f4mIiCjqywFAiWlWLUb+vVYomAZQvPMIhYSEmOCjGwD4qmbV48ylLuWh859pVz4AXFEQOn78uHvYfFJSkvzjH/+Qs2fPym233SYdO3Ys6ssBQIlpXCXGFEwfS9WC6XRJiC3D2QZweTVCGzduNLVAuoSGrjS/bt06adOmjbzxxhsyZcoU6dKli8yePbuwLwcApVIw3aDy+YJphtEDuJIg9PTTT5sZpb/77juz4GrPnj3NCvS6hsfJkyfloYcekvHjx3OSAfhkwfTGfae8vSsA/DkIrVq1SsaOHSvXXXedvPbaa3LgwAF5+OGHJSgoyGy/+93vZOvWrSW7twBQRM1cQYgWIQBXEoROnDghCQkJ5ufo6GiJiorKtfq8/nz69OnCvhwAlIpm1V1BKMUUTAPAZc8jlHfEBSMwAPhDwXRwkMMUTB9ihmkAVzJqbNCgQRIeHm5+1skUhw0bZlqGVHp6elFeCgBKRZlQnWE6WrYeOm1mmK4Sy3xnAC4jCA0cODDX9fvuuy/fYwYMGFDYlwOAUq0TMkFof7Lc1CSnix8AihSEpk6dyhkD4Ld1Qp+t2UfBNIArX2sMAPx15Nim/TkzTAOAC0EIQMBr5C6YzpCDyee8vTsAfAhBCIAtCqaZYRpAQQhCAGyzEr3SkWMA4EIQAmALzDANoCAEIQC20Kx6nLmkYBqAJ4IQAFtomFBWQoIccvxMhhygYBrAeQQhAPYrmKZOCMB5BCEANqwTOuXtXQHgIwhCAGyjqcdK9ABQ5EVXAcCfNT/fIrT2l5Py3KyNV/RaNcpHytAb6khQkKOY9g6ANxCEANjGVQllJSI0WFLTs2TGir1X/Ho1y0fKrc2qFMu+AfAOghAAWxVMTx3cRlbsOnFFr7N+3ylZtPWITPthD0EI8HMEIQC2cm2dCma7EoeSz8l1ry6SlbtPyE8HUqRx1ZxZqwH4H4qlAaCIEmLLyC1NE8zP7/+wh/MH+DGCEABchkEdapvL2ev2y4kzGZxDwE8RhADgMrSqVU6aVouR9CynfLLqyguvAXgHQQgALoPD4ZBBHRLNzx8u+0Wysp2cR8APEYQA4DL1bF5FKkSFmbXL5v90mPMI+CGCEABcwXD8e9vWND9PpWga8EsEIQC4AvddW0uCgxxmKP3mA8mcS8DPEIQA4AowlB7wbwQhALhCg6/LGUr/xboDDKUH/AxBCACu0DU1fx1K//FKhtID/oQgBADFOJT+L19vk7smL5OPVuyV5LRMzi3g4xyWZVne3glflpKSIrGxsZKcnCwxMawnBKBgGVlOeWLmOvlq40Fx/VUNCw6SLg0rye1XVzW1RHk5RCQ0OEjCQoLMZWiwwzwnKMhh7isJcZFhprgbCHQphfz+JggV04kEANeCrF+s2y+zftwvWw+d9rmTUrZMiFl0tkNd3eKlQeVo06IFBBqCUCmfSADIa8vBFLMW2eKtR+RcZv6Zp52WJVnZlmRmOyUj25lzmeUUZym208dHh0mb2uUlpkyo+ALNZNoiFuxwmJarIHOZ0/14JbSl7e42NaRG+chi21f4NoJQKZ9IAPB12U7LzHW0dMdx+WHnMVm150SBAS1Q9bmmurx+Vwtv7wZ87Ps7pLR2CADgXdrC0rx6nNmGd64r6VnZsj4pWdYlnZTMbN8oF9WyVV22LduyxOm03JdXYvuRVFm09YgcOX2u2PYTgYMgBAA2FR4SLG0Ty5stkH2z+ZAJQinnsry9K/BBDJ8HAAS0mIic+qfT55jOAPkRhAAAAU1HyqmUs7QIIT+CEAAgoLlGxNEihIIQhAAAtghCugSKFogDnghCAICAFn2+a0ydpmAa/hiE9uzZIw8++KAkJiZKRESE1K1bV0aPHi0ZGRkXfd65c+dkxIgRUqFCBYmOjpY+ffrI4cOHS22/AQC+MW1AdLirToiCafhhENq6das4nU555513ZPPmzfLGG2/I5MmT5dlnn73o8x5//HGZO3eufPbZZ7JkyRI5cOCA3HnnnaW23wAA3yqYpkUIfjmP0M0332w2lzp16si2bdtk0qRJ8tprrxX4HJ1J8t1335WPPvpIunTpYm6bOnWqNGrUSJYvXy7XXnttqe0/AMD7dUIHk88RhOCfLUIXCjrly194ErA1a9ZIZmamdOvWzX1bw4YNpWbNmrJs2bILPi89Pd1My+25AQACZAg9cwkhEILQjh07ZOLEifLQQw9d8DGHDh2SsLAwiYuLy3V75cqVzX0XMm7cOLM2iWurUaNGse47AKD0MakifDIIjRw50qwofLFN64M87d+/33ST9e3bV4YMGVLs+zRq1CjT2uTakpKSiv09AACli0kV4ZM1Qk8++aQMGjTooo/ReiAXLXa+8cYbpUOHDjJlypSLPi8hIcGMKjt16lSuViEdNab3XUh4eLjZAACBg0kV4ZNBqGLFimYrDG0J0hDUqlUrU/QcFHTxxix9XGhoqCxcuNAMm1daYL13715p3759sew/AMDfaoRYZgN+WCOkIahz586m0FlHiR09etTU+XjW+uhjtBh65cqV5rrW9+jcQ0888YQsXrzYFE8PHjzYhCBGjAGAPWuEKJaGXw6fnz9/vimQ1q169eq57rMsy1zqCDFt8UlLS3Pfp/MNacuRtgjpaLDu3bvL3//+91LffwCAd1EjhAtxWK4kgQLp8HltXdLC6ZiYGM4SAPihuesPyO8+/lHaJZaXTx+iPMIOUgr5/e0XXWMAAFwJaoRwIQQhAEDAYx4hXAhBCAAQ8GJco8ZYdBV5EIQAALaZRyg1PUucTkpj8SuCEAAg4JU9H4Q0A53JYC4h/IogBAAIeGVCgyQ02GF+Ps2kivBAEAIABDxdu9LVKsSkivBEEAIA2GoIPS1C8EQQAgDYqmCakWPwRBACANgCLUIoCEEIAGCvFqFzmd7eFfgQghAAwBZoEUJBCEIAAFsts0GLEDwRhAAA9lp49SwTKuJXBCEAgK1qhE5TIwQPBCEAgL1ahJhZGh4IQgAAW3DVCNEiBE8EIQCAzWqEGD6PXxGEAAA2qxGiWBq/IggBAGyBCRVREIIQAMAWYiJyusbOZTolI8vp7d2BjyAIAQBsITo8JwgpCqbhQhACANhCSHCQRIUFm5+pE4ILQQgAYBtlWXgVeRCEAAC2qxOiRQguBCEAgP1ahJhLCOcRhAAAthFzflJFWoTgQhACANgGNULIiyAEALANFl5FXgQhAIDtFl6lRgguBCEAgO1ahKgRggtBCABgw4VXWYEeOQhCAAAb1ggRhJCDIAQAsF2NEF1jcCEIAQBsN48QLUJwIQgBAGxYI5Tl7V2BjyAIAQBsN6GiBiHLsry9O/ABBCEAgO0WXc12WpKWke3t3YEPIAgBAGwjIjRYgoMc5mfqhKAIQgAA23A4HCy8ilwIQgAAey68epa5hEAQAgDYtE6IkWNQtAgBAGylbPj5FiFmlwZBCABg1xahFOYSAkEIAGA31AjBE11jAABbYXZpeCIIAQBshRXo4YkgBACwZRBi1BgUQQgAYCsxEcwjhF8RhAAAthLjbhFiQkUQhAAANi2WZvg8FC1CAABbDp+nRQiKIAQAsBWW2IAnghAAwJYtQmkZ2ZKZ7fT27sDLCEIAAFsOn1epLLNhe34RhPbs2SMPPvigJCYmSkREhNStW1dGjx4tGRkZF31e586dxeFw5NqGDRtWavsNAPA9ocFBEhEabH5m4VX8Got92NatW8XpdMo777wj9erVk02bNsmQIUPkzJkz8tprr130ufq4P/3pT+7rkZGRpbDHAABfrxM6m5nNpIrwjyB08803m82lTp06sm3bNpk0adIlg5AGn4SEhFLYSwCAP9UJHU5Jl5SzzCVkd37RNVaQ5ORkKV++/CUfN2PGDImPj5emTZvKqFGjJC0t7aKPT09Pl5SUlFwbACAwJ1VkLiH4RYtQXjt27JCJEydesjWoX79+UqtWLalataps2LBBnnnmGdOS9Pnnn1/wOePGjZMxY8aUwF4DAHxt5Bg1QvBqi9DIkSPzFTPn3bQ+yNP+/ftNN1nfvn1N/c/FDB06VLp37y7NmjWT/v37y/Tp02XWrFmyc+fOCz5HW420tcm1JSUlFdvxAgB8a70xFl6FV1uEnnzySRk0aNBFH6P1QC4HDhyQG2+8UTp06CBTpkwp8vu1a9fO3aKkI88KEh4ebjYAQOAPoadGCF4NQhUrVjRbYWhLkIagVq1aydSpUyUoqOiNWevWrTOXVapUKfJzAQCBt94YLULwi2JpDUE6J1DNmjVNXdDRo0fl0KFDZvN8TMOGDWXlypXmunZ/vfTSS7JmzRozD9GcOXNkwIAB0rFjR2nevLkXjwYA4DMtQqxAb3t+USw9f/58052lW/Xq1XPdZ1mWuczMzDSF0K5RYWFhYbJgwQJ58803zXxDNWrUkD59+sgf//hHrxwDAMAXa4QYPm93fhGEtI7oUrVEtWvXdocipcFnyZIlpbB3AAC/HT5/NsvbuwIv84uuMQAASqJr7HQ6LUJ2RxACANi2WJoWIRCEAAC2nVCRGiEQhAAAtlx01bXEhmd9KeyHIAQAsG2LULbTMqvQw74IQgAA24kKC5YgR87PTKpobwQhAIDt6FqW7oVXzzJyzM4IQgAAsXudEOzLLyZUBACguJUN1xahs7L72BmpHMNi294UFxkm0eHeiSQEIQCArVuEnvpsvbd3xfZeuaOZ9GtX0yvngSAEALClns2ryqb9KZKZ7fT2rthesBcLdRwWEyhcVEpKisTGxkpycrLExMSU1ucCAABK4fubYmkAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbBCEAAGBbId7eAV9nWZa5TElJ8fauAACAQnJ9b7u+xy+EIHQJp0+fNpc1atQo7LkHAAA+9D0eGxt7wfsd1qWiks05nU45cOCAlC1bVhwOR7EmVQ1XSUlJEhMTI4HKDsfJMQYGPsfAwOcYGFKK4btD442GoKpVq0pQ0IUrgWgRugQ9edWrV5eSoh9woAYEux0nxxgY+BwDA59jYIi5wu+Oi7UEuVAsDQAAbIsgBAAAbIsg5CXh4eEyevRocxnI7HCcHGNg4HMMDHyOgSG8FL87KJYGAAC2RYsQAACwLYIQAACwLYIQAACwLYIQAACwLYKQl/ztb3+T2rVrS5kyZaRdu3aycuVK8Vffffed3HbbbWb2Tp19e/bs2flm93zhhRekSpUqEhERId26dZPt27eLPxk3bpy0adPGzDBeqVIl6d27t2zbti3XY86dOycjRoyQChUqSHR0tPTp00cOHz4s/mLSpEnSvHlz9wRm7du3l//+978Bc3wFGT9+vPmdfeyxxwLmOF988UVzTJ5bw4YNA+b4XPbv3y/33XefOQ79u9KsWTNZvXp1wPzd0e+HvJ+jbvrZBcrnmJ2dLc8//7wkJiaaz6hu3bry0ksv5VobrFQ+R11iA6Xrk08+scLCwqz33nvP2rx5szVkyBArLi7OOnz4sF9+FP/5z3+s5557zvr888/1t9eaNWtWrvvHjx9vxcbGWrNnz7bWr19v3X777VZiYqJ19uxZy190797dmjp1qrVp0yZr3bp11q233mrVrFnTSk1NdT9m2LBhVo0aNayFCxdaq1evtq699lqrQ4cOlr+YM2eO9dVXX1k///yztW3bNuvZZ5+1QkNDzTEHwvHltXLlSqt27dpW8+bNrUcffdR9u78f5+jRo60mTZpYBw8edG9Hjx4NmONTJ06csGrVqmUNGjTIWrFihbVr1y7r66+/tnbs2BEwf3eOHDmS6zOcP3+++fu6ePHigPkcx44da1WoUMH68ssvrd27d1ufffaZFR0dbU2YMKFUP0eCkBe0bdvWGjFihPt6dna2VbVqVWvcuHGWv8sbhJxOp5WQkGD95S9/cd926tQpKzw83Pr4448tf6V/pPRYlyxZ4j4mDQ36P7LLli1bzGOWLVtm+aty5cpZ//znPwPu+E6fPm3Vr1/ffLl06tTJHYQC4Tg1CLVo0aLA+wLh+NQzzzxjXX/99Re8PxD/7ujvaN26dc2xBcrn2KNHD+uBBx7Iddudd95p9e/fv1Q/R7rGSllGRoasWbPGNO95rmem15ctWyaBZvfu3XLo0KFcx6trv2h3oD8fb3JysrksX768udTPNDMzM9dxandEzZo1/fI4tcn6k08+kTNnzpguskA7Pu1S6NGjR67jUYFynNp1oF3VderUkf79+8vevXsD6vjmzJkjrVu3lr59+5qu6pYtW8o//vGPgP27o98bH374oTzwwAOmeyxQPscOHTrIwoUL5eeffzbX169fL99//73ccsstpfo5suhqKTt27Jj5kqlcuXKu2/X61q1bJdDoL7Eq6Hhd9/kbp9Npakquu+46adq0qblNjyUsLEzi4uL8+jg3btxogo/WH2jdwaxZs6Rx48aybt26gDg+pQFv7dq1smrVqnz3BcLnqF8S06ZNk6uuukoOHjwoY8aMkRtuuEE2bdoUEMendu3aZWrannjiCXn22WfNZ/n73//eHNvAgQMD7u+O1l2eOnVKBg0aZK4Hyuc4cuRIs8q8hrjg4GDz3Th27FgT3lVpfY4EIeAyWhP0S0X/5RJo9MtTQ4+2eP3rX/8yXypLliyRQJGUlCSPPvqozJ8/3wxUCESuf00rLX7XYFSrVi2ZOXOmKTYNBPqPEW0ReuWVV8x1bRHS/ycnT55sfmcDzbvvvms+V23lCyQzZ86UGTNmyEcffSRNmjQxf3v0H5l6nKX5OdI1Vsri4+NN8s1b3a/XExISJNC4jilQjveRRx6RL7/8UhYvXizVq1d3367Hos3X+q82fz5O/VdmvXr1pFWrVmakXIsWLWTChAkBc3zapXDkyBG55pprJCQkxGwa9N566y3zs/5LMxCO05O2GjRo0EB27NgRMJ+jjiDSlkpPjRo1cncBBtLfnV9++UUWLFgg//d//+e+LVA+xz/84Q+mVeiee+4xo/7uv/9+efzxx83fntL8HAlCXvii0S8Z7Rf1/NeNXtcuiUCjwyL1F9bzeLUpdMWKFX51vFoHriFIu4oWLVpkjsuTfqahoaG5jlOH1+sfZn86zrz0dzM9PT1gjq9r166m+0//5enatGVBm+JdPwfCcXpKTU2VnTt3mvAQKJ+jdkvnnb5C60y05SuQ/u6oqVOnmjoorWlzCZTPMS0tzdTIetKGAv27U6qfY7GVXaNIw+e16n3atGnWTz/9ZA0dOtQMnz906JBfnkUdgfPjjz+aTX+l/vrXv5qff/nlF/fwRz2+L774wtqwYYPVq1cvvxrGqoYPH26GcH777be5hrSmpaW5H6PDWXVI/aJFi8xw1vbt25vNX4wcOdKMgtNhrPo56XWHw2F98803AXF8F+I5aiwQjvPJJ580v6f6OS5dutTq1q2bFR8fb0Y6BsLxuaY+CAkJMcOvt2/fbs2YMcOKjIy0PvzwQ/djAuHvjo4o1s9KR8nlFQif48CBA61q1aq5h8/rFCz6u/r000+X6udIEPKSiRMnml9inU9Ih9MvX77c8lc6r4UGoLyb/pK7hkA+//zzVuXKlU0A7Nq1q5mnxp8UdHy66dxCLvo/5sMPP2yGnOsf5TvuuMOEJX+hw1h1bhb9naxYsaL5nFwhKBCOr7BByN+P8+6777aqVKliPkf9ktHrnvPr+PvxucydO9dq2rSp+ZvSsGFDa8qUKbnuD4S/Ozo3kv6dKWi/A+FzTElJMf/v6XdhmTJlrDp16pg56dLT00v1c3Tof4qvfQkAAMB/UCMEAABsiyAEAABsiyAEAABsiyAEAABsiyAEAABsiyAEAABsiyAEAABsiyAEAJdQu3ZtefPNNzlPQAAiCAHwKYMGDZLevXubnzt37mxWoy4t06ZNM4uU5rVq1SoZOnRoqe0HgNITUorvBQBeoSt164LHl6tixYrFuj8AfActQgB8tmVoyZIlMmHCBHE4HGbbs2ePuW/Tpk1yyy23SHR0tFSuXFnuv/9+OXbsmPu52pL0yCOPmNak+Ph46d69u7n9r3/9qzRr1kyioqKkRo0a8vDDD5vV2dW3334rgwcPluTkZPf7vfjiiwV2jekq37169TLvHxMTI3fddZccPnzYfb8+7+qrr5YPPvjAPDc2NlbuueceOX36dKmdPwCFQxAC4JM0ALVv316GDBkiBw8eNJuGl1OnTkmXLl2kZcuWsnr1apk3b54JIRpGPL3//vumFWjp0qUyefJkc1tQUJC89dZbsnnzZnP/okWL5Omnnzb3dejQwYQdDTau93vqqafy7ZfT6TQh6MSJEyaozZ8/X3bt2iV33313rsft3LlTZs+eLV9++aXZ9LHjx48v0XMGoOjoGgPgk7QVRYNMZGSkJCQkuG9/++23TQh65ZVX3Le99957JiT9/PPP0qBBA3Nb/fr15c9//nOu1/SsN9KWmpdfflmGDRsmf//738176XtqS5Dn++W1cOFC2bhxo+zevdu8p5o+fbo0adLE1BK1adPGHZi05qhs2bLmurZa6XPHjh1bbOcIwJWjRQiAX1m/fr0sXrzYdEu5toYNG7pbYVxatWqV77kLFiyQrl27SrVq1UxA0XBy/PhxSUtLK/T7b9myxQQgVwhSjRs3NkXWep9n0HKFIFWlShU5cuTIZR0zgJJDixAAv6I1Pbfddpu8+uqr+e7TsOGidUCetL6oZ8+eMnz4cNMqU758efn+++/lwQcfNMXU2vJUnEJDQ3Nd15YmbSUC4FsIQgB8lnZXZWdn57rtmmuukX//+9+mxSUkpPB/wtasWWOCyOuvv25qhdTMmTMv+X55NWrUSJKSkszmahX66aefTO2StgwB8C90jQHwWRp2VqxYYVpzdFSYBpkRI0aYQuV7773X1ORod9jXX39tRnxdLMTUq1dPMjMzZeLEiaa4WUd0uYqoPd9PW5y0lkffr6Aus27dupmRZ/3795e1a9fKypUrZcCAAdKpUydp3bp1iZwHACWHIATAZ+moreDgYNPSonP56LD1qlWrmpFgGnpuuukmE0q0CFprdFwtPQVp0aKFGT6vXWpNmzaVGTNmyLhx43I9RkeOafG0jgDT98tbbO3q4vriiy+kXLly0rFjRxOM6tSpI59++mmJnAMAJcthWZZVwu8BAADgk2gRAgAAtkUQAgAAtkUQAgAAtkUQAgAAtkUQAgAAtkUQAgAAtkUQAgAAtkUQAgAAtkUQAgAAtkUQAgAAtkUQAgAAtkUQAgAAYlf/D+sCXEaqwag8AAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "plt.plot(best_y_per_iteration.values)\n", + "plt.xlabel(\"Iteration\")\n", + "plt.ylabel(\"Best y so far\")\n", + "plt.title(\"Best y per iteration\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--- Iteration 0 ---\n", + " n_experiments: 66\n", + " acqf probe at fixed point: -45.9420604508\n", + " acq_val from get_candidate: -5.48109119224731\n", + " new y value: [-2.97790422]\n", + " experiments: 66 -> 67\n", + " lengthscale[0:3]: [0.2937835994912956, 0.4041666833510878, 4.41966538654436]\n", + "\n", + "--- Iteration 1 ---\n", + " n_experiments: 67\n", + " acqf probe at fixed point: -45.9419739465\n", + " acq_val from get_candidate: -5.645335943157333\n", + " new y value: [-2.99606393]\n", + " experiments: 67 -> 68\n", + " lengthscale[0:3]: [0.30036564796021076, 0.3919503617733176, 4.339199268257635]\n", + "\n", + "--- Iteration 2 ---\n", + " n_experiments: 68\n", + " acqf probe at fixed point: -45.9430990850\n", + " acq_val from get_candidate: -7.4356479767381565\n", + " new y value: [-0.00015728]\n", + " experiments: 68 -> 69\n", + " lengthscale[0:3]: [0.3312261463319473, 0.39313917655484865, 3.5347083776042476]\n", + "\n", + "--- Iteration 3 ---\n", + " n_experiments: 69\n", + " acqf probe at fixed point: -45.9446675469\n", + " acq_val from get_candidate: -6.929772202048997\n", + " new y value: [-0.00067149]\n", + " experiments: 69 -> 70\n", + " lengthscale[0:3]: [0.29295722403909413, 0.4010253404562717, 4.663433093634748]\n", + "\n", + "--- Iteration 4 ---\n", + " n_experiments: 70\n", + " acqf probe at fixed point: -45.9402736516\n", + " acq_val from get_candidate: -7.836574244144488\n", + " new y value: [-0.0006946]\n", + " experiments: 70 -> 71\n", + " lengthscale[0:3]: [0.30060401163820644, 0.3944575465842801, 4.207188642906171]\n", + "\n", + "--- Iteration 5 ---\n", + " n_experiments: 71\n", + " acqf probe at fixed point: -45.9433603961\n", + " acq_val from get_candidate: -2.6929366641172288\n", + " new y value: [-3.11049333]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-07 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-05 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-04 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-03 to the diagonal\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " experiments: 71 -> 72\n", + " lengthscale[0:3]: [0.30018737965437287, 0.4054900134182153, 13.665555661586506]\n", + "\n", + "--- Iteration 6 ---\n", + " n_experiments: 72\n", + " acqf probe at fixed point: -46.0112759163\n", + " acq_val from get_candidate: -6.4480336673718055\n", + " new y value: [-0.00106805]\n", + " experiments: 72 -> 73\n", + " lengthscale[0:3]: [0.2953749217570681, 0.4102799011484572, 13.598938811874865]\n", + "\n", + "--- Iteration 7 ---\n", + " n_experiments: 73\n", + " acqf probe at fixed point: -46.0108940378\n", + " acq_val from get_candidate: -39.15175009566557\n", + " new y value: [-3.00207657]\n", + " experiments: 73 -> 74\n", + " lengthscale[0:3]: [0.2970526387781363, 1.0787512449468613, 3.351634232309813]\n", + "\n", + "--- Iteration 8 ---\n", + " n_experiments: 74\n", + " acqf probe at fixed point: -46.0192865010\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed in `gen_candidates_scipy` with the following warning(s):\n", + "[OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .')]\n", + "Trying again with a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed on the second try, after generating a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed in `gen_candidates_scipy` with the following warning(s):\n", + "[OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .')]\n", + "Trying again with a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " acq_val from get_candidate: -12.164814561040743\n", + " new y value: [-0.01646253]\n", + " experiments: 74 -> 75\n", + " lengthscale[0:3]: [0.39385986766547937, 0.3448380489990462, 8.83546958707795]\n", + "\n", + "--- Iteration 9 ---\n", + " n_experiments: 75\n", + " acqf probe at fixed point: -46.0094130418\n", + " acq_val from get_candidate: -5.745335865515177\n", + " new y value: [-0.17019157]\n", + " experiments: 75 -> 76\n", + " lengthscale[0:3]: [0.3044109295808983, 0.3922305605059308, 14.09836815838481]\n", + "\n", + "--- Iteration 10 ---\n", + " n_experiments: 76\n", + " acqf probe at fixed point: -46.0100629356\n", + " acq_val from get_candidate: -45.89948778074979\n", + " new y value: [-0.16469201]\n", + " experiments: 76 -> 77\n", + " lengthscale[0:3]: [0.3055344125405825, 0.40015436028875384, 15.101534607314072]\n", + "\n", + "--- Iteration 11 ---\n", + " n_experiments: 77\n", + " acqf probe at fixed point: -46.0129252589\n", + " acq_val from get_candidate: -45.86736603648621\n", + " new y value: [-0.3761452]\n", + " experiments: 77 -> 78\n", + " lengthscale[0:3]: [0.03089655749767049, 108.44260342490087, 7.957079269897791]\n", + "\n", + "--- Iteration 12 ---\n", + " n_experiments: 78\n", + " acqf probe at fixed point: -45.9379630881\n", + " acq_val from get_candidate: -0.7785376643322541\n", + " new y value: [-0.00878947]\n", + " experiments: 78 -> 79\n", + " lengthscale[0:3]: [0.3050036490876489, 0.39526463962698116, 10.838574498067157]\n", + "\n", + "--- Iteration 13 ---\n", + " n_experiments: 79\n", + " acqf probe at fixed point: -46.0129851173\n", + " acq_val from get_candidate: -45.75768827975949\n", + " new y value: [-0.37977354]\n", + " experiments: 79 -> 80\n", + " lengthscale[0:3]: [0.3006522276959701, 0.3935762903405451, 14.842249471453774]\n", + "\n", + "--- Iteration 14 ---\n", + " n_experiments: 80\n", + " acqf probe at fixed point: -46.0135545348\n", + " acq_val from get_candidate: -11.812887247140306\n", + " new y value: [-0.0033548]\n", + " experiments: 80 -> 81\n", + " lengthscale[0:3]: [0.2196389815319199, 0.5123909611274864, 7.623324726501774]\n", + "\n", + "--- Iteration 15 ---\n", + " n_experiments: 81\n", + " acqf probe at fixed point: -46.0121977271\n", + " acq_val from get_candidate: -7.900848122053444\n", + " new y value: [-2.46071964]\n", + " experiments: 81 -> 82\n", + " lengthscale[0:3]: [0.3160045940298298, 0.3706480418244254, 14.861831419103268]\n", + "\n", + "--- Iteration 16 ---\n", + " n_experiments: 82\n", + " acqf probe at fixed point: -46.0142045800\n", + " acq_val from get_candidate: -45.75633524187079\n", + " new y value: [-0.37986192]\n", + " experiments: 82 -> 83\n", + " lengthscale[0:3]: [0.31168116688795683, 0.37187935816620016, 14.805823721437385]\n", + "\n", + "--- Iteration 17 ---\n", + " n_experiments: 83\n", + " acqf probe at fixed point: -46.0145541933\n", + " acq_val from get_candidate: -45.7564354110594\n", + " new y value: [-0.37953567]\n", + " experiments: 83 -> 84\n", + " lengthscale[0:3]: [0.3077234584278307, 0.3712324781376418, 11.870060602968802]\n", + "\n", + "--- Iteration 18 ---\n", + " n_experiments: 84\n", + " acqf probe at fixed point: -46.0133323711\n", + " acq_val from get_candidate: -45.75837938393406\n", + " new y value: [-0.37769154]\n", + " experiments: 84 -> 85\n", + " lengthscale[0:3]: [0.17143206289072715, 0.545940037880173, 0.8365451921167117]\n", + "\n", + "--- Iteration 19 ---\n", + " n_experiments: 85\n", + " acqf probe at fixed point: -45.3993822367\n", + " acq_val from get_candidate: -0.33510238302079465\n", + " new y value: [-2.99120487]\n", + " experiments: 85 -> 86\n", + " lengthscale[0:3]: [0.31513088488426055, 0.38201061105153294, 18.104990428689263]\n", + "\n", + "--- Iteration 20 ---\n", + " n_experiments: 86\n", + " acqf probe at fixed point: -46.0112114395\n", + " acq_val from get_candidate: -45.75640901072899\n", + " new y value: [-0.38357929]\n", + " experiments: 86 -> 87\n", + " lengthscale[0:3]: [0.301989578174714, 0.35224097931143306, 1.4549593860058594]\n", + "\n", + "--- Iteration 21 ---\n", + " n_experiments: 87\n", + " acqf probe at fixed point: -45.9529142233\n", + " acq_val from get_candidate: -2.9594742503756333\n", + " new y value: [-2.90055904]\n", + " experiments: 87 -> 88\n", + " lengthscale[0:3]: [0.12414469998203487, 0.4944761822780233, 0.9127515924007126]\n", + "\n", + "--- Iteration 22 ---\n", + " n_experiments: 88\n", + " acqf probe at fixed point: -9.2162054586\n", + " acq_val from get_candidate: -1.5305870841600848\n", + " new y value: [-2.9749772]\n", + " experiments: 88 -> 89\n", + " lengthscale[0:3]: [0.22503487233265027, 0.2691016257809672, 2.1812105898031273]\n", + "\n", + "--- Iteration 23 ---\n", + " n_experiments: 89\n", + " acqf probe at fixed point: -45.9539942651\n", + " acq_val from get_candidate: -5.513861895133451\n", + " new y value: [-0.00023194]\n", + " experiments: 89 -> 90\n", + " lengthscale[0:3]: [0.3195103170518832, 0.38685066238250077, 8.449608313147367]\n", + "\n", + "--- Iteration 24 ---\n", + " n_experiments: 90\n", + " acqf probe at fixed point: -46.0118833014\n", + " acq_val from get_candidate: -38.843144708007856\n", + " new y value: [-3.01303979]\n", + " experiments: 90 -> 91\n", + " lengthscale[0:3]: [0.4148201136989792, 0.2717726187241554, 1.2394691956390447]\n", + "\n", + "--- Iteration 25 ---\n", + " n_experiments: 91\n", + " acqf probe at fixed point: -45.9352343733\n", + " acq_val from get_candidate: -4.374878298424731\n", + " new y value: [-0.00096781]\n", + " experiments: 91 -> 92\n", + " lengthscale[0:3]: [0.32032115103464204, 0.39106494722973356, 13.63965179862593]\n", + "\n", + "--- Iteration 26 ---\n", + " n_experiments: 92\n", + " acqf probe at fixed point: -46.0111084633\n", + " acq_val from get_candidate: -45.75534578242555\n", + " new y value: [-0.38687496]\n", + " experiments: 92 -> 93\n", + " lengthscale[0:3]: [0.32200136414226704, 0.3856479669857732, 14.93704574089937]\n", + "\n", + "--- Iteration 27 ---\n", + " n_experiments: 93\n", + " acqf probe at fixed point: -46.0109705722\n", + " acq_val from get_candidate: -45.752005521140276\n", + " new y value: [-0.38132764]\n", + " experiments: 93 -> 94\n", + " lengthscale[0:3]: [0.3171100781183023, 0.39110930535487876, 16.031086296578092]\n", + "\n", + "--- Iteration 28 ---\n", + " n_experiments: 94\n", + " acqf probe at fixed point: -46.0103654171\n", + " acq_val from get_candidate: -38.87649071044438\n", + " new y value: [-3.01310733]\n", + " experiments: 94 -> 95\n", + " lengthscale[0:3]: [0.32143656231313106, 0.3899184664723671, 20.657003633942974]\n", + "\n", + "--- Iteration 29 ---\n", + " n_experiments: 95\n", + " acqf probe at fixed point: -46.0091122417\n", + " acq_val from get_candidate: -45.752113163467726\n", + " new y value: [-0.3860681]\n", + " experiments: 95 -> 96\n", + " lengthscale[0:3]: [0.3920155869694801, 0.025147022994274456, 5.885963226058412]\n" + ] + } + ], + "source": [ + "for iteration in range(30):\n", + " acqf_check = strategy._get_acqfs(n=1)[0]\n", + "\n", + " # Probe: evaluate acqf at a fixed test point to see if model actually changed\n", + " test_point = torch.zeros(1, 1, bounds.shape[1], **tkwargs)\n", + " test_point[..., 0] = 0.5\n", + " probe_val = acqf_check(test_point).item()\n", + "\n", + " print(f\"\\n--- Iteration {iteration} ---\")\n", + " print(f\" n_experiments: {len(strategy.experiments)}\")\n", + " print(f\" acqf probe at fixed point: {probe_val:.10f}\")\n", + "\n", + " candidate, acq_val = get_candidate()\n", + " print(f\" acq_val from get_candidate: {acq_val}\")\n", + "\n", + " new_experiments = benchmark.f(candidate, return_complete=True)\n", + " print(f\" new y value: {new_experiments['y'].values}\")\n", + "\n", + " n_before = len(strategy.experiments)\n", + " strategy.tell(new_experiments, replace=False)\n", + " n_after = len(strategy.experiments)\n", + " print(f\" experiments: {n_before} -> {n_after}\")\n", + "\n", + " # Check if model params actually changed\n", + " ls = strategy.model.covar_module.lengthscale\n", + " print(f\" lengthscale[0:3]: {ls[0, :3].tolist()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-0.00568271, -0.09673502, -0.09673502, -0.09673502, -0.09673502,\n", + " -0.09713184, -0.61958705, -0.61958705, -0.61958705, -0.61958705,\n", + " -0.61958705, -0.61958705, -0.61958705, -0.61958705, -0.61958705,\n", + " -0.61958705, -0.61958705, -0.61958705, -1.00350188, -1.00350188,\n", + " -1.00350188, -1.00350188, -1.00350188, -1.00350188, -1.00350188,\n", + " -1.11410792, -1.11410792, -1.74733591, -1.74733591, -1.78129139,\n", + " -2.66609372, -2.70018492, -2.70018492, -2.70018492, -2.70018492,\n", + " -2.70018492, -2.70018492, -2.94932638, -2.94932638, -2.94932638,\n", + " -2.94932638, -2.94932638, -2.94932638, -2.94932638, -2.94932638,\n", + " -2.94932638, -2.94932638, -2.94932638, -2.94932638, -2.94932638,\n", + " -2.94932638, -2.94932638, -2.94932638, -2.94932638, -2.94932638,\n", + " -2.95942063, -2.95942063, -2.95942063, -2.95942063, -2.95942063,\n", + " -2.95942063, -2.95942063, -2.96859675, -2.9747409 , -2.9747409 ,\n", + " -2.9747409 , -2.97790422, -2.99606393, -2.99606393, -2.99606393,\n", + " -2.99606393, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "best_y_per_iteration = strategy.experiments[\"y\"].cummin()\n", + "best_y_per_iteration.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Best y per iteration')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAHHCAYAAABTMjf2AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQgdJREFUeJzt3Qd4VFX6x/E3vZICgYTQm4YOi4ggAgqCCiqChaIUWVkQ1gIuwt9dFV1FXF1RUBQLRUFFVxHQxUVAVHoRBOkK0jtJIIHU+T/vgRkTSEISZzJzZ76f57kmM3Mnc3ITmV/Oec85fjabzSYAAAA+yN/dDQAAAHAXghAAAPBZBCEAAOCzCEIAAMBnEYQAAIDPIggBAACfRRACAAA+iyAEAAB8FkEIAAD4LIIQABTTt99+K35+fuajVUybNs20ec+ePe5uCuCRCEKAF7C/2eU9KlWqJNdff73897//ddnrpqeny9NPP22pYOBss2bNkgkTJri7GfL888/LnDlz3N0MwHL82GsM8I4gNHDgQHnmmWekVq1aolsIHjlyxNz/888/y7x586Rbt25Of93jx49LxYoV5amnnjKByNvl5uZKZmamBAcHi7//+b8j9bpu3rzZ7T0ukZGRcuedd5qfeV45OTmSlZUlISEhJiADyC/wotsALOzmm2+Wq666ynF70KBBEh8fLx9++KFLgpA3SktLk4iIiAIf0/ATGhpaZoHLGa8VEBBgDgAFY2gM8GIxMTESFhYmgYGBl7zR6nBOw4YNzZuthqW//OUvcurUqXznrV27Vrp06SJxcXHm62hv0/33328e0x4Q7Q1SY8eOdQzJFdYz9Ouvv5rHX3nllUseW758uXlMA9vl6nM+/vhj+b//+z9JSEgwgeW2226Tffv2XXL+qlWr5KabbpLo6GgJDw+X9u3by7Jly/Kdo23Vr7llyxbp06ePxMbGStu2bS/bBvtQYIcOHeTLL7+U3377zfH916xZ03F+RkaG6S2rW7eu6ZGpVq2ajBo1ytyflz5v+PDhMnPmTPMz0XMXLFhgHnvppZekTZs2UqFCBfMzaNGihXz66aeXPF8D3PTp0x3tGDBgQJE1Qm+88YbjtRITE2XYsGGSnJyc7xz9/ho1amSujw6z6nWsUqWKvPjii4VeI8Bq6BECvEhKSooZrtKhsaNHj8rEiRPlzJkzcu+99+Y7T0OPfTjtoYcekt27d8ukSZPkxx9/NGEhKCjIPL9z584m7IwePdqEKn0z/eyzz8zX0PsnT54sQ4cOlTvuuEN69Ohh7m/SpEmBbatdu7Zce+215s3+0UcfzfeY3leuXDm5/fbbL/s9Pvfcc+aN/fHHHzdt1EDXqVMn2bBhgwkKavHixaZ3TEODBhHtyZk6darccMMN8v3338vVV1+d72veddddUq9ePVNno9euuJ544glzzffv3+8IeDpEZQ+bGtJ++OEHGTx4sNSvX182bdpkztuxY8cl9Tza5tmzZ5tApMHTHqheffVV83X69u1reok++ugj09758+dL165dzTnvv/++/PnPfzbfl76WqlOnTqHt1gCo4VWvm/78tm/fbn6Wa9ascfz87TQca6DUn+/dd99tQphe+8aNG5trDFie1ggBsLapU6fqu/clR0hIiG3atGn5zv3+++/NYzNnzsx3/4IFC/Ld//nnn5vba9asKfR1jx07Zs556qmnitXOt956y5y/detWx32ZmZm2uLg4W//+/Yt87pIlS8xzq1SpYktNTXXcP3v2bHP/q6++am7n5uba6tWrZ+vSpYv53C49Pd1Wq1Yt24033ui4T9utz+3du3ex2m9vg36069q1q61GjRqXnPv+++/b/P39zfXO68033zRfY9myZY779Lae+/PPP1/ydbTdeen1atSoke2GG27Id39ERESB19D+u7F7925z++jRo7bg4GBb586dbTk5OY7zJk2aZM577733HPe1b9/e3DdjxgzHfRkZGbaEhARbz549i7hSgHUwNAZ4kddff10WLlxojg8++MAMZ2hPgb0XR33yySdmuOjGG280vUf2Q3tPtDdjyZIl5jztAVLa86DFts6gPQo6FKc9QHZff/21ef2Le60K069fP9N7ZKcFwpUrV5avvvrK3NaeoZ07d5qhrhMnTji+Px066tixo3z33XemtyavIUOGiLPpddZeoKSkpHzXWXullP062+nQXYMGDS75OvZeLnvvjPZAXXfddbJ+/fpSteubb74xPUuPPPKIo+BbPfDAAxIVFWWG+vLS34m8PxstFNeeJx3qBLwBQ2OAF9E3qLzF0r1795bmzZub4RYtltY3MQ0J+maq0+sLosNN9jfmnj17miEUHc7RepHu3bubgKF1JaWh4erWW281U86fffZZc5+GIq07sQeEy9EhrLx0mExrcOw1MPr9qf79+xf6NfT713ogO619cjZtx9atWx11VIVd58u1QYPoP//5TxPw8tYWlXYGmNYzqSuvvDLf/fq7ocOX9sftqlateslr6bX76aefSvX6gKchCAFeTP/i114hrTPRN2YtjtXeEA1BeXtl8rK/ceubn9aDrFy50ky/154bLZR++eWXzX32WpiS0h4d7S3RAmmtM5k7d648+OCD+Xon/gh7b8+//vUvadasWYHnXNz2vL0uzqLt0O/v3//+d4GPa+H05dqg9UxaH9SuXTtT3Kw9X1q/o/VOGibLQmEzzkpSSwV4MoIQ4OWys7PNRy2athfR6vCIFi4XJwBcc8015tAiZX3z1aJdLdjVIbfS9Epo4a2GLQ1irVq1Mosy3nfffcV+vr3HJ+8b8q5duxxF2vYiYR3m0WJgVyvsGmg7Nm7caIbjStt785///McMJWoIzdsLp0GouO24WI0aNcxHLZDWHiA7HS7TovmyuGaAJ6FGCPBiWtvzv//9zwx7aL2KvU5HF9mzD01dHJrsU6i1HuXiv/rtPSz2IRqdTq0unnZdFJ3Kr0N2OkNKZ65pr0lhM80KMmPGDDl9+rTjtvZaHTp0yDGDSWudNITotHN7+Mvr2LFj4kw6hV+H2i6m1/nAgQPy9ttvX/LY2bNnTc1ScXpjNODoz8tOhwALWkFa21Gcn4MGHf19eO211/L9fN99913zfdhnogG+gh4hwIvodhrbtm1z1KBoD472oOj0d+0hsdf+6PT5cePGmboTnSKvwy16ng5Z6TCaFiDrmjQ6HKNT4zVYaPjQN3X9Orfccov5WtqjpAW+urbPFVdcIeXLlzfrzuhxueExfSPWguHx48eX6HvU19C1fnTqv66erdPntUZIi32VDrG98847JhjpUKCepzVIGkr09bT9OtTnLBq89PsfMWKEtGzZ0gy7aR2U9nJp2NNCbH1d7YHTQKM/H71fe3ny1nMVREOJDq1pL5rWZunPVAvi9fu9uEZH26E9fXq+rgukNUfa43Yx7Y0bM2aMqf3Sr6tDb9o7pD9rbX9xi9YBr+HuaWsAXDN9PjQ01NasWTPb5MmT800jt5syZYqtRYsWtrCwMFu5cuVsjRs3to0aNcp28OBB8/j69evNtPLq1aubafiVKlWydevWzbZ27dp8X2f58uXm6+iU7JJMpW/YsKGZMr5///4STV3/8MMPbWPGjDHt0bbr9PXffvvtkvN//PFHW48ePWwVKlQw7dcp7nfffbdt0aJFl0yf12UASjt9/syZM7Y+ffrYYmJizGN5p9LrVPfx48eb71XbEBsba67V2LFjbSkpKY7z9HnDhg0r8DXfffddsxyAPj8pKcn8rO3tzmvbtm22du3amWuij9mn0l88fT7vdHn9ekFBQbb4+Hjb0KFDbadOncp3jk6f17ZfTL92QUsGAFbEXmMA3EJns2nvzqJFi4p1vq7mrIXf2mulPVYA4AzUCAEoc7p1hw7L6RAZALgTNUIAyozu0r5u3TozBV+ngt9zzz1cfQBuRY8QgDKjM7y0eFlns+kGq2WxkzsAFIUaIQAA4LPoEQIAAD6LIAQAAHwWxdLF2C/o4MGDZrfr0i6TDwAAypYu0aULweoCo0XtZUgQugwNQRdvjggAAKxh3759UrVq1UIfJwhdhvYE2S+kfYsCAADg2VJTU01Hhv19vDAEocuwD4dpCCIIAQBgLZcra6FYGgAA+CyCEAAA8FkEIQAA4LMIQgAAwGcRhAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQQgAAPgsywWh119/XWrWrCmhoaHSqlUrWb16dZHnf/LJJ5KUlGTOb9y4sXz11Vdl1lYAAODZLBWEPv74YxkxYoQ89dRTsn79emnatKl06dJFjh49WuD5y5cvl969e8ugQYPkxx9/lO7du5tj8+bNZd52AADgefxsuk+9RWgPUMuWLWXSpEnmdm5urtlQ7a9//auMHj36kvPvueceSUtLk/nz5zvuu+aaa6RZs2by5ptvFnvTtujoaElJSWGvMQAALKK479+W6RHKzMyUdevWSadOnRz3+fv7m9srVqwo8Dl6f97zlfYgFXa+ysjIMBcv7+EKmj+X7zouGdk5Lvn6AADg8iwThI4fPy45OTkSHx+f7369ffjw4QKfo/eX5Hw1btw4kyDth/Y4ucKDM9dLn3dWyX/WHXDJ1wcAAF4UhMrKmDFjTDea/di3b59LXqdlzfLm4+SluyQ7J9clrwEAALwkCMXFxUlAQIAcOXIk3/16OyEhocDn6P0lOV+FhISYscS8hyv0vrq6VIgIln0nz8rcjQdd8hoAAMBLglBwcLC0aNFCFi1a5LhPi6X1duvWrQt8jt6f93y1cOHCQs8vS2HBATLoulrm8ze+/UVycy1Tsw4AgNewTBBSOnX+7bfflunTp8vWrVtl6NChZlbYwIEDzeP9+vUzQ1t2Dz/8sCxYsEBefvll2bZtmzz99NOydu1aGT58uHiC+66pIVGhgbLr6Bn5+ufC65YAAIBrWCoI6XT4l156SZ588kkzBX7Dhg0m6NgLovfu3SuHDh1ynN+mTRuZNWuWTJkyxaw59Omnn8qcOXOkUaNG4gnKhQbJgGvP9wpNWrLLzCQDAABlx1LrCLmDq9cROpWWKdeOXyzpmTkydUBLuT6pktNfAwAAX5PqbesIeavYiGC595oa5vOJi3fSKwQAQBkiCHmAP19XS4ID/WX93mRZ8esJdzcHAACfEejuBkCkUrlQ6dWymsxY8ZtMWLhTcku5rJC/n4ifn5/5GODvJ1VjwyUhOpRLDABAIagR8pC9xg4kn5X2Ly6RbCdOow/095PvRl0viTFhTvuaAAB40/s3PUIeokpMmIy5pb58um5/qeuE9Gk5Npvk2myy/9RZyczOlR1HThOEAAAoBEHIgwxqW8scznDfu6vk+53H5cSZTKd8PQAAvBHF0l5Kt+9QJ9MIQgAAFIYg5KXKR4SYj8fTMtzdFAAAPBZByEtViLzQI8TQGAAAhSIIefnQ2AmGxgAAKBRByEtViDw/NEYQAgCgcAQhL1Xe3iN0hhohAAAKQxDyUnH2GiGGxgAAKBRByMt7hHRX+7OZOe5uDgAAHokg5KUiQwLNRq7qBFPoAQAoEEHIS+nmq46ZY0yhBwCgQAQhX1hLiDohAAAKRBDyhdWlmTkGAECBCEJeLI79xgAAKBJByAdmjjE0BgBAwQhCPrC69HGKpQEAKBBByIvZZ42dZPo8AAAFIgj5wKwx9hsDAKBgBCGf2G8s091NAQDAIxGEvFicYwd6Nl4FAKAgBCEf6BE6l5Ur6ZnZ7m4OAAAehyDkxcKDAyQ06MJ+YwyPAQBwCYKQ1+83Zh8eo04IAICLEYR8ZeYY22wAAHAJgpCvzByjRwgAgEsQhLycY2iMGiEAAC5BEPKRoTFWlwYA4FIEIR/ZZoMeIQAALkUQ8nLUCAEAUDiCkJdjdWkAAApHEPKRHqGTFEsDAHAJgpCPFEsfT8sUm83m7uYAAOBRCEI+Mn0+MztX0jJz3N0cAAA8CkHIy4UFB5g9xxSrSwMAkB9ByAcwcwwAgIIRhHxAhUhWlwYAoCAEIR9aVJHVpQEAyI8g5ENB6DhT6AEAsGYQOnnypPTt21eioqIkJiZGBg0aJGfOnCnyOR06dBA/P798x5AhQ8TXlHfsN5bp7qYAAOBRAt3dgOLSEHTo0CFZuHChZGVlycCBA2Xw4MEya9asIp/3wAMPyDPPPOO4HR4eLr47NEYQAgDAckFo69atsmDBAlmzZo1cddVV5r6JEyfKLbfcIi+99JIkJiYW+lwNPgkJCeLL7GsJHT+T4e6mAADgUSwxNLZixQozHGYPQapTp07i7+8vq1atKvK5M2fOlLi4OGnUqJGMGTNG0tPTizw/IyNDUlNT8x1Wx9AYAAAW7hE6fPiwVKpUKd99gYGBUr58efNYYfr06SM1atQwPUY//fSTPP7447J9+3b57LPPCn3OuHHjZOzYseJN4i70CJ2gWBoAAM8JQqNHj5bx48dfdlistLSGyK5x48ZSuXJl6dixo/zyyy9Sp06dAp+jvUYjRoxw3NYeoWrVqom39AjpfmNaNA4AANwchEaOHCkDBgwo8pzatWubGp+jR4/muz87O9vMJCtJ/U+rVq3Mx127dhUahEJCQszhjcXSmTm5cjojW6JCg9zdJAAAPIJbg1DFihXNcTmtW7eW5ORkWbdunbRo0cLct3jxYsnNzXWEm+LYsGGD+ag9Q74kNChAIoIDzKarJ89kEoQAALBSsXT9+vXlpptuMlPhV69eLcuWLZPhw4dLr169HDPGDhw4IElJSeZxpcNfzz77rAlPe/bskblz50q/fv2kXbt20qRJE/HZbTbSmDkGAIClgpB99pcGHa3x0Wnzbdu2lSlTpjge17WFtBDaPissODhYvvnmG+ncubN5ng7D9ezZU+bNmyc+vfEqBdMAAFhr1pjSGWJFLZ5Ys2ZNUwhspwXOS5cuLaPWeb64CwXTJ1hUEQAA6/UIwTk9QqwuDQDA7whCPlYjxOrSAAD8jiDkI9hvDACASxGEfEQFe40QxdIAADgQhHxEefs2GxRLAwBgvVljcM7Q2O7jZ2TwjLVecznrxUfKY52vZNsQAECpEIR8RNXYMAkK8JNzWbnyvy1HxFvo93Jzo8rSqEq0u5sCALAggpCPiAkPlo//0lq2HTot3uKt736R306ky/5TZwlCAIBSIQj5kD9VjzWHt/h+5zEThA6lnHV3UwAAFkWxNCwrMSbMfDyUcs7dTQEAWBRBCJZVOTrUfDyQTI8QAKB0CEKwfo8QQQgAUEoEIVi+R4ihMQBAaRGEYFlVLvQIHUk9J9k5ue5uDgDAgghCsKy4yBCzNlKuTeTo6Qx3NwcAYEEEIViWv7+fxEedHx47SJ0QAKAUCEKwtMTo88NjB5lCDwAoBYIQLC0x5kLBND1CAIBSIAjB0iqzqCIA4A8gCMHSEllUEQDwBxCEYGmVL9QIsd8YAKA0CEKwtMqOGiH2GwMAlBxBCF6xqOKJtEw5l5Xj7uYAACyGIARLiw4LkrCgAPM5W20AAEqKIARL8/PzyzM8xi70AICSIQjB8lhUEQBQWgQhWB6LKgIASosgBK+ZQn8whaExAEDJEITgNT1CB5lCDwAoIYIQLI9FFQEApUUQguUl2vcbo0cIAFBCBCF4zdDY6YxsST2X5e7mAAAshCAEywsPDjQLKyp6hQAAJUEQgleofGEXemaOAQBKgiAEr9pzjB4hAEBJEITgFezbbBxkmw0AQAkQhOAVWFQRAFAaBCF42TYb59zdFACAhRCE4FUbrx5imw0AQAkQhOBViyoeTDknNpvN3c0BAFgEQQheIT4qVPz8RDKzc+VEWqa7mwMAsAiCELxCcKC/xEWGmM+pEwIAFBdBCF44PHbW3U0BAFiEZYLQc889J23atJHw8HCJiYkp1nO0VuTJJ5+UypUrS1hYmHTq1El27tzp8rbCPRLtq0uzlhAAwNuCUGZmptx1110ydOjQYj/nxRdflNdee03efPNNWbVqlUREREiXLl3k3DmmWHvzWkKHUvj5AgCKJ1AsYuzYsebjtGnTit0bNGHCBPn73/8ut99+u7lvxowZEh8fL3PmzJFevXq5tL1w31pC9AgBALwuCJXU7t275fDhw2Y4zC46OlpatWolK1asIAh5cY3QbyfS5bcTaWX2upEhgVLhQqE2AMBavDYIaQhS2gOUl962P1aQjIwMc9ilpqa6sJVwxQ70mw6kSPt/fVumF3fG/VdLuysqlulrAgAsXiM0evRo8fPzK/LYtm1bmbZp3LhxpufIflSrVq1MXx+l1yAxSlrUiDU9NGV1hASe/1/onR9286MDAAtya4/QyJEjZcCAAUWeU7t27VJ97YSEBPPxyJEjZtaYnd5u1qxZoc8bM2aMjBgxIl+PEGHIGkICA+Q/Q9uU6WvuPZEu7f61RL7feUz2nUyXauXDy/T1AQAWDkIVK1Y0hyvUqlXLhKFFixY5go+GGp09VtTMs5CQEHMAxVG9Qri0rRsnP+w6LrPX7pORna/kwgGAhVhm+vzevXtlw4YN5mNOTo75XI8zZ844zklKSpLPP//cfK7Dao888oj885//lLlz58qmTZukX79+kpiYKN27d3fjdwJv0+vq88OnGoSyc3Ld3RwAgDcWS+vCiNOnT3fcbt68ufm4ZMkS6dChg/l8+/btkpKS4jhn1KhRkpaWJoMHD5bk5GRp27atLFiwQEJDzxfVAs7QuUGCVIgIliOpGbJk+zG5sUH+An0AgOfys7FVd5F0OE2LpjVgRUVFldXPBRbz/FdbZcp3v8oNSZXkvQEt3d0cAPB5qcV8/7bM0BjgyXq1PD889u32oyzoCAAWQhACnKB2xUhpVau85NrO1woBAKyBIAQ4SZ9W1c3H2Wv2SY4mIgCAxyMIAU7SpWGCxIQHycGUc/LdjmNcVwCwAIIQ4CShQQHSo3lV8/n0FXtkx5HTjmPnkdOSxdR6APA4lpk+D1hB76uryXvLdsu324+ZIy9dePGDP7dyW9sAAJeiRwhwonrx5UytUPmIYMcRHRZkHtu4L5lrDQAehh4hwMmev6OxOexOpmXKn55dKKczss3K04EB/P0BAJ6Cf5EBF4sK/f3vjdRz2VxvAPAgBCHAxbQHqFzI+TCUnJ7J9QYAD0IQAspAdPj5OqHks1lcbwCwahDSbcl09/dz5865rkWAF9L1hVRKOkEIACwdhOrWrSv79rGFAFAS9pljKfQIAYB1g5C/v7/Uq1dPTpw44boWAV4oJizYfKRGCAAsXiP0wgsvyN/+9jfZvHmza1oEeCFqhADAS9YR6tevn6Snp0vTpk0lODhYwsLC8j1+8uRJZ7YP8AoxF4bGkqkRAgBrB6EJEya4piWAD9QIpVIjBADWDkL9+/d3TUsAH5g1xvR5APCiLTZ0Gn1mZv4F4qKiov5omwCvE02xNAB4R7F0WlqaDB8+XCpVqiQRERESGxub7wBQ+NAYPUIAYPEgNGrUKFm8eLFMnjxZQkJC5J133pGxY8dKYmKizJgxwzWtBLxkaIwaIQCw+NDYvHnzTODp0KGDDBw4UK677jqzyGKNGjVk5syZ0rdvX9e0FPCGGqH0LLMwqZ+fn7ubBAAoTY+QTo+vXbu2ox7IPl2+bdu28t1333FRgSIWVMzOtUlaZg7XCACsGoQ0BO3evdt8npSUJLNnz3b0FMXExDi/hYAXCA3yl+CA8/+7sc0GAFg4COlw2MaNG83no0ePltdff11CQ0Pl0UcfNStOA7iUDoU5VpdOzz/TEgBgoRohDTx2nTp1km3btsm6detMnVCTJk2c3T7Aq1aXPnY6gx3oAcBqPULly5eX48ePm8/vv/9+OX36tOMxLZLu0aMHIQi4DBZVBACLBiFdNDE1NdV8Pn36dLOQIoDSrSVEjRAAWGxorHXr1tK9e3dp0aKFmfr70EMPXbLZqt17773n7DYCXra6dJa7mwIAKEkQ+uCDD+SVV16RX375xRR9pqSk0CsElHpojGJpALBUEIqPj5cXXnjBfF6rVi15//33pUKFCq5uG+B1xdIqhR4hALDurDH7GkIASsY+fZ4aIQCw8DpCAP7gxqv0CAGAxyAIAWUkJvxCsfRZiqUBwFMQhIAyrxGiWBoAPAVBCCgjrCMEAF5QLK1ycnJkzpw5snXrVnO7YcOGctttt0lAQICz2wd43fR53X0+MztXggP5OwQALBeEdu3aJV27dpX9+/fLlVdeae4bN26cVKtWTb788kupU6eOK9oJWF650CDx8xOx2c7PHKtYLsTdTQIAn1fiP0l1VenatWvLvn37ZP369ebYu3evWV9IHwNQsAB/PykXcv5vD6bQA4BFe4SWLl0qK1euNBux2uniirrg4rXXXuvs9gFeN3Ms9Vy2pLC6NABYs0coJCQk3+7zdmfOnJHg4PPTgwFcZpsN1hICAGsGoW7dusngwYNl1apVZgNWPbSHaMiQIaZgGkDhWFQRACwehF577TVTEK070oeGhppDh8Tq1q0rr776qmtaCXgJptADgMWDUExMjHzxxReyY8cO+fTTT82xfft2+fzzzyU6Oto1rRSR5557Ttq0aSPh4eGmDcUxYMAA8fPzy3fcdNNNLmsjUPwd6FldGgAsu46Q0h4gPXRNoU2bNsmpU6ckNjZWXCUzM1Puuusu0xP17rvvFvt5GnymTp2ar8YJcJeYsPN1dKwuDQAWDUKPPPKING7cWAYNGmRCUPv27WX58uWmp2b+/PnSoUMHlzR07Nix5uO0adNK9DwNPgkJCS5pE1BS9AgBgMWHxnQorGnTpubzefPmya+//irbtm2TRx99VJ544gnxNN9++61UqlTJLP44dOhQOXHihLubBB8WZd9vjKExALBmEDp+/Lijh+Wrr76Su+++W6644gq5//77zRCZJ9FhsRkzZsiiRYtk/PjxZg2km2++2fRkFSYjI0NSU1PzHYCzN15l+jwAWDQIxcfHy5YtW0yYWLBggdx4443m/vT09BLvNTZ69OhLipkvPrS3qbR69eplpvTrUF737t3N0N2aNWtML1FhdLsQLfq2H7p1CODMBRUVPUIAYNEaoYEDB5peoMqVK5ug0qlTJ3O/riuUlJRUoq81cuRIM7OrKLqdh7Po14qLizP7pXXs2LHAc8aMGSMjRoxw3NYeIcIQnL+gYiYXFQCsGISefvppadSokdlrTGdx2WdhaW+Q9vCURMWKFc1RVnSjWK0R0hBXGP1+mFmGslhHKDfXJv7+flxsALDa9Pk777zzkvv69+8vrqQbu548edJ81GG5DRs2mPt1Cn9kZKT5XHukdGjrjjvuMFt+6Eyznj17mpqmX375RUaNGmXO79Kli0vbClwuCOXaRM5kZktU6PnbAACLrSNU1p588kmZPn2643bz5s3NxyVLljim7OvCjikpKY4eqp9++sk8Jzk5WRITE6Vz587y7LPP0uMDtwkNCpDQIH85l5UrKelZBCEAcDPLBCFdP+hyawjpvmd2YWFh8vXXX5dBy4CS9wqdy8oww2OU4gOAxWaNAXDO6tJMoQcA9yMIAWUs2rHfGDPHAMByQUi31NBFCs+ePeuaFgFejkUVAcDCQUiLlB977DEzE+uBBx6QlStXuqZlgA9MoQcAWCwITZgwQQ4ePGh2dD969Ki0a9dOGjRoIC+99JIcOXLENa0EvHBRRYIQAFi0RigwMFB69OghX3zxhVmksE+fPvKPf/zDrMCsW1ksXrzY+S0FvGybDVaXBgCLF0uvXr1annrqKXn55ZfNDu+6PYVuYdGtWzczfAag8KExZo0BgAXXEdLhsPfff98Mje3cuVNuvfVW+fDDD81qzbr3mNL9w3Tndx0uA5AfNUIAYOEgVLVqValTp47cf//9JvAUtFdYkyZNpGXLls5qI+BVqBECAAsHoUWLFsl1111X5DlRUVFm6wsAl2JBRQCwcI3Q5UIQgOL1CLGgIgC4HytLA2Us6kKxtG68ei4rh+sPAG5EEALKWLmQQPE/P69AUllUEQDciiAElPX/dP5+v0+hJwgBgLWCkO4zlpGRccn9mZmZ5jEAl8cUegCwaBAaOHCgpKSkXHL/6dOnzWMALi/asbo0+40BgKWCkM1mcyycmJdutREdHe2sdgE+sgN9prubAgA+LbAku85rANKjY8eOZr8xu5ycHNm9e7dZTRrA5bGoIgBYLAjpZqpqw4YNZjuNyMhIx2PBwcFSs2ZN6dmzp2taCXgZaoQAwGJBSDdXVRp4evXqJSEhIa5sF+ATQ2Nfbjok+06m/35/eLD8pX1tqRwd5sbWAYDvKPEWGzfccIMcO3bM7Dlm34F+1qxZ0qBBAxk8eLAr2gh4naqx4ebjr8fSzJHX3I0HZWLv5nJt3Tg3tQ4AfEeJg1CfPn1M4Lnvvvvk8OHD0qlTJ2nUqJHMnDnT3H7yySdd01LAi9zePFF0zkFKnnWEbDaRz388IFsOpcp9766SkZ2vlKHt65h1hwAAruFn02lgJRAbGysrV66UK6+8Ul577TX5+OOPZdmyZfK///1PhgwZIr/++qt4k9TUVDMbTpcM0M1kAVfSLTee/GKzzF6739zuVL+SvHxXM4m+sD8ZAMC5798lnj6flZXlqA/65ptv5LbbbjOfJyUlyaFDh0r65QDkERoUIC/e2VTG92wswYH+8s3Wo3LPlBWSnZPLdQIAFyhxEGrYsKG8+eab8v3338vChQsdU+YPHjwoFSpUcEUbAZ9zT8vq8tnQNhIS6C/bDp+WPSd+L6gGALgxCI0fP17eeust6dChg/Tu3VuaNm1q7p87d65cffXVTmwa4NsaVYmWuMjzva9pGdnubg4AeKUSF0trADp+/LgZe9N6ITstoA4PPz8TBoBzhAcHmI9pmQQhAPCY3ee1vnrdunWmZ0j3GLMvqkgQApwrIuT83yppGTlcWgDwhB6h3377zdQF7d271+xCf+ONN0q5cuXMkJne1vohAM4R6QhC9AgBgEf0CD388MNy1VVXyalTpyQs7PfVb++44w5ZtGiRs9sH+DSGxgDAw3qEdLbY8uXLzVBYXrr1xoEDB5zZNsDn0SMEAB7WI5Sbm2t2m7/Y/v37zRAZAOfXCJ2hRggAPCMIde7cWSZMmOC47efnJ2fOnDGbst5yyy3Obh/g08JDzs8aS6dGCAA8Y2js5Zdfli5duphNVs+dO2f2Htu5c6fExcXJhx9+6JpWAj4qMvhCsTTT5wHAM4KQ7jq/ceNGs8eYftTeoEGDBknfvn3zFU8D+OMYGgMADwtC5kmBgSb46AHAdSIYGgMAzwpCJ06ccOwptm/fPnn77bfl7Nmzcuutt0q7du1c0UbAZ/3eI8Q6QgDg1mLpTZs2mSnylSpVMjvNb9iwQVq2bCmvvPKKTJkyRW644QaZM2eOSxoJiK+vLE2NEAC4NwiNGjVKGjduLN99953Zb6xbt27StWtXSUlJMYsr/uUvf5EXXniBHxPgRBEXiqXTmT4PAO4dGluzZo0sXrxYmjRpYnac116gBx98UPz9z2epv/71r3LNNde4ppWAj9cIMTQGAG7uETp58qQkJCSYzyMjIyUiIiLf7vP6uX0DVgDOwcrSAOBBCyrq4olF3QbgqhqhHMnNtXF5AcCds8YGDBggISEh5nNdTHHIkCGmZ0jpzvMAnCviQo2QOpuV4whGAIAy7hHq37+/mTEWHR1tjnvvvVcSExMdt/Wxfv36iSvs2bPHLNpYq1Yts2hjnTp1zJYemZmZRT5Pw9qwYcPMdH8dzuvZs6ccOXLEJW0EXCE0yF/8L3S8pjGFHgCcrth/Xk6dOlXcZdu2bWaz17feekvq1q0rmzdvlgceeEDS0tLkpZdeKvR5jz76qHz55ZfyySefmLA2fPhw6dGjhyxbtqxM2w+Ulg4/ay/Q6XPZpmC6EpcSAJzKEv3sN910kznsateuLdu3b5fJkycXGoR0Wv+7774rs2bNMmsc2cNc/fr1ZeXKlcxwg2VEBJ8PQumZOe5uCgB4nRLvPu8pNOiUL1++0MfXrVsnWVlZ0qlTJ8d9uhBk9erVZcWKFWXUSuCPYwo9APh4j9DFdu3aJRMnTixyWOzw4cMSHBwsMTEx+e6Pj483jxVGi77zFn6npqY6qdVA6TCFHgC8tEdo9OjRpgaiqEPrg/I6cOCAGSa76667TJ2Qs40bN85RAK5HtWrVnP4aQEmEB/8+hR4A4EU9QiNHjjRT8oui9UB2Bw8elOuvv17atGljVrYuii7+qLPKkpOT8/UK6awx+8KQBRkzZoyMGDEiX48QYQgesZYQs8YAwLuCUMWKFc1RHNoTpCGoRYsWpujZvrVHYfS8oKAgWbRokZk2r7TAeu/evdK6detCn6frJNnXSgI8QeSFbTYIQgDgo8XSGoJ0o1ctdNa6oGPHjpk6n7y1PnqOFkOvXr3a3NZhLV17SHt3lixZYoqnBw4caEIQe6LBSsIdPUIMjQGATxZLL1y40BRI61G1atV8j9ls57cd0Bli2uOTnp7ueOyVV14xPUfaI6QF0F26dJE33nijzNsPOKVYOjObCwkATuZnsycJFEhrhLR3SafrR0VFcZVQ5l79Zqe88s0O6dOqujx/R2N+AgDgxPdvSwyNAb7Mvo4QNUIA4HwEIcAys8aoEQIAZyMIAR6O6fMA4DoEIcAq0+cplgYApyMIAVZZWZoFFQHA6QhCgGX2GqNGCACcjSAEeDhqhADAdQhCgIeLCP69RohlvwDAuQhCgEV6hHJtIueyct3dHADwKgQhwMOFBQWIn9/5z89QMA0ATkUQAjycv7+fhAedHx5LZwo9ADgVQQiw0PAYPUIA4FwEIcACmEIPAK5BEAIsIJyNVwHAJQhCgAVE2FeXpkYIAJyKIARYamgs291NAQCvQhACLFUszTYbAOBMBCHAAiIu1Ail0yMEAE5FEAIsIOJCjdAZaoQAwKkIQoAFsPEqALgGQQiw1NAYNUIA4EwEIcACWFkaAFyDIARYafo8NUIA4FQEIcACwu0LKjI0BgBORRACLFQjxIKKAOBcBCHAAlhZGgBcgyAEWIBjaCyTWWMA4EwEIcBiPUI2m83dzQEAr0EQAixUI5Sda5OM7Fx3NwcAvAZBCLDQ0JiiYBoAnIcgBFhAgL+fhAVdWF2aOiEAcBqCEGARrC4NAM5HEAIsIpK1hADA6QhCgEUwhR4AnI8gBFgEiyoCgPMRhACLTaE/k5Ht7qYAgNcgCAEWEX5hUcV0ghAAOA1BCLCISLbZAACnIwgBFsH0eQBwPoIQYLEaIYbGAMB5CEKA5XqE2IEeAJyFIARYLAix1xgAOA9BCLCIiODzQ2NpmUyfBwBnIQgBFkGPEAD4aBDas2ePDBo0SGrVqiVhYWFSp04deeqppyQzM7PI53Xo0EH8/PzyHUOGDCmzdgOuWVmaGiEAcJbz/7J6uG3btklubq689dZbUrduXdm8ebM88MADkpaWJi+99FKRz9XznnnmGcft8PDwMmgx4HxMnwcAHw1CN910kznsateuLdu3b5fJkydfNghp8ElISCiDVgJlUyOUTo0QAPjW0FhBUlJSpHz58pc9b+bMmRIXFyeNGjWSMWPGSHp6epHnZ2RkSGpqar4D8KwaIYbGAMCneoQutmvXLpk4ceJle4P69OkjNWrUkMTERPnpp5/k8ccfNz1Jn332WaHPGTdunIwdO9YFrQacE4Qyc3IlMztXggMt+3cMAHgMP5vNZnPXi48ePVrGjx9f5Dlbt26VpKQkx+0DBw5I+/btTSH0O++8U6LXW7x4sXTs2NEEKS24LqxHSA877RGqVq2a6YGKiooq0esBzpSdkyt1n/iv+XzDkzdKTHgwFxgACqHv39HR0Zd9/3Zrj9DIkSNlwIABRZ6j9UB2Bw8elOuvv17atGkjU6ZMKfHrtWrVynwsKgiFhISYA/A0gQH+EhLoLxnZuXImI5sgBABO4NYgVLFiRXMUh/YEaQhq0aKFTJ06Vfz9Sz4ssGHDBvOxcuXKJX4u4ClT6DOyM6kTAgAnsUSRgYYgHQqrXr26qQs6duyYHD582Bx5z9EhtNWrV5vbv/zyizz77LOybt06sw7R3LlzpV+/ftKuXTtp0qSJG78boPTCL2y8yurSAOBDxdILFy40w1l6VK1aNd9j9hKnrKwsUwhtnxUWHBws33zzjUyYMMGsN6R1Pj179pS///3vbvkeAGeICGa/MQDwuSCkdUSXqyWqWbOmIxQpDT5Lly4tg9YB7lhdmv3GAMBnhsYAnBfOWkIA4FQEIcBCIqkRAgCnIggBFhJxoUZIp88DAP44ghBgyW02CEIA4AwEIcBCIuxDY+w3BgBOQRACLIQeIQBwLoIQYMXp85kMjQGAMxCEAAsJdyyomOPupgCAVyAIAVacPk+xNAA4BUEIsGCNENPnAcA5CEKABYPQvpPpsu63U+5uDgBYHkEIsJDGVaIlKaGcpGXmyD1vrZBpy3bn22MPAFAyBCHAQoIC/OXToW2ka5PKkp1rk6fnbZGHP9og6cwiAwDv3X0eQP4p9JN6N5c/VY+VcV9tlbkbD8q2w6nSqX68yy5T+Yhg6d+mpgliAOBNCEKABfn5+cmgtrWkSdVoGTZzvew4csYcrlSxXIjc3qyKS18DAMoaQQiwsJY1y8v8h9rKByv3yulzWS55jfW/nZKN+1Nk84EUghAAr0MQAiyuUrlQGXHjFS77+h+u3isb92+SbYdPu+w1AMBdGPAHUKQrE8qZj9sJQgC8EEEIQJGuiD8fhI6ezpCTaZlcLQBehSAE4LKz1KqVDzOf6+w0APAmBCEAl5WUEGU+MjwGwNsQhABclq5mrQhCALwNQQhAsQummTkGwNsQhAAUu0dox5HTkpvL3mYAvAdBCMBl1awQIcGB/pKemSP7T53ligHwGgQhAJcVGOAvdStGms+ZOQbAmxCEABQLBdMAvBFBCEDJCqaPsNUGAO9BEAJQsiB0iEUVAXgPghCAEi2quOdEupzLyuGqAfAKBCEAxRIfFSLRYUGSk2uTXUfPcNUAeAWCEIBi8fPzYyd6AF6HIASg5DPHKJgG4CUIQgCKja02AHgbghCAUqwlxMwxAN6BIASg2K6IPx+EjqRmSHJ6JlcOgOURhAAUW7nQIKkaG2Y+Zyd6AN6AIASgRNhqA4A3IQgBKGXBNHVCAKyPIASgRK68sMI0Q2MAvEGguxsAwJpDY1sPpcrYeT+7uzmWVrNChLSoESv1K0dJgL+fu5sD+CSCEIASqRUXIRHBAZKWmSNTl+3h6jmBXs/m1WOlWbUYCQ8JKNNrGhcRIk2qRUu9SuUIY/BJfjabzebuRniy1NRUiY6OlpSUFImKOj8kAPi65buOyw+7jru7GZame7ZtOZQqP+5NljMZ2e5ujoQHB0ijxGhpUjVarkgoZ3qralYIl4rlQsz2KoC3vn9bJgjddtttsmHDBjl69KjExsZKp06dZPz48ZKYmFjoc86dOycjR46Ujz76SDIyMqRLly7yxhtvSHx8fLFflyAEwNWBaPvh07Lut5Oy5dBpycnNLbMLrv/67z91Vn7an2x6+AoSFhRglkwICrh8SWlIkL/EhgdLTFiQxIQHS2x4kLSuU8EM/xGmUNa8Lgi98sor0rp1a6lcubIcOHBAHnvsMXP/8uXLC33O0KFD5csvv5Rp06aZizF8+HDx9/eXZcuWFft1CUIAfCGM/XrsjGzcn2JC0e7jabLnRJocOHVWcm3OqSu7r3UN6d6sikSEUJGBsuF1Qehic+fOle7du5uenqCgoEse12+8YsWKMmvWLLnzzjvNfdu2bZP69evLihUr5JprrinW6xCEAPiqzOxcOZB8Vg4mnzVhqSj66LmsHLPieHJ6lpxKzzLP/d/PhyUj+3wvV2RIoNzaNNEMt5VUgJ+f3Nq0stSuGFnq7we+JbWYQciS0fzkyZMyc+ZMadOmTYEhSK1bt06ysrLMEJpdUlKSVK9evcggpMFKj7wXEgB8UXCgvymO16O0NBh9um6/zFy11/Q0fbh6b6m/1qJtR2Tu8Lalfj5g+SD0+OOPy6RJkyQ9Pd0Emfnz5xd67uHDhyU4OFhiYmLy3a/1QfpYYcaNGydjx451arsBwFdprdCfr6st919bS5b9clwWbzsq2TklG4iwiU0+XrNPftqfIlsOpkqDRCauwEuC0OjRo03Bc1G2bt1qenLU3/72Nxk0aJD89ttvJqz069fPhCFnFuGNGTNGRowYka9HqFq1ak77+gDgi/z9/eS6ehXNURon0zLlq02H5ZN1++SpxIZObx98l1uDkM7oGjBgQJHn1K5d2/F5XFycOa644gpT66MBZeXKlaaI+mIJCQmSmZkpycnJ+XqFjhw5Yh4rTEhIiDkAAJ7jrquqmSA058cDMvrmJAkJLNv1luC93BqEtJhZj9LIvTDFNG89T14tWrQw9UOLFi2Snj17mvu2b98ue/fuLTA4AQA8V7t6FSUhKlQOp56TRVuPyi2NK7u7SfASlthrbNWqVaY2SNcR0mGxxYsXS+/evaVOnTqOUKNT6nUIbfXq1ea2VorrMJoOcy1ZssQUTw8cONCcX9wZYwAAz6BbkPRsUcV8PnvtPnc3B17EEkEoPDxcPvvsM+nYsaNceeWVJuA0adJEli5d6hjG0hli2uOjhdR51x7q1q2b6RFq166dGRLTrwMAsJ67Wpyv1/xuxzE5lHLW3c2Bl7DsOkJlhXWEAMBz3P3WClm9+6Q81vkKGX5DPXc3B17w/m2JHiEAANQ9V53vFZq9dr/kOmPZa/g8ghAAwDJubpxgVqjeezJdVu856e7mwAsQhAAAlhEerNt0nJ8xRtE0fG5laQAAdE2hD1fvk682HZIHO9SV0CD+pveGFcgj3bQhL0EIAGApzavFSN1KkbLr6Bnp9O+l7m4OnOD5OxpLn1bVxR0IQgAAS9FtlR7qWE/+/vkmx872sLYAN3bqEYQAAJZzW9NEcwB/FAOrAADAZxGEAACAzyIIAQAAn0UQAgAAPosgBAAAfBZBCAAA+CyCEAAA8FkEIQAA4LMIQgAAwGcRhAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQQgAAPisQHc3wNPZbDbzMTU11d1NAQAAxWR/37a/jxeGIHQZp0+fNh+rVatW3GsPAAA86H08Ojq60Mf9bJeLSj4uNzdXDh48KOXKlRM/Pz+nJlUNV/v27ZOoqCinfV1w7T0Zv/dce1/E7717aLzREJSYmCj+/oVXAtEjdBl68apWrSquoiGIIOQeXHv34dpz7X0Rv/dlr6ieIDuKpQEAgM8iCAEAAJ9FEHKTkJAQeeqpp8xHcO19Bb/3XHtfxO+9Z6NYGgAA+Cx6hAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQchNXn/9dalZs6aEhoZKq1atZPXq1e5qitcaN26ctGzZ0qwKXqlSJenevbts37493znnzp2TYcOGSYUKFSQyMlJ69uwpR44ccVubvdELL7xgVmV/5JFHHPdx3V3nwIEDcu+995rf6bCwMGncuLGsXbs232q7Tz75pFSuXNk83qlTJ9m5c6cLW+QbcnJy5B//+IfUqlXLXNc6derIs88+m2+fK669ZyIIucHHH38sI0aMMNPn169fL02bNpUuXbrI0aNH3dEcr7V06VITclauXCkLFy6UrKws6dy5s6SlpTnOefTRR2XevHnyySefmPN1O5UePXq4td3eZM2aNfLWW29JkyZN8t3PdXeNU6dOybXXXitBQUHy3//+V7Zs2SIvv/yyxMbGOs558cUX5bXXXpM333xTVq1aJREREebfHw2nKL3x48fL5MmTZdKkSbJ161ZzW6/1xIkTufaeTvcaQ9m6+uqrbcOGDXPczsnJsSUmJtrGjRvHj8KFjh49qn+a2ZYuXWpuJycn24KCgmyffPKJ45ytW7eac1asWMHP4g86ffq0rV69eraFCxfa2rdvb3v44Ye57i72+OOP29q2bVvo47m5ubaEhATbv/71L8d9+v9BSEiI7cMPP3R187xa165dbffff3+++3r06GHr27ev+Zxr77noESpjmZmZsm7dOtMdnXc/M729YsWKsm6OT0lJSTEfy5cvbz7qz0F7ifL+LJKSkqR69er8LJxAe+O6du2a7/py3V1r7ty5ctVVV8ldd91lhoObN28ub7/9tuPx3bt3y+HDh/P9THQvJh2e59+fP6ZNmzayaNEi2bFjh7m9ceNG+eGHH+Tmm2/m2ns4Nl0tY8ePHzdjyfHx8fnu19vbtm0r6+b4jNzcXFOjosMGjRo1MvfpG0JwcLDExMRc8rPQx1B6H330kRn21aGxi3HdXefXX381wzM69P5///d/5vo/9NBD5ve8f//+jt/rgv794Xf+jxk9erTZZV7/mAoICDD/zj/33HPSt29f8zjX3nMRhOAzvRObN282f6HBtfbt2ycPP/ywqcvSyQAo28CvPULPP/+8ua09Qvp7r/VAGoTgOrNnz5aZM2fKrFmzpGHDhrJhwwbzx1diYiLX3sMxNFbG4uLizF8LF89M0tsJCQll3RyfMHz4cJk/f74sWbJEqlat6rhfr7cOVSYnJ+c7n5/FH6NDjlr4/6c//UkCAwPNoYXoWqCrn2vvA9fdNXQmWIMGDfLdV79+fdm7d6/53P5vDP/+ON/f/vY30yvUq1cvM1PvvvvuM5MCdPYq196zEYTKmHZRt2jRwowl5/0rTm+3bt26rJvj1XSqqoagzz//XBYvXmymtealPwedXZP3Z6HT6/VNg59F6XXs2FE2bdpk/iK2H9pLoUME9s+57q6hQ78XLxGhNSs1atQwn+v/AxqG8v7O63COzh7jd/6PSU9PN/WeeekfvfrvO9few7m7WtsXffTRR2aWxrRp02xbtmyxDR482BYTE2M7fPiwu5vmVYYOHWqLjo62ffvtt7ZDhw45jvT0dMc5Q4YMsVWvXt22ePFi29q1a22tW7c2B5wr76wxrrvrrF692hYYGGh77rnnbDt37rTNnDnTFh4ebvvggw8c57zwwgvm35svvvjC9tNPP9luv/12W61atWxnz551Ycu8X//+/W1VqlSxzZ8/37Z7927bZ599ZouLi7ONGjXKcQ7X3jMRhNxk4sSJ5g04ODjYTKdfuXKlu5ritTTnF3RMnTrVcY7+4//ggw/aYmNjzRvGHXfcYcISXBuEuO6uM2/ePFujRo3MH1tJSUm2KVOm5Htcp3H/4x//sMXHx5tzOnbsaNu+fbsLW+QbUlNTze+4/rseGhpqq127tu2JJ56wZWRkOM7h2nsmP/2Pu3ulAAAA3IEaIQAA4LMIQgAAwGcRhAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQQgAAPgsghAAXEbNmjVlwoQJXCfACxGEAHiUAQMGSPfu3c3nHTp0MDt4l5Vp06ZJTEzMJfevWbNGBg8eXGbtAFB2AsvwtQDALXS3e93wuLQqVqzo1PYA8Bz0CAHw2J6hpUuXyquvvip+fn7m2LNnj3ls8+bNcvPNN0tkZKTEx8fLfffdJ8ePH3c8V3uShg8fbnqT4uLipEuXLub+f//739K4cWOJiIiQatWqyYMPPihnzpwxj3377bcycOBASUlJcbze008/XeDQ2N69e+X22283rx8VFSV33323HDlyxPG4Pq9Zs2by/vvvm+dGR0dLr1695PTp02V2/QAUD0EIgEfSANS6dWt54IEH5NChQ+bQ8JKcnCw33HCDNG/eXNauXSsLFiwwIUTDSF7Tp083vUDLli2TN99809zn7+8vr732mvz888/m8cWLF8uoUaPMY23atDFhR4ON/fUee+yxS9qVm5trQtDJkydNUFu4cKH8+uuvcs899+Q775dffpE5c+bI/PnzzaHnvvDCCy69ZgBKjqExAB5Je1E0yISHh0tCQoLj/kmTJpkQ9Pzzzzvue++990xI2rFjh1xxxRXmvnr16smLL76Y72vmrTfSnpp//vOfMmTIEHnjjTfMa+lrak9Q3te72KJFi2TTpk2ye/du85pqxowZ0rBhQ1NL1LJlS0dg0pqjcuXKmdvaa6XPfe6555x2jQD8cfQIAbCUjRs3ypIlS8ywlP1ISkpy9MLYtWjR4pLnfvPNN9KxY0epUqWKCSgaTk6cOCHp6enFfv2tW7eaAGQPQapBgwamyFofyxu07CFIVa5cWY4ePVqq7xmA69AjBMBStKbn1ltvlfHjx1/ymIYNO60Dykvri7p16yZDhw41vTLly5eXH374QQYNGmSKqbXnyZmCgoLy3daeJu0lAuBZCEIAPJYOV+Xk5OS7709/+pP85z//MT0ugYHF/yds3bp1Joi8/PLLplZIzZ49+7Kvd7H69evLvn37zGHvFdqyZYupXdKeIQDWwtAYAI+lYWfVqlWmN0dnhWmQGTZsmClU7t27t6nJ0eGwr7/+2sz4KirE1K1bV7KysmTixImmuFlndNmLqPO+nvY4aS2Pvl5BQ2adOnUyM8/69u0r69evl9WrV0u/fv2kffv2ctVVV7nkOgBwHYIQAI+ls7YCAgJMT4uu5aPT1hMTE81MMA09nTt3NqFEi6C1Rsfe01OQpk2bmunzOqTWqFEjmTlzpowbNy7fOTpzTIundQaYvt7Fxdb2Ia4vvvhCYmNjpV27diYY1a5dWz7++GOXXAMAruVns9lsLn4NAAAAj0SPEAAA8FkEIQAA4LMIQgAAwGcRhAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQQgAAPgsghAAAPBZBCEAAOCzCEIAAMBnEYQAAID4qv8HzYNusut+DJYAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "plt.plot(best_y_per_iteration.values)\n", + "plt.xlabel(\"Iteration\")\n", + "plt.ylabel(\"Best y so far\")\n", + "plt.title(\"Best y per iteration\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "96" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(best_y_per_iteration)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "96" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(strategy.experiments)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
710.3999520.8877421.00.5504820.00.00.00.00.00.01.00.0-3.110493True
940.3980060.8851970.00.5726160.00.01.00.00.00.00.01.0-3.013107True
900.3977890.8849290.00.5733320.00.01.00.00.00.00.00.0-3.013040True
730.3958580.8817980.00.5564070.00.00.00.01.00.01.01.0-3.002077True
670.4018520.8889790.00.5496050.00.00.01.00.00.01.01.0-2.996064True
.............................................
691.0000000.6017400.00.0000000.00.01.00.00.01.00.00.0-0.000671True
440.0000000.0000000.00.0000001.00.01.00.00.00.00.00.0-0.000255True
891.0000001.0000001.00.0000001.00.00.00.00.00.01.00.0-0.000232True
681.0000000.4827590.01.0000000.01.01.00.00.00.00.00.0-0.000157True
581.0000000.0000000.01.0000000.00.00.00.00.00.01.00.0-0.000017True
\n", + "

96 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 x_spurious_1 \\\n", + "71 0.399952 0.887742 1.0 0.550482 0.0 0.0 0.0 0.0 \n", + "94 0.398006 0.885197 0.0 0.572616 0.0 0.0 1.0 0.0 \n", + "90 0.397789 0.884929 0.0 0.573332 0.0 0.0 1.0 0.0 \n", + "73 0.395858 0.881798 0.0 0.556407 0.0 0.0 0.0 0.0 \n", + "67 0.401852 0.888979 0.0 0.549605 0.0 0.0 0.0 1.0 \n", + ".. ... ... ... ... ... ... ... ... \n", + "69 1.000000 0.601740 0.0 0.000000 0.0 0.0 1.0 0.0 \n", + "44 0.000000 0.000000 0.0 0.000000 1.0 0.0 1.0 0.0 \n", + "89 1.000000 1.000000 1.0 0.000000 1.0 0.0 0.0 0.0 \n", + "68 1.000000 0.482759 0.0 1.000000 0.0 1.0 1.0 0.0 \n", + "58 1.000000 0.000000 0.0 1.000000 0.0 0.0 0.0 0.0 \n", + "\n", + " x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 y valid_y \n", + "71 0.0 0.0 1.0 0.0 -3.110493 True \n", + "94 0.0 0.0 0.0 1.0 -3.013107 True \n", + "90 0.0 0.0 0.0 0.0 -3.013040 True \n", + "73 1.0 0.0 1.0 1.0 -3.002077 True \n", + "67 0.0 0.0 1.0 1.0 -2.996064 True \n", + ".. ... ... ... ... ... ... \n", + "69 0.0 1.0 0.0 0.0 -0.000671 True \n", + "44 0.0 0.0 0.0 0.0 -0.000255 True \n", + "89 0.0 0.0 1.0 0.0 -0.000232 True \n", + "68 0.0 0.0 0.0 0.0 -0.000157 True \n", + "58 0.0 0.0 1.0 0.0 -0.000017 True \n", + "\n", + "[96 rows x 14 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.sort_values(by=\"y\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.40246971, 0.49047062, 0.83129865, 0. , 0. ,\n", + " 0. , 0.25661052, 0.92627053, 0. , 0.97370585,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0.29692436, 0.30276136,\n", + " 0.31188112, 0. , 0.31113767, 0.31041889, 0.33260343,\n", + " 0.33121978, 0.33797472, 0.34426884, 0. , 0.42236647,\n", + " 0.38995489, 0.40404801, 0.40527138, 0.40636354, 0.37343461,\n", + " 0. , 0.46571731, 0.42500977, 0.81058484, 0.44997919,\n", + " 0. , 0.4405271 , 0.43008012, 0. , 0. ,\n", + " 0.4443915 , 0. , 0. , 0. , 0.43516446,\n", + " 0.43388985, 0.44600393, 0. , 0. , 0. ,\n", + " 0.42146618, 0.48546936, 0.40232409, 1. , 0.39738112,\n", + " 1. , 0.41107153, 0.40042298, 0.39872447, 0. ,\n", + " 0.39733587, 0.39918283, 0.40185191, 1. , 1. ,\n", + " 0.66228013, 0.39995177, 1. , 0.39585761, 1. ,\n", + " 0. , 0. , 0. , 0.45403372, 0. ,\n", + " 1. , 0.44424961, 0. , 0. , 0. ,\n", + " 0.4010512 , 0. , 0.38435384, 0.41094965, 1. ,\n", + " 0.39778853, 0. , 0. , 0. , 0.39800625,\n", + " 0. ])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.x_0.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.3920, 0.0251, 5.8860, 1.1981, 7.8354, 0.9001, 15.2773, 13.5841,\n", + " 13.1932, 14.4943, 13.7058, 15.2569]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.4024700.0000000.0000000.0000000.0000000.0000000.0000000.6374530.1399190.4218240.6713710.331502-0.005683True
10.4904710.2944280.0000000.0000000.0000000.9252190.1794820.7729880.4809430.0000000.0000000.000000-0.096735True
20.8312990.0000000.1528290.2337560.0000000.2376550.2170320.0000000.5361540.0000000.0000000.000000-0.041713True
30.0000000.8293810.0000000.2879120.0000000.0000000.7741490.0000000.0000000.5073140.7311190.425413-0.082572True
40.0000000.8670440.1348860.0000000.2504430.2510050.7662670.0000000.0214130.0000000.0000000.000000-0.045721True
50.0000000.3077290.8427450.1164260.0000000.2726150.0000000.0000000.0000000.7787310.0000000.155062-0.097132True
60.2566110.4944050.0000000.6031580.0000000.0000000.8366570.0000000.0000000.8080900.0000000.134337-0.619587True
70.9262710.0000000.0000000.3188560.0000000.0000000.6645890.0000000.4928020.0531630.0000000.259605-0.003226True
80.0000000.2191180.6250430.7884450.0000000.0000000.0000000.4703430.0000000.0000000.3855200.006767-0.006505True
90.9737060.0000000.0000000.0623920.0000000.6765590.0000000.8051040.0000000.2609770.4023370.000000-0.054280True
100.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.9186150.0000000.058531-0.005089True
110.0000000.0000000.0000000.6817350.0000000.0000000.0000000.0000000.0000000.8729310.0000000.110071-0.002782True
120.0000000.5359470.0000000.0000000.0000000.0000001.0000000.0000000.0000000.0000000.0000000.157608-0.005837True
130.0000000.0000000.0000000.0000000.0000000.0000000.9638820.0000000.0000000.8681810.0000000.102231-0.005089True
140.0000000.5218990.0000000.6834860.0000000.0010560.0000000.0000000.0000001.0000000.0000000.144456-0.059981True
150.0000000.0000000.0000000.6328840.0000000.0000000.9731650.0000000.0000001.0000000.0000000.141852-0.003873True
160.0000000.4853940.0190510.0000000.0000000.0000000.8809520.0000000.0000000.9239620.0000000.111996-0.005705True
170.0000000.0000000.0000000.7615450.0000000.0000000.8266870.0000000.0000000.1617930.0000000.000000-0.001490True
180.2969240.5858560.0000000.7151310.0000000.0000000.9460410.0000000.0000000.0000000.0000000.000000-1.003502True
190.3027610.0000000.0000000.0000000.3969290.0000000.0000000.0000000.0000000.0000000.0000000.000000-0.024736True
200.3118810.0000000.0000000.8194800.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000-0.003597True
210.0000000.6611510.0000000.8140610.0000000.0000001.0000000.0000000.0000000.0000000.0000000.000000-0.071057True
220.3111380.0000000.0000000.0000000.0000000.0000001.0000000.0000000.0000000.0000000.0000000.000000-0.006050True
230.3104190.6607250.0000000.0000000.0000000.0000000.9965630.0000000.0000000.0000000.0000000.000000-0.067044True
240.3326030.0000000.0000000.8009070.0000000.0000001.0000000.0000000.0000000.0000000.0000000.000000-0.004195True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.402470 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "1 0.490471 0.294428 0.000000 0.000000 0.000000 0.925219 0.179482 \n", + "2 0.831299 0.000000 0.152829 0.233756 0.000000 0.237655 0.217032 \n", + "3 0.000000 0.829381 0.000000 0.287912 0.000000 0.000000 0.774149 \n", + "4 0.000000 0.867044 0.134886 0.000000 0.250443 0.251005 0.766267 \n", + "5 0.000000 0.307729 0.842745 0.116426 0.000000 0.272615 0.000000 \n", + "6 0.256611 0.494405 0.000000 0.603158 0.000000 0.000000 0.836657 \n", + "7 0.926271 0.000000 0.000000 0.318856 0.000000 0.000000 0.664589 \n", + "8 0.000000 0.219118 0.625043 0.788445 0.000000 0.000000 0.000000 \n", + "9 0.973706 0.000000 0.000000 0.062392 0.000000 0.676559 0.000000 \n", + "10 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "11 0.000000 0.000000 0.000000 0.681735 0.000000 0.000000 0.000000 \n", + "12 0.000000 0.535947 0.000000 0.000000 0.000000 0.000000 1.000000 \n", + "13 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.963882 \n", + "14 0.000000 0.521899 0.000000 0.683486 0.000000 0.001056 0.000000 \n", + "15 0.000000 0.000000 0.000000 0.632884 0.000000 0.000000 0.973165 \n", + "16 0.000000 0.485394 0.019051 0.000000 0.000000 0.000000 0.880952 \n", + "17 0.000000 0.000000 0.000000 0.761545 0.000000 0.000000 0.826687 \n", + "18 0.296924 0.585856 0.000000 0.715131 0.000000 0.000000 0.946041 \n", + "19 0.302761 0.000000 0.000000 0.000000 0.396929 0.000000 0.000000 \n", + "20 0.311881 0.000000 0.000000 0.819480 0.000000 0.000000 0.000000 \n", + "21 0.000000 0.661151 0.000000 0.814061 0.000000 0.000000 1.000000 \n", + "22 0.311138 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 \n", + "23 0.310419 0.660725 0.000000 0.000000 0.000000 0.000000 0.996563 \n", + "24 0.332603 0.000000 0.000000 0.800907 0.000000 0.000000 1.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.637453 0.139919 0.421824 0.671371 0.331502 \n", + "1 0.772988 0.480943 0.000000 0.000000 0.000000 \n", + "2 0.000000 0.536154 0.000000 0.000000 0.000000 \n", + "3 0.000000 0.000000 0.507314 0.731119 0.425413 \n", + "4 0.000000 0.021413 0.000000 0.000000 0.000000 \n", + "5 0.000000 0.000000 0.778731 0.000000 0.155062 \n", + "6 0.000000 0.000000 0.808090 0.000000 0.134337 \n", + "7 0.000000 0.492802 0.053163 0.000000 0.259605 \n", + "8 0.470343 0.000000 0.000000 0.385520 0.006767 \n", + "9 0.805104 0.000000 0.260977 0.402337 0.000000 \n", + "10 0.000000 0.000000 0.918615 0.000000 0.058531 \n", + "11 0.000000 0.000000 0.872931 0.000000 0.110071 \n", + "12 0.000000 0.000000 0.000000 0.000000 0.157608 \n", + "13 0.000000 0.000000 0.868181 0.000000 0.102231 \n", + "14 0.000000 0.000000 1.000000 0.000000 0.144456 \n", + "15 0.000000 0.000000 1.000000 0.000000 0.141852 \n", + "16 0.000000 0.000000 0.923962 0.000000 0.111996 \n", + "17 0.000000 0.000000 0.161793 0.000000 0.000000 \n", + "18 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "19 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "20 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "21 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "22 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "24 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.005683 True \n", + "1 -0.096735 True \n", + "2 -0.041713 True \n", + "3 -0.082572 True \n", + "4 -0.045721 True \n", + "5 -0.097132 True \n", + "6 -0.619587 True \n", + "7 -0.003226 True \n", + "8 -0.006505 True \n", + "9 -0.054280 True \n", + "10 -0.005089 True \n", + "11 -0.002782 True \n", + "12 -0.005837 True \n", + "13 -0.005089 True \n", + "14 -0.059981 True \n", + "15 -0.003873 True \n", + "16 -0.005705 True \n", + "17 -0.001490 True \n", + "18 -1.003502 True \n", + "19 -0.024736 True \n", + "20 -0.003597 True \n", + "21 -0.071057 True \n", + "22 -0.006050 True \n", + "23 -0.067044 True \n", + "24 -0.004195 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.2561, 0.3035, 1.2614, 1.3028, 1.5076, 1.4546, 1.6381, 1.6320, 1.7198,\n", + " 1.9160, 1.4628, 0.4608]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.7022250.1195760.0000000.0000000.5141660.0000000.3982430.0000000.0000000.9116470.0000000.506541-0.007801True
10.0000000.1354060.8432900.0000000.0000000.0000000.0795280.2151900.7617940.0000000.7555720.000000-0.008543True
20.0000000.6351610.0000000.7574720.0403000.0000000.0585810.0000000.0000000.0000000.5398340.529786-0.082457True
30.0000000.0000000.0000000.5976710.1564120.0000000.9289830.9367160.0000000.6583720.0000000.631700-0.015096True
40.8185240.0000000.0000000.0000000.9194440.6829780.0000000.0000000.6966650.6427770.0000000.947762-0.000610True
50.0000000.0000000.5590090.4596630.0000000.0000000.0647950.0000000.6381050.9473500.3983820.000000-0.017627True
60.0000000.0068260.0000000.3473980.6099630.0000000.2067430.9459320.0000000.8443830.0000000.000000-0.011483True
70.4624120.7461500.0438960.6079460.0000000.0000000.0000000.0000000.0000000.0000000.3490180.801749-2.436279True
80.0000000.0000000.0000000.7446480.0000000.0000000.2045490.7418980.5717230.4318280.0000000.892975-0.001715True
90.0000000.7577750.0000000.0000000.0000000.9877370.0000000.0000000.0826880.0030900.4656500.160617-0.021754True
100.4143310.8977280.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.4302510.000000-0.111822True
111.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000001.0000000.000000-0.001019True
120.5108750.0000000.0000000.6945000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000-0.006296True
130.4373670.0000000.0000000.6891500.0000000.0000000.0000000.0000000.0000000.0000000.4106880.910016-0.007494True
141.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000001.0000001.000000-0.001019True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.702225 0.119576 0.000000 0.000000 0.514166 0.000000 0.398243 \n", + "1 0.000000 0.135406 0.843290 0.000000 0.000000 0.000000 0.079528 \n", + "2 0.000000 0.635161 0.000000 0.757472 0.040300 0.000000 0.058581 \n", + "3 0.000000 0.000000 0.000000 0.597671 0.156412 0.000000 0.928983 \n", + "4 0.818524 0.000000 0.000000 0.000000 0.919444 0.682978 0.000000 \n", + "5 0.000000 0.000000 0.559009 0.459663 0.000000 0.000000 0.064795 \n", + "6 0.000000 0.006826 0.000000 0.347398 0.609963 0.000000 0.206743 \n", + "7 0.462412 0.746150 0.043896 0.607946 0.000000 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.744648 0.000000 0.000000 0.204549 \n", + "9 0.000000 0.757775 0.000000 0.000000 0.000000 0.987737 0.000000 \n", + "10 0.414331 0.897728 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "11 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "12 0.510875 0.000000 0.000000 0.694500 0.000000 0.000000 0.000000 \n", + "13 0.437367 0.000000 0.000000 0.689150 0.000000 0.000000 0.000000 \n", + "14 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.000000 0.911647 0.000000 0.506541 \n", + "1 0.215190 0.761794 0.000000 0.755572 0.000000 \n", + "2 0.000000 0.000000 0.000000 0.539834 0.529786 \n", + "3 0.936716 0.000000 0.658372 0.000000 0.631700 \n", + "4 0.000000 0.696665 0.642777 0.000000 0.947762 \n", + "5 0.000000 0.638105 0.947350 0.398382 0.000000 \n", + "6 0.945932 0.000000 0.844383 0.000000 0.000000 \n", + "7 0.000000 0.000000 0.000000 0.349018 0.801749 \n", + "8 0.741898 0.571723 0.431828 0.000000 0.892975 \n", + "9 0.000000 0.082688 0.003090 0.465650 0.160617 \n", + "10 0.000000 0.000000 0.000000 0.430251 0.000000 \n", + "11 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "12 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "13 0.000000 0.000000 0.000000 0.410688 0.910016 \n", + "14 0.000000 0.000000 0.000000 1.000000 1.000000 \n", + "\n", + " y valid_y \n", + "0 -0.007801 True \n", + "1 -0.008543 True \n", + "2 -0.082457 True \n", + "3 -0.015096 True \n", + "4 -0.000610 True \n", + "5 -0.017627 True \n", + "6 -0.011483 True \n", + "7 -2.436279 True \n", + "8 -0.001715 True \n", + "9 -0.021754 True \n", + "10 -0.111822 True \n", + "11 -0.001019 True \n", + "12 -0.006296 True \n", + "13 -0.007494 True \n", + "14 -0.001019 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.7022250.1195760.0000000.0000000.5141660.0000000.3982430.0000000.0000000.9116470.0000000.506541-0.007801True
10.0000000.1354060.8432900.0000000.0000000.0000000.0795280.2151900.7617940.0000000.7555720.000000-0.008543True
20.0000000.6351610.0000000.7574720.0403000.0000000.0585810.0000000.0000000.0000000.5398340.529786-0.082457True
30.0000000.0000000.0000000.5976710.1564120.0000000.9289830.9367160.0000000.6583720.0000000.631700-0.015096True
40.8185240.0000000.0000000.0000000.9194440.6829780.0000000.0000000.6966650.6427770.0000000.947762-0.000610True
50.0000000.0000000.5590090.4596630.0000000.0000000.0647950.0000000.6381050.9473500.3983820.000000-0.017627True
60.0000000.0068260.0000000.3473980.6099630.0000000.2067430.9459320.0000000.8443830.0000000.000000-0.011483True
70.4624120.7461500.0438960.6079460.0000000.0000000.0000000.0000000.0000000.0000000.3490180.801749-2.436279True
80.0000000.0000000.0000000.7446480.0000000.0000000.2045490.7418980.5717230.4318280.0000000.892975-0.001715True
90.0000000.7577750.0000000.0000000.0000000.9877370.0000000.0000000.0826880.0030900.4656500.160617-0.021754True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.702225 0.119576 0.000000 0.000000 0.514166 0.000000 0.398243 \n", + "1 0.000000 0.135406 0.843290 0.000000 0.000000 0.000000 0.079528 \n", + "2 0.000000 0.635161 0.000000 0.757472 0.040300 0.000000 0.058581 \n", + "3 0.000000 0.000000 0.000000 0.597671 0.156412 0.000000 0.928983 \n", + "4 0.818524 0.000000 0.000000 0.000000 0.919444 0.682978 0.000000 \n", + "5 0.000000 0.000000 0.559009 0.459663 0.000000 0.000000 0.064795 \n", + "6 0.000000 0.006826 0.000000 0.347398 0.609963 0.000000 0.206743 \n", + "7 0.462412 0.746150 0.043896 0.607946 0.000000 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.744648 0.000000 0.000000 0.204549 \n", + "9 0.000000 0.757775 0.000000 0.000000 0.000000 0.987737 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.000000 0.911647 0.000000 0.506541 \n", + "1 0.215190 0.761794 0.000000 0.755572 0.000000 \n", + "2 0.000000 0.000000 0.000000 0.539834 0.529786 \n", + "3 0.936716 0.000000 0.658372 0.000000 0.631700 \n", + "4 0.000000 0.696665 0.642777 0.000000 0.947762 \n", + "5 0.000000 0.638105 0.947350 0.398382 0.000000 \n", + "6 0.945932 0.000000 0.844383 0.000000 0.000000 \n", + "7 0.000000 0.000000 0.000000 0.349018 0.801749 \n", + "8 0.741898 0.571723 0.431828 0.000000 0.892975 \n", + "9 0.000000 0.082688 0.003090 0.465650 0.160617 \n", + "\n", + " y valid_y \n", + "0 -0.007801 True \n", + "1 -0.008543 True \n", + "2 -0.082457 True \n", + "3 -0.015096 True \n", + "4 -0.000610 True \n", + "5 -0.017627 True \n", + "6 -0.011483 True \n", + "7 -2.436279 True \n", + "8 -0.001715 True \n", + "9 -0.021754 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((5, 35, 42, 47, 48), {}, -4.6231695559517085)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mcts = MCTS(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=reward_fn,\n", + " use_cache=False,\n", + " rollout_mode=\"uniform_subset\",\n", + ")\n", + "mcts.run(n_iterations=2000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_10x_spurious_11...x_spurious_41x_spurious_42x_spurious_43x_spurious_5x_spurious_6x_spurious_7x_spurious_8x_spurious_9yvalid_y
00.0000000.0000000.0290300.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.1170840.00.000000-0.005313True
10.3088960.0000000.0000000.5516740.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.011710True
20.0000000.0000000.8619450.0000000.00.8948000.0000000.00.0000000.0...0.0000000.00.000000.00.00.1033970.00.000000-0.279134True
30.0000000.7871880.0000000.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.007694True
40.0000000.0000000.8330570.9394680.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.000441True
50.0000000.4647900.0000000.1868110.00.0000000.0000000.00.0000000.0...0.8330060.00.000000.00.00.0000000.00.693337-0.018105True
60.0000000.0000000.0000000.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.005089True
70.0000000.0000000.0000000.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.005089True
80.0000000.0000000.0000000.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.903890.00.00.0000000.00.000000-0.005089True
90.0000000.0000000.0000000.0000000.00.0000000.9407920.00.7665960.0...0.0000000.00.000000.00.00.0000000.00.000000-0.005089True
100.0000000.0000000.0000000.0000000.01.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.070364True
110.0000000.0000000.0000000.0000000.01.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.070364True
120.0000000.0000000.0000000.0000000.00.5946860.0000000.00.0000001.0...0.0000000.00.000000.00.00.0000000.00.000000-0.166469True
130.0000000.0000000.0000000.0000000.00.5779770.0000000.00.0000000.0...1.0000000.00.000000.00.01.0000001.00.000000-0.163030True
140.0000000.0000000.0000000.0000000.00.5978800.0000000.00.0000000.0...0.0000000.00.000000.01.00.0000000.00.000000-0.167050True
\n", + "

15 rows × 52 columns

\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.000000 0.000000 0.029030 0.000000 0.0 0.000000 0.000000 \n", + "1 0.308896 0.000000 0.000000 0.551674 0.0 0.000000 0.000000 \n", + "2 0.000000 0.000000 0.861945 0.000000 0.0 0.894800 0.000000 \n", + "3 0.000000 0.787188 0.000000 0.000000 0.0 0.000000 0.000000 \n", + "4 0.000000 0.000000 0.833057 0.939468 0.0 0.000000 0.000000 \n", + "5 0.000000 0.464790 0.000000 0.186811 0.0 0.000000 0.000000 \n", + "6 0.000000 0.000000 0.000000 0.000000 0.0 0.000000 0.000000 \n", + "7 0.000000 0.000000 0.000000 0.000000 0.0 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.000000 0.0 0.000000 0.000000 \n", + "9 0.000000 0.000000 0.000000 0.000000 0.0 0.000000 0.940792 \n", + "10 0.000000 0.000000 0.000000 0.000000 0.0 1.000000 0.000000 \n", + "11 0.000000 0.000000 0.000000 0.000000 0.0 1.000000 0.000000 \n", + "12 0.000000 0.000000 0.000000 0.000000 0.0 0.594686 0.000000 \n", + "13 0.000000 0.000000 0.000000 0.000000 0.0 0.577977 0.000000 \n", + "14 0.000000 0.000000 0.000000 0.000000 0.0 0.597880 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_10 x_spurious_11 ... x_spurious_41 \\\n", + "0 0.0 0.000000 0.0 ... 0.000000 \n", + "1 0.0 0.000000 0.0 ... 0.000000 \n", + "2 0.0 0.000000 0.0 ... 0.000000 \n", + "3 0.0 0.000000 0.0 ... 0.000000 \n", + "4 0.0 0.000000 0.0 ... 0.000000 \n", + "5 0.0 0.000000 0.0 ... 0.833006 \n", + "6 0.0 0.000000 0.0 ... 0.000000 \n", + "7 0.0 0.000000 0.0 ... 0.000000 \n", + "8 0.0 0.000000 0.0 ... 0.000000 \n", + "9 0.0 0.766596 0.0 ... 0.000000 \n", + "10 0.0 0.000000 0.0 ... 0.000000 \n", + "11 0.0 0.000000 0.0 ... 0.000000 \n", + "12 0.0 0.000000 1.0 ... 0.000000 \n", + "13 0.0 0.000000 0.0 ... 1.000000 \n", + "14 0.0 0.000000 0.0 ... 0.000000 \n", + "\n", + " x_spurious_42 x_spurious_43 x_spurious_5 x_spurious_6 x_spurious_7 \\\n", + "0 0.0 0.00000 0.0 0.0 0.117084 \n", + "1 0.0 0.00000 0.0 0.0 0.000000 \n", + "2 0.0 0.00000 0.0 0.0 0.103397 \n", + "3 0.0 0.00000 0.0 0.0 0.000000 \n", + "4 0.0 0.00000 0.0 0.0 0.000000 \n", + "5 0.0 0.00000 0.0 0.0 0.000000 \n", + "6 0.0 0.00000 0.0 0.0 0.000000 \n", + "7 0.0 0.00000 0.0 0.0 0.000000 \n", + "8 0.0 0.90389 0.0 0.0 0.000000 \n", + "9 0.0 0.00000 0.0 0.0 0.000000 \n", + "10 0.0 0.00000 0.0 0.0 0.000000 \n", + "11 0.0 0.00000 0.0 0.0 0.000000 \n", + "12 0.0 0.00000 0.0 0.0 0.000000 \n", + "13 0.0 0.00000 0.0 0.0 1.000000 \n", + "14 0.0 0.00000 0.0 1.0 0.000000 \n", + "\n", + " x_spurious_8 x_spurious_9 y valid_y \n", + "0 0.0 0.000000 -0.005313 True \n", + "1 0.0 0.000000 -0.011710 True \n", + "2 0.0 0.000000 -0.279134 True \n", + "3 0.0 0.000000 -0.007694 True \n", + "4 0.0 0.000000 -0.000441 True \n", + "5 0.0 0.693337 -0.018105 True \n", + "6 0.0 0.000000 -0.005089 True \n", + "7 0.0 0.000000 -0.005089 True \n", + "8 0.0 0.000000 -0.005089 True \n", + "9 0.0 0.000000 -0.005089 True \n", + "10 0.0 0.000000 -0.070364 True \n", + "11 0.0 0.000000 -0.070364 True \n", + "12 0.0 0.000000 -0.166469 True \n", + "13 1.0 0.000000 -0.163030 True \n", + "14 0.0 0.000000 -0.167050 True \n", + "\n", + "[15 rows x 52 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2510" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import math\n", + "\n", + "\n", + "sum([math.comb(12, i) for i in range(7)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'strategy' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mstrategy\u001b[49m\n", + "\u001b[31mNameError\u001b[39m: name 'strategy' is not defined" + ] + } + ], + "source": [ + "strategy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56", + "metadata": {}, + "outputs": [], + "source": [ + "strategy.acqf_optimizer._candidates_tensor_to_dataframe(candidates, strategy.domain)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(5, 37, 48)\n", + "-4.677423986870725\n", + "-4.664086891505056\n" + ] + } + ], + "source": [ + "leaf, path = mcts._select_and_expand()\n", + "selected_features, cat_selections = mcts._get_selection(leaf)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))\n", + "print(reward_fn2(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-39.6933963563105\n", + "-10.756249958648718\n" + ] + } + ], + "source": [ + "print(\n", + " reward_fn(\n", + " (\n", + " 0,\n", + " 1,\n", + " 2,\n", + " 3,\n", + " 4,\n", + " ),\n", + " cat_selections={},\n", + " )\n", + ")\n", + "print(\n", + " reward_fn2(\n", + " (\n", + " 0,\n", + " 1,\n", + " 2,\n", + " 3,\n", + " 4,\n", + " ),\n", + " cat_selections={},\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 8.7714, 3.3755, 0.2549, 0.0688, 2.2844, 1.9798, 13.7240, 6.4729,\n", + " 3.0276, 0.1636, 12.4373, 12.4482]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "np.float64(-0.2791335862611258)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments.y.min()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(18,)\n", + "-32.97498755616551\n" + ] + } + ], + "source": [ + "selected_features, _, _ = mcts._rollout(mcts.root)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "np.float64(-0.4820260292929051)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments.y.min()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-8.419034388799723" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reward_fn(\n", + " [\n", + " 31,\n", + " 49,\n", + " ],\n", + " cat_selections={},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x_3 0.0\n", + "dtype: float64" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments[[\"x_3\"]].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'x_3'" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "benchmark.domain.inputs.get_keys()[3]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 50])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bounds.shape" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "feature-mcts", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mcts-report/test_acqf_hartmann.py b/mcts-report/test_acqf_hartmann.py new file mode 100644 index 000000000..459125883 --- /dev/null +++ b/mcts-report/test_acqf_hartmann.py @@ -0,0 +1,263 @@ +"""Test MCTS acquisition optimization on Hartmann(dim=6, allowed_k=4). + +Compares MCTS-guided optimization against exhaustive enumeration of all +NChooseK subsets to verify that MCTS finds the best (or near-best) +combinatorial structure when optimizing a real acquisition function. + +Uses bofire's SingleTaskGPSurrogate and SoboStrategy for proper GP fitting +with data transforms, and generates NChooseK-respecting initial data. + +Usage: + python mcts-report/test_acqf_hartmann.py +""" + +import itertools +import sys +import time +import warnings +from pathlib import Path + +import torch +from botorch.optim import optimize_acqf + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +import bofire.data_models.strategies.api as data_models +from bofire.benchmarks.api import Hartmann +from bofire.data_models.strategies.predictives.acqf_optimization import BotorchOptimizer +from bofire.strategies.predictives.optimize_mcts import optimize_acqf_mcts +from bofire.strategies.predictives.sobo import SoboStrategy +from bofire.strategies.random import RandomStrategy + + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- + +DIM = 6 +ALLOWED_K = 4 +N_INITIAL = 20 +N_MCTS_ITERATIONS = 20 +N_RESTARTS = 20 +RAW_SAMPLES = 2048 +N_SEEDS = 5 + + +def make_strategy_and_acqf(benchmark: Hartmann, seed: int): + """Create a fitted SoboStrategy and extract acqf + bounds. + + Generates NChooseK-respecting initial data via RandomStrategy, + evaluates on Hartmann6, fits a GP via bofire's SingleTaskGPSurrogate, + and returns the acquisition function and bounds. + """ + domain = benchmark.domain + + # Generate NChooseK-respecting initial data + random_strategy = RandomStrategy( + data_model=data_models.RandomStrategy(domain=domain, seed=seed), + ) + candidates = random_strategy.ask(N_INITIAL) + experiments = benchmark.f(candidates, return_complete=True) + + # Create SoboStrategy with custom optimizer settings + strategy = SoboStrategy( + data_model=data_models.SoboStrategy( + domain=domain, + acquisition_optimizer=BotorchOptimizer( + n_restarts=N_RESTARTS, + n_raw_samples=RAW_SAMPLES, + ), + ), + ) + strategy.tell(experiments) + + # Extract the fitted acqf and bounds + acqf = strategy._get_acqfs(1)[0] + # Get bounds in the same way the optimizer does + from bofire.strategies.utils import get_torch_bounds_from_domain + + bounds = get_torch_bounds_from_domain(domain, strategy.input_preprocessing_specs) + + best_f = experiments["y"].min() # Hartmann is minimized + return strategy, acqf, bounds, best_f, experiments + + +# --------------------------------------------------------------------------- +# Exhaustive enumeration +# --------------------------------------------------------------------------- + + +def enumerate_all_subsets(dim: int, max_k: int) -> list[frozenset[int]]: + """Generate all subsets of {0, ..., dim-1} with size 0..max_k.""" + subsets = [] + for k in range(0, max_k + 1): + for combo in itertools.combinations(range(dim), k): + subsets.append(frozenset(combo)) + return subsets + + +def exhaustive_optimize(acqf, bounds: torch.Tensor, subsets: list[frozenset[int]]): + """Run optimize_acqf for every subset, return best result. + + Returns: + (best_candidates, best_acq_val, best_subset, all_results) + where all_results is a list of (subset, acq_val) sorted descending. + """ + dim = bounds.shape[1] + results = [] + + for subset in subsets: + # Fix inactive features to 0 + fixed = {i: 0.0 for i in range(dim) if i not in subset} + + if len(subset) == 0: + # All features fixed to 0 — evaluate directly + candidate = torch.zeros(1, dim, dtype=bounds.dtype) + with torch.no_grad(): + val = acqf(candidate.unsqueeze(0)).item() + results.append((subset, val, candidate)) + continue + + candidates, acq_val = optimize_acqf( + acq_function=acqf, + bounds=bounds, + q=1, + num_restarts=N_RESTARTS, + raw_samples=RAW_SAMPLES, + fixed_features=fixed, + ) + results.append((subset, acq_val.item(), candidates)) + + # Sort by acquisition value descending + results.sort(key=lambda x: x[1], reverse=True) + + best_subset, best_val, best_cand = results[0] + return best_cand, best_val, best_subset, [(s, v) for s, v, _ in results] + + +# --------------------------------------------------------------------------- +# MCTS optimization +# --------------------------------------------------------------------------- + + +def mcts_optimize(strategy: SoboStrategy, acqf, bounds: torch.Tensor, seed: int): + """Run optimize_acqf_mcts and return best result.""" + candidates, acq_val = optimize_acqf_mcts( + acq_function=acqf, + bounds=bounds, + nchooseks=[(list(range(DIM)), 0, ALLOWED_K)], + num_iterations=N_MCTS_ITERATIONS, + q=1, + raw_samples=RAW_SAMPLES, + num_restarts=N_RESTARTS, + seed=seed, + ) + + # Determine which features MCTS selected (non-zero in candidates) + selected = frozenset(i for i in range(DIM) if candidates[0, i].abs().item() > 1e-6) + + return candidates, acq_val, selected + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + warnings.filterwarnings("ignore", category=RuntimeWarning) + warnings.filterwarnings("ignore", message=".*InputDataWarning.*") + warnings.filterwarnings("ignore", message=".*model inputs.*") + + benchmark = Hartmann(dim=DIM, allowed_k=ALLOWED_K) + + print("MCTS vs Exhaustive Enumeration on Hartmann(dim=6, allowed_k=4)") + print("=" * 70) + print(f" Initial points: {N_INITIAL} (NChooseK-respecting)") + print(f" MCTS iterations: {N_MCTS_ITERATIONS}") + print(f" Restarts/raw_samples per optimize_acqf: {N_RESTARTS}/{RAW_SAMPLES}") + print(f" Seeds: {N_SEEDS}") + print(" Surrogate: bofire SingleTaskGPSurrogate (with data transforms)") + print() + + subsets = enumerate_all_subsets(DIM, ALLOWED_K) + print(f" Total subsets to enumerate: {len(subsets)}") + print() + + mcts_ranks = [] + mcts_gaps = [] + mcts_found_best = 0 + + for seed in range(N_SEEDS): + print(f"--- Seed {seed} ---") + + strategy, acqf, bounds, best_f, experiments = make_strategy_and_acqf( + benchmark, seed + ) + print(f" GP trained, best_f = {best_f:.4f}") + + # Exhaustive + t0 = time.time() + exh_cand, exh_val, exh_subset, all_results = exhaustive_optimize( + acqf, bounds, subsets + ) + exh_time = time.time() - t0 + + print( + f" Exhaustive: best_acq = {exh_val:.4f}, " + f"subset = {sorted(exh_subset)}, time = {exh_time:.1f}s" + ) + print(" Top 5 subsets:") + for i, (s, v) in enumerate(all_results[:5]): + print(f" #{i + 1}: {str(sorted(s)):>20s} acq = {v:.4f}") + + # MCTS + t0 = time.time() + mcts_cand, mcts_val, mcts_subset = mcts_optimize(strategy, acqf, bounds, seed) + mcts_time = time.time() - t0 + + # Find MCTS subset rank in exhaustive results + rank = next( + (i + 1 for i, (s, _) in enumerate(all_results) if s == mcts_subset), + len(all_results), + ) + + gap = exh_val - mcts_val + mcts_ranks.append(rank) + mcts_gaps.append(gap) + if rank == 1: + mcts_found_best += 1 + + print( + f" MCTS: best_acq = {mcts_val:.4f}, " + f"subset = {sorted(mcts_subset)}, time = {mcts_time:.1f}s" + ) + print( + f" MCTS rank: #{rank}/{len(all_results)}, " + f"gap = {gap:.4f}, " + f"speedup = {exh_time / max(mcts_time, 0.01):.1f}x" + ) + print() + + # Summary + print("=" * 70) + print("SUMMARY") + print("=" * 70) + print(f" Seeds: {N_SEEDS}") + print( + f" MCTS found best subset: {mcts_found_best}/{N_SEEDS} " + f"({100 * mcts_found_best / N_SEEDS:.0f}%)" + ) + print( + f" Mean MCTS rank: {sum(mcts_ranks) / len(mcts_ranks):.1f} " + f"/ {len(subsets)}" + ) + print( + f" Mean acq gap (exhaustive - MCTS): " f"{sum(mcts_gaps) / len(mcts_gaps):.4f}" + ) + print(f" MCTS ranks: {mcts_ranks}") + + +if __name__ == "__main__": + main() diff --git a/mcts-report/test_dag.ipynb b/mcts-report/test_dag.ipynb new file mode 100644 index 000000000..b78ff9f5a --- /dev/null +++ b/mcts-report/test_dag.ipynb @@ -0,0 +1,7295 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "import time\n", + "\n", + "import pandas as pd\n", + "import torch\n", + "from botorch.optim import optimize_acqf\n", + "from botorch.utils.sampling import draw_sobol_samples\n", + "from optimize_mcts_dag import MCTS_DAG\n", + "\n", + "import bofire.surrogates.api as surrogates\n", + "from bofire.benchmarks.api import Hartmann, SpuriousFeaturesWrapper\n", + "from bofire.data_models.surrogates.api import (\n", + " BotorchSurrogates,\n", + " EnsembleMapSaasSingleTaskGPSurrogate,\n", + " SingleTaskGPSurrogate,\n", + ")\n", + "from bofire.strategies import utils\n", + "from bofire.strategies.api import RandomStrategy, SoboStrategy\n", + "from bofire.strategies.predictives.optimize_mcts import (\n", + " Groups,\n", + " NChooseK,\n", + " _SelectionTracker,\n", + ")\n", + "from bofire.utils.torch_tools import tkwargs\n", + "\n", + "\n", + "benchmark = SpuriousFeaturesWrapper(Hartmann(dim=6), n_spurious_features=6, max_count=6)\n", + "random_strategy = RandomStrategy.make(domain=benchmark.domain)\n", + "\n", + "experiments = pd.read_csv(\"experiments.csv\")\n", + "\n", + "strategy = SoboStrategy.make(\n", + " domain=benchmark.domain,\n", + " surrogate_specs=BotorchSurrogates(\n", + " surrogates=[\n", + " EnsembleMapSaasSingleTaskGPSurrogate(\n", + " inputs=benchmark.domain.inputs, outputs=benchmark.domain.outputs\n", + " )\n", + " ]\n", + " ),\n", + ")\n", + "\n", + "strategy.tell(experiments.loc[:9].copy())\n", + "acqf = strategy._get_acqfs(n=1)[0]\n", + "\n", + "bounds = utils.get_torch_bounds_from_domain(\n", + " benchmark.domain, strategy.input_preprocessing_specs\n", + ")\n", + "\n", + "\n", + "def reward_fn(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " if len(selected_features) == 0:\n", + " candidates = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + " local_bounds = bounds[:, selected_features]\n", + " sobol_samples = draw_sobol_samples(bounds=local_bounds, n=64, q=1).squeeze(1)\n", + " candidates = torch.zeros((sobol_samples.shape[0], bounds.shape[1]), **tkwargs)\n", + " candidates[:, selected_features] = sobol_samples\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + "\n", + "\n", + "def reward_fn2(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=1,\n", + " raw_samples=64,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()\n", + "\n", + "\n", + "def reward_fn3(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()\n", + "\n", + "\n", + "# candidates = random_strategy.ask(10)\n", + "# experiments = benchmark.f(candidates, return_complete=True)\n", + "\n", + "# experiments[[\"x_3\"]].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.3403870.9393840.0000000.0000000.9888050.0000000.0000000.0000000.0591740.2835080.6511150.000000-0.093813True
10.0000000.0000000.2157870.6949210.6070900.0000000.6598540.8646030.0000000.0000000.0000000.683445-0.004209True
20.1622490.0000000.0000000.0000000.6701020.0000000.0000000.8375380.3137730.0000000.1942530.959623-0.003265True
30.6468610.1510700.0000000.0239440.0000000.0000000.0000000.0000000.1645060.1478150.0000000.930434-0.005103True
40.0000000.4927420.1909610.0000000.0581920.1650810.0000000.1383720.0000000.0000000.0000000.281132-0.040627True
50.9716560.3410370.2473110.1170810.0000000.0000000.1014650.0000000.0000000.6124490.0000000.000000-0.002416True
60.0097270.3187120.0000000.0000000.5629500.5014310.0000000.0000000.1304470.7760440.0000000.000000-0.216006True
70.8733560.2732550.0000000.0000000.3955370.0000000.0000000.9134770.4222420.5581640.0000000.000000-0.007566True
80.0000000.0000000.0000000.9213050.4787820.5553820.8431830.0000000.6743220.0000000.8404980.000000-0.019061True
90.0000000.8613230.0000000.0000000.2463090.0000000.3676990.5743210.0000000.7715720.2532700.000000-0.010974True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.340387 0.939384 0.000000 0.000000 0.988805 0.000000 0.000000 \n", + "1 0.000000 0.000000 0.215787 0.694921 0.607090 0.000000 0.659854 \n", + "2 0.162249 0.000000 0.000000 0.000000 0.670102 0.000000 0.000000 \n", + "3 0.646861 0.151070 0.000000 0.023944 0.000000 0.000000 0.000000 \n", + "4 0.000000 0.492742 0.190961 0.000000 0.058192 0.165081 0.000000 \n", + "5 0.971656 0.341037 0.247311 0.117081 0.000000 0.000000 0.101465 \n", + "6 0.009727 0.318712 0.000000 0.000000 0.562950 0.501431 0.000000 \n", + "7 0.873356 0.273255 0.000000 0.000000 0.395537 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.921305 0.478782 0.555382 0.843183 \n", + "9 0.000000 0.861323 0.000000 0.000000 0.246309 0.000000 0.367699 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.059174 0.283508 0.651115 0.000000 \n", + "1 0.864603 0.000000 0.000000 0.000000 0.683445 \n", + "2 0.837538 0.313773 0.000000 0.194253 0.959623 \n", + "3 0.000000 0.164506 0.147815 0.000000 0.930434 \n", + "4 0.138372 0.000000 0.000000 0.000000 0.281132 \n", + "5 0.000000 0.000000 0.612449 0.000000 0.000000 \n", + "6 0.000000 0.130447 0.776044 0.000000 0.000000 \n", + "7 0.913477 0.422242 0.558164 0.000000 0.000000 \n", + "8 0.000000 0.674322 0.000000 0.840498 0.000000 \n", + "9 0.574321 0.000000 0.771572 0.253270 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.093813 True \n", + "1 -0.004209 True \n", + "2 -0.003265 True \n", + "3 -0.005103 True \n", + "4 -0.040627 True \n", + "5 -0.002416 True \n", + "6 -0.216006 True \n", + "7 -0.007566 True \n", + "8 -0.019061 True \n", + "9 -0.010974 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments.loc[:9]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MAEMSDR2MAPEPEARSONSPEARMANFISHER
00.0554610.017940.97192410.7924820.986920.960975.674047e-18
\n", + "
" + ], + "text/plain": [ + " MAE MSD R2 MAPE PEARSON SPEARMAN FISHER\n", + "0 0.055461 0.01794 0.971924 10.792482 0.98692 0.96097 5.674047e-18" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "surrogate_data = EnsembleMapSaasSingleTaskGPSurrogate(\n", + " inputs=benchmark.domain.inputs, outputs=benchmark.domain.outputs\n", + ")\n", + "surrogate = surrogates.map(surrogate_data)\n", + "\n", + "cv_train, cv_test, _ = surrogate.cross_validate(experiments, folds=5)\n", + "display(cv_test.get_metrics())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MAEMSDR2MAPEPEARSONSPEARMANFISHER
00.0684960.0261620.95905817.9382450.9814980.9556965.674047e-18
\n", + "
" + ], + "text/plain": [ + " MAE MSD R2 MAPE PEARSON SPEARMAN FISHER\n", + "0 0.068496 0.026162 0.959058 17.938245 0.981498 0.955696 5.674047e-18" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "surrogate_data = SingleTaskGPSurrogate(\n", + " inputs=benchmark.domain.inputs, outputs=benchmark.domain.outputs\n", + ")\n", + "surrogate = surrogates.map(surrogate_data)\n", + "\n", + "cv_train, cv_test, _ = surrogate.cross_validate(experiments, folds=5)\n", + "display(cv_test.get_metrics())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "surrogate_data = EnsembleMapSaasSingleTaskGPSurrogate(\n", + " inputs=benchmark.domain.inputs, outputs=benchmark.domain.outputs\n", + ")\n", + "surrogate = surrogates.map(surrogate_data)\n", + "\n", + "cv_train, cv_test, _ = surrogate.cross_validate(experiments, folds=5)\n", + "display(cv_test.get_metrics())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[7.9170e+03, 1.6480e+00, 7.0719e-01, 4.3675e-01, 3.9497e-01,\n", + " 5.8420e-01, 1.0000e+04, 1.0000e+04, 1.0000e+04, 1.0000e+04,\n", + " 1.0000e+04, 1.0000e+04]],\n", + "\n", + " [[1.0000e+04, 1.4312e+00, 6.6053e-01, 4.0018e-01, 3.6135e-01,\n", + " 5.3718e-01, 1.0000e+04, 1.0000e+04, 1.0000e+04, 1.0000e+04,\n", + " 1.0000e+04, 1.0000e+04]],\n", + "\n", + " [[8.8197e+03, 1.5363e+00, 6.8223e-01, 4.1844e-01, 3.7854e-01,\n", + " 5.5984e-01, 1.0000e+04, 1.0000e+04, 1.0000e+04, 1.0000e+04,\n", + " 1.0000e+04, 1.0000e+04]],\n", + "\n", + " [[1.0000e+04, 1.4326e+00, 6.6081e-01, 4.0044e-01, 3.6159e-01,\n", + " 5.3748e-01, 1.0000e+04, 1.0000e+04, 1.0000e+04, 1.0000e+04,\n", + " 1.0000e+04, 1.0000e+04]]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.base_kernel.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 1.7057, 1.8292, 29.8056, 0.3237, 0.2182, 72.7240, 85.6520,\n", + " 7.8904, 199.8355, 15.3883, 7.4227, 141.4410]],\n", + " dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[2.6226e+00, 8.2782e+01, 4.1018e-01, 1.1438e+02, 3.3914e-01, 1.3650e-01,\n", + " 2.0361e+00, 8.3981e+00, 2.5162e+00, 3.5898e+02, 1.1046e+01, 4.6532e+00]],\n", + " dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_gold_standard(selections, acqf, bounds, n_features):\n", + " \"\"\"Evaluate each selection with full optimize_acqf (gold standard).\"\"\"\n", + " values = []\n", + " for sel in selections:\n", + " fixed = {i: 0.0 for i in range(n_features) if i not in sel}\n", + " if len(sel) == 0:\n", + " cand = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " val = acqf(cand.unsqueeze(-2)).max().item()\n", + " else:\n", + " cand, acq_val = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features=fixed,\n", + " )\n", + " val = acq_val.item()\n", + " values.append(val)\n", + " return values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[-45.3893631573142]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluate_gold_standard([(0, 1, 2, 3, 11)], acqf, bounds, n_features=12)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-5.114861134566571" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reward_fn2((0, 1, 2, 3, 4, 5), cat_selections={})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x_5 44.561275\n", + "dtype: float64" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments[[\"x_5\"]].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.4024700.0000000.0000000.0000000.0000000.0000000.0000000.6374530.1399190.4218240.6713710.331502-0.0056831
10.4904710.2944280.0000000.0000000.0000000.9252190.1794820.7729880.4809430.0000000.0000000.000000-0.0967351
20.8312990.0000000.1528290.2337560.0000000.2376550.2170320.0000000.5361540.0000000.0000000.000000-0.0417131
30.0000000.8293810.0000000.2879120.0000000.0000000.7741490.0000000.0000000.5073140.7311190.425413-0.0825721
40.0000000.8670440.1348860.0000000.2504430.2510050.7662670.0000000.0214130.0000000.0000000.000000-0.0457211
50.0000000.3077290.8427450.1164260.0000000.2726150.0000000.0000000.0000000.7787310.0000000.155062-0.0971321
60.2566110.4944050.0000000.6031580.0000000.0000000.8366570.0000000.0000000.8080900.0000000.134337-0.6195871
70.9262710.0000000.0000000.3188560.0000000.0000000.6645890.0000000.4928020.0531630.0000000.259605-0.0032261
80.0000000.2191180.6250430.7884450.0000000.0000000.0000000.4703430.0000000.0000000.3855200.006767-0.0065051
90.9737060.0000000.0000000.0623920.0000000.6765590.0000000.8051040.0000000.2609770.4023370.000000-0.0542801
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.402470 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "1 0.490471 0.294428 0.000000 0.000000 0.000000 0.925219 0.179482 \n", + "2 0.831299 0.000000 0.152829 0.233756 0.000000 0.237655 0.217032 \n", + "3 0.000000 0.829381 0.000000 0.287912 0.000000 0.000000 0.774149 \n", + "4 0.000000 0.867044 0.134886 0.000000 0.250443 0.251005 0.766267 \n", + "5 0.000000 0.307729 0.842745 0.116426 0.000000 0.272615 0.000000 \n", + "6 0.256611 0.494405 0.000000 0.603158 0.000000 0.000000 0.836657 \n", + "7 0.926271 0.000000 0.000000 0.318856 0.000000 0.000000 0.664589 \n", + "8 0.000000 0.219118 0.625043 0.788445 0.000000 0.000000 0.000000 \n", + "9 0.973706 0.000000 0.000000 0.062392 0.000000 0.676559 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.637453 0.139919 0.421824 0.671371 0.331502 \n", + "1 0.772988 0.480943 0.000000 0.000000 0.000000 \n", + "2 0.000000 0.536154 0.000000 0.000000 0.000000 \n", + "3 0.000000 0.000000 0.507314 0.731119 0.425413 \n", + "4 0.000000 0.021413 0.000000 0.000000 0.000000 \n", + "5 0.000000 0.000000 0.778731 0.000000 0.155062 \n", + "6 0.000000 0.000000 0.808090 0.000000 0.134337 \n", + "7 0.000000 0.492802 0.053163 0.000000 0.259605 \n", + "8 0.470343 0.000000 0.000000 0.385520 0.006767 \n", + "9 0.805104 0.000000 0.260977 0.402337 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.005683 1 \n", + "1 -0.096735 1 \n", + "2 -0.041713 1 \n", + "3 -0.082572 1 \n", + "4 -0.045721 1 \n", + "5 -0.097132 1 \n", + "6 -0.619587 1 \n", + "7 -0.003226 1 \n", + "8 -0.006505 1 \n", + "9 -0.054280 1 " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_10x_spurious_11...x_spurious_41x_spurious_42x_spurious_43x_spurious_5x_spurious_6x_spurious_7x_spurious_8x_spurious_9yvalid_y
00.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.2026600.0000000.0000000.0000000.0000000.0000000.280977-0.0050891
10.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0000000.5171240.0000000.7295480.7651010.0000000.030946-0.0050891
20.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0000000.0000000.0128140.8652780.0000000.0000000.377278-0.0050891
30.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.8330690.2585300.0000000.0000000.0000000.0000000.885180-0.0050891
40.00.00.00.00.00.3797580.00.0000000.00.0...0.0000000.0000000.0000000.0000000.0000000.0516230.0000000.207810-0.0905331
..................................................................
2510.00.00.00.00.00.0000000.00.0000000.00.0...0.6772340.0000000.0000000.8742940.0000000.0000000.8935000.198800-0.0050891
2520.00.00.00.00.00.0000000.00.9617930.00.0...0.0000000.4916440.0000000.0000000.6685070.0000000.0000000.449542-0.0050891
2530.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0000000.0000000.0000000.3692730.0000000.0000000.862634-0.0050891
2540.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0000000.7771760.9883990.0000000.6791650.0000000.205755-0.0050891
2550.00.00.00.00.00.0000000.00.0000000.00.0...0.0000000.0286050.0000000.0000000.0000000.0000000.5994050.957740-0.0050891
\n", + "

256 rows × 52 columns

\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 x_spurious_1 \\\n", + "0 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "1 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "2 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "3 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "4 0.0 0.0 0.0 0.0 0.0 0.379758 0.0 0.000000 \n", + ".. ... ... ... ... ... ... ... ... \n", + "251 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "252 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.961793 \n", + "253 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "254 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "255 0.0 0.0 0.0 0.0 0.0 0.000000 0.0 0.000000 \n", + "\n", + " x_spurious_10 x_spurious_11 ... x_spurious_41 x_spurious_42 \\\n", + "0 0.0 0.0 ... 0.000000 0.202660 \n", + "1 0.0 0.0 ... 0.000000 0.000000 \n", + "2 0.0 0.0 ... 0.000000 0.000000 \n", + "3 0.0 0.0 ... 0.000000 0.833069 \n", + "4 0.0 0.0 ... 0.000000 0.000000 \n", + ".. ... ... ... ... ... \n", + "251 0.0 0.0 ... 0.677234 0.000000 \n", + "252 0.0 0.0 ... 0.000000 0.491644 \n", + "253 0.0 0.0 ... 0.000000 0.000000 \n", + "254 0.0 0.0 ... 0.000000 0.000000 \n", + "255 0.0 0.0 ... 0.000000 0.028605 \n", + "\n", + " x_spurious_43 x_spurious_5 x_spurious_6 x_spurious_7 x_spurious_8 \\\n", + "0 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "1 0.517124 0.000000 0.729548 0.765101 0.000000 \n", + "2 0.000000 0.012814 0.865278 0.000000 0.000000 \n", + "3 0.258530 0.000000 0.000000 0.000000 0.000000 \n", + "4 0.000000 0.000000 0.000000 0.051623 0.000000 \n", + ".. ... ... ... ... ... \n", + "251 0.000000 0.874294 0.000000 0.000000 0.893500 \n", + "252 0.000000 0.000000 0.668507 0.000000 0.000000 \n", + "253 0.000000 0.000000 0.369273 0.000000 0.000000 \n", + "254 0.777176 0.988399 0.000000 0.679165 0.000000 \n", + "255 0.000000 0.000000 0.000000 0.000000 0.599405 \n", + "\n", + " x_spurious_9 y valid_y \n", + "0 0.280977 -0.005089 1 \n", + "1 0.030946 -0.005089 1 \n", + "2 0.377278 -0.005089 1 \n", + "3 0.885180 -0.005089 1 \n", + "4 0.207810 -0.090533 1 \n", + ".. ... ... ... \n", + "251 0.198800 -0.005089 1 \n", + "252 0.449542 -0.005089 1 \n", + "253 0.862634 -0.005089 1 \n", + "254 0.205755 -0.005089 1 \n", + "255 0.957740 -0.005089 1 \n", + "\n", + "[256 rows x 52 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((3, 8, 17, 18, 20, 49), {}, [])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "groups = Groups(groups=[NChooseK(features=list(range(50)), max_count=6, min_count=0)])\n", + "mcts = MCTS(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=lambda x, y: 0.0,\n", + " rollout_mode=\"uniform_subset\",\n", + " p_stop_rollout=0.0,\n", + ")\n", + "mcts._rollout(mcts.root)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "strategy = SoboStrategy.make(domain=benchmark.domain)\n", + "\n", + "strategy.tell(experiments)\n", + "acqf = strategy._get_acqfs(n=1)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean time per acqf(X) call: 1.853 ms\n", + "Min/Max per call: 1.054 / 14.629 ms\n" + ] + } + ], + "source": [ + "transformed = strategy.domain.inputs.transform(\n", + " candidates,\n", + " strategy.input_preprocessing_specs,\n", + ")\n", + "X = torch.from_numpy(transformed.values).to(**tkwargs)\n", + "X = X.unsqueeze(-2)\n", + "\n", + "n_calls = 100\n", + "call_times = []\n", + "\n", + "with torch.no_grad():\n", + " for _ in range(n_calls):\n", + " t0 = time.perf_counter()\n", + " _ = acqf(X)\n", + " call_times.append(time.perf_counter() - t0)\n", + "\n", + "mean_time = sum(call_times) / n_calls\n", + "print(f\"Mean time per acqf(X) call: {mean_time * 1e3:.3f} ms\")\n", + "print(f\"Min/Max per call: {min(call_times) * 1e3:.3f} / {max(call_times) * 1e3:.3f} ms\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-33.6191, -32.9385, -33.6550, -33.6543, -31.8275, -33.6546, -33.6545,\n", + " -33.6547, -33.6544, -33.6546, -33.6545, -31.5978, -33.6528, -33.6542,\n", + " -33.6542, -33.6536, -33.6542, -32.0027, -33.6554, -33.6563, -33.6516,\n", + " -33.6549, -33.5973, -33.6491, -33.6549, -33.6451, -33.6546, -33.6551,\n", + " -33.6543, -33.6530, -33.6553, -33.6550, -33.6532, -33.6542, -33.6548,\n", + " -33.6539, -33.6533, -33.6483, -33.6543, -33.6543, -33.6543, -33.6373,\n", + " -33.6253, -33.6540, -33.6504, -33.6526, -33.6544, -33.6551, -33.6510,\n", + " -33.6544, -33.6547, -33.6532, -33.6543, -33.6543, -33.6550, -33.6545,\n", + " -33.6505, -33.6542, -33.6544, -33.6712, -33.6541, -33.6545, -33.6533,\n", + " -33.6550], dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "acqf(X)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-10.2082, dtype=torch.float64, grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "domain2 = copy.deepcopy(benchmark.domain)\n", + "domain2.constraints.constraints = []\n", + "for i in range(6, 50):\n", + " domain2.inputs.features[i].bounds = (0.0, 0.0)\n", + "\n", + "candidates2 = domain2.inputs.sample(64)\n", + "\n", + "transformed = strategy.domain.inputs.transform(\n", + " candidates2,\n", + " strategy.input_preprocessing_specs,\n", + ")\n", + "X = torch.from_numpy(transformed.values).to(**tkwargs)\n", + "X = X.unsqueeze(-2)\n", + "\n", + "acqf(X).max()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((5, 9), {}, -4.918443727224067)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from bofire.strategies import utils\n", + "\n", + "\n", + "groups = Groups(\n", + " groups=[\n", + " NChooseK(\n", + " features=list(range(len(benchmark.domain.inputs.get_keys()))),\n", + " max_count=6,\n", + " min_count=0,\n", + " )\n", + " ]\n", + ")\n", + "\n", + "bounds = utils.get_torch_bounds_from_domain(\n", + " benchmark.domain, strategy.input_preprocessing_specs\n", + ")\n", + "\n", + "\n", + "def reward_fn(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " if len(selected_features) == 0:\n", + " candidates = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + " local_bounds = bounds[:, selected_features]\n", + " sobol_samples = draw_sobol_samples(bounds=local_bounds, n=64, q=1).squeeze(1)\n", + " candidates = torch.zeros((sobol_samples.shape[0], bounds.shape[1]), **tkwargs)\n", + " candidates[:, selected_features] = sobol_samples\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + "\n", + "\n", + "def reward_fn2(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.inputs.domain.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=1,\n", + " raw_samples=64,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()\n", + "\n", + "\n", + "mcts = MCTS_DAG(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=reward_fn,\n", + " use_cache=False,\n", + " rollout_mode=\"ts_group_action\",\n", + " adaptive_prior_var=True,\n", + " separate_stop=True,\n", + ")\n", + "mcts.run(n_iterations=1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(0, 1, 6, 8)\n", + "-2.5978651482116755\n" + ] + } + ], + "source": [ + "leaf, path = mcts._select_and_expand()\n", + "selected_features, cat_selections = mcts._get_selection(leaf)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))\n", + "# print(reward_fn2(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-7.509994527934811\n" + ] + } + ], + "source": [ + "print(reward_fn((8, 9, 10), cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5
00.4115960.9102150.00.00.00.00.00.00.00.00.00.878617
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 x_spurious_1 \\\n", + "0 0.411596 0.910215 0.0 0.0 0.0 0.0 0.0 0.0 \n", + "\n", + " x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \n", + "0 0.0 0.0 0.0 0.878617 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-2.5506, dtype=torch.float64)\n" + ] + } + ], + "source": [ + "candidate, acqf_val = get_candidate()\n", + "\n", + "display(candidate)\n", + "print(acqf_val)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--- Iteration 0 ---\n", + " n_experiments: 10\n", + "[(4, 5, 10, 11), (2, 4, 5, 10, 11), (0, 4, 5, 10, 11), (5, 9, 11), (5, 9), (5, 9), (5, 9, 11), (5, 9, 11), (5, 9), (5, 9), (0, 1, 4, 5, 8, 9), (0, 1, 4, 8, 9, 10), (1, 2, 4, 5, 7, 11), (3, 4, 5, 6, 8, 10), (1, 4, 6, 7, 9, 10)]\n", + ". chosen subset: (0, 1, 4, 5, 8, 9)\n", + " gold standard acq value: -3.272324\n", + " acq_val from get_candidate: -3.0824304931611106\n", + " new y value: [-0.00190784]\n", + " experiments: 10 -> 11\n", + "\n", + "--- Iteration 1 ---\n", + " n_experiments: 11\n", + "[(4, 5, 9, 10, 11), (4, 5), (5, 9, 10, 11), (5, 9, 10), (5, 9, 11), (5, 9, 11), (5, 9, 11), (5, 9, 11), (5, 9, 10), (5, 8, 9, 11), (0, 1, 4, 5, 8, 9), (0, 1, 4, 8, 9, 10), (1, 2, 4, 5, 7, 11), (3, 4, 5, 6, 8, 10), (1, 4, 6, 7, 9, 10)]\n", + ". chosen subset: (4, 5, 9, 10, 11)\n", + " gold standard acq value: -3.497068\n", + " acq_val from get_candidate: -3.2354980634590476\n", + " new y value: [-0.00106425]\n", + " experiments: 11 -> 12\n", + "\n", + "--- Iteration 2 ---\n", + " n_experiments: 12\n", + "[(5, 8), (5, 6, 8), (5, 8, 11), (5, 8, 9), (5, 8, 9, 11), (5, 8), (5, 8), (5, 8), (5, 8), (5, 8, 11), (0, 1, 4, 5, 8, 9), (0, 1, 4, 8, 9, 10), (1, 2, 4, 5, 7, 11), (3, 4, 5, 6, 8, 10), (1, 4, 6, 7, 9, 10)]\n", + ". chosen subset: (3, 4, 5, 6, 8, 10)\n", + " gold standard acq value: -5.730687\n", + " acq_val from get_candidate: -5.479729315100673\n", + " new y value: [-0.00127744]\n", + " experiments: 12 -> 13\n", + "\n", + "--- Iteration 3 ---\n", + " n_experiments: 13\n", + "[(0, 1, 2, 3, 5, 6), (5, 7, 11), (5, 7, 9, 11), (5, 11), (5, 9, 11), (5,), (5, 11), (5, 8, 9, 11), (5, 9), (5,), (0, 1, 4, 5, 8, 9), (0, 1, 4, 8, 9, 10), (1, 2, 4, 5, 7, 11), (3, 4, 5, 6, 8, 10), (1, 4, 6, 7, 9, 10)]\n", + ". chosen subset: (3, 4, 5, 6, 8, 10)\n", + " gold standard acq value: -3.921437\n", + " acq_val from get_candidate: -3.921436410084564\n", + " new y value: [-0.00063887]\n", + " experiments: 13 -> 14\n", + "\n", + "--- Iteration 4 ---\n", + " n_experiments: 14\n", + "[(0, 2, 4, 5, 8, 10), (0, 4, 5, 10, 11), (4, 5, 7, 8, 10), (5, 6, 9, 11), (5, 9, 11), (5, 8, 9, 11), (5, 8, 11), (5, 7, 11), (5, 11), (5, 8, 10), (0, 1, 4, 5, 8, 9), (0, 1, 4, 8, 9, 10), (1, 2, 4, 5, 7, 11), (3, 4, 5, 6, 8, 10), (1, 4, 6, 7, 9, 10)]\n", + ". chosen subset: (1, 2, 4, 5, 7, 11)\n", + " gold standard acq value: -5.187396\n", + " acq_val from get_candidate: -5.187395310570116\n", + " new y value: [-0.43914631]\n", + " experiments: 14 -> 15\n", + "\n", + "--- Iteration 5 ---\n", + " n_experiments: 15\n", + "[(0, 2, 5), (3, 4, 5, 6, 7, 11), (5, 6, 9), (5, 6, 11), (5, 6, 8, 11), (5, 6, 11), (5, 6, 7, 9, 10), (5, 6, 10, 11), (5, 6, 8, 11), (5, 6, 7, 11), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9), (0, 1, 4, 8, 9, 10), (1, 2, 4, 5, 7, 11), (3, 4, 5, 6, 8, 10)]\n", + ". chosen subset: (1, 2, 4, 5, 7, 11)\n", + " gold standard acq value: -3.562293\n", + " acq_val from get_candidate: -3.5622875803820016\n", + " new y value: [-0.00643616]\n", + " experiments: 15 -> 16\n", + "\n", + "--- Iteration 6 ---\n", + " n_experiments: 16\n", + "[(1, 2, 5, 7, 9), (2, 5), (2, 5, 11), (2, 4, 5, 6, 8, 11), (5, 11), (5, 11), (5, 11), (5, 11), (5, 11), (5, 11), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9), (0, 1, 4, 8, 9, 10), (1, 2, 4, 5, 7, 11), (3, 4, 5, 6, 8, 10)]\n", + ". chosen subset: (1, 2, 5, 7, 9)\n", + " gold standard acq value: -4.191990\n", + " acq_val from get_candidate: -4.191989563429685\n", + " new y value: [-0.02081094]\n", + " experiments: 16 -> 17\n", + "\n", + "--- Iteration 7 ---\n", + " n_experiments: 17\n", + "[(1, 2, 5, 10), (3, 4, 5, 8, 10, 11), (3, 5, 10), (1, 5), (1, 5, 7, 10), (5, 10), (5, 11), (5, 8, 9, 10, 11), (5, 11), (5, 10), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9), (0, 1, 4, 8, 9, 10), (1, 2, 4, 5, 7, 11), (1, 2, 5, 9)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed in `gen_candidates_scipy` with the following warning(s):\n", + "[OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .')]\n", + "Trying again with a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ". chosen subset: (1, 2, 5, 10)\n", + " gold standard acq value: -4.091967\n", + " acq_val from get_candidate: -4.091966906328363\n", + " new y value: [-0.1143337]\n", + " experiments: 17 -> 18\n", + "\n", + "--- Iteration 8 ---\n", + " n_experiments: 18\n", + "[(2, 4, 8, 10, 11), (2, 4, 9), (2, 4, 10), (2, 4), (2, 4, 8), (2, 4, 8, 10), (2, 4), (2, 4, 11), (2, 4, 5, 11), (2, 4, 9, 10), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9), (1, 2, 5, 10), (0, 1, 4, 8, 9, 10), (1, 2, 4, 5, 7, 11)]\n", + ". chosen subset: (1, 2, 4, 5, 7, 11)\n", + " gold standard acq value: -4.352177\n", + " acq_val from get_candidate: -4.352175311070443\n", + " new y value: [-0.23929266]\n", + " experiments: 18 -> 19\n", + "\n", + "--- Iteration 9 ---\n", + " n_experiments: 19\n", + "[(2, 4, 5, 7, 10), (2, 3, 4, 6, 9, 10), (4, 5), (4, 5, 10), (4, 5, 6), (4, 5, 10), (4, 5), (4, 5, 10), (4, 5), (4, 5), (1, 2, 4, 5, 11), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9), (1, 2, 5, 10), (0, 1, 4, 8, 9, 10)]\n", + ". chosen subset: (1, 2, 4, 5, 11)\n", + " gold standard acq value: -4.059375\n", + " acq_val from get_candidate: -4.0593750473351005\n", + " new y value: [-0.47169123]\n", + " experiments: 19 -> 20\n", + "\n", + "--- Iteration 10 ---\n", + " n_experiments: 20\n", + "[(4, 5), (4, 5, 8, 9, 10), (4, 5, 10), (4, 5, 6, 8, 10), (4, 5, 8), (4, 5), (4, 5), (4, 5, 10, 11), (4, 5), (4, 5), (2, 4, 5), (1, 2, 4, 5, 11), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9), (1, 2, 5, 10)]\n", + ". chosen subset: (2, 4, 5)\n", + " gold standard acq value: -4.426326\n", + " acq_val from get_candidate: -4.4263254279920075\n", + " new y value: [-0.01327007]\n", + " experiments: 20 -> 21\n", + "\n", + "--- Iteration 11 ---\n", + " n_experiments: 21\n", + "[(1, 2, 4, 5), (5, 6), (5, 6, 7), (5, 6, 9), (5, 6, 7, 9), (5,), (5,), (5,), (5,), (5, 9, 10, 11), (2, 4, 5), (1, 2, 4, 5, 11), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9), (1, 2, 5, 10)]\n", + ". chosen subset: (1, 2, 4, 5, 11)\n", + " gold standard acq value: -4.433573\n", + " acq_val from get_candidate: -4.433572743130247\n", + " new y value: [-0.49692316]\n", + " experiments: 21 -> 22\n", + "\n", + "--- Iteration 12 ---\n", + " n_experiments: 22\n", + "[(2, 5), (2, 5, 6), (2, 5, 6, 7, 9), (2, 5, 6, 9), (2, 5, 7, 9), (2, 5), (2, 5), (2, 5), (2, 5), (2, 5), (1, 2, 4, 5), (2, 4, 5), (1, 2, 4, 5, 11), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9)]\n", + ". chosen subset: (2, 5, 6, 7, 9)\n", + " gold standard acq value: -4.873673\n", + " acq_val from get_candidate: -4.872499826321248\n", + " new y value: [-0.11402308]\n", + " experiments: 22 -> 23\n", + "\n", + "--- Iteration 13 ---\n", + " n_experiments: 23\n", + "[(2, 5), (2, 5, 7), (2, 5, 9), (2, 5, 7, 11), (2, 5, 7, 9), (2, 5), (2, 5), (2, 5), (2, 5), (2, 5), (1, 2, 4, 5), (2, 4, 5), (1, 2, 4, 5, 11), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9)]\n", + ". chosen subset: (2, 5, 9)\n", + " gold standard acq value: -4.965188\n", + " acq_val from get_candidate: -4.965202329122498\n", + " new y value: [-0.11341447]\n", + " experiments: 23 -> 24\n", + "\n", + "--- Iteration 14 ---\n", + " n_experiments: 24\n", + "[(4, 5, 7, 9, 10), (2, 4, 5, 6, 9, 10), (5,), (5, 7), (5, 6), (5,), (5,), (5,), (5,), (5,), (1, 2, 4, 5), (2, 4, 5), (1, 2, 4, 5, 11), (1, 2, 4, 5, 11), (0, 1, 4, 5, 8, 9)]\n", + ". chosen subset: (0, 1, 4, 5, 8, 9)\n", + " gold standard acq value: -4.168549\n", + " acq_val from get_candidate: -4.168548793473345\n", + " new y value: [-0.64283099]\n", + " experiments: 24 -> 25\n", + "\n", + "--- Iteration 15 ---\n", + " n_experiments: 25\n", + "[(5, 6), (5, 6, 7, 8), (5, 6, 7), (5, 6, 8), (5, 6, 8, 11), (5,), (5,), (5,), (5,), (5,), (0, 4, 5), (1, 2, 4, 5), (2, 4, 5), (1, 2, 4, 5, 11), (1, 2, 4, 5, 11)]\n", + ". chosen subset: (5, 6, 7, 8)\n", + " gold standard acq value: -5.568869\n", + " acq_val from get_candidate: -5.30377895237886\n", + " new y value: [-0.15934426]\n", + " experiments: 25 -> 26\n", + "\n", + "--- Iteration 16 ---\n", + " n_experiments: 26\n", + "[(5, 6), (5, 7, 8), (5, 8), (5,), (5, 7), (5,), (5,), (5,), (5,), (5,), (0, 4, 5), (1, 2, 4, 5), (2, 4, 5), (1, 2, 4, 5, 11), (1, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 4, 5)\n", + " gold standard acq value: -5.228153\n", + " acq_val from get_candidate: -5.2281530548772235\n", + " new y value: [-7.3588042e-05]\n", + " experiments: 26 -> 27\n", + "\n", + "--- Iteration 17 ---\n", + " n_experiments: 27\n", + "[(4, 5, 11), (4, 5, 6), (4, 5, 8, 11), (4, 5, 10, 11), (4, 5, 7, 11), (4, 5), (4, 5, 7, 10), (4, 5), (4, 5), (4, 5, 7, 9, 11), (0, 4, 5), (1, 2, 4, 5), (2, 4, 5), (1, 2, 4, 5, 11), (1, 2, 4, 5, 11)]\n", + ". chosen subset: (1, 2, 4, 5)\n", + " gold standard acq value: -5.259865\n", + " acq_val from get_candidate: -5.259873205703482\n", + " new y value: [-0.64124025]\n", + " experiments: 27 -> 28\n", + "\n", + "--- Iteration 18 ---\n", + " n_experiments: 28\n", + "[(3, 4, 5), (3, 4, 5, 11), (3, 4, 5, 6, 9), (3, 5), (3, 5, 7, 8, 9), (3, 5, 9, 11), (3, 5, 10), (3, 5, 6, 11), (3, 5, 6), (6, 11), (0, 4, 5), (1, 4, 5), (1, 2, 4, 5), (2, 4, 5), (1, 2, 4, 5, 11)]\n", + ". chosen subset: (3, 4, 5, 6, 9)\n", + " gold standard acq value: -5.096295\n", + " acq_val from get_candidate: -5.096293312431652\n", + " new y value: [-0.82025242]\n", + " experiments: 28 -> 29\n", + "\n", + "--- Iteration 19 ---\n", + " n_experiments: 29\n", + "[(4, 5, 6), (3, 4, 5, 6), (3, 4, 5), (3, 5, 6, 11), (3, 5, 6, 7, 9, 11), (9, 10), (3, 8), (3, 5, 6, 9, 11), (3, 5, 6, 8, 11), (3, 5, 6, 8, 11), (4, 5, 6, 9), (0, 4, 5), (1, 4, 5), (1, 2, 4, 5), (2, 4, 5)]\n", + ". chosen subset: (3, 4, 5, 6)\n", + " gold standard acq value: -9.524383\n", + " acq_val from get_candidate: -5.0343722531026796\n", + " new y value: [-0.82429871]\n", + " experiments: 29 -> 30\n", + "\n", + "--- Iteration 20 ---\n", + " n_experiments: 30\n", + "[(3, 5, 6, 8, 9, 10), (3, 5, 6, 11), (3, 5, 6, 9, 10), (3, 5, 6, 8, 11), (3, 5, 6, 10), (3, 11), (3, 5, 6, 11), (3, 5, 6, 10, 11), (3, 5, 6, 11), (3, 5, 6, 9), (4, 5, 6), (4, 5, 6, 9), (0, 4, 5), (1, 4, 5), (1, 2, 4, 5)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed in `gen_candidates_scipy` with the following warning(s):\n", + "[OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .'), OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .')]\n", + "Trying again with a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed in `gen_candidates_scipy` with the following warning(s):\n", + "[OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .')]\n", + "Trying again with a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed on the second try, after generating a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ". chosen subset: (4, 5, 6)\n", + " gold standard acq value: -10.231045\n", + " acq_val from get_candidate: -5.504064077636816\n", + " new y value: [-0.83019756]\n", + " experiments: 30 -> 31\n", + "\n", + "--- Iteration 21 ---\n", + " n_experiments: 31\n", + "[(1, 4, 5, 6, 11), (1, 4, 5, 6, 8, 9), (1, 4, 5, 6, 8, 10), (1, 4, 5, 6, 8, 11), (1, 4, 5, 6, 8), (1, 4, 5, 6, 10, 11), (1, 4, 5, 6, 8, 10), (1, 4, 5, 6, 10), (1, 4, 5, 6, 11), (1, 4, 5, 6, 11), (4, 5, 6), (4, 5, 6), (4, 5, 6, 9), (0, 4, 5), (1, 4, 5)]\n", + ". chosen subset: (1, 4, 5, 6, 8)\n", + " gold standard acq value: -40.221332\n", + " acq_val from get_candidate: -6.113333182857572\n", + " new y value: [-0.60362828]\n", + " experiments: 31 -> 32\n", + "\n", + "--- Iteration 22 ---\n", + " n_experiments: 32\n", + "[(3, 4, 5, 6, 9, 11), (3, 4, 5, 6), (3, 4, 5, 6, 10, 11), (3, 4, 5, 6, 7, 11), (3, 4, 5, 6, 10), (3, 4, 5), (3, 4, 5, 6, 7, 10), (3, 4, 5, 6, 10), (3, 4, 5, 6, 11), (3, 4, 5, 6, 10, 11), (4, 5, 6), (4, 5, 6), (4, 5, 6, 9), (0, 4, 5), (1, 4, 5)]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed in `gen_candidates_scipy` with the following warning(s):\n", + "[OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .')]\n", + "Trying again with a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed on the second try, after generating a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + ". chosen subset: (3, 4, 5, 6, 10, 11)\n", + " gold standard acq value: -10.387241\n", + " acq_val from get_candidate: -5.657902067201425\n", + " new y value: [-0.83526548]\n", + " experiments: 32 -> 33\n", + "\n", + "--- Iteration 23 ---\n", + " n_experiments: 33\n", + "[(4, 5, 6, 10), (4, 5, 6, 9, 10), (4, 5, 6, 11), (4, 5, 6, 10, 11), (4, 5, 6), (4, 5, 6, 9, 10, 11), (4, 5, 6, 10), (4, 5, 6, 10), (4, 5, 6, 9, 11), (4, 5, 6, 11), (4, 5, 6, 10, 11), (4, 5, 6), (4, 5, 6), (4, 5, 6, 9), (0, 4, 5)]\n", + ". chosen subset: (4, 5, 6, 10, 11)\n", + " gold standard acq value: -40.243085\n", + " acq_val from get_candidate: -8.215889445948243\n", + " new y value: [-0.33540195]\n", + " experiments: 33 -> 34\n", + "\n", + "--- Iteration 24 ---\n", + " n_experiments: 34\n", + "[(0, 4, 5, 6), (0, 4, 5, 6, 10, 11), (0, 4, 5, 6, 11), (0, 4, 5, 6, 10), (0, 4, 5, 6, 9, 11), (0, 4, 5, 6, 10, 11), (0, 4, 5, 6, 11), (0, 4, 5, 6), (0, 4, 5, 6, 10, 11), (0, 4, 5, 6, 10, 11), (4, 5, 6, 10, 11), (4, 5, 6), (4, 5, 6), (4, 5, 6, 9), (0, 4, 5)]\n", + ". chosen subset: (0, 4, 5, 6)\n", + " gold standard acq value: -9.798223\n", + " acq_val from get_candidate: -7.05930176242321\n", + " new y value: [-0.20200808]\n", + " experiments: 34 -> 35\n", + "\n", + "--- Iteration 25 ---\n", + " n_experiments: 35\n", + "[(4, 5, 6, 8), (4, 5, 11), (4, 5, 9, 11), (4, 5, 6, 9), (4, 5, 10, 11), (4, 5, 7, 8), (4, 5, 7, 8), (4, 5, 6, 10), (4, 5, 6, 7, 11), (4, 5, 6, 8), (4, 5, 6, 10, 11), (4, 5, 6), (4, 5, 6), (4, 5, 6, 9), (0, 4, 5)]\n", + ". chosen subset: (0, 4, 5)\n", + " gold standard acq value: -3.633982\n", + " acq_val from get_candidate: -3.6339824213356966\n", + " new y value: [-0.96792611]\n", + " experiments: 35 -> 36\n", + "\n", + "--- Iteration 26 ---\n", + " n_experiments: 36\n", + "[(3, 4, 5, 10, 11), (3, 4, 5, 6, 7, 8), (3, 4, 5, 7, 9, 10), (3, 4, 5, 8, 10, 11), (3, 4, 5, 11), (3, 4, 5, 6), (3, 4, 5, 10, 11), (3, 4, 5, 6, 8), (3, 4, 5, 11), (3, 4, 5, 10, 11), (0, 4, 5), (4, 5, 6, 10, 11), (4, 5, 6), (4, 5, 6), (4, 5, 6, 9)]\n", + ". chosen subset: (0, 4, 5)\n", + " gold standard acq value: -4.511042\n", + " acq_val from get_candidate: -4.511042099621564\n", + " new y value: [-0.98494359]\n", + " experiments: 36 -> 37\n", + "\n", + "--- Iteration 27 ---\n", + " n_experiments: 37\n", + "[(3, 4, 5, 10), (3, 4, 5), (3, 4, 5, 11), (3, 4, 5, 7, 11), (3, 4, 5, 7), (3, 4, 5, 7, 10), (3, 4, 5, 10, 11), (3, 4, 5), (3, 4, 5, 7, 10), (3, 4, 5, 9, 10), (0, 4, 5), (0, 4, 5), (4, 5, 6, 10, 11), (4, 5, 6), (4, 5, 6)]\n", + ". chosen subset: (0, 4, 5)\n", + " gold standard acq value: -5.442976\n", + " acq_val from get_candidate: -6.035241018591719\n", + " new y value: [-0.97803858]\n", + " experiments: 37 -> 38\n", + "\n", + "--- Iteration 28 ---\n", + " n_experiments: 38\n", + "[(0, 4, 5, 9), (0, 4, 5, 11), (0, 4, 5, 9, 11), (0, 4, 5, 6, 10, 11), (0, 4, 5, 7, 9, 11), (0, 4, 5), (0, 4, 5, 7, 8, 9), (0, 4, 5, 7), (0, 4, 5, 7), (0, 4, 5, 8), (0, 4, 5), (0, 4, 5), (0, 4, 5), (4, 5, 6, 10, 11), (4, 5, 6)]\n", + ". chosen subset: (0, 4, 5, 6, 10, 11)\n", + " gold standard acq value: -5.642574\n", + " acq_val from get_candidate: -6.097594757948485\n", + " new y value: [-0.9802264]\n", + " experiments: 38 -> 39\n", + "\n", + "--- Iteration 29 ---\n", + " n_experiments: 39\n", + "[(4, 5), (4, 5, 7, 9, 10, 11), (4, 5, 10, 11), (4, 5, 6), (4, 5, 9, 11), (4, 5, 6, 9, 11), (4, 5, 9, 10), (4, 5, 9), (4, 5, 9), (4, 5, 9), (0, 4, 5), (0, 4, 5, 6, 10, 11), (0, 4, 5), (0, 4, 5), (4, 5, 6, 10, 11)]\n", + ". chosen subset: (0, 4, 5, 6, 10, 11)\n", + " gold standard acq value: -5.489166\n", + " acq_val from get_candidate: -7.30397498375664\n", + " new y value: [-0.98425541]\n", + " experiments: 39 -> 40\n", + "\n", + "--- Iteration 30 ---\n", + " n_experiments: 40\n", + "[(0, 4, 5, 10), (0, 4, 5, 6, 9), (0, 4, 5, 6, 9, 10), (0, 4, 5, 11), (0, 4, 5, 8, 11), (0, 4, 5, 8), (0, 4, 8), (0, 4, 5, 10), (0, 4, 5, 10), (0, 4, 5, 8), (0, 4, 5), (0, 4, 5, 6, 10, 11), (0, 4, 5, 6, 10, 11), (0, 4, 5), (0, 4, 5)]\n", + ". chosen subset: (0, 4, 5, 8, 11)\n", + " gold standard acq value: -5.430853\n", + " acq_val from get_candidate: -7.56769278457636\n", + " new y value: [-0.9837223]\n", + " experiments: 40 -> 41\n", + "\n", + "--- Iteration 31 ---\n", + " n_experiments: 41\n", + "[(0, 2, 4, 5, 11), (0, 2, 4, 5, 6, 9), (0, 2, 4, 5, 9), (0, 2, 4, 5, 10, 11), (0, 2, 4, 5), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 9, 10), (0, 2, 4, 5, 11), (0, 2, 4, 5), (0, 2, 4, 5, 10), (0, 4, 5), (0, 4, 5, 6, 10, 11), (0, 4, 5, 8, 11), (0, 4, 5, 6, 10, 11), (0, 4, 5)]\n", + ". chosen subset: (0, 2, 4, 5, 7, 11)\n", + " gold standard acq value: -5.376202\n", + " acq_val from get_candidate: -6.934757907122603\n", + " new y value: [-1.1601459]\n", + " experiments: 41 -> 42\n", + "\n", + "--- Iteration 32 ---\n", + " n_experiments: 42\n", + "[(2, 4, 5, 6, 11), (2, 4, 5, 8, 9, 10), (2, 4, 5, 9, 11), (2, 4, 5, 7, 8, 11), (2, 4, 5, 9), (2, 4, 5, 9), (2, 4, 5, 11), (2, 4, 5, 7, 10), (2, 4, 5, 11), (2, 4, 5), (0, 2, 4, 5, 7, 11), (0, 4, 5), (0, 4, 5, 6, 10, 11), (0, 4, 5, 8, 11), (0, 4, 5, 6, 10, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 7, 11)\n", + " gold standard acq value: -3.602083\n", + " acq_val from get_candidate: -3.594316803214877\n", + " new y value: [-1.30581201]\n", + " experiments: 42 -> 43\n", + "\n", + "--- Iteration 33 ---\n", + " n_experiments: 43\n", + "[(2, 4, 5, 11), (2, 4, 5, 10, 11), (2, 4, 5, 6, 7, 11), (2, 4, 5, 8, 11), (2, 4, 5, 7, 11), (2, 4, 5, 8, 10), (2, 4, 5, 6, 10, 11), (2, 4, 5, 8), (2, 4, 5, 8, 11), (2, 4, 5, 7, 9), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 7, 11), (0, 4, 5), (0, 4, 5, 6, 10, 11), (0, 4, 5, 8, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 7, 11)\n", + " gold standard acq value: -2.519226\n", + " acq_val from get_candidate: -2.5192241745907564\n", + " new y value: [-1.53492888]\n", + " experiments: 43 -> 44\n", + "\n", + "--- Iteration 34 ---\n", + " n_experiments: 44\n", + "[(2, 4, 5, 11), (2, 4, 5, 8, 10), (2, 4, 5, 8), (2, 4, 5, 10, 11), (2, 4, 5), (2, 4, 5, 10, 11), (2, 4, 5, 8, 9, 10), (2, 4, 5), (2, 4, 5), (2, 4, 5, 8, 10), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 7, 11), (0, 4, 5), (0, 4, 5, 6, 10, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 7, 11)\n", + " gold standard acq value: -3.300021\n", + " acq_val from get_candidate: -3.3000141248176336\n", + " new y value: [-1.45187633]\n", + " experiments: 44 -> 45\n", + "\n", + "--- Iteration 35 ---\n", + " n_experiments: 45\n", + "[(0, 4, 5), (0, 4, 5, 6, 9), (0, 4, 5, 6, 7), (0, 4, 5, 6, 7, 9), (4, 5, 8, 11), (4, 5, 9, 11), (4, 5, 9), (4, 5, 9, 11), (4, 5, 10), (4, 5, 11), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 7, 11), (0, 4, 5)]\n", + ". chosen subset: (0, 2, 4, 5, 7, 11)\n", + " gold standard acq value: -4.832028\n", + " acq_val from get_candidate: -4.832023729248723\n", + " new y value: [-1.58299409]\n", + " experiments: 45 -> 46\n", + "\n", + "--- Iteration 36 ---\n", + " n_experiments: 46\n", + "[(2, 4, 5, 6, 9), (2, 3, 5, 9), (2, 5, 11), (2, 5, 6), (2, 5, 6, 9, 11), (2, 5, 10, 11), (2, 5, 11), (2, 5), (2, 5), (2, 5), (0, 2, 4, 5, 11), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 7, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 7, 11)\n", + " gold standard acq value: -4.043174\n", + " acq_val from get_candidate: -4.043173073089793\n", + " new y value: [-1.64984846]\n", + " experiments: 46 -> 47\n", + "\n", + "--- Iteration 37 ---\n", + " n_experiments: 47\n", + "[(2, 3, 4, 5, 8, 11), (2, 3, 4, 5, 11), (2, 3, 4, 5), (2, 3, 4, 5, 10), (2, 3, 4, 5, 9, 11), (2, 3, 4, 5), (2, 3, 4, 5, 11), (2, 3, 4, 5), (2, 3, 4, 5), (2, 3, 4, 5, 7), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 7, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -3.986907\n", + " acq_val from get_candidate: -3.986905091764227\n", + " new y value: [-1.65134562]\n", + " experiments: 47 -> 48\n", + "\n", + "--- Iteration 38 ---\n", + " n_experiments: 48\n", + "[(2, 4, 5, 11), (2, 4, 5, 8, 9), (2, 4, 5, 10, 11), (2, 4, 5, 10), (2, 4, 5, 7), (2, 4, 5, 6, 8, 11), (2, 4, 5, 7), (2, 4, 5, 7), (2, 4, 5, 7, 10), (2, 4, 5), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 7, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -4.423856\n", + " acq_val from get_candidate: -4.423855322494039\n", + " new y value: [-1.65329982]\n", + " experiments: 48 -> 49\n", + "\n", + "--- Iteration 39 ---\n", + " n_experiments: 49\n", + "[(4, 5, 10), (4, 5, 11), (4, 5, 9), (4, 5, 10, 11), (4, 5, 6, 9, 11), (4, 5, 11), (4, 5, 10, 11), (4, 5), (4, 5, 11), (4, 5, 8, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 7, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -4.685702\n", + " acq_val from get_candidate: -5.163080123571689\n", + " new y value: [-1.66594167]\n", + " experiments: 49 -> 50\n", + "\n", + "--- Iteration 40 ---\n", + " n_experiments: 50\n", + "[(2, 3, 4, 5, 6, 7), (2, 3, 4, 5), (2, 3, 4, 5, 7, 8), (2, 3, 4, 5, 7), (2, 3, 4, 5, 10, 11), (2, 3, 4, 5, 10, 11), (2, 3, 4, 5, 7), (2, 3, 4, 5, 7), (2, 3, 4, 5, 7, 11), (2, 3, 4, 5, 10, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -4.934374\n", + " acq_val from get_candidate: -5.873189109380627\n", + " new y value: [-1.60185088]\n", + " experiments: 50 -> 51\n", + "\n", + "--- Iteration 41 ---\n", + " n_experiments: 51\n", + "[(2, 4, 5, 9, 10), (2, 4, 5, 10), (2, 5, 9, 11), (2, 5, 9), (2, 5, 8, 9, 10), (2, 5, 6, 11), (2, 5, 7, 10, 11), (2, 5, 11), (2, 5, 9), (2, 5, 8, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -4.002764\n", + " acq_val from get_candidate: -4.0027540397503785\n", + " new y value: [-1.20011115]\n", + " experiments: 51 -> 52\n", + "\n", + "--- Iteration 42 ---\n", + " n_experiments: 52\n", + "[(1, 2, 4, 5, 7, 9), (1, 2, 4, 5, 6, 10), (1, 2, 4, 5, 7, 10), (1, 2, 4, 5, 6, 11), (1, 2, 4, 5, 6, 9), (1, 2, 4, 5), (1, 2, 4, 5, 8, 9), (1, 2, 10, 11), (1, 2, 4, 5, 6, 8), (1, 2, 4, 5, 7, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -3.890156\n", + " acq_val from get_candidate: -3.8901501835564867\n", + " new y value: [-1.69266021]\n", + " experiments: 52 -> 53\n", + "\n", + "--- Iteration 43 ---\n", + " n_experiments: 53\n", + "[(2, 4, 5, 10), (2, 4, 5, 9, 10), (2, 4, 5, 6, 7, 10), (2, 4, 5, 9, 11), (2, 4, 5, 6, 10), (2, 4, 5, 9, 11), (2, 4, 5, 6, 9, 11), (2, 4, 5, 6, 8, 10), (2, 4, 5, 8, 10, 11), (2, 4, 5, 6), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -4.881730\n", + " acq_val from get_candidate: -5.975339633359058\n", + " new y value: [-1.69285623]\n", + " experiments: 53 -> 54\n", + "\n", + "--- Iteration 44 ---\n", + " n_experiments: 54\n", + "[(0, 4, 5, 7, 8, 10), (4, 5, 7, 11), (4, 5, 11), (4, 5, 10), (4, 5, 7, 8, 10, 11), (4, 5), (4, 6, 8), (4, 5), (4, 5), (4, 5, 8, 10, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -4.979595\n", + " acq_val from get_candidate: -6.5417658956841915\n", + " new y value: [-1.68063219]\n", + " experiments: 54 -> 55\n", + "\n", + "--- Iteration 45 ---\n", + " n_experiments: 55\n", + "[(4, 5, 7, 9, 10, 11), (4, 5, 10, 11), (4, 5, 6, 11), (4, 5), (4, 5, 11), (4, 5), (4, 5, 6, 8, 10), (4, 5, 7, 10, 11), (4, 5, 8, 10, 11), (4, 5, 6, 8), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -4.963915\n", + " acq_val from get_candidate: -6.891408896557104\n", + " new y value: [-1.69262525]\n", + " experiments: 55 -> 56\n", + "\n", + "--- Iteration 46 ---\n", + " n_experiments: 56\n", + "[(2, 4, 5, 9, 10), (2, 4, 5), (2, 4, 5, 6, 9), (2, 4, 5, 6), (2, 4, 5, 8, 9, 10), (2, 4, 5), (2, 4, 5, 11), (2, 4, 5, 8, 9, 11), (2, 4, 5, 11), (2, 4, 5, 7), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -4.983611\n", + " acq_val from get_candidate: -7.543342672934277\n", + " new y value: [-1.69413447]\n", + " experiments: 56 -> 57\n", + "\n", + "--- Iteration 47 ---\n", + " n_experiments: 57\n", + "[(4, 5, 11), (4, 5, 9, 11), (4, 5, 8, 11), (4, 5, 6, 7, 9), (4, 5, 8, 9, 10, 11), (4, 5, 11), (4, 5), (4, 5, 8, 10), (4, 5, 10), (4, 5, 7, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -4.957923\n", + " acq_val from get_candidate: -8.116538773904608\n", + " new y value: [-1.69197899]\n", + " experiments: 57 -> 58\n", + "\n", + "--- Iteration 48 ---\n", + " n_experiments: 58\n", + "[(2, 5, 11), (2, 5, 7, 10), (2, 5, 6), (2, 5, 8, 11), (2, 5, 7, 8, 11), (2, 5, 7, 8, 11), (2, 5, 10, 11), (2, 5, 9), (2, 5), (2, 5, 8, 10), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -5.169048\n", + " acq_val from get_candidate: -8.810839177360982\n", + " new y value: [-0.02139907]\n", + " experiments: 58 -> 59\n", + "\n", + "--- Iteration 49 ---\n", + " n_experiments: 59\n", + "[(4, 5, 10, 11), (4, 5, 10), (4, 5, 6, 7, 10, 11), (4, 5, 11), (4, 5, 9), (4, 5, 11), (4, 5, 11), (4, 5, 11), (4, 5, 10, 11), (4, 5, 9, 10, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11), (0, 2, 4, 5, 11)]\n", + ". chosen subset: (0, 2, 4, 5, 11)\n", + " gold standard acq value: -5.024432\n", + " acq_val from get_candidate: -8.772387196375274\n", + " new y value: [-1.69077765]\n", + " experiments: 59 -> 60\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "\n", + "def get_candidate():\n", + " acqf = strategy._get_acqfs(n=1)[0]\n", + " groups = Groups(\n", + " groups=[\n", + " NChooseK(\n", + " features=list(range(len(benchmark.domain.inputs.get_keys()))),\n", + " max_count=6,\n", + " min_count=0,\n", + " )\n", + " ]\n", + " )\n", + " bounds = utils.get_torch_bounds_from_domain(\n", + " benchmark.domain, strategy.input_preprocessing_specs\n", + " )\n", + "\n", + " def reward_fn(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + " ) -> float:\n", + " if len(selected_features) == 0:\n", + " candidates = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + " local_bounds = bounds[:, selected_features]\n", + " sobol_samples = draw_sobol_samples(bounds=local_bounds, n=256, q=1).squeeze(1)\n", + " candidates = torch.zeros((sobol_samples.shape[0], bounds.shape[1]), **tkwargs)\n", + " candidates[:, selected_features] = sobol_samples\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + "\n", + " tracker = _SelectionTracker(inner_fn=reward_fn)\n", + "\n", + " mcts = MCTS_DAG(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=tracker,\n", + " use_cache=False,\n", + " rollout_mode=\"ts_group_action\",\n", + " adaptive_prior_var=True,\n", + " separate_stop=True,\n", + " )\n", + " mcts.run(n_iterations=1500)\n", + " best_valuue = -np.inf\n", + " best_subset = None\n", + " best_candidate = None\n", + " # here we should gather candidates based on the top eval in the MCTS search\n", + " # and we should always re-evaluate the best performing selections from the current set\n", + " # of experiments to have proper exploitation in the search\n", + " # (currently we only have exploration through the MCTS search)\n", + " subsets_to_optimize = [i[0] for i in tracker.top_k(k=5)]\n", + " for _ in range(5):\n", + " leaf, _ = mcts._select_and_expand()\n", + " selected_features, cat_selections = mcts._get_selection(leaf)\n", + " subsets_to_optimize.append(selected_features)\n", + " # also add the top 5 performing subsets from the current experiments\n", + " # we should use here always 5 distinct subsets to get more exploration in\n", + " # based on the history\n", + " top_experiments = strategy.experiments.nsmallest(5, \"y\")\n", + " top_subsets = [\n", + " tuple(\n", + " i\n", + " for i, val in enumerate(row[benchmark.domain.inputs.get_keys()].to_numpy())\n", + " if val > 0.0\n", + " )\n", + " for _, row in top_experiments.iterrows()\n", + " ]\n", + " subsets_to_optimize.extend(top_subsets)\n", + " # now see if there are duplicates in subsets_to_optimize and remove them\n", + " print(subsets_to_optimize)\n", + " subsets_to_optimize = list(set(subsets_to_optimize))\n", + "\n", + " for selected_features in subsets_to_optimize:\n", + " # leaf, path = mcts._select_and_expand()\n", + " # selected_features, cat_selections = mcts._get_selection(leaf)\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features=fixed,\n", + " )\n", + " if acq_value > best_valuue:\n", + " best_valuue = acq_value\n", + " best_candidate = candidates\n", + " best_subset = selected_features\n", + " _, gold_acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features={6: 0.0, 7: 0.0, 8: 0.0, 9: 0.0, 10: 0.0, 11: 0.0},\n", + " )\n", + " print(f\". chosen subset: {best_subset}\")\n", + " print(f\" gold standard acq value: {gold_acq_value:.6f}\")\n", + " return strategy.acqf_optimizer._candidates_tensor_to_dataframe(\n", + " best_candidate, strategy.domain\n", + " ), best_valuue\n", + "\n", + "\n", + "# for _ in range(5):\n", + "# candidate, acq_val = get_candidate()\n", + "# new_experiments = benchmark.f(candidate, return_complete=True)\n", + "# strategy.tell(new_experiments,replace=False)\n", + "# print(acqf_val.item())\n", + "# print(strategy.experiments.y.min())\n", + "\n", + "for iteration in range(50):\n", + " acqf_check = strategy._get_acqfs(n=1)[0]\n", + "\n", + " # Probe: evaluate acqf at a fixed test point to see if model actually changed\n", + " # test_point = torch.zeros(1, 1, bounds.shape[1], **tkwargs)\n", + " # test_point[..., 0] = 0.5\n", + " # probe_val = acqf_check(test_point).item()\n", + "\n", + " print(f\"\\n--- Iteration {iteration} ---\")\n", + " print(f\" n_experiments: {len(strategy.experiments)}\")\n", + " # print(f\" acqf probe at fixed point: {probe_val:.10f}\")\n", + "\n", + " candidate, acq_val = get_candidate()\n", + " print(f\" acq_val from get_candidate: {acq_val}\")\n", + "\n", + " new_experiments = benchmark.f(candidate, return_complete=True)\n", + " print(f\" new y value: {new_experiments['y'].values}\")\n", + "\n", + " n_before = len(strategy.experiments)\n", + " strategy.tell(new_experiments, replace=False)\n", + " n_after = len(strategy.experiments)\n", + " print(f\" experiments: {n_before} -> {n_after}\")\n", + "\n", + " # Check if model params actually changed\n", + " # ls = strategy.model.covar_module.lengthscale\n", + " # print(f\" lengthscale[0:3]: {ls[0, :3].tolist()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
200.0000000.0000001.0000000.00.4137190.0000000.0000000.0000000.0000000.0000000.0000000.000000-0.013270True
210.0000000.0655051.0000000.00.3854150.6082150.0000000.0000000.0000000.0000000.0000000.000000-0.496923True
220.0000000.0000001.0000000.00.0000000.5683161.0000000.7330900.0000001.0000000.0000000.000000-0.114023True
230.0000000.0000001.0000000.00.0000000.5659850.0000000.0000000.0000000.9654120.0000000.000000-0.113414True
240.5845230.0000000.0000000.00.3363160.5930070.0000000.0000000.0000000.0000000.0000000.000000-0.642831True
250.0000000.0000000.0000000.00.0000000.5627801.0000000.9715010.7090410.0000000.0000000.000000-0.159344True
261.0000000.0000000.0000000.00.9770200.5577780.0000000.0000000.0000000.0000000.0000000.000000-0.000074True
270.0000000.4475530.0000000.00.3180410.6206190.0000000.0000000.0000000.0000000.0000000.000000-0.641240True
280.0000000.0000000.0000000.00.3030890.6167930.9674080.0000000.0000000.3767680.0000000.000000-0.820252True
290.0000000.0000000.0000000.00.3289790.6433151.0000000.0000000.0000000.0000000.0000000.000000-0.824299True
300.0000000.0000000.0000000.00.3019950.6928111.0000000.0000000.0000000.0000000.0000000.000000-0.830198True
310.0000000.4821890.0000000.00.3119730.6735871.0000000.0000000.9872340.0000000.0000000.000000-0.603628True
320.0000000.0000000.0000000.00.3069540.6704781.0000000.0000000.0000000.0000001.0000001.000000-0.835265True
330.0000000.0000000.0000000.00.2760731.0000001.0000000.0000000.0000000.0000001.0000000.920440-0.335402True
340.9608260.0000000.0000000.00.2974400.6613721.0000000.0000000.0000000.0000000.0000000.000000-0.202008True
350.1597400.0000000.0000000.00.3108340.6544220.0000000.0000000.0000000.0000000.0000000.000000-0.967926True
360.2253340.0000000.0000000.00.3069480.6584110.0000000.0000000.0000000.0000000.0000000.000000-0.984944True
370.2393490.0000000.0000000.00.3202000.6859150.0000000.0000000.0000000.0000000.0000000.000000-0.978039True
380.2358960.0000000.0000000.00.2898600.6789970.8706880.0000000.0000000.0000000.9358820.789118-0.980226True
390.2223870.0000000.0000000.00.3038870.6754480.6717070.0000000.0000000.0000000.7534820.926419-0.984255True
400.2485640.0000000.0000000.00.3035690.6526040.0000000.0000001.0000000.0000000.0000000.999482-0.983722True
410.2394560.0000000.1489110.00.3034530.6639190.0000001.0000000.0000000.0000000.0000000.787317-1.160146True
420.2348280.0000000.2590110.00.3058150.6604200.0000001.0000000.0000000.0000000.0000000.972869-1.305812True
430.2423870.0000000.4007710.00.3073310.6525690.0000000.0662850.0000000.0000000.0000000.949719-1.534929True
440.3475790.0000000.5050230.00.2949800.6492920.0000000.0000000.0000000.0000000.0000000.987207-1.451876True
450.2145850.0000000.4288410.00.2979920.6259180.0000000.0000000.0000000.0000000.0000000.956518-1.582994True
460.1626420.0000000.4753600.00.3040140.6160970.0000000.0000000.0000000.0000000.0000000.988898-1.649848True
470.1076680.0000000.4996690.00.3175100.6399590.0000000.0000000.0000000.0000000.0000000.949427-1.651346True
480.1449160.0000000.4960690.00.3364870.5956290.0000000.0000000.0000000.0000000.0000000.943952-1.653300True
490.1370740.0000000.4996180.00.3190830.6123130.0000000.0000000.0000000.0000000.0000000.953744-1.665942True
500.1029370.0000000.5116030.00.3084610.5826050.0000000.0000000.0000000.0000000.0000000.132402-1.601851True
510.1071510.0000000.5103520.00.5337340.6177340.0000000.0000000.0000000.0000000.0000000.990703-1.200111True
520.1670970.0000000.5110260.00.3358150.6430760.0000000.0000000.0000000.0000000.0000000.968920-1.692660True
530.1739090.0000000.5360530.00.3319790.6516480.0000000.0000000.0000000.0000000.0000000.994026-1.692856True
540.1731160.0000000.5163340.00.3449650.6675500.0000000.0000000.0000000.0000000.0000000.022189-1.680632True
550.1784130.0000000.5342270.00.3373750.6343360.0000000.0000000.0000000.0000000.0000000.218449-1.692625True
560.1717460.0000000.5326250.00.3224420.6444290.0000000.0000000.0000000.0000000.0000000.039795-1.694134True
570.1598230.0000000.5393980.00.3371930.6429760.0000000.0000000.0000000.0000000.0000000.998248-1.691979True
580.9104270.0000001.0000000.00.6573920.7134680.0000000.0000000.0000000.0000000.0000000.974299-0.021399True
590.1858910.0000000.5225660.00.3410130.6394540.0000000.0000000.0000000.0000000.0000000.941160-1.690778True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "20 0.000000 0.000000 1.000000 0.0 0.413719 0.000000 0.000000 \n", + "21 0.000000 0.065505 1.000000 0.0 0.385415 0.608215 0.000000 \n", + "22 0.000000 0.000000 1.000000 0.0 0.000000 0.568316 1.000000 \n", + "23 0.000000 0.000000 1.000000 0.0 0.000000 0.565985 0.000000 \n", + "24 0.584523 0.000000 0.000000 0.0 0.336316 0.593007 0.000000 \n", + "25 0.000000 0.000000 0.000000 0.0 0.000000 0.562780 1.000000 \n", + "26 1.000000 0.000000 0.000000 0.0 0.977020 0.557778 0.000000 \n", + "27 0.000000 0.447553 0.000000 0.0 0.318041 0.620619 0.000000 \n", + "28 0.000000 0.000000 0.000000 0.0 0.303089 0.616793 0.967408 \n", + "29 0.000000 0.000000 0.000000 0.0 0.328979 0.643315 1.000000 \n", + "30 0.000000 0.000000 0.000000 0.0 0.301995 0.692811 1.000000 \n", + "31 0.000000 0.482189 0.000000 0.0 0.311973 0.673587 1.000000 \n", + "32 0.000000 0.000000 0.000000 0.0 0.306954 0.670478 1.000000 \n", + "33 0.000000 0.000000 0.000000 0.0 0.276073 1.000000 1.000000 \n", + "34 0.960826 0.000000 0.000000 0.0 0.297440 0.661372 1.000000 \n", + "35 0.159740 0.000000 0.000000 0.0 0.310834 0.654422 0.000000 \n", + "36 0.225334 0.000000 0.000000 0.0 0.306948 0.658411 0.000000 \n", + "37 0.239349 0.000000 0.000000 0.0 0.320200 0.685915 0.000000 \n", + "38 0.235896 0.000000 0.000000 0.0 0.289860 0.678997 0.870688 \n", + "39 0.222387 0.000000 0.000000 0.0 0.303887 0.675448 0.671707 \n", + "40 0.248564 0.000000 0.000000 0.0 0.303569 0.652604 0.000000 \n", + "41 0.239456 0.000000 0.148911 0.0 0.303453 0.663919 0.000000 \n", + "42 0.234828 0.000000 0.259011 0.0 0.305815 0.660420 0.000000 \n", + "43 0.242387 0.000000 0.400771 0.0 0.307331 0.652569 0.000000 \n", + "44 0.347579 0.000000 0.505023 0.0 0.294980 0.649292 0.000000 \n", + "45 0.214585 0.000000 0.428841 0.0 0.297992 0.625918 0.000000 \n", + "46 0.162642 0.000000 0.475360 0.0 0.304014 0.616097 0.000000 \n", + "47 0.107668 0.000000 0.499669 0.0 0.317510 0.639959 0.000000 \n", + "48 0.144916 0.000000 0.496069 0.0 0.336487 0.595629 0.000000 \n", + "49 0.137074 0.000000 0.499618 0.0 0.319083 0.612313 0.000000 \n", + "50 0.102937 0.000000 0.511603 0.0 0.308461 0.582605 0.000000 \n", + "51 0.107151 0.000000 0.510352 0.0 0.533734 0.617734 0.000000 \n", + "52 0.167097 0.000000 0.511026 0.0 0.335815 0.643076 0.000000 \n", + "53 0.173909 0.000000 0.536053 0.0 0.331979 0.651648 0.000000 \n", + "54 0.173116 0.000000 0.516334 0.0 0.344965 0.667550 0.000000 \n", + "55 0.178413 0.000000 0.534227 0.0 0.337375 0.634336 0.000000 \n", + "56 0.171746 0.000000 0.532625 0.0 0.322442 0.644429 0.000000 \n", + "57 0.159823 0.000000 0.539398 0.0 0.337193 0.642976 0.000000 \n", + "58 0.910427 0.000000 1.000000 0.0 0.657392 0.713468 0.000000 \n", + "59 0.185891 0.000000 0.522566 0.0 0.341013 0.639454 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "20 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "21 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "22 0.733090 0.000000 1.000000 0.000000 0.000000 \n", + "23 0.000000 0.000000 0.965412 0.000000 0.000000 \n", + "24 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "25 0.971501 0.709041 0.000000 0.000000 0.000000 \n", + "26 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "27 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "28 0.000000 0.000000 0.376768 0.000000 0.000000 \n", + "29 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "30 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "31 0.000000 0.987234 0.000000 0.000000 0.000000 \n", + "32 0.000000 0.000000 0.000000 1.000000 1.000000 \n", + "33 0.000000 0.000000 0.000000 1.000000 0.920440 \n", + "34 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "35 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "36 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "37 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "38 0.000000 0.000000 0.000000 0.935882 0.789118 \n", + "39 0.000000 0.000000 0.000000 0.753482 0.926419 \n", + "40 0.000000 1.000000 0.000000 0.000000 0.999482 \n", + "41 1.000000 0.000000 0.000000 0.000000 0.787317 \n", + "42 1.000000 0.000000 0.000000 0.000000 0.972869 \n", + "43 0.066285 0.000000 0.000000 0.000000 0.949719 \n", + "44 0.000000 0.000000 0.000000 0.000000 0.987207 \n", + "45 0.000000 0.000000 0.000000 0.000000 0.956518 \n", + "46 0.000000 0.000000 0.000000 0.000000 0.988898 \n", + "47 0.000000 0.000000 0.000000 0.000000 0.949427 \n", + "48 0.000000 0.000000 0.000000 0.000000 0.943952 \n", + "49 0.000000 0.000000 0.000000 0.000000 0.953744 \n", + "50 0.000000 0.000000 0.000000 0.000000 0.132402 \n", + "51 0.000000 0.000000 0.000000 0.000000 0.990703 \n", + "52 0.000000 0.000000 0.000000 0.000000 0.968920 \n", + "53 0.000000 0.000000 0.000000 0.000000 0.994026 \n", + "54 0.000000 0.000000 0.000000 0.000000 0.022189 \n", + "55 0.000000 0.000000 0.000000 0.000000 0.218449 \n", + "56 0.000000 0.000000 0.000000 0.000000 0.039795 \n", + "57 0.000000 0.000000 0.000000 0.000000 0.998248 \n", + "58 0.000000 0.000000 0.000000 0.000000 0.974299 \n", + "59 0.000000 0.000000 0.000000 0.000000 0.941160 \n", + "\n", + " y valid_y \n", + "20 -0.013270 True \n", + "21 -0.496923 True \n", + "22 -0.114023 True \n", + "23 -0.113414 True \n", + "24 -0.642831 True \n", + "25 -0.159344 True \n", + "26 -0.000074 True \n", + "27 -0.641240 True \n", + "28 -0.820252 True \n", + "29 -0.824299 True \n", + "30 -0.830198 True \n", + "31 -0.603628 True \n", + "32 -0.835265 True \n", + "33 -0.335402 True \n", + "34 -0.202008 True \n", + "35 -0.967926 True \n", + "36 -0.984944 True \n", + "37 -0.978039 True \n", + "38 -0.980226 True \n", + "39 -0.984255 True \n", + "40 -0.983722 True \n", + "41 -1.160146 True \n", + "42 -1.305812 True \n", + "43 -1.534929 True \n", + "44 -1.451876 True \n", + "45 -1.582994 True \n", + "46 -1.649848 True \n", + "47 -1.651346 True \n", + "48 -1.653300 True \n", + "49 -1.665942 True \n", + "50 -1.601851 True \n", + "51 -1.200111 True \n", + "52 -1.692660 True \n", + "53 -1.692856 True \n", + "54 -1.680632 True \n", + "55 -1.692625 True \n", + "56 -1.694134 True \n", + "57 -1.691979 True \n", + "58 -0.021399 True \n", + "59 -1.690778 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.tail(40)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[5.6497e-01, 1.5572e+00, 5.3921e-01, 7.7308e-01, 3.6603e-01,\n", + " 4.7035e-01, 1.0000e+04, 3.3364e+03, 4.6948e+03, 1.1356e+03,\n", + " 1.0000e+04, 1.0000e+04]],\n", + "\n", + " [[5.5754e-01, 1.5224e+00, 5.2977e-01, 7.3852e-01, 3.5700e-01,\n", + " 4.5979e-01, 1.0000e+04, 3.2414e+03, 4.5201e+03, 1.1079e+03,\n", + " 1.0000e+04, 1.0000e+04]],\n", + "\n", + " [[5.5478e-01, 1.5095e+00, 5.2631e-01, 7.2562e-01, 3.5340e-01,\n", + " 4.5590e-01, 1.0000e+04, 3.1990e+03, 4.4434e+03, 1.1016e+03,\n", + " 1.0000e+04, 1.0000e+04]],\n", + "\n", + " [[5.8321e-01, 1.6395e+00, 5.6162e-01, 8.5048e-01, 3.8500e-01,\n", + " 4.9547e-01, 1.0000e+04, 3.5157e+03, 5.0937e+03, 1.2091e+03,\n", + " 1.0000e+04, 1.0000e+04]]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.base_kernel.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.3403870.9393840.0000000.0000000.9888050.0000000.0000000.0000000.0591740.2835080.6511150.000000-0.093813True
10.0000000.0000000.2157870.6949210.6070900.0000000.6598540.8646030.0000000.0000000.0000000.683445-0.004209True
20.1622490.0000000.0000000.0000000.6701020.0000000.0000000.8375380.3137730.0000000.1942530.959623-0.003265True
30.6468610.1510700.0000000.0239440.0000000.0000000.0000000.0000000.1645060.1478150.0000000.930434-0.005103True
40.0000000.4927420.1909610.0000000.0581920.1650810.0000000.1383720.0000000.0000000.0000000.281132-0.040627True
50.9716560.3410370.2473110.1170810.0000000.0000000.1014650.0000000.0000000.6124490.0000000.000000-0.002416True
60.0097270.3187120.0000000.0000000.5629500.5014310.0000000.0000000.1304470.7760440.0000000.000000-0.216006True
70.8733560.2732550.0000000.0000000.3955370.0000000.0000000.9134770.4222420.5581640.0000000.000000-0.007566True
80.0000000.0000000.0000000.9213050.4787820.5553820.8431830.0000000.6743220.0000000.8404980.000000-0.019061True
90.0000000.8613230.0000000.0000000.2463090.0000000.3676990.5743210.0000000.7715720.2532700.000000-0.010974True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.340387 0.939384 0.000000 0.000000 0.988805 0.000000 0.000000 \n", + "1 0.000000 0.000000 0.215787 0.694921 0.607090 0.000000 0.659854 \n", + "2 0.162249 0.000000 0.000000 0.000000 0.670102 0.000000 0.000000 \n", + "3 0.646861 0.151070 0.000000 0.023944 0.000000 0.000000 0.000000 \n", + "4 0.000000 0.492742 0.190961 0.000000 0.058192 0.165081 0.000000 \n", + "5 0.971656 0.341037 0.247311 0.117081 0.000000 0.000000 0.101465 \n", + "6 0.009727 0.318712 0.000000 0.000000 0.562950 0.501431 0.000000 \n", + "7 0.873356 0.273255 0.000000 0.000000 0.395537 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.921305 0.478782 0.555382 0.843183 \n", + "9 0.000000 0.861323 0.000000 0.000000 0.246309 0.000000 0.367699 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.059174 0.283508 0.651115 0.000000 \n", + "1 0.864603 0.000000 0.000000 0.000000 0.683445 \n", + "2 0.837538 0.313773 0.000000 0.194253 0.959623 \n", + "3 0.000000 0.164506 0.147815 0.000000 0.930434 \n", + "4 0.138372 0.000000 0.000000 0.000000 0.281132 \n", + "5 0.000000 0.000000 0.612449 0.000000 0.000000 \n", + "6 0.000000 0.130447 0.776044 0.000000 0.000000 \n", + "7 0.913477 0.422242 0.558164 0.000000 0.000000 \n", + "8 0.000000 0.674322 0.000000 0.840498 0.000000 \n", + "9 0.574321 0.000000 0.771572 0.253270 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.093813 True \n", + "1 -0.004209 True \n", + "2 -0.003265 True \n", + "3 -0.005103 True \n", + "4 -0.040627 True \n", + "5 -0.002416 True \n", + "6 -0.216006 True \n", + "7 -0.007566 True \n", + "8 -0.019061 True \n", + "9 -0.010974 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((2, 4, 5, 6, 8), {}, -40.48350703005781)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "acqf = strategy._get_acqfs(n=1)[0]\n", + "groups = Groups(\n", + " groups=[\n", + " NChooseK(\n", + " features=list(range(len(benchmark.domain.inputs.get_keys()))),\n", + " max_count=6,\n", + " min_count=0,\n", + " )\n", + " ]\n", + ")\n", + "bounds = utils.get_torch_bounds_from_domain(\n", + " benchmark.domain, strategy.input_preprocessing_specs\n", + ")\n", + "\n", + "\n", + "def reward_fn(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " if len(selected_features) == 0:\n", + " candidates = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + " local_bounds = bounds[:, selected_features]\n", + " sobol_samples = draw_sobol_samples(bounds=local_bounds, n=256, q=1).squeeze(1)\n", + " candidates = torch.zeros((sobol_samples.shape[0], bounds.shape[1]), **tkwargs)\n", + " candidates[:, selected_features] = sobol_samples\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + "\n", + "\n", + "def reward_fn2(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=1,\n", + " raw_samples=64,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()\n", + "\n", + "\n", + "tracker = _SelectionTracker(inner_fn=reward_fn)\n", + "\n", + "mcts = MCTS_DAG(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=tracker,\n", + " use_cache=False,\n", + " rollout_mode=\"ts_group_action\",\n", + " adaptive_prior_var=True,\n", + " separate_stop=True,\n", + ")\n", + "mcts.run(n_iterations=1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[((2, 4, 5, 6, 8), {}),\n", + " ((2, 4, 5, 7, 11), {}),\n", + " ((2, 4, 5, 10), {}),\n", + " ((2, 4, 5, 6, 9, 10), {}),\n", + " ((2, 4, 5, 6, 10, 11), {})]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tracker.top_k(k=5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-40.47646914212883" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def reward_fn(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " if len(selected_features) == 0:\n", + " candidates = torch.zeros((1, bounds.shape[1]), **tkwargs)\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + " local_bounds = bounds[:, selected_features]\n", + " sobol_samples = draw_sobol_samples(bounds=local_bounds, n=1024, q=1).squeeze(1)\n", + " candidates = torch.zeros((sobol_samples.shape[0], bounds.shape[1]), **tkwargs)\n", + " candidates[:, selected_features] = sobol_samples\n", + " return acqf(candidates.unsqueeze(-2)).max().item()\n", + "\n", + "\n", + "reward_fn2((2, 4, 5, 6, 8), cat_selections={})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
550.1784130.00.5342270.00.3373750.6343360.00.00.00.00.00.218449-1.692625True
560.1717460.00.5326250.00.3224420.6444290.00.00.00.00.00.039795-1.694134True
570.1598230.00.5393980.00.3371930.6429760.00.00.00.00.00.998248-1.691979True
580.9104270.01.0000000.00.6573920.7134680.00.00.00.00.00.974299-0.021399True
590.1858910.00.5225660.00.3410130.6394540.00.00.00.00.00.941160-1.690778True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "55 0.178413 0.0 0.534227 0.0 0.337375 0.634336 0.0 \n", + "56 0.171746 0.0 0.532625 0.0 0.322442 0.644429 0.0 \n", + "57 0.159823 0.0 0.539398 0.0 0.337193 0.642976 0.0 \n", + "58 0.910427 0.0 1.000000 0.0 0.657392 0.713468 0.0 \n", + "59 0.185891 0.0 0.522566 0.0 0.341013 0.639454 0.0 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "55 0.0 0.0 0.0 0.0 0.218449 \n", + "56 0.0 0.0 0.0 0.0 0.039795 \n", + "57 0.0 0.0 0.0 0.0 0.998248 \n", + "58 0.0 0.0 0.0 0.0 0.974299 \n", + "59 0.0 0.0 0.0 0.0 0.941160 \n", + "\n", + " y valid_y \n", + "55 -1.692625 True \n", + "56 -1.694134 True \n", + "57 -1.691979 True \n", + "58 -0.021399 True \n", + "59 -1.690778 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.tail()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2285.4914393212052, sum_sq_rewards=100451.42783509253, n_visits=52, children={}),\n", + " [TSNode(partial_by_group=((),), stopped_by_group=(False,), group_idx=0, n_obs=1000, sum_rewards=-44064.18923151626, sum_sq_rewards=1941751.7644583723, n_visits=1000, children={-1: TSNode(partial_by_group=((),), stopped_by_group=(True,), group_idx=1, n_obs=5, sum_rewards=-224.94021040439745, sum_sq_rewards=10119.619651354918, n_visits=5, children={}), 8: TSNode(partial_by_group=((8,),), stopped_by_group=(False,), group_idx=0, n_obs=6, sum_rewards=-269.91724563250455, sum_sq_rewards=12142.55325008863, n_visits=6, children={-1: TSNode(partial_by_group=((8,),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-89.97261537395104, sum_sq_rewards=4047.5357586146, n_visits=2, children={}), 11: TSNode(partial_by_group=((8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.986396061316135, sum_sq_rewards=2023.7758305855998, n_visits=1, children={}), 9: TSNode(partial_by_group=((8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.98502118668674, sum_sq_rewards=2023.6521311666545, n_visits=1, children={}), 10: TSNode(partial_by_group=((8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.98662671146576, sum_sq_rewards=2023.7965828767647, n_visits=1, children={})}), 7: TSNode(partial_by_group=((7,),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.93141084670353, sum_sq_rewards=10118.82791894983, n_visits=5, children={-1: TSNode(partial_by_group=((7,),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.98650740308012, sum_sq_rewards=2023.7858483273824, n_visits=1, children={}), 8: TSNode(partial_by_group=((7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.985069711353, sum_sq_rewards=2023.656496935289, n_visits=1, children={}), 11: TSNode(partial_by_group=((7, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.98665192518045, sum_sq_rewards=2023.7988514373421, n_visits=1, children={}), 10: TSNode(partial_by_group=((7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.98663552719344, sum_sq_rewards=2023.7973760565433, n_visits=1, children={})}), 3: TSNode(partial_by_group=((3,),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.23151234808296, sum_sq_rewards=10056.156572886164, n_visits=5, children={-1: TSNode(partial_by_group=((3,),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.949456136088514, sum_sq_rewards=2020.4536069301455, n_visits=1, children={}), 6: TSNode(partial_by_group=((3, 6),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.94415484716498, sum_sq_rewards=2019.9770549259435, n_visits=1, children={}), 7: TSNode(partial_by_group=((3, 7),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.94554321592671, sum_sq_rewards=2020.1018549747357, n_visits=1, children={}), 5: TSNode(partial_by_group=((3, 5),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.44398039831759, sum_sq_rewards=1975.2673936460383, n_visits=1, children={})}), 9: TSNode(partial_by_group=((9,),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.93275755147988, sum_sq_rewards=10118.949083979358, n_visits=5, children={-1: TSNode(partial_by_group=((9,),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.986446128847994, sum_sq_rewards=2023.7803353037427, n_visits=1, children={}), 11: TSNode(partial_by_group=((9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.98659471978125, sum_sq_rewards=2023.7937044818505, n_visits=1, children={}), 10: TSNode(partial_by_group=((9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.97304135379608, sum_sq_rewards=4047.5740852325316, n_visits=2, children={-1: TSNode(partial_by_group=((9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.986463312302426, sum_sq_rewards=2023.7818813491322, n_visits=1, children={})})}), 2: TSNode(partial_by_group=((2,),), stopped_by_group=(False,), group_idx=0, n_obs=6, sum_rewards=-269.8572339995221, sum_sq_rewards=12137.154463860978, n_visits=6, children={-1: TSNode(partial_by_group=((2,),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.977521802756044, sum_sq_rewards=2022.9774675173953, n_visits=1, children={}), 6: TSNode(partial_by_group=((2, 6),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.97400137693486, sum_sq_rewards=2022.6607998525385, n_visits=1, children={}), 8: TSNode(partial_by_group=((2, 8),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.976441490100626, sum_sq_rewards=2022.880289112445, n_visits=1, children={}), 9: TSNode(partial_by_group=((2, 9),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.976477758224455, sum_sq_rewards=2022.883551536059, n_visits=1, children={}), 7: TSNode(partial_by_group=((2, 7),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.97660688529078, sum_sq_rewards=2022.8951669139858, n_visits=1, children={})}), 11: TSNode(partial_by_group=((11,),), stopped_by_group=(False,), group_idx=0, n_obs=6, sum_rewards=-269.928276671906, sum_sq_rewards=12143.545757844211, n_visits=6, children={-1: TSNode(partial_by_group=((11,),), stopped_by_group=(True,), group_idx=1, n_obs=5, sum_rewards=-224.94023360919533, sum_sq_rewards=10119.621739231898, n_visits=5, children={})}), 4: TSNode(partial_by_group=((4,),), stopped_by_group=(False,), group_idx=0, n_obs=926, sum_rewards=-40739.97326770851, sum_sq_rewards=1792414.37179124, n_visits=926, children={-1: TSNode(partial_by_group=((4,),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-179.71237238492932, sum_sq_rewards=8074.134208274124, n_visits=4, children={}), 9: TSNode(partial_by_group=((4, 9),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.64147522964618, sum_sq_rewards=10092.758486977678, n_visits=5, children={-1: TSNode(partial_by_group=((4, 9),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92602886047754, sum_sq_rewards=2018.3480691724606, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.85822122290605, sum_sq_rewards=4037.249961126774, n_visits=2, children={-1: TSNode(partial_by_group=((4, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.9286339499872, sum_sq_rewards=2018.5821486119426, n_visits=1, children={})}), 10: TSNode(partial_by_group=((4, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92938243834094, sum_sq_rewards=2018.649406290699, n_visits=1, children={})}), 8: TSNode(partial_by_group=((4, 8),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.63801919832252, sum_sq_rewards=10092.447957339235, n_visits=5, children={-1: TSNode(partial_by_group=((4, 8),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92610698430661, sum_sq_rewards=2018.3550887653628, n_visits=1, children={}), 10: TSNode(partial_by_group=((4, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92895818413118, sum_sq_rewards=2018.6112835114084, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.9253558188764, sum_sq_rewards=2018.2875954526514, n_visits=1, children={}), 9: TSNode(partial_by_group=((4, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92640282640292, sum_sq_rewards=2018.3816709202245, n_visits=1, children={})}), 6: TSNode(partial_by_group=((4, 6),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.64620054718833, sum_sq_rewards=10093.183096768624, n_visits=5, children={-1: TSNode(partial_by_group=((4, 6),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.93141911884064, sum_sq_rewards=2018.832424032918, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 6, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.929658386626386, sum_sq_rewards=2018.6742027389466, n_visits=1, children={}), 9: TSNode(partial_by_group=((4, 6, 9),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.929975725775975, sum_sq_rewards=2018.7027187188185, n_visits=1, children={}), 7: TSNode(partial_by_group=((4, 6, 7),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92665365135187, sum_sq_rewards=2018.4042083085283, n_visits=1, children={})}), 10: TSNode(partial_by_group=((4, 10),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.6453332671047, sum_sq_rewards=10093.105161513364, n_visits=5, children={-1: TSNode(partial_by_group=((4, 10),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-89.8589212417643, sum_sq_rewards=4037.312871881153, n_visits=2, children={}), 11: TSNode(partial_by_group=((4, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.8569506631153, sum_sq_rewards=4037.135791332535, n_visits=2, children={-1: TSNode(partial_by_group=((4, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92825650978432, sum_sq_rewards=2018.548233008977, n_visits=1, children={})})}), 7: TSNode(partial_by_group=((4, 7),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.64246752734923, sum_sq_rewards=10092.847666325106, n_visits=5, children={-1: TSNode(partial_by_group=((4, 7),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92866631320418, sum_sq_rewards=2018.5850566832482, n_visits=1, children={}), 9: TSNode(partial_by_group=((4, 7, 9),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.93065632824331, sum_sq_rewards=2018.7638780867105, n_visits=1, children={}), 10: TSNode(partial_by_group=((4, 7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.930696493988805, sum_sq_rewards=2018.767487434938, n_visits=1, children={}), 8: TSNode(partial_by_group=((4, 7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92497729584855, sum_sq_rewards=2018.2535850325075, n_visits=1, children={})}), 5: TSNode(partial_by_group=((4, 5),), stopped_by_group=(False,), group_idx=0, n_obs=891, sum_rewards=-39167.47720350492, sum_sq_rewards=1721764.546765213, n_visits=891, children={-1: TSNode(partial_by_group=((4, 5),), stopped_by_group=(True,), group_idx=1, n_obs=89, sum_rewards=-3912.7613868103085, sum_sq_rewards=172019.22774163214, n_visits=89, children={}), 10: TSNode(partial_by_group=((4, 5, 10),), stopped_by_group=(False,), group_idx=0, n_obs=113, sum_rewards=-4967.487764600544, sum_sq_rewards=218371.24035568372, n_visits=113, children={-1: TSNode(partial_by_group=((4, 5, 10),), stopped_by_group=(True,), group_idx=1, n_obs=59, sum_rewards=-2593.495125985222, sum_sq_rewards=114003.73086613679, n_visits=59, children={}), 11: TSNode(partial_by_group=((4, 5, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=53, sum_rewards=-2330.0600732954176, sum_sq_rewards=102437.4391939591, n_visits=53, children={-1: TSNode(partial_by_group=((4, 5, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2286.1179279582257, sum_sq_rewards=100506.52705712414, n_visits=52, children={})})}), 11: TSNode(partial_by_group=((4, 5, 11),), stopped_by_group=(False,), group_idx=0, n_obs=89, sum_rewards=-3912.7125718756874, sum_sq_rewards=172014.92435577657, n_visits=89, children={-1: TSNode(partial_by_group=((4, 5, 11),), stopped_by_group=(True,), group_idx=1, n_obs=88, sum_rewards=-3868.7714891661, sum_sq_rewards=170084.10560608577, n_visits=88, children={})}), 6: TSNode(partial_by_group=((4, 5, 6),), stopped_by_group=(False,), group_idx=0, n_obs=74, sum_rewards=-3253.47243268099, sum_sq_rewards=143041.75710217294, n_visits=74, children={-1: TSNode(partial_by_group=((4, 5, 6),), stopped_by_group=(True,), group_idx=1, n_obs=10, sum_rewards=-439.66411802568587, sum_sq_rewards=19330.465439171287, n_visits=10, children={}), 8: TSNode(partial_by_group=((4, 5, 6, 8),), stopped_by_group=(False,), group_idx=0, n_obs=16, sum_rewards=-703.3535630544754, sum_sq_rewards=30919.16275750841, n_visits=16, children={-1: TSNode(partial_by_group=((4, 5, 6, 8),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-175.8363854988221, sum_sq_rewards=7729.615020124167, n_visits=4, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=4, sum_rewards=-175.84781106588187, sum_sq_rewards=7730.616296158963, n_visits=4, children={-1: TSNode(partial_by_group=((4, 5, 6, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=3, sum_rewards=-131.85283673154302, sum_sq_rewards=5795.058529479829, n_visits=3, children={})}), 10: TSNode(partial_by_group=((4, 5, 6, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-219.8227981519812, sum_sq_rewards=9664.422789140159, n_visits=5, children={-1: TSNode(partial_by_group=((4, 5, 6, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-88.02325681303296, sum_sq_rewards=3874.046928997293, n_visits=2, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 8, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.82381529890702, sum_sq_rewards=3856.5113793941046, n_visits=2, children={})}), 9: TSNode(partial_by_group=((4, 5, 6, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-87.8621293322072, sum_sq_rewards=3859.8777774492482, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-43.95218402328926, sum_sq_rewards=1931.7944804170838, n_visits=1, children={})})}), 11: TSNode(partial_by_group=((4, 5, 6, 11),), stopped_by_group=(False,), group_idx=0, n_obs=17, sum_rewards=-747.2567353148665, sum_sq_rewards=32846.64274766709, n_visits=17, children={-1: TSNode(partial_by_group=((4, 5, 6, 11),), stopped_by_group=(True,), group_idx=1, n_obs=16, sum_rewards=-703.2868673295923, sum_sq_rewards=30913.29345702465, n_visits=16, children={})}), 7: TSNode(partial_by_group=((4, 5, 6, 7),), stopped_by_group=(False,), group_idx=0, n_obs=12, sum_rewards=-527.798461848819, sum_sq_rewards=23214.284110934812, n_visits=12, children={-1: TSNode(partial_by_group=((4, 5, 6, 7),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-87.93915992300722, sum_sq_rewards=3866.6480880716545, n_visits=2, children={}), 10: TSNode(partial_by_group=((4, 5, 6, 7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=3, sum_rewards=-131.99082540147066, sum_sq_rewards=5807.193616973811, n_visits=3, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 10),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.0023435605371, sum_sq_rewards=1936.2062388195407, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 7, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=1, sum_rewards=-43.972912323197995, sum_sq_rewards=1933.617018183658, n_visits=1, children={})}), 8: TSNode(partial_by_group=((4, 5, 6, 7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-87.89291074303364, sum_sq_rewards=3862.5830404748212, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 8),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-43.922361457541186, sum_sq_rewards=1929.1738360068996, n_visits=1, children={})}), 11: TSNode(partial_by_group=((4, 5, 6, 7, 11),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-88.01791838251545, sum_sq_rewards=3873.5779551264786, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-43.98685793673277, sum_sq_rewards=1934.8436711463105, n_visits=1, children={})}), 9: TSNode(partial_by_group=((4, 5, 6, 7, 9),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-87.94349197260667, sum_sq_rewards=3867.0355324076445, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 9),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.029375144195555, sum_sq_rewards=1938.5858755883053, n_visits=1, children={})})}), 10: TSNode(partial_by_group=((4, 5, 6, 10),), stopped_by_group=(False,), group_idx=0, n_obs=10, sum_rewards=-439.66987745161964, sum_sq_rewards=19330.969558831595, n_visits=10, children={-1: TSNode(partial_by_group=((4, 5, 6, 10),), stopped_by_group=(True,), group_idx=1, n_obs=6, sum_rewards=-263.78454656081044, sum_sq_rewards=11597.051932874065, n_visits=6, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=3, sum_rewards=-131.9134151478395, sum_sq_rewards=5800.388251850703, n_visits=3, children={-1: TSNode(partial_by_group=((4, 5, 6, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-87.88453758728016, sum_sq_rewards=3861.84619260798, n_visits=2, children={})})}), 9: TSNode(partial_by_group=((4, 5, 6, 9),), stopped_by_group=(False,), group_idx=0, n_obs=8, sum_rewards=-351.8109890573494, sum_sq_rewards=15471.381338727459, n_visits=8, children={-1: TSNode(partial_by_group=((4, 5, 6, 9),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-87.96851731041005, sum_sq_rewards=3869.232754546095, n_visits=2, children={}), 10: TSNode(partial_by_group=((4, 5, 6, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=4, sum_rewards=-175.91496034420223, sum_sq_rewards=7736.521615340711, n_visits=4, children={-1: TSNode(partial_by_group=((4, 5, 6, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.00361414890658, sum_sq_rewards=1936.3180581658512, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.9515428139761, sum_sq_rewards=3867.739243850584, n_visits=2, children={})}), 11: TSNode(partial_by_group=((4, 5, 6, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.004609718471485, sum_sq_rewards=1936.405676474995, n_visits=1, children={})})}), 7: TSNode(partial_by_group=((4, 5, 7),), stopped_by_group=(False,), group_idx=0, n_obs=261, sum_rewards=-11472.15312187526, sum_sq_rewards=504254.31677490985, n_visits=261, children={-1: TSNode(partial_by_group=((4, 5, 7),), stopped_by_group=(True,), group_idx=1, n_obs=45, sum_rewards=-1978.0154399179673, sum_sq_rewards=86945.49803190352, n_visits=45, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=54, sum_rewards=-2373.663953734188, sum_sq_rewards=104338.60066699992, n_visits=54, children={-1: TSNode(partial_by_group=((4, 5, 7, 10),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.9045444714175, sum_sq_rewards=50239.67937641118, n_visits=26, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=27, sum_rewards=-1186.815408600552, sum_sq_rewards=52167.846096387606, n_visits=27, children={-1: TSNode(partial_by_group=((4, 5, 7, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.8797542310706, sum_sq_rewards=50237.50437151309, n_visits=26, children={})})}), 9: TSNode(partial_by_group=((4, 5, 7, 9),), stopped_by_group=(False,), group_idx=0, n_obs=38, sum_rewards=-1670.4877740498455, sum_sq_rewards=73435.03131497768, n_visits=38, children={-1: TSNode(partial_by_group=((4, 5, 7, 9),), stopped_by_group=(True,), group_idx=1, n_obs=14, sum_rewards=-615.3066670101945, sum_sq_rewards=27043.037949194786, n_visits=14, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=15, sum_rewards=-659.3723206412016, sum_sq_rewards=28984.806842262173, n_visits=15, children={-1: TSNode(partial_by_group=((4, 5, 7, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=6, sum_rewards=-263.6360004678156, sum_sq_rewards=11583.993595170117, n_visits=6, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=8, sum_rewards=-351.7372275893211, sum_sq_rewards=15464.893098870949, n_visits=8, children={})}), 11: TSNode(partial_by_group=((4, 5, 7, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=8, sum_rewards=-351.8172262134047, sum_sq_rewards=15471.929156006257, n_visits=8, children={-1: TSNode(partial_by_group=((4, 5, 7, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=7, sum_rewards=-307.8029281771427, sum_sq_rewards=13534.670724381365, n_visits=7, children={})})}), 11: TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(False,), group_idx=0, n_obs=53, sum_rewards=-2329.4620246822856, sum_sq_rewards=102384.84021208857, n_visits=53, children={-1: TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2285.4914393212052, sum_sq_rewards=100451.42783509253, n_visits=52, children={})}), 8: TSNode(partial_by_group=((4, 5, 7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=70, sum_rewards=-3076.565577991281, sum_sq_rewards=135218.00988236978, n_visits=70, children={-1: TSNode(partial_by_group=((4, 5, 7, 8),), stopped_by_group=(True,), group_idx=1, n_obs=17, sum_rewards=-747.127493578221, sum_sq_rewards=32835.28378186662, n_visits=17, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=22, sum_rewards=-966.8382846117236, sum_sq_rewards=42489.845954838915, n_visits=22, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=21, sum_rewards=-922.920341731806, sum_sq_rewards=40561.06024803523, n_visits=21, children={})}), 9: TSNode(partial_by_group=((4, 5, 7, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=10, sum_rewards=-439.62629242740445, sum_sq_rewards=19327.13755483925, n_visits=10, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=5, sum_rewards=-219.89046816143235, sum_sq_rewards=9670.368169661724, n_visits=5, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 8, 9, 10),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.92074668263248, sum_sq_rewards=3865.02911757097, n_visits=2, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 9, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.91360021280931, sum_sq_rewards=3864.400552291376, n_visits=2, children={})}), 10: TSNode(partial_by_group=((4, 5, 7, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=20, sum_rewards=-878.9950879943008, sum_sq_rewards=38631.64121969437, n_visits=20, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=9, sum_rewards=-395.73839266807806, sum_sq_rewards=17400.996240694472, n_visits=9, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=10, sum_rewards=-439.34521488544067, sum_sq_rewards=19302.426864498713, n_visits=10, children={})})})}), 9: TSNode(partial_by_group=((4, 5, 9),), stopped_by_group=(False,), group_idx=0, n_obs=94, sum_rewards=-4132.514515339099, sum_sq_rewards=181677.5143163395, n_visits=94, children={-1: TSNode(partial_by_group=((4, 5, 9),), stopped_by_group=(True,), group_idx=1, n_obs=21, sum_rewards=-923.2496630550982, sum_sq_rewards=40590.03751593179, n_visits=21, children={}), 11: TSNode(partial_by_group=((4, 5, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=31, sum_rewards=-1362.9331668007874, sum_sq_rewards=59922.18876577904, n_visits=31, children={-1: TSNode(partial_by_group=((4, 5, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=30, sum_rewards=-1318.9713445609393, sum_sq_rewards=57989.546951131044, n_visits=30, children={})}), 10: TSNode(partial_by_group=((4, 5, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=41, sum_rewards=-1802.414938877525, sum_sq_rewards=79236.60740220043, n_visits=41, children={-1: TSNode(partial_by_group=((4, 5, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=22, sum_rewards=-967.1973728850329, sum_sq_rewards=42521.41229007139, n_visits=22, children={}), 11: TSNode(partial_by_group=((4, 5, 9, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=18, sum_rewards=-791.2899300890582, sum_sq_rewards=34785.5579160644, n_visits=18, children={-1: TSNode(partial_by_group=((4, 5, 9, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=17, sum_rewards=-747.3741921983864, sum_sq_rewards=32856.96588158221, n_visits=17, children={})})})}), 8: TSNode(partial_by_group=((4, 5, 8),), stopped_by_group=(False,), group_idx=0, n_obs=170, sum_rewards=-7472.359537464019, sum_sq_rewards=328448.16905515944, n_visits=170, children={-1: TSNode(partial_by_group=((4, 5, 8),), stopped_by_group=(True,), group_idx=1, n_obs=47, sum_rewards=-2066.0088489035124, sum_sq_rewards=90816.91547065707, n_visits=47, children={}), 11: TSNode(partial_by_group=((4, 5, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=27, sum_rewards=-1186.8973845709677, sum_sq_rewards=52175.054707278956, n_visits=27, children={-1: TSNode(partial_by_group=((4, 5, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.9747436443677, sum_sq_rewards=50245.85632131191, n_visits=26, children={})}), 10: TSNode(partial_by_group=((4, 5, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=56, sum_rewards=-2460.9770037747253, sum_sq_rewards=108150.18593153136, n_visits=56, children={-1: TSNode(partial_by_group=((4, 5, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.5694461425167, sum_sq_rewards=50210.21206557458, n_visits=26, children={}), 11: TSNode(partial_by_group=((4, 5, 8, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=29, sum_rewards=-1274.4010901642278, sum_sq_rewards=56003.4046869463, n_visits=29, children={-1: TSNode(partial_by_group=((4, 5, 8, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=28, sum_rewards=-1230.45112241069, sum_sq_rewards=54071.80502140927, n_visits=28, children={})})}), 9: TSNode(partial_by_group=((4, 5, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=39, sum_rewards=-1714.4664991327732, sum_sq_rewards=75369.15035441122, n_visits=39, children={-1: TSNode(partial_by_group=((4, 5, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=16, sum_rewards=-703.2787708937077, sum_sq_rewards=30912.583875859167, n_visits=16, children={}), 10: TSNode(partial_by_group=((4, 5, 8, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=13, sum_rewards=-571.512793859451, sum_sq_rewards=25125.15452660411, n_visits=13, children={-1: TSNode(partial_by_group=((4, 5, 8, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-175.85435917582294, sum_sq_rewards=7731.1903163133575, n_visits=4, children={}), 11: TSNode(partial_by_group=((4, 5, 8, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=8, sum_rewards=-351.7437208786052, sum_sq_rewards=15465.462121713665, n_visits=8, children={})}), 11: TSNode(partial_by_group=((4, 5, 8, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=9, sum_rewards=-395.70104514964174, sum_sq_rewards=17397.709017938054, n_visits=9, children={-1: TSNode(partial_by_group=((4, 5, 8, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=8, sum_rewards=-351.6898334141943, sum_sq_rewards=15460.722259515675, n_visits=8, children={})})})})}), 11: TSNode(partial_by_group=((4, 11),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.6420447794053, sum_sq_rewards=10092.809672320658, n_visits=5, children={-1: TSNode(partial_by_group=((4, 11),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-179.71615062026032, sum_sq_rewards=8074.473706321962, n_visits=4, children={})})}), 10: TSNode(partial_by_group=((10,),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.94059775966116, sum_sq_rewards=10119.654504112066, n_visits=5, children={-1: TSNode(partial_by_group=((10,),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-89.97609798267504, sum_sq_rewards=4047.8491040940626, n_visits=2, children={}), 11: TSNode(partial_by_group=((10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.97632738615033, sum_sq_rewards=4047.8697449504484, n_visits=2, children={-1: TSNode(partial_by_group=((10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.98814644975891, sum_sq_rewards=2023.9333209849551, n_visits=1, children={})})}), 5: TSNode(partial_by_group=((5,),), stopped_by_group=(False,), group_idx=0, n_obs=11, sum_rewards=-494.1673679075095, sum_sq_rewards=22200.126171741067, n_visits=11, children={-1: TSNode(partial_by_group=((5,),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92745202360811, sum_sq_rewards=2018.4759453336087, n_visits=1, children={}), 9: TSNode(partial_by_group=((5, 9),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.85090668325084, sum_sq_rewards=4036.592722659957, n_visits=2, children={-1: TSNode(partial_by_group=((5, 9),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.927291660332195, sum_sq_rewards=2018.4615359325549, n_visits=1, children={})}), 7: TSNode(partial_by_group=((5, 7),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.84703752655294, sum_sq_rewards=4036.2450855788056, n_visits=2, children={-1: TSNode(partial_by_group=((5, 7),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.925690156785095, sum_sq_rewards=2018.317636063457, n_visits=1, children={})}), 6: TSNode(partial_by_group=((5, 6),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.84763643697634, sum_sq_rewards=4036.298886676833, n_visits=2, children={-1: TSNode(partial_by_group=((5, 6),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92392140069303, sum_sq_rewards=2018.1587140156455, n_visits=1, children={})}), 11: TSNode(partial_by_group=((5, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.924768000987555, sum_sq_rewards=2018.2347799425554, n_visits=1, children={}), 10: TSNode(partial_by_group=((5, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92323991729833, sum_sq_rewards=2018.0974846671463, n_visits=1, children={}), 8: TSNode(partial_by_group=((5, 8),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92368351675967, sum_sq_rewards=2018.1373407139843, n_visits=1, children={})}), 1: TSNode(partial_by_group=((1,),), stopped_by_group=(False,), group_idx=0, n_obs=4, sum_rewards=-179.92250534429655, sum_sq_rewards=8093.026990798879, n_visits=4, children={-1: TSNode(partial_by_group=((1,),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.98178713645662, sum_sq_rewards=2023.361173989494, n_visits=1, children={}), 6: TSNode(partial_by_group=((1, 6),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.978517248864, sum_sq_rewards=2023.0670139063566, n_visits=1, children={}), 10: TSNode(partial_by_group=((1, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.98215201222307, sum_sq_rewards=2023.3939996507438, n_visits=1, children={})}), 6: TSNode(partial_by_group=((6,),), stopped_by_group=(False,), group_idx=0, n_obs=4, sum_rewards=-179.9366615474778, sum_sq_rewards=8094.300542564167, n_visits=4, children={-1: TSNode(partial_by_group=((6,),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.984208829998195, sum_sq_rewards=2023.5790440608876, n_visits=1, children={}), 10: TSNode(partial_by_group=((6, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.984455672746996, sum_sq_rewards=2023.6012521733394, n_visits=1, children={}), 11: TSNode(partial_by_group=((6, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.984322120217264, sum_sq_rewards=2023.5892366154683, n_visits=1, children={})}), 0: TSNode(partial_by_group=((0,),), stopped_by_group=(False,), group_idx=0, n_obs=12, sum_rewards=-536.5101837942188, sum_sq_rewards=23993.477758952868, n_visits=12, children={-1: TSNode(partial_by_group=((0,),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.989224182123394, sum_sq_rewards=2024.0302925093563, n_visits=1, children={}), 10: TSNode(partial_by_group=((0, 10),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.97905357297097, sum_sq_rewards=4048.1150415423776, n_visits=2, children={-1: TSNode(partial_by_group=((0, 10),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.990073863788155, sum_sq_rewards=2024.106746269114, n_visits=1, children={})}), 8: TSNode(partial_by_group=((0, 8),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.98829349193363, sum_sq_rewards=2023.946551316358, n_visits=1, children={}), 3: TSNode(partial_by_group=((0, 3),), stopped_by_group=(False,), group_idx=0, n_obs=3, sum_rewards=-131.8450186696531, sum_sq_rewards=5798.578587883327, n_visits=3, children={-1: TSNode(partial_by_group=((0, 3),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.800797441052126, sum_sq_rewards=2007.1114513541827, n_visits=1, children={}), 8: TSNode(partial_by_group=((0, 3, 8),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.77089239714546, sum_sq_rewards=2004.432806036777, n_visits=1, children={})}), 6: TSNode(partial_by_group=((0, 6),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.98718745685651, sum_sq_rewards=2023.8470352783477, n_visits=1, children={}), 1: TSNode(partial_by_group=((0, 1),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.8306338483135, sum_sq_rewards=2009.7857312415517, n_visits=1, children={}), 7: TSNode(partial_by_group=((0, 7),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.988844900716444, sum_sq_rewards=2023.99616550072, n_visits=1, children={}), 5: TSNode(partial_by_group=((0, 5),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.956166787710295, sum_sq_rewards=2021.0569322444262, n_visits=1, children={})})}),\n", + " TSNode(partial_by_group=((4,),), stopped_by_group=(False,), group_idx=0, n_obs=926, sum_rewards=-40739.97326770851, sum_sq_rewards=1792414.37179124, n_visits=926, children={-1: TSNode(partial_by_group=((4,),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-179.71237238492932, sum_sq_rewards=8074.134208274124, n_visits=4, children={}), 9: TSNode(partial_by_group=((4, 9),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.64147522964618, sum_sq_rewards=10092.758486977678, n_visits=5, children={-1: TSNode(partial_by_group=((4, 9),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92602886047754, sum_sq_rewards=2018.3480691724606, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.85822122290605, sum_sq_rewards=4037.249961126774, n_visits=2, children={-1: TSNode(partial_by_group=((4, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.9286339499872, sum_sq_rewards=2018.5821486119426, n_visits=1, children={})}), 10: TSNode(partial_by_group=((4, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92938243834094, sum_sq_rewards=2018.649406290699, n_visits=1, children={})}), 8: TSNode(partial_by_group=((4, 8),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.63801919832252, sum_sq_rewards=10092.447957339235, n_visits=5, children={-1: TSNode(partial_by_group=((4, 8),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92610698430661, sum_sq_rewards=2018.3550887653628, n_visits=1, children={}), 10: TSNode(partial_by_group=((4, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92895818413118, sum_sq_rewards=2018.6112835114084, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.9253558188764, sum_sq_rewards=2018.2875954526514, n_visits=1, children={}), 9: TSNode(partial_by_group=((4, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92640282640292, sum_sq_rewards=2018.3816709202245, n_visits=1, children={})}), 6: TSNode(partial_by_group=((4, 6),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.64620054718833, sum_sq_rewards=10093.183096768624, n_visits=5, children={-1: TSNode(partial_by_group=((4, 6),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.93141911884064, sum_sq_rewards=2018.832424032918, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 6, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.929658386626386, sum_sq_rewards=2018.6742027389466, n_visits=1, children={}), 9: TSNode(partial_by_group=((4, 6, 9),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.929975725775975, sum_sq_rewards=2018.7027187188185, n_visits=1, children={}), 7: TSNode(partial_by_group=((4, 6, 7),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92665365135187, sum_sq_rewards=2018.4042083085283, n_visits=1, children={})}), 10: TSNode(partial_by_group=((4, 10),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.6453332671047, sum_sq_rewards=10093.105161513364, n_visits=5, children={-1: TSNode(partial_by_group=((4, 10),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-89.8589212417643, sum_sq_rewards=4037.312871881153, n_visits=2, children={}), 11: TSNode(partial_by_group=((4, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-89.8569506631153, sum_sq_rewards=4037.135791332535, n_visits=2, children={-1: TSNode(partial_by_group=((4, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92825650978432, sum_sq_rewards=2018.548233008977, n_visits=1, children={})})}), 7: TSNode(partial_by_group=((4, 7),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.64246752734923, sum_sq_rewards=10092.847666325106, n_visits=5, children={-1: TSNode(partial_by_group=((4, 7),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.92866631320418, sum_sq_rewards=2018.5850566832482, n_visits=1, children={}), 9: TSNode(partial_by_group=((4, 7, 9),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.93065632824331, sum_sq_rewards=2018.7638780867105, n_visits=1, children={}), 10: TSNode(partial_by_group=((4, 7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.930696493988805, sum_sq_rewards=2018.767487434938, n_visits=1, children={}), 8: TSNode(partial_by_group=((4, 7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.92497729584855, sum_sq_rewards=2018.2535850325075, n_visits=1, children={})}), 5: TSNode(partial_by_group=((4, 5),), stopped_by_group=(False,), group_idx=0, n_obs=891, sum_rewards=-39167.47720350492, sum_sq_rewards=1721764.546765213, n_visits=891, children={-1: TSNode(partial_by_group=((4, 5),), stopped_by_group=(True,), group_idx=1, n_obs=89, sum_rewards=-3912.7613868103085, sum_sq_rewards=172019.22774163214, n_visits=89, children={}), 10: TSNode(partial_by_group=((4, 5, 10),), stopped_by_group=(False,), group_idx=0, n_obs=113, sum_rewards=-4967.487764600544, sum_sq_rewards=218371.24035568372, n_visits=113, children={-1: TSNode(partial_by_group=((4, 5, 10),), stopped_by_group=(True,), group_idx=1, n_obs=59, sum_rewards=-2593.495125985222, sum_sq_rewards=114003.73086613679, n_visits=59, children={}), 11: TSNode(partial_by_group=((4, 5, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=53, sum_rewards=-2330.0600732954176, sum_sq_rewards=102437.4391939591, n_visits=53, children={-1: TSNode(partial_by_group=((4, 5, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2286.1179279582257, sum_sq_rewards=100506.52705712414, n_visits=52, children={})})}), 11: TSNode(partial_by_group=((4, 5, 11),), stopped_by_group=(False,), group_idx=0, n_obs=89, sum_rewards=-3912.7125718756874, sum_sq_rewards=172014.92435577657, n_visits=89, children={-1: TSNode(partial_by_group=((4, 5, 11),), stopped_by_group=(True,), group_idx=1, n_obs=88, sum_rewards=-3868.7714891661, sum_sq_rewards=170084.10560608577, n_visits=88, children={})}), 6: TSNode(partial_by_group=((4, 5, 6),), stopped_by_group=(False,), group_idx=0, n_obs=74, sum_rewards=-3253.47243268099, sum_sq_rewards=143041.75710217294, n_visits=74, children={-1: TSNode(partial_by_group=((4, 5, 6),), stopped_by_group=(True,), group_idx=1, n_obs=10, sum_rewards=-439.66411802568587, sum_sq_rewards=19330.465439171287, n_visits=10, children={}), 8: TSNode(partial_by_group=((4, 5, 6, 8),), stopped_by_group=(False,), group_idx=0, n_obs=16, sum_rewards=-703.3535630544754, sum_sq_rewards=30919.16275750841, n_visits=16, children={-1: TSNode(partial_by_group=((4, 5, 6, 8),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-175.8363854988221, sum_sq_rewards=7729.615020124167, n_visits=4, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=4, sum_rewards=-175.84781106588187, sum_sq_rewards=7730.616296158963, n_visits=4, children={-1: TSNode(partial_by_group=((4, 5, 6, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=3, sum_rewards=-131.85283673154302, sum_sq_rewards=5795.058529479829, n_visits=3, children={})}), 10: TSNode(partial_by_group=((4, 5, 6, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-219.8227981519812, sum_sq_rewards=9664.422789140159, n_visits=5, children={-1: TSNode(partial_by_group=((4, 5, 6, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-88.02325681303296, sum_sq_rewards=3874.046928997293, n_visits=2, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 8, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.82381529890702, sum_sq_rewards=3856.5113793941046, n_visits=2, children={})}), 9: TSNode(partial_by_group=((4, 5, 6, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-87.8621293322072, sum_sq_rewards=3859.8777774492482, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-43.95218402328926, sum_sq_rewards=1931.7944804170838, n_visits=1, children={})})}), 11: TSNode(partial_by_group=((4, 5, 6, 11),), stopped_by_group=(False,), group_idx=0, n_obs=17, sum_rewards=-747.2567353148665, sum_sq_rewards=32846.64274766709, n_visits=17, children={-1: TSNode(partial_by_group=((4, 5, 6, 11),), stopped_by_group=(True,), group_idx=1, n_obs=16, sum_rewards=-703.2868673295923, sum_sq_rewards=30913.29345702465, n_visits=16, children={})}), 7: TSNode(partial_by_group=((4, 5, 6, 7),), stopped_by_group=(False,), group_idx=0, n_obs=12, sum_rewards=-527.798461848819, sum_sq_rewards=23214.284110934812, n_visits=12, children={-1: TSNode(partial_by_group=((4, 5, 6, 7),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-87.93915992300722, sum_sq_rewards=3866.6480880716545, n_visits=2, children={}), 10: TSNode(partial_by_group=((4, 5, 6, 7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=3, sum_rewards=-131.99082540147066, sum_sq_rewards=5807.193616973811, n_visits=3, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 10),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.0023435605371, sum_sq_rewards=1936.2062388195407, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 7, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=1, sum_rewards=-43.972912323197995, sum_sq_rewards=1933.617018183658, n_visits=1, children={})}), 8: TSNode(partial_by_group=((4, 5, 6, 7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-87.89291074303364, sum_sq_rewards=3862.5830404748212, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 8),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-43.922361457541186, sum_sq_rewards=1929.1738360068996, n_visits=1, children={})}), 11: TSNode(partial_by_group=((4, 5, 6, 7, 11),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-88.01791838251545, sum_sq_rewards=3873.5779551264786, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-43.98685793673277, sum_sq_rewards=1934.8436711463105, n_visits=1, children={})}), 9: TSNode(partial_by_group=((4, 5, 6, 7, 9),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-87.94349197260667, sum_sq_rewards=3867.0355324076445, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 9),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.029375144195555, sum_sq_rewards=1938.5858755883053, n_visits=1, children={})})}), 10: TSNode(partial_by_group=((4, 5, 6, 10),), stopped_by_group=(False,), group_idx=0, n_obs=10, sum_rewards=-439.66987745161964, sum_sq_rewards=19330.969558831595, n_visits=10, children={-1: TSNode(partial_by_group=((4, 5, 6, 10),), stopped_by_group=(True,), group_idx=1, n_obs=6, sum_rewards=-263.78454656081044, sum_sq_rewards=11597.051932874065, n_visits=6, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=3, sum_rewards=-131.9134151478395, sum_sq_rewards=5800.388251850703, n_visits=3, children={-1: TSNode(partial_by_group=((4, 5, 6, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-87.88453758728016, sum_sq_rewards=3861.84619260798, n_visits=2, children={})})}), 9: TSNode(partial_by_group=((4, 5, 6, 9),), stopped_by_group=(False,), group_idx=0, n_obs=8, sum_rewards=-351.8109890573494, sum_sq_rewards=15471.381338727459, n_visits=8, children={-1: TSNode(partial_by_group=((4, 5, 6, 9),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-87.96851731041005, sum_sq_rewards=3869.232754546095, n_visits=2, children={}), 10: TSNode(partial_by_group=((4, 5, 6, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=4, sum_rewards=-175.91496034420223, sum_sq_rewards=7736.521615340711, n_visits=4, children={-1: TSNode(partial_by_group=((4, 5, 6, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.00361414890658, sum_sq_rewards=1936.3180581658512, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.9515428139761, sum_sq_rewards=3867.739243850584, n_visits=2, children={})}), 11: TSNode(partial_by_group=((4, 5, 6, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.004609718471485, sum_sq_rewards=1936.405676474995, n_visits=1, children={})})}), 7: TSNode(partial_by_group=((4, 5, 7),), stopped_by_group=(False,), group_idx=0, n_obs=261, sum_rewards=-11472.15312187526, sum_sq_rewards=504254.31677490985, n_visits=261, children={-1: TSNode(partial_by_group=((4, 5, 7),), stopped_by_group=(True,), group_idx=1, n_obs=45, sum_rewards=-1978.0154399179673, sum_sq_rewards=86945.49803190352, n_visits=45, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=54, sum_rewards=-2373.663953734188, sum_sq_rewards=104338.60066699992, n_visits=54, children={-1: TSNode(partial_by_group=((4, 5, 7, 10),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.9045444714175, sum_sq_rewards=50239.67937641118, n_visits=26, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=27, sum_rewards=-1186.815408600552, sum_sq_rewards=52167.846096387606, n_visits=27, children={-1: TSNode(partial_by_group=((4, 5, 7, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.8797542310706, sum_sq_rewards=50237.50437151309, n_visits=26, children={})})}), 9: TSNode(partial_by_group=((4, 5, 7, 9),), stopped_by_group=(False,), group_idx=0, n_obs=38, sum_rewards=-1670.4877740498455, sum_sq_rewards=73435.03131497768, n_visits=38, children={-1: TSNode(partial_by_group=((4, 5, 7, 9),), stopped_by_group=(True,), group_idx=1, n_obs=14, sum_rewards=-615.3066670101945, sum_sq_rewards=27043.037949194786, n_visits=14, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=15, sum_rewards=-659.3723206412016, sum_sq_rewards=28984.806842262173, n_visits=15, children={-1: TSNode(partial_by_group=((4, 5, 7, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=6, sum_rewards=-263.6360004678156, sum_sq_rewards=11583.993595170117, n_visits=6, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=8, sum_rewards=-351.7372275893211, sum_sq_rewards=15464.893098870949, n_visits=8, children={})}), 11: TSNode(partial_by_group=((4, 5, 7, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=8, sum_rewards=-351.8172262134047, sum_sq_rewards=15471.929156006257, n_visits=8, children={-1: TSNode(partial_by_group=((4, 5, 7, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=7, sum_rewards=-307.8029281771427, sum_sq_rewards=13534.670724381365, n_visits=7, children={})})}), 11: TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(False,), group_idx=0, n_obs=53, sum_rewards=-2329.4620246822856, sum_sq_rewards=102384.84021208857, n_visits=53, children={-1: TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2285.4914393212052, sum_sq_rewards=100451.42783509253, n_visits=52, children={})}), 8: TSNode(partial_by_group=((4, 5, 7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=70, sum_rewards=-3076.565577991281, sum_sq_rewards=135218.00988236978, n_visits=70, children={-1: TSNode(partial_by_group=((4, 5, 7, 8),), stopped_by_group=(True,), group_idx=1, n_obs=17, sum_rewards=-747.127493578221, sum_sq_rewards=32835.28378186662, n_visits=17, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=22, sum_rewards=-966.8382846117236, sum_sq_rewards=42489.845954838915, n_visits=22, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=21, sum_rewards=-922.920341731806, sum_sq_rewards=40561.06024803523, n_visits=21, children={})}), 9: TSNode(partial_by_group=((4, 5, 7, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=10, sum_rewards=-439.62629242740445, sum_sq_rewards=19327.13755483925, n_visits=10, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=5, sum_rewards=-219.89046816143235, sum_sq_rewards=9670.368169661724, n_visits=5, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 8, 9, 10),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.92074668263248, sum_sq_rewards=3865.02911757097, n_visits=2, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 9, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.91360021280931, sum_sq_rewards=3864.400552291376, n_visits=2, children={})}), 10: TSNode(partial_by_group=((4, 5, 7, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=20, sum_rewards=-878.9950879943008, sum_sq_rewards=38631.64121969437, n_visits=20, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=9, sum_rewards=-395.73839266807806, sum_sq_rewards=17400.996240694472, n_visits=9, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=10, sum_rewards=-439.34521488544067, sum_sq_rewards=19302.426864498713, n_visits=10, children={})})})}), 9: TSNode(partial_by_group=((4, 5, 9),), stopped_by_group=(False,), group_idx=0, n_obs=94, sum_rewards=-4132.514515339099, sum_sq_rewards=181677.5143163395, n_visits=94, children={-1: TSNode(partial_by_group=((4, 5, 9),), stopped_by_group=(True,), group_idx=1, n_obs=21, sum_rewards=-923.2496630550982, sum_sq_rewards=40590.03751593179, n_visits=21, children={}), 11: TSNode(partial_by_group=((4, 5, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=31, sum_rewards=-1362.9331668007874, sum_sq_rewards=59922.18876577904, n_visits=31, children={-1: TSNode(partial_by_group=((4, 5, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=30, sum_rewards=-1318.9713445609393, sum_sq_rewards=57989.546951131044, n_visits=30, children={})}), 10: TSNode(partial_by_group=((4, 5, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=41, sum_rewards=-1802.414938877525, sum_sq_rewards=79236.60740220043, n_visits=41, children={-1: TSNode(partial_by_group=((4, 5, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=22, sum_rewards=-967.1973728850329, sum_sq_rewards=42521.41229007139, n_visits=22, children={}), 11: TSNode(partial_by_group=((4, 5, 9, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=18, sum_rewards=-791.2899300890582, sum_sq_rewards=34785.5579160644, n_visits=18, children={-1: TSNode(partial_by_group=((4, 5, 9, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=17, sum_rewards=-747.3741921983864, sum_sq_rewards=32856.96588158221, n_visits=17, children={})})})}), 8: TSNode(partial_by_group=((4, 5, 8),), stopped_by_group=(False,), group_idx=0, n_obs=170, sum_rewards=-7472.359537464019, sum_sq_rewards=328448.16905515944, n_visits=170, children={-1: TSNode(partial_by_group=((4, 5, 8),), stopped_by_group=(True,), group_idx=1, n_obs=47, sum_rewards=-2066.0088489035124, sum_sq_rewards=90816.91547065707, n_visits=47, children={}), 11: TSNode(partial_by_group=((4, 5, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=27, sum_rewards=-1186.8973845709677, sum_sq_rewards=52175.054707278956, n_visits=27, children={-1: TSNode(partial_by_group=((4, 5, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.9747436443677, sum_sq_rewards=50245.85632131191, n_visits=26, children={})}), 10: TSNode(partial_by_group=((4, 5, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=56, sum_rewards=-2460.9770037747253, sum_sq_rewards=108150.18593153136, n_visits=56, children={-1: TSNode(partial_by_group=((4, 5, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.5694461425167, sum_sq_rewards=50210.21206557458, n_visits=26, children={}), 11: TSNode(partial_by_group=((4, 5, 8, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=29, sum_rewards=-1274.4010901642278, sum_sq_rewards=56003.4046869463, n_visits=29, children={-1: TSNode(partial_by_group=((4, 5, 8, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=28, sum_rewards=-1230.45112241069, sum_sq_rewards=54071.80502140927, n_visits=28, children={})})}), 9: TSNode(partial_by_group=((4, 5, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=39, sum_rewards=-1714.4664991327732, sum_sq_rewards=75369.15035441122, n_visits=39, children={-1: TSNode(partial_by_group=((4, 5, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=16, sum_rewards=-703.2787708937077, sum_sq_rewards=30912.583875859167, n_visits=16, children={}), 10: TSNode(partial_by_group=((4, 5, 8, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=13, sum_rewards=-571.512793859451, sum_sq_rewards=25125.15452660411, n_visits=13, children={-1: TSNode(partial_by_group=((4, 5, 8, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-175.85435917582294, sum_sq_rewards=7731.1903163133575, n_visits=4, children={}), 11: TSNode(partial_by_group=((4, 5, 8, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=8, sum_rewards=-351.7437208786052, sum_sq_rewards=15465.462121713665, n_visits=8, children={})}), 11: TSNode(partial_by_group=((4, 5, 8, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=9, sum_rewards=-395.70104514964174, sum_sq_rewards=17397.709017938054, n_visits=9, children={-1: TSNode(partial_by_group=((4, 5, 8, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=8, sum_rewards=-351.6898334141943, sum_sq_rewards=15460.722259515675, n_visits=8, children={})})})})}), 11: TSNode(partial_by_group=((4, 11),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-224.6420447794053, sum_sq_rewards=10092.809672320658, n_visits=5, children={-1: TSNode(partial_by_group=((4, 11),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-179.71615062026032, sum_sq_rewards=8074.473706321962, n_visits=4, children={})})}),\n", + " TSNode(partial_by_group=((4, 5),), stopped_by_group=(False,), group_idx=0, n_obs=891, sum_rewards=-39167.47720350492, sum_sq_rewards=1721764.546765213, n_visits=891, children={-1: TSNode(partial_by_group=((4, 5),), stopped_by_group=(True,), group_idx=1, n_obs=89, sum_rewards=-3912.7613868103085, sum_sq_rewards=172019.22774163214, n_visits=89, children={}), 10: TSNode(partial_by_group=((4, 5, 10),), stopped_by_group=(False,), group_idx=0, n_obs=113, sum_rewards=-4967.487764600544, sum_sq_rewards=218371.24035568372, n_visits=113, children={-1: TSNode(partial_by_group=((4, 5, 10),), stopped_by_group=(True,), group_idx=1, n_obs=59, sum_rewards=-2593.495125985222, sum_sq_rewards=114003.73086613679, n_visits=59, children={}), 11: TSNode(partial_by_group=((4, 5, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=53, sum_rewards=-2330.0600732954176, sum_sq_rewards=102437.4391939591, n_visits=53, children={-1: TSNode(partial_by_group=((4, 5, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2286.1179279582257, sum_sq_rewards=100506.52705712414, n_visits=52, children={})})}), 11: TSNode(partial_by_group=((4, 5, 11),), stopped_by_group=(False,), group_idx=0, n_obs=89, sum_rewards=-3912.7125718756874, sum_sq_rewards=172014.92435577657, n_visits=89, children={-1: TSNode(partial_by_group=((4, 5, 11),), stopped_by_group=(True,), group_idx=1, n_obs=88, sum_rewards=-3868.7714891661, sum_sq_rewards=170084.10560608577, n_visits=88, children={})}), 6: TSNode(partial_by_group=((4, 5, 6),), stopped_by_group=(False,), group_idx=0, n_obs=74, sum_rewards=-3253.47243268099, sum_sq_rewards=143041.75710217294, n_visits=74, children={-1: TSNode(partial_by_group=((4, 5, 6),), stopped_by_group=(True,), group_idx=1, n_obs=10, sum_rewards=-439.66411802568587, sum_sq_rewards=19330.465439171287, n_visits=10, children={}), 8: TSNode(partial_by_group=((4, 5, 6, 8),), stopped_by_group=(False,), group_idx=0, n_obs=16, sum_rewards=-703.3535630544754, sum_sq_rewards=30919.16275750841, n_visits=16, children={-1: TSNode(partial_by_group=((4, 5, 6, 8),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-175.8363854988221, sum_sq_rewards=7729.615020124167, n_visits=4, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=4, sum_rewards=-175.84781106588187, sum_sq_rewards=7730.616296158963, n_visits=4, children={-1: TSNode(partial_by_group=((4, 5, 6, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=3, sum_rewards=-131.85283673154302, sum_sq_rewards=5795.058529479829, n_visits=3, children={})}), 10: TSNode(partial_by_group=((4, 5, 6, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=5, sum_rewards=-219.8227981519812, sum_sq_rewards=9664.422789140159, n_visits=5, children={-1: TSNode(partial_by_group=((4, 5, 6, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-88.02325681303296, sum_sq_rewards=3874.046928997293, n_visits=2, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 8, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.82381529890702, sum_sq_rewards=3856.5113793941046, n_visits=2, children={})}), 9: TSNode(partial_by_group=((4, 5, 6, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-87.8621293322072, sum_sq_rewards=3859.8777774492482, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-43.95218402328926, sum_sq_rewards=1931.7944804170838, n_visits=1, children={})})}), 11: TSNode(partial_by_group=((4, 5, 6, 11),), stopped_by_group=(False,), group_idx=0, n_obs=17, sum_rewards=-747.2567353148665, sum_sq_rewards=32846.64274766709, n_visits=17, children={-1: TSNode(partial_by_group=((4, 5, 6, 11),), stopped_by_group=(True,), group_idx=1, n_obs=16, sum_rewards=-703.2868673295923, sum_sq_rewards=30913.29345702465, n_visits=16, children={})}), 7: TSNode(partial_by_group=((4, 5, 6, 7),), stopped_by_group=(False,), group_idx=0, n_obs=12, sum_rewards=-527.798461848819, sum_sq_rewards=23214.284110934812, n_visits=12, children={-1: TSNode(partial_by_group=((4, 5, 6, 7),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-87.93915992300722, sum_sq_rewards=3866.6480880716545, n_visits=2, children={}), 10: TSNode(partial_by_group=((4, 5, 6, 7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=3, sum_rewards=-131.99082540147066, sum_sq_rewards=5807.193616973811, n_visits=3, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 10),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.0023435605371, sum_sq_rewards=1936.2062388195407, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 7, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=1, sum_rewards=-43.972912323197995, sum_sq_rewards=1933.617018183658, n_visits=1, children={})}), 8: TSNode(partial_by_group=((4, 5, 6, 7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-87.89291074303364, sum_sq_rewards=3862.5830404748212, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 8),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-43.922361457541186, sum_sq_rewards=1929.1738360068996, n_visits=1, children={})}), 11: TSNode(partial_by_group=((4, 5, 6, 7, 11),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-88.01791838251545, sum_sq_rewards=3873.5779551264786, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-43.98685793673277, sum_sq_rewards=1934.8436711463105, n_visits=1, children={})}), 9: TSNode(partial_by_group=((4, 5, 6, 7, 9),), stopped_by_group=(False,), group_idx=0, n_obs=2, sum_rewards=-87.94349197260667, sum_sq_rewards=3867.0355324076445, n_visits=2, children={-1: TSNode(partial_by_group=((4, 5, 6, 7, 9),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.029375144195555, sum_sq_rewards=1938.5858755883053, n_visits=1, children={})})}), 10: TSNode(partial_by_group=((4, 5, 6, 10),), stopped_by_group=(False,), group_idx=0, n_obs=10, sum_rewards=-439.66987745161964, sum_sq_rewards=19330.969558831595, n_visits=10, children={-1: TSNode(partial_by_group=((4, 5, 6, 10),), stopped_by_group=(True,), group_idx=1, n_obs=6, sum_rewards=-263.78454656081044, sum_sq_rewards=11597.051932874065, n_visits=6, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=3, sum_rewards=-131.9134151478395, sum_sq_rewards=5800.388251850703, n_visits=3, children={-1: TSNode(partial_by_group=((4, 5, 6, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-87.88453758728016, sum_sq_rewards=3861.84619260798, n_visits=2, children={})})}), 9: TSNode(partial_by_group=((4, 5, 6, 9),), stopped_by_group=(False,), group_idx=0, n_obs=8, sum_rewards=-351.8109890573494, sum_sq_rewards=15471.381338727459, n_visits=8, children={-1: TSNode(partial_by_group=((4, 5, 6, 9),), stopped_by_group=(True,), group_idx=1, n_obs=2, sum_rewards=-87.96851731041005, sum_sq_rewards=3869.232754546095, n_visits=2, children={}), 10: TSNode(partial_by_group=((4, 5, 6, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=4, sum_rewards=-175.91496034420223, sum_sq_rewards=7736.521615340711, n_visits=4, children={-1: TSNode(partial_by_group=((4, 5, 6, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=1, sum_rewards=-44.00361414890658, sum_sq_rewards=1936.3180581658512, n_visits=1, children={}), 11: TSNode(partial_by_group=((4, 5, 6, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.9515428139761, sum_sq_rewards=3867.739243850584, n_visits=2, children={})}), 11: TSNode(partial_by_group=((4, 5, 6, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=1, sum_rewards=-44.004609718471485, sum_sq_rewards=1936.405676474995, n_visits=1, children={})})}), 7: TSNode(partial_by_group=((4, 5, 7),), stopped_by_group=(False,), group_idx=0, n_obs=261, sum_rewards=-11472.15312187526, sum_sq_rewards=504254.31677490985, n_visits=261, children={-1: TSNode(partial_by_group=((4, 5, 7),), stopped_by_group=(True,), group_idx=1, n_obs=45, sum_rewards=-1978.0154399179673, sum_sq_rewards=86945.49803190352, n_visits=45, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=54, sum_rewards=-2373.663953734188, sum_sq_rewards=104338.60066699992, n_visits=54, children={-1: TSNode(partial_by_group=((4, 5, 7, 10),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.9045444714175, sum_sq_rewards=50239.67937641118, n_visits=26, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=27, sum_rewards=-1186.815408600552, sum_sq_rewards=52167.846096387606, n_visits=27, children={-1: TSNode(partial_by_group=((4, 5, 7, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.8797542310706, sum_sq_rewards=50237.50437151309, n_visits=26, children={})})}), 9: TSNode(partial_by_group=((4, 5, 7, 9),), stopped_by_group=(False,), group_idx=0, n_obs=38, sum_rewards=-1670.4877740498455, sum_sq_rewards=73435.03131497768, n_visits=38, children={-1: TSNode(partial_by_group=((4, 5, 7, 9),), stopped_by_group=(True,), group_idx=1, n_obs=14, sum_rewards=-615.3066670101945, sum_sq_rewards=27043.037949194786, n_visits=14, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=15, sum_rewards=-659.3723206412016, sum_sq_rewards=28984.806842262173, n_visits=15, children={-1: TSNode(partial_by_group=((4, 5, 7, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=6, sum_rewards=-263.6360004678156, sum_sq_rewards=11583.993595170117, n_visits=6, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=8, sum_rewards=-351.7372275893211, sum_sq_rewards=15464.893098870949, n_visits=8, children={})}), 11: TSNode(partial_by_group=((4, 5, 7, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=8, sum_rewards=-351.8172262134047, sum_sq_rewards=15471.929156006257, n_visits=8, children={-1: TSNode(partial_by_group=((4, 5, 7, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=7, sum_rewards=-307.8029281771427, sum_sq_rewards=13534.670724381365, n_visits=7, children={})})}), 11: TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(False,), group_idx=0, n_obs=53, sum_rewards=-2329.4620246822856, sum_sq_rewards=102384.84021208857, n_visits=53, children={-1: TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2285.4914393212052, sum_sq_rewards=100451.42783509253, n_visits=52, children={})}), 8: TSNode(partial_by_group=((4, 5, 7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=70, sum_rewards=-3076.565577991281, sum_sq_rewards=135218.00988236978, n_visits=70, children={-1: TSNode(partial_by_group=((4, 5, 7, 8),), stopped_by_group=(True,), group_idx=1, n_obs=17, sum_rewards=-747.127493578221, sum_sq_rewards=32835.28378186662, n_visits=17, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=22, sum_rewards=-966.8382846117236, sum_sq_rewards=42489.845954838915, n_visits=22, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=21, sum_rewards=-922.920341731806, sum_sq_rewards=40561.06024803523, n_visits=21, children={})}), 9: TSNode(partial_by_group=((4, 5, 7, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=10, sum_rewards=-439.62629242740445, sum_sq_rewards=19327.13755483925, n_visits=10, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=5, sum_rewards=-219.89046816143235, sum_sq_rewards=9670.368169661724, n_visits=5, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 8, 9, 10),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.92074668263248, sum_sq_rewards=3865.02911757097, n_visits=2, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 9, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.91360021280931, sum_sq_rewards=3864.400552291376, n_visits=2, children={})}), 10: TSNode(partial_by_group=((4, 5, 7, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=20, sum_rewards=-878.9950879943008, sum_sq_rewards=38631.64121969437, n_visits=20, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=9, sum_rewards=-395.73839266807806, sum_sq_rewards=17400.996240694472, n_visits=9, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=10, sum_rewards=-439.34521488544067, sum_sq_rewards=19302.426864498713, n_visits=10, children={})})})}), 9: TSNode(partial_by_group=((4, 5, 9),), stopped_by_group=(False,), group_idx=0, n_obs=94, sum_rewards=-4132.514515339099, sum_sq_rewards=181677.5143163395, n_visits=94, children={-1: TSNode(partial_by_group=((4, 5, 9),), stopped_by_group=(True,), group_idx=1, n_obs=21, sum_rewards=-923.2496630550982, sum_sq_rewards=40590.03751593179, n_visits=21, children={}), 11: TSNode(partial_by_group=((4, 5, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=31, sum_rewards=-1362.9331668007874, sum_sq_rewards=59922.18876577904, n_visits=31, children={-1: TSNode(partial_by_group=((4, 5, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=30, sum_rewards=-1318.9713445609393, sum_sq_rewards=57989.546951131044, n_visits=30, children={})}), 10: TSNode(partial_by_group=((4, 5, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=41, sum_rewards=-1802.414938877525, sum_sq_rewards=79236.60740220043, n_visits=41, children={-1: TSNode(partial_by_group=((4, 5, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=22, sum_rewards=-967.1973728850329, sum_sq_rewards=42521.41229007139, n_visits=22, children={}), 11: TSNode(partial_by_group=((4, 5, 9, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=18, sum_rewards=-791.2899300890582, sum_sq_rewards=34785.5579160644, n_visits=18, children={-1: TSNode(partial_by_group=((4, 5, 9, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=17, sum_rewards=-747.3741921983864, sum_sq_rewards=32856.96588158221, n_visits=17, children={})})})}), 8: TSNode(partial_by_group=((4, 5, 8),), stopped_by_group=(False,), group_idx=0, n_obs=170, sum_rewards=-7472.359537464019, sum_sq_rewards=328448.16905515944, n_visits=170, children={-1: TSNode(partial_by_group=((4, 5, 8),), stopped_by_group=(True,), group_idx=1, n_obs=47, sum_rewards=-2066.0088489035124, sum_sq_rewards=90816.91547065707, n_visits=47, children={}), 11: TSNode(partial_by_group=((4, 5, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=27, sum_rewards=-1186.8973845709677, sum_sq_rewards=52175.054707278956, n_visits=27, children={-1: TSNode(partial_by_group=((4, 5, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.9747436443677, sum_sq_rewards=50245.85632131191, n_visits=26, children={})}), 10: TSNode(partial_by_group=((4, 5, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=56, sum_rewards=-2460.9770037747253, sum_sq_rewards=108150.18593153136, n_visits=56, children={-1: TSNode(partial_by_group=((4, 5, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.5694461425167, sum_sq_rewards=50210.21206557458, n_visits=26, children={}), 11: TSNode(partial_by_group=((4, 5, 8, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=29, sum_rewards=-1274.4010901642278, sum_sq_rewards=56003.4046869463, n_visits=29, children={-1: TSNode(partial_by_group=((4, 5, 8, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=28, sum_rewards=-1230.45112241069, sum_sq_rewards=54071.80502140927, n_visits=28, children={})})}), 9: TSNode(partial_by_group=((4, 5, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=39, sum_rewards=-1714.4664991327732, sum_sq_rewards=75369.15035441122, n_visits=39, children={-1: TSNode(partial_by_group=((4, 5, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=16, sum_rewards=-703.2787708937077, sum_sq_rewards=30912.583875859167, n_visits=16, children={}), 10: TSNode(partial_by_group=((4, 5, 8, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=13, sum_rewards=-571.512793859451, sum_sq_rewards=25125.15452660411, n_visits=13, children={-1: TSNode(partial_by_group=((4, 5, 8, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=4, sum_rewards=-175.85435917582294, sum_sq_rewards=7731.1903163133575, n_visits=4, children={}), 11: TSNode(partial_by_group=((4, 5, 8, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=8, sum_rewards=-351.7437208786052, sum_sq_rewards=15465.462121713665, n_visits=8, children={})}), 11: TSNode(partial_by_group=((4, 5, 8, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=9, sum_rewards=-395.70104514964174, sum_sq_rewards=17397.709017938054, n_visits=9, children={-1: TSNode(partial_by_group=((4, 5, 8, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=8, sum_rewards=-351.6898334141943, sum_sq_rewards=15460.722259515675, n_visits=8, children={})})})})}),\n", + " TSNode(partial_by_group=((4, 5, 7),), stopped_by_group=(False,), group_idx=0, n_obs=261, sum_rewards=-11472.15312187526, sum_sq_rewards=504254.31677490985, n_visits=261, children={-1: TSNode(partial_by_group=((4, 5, 7),), stopped_by_group=(True,), group_idx=1, n_obs=45, sum_rewards=-1978.0154399179673, sum_sq_rewards=86945.49803190352, n_visits=45, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 10),), stopped_by_group=(False,), group_idx=0, n_obs=54, sum_rewards=-2373.663953734188, sum_sq_rewards=104338.60066699992, n_visits=54, children={-1: TSNode(partial_by_group=((4, 5, 7, 10),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.9045444714175, sum_sq_rewards=50239.67937641118, n_visits=26, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 10, 11),), stopped_by_group=(False,), group_idx=0, n_obs=27, sum_rewards=-1186.815408600552, sum_sq_rewards=52167.846096387606, n_visits=27, children={-1: TSNode(partial_by_group=((4, 5, 7, 10, 11),), stopped_by_group=(True,), group_idx=1, n_obs=26, sum_rewards=-1142.8797542310706, sum_sq_rewards=50237.50437151309, n_visits=26, children={})})}), 9: TSNode(partial_by_group=((4, 5, 7, 9),), stopped_by_group=(False,), group_idx=0, n_obs=38, sum_rewards=-1670.4877740498455, sum_sq_rewards=73435.03131497768, n_visits=38, children={-1: TSNode(partial_by_group=((4, 5, 7, 9),), stopped_by_group=(True,), group_idx=1, n_obs=14, sum_rewards=-615.3066670101945, sum_sq_rewards=27043.037949194786, n_visits=14, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 9, 10),), stopped_by_group=(False,), group_idx=0, n_obs=15, sum_rewards=-659.3723206412016, sum_sq_rewards=28984.806842262173, n_visits=15, children={-1: TSNode(partial_by_group=((4, 5, 7, 9, 10),), stopped_by_group=(True,), group_idx=1, n_obs=6, sum_rewards=-263.6360004678156, sum_sq_rewards=11583.993595170117, n_visits=6, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 9, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=8, sum_rewards=-351.7372275893211, sum_sq_rewards=15464.893098870949, n_visits=8, children={})}), 11: TSNode(partial_by_group=((4, 5, 7, 9, 11),), stopped_by_group=(False,), group_idx=0, n_obs=8, sum_rewards=-351.8172262134047, sum_sq_rewards=15471.929156006257, n_visits=8, children={-1: TSNode(partial_by_group=((4, 5, 7, 9, 11),), stopped_by_group=(True,), group_idx=1, n_obs=7, sum_rewards=-307.8029281771427, sum_sq_rewards=13534.670724381365, n_visits=7, children={})})}), 11: TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(False,), group_idx=0, n_obs=53, sum_rewards=-2329.4620246822856, sum_sq_rewards=102384.84021208857, n_visits=53, children={-1: TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2285.4914393212052, sum_sq_rewards=100451.42783509253, n_visits=52, children={})}), 8: TSNode(partial_by_group=((4, 5, 7, 8),), stopped_by_group=(False,), group_idx=0, n_obs=70, sum_rewards=-3076.565577991281, sum_sq_rewards=135218.00988236978, n_visits=70, children={-1: TSNode(partial_by_group=((4, 5, 7, 8),), stopped_by_group=(True,), group_idx=1, n_obs=17, sum_rewards=-747.127493578221, sum_sq_rewards=32835.28378186662, n_visits=17, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 11),), stopped_by_group=(False,), group_idx=0, n_obs=22, sum_rewards=-966.8382846117236, sum_sq_rewards=42489.845954838915, n_visits=22, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 11),), stopped_by_group=(True,), group_idx=1, n_obs=21, sum_rewards=-922.920341731806, sum_sq_rewards=40561.06024803523, n_visits=21, children={})}), 9: TSNode(partial_by_group=((4, 5, 7, 8, 9),), stopped_by_group=(False,), group_idx=0, n_obs=10, sum_rewards=-439.62629242740445, sum_sq_rewards=19327.13755483925, n_visits=10, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 9),), stopped_by_group=(True,), group_idx=1, n_obs=5, sum_rewards=-219.89046816143235, sum_sq_rewards=9670.368169661724, n_visits=5, children={}), 10: TSNode(partial_by_group=((4, 5, 7, 8, 9, 10),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.92074668263248, sum_sq_rewards=3865.02911757097, n_visits=2, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 9, 11),), stopped_by_group=(False,), group_idx=1, n_obs=2, sum_rewards=-87.91360021280931, sum_sq_rewards=3864.400552291376, n_visits=2, children={})}), 10: TSNode(partial_by_group=((4, 5, 7, 8, 10),), stopped_by_group=(False,), group_idx=0, n_obs=20, sum_rewards=-878.9950879943008, sum_sq_rewards=38631.64121969437, n_visits=20, children={-1: TSNode(partial_by_group=((4, 5, 7, 8, 10),), stopped_by_group=(True,), group_idx=1, n_obs=9, sum_rewards=-395.73839266807806, sum_sq_rewards=17400.996240694472, n_visits=9, children={}), 11: TSNode(partial_by_group=((4, 5, 7, 8, 10, 11),), stopped_by_group=(False,), group_idx=1, n_obs=10, sum_rewards=-439.34521488544067, sum_sq_rewards=19302.426864498713, n_visits=10, children={})})})}),\n", + " TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(False,), group_idx=0, n_obs=53, sum_rewards=-2329.4620246822856, sum_sq_rewards=102384.84021208857, n_visits=53, children={-1: TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2285.4914393212052, sum_sq_rewards=100451.42783509253, n_visits=52, children={})}),\n", + " TSNode(partial_by_group=((4, 5, 7, 11),), stopped_by_group=(True,), group_idx=1, n_obs=52, sum_rewards=-2285.4914393212052, sum_sq_rewards=100451.42783509253, n_visits=52, children={})])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mcts._select_and_expand()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[((2, 3, 5, 11), {}),\n", + " ((3, 5, 6, 8, 9, 10), {}),\n", + " ((3, 5, 8, 9, 10), {}),\n", + " ((3, 5, 7, 8, 9, 10), {}),\n", + " ((3, 5, 8, 9, 10, 11), {}),\n", + " ((3, 5, 6, 9, 10), {}),\n", + " ((3, 5, 6, 9, 10, 11), {}),\n", + " ((3, 5, 6, 8, 9, 11), {}),\n", + " ((3, 5, 6, 8, 9), {}),\n", + " ((3, 5, 6, 7, 8, 9), {})]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tracker.top_k(k=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.3403870.9393840.0000000.0000000.9888050.0000000.0000000.0000000.0591740.2835080.6511150.000000-0.093813True
10.0000000.0000000.2157870.6949210.6070900.0000000.6598540.8646030.0000000.0000000.0000000.683445-0.004209True
20.1622490.0000000.0000000.0000000.6701020.0000000.0000000.8375380.3137730.0000000.1942530.959623-0.003265True
30.6468610.1510700.0000000.0239440.0000000.0000000.0000000.0000000.1645060.1478150.0000000.930434-0.005103True
40.0000000.4927420.1909610.0000000.0581920.1650810.0000000.1383720.0000000.0000000.0000000.281132-0.040627True
.............................................
750.0000000.0000000.7512170.0000000.6833750.4044430.0000000.0000000.0000000.0000000.0000000.000000-0.338749True
760.0000000.0000001.0000001.0000000.0000000.5332760.0000000.0000000.0000001.0000000.0000000.000000-0.007243True
770.0000000.0000000.4359850.2888460.3115910.6914390.0000000.0000001.0000001.0000000.0000000.000000-2.591009True
780.0000000.0000000.4243290.4090350.3438970.7186910.0000000.0000000.5095360.6181830.0000000.000000-2.122833True
790.0000000.0000001.0000001.0000000.9290300.0000000.0000000.0000000.0000001.0000001.0000000.000000-0.000114True
\n", + "

80 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.340387 0.939384 0.000000 0.000000 0.988805 0.000000 0.000000 \n", + "1 0.000000 0.000000 0.215787 0.694921 0.607090 0.000000 0.659854 \n", + "2 0.162249 0.000000 0.000000 0.000000 0.670102 0.000000 0.000000 \n", + "3 0.646861 0.151070 0.000000 0.023944 0.000000 0.000000 0.000000 \n", + "4 0.000000 0.492742 0.190961 0.000000 0.058192 0.165081 0.000000 \n", + ".. ... ... ... ... ... ... ... \n", + "75 0.000000 0.000000 0.751217 0.000000 0.683375 0.404443 0.000000 \n", + "76 0.000000 0.000000 1.000000 1.000000 0.000000 0.533276 0.000000 \n", + "77 0.000000 0.000000 0.435985 0.288846 0.311591 0.691439 0.000000 \n", + "78 0.000000 0.000000 0.424329 0.409035 0.343897 0.718691 0.000000 \n", + "79 0.000000 0.000000 1.000000 1.000000 0.929030 0.000000 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.059174 0.283508 0.651115 0.000000 \n", + "1 0.864603 0.000000 0.000000 0.000000 0.683445 \n", + "2 0.837538 0.313773 0.000000 0.194253 0.959623 \n", + "3 0.000000 0.164506 0.147815 0.000000 0.930434 \n", + "4 0.138372 0.000000 0.000000 0.000000 0.281132 \n", + ".. ... ... ... ... ... \n", + "75 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "76 0.000000 0.000000 1.000000 0.000000 0.000000 \n", + "77 0.000000 1.000000 1.000000 0.000000 0.000000 \n", + "78 0.000000 0.509536 0.618183 0.000000 0.000000 \n", + "79 0.000000 0.000000 1.000000 1.000000 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.093813 True \n", + "1 -0.004209 True \n", + "2 -0.003265 True \n", + "3 -0.005103 True \n", + "4 -0.040627 True \n", + ".. ... ... \n", + "75 -0.338749 True \n", + "76 -0.007243 True \n", + "77 -2.591009 True \n", + "78 -2.122833 True \n", + "79 -0.000114 True \n", + "\n", + "[80 rows x 14 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34", + "metadata": {}, + "outputs": [], + "source": [ + "def reward_fn3(\n", + " selected_features: tuple[int, ...], cat_selections: dict[int, float]\n", + ") -> float:\n", + " fixed = {\n", + " i: 0.0\n", + " for i in range(len(benchmark.domain.inputs.get_keys()))\n", + " if i not in selected_features\n", + " }\n", + " candidates, acq_value = optimize_acqf(\n", + " acq_function=acqf,\n", + " bounds=bounds,\n", + " q=1,\n", + " num_restarts=20,\n", + " raw_samples=2048,\n", + " fixed_features=fixed,\n", + " )\n", + " return acq_value.item()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(2, 3, 4, 5, 9)\n", + "-41.97450586826323\n", + "-4.2207785897443415\n" + ] + } + ], + "source": [ + "leaf, path = mcts._select_and_expand()\n", + "selected_features, cat_selections = mcts._get_selection(leaf)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))\n", + "print(reward_fn2(selected_features, cat_selections={}))\n", + "# print(reward_fn3(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(0, 11)\n", + "-45.67180087515713\n", + "-45.67175933157002\n", + "-45.67174110576808\n" + ] + } + ], + "source": [ + "# leaf, path = mcts._select_and_expand()\n", + "# selected_features, cat_selections = mcts._get_selection(leaf)\n", + "selected_features = (0, 11)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))\n", + "print(reward_fn2(selected_features, cat_selections={}))\n", + "print(reward_fn3(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "80" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(strategy.experiments)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38", + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'SpuriousFeaturesWrapper' object has no attribute 'inputs'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[223]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[43mreward_fn2\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m7\u001b[39;49m\u001b[43m,\u001b[49m\u001b[32;43m11\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcat_selections\u001b[49m\u001b[43m=\u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m)\n", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[65]\u001b[39m\u001b[32m, line 27\u001b[39m, in \u001b[36mreward_fn2\u001b[39m\u001b[34m(selected_features, cat_selections)\u001b[39m\n\u001b[32m 26\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mreward_fn2\u001b[39m(selected_features: \u001b[38;5;28mtuple\u001b[39m[\u001b[38;5;28mint\u001b[39m, ...], cat_selections: \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mint\u001b[39m, \u001b[38;5;28mfloat\u001b[39m]) -> \u001b[38;5;28mfloat\u001b[39m:\n\u001b[32m---> \u001b[39m\u001b[32m27\u001b[39m fixed = {i: \u001b[32m0.0\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mlen\u001b[39m(\u001b[43mbenchmark\u001b[49m\u001b[43m.\u001b[49m\u001b[43minputs\u001b[49m.domain.get_keys())) \u001b[38;5;28;01mif\u001b[39;00m i \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m selected_features}\n\u001b[32m 28\u001b[39m candidates, acq_value = optimize_acqf(\n\u001b[32m 29\u001b[39m acq_function=acqf,\n\u001b[32m 30\u001b[39m bounds=bounds,\n\u001b[32m (...)\u001b[39m\u001b[32m 34\u001b[39m fixed_features=fixed,\n\u001b[32m 35\u001b[39m )\n\u001b[32m 36\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m acq_value.item()\n", + "\u001b[31mAttributeError\u001b[39m: 'SpuriousFeaturesWrapper' object has no attribute 'inputs'" + ] + } + ], + "source": [ + "print(reward_fn2((0, 2, 3, 7, 11), cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.93938361, 0. , 0. , 0.15107026, 0.49274244,\n", + " 0.34103735, 0.31871182, 0.27325514, 0. , 0.86132336,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 1. , 0. , 0. , 0. ,\n", + " 1. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0. , 0. ])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.x_1.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-0.09381345, -0.09381345, -0.09381345, -0.09381345, -0.09381345,\n", + " -0.09381345, -0.21600644, -0.21600644, -0.21600644, -0.21600644,\n", + " -0.21600644, -0.21600644, -0.27450036, -0.27450036, -0.42447708,\n", + " -0.5281177 , -0.69931485, -0.69931485, -0.69931485, -0.8353837 ,\n", + " -0.8353837 , -0.8353837 , -0.83551668, -1.61784249, -1.61784249,\n", + " -1.61784249, -1.91181011, -1.91181011, -1.91306396, -1.91306396,\n", + " -1.91306396, -1.91306396, -1.91306396, -1.91306396, -1.91306396,\n", + " -1.91306396, -1.91306396, -1.91306396, -1.91306396, -1.91306396,\n", + " -1.91306396, -1.91306396, -1.91306396, -1.91306396, -1.91306396,\n", + " -1.91306396, -1.91306396, -1.91306396, -1.91306396, -1.91306396,\n", + " -1.91306396, -1.91306396, -1.91306396, -1.91306396, -1.91339879,\n", + " -1.91339879, -1.91438567, -1.91438567, -1.91438567, -1.91456979,\n", + " -1.91456979, -1.91456979, -1.91456979, -1.91456979, -1.91462882,\n", + " -1.91462882, -1.91462882, -1.91462882, -1.91462882, -1.91462882,\n", + " -1.91462882, -1.91462882, -1.91462882, -1.91462882, -1.91462882,\n", + " -1.91462882, -1.91462882, -1.91462882, -1.91462882, -1.91462882,\n", + " -1.91462882, -1.91462882, -1.91462882, -1.91462882, -2.56435736,\n", + " -2.56435736, -2.56770656, -2.56770656, -2.56770656, -2.57161001,\n", + " -2.57161001, -2.58754068, -2.61216745, -2.61216745, -2.6217891 ,\n", + " -2.62238646, -2.62238646, -2.62238646, -2.62238646, -2.62238646,\n", + " -2.62238646, -2.62930695, -2.62930695, -2.62930695, -2.62930695,\n", + " -2.62930695, -2.62930695, -2.62930695, -2.62930695, -2.62930695,\n", + " -2.62930695, -2.62930695, -2.62930695, -2.62930695, -2.62930695,\n", + " -2.62930695, -2.62930695, -2.62930695, -2.62951138, -2.62951138,\n", + " -2.62951138, -2.62951138, -2.62951138, -2.62951138, -2.62951138,\n", + " -2.62951138, -2.62951138, -2.62951138, -2.62951138, -2.62951138,\n", + " -2.62959877, -2.62964234, -2.62964234, -2.62964234, -2.62964234,\n", + " -2.62964234, -2.62964234, -2.62964234, -2.62964234, -2.62964234,\n", + " -2.62964234, -2.62964234, -2.62964234, -2.62964234, -2.62964234,\n", + " -2.62964234, -2.62964234, -2.62964234, -2.62964234, -2.62964234,\n", + " -2.62964234, -2.62964234, -2.62964234, -2.62964234, -2.62964234,\n", + " -2.62964234, -2.62964234, -2.62964234, -2.62964234, -2.62964234,\n", + " -2.62964234, -2.62964234, -2.62964234, -2.62964234, -2.62964234,\n", + " -2.62964234, -2.62964234, -2.62964234, -2.62964234, -2.62964234])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "best_y_per_iteration = strategy.experiments[\"y\"].cummin()\n", + "best_y_per_iteration.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Best y per iteration')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkkAAAHHCAYAAACr0swBAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQRNJREFUeJzt3Qd8VFX2wPGTTihJgEBCCSWAgDQRLFgQAQEFBWWx4EqRBUF01y64KmLD3hV0VymKBXcVRXfxT7fRBBFBQEGa9JYECKS+/+dcmNkMmSSTkMnMvPf7fj6Pybx5k7nz8pI5nHvuvWGWZVkCAAAAD+GedwEAAECQBAAAUAQySQAAAF4QJAEAAHhBkAQAAOAFQRIAAIAXBEkAAABeECQBAAB4QZAEAADgBUESAPho4cKFEhYWZm5DxZQpU0ybt2zZEuimACGHIAmwAdcHYcGtdu3acumll8p///tfv71uZmamPPLIIyEVNJS3999/X1566aVAN0OefPJJmTlzZqCbAdhKGGu3AfYIkoYOHSqPPvqoNG7cWHRJxj179pj9a9eulVmzZkmfPn3K/XX3798vtWrVknHjxplgye7y8/MlOztboqOjJTz8xP8x9byuWbMm4JmaqlWryp/+9CfzMy8oLy9PcnJyJCYmxgTPAHwXWYpjAQS5yy+/XDp27Oi+P2zYMElKSpIPPvjAL0GSHR09elSqVKni9TENjCpVqlRhwVh5vFZERITZAJQe3W2AjSUkJEhsbKxERkYW+hDWLqJWrVqZD2INpG655RY5dOiQx3E//PCD9OzZUxITE8330SzVzTffbB7TzIlmkdT48ePd3XxFZZR+//138/iLL75Y6LHvv//ePKbBXEn1QB999JE88MADkpycbIKZq666SrZv317o+KVLl0qvXr0kPj5eKleuLJdccol89913HsdoW/V7/vLLLzJw4ECpXr26XHTRRSW2wdW92KVLF/nyyy9l69at7vffqFEj9/FZWVkmy9a0aVOTyUlJSZH77rvP7C9In3fbbbfJ9OnTzc9Ej509e7Z57LnnnpMLLrhAatasaX4GHTp0kH/961+Fnq/B3dSpU93tGDJkSLE1SW+88Yb7terWrSujR4+WtLQ0j2P0/bVu3dqcH+261fNYr149eeaZZ4o8R4CdkEkCbCQ9Pd10gWl32969e+XVV1+VI0eOyJ///GeP4zQgcnXR/fWvf5XNmzfLa6+9Jj/++KMJJKKioszze/ToYQKhMWPGmIBLP2g/+eQT8z10/8SJE2XUqFFy9dVXyzXXXGP2t23b1mvbUlNT5cILLzSBwJ133unxmO6rVq2a9O3bt8T3+MQTT5gP/fvvv9+0UYO97t27y6pVq0wQoebPn2+yahpQaJCiGaDJkydL165d5ZtvvpFzzz3X43sOGDBAmjVrZup69Nz56u9//7s553/88Yc7+NNuL1cgqgHct99+KyNGjJCWLVvKzz//bI779ddfC9UPaZtnzJhhgiUNSl3B1ssvv2y+z4033miySx9++KFp7xdffCG9e/c2x7z77rvyl7/8xbwvfS3VpEmTItutwaEGtnre9Oe3YcMG87Ncvny5++fvooGzBpv687322mtNgKbnvk2bNuYcA7amNUkAQtvkyZP1k73QFhMTY02ZMsXj2G+++cY8Nn36dI/9s2fP9tj/6aefmvvLly8v8nX37dtnjhk3bpxP7XzzzTfN8evWrXPvy87OthITE63BgwcX+9wFCxaY59arV8/KyMhw758xY4bZ//LLL5v7+fn5VrNmzayePXuar10yMzOtxo0bW5dddpl7n7Zbn3vDDTf41H5XG/TWpXfv3lbDhg0LHfvuu+9a4eHh5nwXNGnSJPM9vvvuO/c+va/Hrl27ttD30XYXpOerdevWVteuXT32V6lSxes5dF0bmzdvNvf37t1rRUdHWz169LDy8vLcx7322mvmuHfeece975JLLjH7pk2b5t6XlZVlJScnW/379y/mTAH2QHcbYCOvv/66zJkzx2zvvfee6SLRDIMr+6M+/vhj0wV12WWXmayTa9Osi2ZBFixYYI7TzJHSjIUW/pYHzURo955mjly++uor8/qnZruKMmjQIJN1ctFi5Tp16sh//vMfc18zSr/99pvpPjtw4ID7/Wl3VLdu3eTrr782WZ6CRo4cKeVNz7Nmj1q0aOFxnjWbpVzn2UW7A88888xC38eVHXNldTRzdfHFF8vKlSvL1K65c+eajNQdd9zhLj5Xw4cPl7i4ONN9WJBeEwV/Nlq0rhkr7T4F7I7uNsBG9MOrYOH2DTfcIO3btzddOFq4rR9wGkDoB61OEeCNdmG5PrT79+9vumW0i0jrU/r162eCD61jKQsNvK688kozbP6xxx4z+zRg0joXV/BQEu0WK0i73rTmx1Vzo+9PDR48uMjvoe9f649ctNaqvGk71q1b567bKuo8l9QGDVIff/xxE/wVrGUq60g1rZ9SzZs399iv14Z2iboed6lfv36h19Jzt3r16jK9PhBKCJIAG9NMgWaTtK5FP7S1UFezKBogFczmFOT6UNcPRq0/WbJkiZlCQDM+WrT9/PPPm32u2pvS0kyQZlm0WFvrWj7//HO59dZbPbIap8OVJXr22WflrLPO8nrMqW0vmK0pL9oOfX8vvPCC18e1iLukNmj9lNYjde7c2RRaa8ZM64W0vkoDzYpQ1Mi40tRuAaGKIAmwudzcXHOrBdyugl7tctEial+Cg/PPP99sWjCtH8xaQKzFw9qNV5ZshhYBayCmQdp5551nJqS86aabfH6+K1NU8MN648aN7oJxV8Gydh1pYbK/FXUOtB0//fST6eIra9bn3//+t+me1AC1YPZOgyRf23Gqhg0bmlst1tbMkYt2wWkBf0WcMyBUUJME2JjWEv3f//2f6UrR+hhXXZBOMOjq7jo1oHINA9f6l1OzBa7MjKvbR4eEq1OHjhdHpyPQbkAdyaUj7DTbUtSIOG+mTZsmhw8fdt/XbNeuXbvcI620tkoDFB067woMC9q3b5+UJ52GQLvvTqXneceOHfKPf/yj0GPHjh0zNVK+ZHE0+NGfl4t2K3qbWVvb4cvPQYMgvR5eeeUVj5/v22+/bd6Ha8QcADJJgK3oEiTr169317xo5kczLzqEXzMrrlojnQJgwoQJps5Fh/lrF44ep91g2jWnxdA654528ejwfg06NDDRD3z9PldccYX5XpqJ0mJjnbvojDPOkBo1aph5dXQrqctNP6S1ePnpp58u1XvU19C5jHT6Ap1VXKcA0JokLTxW2m33z3/+0wRN2r2ox2nNkwYs+nrafu0+LC8alOn7v+uuu+Scc84xXXlad6XZMQ0EtShcX1czdxrs6M9H92t2qGD9mDcasGh3nWbftBZMf6ZanK/v99SaIG2HZgj1eJ33SGucNFN3Ks3ijR071tSa6ffV7jzNKunPWtvvawE94AiBHl4HwD9TAFSqVMk666yzrIkTJ3oMhXd56623rA4dOlixsbFWtWrVrDZt2lj33XeftXPnTvP4ypUrzdD4Bg0amKkEateubfXp08f64YcfPL7P999/b76PDisvzXQArVq1MsPe//jjj1INv//ggw+ssWPHmvZo23UI/tatWwsd/+OPP1rXXHONVbNmTdN+HaZ/7bXXWvPmzSs0BYBOZVDWKQCOHDliDRw40EpISDCPFZwOQIfrP/300+a9ahuqV69uztX48eOt9PR093H6vNGjR3t9zbfffttMaaDPb9GihflZu9pd0Pr1663OnTubc6KPuaYDOHUKgIJD/vX7RUVFWUlJSdaoUaOsQ4cOeRyjUwBo20+l39vbtAeA3bB2G4CA0FF3mhWaN2+eT8frLNdahK7ZLs10AYC/UZMEoMLpcifa1afdbgAQrBjdBqDCrFmzRlasWGGmEdDh7Ndddx1nH0DQIpMEoMLoSDQtpNZRd7qYbXmscg8A/kJNEgAAgBdkkgAAALwgSAIAAPCCwm0f1l/auXOnWXW8rEsLAACAiqVTkOkkuDq5apnXhrRCjE6AppOY6cRq5557rrV06dJij58xY4bVvHlzc3zr1q2tL7/8slSvt3379kKT9LFxDrgGuAa4BrgGuAYkJM6Bfo6XVUhlklxT/0+aNMlMt6/LEfTs2dNMqa+rmp9KVxnXNaJ0+YU+ffqYJRr69esnK1euLHHZBBfNIKnt27e7l3UAAADBLSMjQ1JSUtyf47Yf3aaBka4t9Nprr7m7wvQE3H777WZtqlPpHCy6iOQXX3zh3qermesinRpo+XqS4+PjzcKPBEkAAISG8vj8DpnC7ezsbDMJna5g7aJ9jHp/8eLFXp+j+wserzTzVNTxrtXN9cQW3AAAgPOETJC0f/9+s4J2UlKSx369v3v3bq/P0f2lOV5p15xGnq5NM1UAAMB5QiZIqihjx441qTnXprVIAADAeUKmcDsxMVEiIiJkz549Hvv1fnJystfn6P7SHK9iYmLMBgAAnC1kMknR0dHSoUMHmTdvnnufFm7r/U6dOnl9ju4veLyaM2dOkccDAACEXCZJ6fD/wYMHS8eOHeXcc881UwDo6DVdMFMNGjRI6tWrZ+qK1N/+9je55JJLzIrjvXv3lg8//FB++OEHeeuttwL8TgAAQLALqSBJh/Tv27dPHn74YVN8rUP5Z8+e7S7O3rZtm8esmhdccIGZG+nBBx+UBx54QJo1ayYzZ870eY4kAADgXCE1T1IgME8SAAChx1HzJAEAAFQkgiQAAAAvCJIAAAC8IEgCAAAI9dFtdnLoaLYczc4t9fPCw8KkTnwlCQsL80u7AADACQRJAfLs/22Q95duK9Nz+55VV16+vn25twkAAPwPQVKARIWHSUxk6Xo7da6G7Nx8mb9+r+jMDWSTAADwH4KkABnft7XZSiMrN0/OfPgrOXw8V/YezpKkuEp+ax8AAE5H4XYIiYmMkIY1K5uvf9tzJNDNAQDA1giSQkyz2lXN7a97Dge6KQAA2BpBUog5I6mauf1tL5kkAAD8iSApxDQ9mUnauJdMEgAA/kSQFGKa1T6RSfp1zxEzwg0AAPgHQVKISa1VRcLDRNKP5cj+I9mBbg4AALZFkBRiKkVFSIMaJ0e40eUGAIDfECSFoKYnu9yYBgAAAP8hSApBZySdKN4mkwQAgP8QJIWgZq4giQklAQDwG4KkEB7htpG5kgAA8BuCpBDUpFZVCQsTOXA0Ww4cyQp0cwAAsCWCpBAUGx0h9avHmq+ZeRsAAP+I9NP3RQV0uW0/eEyGTl4u0ZG+x7qagfrzeQ3lnp7N/do+AABCHUFSiOrcLFHmr98rx3LyzFYab339uwzvnCrxsVF+ax8AAKGOIClEDbmwsXQ/M0mO5+SX6nmj3lthuuj+b+1uGdAxxW/tAwAg1BEkhbD61U/MvF0aV7WrK8/P+VVmrd5FkAQAQDEo3HaYPu3qmtvvNu5nZBwAAMUgSHKYxolVpE29eMnLt+S/a3YHujkAAAQtgiQHurJdHXM766edgW4KAABBiyDJgXq3PdHltmzLQVn6+wFZtytD0jKzA90sAACCCoXbDlQvIVY6NqwuP2w9JNe9tcTsqxIdId/c31VqVIkOdPMAAAgKZJIcanTXpmbW7sSqMRIZHiZHs/Pkl50ZgW4WAABBgyDJoS5tXlu+vb+r/PBgd7nkjFpm39aDRwPdLAAAggZBEiSlxon5lrYdzORsAABwEkESpMHJIGk7QRIAAG4ESZCGNU8ESVsPkEkCAMCFIAnuTNK2A5liWRZnBAAAgiQUrEk6nJUraZk5nBQAAAiSoCpFRUhSXIz5muJtAABOoLsNnl1uFG8DAGAQJMFoUKOKuSVIAgDgBIIkFCreBgAABEk4qUHN2BNBEt1tAAAYZJJg0N0GAIAngiR4dLftTD8m2bn5nBUAgOMRJMFIrBotlaMjROeS/OMQdUkAABAkwQgLC2MaAAAACiBIQqGZt1noFgAAgiQU0PBkkMRCtwAAECShgAY1mXUbAAAXuttQqLuNTBIAAARJKKB5UjVzu2nfETmek8e5AQA4GpkkuNWJryQ1qkRLbr4lv+45zJkBADgaQRI8pgFoVTfOfL1mRwZnBgDgaARJ8NCqbry5XbsznTMDAHA0giR4aF3vZCZpJ5kkAICzhUyQdPDgQbnxxhslLi5OEhISZNiwYXLkyJFin9OlSxfThVRwGzlyZIW1ORS1PplJWrcrQ3LyWMMNAOBcIRMkaYC0du1amTNnjnzxxRfy9ddfy4gRI0p83vDhw2XXrl3u7ZlnnqmQ9obyQrfVYiLNIrc6yg0AAKeKlBCwbt06mT17tixfvlw6duxo9r366qtyxRVXyHPPPSd169Yt8rmVK1eW5OTkCmxtaAsPD5OWdeNk2eaDsnZHhrRIPtH9BgCA04REJmnx4sWmi80VIKnu3btLeHi4LF26tNjnTp8+XRITE6V169YyduxYycxkhXtfu9zWULwNAHCwkMgk7d69W2rXru2xLzIyUmrUqGEeK8rAgQOlYcOGJtO0evVquf/++2XDhg3yySefFPmcrKwss7lkZGQ4tnhbM0kAADhVQIOkMWPGyNNPP11iV1tZFaxZatOmjdSpU0e6desmmzZtkiZNmnh9zoQJE2T8+PHiZK3r/W8agPx8y3TBAQDgNAENku6++24ZMmRIscekpqaamqK9e/d67M/NzTUj3kpTb3TeeeeZ240bNxYZJGmX3F133eWRSUpJSREnSU2sIjGR4XI0O0+2HDgqqbWqBrpJAAA4K0iqVauW2UrSqVMnSUtLkxUrVkiHDh3Mvvnz50t+fr478PHFqlWrzK1mlIoSExNjNieLjAiXlnXiZNX2NFm7M4MgCQDgSCFRuN2yZUvp1auXGc6/bNky+e677+S2226T66+/3j2ybceOHdKiRQvzuNIutccee8wEVlu2bJHPP/9cBg0aJJ07d5a2bdsG+B0FvzNPLk+yfjd1SQAAZwqJIMk1Sk2DIK0p0qH/F110kbz11lvux3NyckxRtmv0WnR0tMydO1d69Ohhnqdde/3795dZs2YF8F2EjoY1KpvbnWnHA90UAAACIiRGtykdyfb+++8X+XijRo3Esiz3fa0jWrRoUQW1zn7qJsSa2x1pxwLdFAAAAiJkMkkIUJB0iCAJAOBMBEnwqt7JIGl3xnHJy/9fhg4AAKcgSIJXtarFSGR4mAmQ9h6mLgkA4DwESfAqIjxMkuMrma93UpcEAHAggiT4ULxNJgkA4DwESSixLolMEgDAiQiSUKS6CXS3AQCciyAJJXa3kUkCADgRQRKKRE0SAMDJCJJQJGqSAABORpCEEjNJ6cdy5EhWLmcKAOAoBEkoUtWYSImPjTJfU5cEAHAagiQUi4VuAQBORZCEYtVjGgAAgEMRJKFYTAMAAHAqgiT4GCSxNAkAwFkIklAsapIAAE5FkIRiUZMEAHAqgiT4lEnanX5c8vItzhYAwDEiA90ABLfa1SpJRHiY5OZb0mnCPAkPC3M/ViUmQp4b0E7aN6ge0DYCAOAPBEkolgZIZzdIkOVbDsnew1mFHv/0xx0ESQAAWyJIQone+8t58tueIx775q3bKy/O/VU27z/KGQQA2BJBEkoUExkhrevFe+w7lpMnL84VgiQAgG1RuI0yaZxYxdzuSDsmx3PyOIsAANshSEKZ1KwSLdUqRYpliWw7mMlZBADYDkESyiQsLExST2aTft9HXRIAwH4IknDaXW4UbwMA7IggCWXWOLGqud2833PkGwAAdkCQhDJrXItMEgDAvgiSUGaNaxIkAQDsiyAJZdYosbK53X8kW9KP5XAmAQC2QpCEMqtWKUpqVYsxX29h5m0AgM0QJOG0MMINAGBXBEk4La65kpgGAABgNwRJOC1kkgAAdkWQhNNCkAQAsCuCJJyW1AJzJVm6kBsAADYRGegGILSl1Kgs4WEiR7JypfOzCyQ8LEzCCqzvZm5FpGntqvLGjWdLZARxOQAgNBAk4bTEREbIWSkJsnJbmmw/eKzI437ff1TW7z4srevFc8YBACGBIAmnbfpfzpdfdmWIyInuNlevm6vz7bb3V8qejCzJys3nbAMAQgZBEk5bbHSEdGhYvdhJJzVIyskjSAIAhA4KROB3USfrkLLJJAEAQghBEvwuOvLEZUYmCQAQSgiS4HfRESdGuZFJAgCEEoIkVFgmKZuaJABACCFIgt9RkwQACEUESfC7aFfhNpkkAEAIIUhCxRVuM7oNABBCCJLgd2SSAAChiCAJFTgFAAvgAgBCB0ESKqxwm2VJAAChhCAJFTcFADVJAIAQQpAEv2PGbQBAKCJIgt8xTxIAIBQRJMHvYli7DQAQggiS4HdRrN0GAAhBIRMkPfHEE3LBBRdI5cqVJSEhwafnWJYlDz/8sNSpU0diY2Ole/fu8ttvv/m9rfDEPEkAgFAUMkFSdna2DBgwQEaNGuXzc5555hl55ZVXZNKkSbJ06VKpUqWK9OzZU44fP+7XtsJTFKPbAAAhKFJCxPjx483tlClTfM4ivfTSS/Lggw9K3759zb5p06ZJUlKSzJw5U66//nq/thf/QyYJABCKQiaTVFqbN2+W3bt3my42l/j4eDnvvPNk8eLFRT4vKytLMjIyPDacHqYAAACEItsGSRogKc0cFaT3XY95M2HCBBNMubaUlBS/t9UxmSQmkwQAhJCABkljxoyRsLCwYrf169dXaJvGjh0r6enp7m379u0V+vq2nnGbtdsAACEkoDVJd999twwZMqTYY1JTU8v0vZOTk83tnj17zOg2F71/1llnFfm8mJgYs6H8MJkkACAUBTRIqlWrltn8oXHjxiZQmjdvnjso0voiHeVWmhFyKM+12/I4nQCAkBEyNUnbtm2TVatWmdu8vDzztW5HjhxxH9OiRQv59NNPzdfaVXfHHXfI448/Lp9//rn8/PPPMmjQIKlbt67069cvgO/EyYXbVqCbAgCAfzJJOqxea3Rq164tlSpVkoqkk0JOnTrVfb99+/bmdsGCBdKlSxfz9YYNG0wdkct9990nR48elREjRkhaWppcdNFFMnv27Apvu9NRuA0ACEVhlkY+PsrPzzcBxtq1a6VZs2biBNpFp6PcNPiKi4sLdHNC0q97DkuPF7+WmlWiZcVDlwW6OQAAB8goh8/vUnW3hYeHm+DowIEDZXoxOBOF2wAAR9QkPfXUU3LvvffKmjVr/NMi2LYmKSsvP9BNAQDAf6PbtPg5MzNT2rVrJ9HR0Wbh2IIOHjxY2m8Jh9Qk5eTlm7o2LaoHAMB2QZKuhwaUJUjS6rfcfEuiIgiSAAA2DJIGDx7sn5bA9t1trmySq0YJAADbTiZ5/Phxyc7O9tjHCDCcqmDmSNdvqxzNOQIABL9S/5de5x267bbbzFxJVapUkerVq3tswKkiI8Il/GSclE3xNgDArkGSTtA4f/58mThxolnj7J///KeMHz/ezGQ9bdo0/7QSIY9pAAAAtu9umzVrlgmGdJbroUOHysUXXyxNmzaVhg0byvTp0+XGG2/0T0sR8nVJWbn5prsNAABbZpJ0iH9qaqq7/sg15F+X/Pj666/Lv4WwhRjWbwMA2D1I0gBp8+bN7gVlZ8yY4c4wJSQklH8LYQt0twEAbB8kaRfbTz/9ZL4eM2aMvP7662Y9tzvvvNPMxA0UNw0AhdsAANvWJGkw5NK9e3dZv369rFixwtQltW3btrzbB5sgkwQAsGUmqUaNGrJ//37z9c033yyHDx92P6YF29dccw0BEnyadZtMEgDAVkGSThiZkZFhvp46daqZRBIoS3dbDqPbAAB26m7r1KmT9OvXTzp06GAWKP3rX/9aaGFbl3feeae82wgbIJMEALBlkPTee+/Jiy++KJs2bTIruKenp5NNQtkyScy4DQCwU5CUlJQkTz31lPm6cePG8u6770rNmjX93TbYcP02nVASAABbjm5zzZEElAaZJACA7edJAsqCKQAAAKGGIAkVO5kk3W0AgBBBkIQKXruNmiQAQGggSEKFoLsNAGD7wm2Vl5cnM2fOlHXr1pn7rVq1kquuukoiIiLKu32w3TxJVqCbAgCAf4KkjRs3Su/eveWPP/6Q5s2bm30TJkyQlJQU+fLLL6VJkyal/ZZwgChqkgAAdu9u09m2U1NTZfv27bJy5Uqzbdu2zcyfpI8BxWeS8jhBAAB7ZpIWLVokS5YsMYveuujEkjrZ5IUXXlje7YPt1m6juw0AYNNMUkxMjBw+fLjQ/iNHjkh0dHR5tQs2w9ptAADbB0l9+vSRESNGyNKlS81it7ppZmnkyJGmeBsodp4kpgAAANg1SHrllVdMcXanTp2kUqVKZtNutqZNm8rLL7/sn1Yi5DEFAADA9jVJCQkJ8tlnn5lRbq4pAFq2bGmCJKAozLgNAHDEPElKgyLddM6kn3/+WQ4dOiTVq1cv39bBNqIiwswtM24DAGzb3XbHHXfI22+/bb7WAOmSSy6Rs88+28yTtHDhQn+0ETZaloS12wAAtg2S/vWvf0m7du3M17NmzZLff/9d1q9fL3feeaf8/e9/90cbYacpACjcBgDYNUjav3+/JCcnm6//85//yLXXXitnnHGG3HzzzabbDSiucDsrlwVuAQA2DZKSkpLkl19+MV1ts2fPlssuu8zsz8zMZO02lDhPEpkkAIBtC7eHDh1qskd16tSRsLAw6d69u9mv8ya1aNHCH22EndZuo7sNAGDXIOmRRx6R1q1bm7XbBgwYYGbgVhERETJmzBh/tBF2mnGb7jYAgJ2nAPjTn/5UaN/gwYPLoz2w+ei2nDzWbgMA2LQmCSgLZtwGAIQagiRUCNZuAwCEGoIkVHgmSRdFBgAg2BEkoUIzSYq6JACALYMkXYZk2rRpcuzYMf+0CLYu3FbMlQQAsGWQ1L59e7nnnnvMrNvDhw+XJUuW+KdlsGV3m2IaAACALYOkl156SXbu3CmTJ0+WvXv3SufOneXMM8+U5557Tvbs2eOfViLkRYSHmU2RSQIA2LYmKTIyUq655hr57LPP5I8//pCBAwfKQw89JCkpKdKvXz+ZP39++bcUIS8q4kSQxPptAADbF24vW7ZMxo0bJ88//7zUrl1bxo4dK4mJidKnTx/TJQcUxPptAABbz7itXWzvvvuu6W777bff5Morr5QPPvhAevbsadZyU0OGDJFevXqZLjjAhbmSAAC2DpLq168vTZo0kZtvvtkEQ7Vq1Sp0TNu2beWcc84przbCJli/DQBg6yBp3rx5cvHFFxd7TFxcnCxYsOB02gUbZ5Io3AYA2LImqaQACShpGgAKtwEAoYAZtxGATBLLkgAAgh9BEgKyfhsAAMGOIAkVP7qNIAkAYMcgSddty8rKKrQ/OzvbPOYvTzzxhFxwwQVSuXJlSUhI8Ok5OvpOpyUouOnUBAjs+m0UbgMAbBkkDR06VNLT0wvtP3z4sHnMXzQIGzBggIwaNapUz9OgaNeuXe5N53RCYNDdBgCw9RQAlmW5J40sSJcniY+PF38ZP368uZ0yZUqpnhcTE2MW40UQzZOUR00SAMBGQVL79u3dXVbdunUz67e55OXlyebNm4OyK2vhwoVmyZTq1atL165d5fHHH5eaNWsWebx2JRbsTszIyKigltpfFDVJAAA7Bkm6cK1atWqVWYKkatWq7seio6OlUaNG0r9/fwkmGrTpQryNGzeWTZs2yQMPPCCXX365LF68WCIiIrw+Z8KECe6sFcoXmSQAgC2DJF3IVmkwdP3115turNM1ZswYefrpp4s9Zt26ddKiRYsyfX9tp0ubNm3Mcim6pIpmlzQb5o0u0nvXXXd5ZJJSUlLK9PrwFB15ops2h9FtAAA71iRpl9W+ffvMGm5q2bJl8v7778uZZ54pI0aMKNX3uvvuu80ItOKkpqaWtonFfq/ExETZuHFjkUGSBn/lEQCiMDJJAABbB0kDBw40wdBNN90ku3fvlu7du0vr1q1l+vTp5v7DDz/s8/fSxXG9LZDrL1pcfuDAAalTp06FvSa8zJNE4TYAwI5TAKxZs0bOPfdc8/WMGTNMN9b3339vgqTSjjwrjW3btpl6KL3VQnH9WrcjR464j9FuuU8//dR8rfvvvfdeWbJkiWzZssUszNu3b19p2rSpqalCxWMKAACArTNJOTk57u6ouXPnylVXXeUOUHQeIn/RDNXUqVM9RtupBQsWSJcuXczXGzZscM/hpIXZq1evNs9JS0uTunXrSo8ePeSxxx6jOy3ga7cxBQAAwIZBUqtWrWTSpEnSu3dvmTNnjgk61M6dO4sdWn+6NEtVUqZK53ByiY2Nla+++spv7UHpkUkCANi6u01Ho7355psme3PDDTdIu3btzP7PP//c3Q0HFLcsCWu3AQBsmUnS4Gj//v1maLxO0Oiixdy6rhpQcnfb/zJ+AADYJpPk6tZasWKFySjpmm2uCSUJkuBLd1sW8yQBAOyYSdq6dauZyVpHmenyHZdddplUq1bNdMPpfa1XAoqbJ4nCbQCALTNJf/vb36Rjx45y6NAhUxztcvXVV5th9kBRWLsNAGDrTNI333xj5kXS7rWCdLmSHTt2lGfbYDPMuA0AsHUmKT8/30zm6G02a+12A0oa3UZ3GwDAlkGSTsj40ksvue+HhYWZ2a11AdwrrriivNsHGxZuHz6eK7/szPDY0jKzA908AAA8hFkFZ2D0gWaMdFkPfdpvv/1m6pP0VheO/frrr6V27dpiJzrVQXx8vJnJOy4uLtDNCWnLtxyUAZMWe30sNipCvr3/UqlZlcWFAQDB8fld6pqk+vXry08//SQfffSRudUs0rBhw+TGG2/0KOQGTtW6bryc26iGbDlw1GP/gaPZciwnTzbtO0qQBAAI3UyS05BJ8r9+r38nq7anyVs3dZAerZIr4BUBAHaXEYhM0oEDB9xrtG3fvl3+8Y9/yLFjx+TKK6+Uzp07l6kRcLb42Chzm34sJ9BNAQCg9IXbP//8sxnmrzVHLVq0kFWrVsk555wjL774orz11lvStWtXmTlzpq/fDnBLqEyQBAAI4SDpvvvukzZt2pjibF2/rU+fPtK7d2+TxtKJJW+55RZ56qmn/Nta2BKZJABAMPK5u2358uUyf/58adu2rbRr185kj2699VYJDz8RZ91+++1y/vnn+7OtsKmEk91taZl0twEAQjCTdPDgQUlOPlFUW7VqValSpYpUr17d/bh+7VrsFiiNOGqSAAChPpmkThxZ3H2gLBIqn1jiJo3CbQBAECnV6LYhQ4ZITMyJyf6OHz8uI0eONBkllZWV5Z8WwjHdbenMug0ACMUgafDgwR73//znPxc6ZtCgQeXTKjhKPKPbAAChHCRNnjzZvy2BY7kLt+luAwCE8gK3gL+mAMg4liP5+UwADwAIDgRJCJrRbRofHc7KDXRzAAAwCJIQcJWiIqRSVLg7mwQAQDAgSEJQSIg9OQ0AE0oCAIIEQRKCAkuTAACCDUESgmoagLRj2YFuCgAABkESggKZJABAsCFIQlBgkVsAQLAhSELQzZUEAEAwIEhCUEhw1SQxug0AECQIkhAU4iufnAKAwm0AQJAgSEJQoHAbABBsCJIQFCjcBgAEG4IkBAUKtwEAwYYgCcFVuM3oNgBAkCBIQlBlkjKz8yQ7Nz/QzQEAgCAJwaFapSgJCzvxdTrZJABAECCThKAQER4m1WIizdcESQCAYECQhKCRcHKupHTmSgIABAGCJAQN5koCAAQTgiQEDZYmAQAEE4IkBI24kyPcqEkCAAQDgiQEDWbdBgAEE4IkBA1qkgAAwYQgCUFXk0R3GwAgGBAkIWgkxJ6YAiAtMzvQTQEAQE7M3gcEUeH2zrTjsnzLQdEJuKtWipS4SlFSOTpCwsweL4rYbR4q/VMkrIgnFf+cYh4r4pnFPae0rxMZHm4m5AQAlB+CJARdd9uGPYdlwKTFgW5OyImJDJeqMZESFVFygtiXAM2XkKuogLK0fGqPT20OK/F79G1XV+7q0bwUrQPgVARJCBpnpSRIl+a1ZOuBTHM/L9+So1m5knE8R3LyrEA3L+hl5eZLVi5dlSV5+9vNBEkAfEKQhKBRKSpCpgw9t9B+y7JMwORNUaGTVcQDVhHPKOp48fP3L7r9pXu/Obn5kpmdJ0ezcyW3HAJKX85HUe+19N/Hl/b48FolPH7oaLYMm/qDCSYBwBcESQh62qUTGUG9TUlqVshPI3S5BgTk5luSm5cvkT50SwJwNv5KAHCEmMgI99fZeWSTAJSMIAmAI0RH/u/PXVYOQRKAkhEkAXAEnSIh8uQ0CdQlAfAFQRIAR02ToLJy8wLdFAAhICSCpC1btsiwYcOkcePGEhsbK02aNJFx48ZJdnbxw52PHz8uo0ePlpo1a0rVqlWlf//+smfPngprN4DgEhN1oi6JTBIA2wRJ69evl/z8fHnzzTdl7dq18uKLL8qkSZPkgQceKPZ5d955p8yaNUs+/vhjWbRokezcuVOuueaaCms3gCDNJFGTBMAuUwD06tXLbC6pqamyYcMGmThxojz33HNen5Oeni5vv/22vP/++9K1a1ezb/LkydKyZUtZsmSJnH/++RXWfgDBge42ALbLJBUVBNWoUaPIx1esWCE5OTnSvXt3974WLVpIgwYNZPHiope8yMrKkoyMDI8NgL2mAaC7DYBtg6SNGzfKq6++KrfcckuRx+zevVuio6MlISHBY39SUpJ5rCgTJkyQ+Ph495aSklKubQcQODFRFG4DCJEgacyYMWY25eI2rUcqaMeOHabrbcCAATJ8+PByb9PYsWNNlsq1bd++vdxfA0BgUJMEIGRqku6++24ZMmRIscdo/ZGLFl5feumlcsEFF8hbb71V7POSk5PN6Le0tDSPbJKObtPHihITE2M2APbtbmPGbQBBHyTVqlXLbL7QDJIGSB06dDAF2OHhxSfB9LioqCiZN2+eGfqvtNh727Zt0qlTp3JpP4DQQiYJgO1qkjRA6tKliym61tFs+/btM3VFBWuL9BgtzF62bJm5r/VEOrfSXXfdJQsWLDCF3EOHDjUBEiPbAGeiJgmA7aYAmDNnjinW1q1+/foej1mWZW51JJtmijIzM92P6XxKmnHSTJKOWuvZs6e88cYbFd5+AMGB0W0ASiPMckUZ8EqnANCslBZxx8XFcZaAEDbm36vlw+Xb5d6ezWX0pU0D3RwAQf75HRLdbQBQvjVJrN0GoGQESQAcI9q9wG1+oJsCIAQQJAFwDGqSAJQGQRIAx2DtNgClQZAEwHlTAOTQ3QagZARJAByD7jYApUGQBMAx6G4DUBoESQAcOOM23W0ASkaQBMB53W3UJAHwAUESAMeguw1AaRAkAXAMCrcBlAZBEgDH1SRlU5MEwAcESQAc2N1G4TaAkhEkAXBgdxsL3AIoGUESAOdlkhjdBsAHBEkAHIN5kgCUBkESAMeIjjhZuJ2XL/n5VqCbAyDIESQBcIyYqBM1Sa5ACQCKQ5AEwHE1SYq6JAAlIUgC4BiR4WESHnbia0a4ASgJQRIAxwgLC2PWbQA+I0gC4NARbsyVBKB4BEkAHFmXdJy5kgCUgCAJgKOwyC0AXxEkAXDo+m10twEoHkESAEfWJGWzyC2AEhAkAXAUutsA+IogCYBDu9uYcRtA8QiSADgzSMqhJglA8QiSADgK3W0AfEWQBMChk0nS3QageARJABwlOoIpAAD4hiAJgDMzScy4DaAEBEkAHIWaJAC+IkgC4CjMuA3AVwRJAByFTBIAXxEkAXAUapIA+IogCYCj0N0GwFcESQAche42AL4iSALgKKzdBsBXBEkAHFmTlJ3L2m0AikeQBMBR6G4D4CuCJADO7G5jxm0AJSBIAuAojG4D4CuCJACOEhMVYW6zcvMD3RQAQY4gCYCjMLoNgK8IkgA4SrS7JonRbQCKR5AEwFHIJAHwFUESAMdOAWBZVqCbAyCIESQBcORkkio7j+JtAEUjSALgyO42xQg3AMUhSALgKNERBYIkJpQEUAyCJACOEhYWxoSSAHxCkATAcRjhBsAXBEkAnDvrNt1tAEI9SNqyZYsMGzZMGjduLLGxsdKkSRMZN26cZGdnF/u8Ll26mNR6wW3kyJEV1m4AwZ1JYnQbgOJESghYv3695Ofny5tvvilNmzaVNWvWyPDhw+Xo0aPy3HPPFftcPe7RRx91369cuXIFtBhASHS3Mes2gFAPknr16mU2l9TUVNmwYYNMnDixxCBJg6Lk5OQKaCWAUJxQEgBCurvNm/T0dKlRo0aJx02fPl0SExOldevWMnbsWMnMzKyQ9gEI/gklCZIAhHwm6VQbN26UV199tcQs0sCBA6Vhw4ZSt25dWb16tdx///0mA/XJJ58U+ZysrCyzuWRkZJRr2wEE0+g2FrkFEKRB0pgxY+Tpp58u9ph169ZJixYt3Pd37Nhhut4GDBhg6o2KM2LECPfXbdq0kTp16ki3bt1k06ZNpvjbmwkTJsj48eNL/V4AhGB3G6PbABQjzArgCo/79u2TAwcOFHuM1h9FR0ebr3fu3GlGrJ1//vkyZcoUCQ8vXW+hFnpXrVpVZs+eLT179vQ5k5SSkmK69+Li4kr1egCC0/BpP8icX/bIk1e3kYHnNQh0cwD4gX5+x8fHn9bnd0AzSbVq1TKbLzSDdOmll0qHDh1k8uTJpQ6Q1KpVq8ytZpSKEhMTYzYA9kV3GwDbFG5rgKQZpAYNGpg6JM1A7d6922wFj9FuuWXLlpn72qX22GOPyYoVK8w8S59//rkMGjRIOnfuLG3btg3guwEQaIxuA2Cbwu05c+aYYm3d6tev7/GYq7cwJyfHFGW7Rq9pF93cuXPlpZdeMt1s2mXWv39/efDBBwPyHgAE4eg2apIAhHqQNGTIELMVp1GjRu6ASWlQtGjRogpoHYBQQ3cbANt0twFAeaK7DYBtMkkA4I9M0uc/7ZRV29MkNipCmtSqIk2TqklilROjaQuqFB0hLZPjJCkuxqwBCcAZCJIAOE5KjRNrOO47nGU29e3G/SU+L7FqtNSqVkk0TIqODJchFzSSfu3r+b29AAKDIAmA41zdvp7US4iVjOM55n7GsRzZuO+IbNxzRA4fzy10/KHMbNm074jsP5JtNpc7PlolK7cdkgd7n2mCJgD2QpAEwHEiwsOkU5OapXrO8Zw8Wb/7sAmo1NLNB+T1BZtk2uKtsm5Xhrw95ByJqxTlpxYDcNyM206ZsROAPc1bt0fu+HCVHM7KlS7Na8nbg88xARgAe3x+kx8GgDLq1jJJPhhxvlSKCpeFG/bJM1+t51wCNkJ3GwCchtb14uXZP7WT2z/4Ud5c9LvZl1StkkRFhEmX5rXdReIAQg9BEgCcpivb1TV1SW8s3OQOlFRE+C/St11dualTQ0moHC2R4WGmO861+doxFxsdIZWj+XMNVDRqkkpATRIAX+TnW/L2t5vl5x3p5v7ujOOybPPBcjl5GlBd1DTRBGNt68f7HFwVpXymegoLeBtO/zycfiPK41SebjN8D7f914bTVb1KtFSNiQy6z2+CpAo4yQCcafUfaTJx4SZZ8vsByc2zJM+yJDffkryTG4ATnry6jQw8r4EE2+c3+VsA8JO29RNk4p87eH2sNAOLN+8/Kl+s3iX/+XmX7D05+WVZnM5g5tMJ6U5nDHWg2ny63+D0zlfo/ZyUdRqvHhGkw8jIJJWATBIAAKGHKQAAAAD8JEgTXAAAAIFFkAQAAOAFQRIAAIAXBEkAAABeECQBAAB4QZAEAADgBUESAACAFwRJAAAAXhAkAQAAeEGQBAAA4AVBEgAAgBcESQAAAF4QJAEAAHgR6W0n/seyLHObkZHBaQEAIES4Prddn+NlQZBUgsOHD5vblJSUMp9kAAAQuM/x+Pj4Mj03zDqdEMsB8vPzZefOnVKtWjUJCwsr1whXA6/t27dLXFycOBnngnPBdcHvB38r+LtZ3p8fGt5ogFS3bl0JDy9bdRGZpBLoia1fv774i/5QnR4kuXAuOBdcF/x+8LeCv5vl+flR1gySC4XbAAAAXhAkAQAAeEGQFCAxMTEybtw4c+t0nAvOBdcFvx/8reDvZjB+flC4DQAA4AWZJAAAAC8IkgAAALwgSAIAAPCCIAkAAMALgqQAef3116VRo0ZSqVIlOe+882TZsmViZxMmTJBzzjnHzFxeu3Zt6devn2zYsMHjmC5duphZzQtuI0eOFLt55JFHCr3PFi1auB8/fvy4jB49WmrWrClVq1aV/v37y549e8SO9Hfg1HOhm75/u18TX3/9tVx55ZVmNmB9XzNnziw0W/DDDz8sderUkdjYWOnevbv89ttvHsccPHhQbrzxRjOJXkJCggwbNkyOHDkidjkPOTk5cv/990ubNm2kSpUq5phBgwaZVRBKuo6eeuopsds1MWTIkELvs1evXra7Jnw5F97+buj27LPPlut1QZAUAB999JHcddddZtjiypUrpV27dtKzZ0/Zu3ev2NWiRYvMB9+SJUtkzpw55o9fjx495OjRox7HDR8+XHbt2uXennnmGbGjVq1aebzPb7/91v3YnXfeKbNmzZKPP/7YnDf9QLjmmmvEjpYvX+5xHvTaUAMGDLD9NaHXvv7u63+YvNH3+corr8ikSZNk6dKlJkjQvxMaRLvoh+HatWvNefviiy/MB8uIESPELuchMzPT/I186KGHzO0nn3xi/nN11VVXFTr20Ucf9bhObr/9drHbNaE0KCr4Pj/44AOPx+1wTfhyLgqeA93eeecdEwTpfyrL9brQtdtQsc4991xr9OjR7vt5eXlW3bp1rQkTJjjmR7F3715dM9BatGiRe98ll1xi/e1vf7Psbty4cVa7du28PpaWlmZFRUVZH3/8sXvfunXrzLlavHixZXf682/SpImVn5/vqGtCf76ffvqp+76+/+TkZOvZZ5/1uDZiYmKsDz74wNz/5ZdfzPOWL1/uPua///2vFRYWZu3YscOyw3nwZtmyZea4rVu3uvc1bNjQevHFFy078XYuBg8ebPXt27fI59jxmvD1utDz0rVrV4995XFdkEmqYNnZ2bJixQqTOi+4PpzeX7x4sThFenq6ua1Ro4bH/unTp0tiYqK0bt1axo4da/4naUfabaJp5NTUVPM/v23btpn9em1olq3g9aFdcQ0aNLD99aG/G++9957cfPPNHotJO+WaKGjz5s2ye/duj+tA16DSrnnXdaC32p3SsWNH9zF6vP490cyTnf926PWh770g7UbRLur27dubLpfc3Fyxo4ULF5qShebNm8uoUaPkwIED7secek3s2bNHvvzyS9O1eKrTvS5Y4LaC7d+/X/Ly8iQpKcljv95fv369OEF+fr7ccccdcuGFF5oPPpeBAwdKw4YNTfCwevVqU4ugqXVNsduJftBNmTLF/JHT9O/48ePl4osvljVr1pgPxujo6EIfAHp96GN2pjUHaWlppu7CadfEqVw/a29/J1yP6a1+WBYUGRlp/uNh12tFuxr1Grjhhhs8FjP961//KmeffbZ5799//70JpvV364UXXhA70a427Xpv3LixbNq0SR544AG5/PLLTXAUERHhyGtCTZ061dS7nlqWUB7XBUESKpzWJmlAULAORxXsN9dCTS1Y7datm/lj0KRJE9v8pPSPmkvbtm1N0KSBwIwZM0yBrlO9/fbb5txoQOS0awIl0wzrtddeawraJ06c6PGY1ngW/J3S/2jccsstZsCInZZ+uv766z1+H/S96u+BZpf098Kp3nnnHZOR14FQ5X1d0N1WwbTbQCP+U0cr6f3k5GSxu9tuu80UEy5YsEDq169f7LEaPKiNGzeKnWnW6IwzzjDvU68B7XbSjIqTro+tW7fK3Llz5S9/+UuxxznlmnD9rIv7O6G3pw720K4EHd1kt2vFFSDpdaIFyQWzSEVdJ3outmzZInam3fX6meL6fXDSNeHyzTffmOxySX87ynpdECRVMI1kO3ToIPPmzfPoftL7nTp1ErvS//1pgPTpp5/K/PnzTbq4JKtWrTK3mj2wMx2eq5kRfZ96bURFRXlcH/oHQGuW7Hx9TJ482XQT9O7du9jjnHJN6O+HfqgVvA4yMjJMXYnrOtBbDaa1js1Ff7f074krmLRTgKR1fBpIa31JSfQ60TqcU7ue7OaPP/4wNUmu3wenXBOnZqD176aOhPPLdXFaZd8okw8//NCMUpkyZYoZjTBixAgrISHB2r17t23P6KhRo6z4+Hhr4cKF1q5du9xbZmameXzjxo3Wo48+av3www/W5s2brc8++8xKTU21OnfubNnN3Xffbc6Dvs/vvvvO6t69u5WYmGhG/KmRI0daDRo0sObPn2/OR6dOncxmVzq6U9/v/fff77Hf7tfE4cOHrR9//NFs+qf4hRdeMF+7Rm099dRT5u+Cvu/Vq1eb0TuNGze2jh075v4evXr1stq3b28tXbrU+vbbb61mzZpZN9xwg2WX85CdnW1dddVVVv369a1Vq1Z5/O3Iysoyz//+++/NCCZ9fNOmTdZ7771n1apVyxo0aJAVaoo7F/rYPffcY0a56u/D3LlzrbPPPtv8zI8fP26ra8KX3w+Vnp5uVa5c2Zo4caJ1qvK6LgiSAuTVV181HwzR0dFmSoAlS5ZYdqYXubdt8uTJ5vFt27aZD78aNWqYALJp06bWvffea34J7Oa6666z6tSpY3729erVM/c1IHDRD8Fbb73Vql69uvkDcPXVV5sPBbv66quvzLWwYcMGj/12vyYWLFjg9XdCh3m7pgF46KGHrKSkJPP+u3XrVugcHThwwHwAVq1a1YqLi7OGDh1qPlzsch40GCjqb4c+T61YscI677zzzH/CKlWqZLVs2dJ68sknPQIHO5wL/Q9ljx49zAe9ThOiw9uHDx9e6D/XdrgmfPn9UG+++aYVGxtrpsc4VXldF2H6j+95JwAAAGegJgkAAMALgiQAAAAvCJIAAAC8IEgCAADwgiAJAADAC4IkAAAALwiSAAAAvCBIAoASNGrUSF566SXOE+AwBEkAgsqQIUOkX79+5usuXbrIHXfcUWGvPWXKFLPg8KmWL18uI0aMqLB2AAgOkYFuAAD4W3Z2tllcuqxq1apVru0BEBrIJAEI2ozSokWL5OWXX5awsDCzbdmyxTy2Zs0aufzyy6Vq1aqSlJQkN910k+zfv9/9XM1A3XbbbSYLlZiYKD179jT7X3jhBWnTpo1UqVJFUlJS5NZbb5UjR46YxxYuXChDhw6V9PR09+s98sgjXrvbtm3bJn379jWvHxcXZ1ap37Nnj/txfd5ZZ50l7777rnlufHy8XH/99XL48OEKO38ATh9BEoCgpMFRp06dZPjw4bJr1y6zaWCTlpYmXbt2lfbt28sPP/wgs2fPNgGKBioFTZ061WSPvvvuO5k0aZLZFx4eLq+88oqsXbvWPD5//ny57777zGMXXHCBCYQ06HG93j333FOoXfn5+SZAOnjwoAni5syZI7///rtcd911Hsdt2rRJZs6cKV988YXZ9NinnnrKr+cMQPmiuw1AUNLsiwY5lStXluTkZPf+1157zQRITz75pHvfO++8YwKoX3/9Vc444wyzr1mzZvLMM894fM+C9U2a4Xn88cdl5MiR8sYbb5jX0tfUDFLB1zvVvHnz5Oeff5bNmzeb11TTpk2TVq1amdqlc845xx1MaY1TtWrVzH3Ndulzn3jiiXI7RwD8i0wSgJDy008/yYIFC0xXl2tr0aKFO3vj0qFDh0LPnTt3rnTr1k3q1atnghcNXA4cOCCZmZk+v/66detMcOQKkNSZZ55pCr71sYJBmCtAUnXq1JG9e/eW6T0DCAwySQBCitYQXXnllfL0008XekwDERetOypI65n69Okjo0aNMtmcGjVqyLfffivDhg0zhd2asSpPUVFRHvc1Q6XZJQChgyAJQNDSLrC8vDyPfWeffbb8+9//NpmayEjf/4StWLHCBCnPP/+8qU1SM2bMKPH1TtWyZUvZvn272VzZpF9++cXUSmlGCYB90N0GIGhpILR06VKTBdLRaxrkjB492hRN33DDDaYGSLvYvvrqKzMyrbgAp2nTppKTkyOvvvqqKbTWkWeugu6Cr6eZKq0d0tfz1g3XvXt3M0LuxhtvlJUrV8qyZctk0KBBcskll0jHjh39ch4ABAZBEoCgpaPLIiIiTIZG5yrSofd169Y1I9Y0IOrRo4cJWLQgW2uCXBkib9q1a2emANBuutatW8v06dNlwoQJHsfoCDct5NaRavp6pxZ+u7rNPvvsM6levbp07tzZBE2pqany0Ucf+eUcAAicMMuyrAC+PgAAQFAikwQAAOAFQRIAAIAXBEkAAABeECQBAAB4QZAEAADgBUESAACAFwRJAAAAXhAkAQAAeEGQBAAA4AVBEgAAgBcESQAAAF4QJAEAAEhh/w8GjTJFVr4w6AAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "plt.plot(best_y_per_iteration.values)\n", + "plt.xlabel(\"Iteration\")\n", + "plt.ylabel(\"Best y so far\")\n", + "plt.title(\"Best y per iteration\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--- Iteration 0 ---\n", + " n_experiments: 66\n", + " acqf probe at fixed point: -45.9420604508\n", + " acq_val from get_candidate: -5.48109119224731\n", + " new y value: [-2.97790422]\n", + " experiments: 66 -> 67\n", + " lengthscale[0:3]: [0.2937835994912956, 0.4041666833510878, 4.41966538654436]\n", + "\n", + "--- Iteration 1 ---\n", + " n_experiments: 67\n", + " acqf probe at fixed point: -45.9419739465\n", + " acq_val from get_candidate: -5.645335943157333\n", + " new y value: [-2.99606393]\n", + " experiments: 67 -> 68\n", + " lengthscale[0:3]: [0.30036564796021076, 0.3919503617733176, 4.339199268257635]\n", + "\n", + "--- Iteration 2 ---\n", + " n_experiments: 68\n", + " acqf probe at fixed point: -45.9430990850\n", + " acq_val from get_candidate: -7.4356479767381565\n", + " new y value: [-0.00015728]\n", + " experiments: 68 -> 69\n", + " lengthscale[0:3]: [0.3312261463319473, 0.39313917655484865, 3.5347083776042476]\n", + "\n", + "--- Iteration 3 ---\n", + " n_experiments: 69\n", + " acqf probe at fixed point: -45.9446675469\n", + " acq_val from get_candidate: -6.929772202048997\n", + " new y value: [-0.00067149]\n", + " experiments: 69 -> 70\n", + " lengthscale[0:3]: [0.29295722403909413, 0.4010253404562717, 4.663433093634748]\n", + "\n", + "--- Iteration 4 ---\n", + " n_experiments: 70\n", + " acqf probe at fixed point: -45.9402736516\n", + " acq_val from get_candidate: -7.836574244144488\n", + " new y value: [-0.0006946]\n", + " experiments: 70 -> 71\n", + " lengthscale[0:3]: [0.30060401163820644, 0.3944575465842801, 4.207188642906171]\n", + "\n", + "--- Iteration 5 ---\n", + " n_experiments: 71\n", + " acqf probe at fixed point: -45.9433603961\n", + " acq_val from get_candidate: -2.6929366641172288\n", + " new y value: [-3.11049333]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-08 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-07 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-05 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-04 to the diagonal\n", + " warnings.warn(\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-03 to the diagonal\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " experiments: 71 -> 72\n", + " lengthscale[0:3]: [0.30018737965437287, 0.4054900134182153, 13.665555661586506]\n", + "\n", + "--- Iteration 6 ---\n", + " n_experiments: 72\n", + " acqf probe at fixed point: -46.0112759163\n", + " acq_val from get_candidate: -6.4480336673718055\n", + " new y value: [-0.00106805]\n", + " experiments: 72 -> 73\n", + " lengthscale[0:3]: [0.2953749217570681, 0.4102799011484572, 13.598938811874865]\n", + "\n", + "--- Iteration 7 ---\n", + " n_experiments: 73\n", + " acqf probe at fixed point: -46.0108940378\n", + " acq_val from get_candidate: -39.15175009566557\n", + " new y value: [-3.00207657]\n", + " experiments: 73 -> 74\n", + " lengthscale[0:3]: [0.2970526387781363, 1.0787512449468613, 3.351634232309813]\n", + "\n", + "--- Iteration 8 ---\n", + " n_experiments: 74\n", + " acqf probe at fixed point: -46.0192865010\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed in `gen_candidates_scipy` with the following warning(s):\n", + "[OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .')]\n", + "Trying again with a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed on the second try, after generating a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n", + "/Users/j30607/sandbox/bofire-worktrees/feature-mcts/.venv/lib/python3.13/site-packages/botorch/optim/optimize.py:795: RuntimeWarning: Optimization failed in `gen_candidates_scipy` with the following warning(s):\n", + "[OptimizationWarning('Optimization failed within `scipy.optimize.minimize` with status 2 and message ABNORMAL: .')]\n", + "Trying again with a new set of initial conditions.\n", + " return _optimize_acqf_batch(opt_inputs=opt_inputs)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " acq_val from get_candidate: -12.164814561040743\n", + " new y value: [-0.01646253]\n", + " experiments: 74 -> 75\n", + " lengthscale[0:3]: [0.39385986766547937, 0.3448380489990462, 8.83546958707795]\n", + "\n", + "--- Iteration 9 ---\n", + " n_experiments: 75\n", + " acqf probe at fixed point: -46.0094130418\n", + " acq_val from get_candidate: -5.745335865515177\n", + " new y value: [-0.17019157]\n", + " experiments: 75 -> 76\n", + " lengthscale[0:3]: [0.3044109295808983, 0.3922305605059308, 14.09836815838481]\n", + "\n", + "--- Iteration 10 ---\n", + " n_experiments: 76\n", + " acqf probe at fixed point: -46.0100629356\n", + " acq_val from get_candidate: -45.89948778074979\n", + " new y value: [-0.16469201]\n", + " experiments: 76 -> 77\n", + " lengthscale[0:3]: [0.3055344125405825, 0.40015436028875384, 15.101534607314072]\n", + "\n", + "--- Iteration 11 ---\n", + " n_experiments: 77\n", + " acqf probe at fixed point: -46.0129252589\n", + " acq_val from get_candidate: -45.86736603648621\n", + " new y value: [-0.3761452]\n", + " experiments: 77 -> 78\n", + " lengthscale[0:3]: [0.03089655749767049, 108.44260342490087, 7.957079269897791]\n", + "\n", + "--- Iteration 12 ---\n", + " n_experiments: 78\n", + " acqf probe at fixed point: -45.9379630881\n", + " acq_val from get_candidate: -0.7785376643322541\n", + " new y value: [-0.00878947]\n", + " experiments: 78 -> 79\n", + " lengthscale[0:3]: [0.3050036490876489, 0.39526463962698116, 10.838574498067157]\n", + "\n", + "--- Iteration 13 ---\n", + " n_experiments: 79\n", + " acqf probe at fixed point: -46.0129851173\n", + " acq_val from get_candidate: -45.75768827975949\n", + " new y value: [-0.37977354]\n", + " experiments: 79 -> 80\n", + " lengthscale[0:3]: [0.3006522276959701, 0.3935762903405451, 14.842249471453774]\n", + "\n", + "--- Iteration 14 ---\n", + " n_experiments: 80\n", + " acqf probe at fixed point: -46.0135545348\n", + " acq_val from get_candidate: -11.812887247140306\n", + " new y value: [-0.0033548]\n", + " experiments: 80 -> 81\n", + " lengthscale[0:3]: [0.2196389815319199, 0.5123909611274864, 7.623324726501774]\n", + "\n", + "--- Iteration 15 ---\n", + " n_experiments: 81\n", + " acqf probe at fixed point: -46.0121977271\n", + " acq_val from get_candidate: -7.900848122053444\n", + " new y value: [-2.46071964]\n", + " experiments: 81 -> 82\n", + " lengthscale[0:3]: [0.3160045940298298, 0.3706480418244254, 14.861831419103268]\n", + "\n", + "--- Iteration 16 ---\n", + " n_experiments: 82\n", + " acqf probe at fixed point: -46.0142045800\n", + " acq_val from get_candidate: -45.75633524187079\n", + " new y value: [-0.37986192]\n", + " experiments: 82 -> 83\n", + " lengthscale[0:3]: [0.31168116688795683, 0.37187935816620016, 14.805823721437385]\n", + "\n", + "--- Iteration 17 ---\n", + " n_experiments: 83\n", + " acqf probe at fixed point: -46.0145541933\n", + " acq_val from get_candidate: -45.7564354110594\n", + " new y value: [-0.37953567]\n", + " experiments: 83 -> 84\n", + " lengthscale[0:3]: [0.3077234584278307, 0.3712324781376418, 11.870060602968802]\n", + "\n", + "--- Iteration 18 ---\n", + " n_experiments: 84\n", + " acqf probe at fixed point: -46.0133323711\n", + " acq_val from get_candidate: -45.75837938393406\n", + " new y value: [-0.37769154]\n", + " experiments: 84 -> 85\n", + " lengthscale[0:3]: [0.17143206289072715, 0.545940037880173, 0.8365451921167117]\n", + "\n", + "--- Iteration 19 ---\n", + " n_experiments: 85\n", + " acqf probe at fixed point: -45.3993822367\n", + " acq_val from get_candidate: -0.33510238302079465\n", + " new y value: [-2.99120487]\n", + " experiments: 85 -> 86\n", + " lengthscale[0:3]: [0.31513088488426055, 0.38201061105153294, 18.104990428689263]\n", + "\n", + "--- Iteration 20 ---\n", + " n_experiments: 86\n", + " acqf probe at fixed point: -46.0112114395\n", + " acq_val from get_candidate: -45.75640901072899\n", + " new y value: [-0.38357929]\n", + " experiments: 86 -> 87\n", + " lengthscale[0:3]: [0.301989578174714, 0.35224097931143306, 1.4549593860058594]\n", + "\n", + "--- Iteration 21 ---\n", + " n_experiments: 87\n", + " acqf probe at fixed point: -45.9529142233\n", + " acq_val from get_candidate: -2.9594742503756333\n", + " new y value: [-2.90055904]\n", + " experiments: 87 -> 88\n", + " lengthscale[0:3]: [0.12414469998203487, 0.4944761822780233, 0.9127515924007126]\n", + "\n", + "--- Iteration 22 ---\n", + " n_experiments: 88\n", + " acqf probe at fixed point: -9.2162054586\n", + " acq_val from get_candidate: -1.5305870841600848\n", + " new y value: [-2.9749772]\n", + " experiments: 88 -> 89\n", + " lengthscale[0:3]: [0.22503487233265027, 0.2691016257809672, 2.1812105898031273]\n", + "\n", + "--- Iteration 23 ---\n", + " n_experiments: 89\n", + " acqf probe at fixed point: -45.9539942651\n", + " acq_val from get_candidate: -5.513861895133451\n", + " new y value: [-0.00023194]\n", + " experiments: 89 -> 90\n", + " lengthscale[0:3]: [0.3195103170518832, 0.38685066238250077, 8.449608313147367]\n", + "\n", + "--- Iteration 24 ---\n", + " n_experiments: 90\n", + " acqf probe at fixed point: -46.0118833014\n", + " acq_val from get_candidate: -38.843144708007856\n", + " new y value: [-3.01303979]\n", + " experiments: 90 -> 91\n", + " lengthscale[0:3]: [0.4148201136989792, 0.2717726187241554, 1.2394691956390447]\n", + "\n", + "--- Iteration 25 ---\n", + " n_experiments: 91\n", + " acqf probe at fixed point: -45.9352343733\n", + " acq_val from get_candidate: -4.374878298424731\n", + " new y value: [-0.00096781]\n", + " experiments: 91 -> 92\n", + " lengthscale[0:3]: [0.32032115103464204, 0.39106494722973356, 13.63965179862593]\n", + "\n", + "--- Iteration 26 ---\n", + " n_experiments: 92\n", + " acqf probe at fixed point: -46.0111084633\n", + " acq_val from get_candidate: -45.75534578242555\n", + " new y value: [-0.38687496]\n", + " experiments: 92 -> 93\n", + " lengthscale[0:3]: [0.32200136414226704, 0.3856479669857732, 14.93704574089937]\n", + "\n", + "--- Iteration 27 ---\n", + " n_experiments: 93\n", + " acqf probe at fixed point: -46.0109705722\n", + " acq_val from get_candidate: -45.752005521140276\n", + " new y value: [-0.38132764]\n", + " experiments: 93 -> 94\n", + " lengthscale[0:3]: [0.3171100781183023, 0.39110930535487876, 16.031086296578092]\n", + "\n", + "--- Iteration 28 ---\n", + " n_experiments: 94\n", + " acqf probe at fixed point: -46.0103654171\n", + " acq_val from get_candidate: -38.87649071044438\n", + " new y value: [-3.01310733]\n", + " experiments: 94 -> 95\n", + " lengthscale[0:3]: [0.32143656231313106, 0.3899184664723671, 20.657003633942974]\n", + "\n", + "--- Iteration 29 ---\n", + " n_experiments: 95\n", + " acqf probe at fixed point: -46.0091122417\n", + " acq_val from get_candidate: -45.752113163467726\n", + " new y value: [-0.3860681]\n", + " experiments: 95 -> 96\n", + " lengthscale[0:3]: [0.3920155869694801, 0.025147022994274456, 5.885963226058412]\n" + ] + } + ], + "source": [ + "for iteration in range(30):\n", + " acqf_check = strategy._get_acqfs(n=1)[0]\n", + "\n", + " # Probe: evaluate acqf at a fixed test point to see if model actually changed\n", + " test_point = torch.zeros(1, 1, bounds.shape[1], **tkwargs)\n", + " test_point[..., 0] = 0.5\n", + " probe_val = acqf_check(test_point).item()\n", + "\n", + " print(f\"\\n--- Iteration {iteration} ---\")\n", + " print(f\" n_experiments: {len(strategy.experiments)}\")\n", + " print(f\" acqf probe at fixed point: {probe_val:.10f}\")\n", + "\n", + " candidate, acq_val = get_candidate()\n", + " print(f\" acq_val from get_candidate: {acq_val}\")\n", + "\n", + " new_experiments = benchmark.f(candidate, return_complete=True)\n", + " print(f\" new y value: {new_experiments['y'].values}\")\n", + "\n", + " n_before = len(strategy.experiments)\n", + " strategy.tell(new_experiments, replace=False)\n", + " n_after = len(strategy.experiments)\n", + " print(f\" experiments: {n_before} -> {n_after}\")\n", + "\n", + " # Check if model params actually changed\n", + " ls = strategy.model.covar_module.lengthscale\n", + " print(f\" lengthscale[0:3]: {ls[0, :3].tolist()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([-0.00568271, -0.09673502, -0.09673502, -0.09673502, -0.09673502,\n", + " -0.09713184, -0.61958705, -0.61958705, -0.61958705, -0.61958705,\n", + " -0.61958705, -0.61958705, -0.61958705, -0.61958705, -0.61958705,\n", + " -0.61958705, -0.61958705, -0.61958705, -1.00350188, -1.00350188,\n", + " -1.00350188, -1.00350188, -1.00350188, -1.00350188, -1.00350188,\n", + " -1.11410792, -1.11410792, -1.74733591, -1.74733591, -1.78129139,\n", + " -2.66609372, -2.70018492, -2.70018492, -2.70018492, -2.70018492,\n", + " -2.70018492, -2.70018492, -2.94932638, -2.94932638, -2.94932638,\n", + " -2.94932638, -2.94932638, -2.94932638, -2.94932638, -2.94932638,\n", + " -2.94932638, -2.94932638, -2.94932638, -2.94932638, -2.94932638,\n", + " -2.94932638, -2.94932638, -2.94932638, -2.94932638, -2.94932638,\n", + " -2.95942063, -2.95942063, -2.95942063, -2.95942063, -2.95942063,\n", + " -2.95942063, -2.95942063, -2.96859675, -2.9747409 , -2.9747409 ,\n", + " -2.9747409 , -2.97790422, -2.99606393, -2.99606393, -2.99606393,\n", + " -2.99606393, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333, -3.11049333, -3.11049333, -3.11049333, -3.11049333,\n", + " -3.11049333])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "best_y_per_iteration = strategy.experiments[\"y\"].cummin()\n", + "best_y_per_iteration.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Best y per iteration')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAHHCAYAAABTMjf2AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQgdJREFUeJzt3Qd4VFX6x/E3vZICgYTQm4YOi4ggAgqCCiqChaIUWVkQ1gIuwt9dFV1FXF1RUBQLRUFFVxHQxUVAVHoRBOkK0jtJIIHU+T/vgRkTSEISZzJzZ76f57kmM3Mnc3ITmV/Oec85fjabzSYAAAA+yN/dDQAAAHAXghAAAPBZBCEAAOCzCEIAAMBnEYQAAIDPIggBAACfRRACAAA+iyAEAAB8FkEIAAD4LIIQABTTt99+K35+fuajVUybNs20ec+ePe5uCuCRCEKAF7C/2eU9KlWqJNdff73897//ddnrpqeny9NPP22pYOBss2bNkgkTJri7GfL888/LnDlz3N0MwHL82GsM8I4gNHDgQHnmmWekVq1aolsIHjlyxNz/888/y7x586Rbt25Of93jx49LxYoV5amnnjKByNvl5uZKZmamBAcHi7//+b8j9bpu3rzZ7T0ukZGRcuedd5qfeV45OTmSlZUlISEhJiADyC/wotsALOzmm2+Wq666ynF70KBBEh8fLx9++KFLgpA3SktLk4iIiAIf0/ATGhpaZoHLGa8VEBBgDgAFY2gM8GIxMTESFhYmgYGBl7zR6nBOw4YNzZuthqW//OUvcurUqXznrV27Vrp06SJxcXHm62hv0/33328e0x4Q7Q1SY8eOdQzJFdYz9Ouvv5rHX3nllUseW758uXlMA9vl6nM+/vhj+b//+z9JSEgwgeW2226Tffv2XXL+qlWr5KabbpLo6GgJDw+X9u3by7Jly/Kdo23Vr7llyxbp06ePxMbGStu2bS/bBvtQYIcOHeTLL7+U3377zfH916xZ03F+RkaG6S2rW7eu6ZGpVq2ajBo1ytyflz5v+PDhMnPmTPMz0XMXLFhgHnvppZekTZs2UqFCBfMzaNGihXz66aeXPF8D3PTp0x3tGDBgQJE1Qm+88YbjtRITE2XYsGGSnJyc7xz9/ho1amSujw6z6nWsUqWKvPjii4VeI8Bq6BECvEhKSooZrtKhsaNHj8rEiRPlzJkzcu+99+Y7T0OPfTjtoYcekt27d8ukSZPkxx9/NGEhKCjIPL9z584m7IwePdqEKn0z/eyzz8zX0PsnT54sQ4cOlTvuuEN69Ohh7m/SpEmBbatdu7Zce+215s3+0UcfzfeY3leuXDm5/fbbL/s9Pvfcc+aN/fHHHzdt1EDXqVMn2bBhgwkKavHixaZ3TEODBhHtyZk6darccMMN8v3338vVV1+d72veddddUq9ePVNno9euuJ544glzzffv3+8IeDpEZQ+bGtJ++OEHGTx4sNSvX182bdpkztuxY8cl9Tza5tmzZ5tApMHTHqheffVV83X69u1reok++ugj09758+dL165dzTnvv/++/PnPfzbfl76WqlOnTqHt1gCo4VWvm/78tm/fbn6Wa9ascfz87TQca6DUn+/dd99tQphe+8aNG5trDFie1ggBsLapU6fqu/clR0hIiG3atGn5zv3+++/NYzNnzsx3/4IFC/Ld//nnn5vba9asKfR1jx07Zs556qmnitXOt956y5y/detWx32ZmZm2uLg4W//+/Yt87pIlS8xzq1SpYktNTXXcP3v2bHP/q6++am7n5uba6tWrZ+vSpYv53C49Pd1Wq1Yt24033ui4T9utz+3du3ex2m9vg36069q1q61GjRqXnPv+++/b/P39zfXO68033zRfY9myZY779Lae+/PPP1/ydbTdeen1atSoke2GG27Id39ERESB19D+u7F7925z++jRo7bg4GBb586dbTk5OY7zJk2aZM577733HPe1b9/e3DdjxgzHfRkZGbaEhARbz549i7hSgHUwNAZ4kddff10WLlxojg8++MAMZ2hPgb0XR33yySdmuOjGG280vUf2Q3tPtDdjyZIl5jztAVLa86DFts6gPQo6FKc9QHZff/21ef2Le60K069fP9N7ZKcFwpUrV5avvvrK3NaeoZ07d5qhrhMnTji+Px066tixo3z33XemtyavIUOGiLPpddZeoKSkpHzXWXullP062+nQXYMGDS75OvZeLnvvjPZAXXfddbJ+/fpSteubb74xPUuPPPKIo+BbPfDAAxIVFWWG+vLS34m8PxstFNeeJx3qBLwBQ2OAF9E3qLzF0r1795bmzZub4RYtltY3MQ0J+maq0+sLosNN9jfmnj17miEUHc7RepHu3bubgKF1JaWh4erWW281U86fffZZc5+GIq07sQeEy9EhrLx0mExrcOw1MPr9qf79+xf6NfT713ogO619cjZtx9atWx11VIVd58u1QYPoP//5TxPw8tYWlXYGmNYzqSuvvDLf/fq7ocOX9sftqlateslr6bX76aefSvX6gKchCAFeTP/i114hrTPRN2YtjtXeEA1BeXtl8rK/ceubn9aDrFy50ky/154bLZR++eWXzX32WpiS0h4d7S3RAmmtM5k7d648+OCD+Xon/gh7b8+//vUvadasWYHnXNz2vL0uzqLt0O/v3//+d4GPa+H05dqg9UxaH9SuXTtT3Kw9X1q/o/VOGibLQmEzzkpSSwV4MoIQ4OWys7PNRy2athfR6vCIFi4XJwBcc8015tAiZX3z1aJdLdjVIbfS9Epo4a2GLQ1irVq1Mosy3nfffcV+vr3HJ+8b8q5duxxF2vYiYR3m0WJgVyvsGmg7Nm7caIbjStt785///McMJWoIzdsLp0GouO24WI0aNcxHLZDWHiA7HS7TovmyuGaAJ6FGCPBiWtvzv//9zwx7aL2KvU5HF9mzD01dHJrsU6i1HuXiv/rtPSz2IRqdTq0unnZdFJ3Kr0N2OkNKZ65pr0lhM80KMmPGDDl9+rTjtvZaHTp0yDGDSWudNITotHN7+Mvr2LFj4kw6hV+H2i6m1/nAgQPy9ttvX/LY2bNnTc1ScXpjNODoz8tOhwALWkFa21Gcn4MGHf19eO211/L9fN99913zfdhnogG+gh4hwIvodhrbtm1z1KBoD472oOj0d+0hsdf+6PT5cePGmboTnSKvwy16ng5Z6TCaFiDrmjQ6HKNT4zVYaPjQN3X9Orfccov5WtqjpAW+urbPFVdcIeXLlzfrzuhxueExfSPWguHx48eX6HvU19C1fnTqv66erdPntUZIi32VDrG98847JhjpUKCepzVIGkr09bT9OtTnLBq89PsfMWKEtGzZ0gy7aR2U9nJp2NNCbH1d7YHTQKM/H71fe3ny1nMVREOJDq1pL5rWZunPVAvi9fu9uEZH26E9fXq+rgukNUfa43Yx7Y0bM2aMqf3Sr6tDb9o7pD9rbX9xi9YBr+HuaWsAXDN9PjQ01NasWTPb5MmT800jt5syZYqtRYsWtrCwMFu5cuVsjRs3to0aNcp28OBB8/j69evNtPLq1aubafiVKlWydevWzbZ27dp8X2f58uXm6+iU7JJMpW/YsKGZMr5///4STV3/8MMPbWPGjDHt0bbr9PXffvvtkvN//PFHW48ePWwVKlQw7dcp7nfffbdt0aJFl0yf12UASjt9/syZM7Y+ffrYYmJizGN5p9LrVPfx48eb71XbEBsba67V2LFjbSkpKY7z9HnDhg0r8DXfffddsxyAPj8pKcn8rO3tzmvbtm22du3amWuij9mn0l88fT7vdHn9ekFBQbb4+Hjb0KFDbadOncp3jk6f17ZfTL92QUsGAFbEXmMA3EJns2nvzqJFi4p1vq7mrIXf2mulPVYA4AzUCAEoc7p1hw7L6RAZALgTNUIAyozu0r5u3TozBV+ngt9zzz1cfQBuRY8QgDKjM7y0eFlns+kGq2WxkzsAFIUaIQAA4LPoEQIAAD6LIAQAAHwWxdLF2C/o4MGDZrfr0i6TDwAAypYu0aULweoCo0XtZUgQugwNQRdvjggAAKxh3759UrVq1UIfJwhdhvYE2S+kfYsCAADg2VJTU01Hhv19vDAEocuwD4dpCCIIAQBgLZcra6FYGgAA+CyCEAAA8FkEIQAA4LMIQgAAwGcRhAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQQgAAPgsywWh119/XWrWrCmhoaHSqlUrWb16dZHnf/LJJ5KUlGTOb9y4sXz11Vdl1lYAAODZLBWEPv74YxkxYoQ89dRTsn79emnatKl06dJFjh49WuD5y5cvl969e8ugQYPkxx9/lO7du5tj8+bNZd52AADgefxsuk+9RWgPUMuWLWXSpEnmdm5urtlQ7a9//auMHj36kvPvueceSUtLk/nz5zvuu+aaa6RZs2by5ptvFnvTtujoaElJSWGvMQAALKK479+W6RHKzMyUdevWSadOnRz3+fv7m9srVqwo8Dl6f97zlfYgFXa+ysjIMBcv7+EKmj+X7zouGdk5Lvn6AADg8iwThI4fPy45OTkSHx+f7369ffjw4QKfo/eX5Hw1btw4kyDth/Y4ucKDM9dLn3dWyX/WHXDJ1wcAAF4UhMrKmDFjTDea/di3b59LXqdlzfLm4+SluyQ7J9clrwEAALwkCMXFxUlAQIAcOXIk3/16OyEhocDn6P0lOV+FhISYscS8hyv0vrq6VIgIln0nz8rcjQdd8hoAAMBLglBwcLC0aNFCFi1a5LhPi6X1duvWrQt8jt6f93y1cOHCQs8vS2HBATLoulrm8ze+/UVycy1Tsw4AgNewTBBSOnX+7bfflunTp8vWrVtl6NChZlbYwIEDzeP9+vUzQ1t2Dz/8sCxYsEBefvll2bZtmzz99NOydu1aGT58uHiC+66pIVGhgbLr6Bn5+ufC65YAAIBrWCoI6XT4l156SZ588kkzBX7Dhg0m6NgLovfu3SuHDh1ynN+mTRuZNWuWTJkyxaw59Omnn8qcOXOkUaNG4gnKhQbJgGvP9wpNWrLLzCQDAABlx1LrCLmDq9cROpWWKdeOXyzpmTkydUBLuT6pktNfAwAAX5PqbesIeavYiGC595oa5vOJi3fSKwQAQBkiCHmAP19XS4ID/WX93mRZ8esJdzcHAACfEejuBkCkUrlQ6dWymsxY8ZtMWLhTcku5rJC/n4ifn5/5GODvJ1VjwyUhOpRLDABAIagR8pC9xg4kn5X2Ly6RbCdOow/095PvRl0viTFhTvuaAAB40/s3PUIeokpMmIy5pb58um5/qeuE9Gk5Npvk2myy/9RZyczOlR1HThOEAAAoBEHIgwxqW8scznDfu6vk+53H5cSZTKd8PQAAvBHF0l5Kt+9QJ9MIQgAAFIYg5KXKR4SYj8fTMtzdFAAAPBZByEtViLzQI8TQGAAAhSIIefnQ2AmGxgAAKBRByEtViDw/NEYQAgCgcAQhL1Xe3iN0hhohAAAKQxDyUnH2GiGGxgAAKBRByMt7hHRX+7OZOe5uDgAAHokg5KUiQwLNRq7qBFPoAQAoEEHIS+nmq46ZY0yhBwCgQAQhX1hLiDohAAAKRBDyhdWlmTkGAECBCEJeLI79xgAAKBJByAdmjjE0BgBAwQhCPrC69HGKpQEAKBBByIvZZ42dZPo8AAAFIgj5wKwx9hsDAKBgBCGf2G8s091NAQDAIxGEvFicYwd6Nl4FAKAgBCEf6BE6l5Ur6ZnZ7m4OAAAehyDkxcKDAyQ06MJ+YwyPAQBwCYKQ1+83Zh8eo04IAICLEYR8ZeYY22wAAHAJgpCvzByjRwgAgEsQhLycY2iMGiEAAC5BEPKRoTFWlwYA4FIEIR/ZZoMeIQAALkUQ8nLUCAEAUDiCkJdjdWkAAApHEPKRHqGTFEsDAHAJgpCPFEsfT8sUm83m7uYAAOBRCEI+Mn0+MztX0jJz3N0cAAA8CkHIy4UFB5g9xxSrSwMAkB9ByAcwcwwAgIIRhHxAhUhWlwYAoCAEIR9aVJHVpQEAyI8g5ENB6DhT6AEAsGYQOnnypPTt21eioqIkJiZGBg0aJGfOnCnyOR06dBA/P798x5AhQ8TXlHfsN5bp7qYAAOBRAt3dgOLSEHTo0CFZuHChZGVlycCBA2Xw4MEya9asIp/3wAMPyDPPPOO4HR4eLr47NEYQAgDAckFo69atsmDBAlmzZo1cddVV5r6JEyfKLbfcIi+99JIkJiYW+lwNPgkJCeLL7GsJHT+T4e6mAADgUSwxNLZixQozHGYPQapTp07i7+8vq1atKvK5M2fOlLi4OGnUqJGMGTNG0tPTizw/IyNDUlNT8x1Wx9AYAAAW7hE6fPiwVKpUKd99gYGBUr58efNYYfr06SM1atQwPUY//fSTPP7447J9+3b57LPPCn3OuHHjZOzYseJN4i70CJ2gWBoAAM8JQqNHj5bx48dfdlistLSGyK5x48ZSuXJl6dixo/zyyy9Sp06dAp+jvUYjRoxw3NYeoWrVqom39AjpfmNaNA4AANwchEaOHCkDBgwo8pzatWubGp+jR4/muz87O9vMJCtJ/U+rVq3Mx127dhUahEJCQszhjcXSmTm5cjojW6JCg9zdJAAAPIJbg1DFihXNcTmtW7eW5ORkWbdunbRo0cLct3jxYsnNzXWEm+LYsGGD+ag9Q74kNChAIoIDzKarJ89kEoQAALBSsXT9+vXlpptuMlPhV69eLcuWLZPhw4dLr169HDPGDhw4IElJSeZxpcNfzz77rAlPe/bskblz50q/fv2kXbt20qRJE/HZbTbSmDkGAIClgpB99pcGHa3x0Wnzbdu2lSlTpjge17WFtBDaPissODhYvvnmG+ncubN5ng7D9ezZU+bNmyc+vfEqBdMAAFhr1pjSGWJFLZ5Ys2ZNUwhspwXOS5cuLaPWeb64CwXTJ1hUEQAA6/UIwTk9QqwuDQDA7whCPlYjxOrSAAD8jiDkI9hvDACASxGEfEQFe40QxdIAADgQhHxEefs2GxRLAwBgvVljcM7Q2O7jZ2TwjLVecznrxUfKY52vZNsQAECpEIR8RNXYMAkK8JNzWbnyvy1HxFvo93Jzo8rSqEq0u5sCALAggpCPiAkPlo//0lq2HTot3uKt736R306ky/5TZwlCAIBSIQj5kD9VjzWHt/h+5zEThA6lnHV3UwAAFkWxNCwrMSbMfDyUcs7dTQEAWBRBCJZVOTrUfDyQTI8QAKB0CEKwfo8QQQgAUEoEIVi+R4ihMQBAaRGEYFlVLvQIHUk9J9k5ue5uDgDAgghCsKy4yBCzNlKuTeTo6Qx3NwcAYEEEIViWv7+fxEedHx47SJ0QAKAUCEKwtMTo88NjB5lCDwAoBYIQLC0x5kLBND1CAIBSIAjB0iqzqCIA4A8gCMHSEllUEQDwBxCEYGmVL9QIsd8YAKA0CEKwtMqOGiH2GwMAlBxBCF6xqOKJtEw5l5Xj7uYAACyGIARLiw4LkrCgAPM5W20AAEqKIARL8/PzyzM8xi70AICSIQjB8lhUEQBQWgQhWB6LKgIASosgBK+ZQn8whaExAEDJEITgNT1CB5lCDwAoIYIQLI9FFQEApUUQguUl2vcbo0cIAFBCBCF4zdDY6YxsST2X5e7mAAAshCAEywsPDjQLKyp6hQAAJUEQgleofGEXemaOAQBKgiAEr9pzjB4hAEBJEITgFezbbBxkmw0AQAkQhOAVWFQRAFAaBCF42TYb59zdFACAhRCE4FUbrx5imw0AQAkQhOBViyoeTDknNpvN3c0BAFgEQQheIT4qVPz8RDKzc+VEWqa7mwMAsAiCELxCcKC/xEWGmM+pEwIAFBdBCF44PHbW3U0BAFiEZYLQc889J23atJHw8HCJiYkp1nO0VuTJJ5+UypUrS1hYmHTq1El27tzp8rbCPRLtq0uzlhAAwNuCUGZmptx1110ydOjQYj/nxRdflNdee03efPNNWbVqlUREREiXLl3k3DmmWHvzWkKHUvj5AgCKJ1AsYuzYsebjtGnTit0bNGHCBPn73/8ut99+u7lvxowZEh8fL3PmzJFevXq5tL1w31pC9AgBALwuCJXU7t275fDhw2Y4zC46OlpatWolK1asIAh5cY3QbyfS5bcTaWX2upEhgVLhQqE2AMBavDYIaQhS2gOUl962P1aQjIwMc9ilpqa6sJVwxQ70mw6kSPt/fVumF3fG/VdLuysqlulrAgAsXiM0evRo8fPzK/LYtm1bmbZp3LhxpufIflSrVq1MXx+l1yAxSlrUiDU9NGV1hASe/1/onR9286MDAAtya4/QyJEjZcCAAUWeU7t27VJ97YSEBPPxyJEjZtaYnd5u1qxZoc8bM2aMjBgxIl+PEGHIGkICA+Q/Q9uU6WvuPZEu7f61RL7feUz2nUyXauXDy/T1AQAWDkIVK1Y0hyvUqlXLhKFFixY5go+GGp09VtTMs5CQEHMAxVG9Qri0rRsnP+w6LrPX7pORna/kwgGAhVhm+vzevXtlw4YN5mNOTo75XI8zZ844zklKSpLPP//cfK7Dao888oj885//lLlz58qmTZukX79+kpiYKN27d3fjdwJv0+vq88OnGoSyc3Ld3RwAgDcWS+vCiNOnT3fcbt68ufm4ZMkS6dChg/l8+/btkpKS4jhn1KhRkpaWJoMHD5bk5GRp27atLFiwQEJDzxfVAs7QuUGCVIgIliOpGbJk+zG5sUH+An0AgOfys7FVd5F0OE2LpjVgRUVFldXPBRbz/FdbZcp3v8oNSZXkvQEt3d0cAPB5qcV8/7bM0BjgyXq1PD889u32oyzoCAAWQhACnKB2xUhpVau85NrO1woBAKyBIAQ4SZ9W1c3H2Wv2SY4mIgCAxyMIAU7SpWGCxIQHycGUc/LdjmNcVwCwAIIQ4CShQQHSo3lV8/n0FXtkx5HTjmPnkdOSxdR6APA4lpk+D1hB76uryXvLdsu324+ZIy9dePGDP7dyW9sAAJeiRwhwonrx5UytUPmIYMcRHRZkHtu4L5lrDQAehh4hwMmev6OxOexOpmXKn55dKKczss3K04EB/P0BAJ6Cf5EBF4sK/f3vjdRz2VxvAPAgBCHAxbQHqFzI+TCUnJ7J9QYAD0IQAspAdPj5OqHks1lcbwCwahDSbcl09/dz5865rkWAF9L1hVRKOkEIACwdhOrWrSv79rGFAFAS9pljKfQIAYB1g5C/v7/Uq1dPTpw44boWAV4oJizYfKRGCAAsXiP0wgsvyN/+9jfZvHmza1oEeCFqhADAS9YR6tevn6Snp0vTpk0lODhYwsLC8j1+8uRJZ7YP8AoxF4bGkqkRAgBrB6EJEya4piWAD9QIpVIjBADWDkL9+/d3TUsAH5g1xvR5APCiLTZ0Gn1mZv4F4qKiov5omwCvE02xNAB4R7F0WlqaDB8+XCpVqiQRERESGxub7wBQ+NAYPUIAYPEgNGrUKFm8eLFMnjxZQkJC5J133pGxY8dKYmKizJgxwzWtBLxkaIwaIQCw+NDYvHnzTODp0KGDDBw4UK677jqzyGKNGjVk5syZ0rdvX9e0FPCGGqH0LLMwqZ+fn7ubBAAoTY+QTo+vXbu2ox7IPl2+bdu28t1333FRgSIWVMzOtUlaZg7XCACsGoQ0BO3evdt8npSUJLNnz3b0FMXExDi/hYAXCA3yl+CA8/+7sc0GAFg4COlw2MaNG83no0ePltdff11CQ0Pl0UcfNStOA7iUDoU5VpdOzz/TEgBgoRohDTx2nTp1km3btsm6detMnVCTJk2c3T7Aq1aXPnY6gx3oAcBqPULly5eX48ePm8/vv/9+OX36tOMxLZLu0aMHIQi4DBZVBACLBiFdNDE1NdV8Pn36dLOQIoDSrSVEjRAAWGxorHXr1tK9e3dp0aKFmfr70EMPXbLZqt17773n7DYCXra6dJa7mwIAKEkQ+uCDD+SVV16RX375xRR9pqSk0CsElHpojGJpALBUEIqPj5cXXnjBfF6rVi15//33pUKFCq5uG+B1xdIqhR4hALDurDH7GkIASsY+fZ4aIQCw8DpCAP7gxqv0CAGAxyAIAWUkJvxCsfRZiqUBwFMQhIAyrxGiWBoAPAVBCCgjrCMEAF5QLK1ycnJkzpw5snXrVnO7YcOGctttt0lAQICz2wd43fR53X0+MztXggP5OwQALBeEdu3aJV27dpX9+/fLlVdeae4bN26cVKtWTb788kupU6eOK9oJWF650CDx8xOx2c7PHKtYLsTdTQIAn1fiP0l1VenatWvLvn37ZP369ebYu3evWV9IHwNQsAB/PykXcv5vD6bQA4BFe4SWLl0qK1euNBux2uniirrg4rXXXuvs9gFeN3Ms9Vy2pLC6NABYs0coJCQk3+7zdmfOnJHg4PPTgwFcZpsN1hICAGsGoW7dusngwYNl1apVZgNWPbSHaMiQIaZgGkDhWFQRACwehF577TVTEK070oeGhppDh8Tq1q0rr776qmtaCXgJptADgMWDUExMjHzxxReyY8cO+fTTT82xfft2+fzzzyU6Oto1rRSR5557Ttq0aSPh4eGmDcUxYMAA8fPzy3fcdNNNLmsjUPwd6FldGgAsu46Q0h4gPXRNoU2bNsmpU6ckNjZWXCUzM1Puuusu0xP17rvvFvt5GnymTp2ar8YJcJeYsPN1dKwuDQAWDUKPPPKING7cWAYNGmRCUPv27WX58uWmp2b+/PnSoUMHlzR07Nix5uO0adNK9DwNPgkJCS5pE1BS9AgBgMWHxnQorGnTpubzefPmya+//irbtm2TRx99VJ544gnxNN9++61UqlTJLP44dOhQOXHihLubBB8WZd9vjKExALBmEDp+/Lijh+Wrr76Su+++W6644gq5//77zRCZJ9FhsRkzZsiiRYtk/PjxZg2km2++2fRkFSYjI0NSU1PzHYCzN15l+jwAWDQIxcfHy5YtW0yYWLBggdx4443m/vT09BLvNTZ69OhLipkvPrS3qbR69eplpvTrUF737t3N0N2aNWtML1FhdLsQLfq2H7p1CODMBRUVPUIAYNEaoYEDB5peoMqVK5ug0qlTJ3O/riuUlJRUoq81cuRIM7OrKLqdh7Po14qLizP7pXXs2LHAc8aMGSMjRoxw3NYeIcIQnL+gYiYXFQCsGISefvppadSokdlrTGdx2WdhaW+Q9vCURMWKFc1RVnSjWK0R0hBXGP1+mFmGslhHKDfXJv7+flxsALDa9Pk777zzkvv69+8vrqQbu548edJ81GG5DRs2mPt1Cn9kZKT5XHukdGjrjjvuMFt+6Eyznj17mpqmX375RUaNGmXO79Kli0vbClwuCOXaRM5kZktU6PnbAACLrSNU1p588kmZPn2643bz5s3NxyVLljim7OvCjikpKY4eqp9++sk8Jzk5WRITE6Vz587y7LPP0uMDtwkNCpDQIH85l5UrKelZBCEAcDPLBCFdP+hyawjpvmd2YWFh8vXXX5dBy4CS9wqdy8oww2OU4gOAxWaNAXDO6tJMoQcA9yMIAWUs2rHfGDPHAMByQUi31NBFCs+ePeuaFgFejkUVAcDCQUiLlB977DEzE+uBBx6QlStXuqZlgA9MoQcAWCwITZgwQQ4ePGh2dD969Ki0a9dOGjRoIC+99JIcOXLENa0EvHBRRYIQAFi0RigwMFB69OghX3zxhVmksE+fPvKPf/zDrMCsW1ksXrzY+S0FvGybDVaXBgCLF0uvXr1annrqKXn55ZfNDu+6PYVuYdGtWzczfAag8KExZo0BgAXXEdLhsPfff98Mje3cuVNuvfVW+fDDD81qzbr3mNL9w3Tndx0uA5AfNUIAYOEgVLVqValTp47cf//9JvAUtFdYkyZNpGXLls5qI+BVqBECAAsHoUWLFsl1111X5DlRUVFm6wsAl2JBRQCwcI3Q5UIQgOL1CLGgIgC4HytLA2Us6kKxtG68ei4rh+sPAG5EEALKWLmQQPE/P69AUllUEQDciiAElPX/dP5+v0+hJwgBgLWCkO4zlpGRccn9mZmZ5jEAl8cUegCwaBAaOHCgpKSkXHL/6dOnzWMALi/asbo0+40BgKWCkM1mcyycmJdutREdHe2sdgE+sgN9prubAgA+LbAku85rANKjY8eOZr8xu5ycHNm9e7dZTRrA5bGoIgBYLAjpZqpqw4YNZjuNyMhIx2PBwcFSs2ZN6dmzp2taCXgZaoQAwGJBSDdXVRp4evXqJSEhIa5sF+ATQ2Nfbjok+06m/35/eLD8pX1tqRwd5sbWAYDvKPEWGzfccIMcO3bM7Dlm34F+1qxZ0qBBAxk8eLAr2gh4naqx4ebjr8fSzJHX3I0HZWLv5nJt3Tg3tQ4AfEeJg1CfPn1M4Lnvvvvk8OHD0qlTJ2nUqJHMnDnT3H7yySdd01LAi9zePFF0zkFKnnWEbDaRz388IFsOpcp9766SkZ2vlKHt65h1hwAAruFn02lgJRAbGysrV66UK6+8Ul577TX5+OOPZdmyZfK///1PhgwZIr/++qt4k9TUVDMbTpcM0M1kAVfSLTee/GKzzF6739zuVL+SvHxXM4m+sD8ZAMC5798lnj6flZXlqA/65ptv5LbbbjOfJyUlyaFDh0r65QDkERoUIC/e2VTG92wswYH+8s3Wo3LPlBWSnZPLdQIAFyhxEGrYsKG8+eab8v3338vChQsdU+YPHjwoFSpUcEUbAZ9zT8vq8tnQNhIS6C/bDp+WPSd+L6gGALgxCI0fP17eeust6dChg/Tu3VuaNm1q7p87d65cffXVTmwa4NsaVYmWuMjzva9pGdnubg4AeKUSF0trADp+/LgZe9N6ITstoA4PPz8TBoBzhAcHmI9pmQQhAPCY3ee1vnrdunWmZ0j3GLMvqkgQApwrIuT83yppGTlcWgDwhB6h3377zdQF7d271+xCf+ONN0q5cuXMkJne1vohAM4R6QhC9AgBgEf0CD388MNy1VVXyalTpyQs7PfVb++44w5ZtGiRs9sH+DSGxgDAw3qEdLbY8uXLzVBYXrr1xoEDB5zZNsDn0SMEAB7WI5Sbm2t2m7/Y/v37zRAZAOfXCJ2hRggAPCMIde7cWSZMmOC47efnJ2fOnDGbst5yyy3Obh/g08JDzs8aS6dGCAA8Y2js5Zdfli5duphNVs+dO2f2Htu5c6fExcXJhx9+6JpWAj4qMvhCsTTT5wHAM4KQ7jq/ceNGs8eYftTeoEGDBknfvn3zFU8D+OMYGgMADwtC5kmBgSb46AHAdSIYGgMAzwpCJ06ccOwptm/fPnn77bfl7Nmzcuutt0q7du1c0UbAZ/3eI8Q6QgDg1mLpTZs2mSnylSpVMjvNb9iwQVq2bCmvvPKKTJkyRW644QaZM2eOSxoJiK+vLE2NEAC4NwiNGjVKGjduLN99953Zb6xbt27StWtXSUlJMYsr/uUvf5EXXniBHxPgRBEXiqXTmT4PAO4dGluzZo0sXrxYmjRpYnac116gBx98UPz9z2epv/71r3LNNde4ppWAj9cIMTQGAG7uETp58qQkJCSYzyMjIyUiIiLf7vP6uX0DVgDOwcrSAOBBCyrq4olF3QbgqhqhHMnNtXF5AcCds8YGDBggISEh5nNdTHHIkCGmZ0jpzvMAnCviQo2QOpuV4whGAIAy7hHq37+/mTEWHR1tjnvvvVcSExMdt/Wxfv36iSvs2bPHLNpYq1Yts2hjnTp1zJYemZmZRT5Pw9qwYcPMdH8dzuvZs6ccOXLEJW0EXCE0yF/8L3S8pjGFHgCcrth/Xk6dOlXcZdu2bWaz17feekvq1q0rmzdvlgceeEDS0tLkpZdeKvR5jz76qHz55ZfyySefmLA2fPhw6dGjhyxbtqxM2w+Ulg4/ay/Q6XPZpmC6EpcSAJzKEv3sN910kznsateuLdu3b5fJkycXGoR0Wv+7774rs2bNMmsc2cNc/fr1ZeXKlcxwg2VEBJ8PQumZOe5uCgB4nRLvPu8pNOiUL1++0MfXrVsnWVlZ0qlTJ8d9uhBk9erVZcWKFWXUSuCPYwo9APh4j9DFdu3aJRMnTixyWOzw4cMSHBwsMTEx+e6Pj483jxVGi77zFn6npqY6qdVA6TCFHgC8tEdo9OjRpgaiqEPrg/I6cOCAGSa76667TJ2Qs40bN85RAK5HtWrVnP4aQEmEB/8+hR4A4EU9QiNHjjRT8oui9UB2Bw8elOuvv17atGljVrYuii7+qLPKkpOT8/UK6awx+8KQBRkzZoyMGDEiX48QYQgesZYQs8YAwLuCUMWKFc1RHNoTpCGoRYsWpujZvrVHYfS8oKAgWbRokZk2r7TAeu/evdK6detCn6frJNnXSgI8QeSFbTYIQgDgo8XSGoJ0o1ctdNa6oGPHjpk6n7y1PnqOFkOvXr3a3NZhLV17SHt3lixZYoqnBw4caEIQe6LBSsIdPUIMjQGATxZLL1y40BRI61G1atV8j9ls57cd0Bli2uOTnp7ueOyVV14xPUfaI6QF0F26dJE33nijzNsPOKVYOjObCwkATuZnsycJFEhrhLR3SafrR0VFcZVQ5l79Zqe88s0O6dOqujx/R2N+AgDgxPdvSwyNAb7Mvo4QNUIA4HwEIcAys8aoEQIAZyMIAR6O6fMA4DoEIcAq0+cplgYApyMIAVZZWZoFFQHA6QhCgGX2GqNGCACcjSAEeDhqhADAdQhCgIeLCP69RohlvwDAuQhCgEV6hHJtIueyct3dHADwKgQhwMOFBQWIn9/5z89QMA0ATkUQAjycv7+fhAedHx5LZwo9ADgVQQiw0PAYPUIA4FwEIcACmEIPAK5BEAIsIJyNVwHAJQhCgAVE2FeXpkYIAJyKIARYamgs291NAQCvQhACLFUszTYbAOBMBCHAAiIu1Ail0yMEAE5FEAIsIOJCjdAZaoQAwKkIQoAFsPEqALgGQQiw1NAYNUIA4EwEIcACWFkaAFyDIARYafo8NUIA4FQEIcACwu0LKjI0BgBORRACLFQjxIKKAOBcBCHAAlhZGgBcgyAEWIBjaCyTWWMA4EwEIcBiPUI2m83dzQEAr0EQAixUI5Sda5OM7Fx3NwcAvAZBCLDQ0JiiYBoAnIcgBFhAgL+fhAVdWF2aOiEAcBqCEGARrC4NAM5HEAIsIpK1hADA6QhCgEUwhR4AnI8gBFgEiyoCgPMRhACLTaE/k5Ht7qYAgNcgCAEWEX5hUcV0ghAAOA1BCLCISLbZAACnIwgBFsH0eQBwPoIQYLEaIYbGAMB5CEKA5XqE2IEeAJyFIARYLAix1xgAOA9BCLCIiODzQ2NpmUyfBwBnIQgBFkGPEAD4aBDas2ePDBo0SGrVqiVhYWFSp04deeqppyQzM7PI53Xo0EH8/PzyHUOGDCmzdgOuWVmaGiEAcJbz/7J6uG3btklubq689dZbUrduXdm8ebM88MADkpaWJi+99FKRz9XznnnmGcft8PDwMmgx4HxMnwcAHw1CN910kznsateuLdu3b5fJkydfNghp8ElISCiDVgJlUyOUTo0QAPjW0FhBUlJSpHz58pc9b+bMmRIXFyeNGjWSMWPGSHp6epHnZ2RkSGpqar4D8KwaIYbGAMCneoQutmvXLpk4ceJle4P69OkjNWrUkMTERPnpp5/k8ccfNz1Jn332WaHPGTdunIwdO9YFrQacE4Qyc3IlMztXggMt+3cMAHgMP5vNZnPXi48ePVrGjx9f5Dlbt26VpKQkx+0DBw5I+/btTSH0O++8U6LXW7x4sXTs2NEEKS24LqxHSA877RGqVq2a6YGKiooq0esBzpSdkyt1n/iv+XzDkzdKTHgwFxgACqHv39HR0Zd9/3Zrj9DIkSNlwIABRZ6j9UB2Bw8elOuvv17atGkjU6ZMKfHrtWrVynwsKgiFhISYA/A0gQH+EhLoLxnZuXImI5sgBABO4NYgVLFiRXMUh/YEaQhq0aKFTJ06Vfz9Sz4ssGHDBvOxcuXKJX4u4ClT6DOyM6kTAgAnsUSRgYYgHQqrXr26qQs6duyYHD582Bx5z9EhtNWrV5vbv/zyizz77LOybt06sw7R3LlzpV+/ftKuXTtp0qSJG78boPTCL2y8yurSAOBDxdILFy40w1l6VK1aNd9j9hKnrKwsUwhtnxUWHBws33zzjUyYMMGsN6R1Pj179pS///3vbvkeAGeICGa/MQDwuSCkdUSXqyWqWbOmIxQpDT5Lly4tg9YB7lhdmv3GAMBnhsYAnBfOWkIA4FQEIcBCIqkRAgCnIggBFhJxoUZIp88DAP44ghBgyW02CEIA4AwEIcBCIuxDY+w3BgBOQRACLIQeIQBwLoIQYMXp85kMjQGAMxCEAAsJdyyomOPupgCAVyAIAVacPk+xNAA4BUEIsGCNENPnAcA5CEKABYPQvpPpsu63U+5uDgBYHkEIsJDGVaIlKaGcpGXmyD1vrZBpy3bn22MPAFAyBCHAQoIC/OXToW2ka5PKkp1rk6fnbZGHP9og6cwiAwDv3X0eQP4p9JN6N5c/VY+VcV9tlbkbD8q2w6nSqX68yy5T+Yhg6d+mpgliAOBNCEKABfn5+cmgtrWkSdVoGTZzvew4csYcrlSxXIjc3qyKS18DAMoaQQiwsJY1y8v8h9rKByv3yulzWS55jfW/nZKN+1Nk84EUghAAr0MQAiyuUrlQGXHjFS77+h+u3isb92+SbYdPu+w1AMBdGPAHUKQrE8qZj9sJQgC8EEEIQJGuiD8fhI6ezpCTaZlcLQBehSAE4LKz1KqVDzOf6+w0APAmBCEAl5WUEGU+MjwGwNsQhABclq5mrQhCALwNQQhAsQummTkGwNsQhAAUu0dox5HTkpvL3mYAvAdBCMBl1awQIcGB/pKemSP7T53ligHwGgQhAJcVGOAvdStGms+ZOQbAmxCEABQLBdMAvBFBCEDJCqaPsNUGAO9BEAJQsiB0iEUVAXgPghCAEi2quOdEupzLyuGqAfAKBCEAxRIfFSLRYUGSk2uTXUfPcNUAeAWCEIBi8fPzYyd6AF6HIASg5DPHKJgG4CUIQgCKja02AHgbghCAUqwlxMwxAN6BIASg2K6IPx+EjqRmSHJ6JlcOgOURhAAUW7nQIKkaG2Y+Zyd6AN6AIASgRNhqA4A3IQgBKGXBNHVCAKyPIASgRK68sMI0Q2MAvEGguxsAwJpDY1sPpcrYeT+7uzmWVrNChLSoESv1K0dJgL+fu5sD+CSCEIASqRUXIRHBAZKWmSNTl+3h6jmBXs/m1WOlWbUYCQ8JKNNrGhcRIk2qRUu9SuUIY/BJfjabzebuRniy1NRUiY6OlpSUFImKOj8kAPi65buOyw+7jru7GZame7ZtOZQqP+5NljMZ2e5ujoQHB0ijxGhpUjVarkgoZ3qralYIl4rlQsz2KoC3vn9bJgjddtttsmHDBjl69KjExsZKp06dZPz48ZKYmFjoc86dOycjR46Ujz76SDIyMqRLly7yxhtvSHx8fLFflyAEwNWBaPvh07Lut5Oy5dBpycnNLbMLrv/67z91Vn7an2x6+AoSFhRglkwICrh8SWlIkL/EhgdLTFiQxIQHS2x4kLSuU8EM/xGmUNa8Lgi98sor0rp1a6lcubIcOHBAHnvsMXP/8uXLC33O0KFD5csvv5Rp06aZizF8+HDx9/eXZcuWFft1CUIAfCGM/XrsjGzcn2JC0e7jabLnRJocOHVWcm3OqSu7r3UN6d6sikSEUJGBsuF1Qehic+fOle7du5uenqCgoEse12+8YsWKMmvWLLnzzjvNfdu2bZP69evLihUr5JprrinW6xCEAPiqzOxcOZB8Vg4mnzVhqSj66LmsHLPieHJ6lpxKzzLP/d/PhyUj+3wvV2RIoNzaNNEMt5VUgJ+f3Nq0stSuGFnq7we+JbWYQciS0fzkyZMyc+ZMadOmTYEhSK1bt06ysrLMEJpdUlKSVK9evcggpMFKj7wXEgB8UXCgvymO16O0NBh9um6/zFy11/Q0fbh6b6m/1qJtR2Tu8Lalfj5g+SD0+OOPy6RJkyQ9Pd0Emfnz5xd67uHDhyU4OFhiYmLy3a/1QfpYYcaNGydjx451arsBwFdprdCfr6st919bS5b9clwWbzsq2TklG4iwiU0+XrNPftqfIlsOpkqDRCauwEuC0OjRo03Bc1G2bt1qenLU3/72Nxk0aJD89ttvJqz069fPhCFnFuGNGTNGRowYka9HqFq1ak77+gDgi/z9/eS6ehXNURon0zLlq02H5ZN1++SpxIZObx98l1uDkM7oGjBgQJHn1K5d2/F5XFycOa644gpT66MBZeXKlaaI+mIJCQmSmZkpycnJ+XqFjhw5Yh4rTEhIiDkAAJ7jrquqmSA058cDMvrmJAkJLNv1luC93BqEtJhZj9LIvTDFNG89T14tWrQw9UOLFi2Snj17mvu2b98ue/fuLTA4AQA8V7t6FSUhKlQOp56TRVuPyi2NK7u7SfASlthrbNWqVaY2SNcR0mGxxYsXS+/evaVOnTqOUKNT6nUIbfXq1ea2VorrMJoOcy1ZssQUTw8cONCcX9wZYwAAz6BbkPRsUcV8PnvtPnc3B17EEkEoPDxcPvvsM+nYsaNceeWVJuA0adJEli5d6hjG0hli2uOjhdR51x7q1q2b6RFq166dGRLTrwMAsJ67Wpyv1/xuxzE5lHLW3c2Bl7DsOkJlhXWEAMBz3P3WClm9+6Q81vkKGX5DPXc3B17w/m2JHiEAANQ9V53vFZq9dr/kOmPZa/g8ghAAwDJubpxgVqjeezJdVu856e7mwAsQhAAAlhEerNt0nJ8xRtE0fG5laQAAdE2hD1fvk682HZIHO9SV0CD+pveGFcgj3bQhL0EIAGApzavFSN1KkbLr6Bnp9O+l7m4OnOD5OxpLn1bVxR0IQgAAS9FtlR7qWE/+/vkmx872sLYAN3bqEYQAAJZzW9NEcwB/FAOrAADAZxGEAACAzyIIAQAAn0UQAgAAPosgBAAAfBZBCAAA+CyCEAAA8FkEIQAA4LMIQgAAwGcRhAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQQgAAPisQHc3wNPZbDbzMTU11d1NAQAAxWR/37a/jxeGIHQZp0+fNh+rVatW3GsPAAA86H08Ojq60Mf9bJeLSj4uNzdXDh48KOXKlRM/Pz+nJlUNV/v27ZOoqCinfV1w7T0Zv/dce1/E7717aLzREJSYmCj+/oVXAtEjdBl68apWrSquoiGIIOQeXHv34dpz7X0Rv/dlr6ieIDuKpQEAgM8iCAEAAJ9FEHKTkJAQeeqpp8xHcO19Bb/3XHtfxO+9Z6NYGgAA+Cx6hAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQchNXn/9dalZs6aEhoZKq1atZPXq1e5qitcaN26ctGzZ0qwKXqlSJenevbts37493znnzp2TYcOGSYUKFSQyMlJ69uwpR44ccVubvdELL7xgVmV/5JFHHPdx3V3nwIEDcu+995rf6bCwMGncuLGsXbs232q7Tz75pFSuXNk83qlTJ9m5c6cLW+QbcnJy5B//+IfUqlXLXNc6derIs88+m2+fK669ZyIIucHHH38sI0aMMNPn169fL02bNpUuXbrI0aNH3dEcr7V06VITclauXCkLFy6UrKws6dy5s6SlpTnOefTRR2XevHnyySefmPN1O5UePXq4td3eZM2aNfLWW29JkyZN8t3PdXeNU6dOybXXXitBQUHy3//+V7Zs2SIvv/yyxMbGOs558cUX5bXXXpM333xTVq1aJREREebfHw2nKL3x48fL5MmTZdKkSbJ161ZzW6/1xIkTufaeTvcaQ9m6+uqrbcOGDXPczsnJsSUmJtrGjRvHj8KFjh49qn+a2ZYuXWpuJycn24KCgmyffPKJ45ytW7eac1asWMHP4g86ffq0rV69eraFCxfa2rdvb3v44Ye57i72+OOP29q2bVvo47m5ubaEhATbv/71L8d9+v9BSEiI7cMPP3R187xa165dbffff3+++3r06GHr27ev+Zxr77noESpjmZmZsm7dOtMdnXc/M729YsWKsm6OT0lJSTEfy5cvbz7qz0F7ifL+LJKSkqR69er8LJxAe+O6du2a7/py3V1r7ty5ctVVV8ldd91lhoObN28ub7/9tuPx3bt3y+HDh/P9THQvJh2e59+fP6ZNmzayaNEi2bFjh7m9ceNG+eGHH+Tmm2/m2ns4Nl0tY8ePHzdjyfHx8fnu19vbtm0r6+b4jNzcXFOjosMGjRo1MvfpG0JwcLDExMRc8rPQx1B6H330kRn21aGxi3HdXefXX381wzM69P5///d/5vo/9NBD5ve8f//+jt/rgv794Xf+jxk9erTZZV7/mAoICDD/zj/33HPSt29f8zjX3nMRhOAzvRObN282f6HBtfbt2ycPP/ywqcvSyQAo28CvPULPP/+8ua09Qvp7r/VAGoTgOrNnz5aZM2fKrFmzpGHDhrJhwwbzx1diYiLX3sMxNFbG4uLizF8LF89M0tsJCQll3RyfMHz4cJk/f74sWbJEqlat6rhfr7cOVSYnJ+c7n5/FH6NDjlr4/6c//UkCAwPNoYXoWqCrn2vvA9fdNXQmWIMGDfLdV79+fdm7d6/53P5vDP/+ON/f/vY30yvUq1cvM1PvvvvuM5MCdPYq196zEYTKmHZRt2jRwowl5/0rTm+3bt26rJvj1XSqqoagzz//XBYvXmymtealPwedXZP3Z6HT6/VNg59F6XXs2FE2bdpk/iK2H9pLoUME9s+57q6hQ78XLxGhNSs1atQwn+v/AxqG8v7O63COzh7jd/6PSU9PN/WeeekfvfrvO9few7m7WtsXffTRR2aWxrRp02xbtmyxDR482BYTE2M7fPiwu5vmVYYOHWqLjo62ffvtt7ZDhw45jvT0dMc5Q4YMsVWvXt22ePFi29q1a22tW7c2B5wr76wxrrvrrF692hYYGGh77rnnbDt37rTNnDnTFh4ebvvggw8c57zwwgvm35svvvjC9tNPP9luv/12W61atWxnz551Ycu8X//+/W1VqlSxzZ8/37Z7927bZ599ZouLi7ONGjXKcQ7X3jMRhNxk4sSJ5g04ODjYTKdfuXKlu5ritTTnF3RMnTrVcY7+4//ggw/aYmNjzRvGHXfcYcISXBuEuO6uM2/ePFujRo3MH1tJSUm2KVOm5Htcp3H/4x//sMXHx5tzOnbsaNu+fbsLW+QbUlNTze+4/rseGhpqq127tu2JJ56wZWRkOM7h2nsmP/2Pu3ulAAAA3IEaIQAA4LMIQgAAwGcRhAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQQgAAPgsghAAXEbNmjVlwoQJXCfACxGEAHiUAQMGSPfu3c3nHTp0MDt4l5Vp06ZJTEzMJfevWbNGBg8eXGbtAFB2AsvwtQDALXS3e93wuLQqVqzo1PYA8Bz0CAHw2J6hpUuXyquvvip+fn7m2LNnj3ls8+bNcvPNN0tkZKTEx8fLfffdJ8ePH3c8V3uShg8fbnqT4uLipEuXLub+f//739K4cWOJiIiQatWqyYMPPihnzpwxj3377bcycOBASUlJcbze008/XeDQ2N69e+X22283rx8VFSV33323HDlyxPG4Pq9Zs2by/vvvm+dGR0dLr1695PTp02V2/QAUD0EIgEfSANS6dWt54IEH5NChQ+bQ8JKcnCw33HCDNG/eXNauXSsLFiwwIUTDSF7Tp083vUDLli2TN99809zn7+8vr732mvz888/m8cWLF8uoUaPMY23atDFhR4ON/fUee+yxS9qVm5trQtDJkydNUFu4cKH8+uuvcs899+Q775dffpE5c+bI/PnzzaHnvvDCCy69ZgBKjqExAB5Je1E0yISHh0tCQoLj/kmTJpkQ9Pzzzzvue++990xI2rFjh1xxxRXmvnr16smLL76Y72vmrTfSnpp//vOfMmTIEHnjjTfMa+lrak9Q3te72KJFi2TTpk2ye/du85pqxowZ0rBhQ1NL1LJlS0dg0pqjcuXKmdvaa6XPfe6555x2jQD8cfQIAbCUjRs3ypIlS8ywlP1ISkpy9MLYtWjR4pLnfvPNN9KxY0epUqWKCSgaTk6cOCHp6enFfv2tW7eaAGQPQapBgwamyFofyxu07CFIVa5cWY4ePVqq7xmA69AjBMBStKbn1ltvlfHjx1/ymIYNO60Dykvri7p16yZDhw41vTLly5eXH374QQYNGmSKqbXnyZmCgoLy3daeJu0lAuBZCEIAPJYOV+Xk5OS7709/+pP85z//MT0ugYHF/yds3bp1Joi8/PLLplZIzZ49+7Kvd7H69evLvn37zGHvFdqyZYupXdKeIQDWwtAYAI+lYWfVqlWmN0dnhWmQGTZsmClU7t27t6nJ0eGwr7/+2sz4KirE1K1bV7KysmTixImmuFlndNmLqPO+nvY4aS2Pvl5BQ2adOnUyM8/69u0r69evl9WrV0u/fv2kffv2ctVVV7nkOgBwHYIQAI+ls7YCAgJMT4uu5aPT1hMTE81MMA09nTt3NqFEi6C1Rsfe01OQpk2bmunzOqTWqFEjmTlzpowbNy7fOTpzTIundQaYvt7Fxdb2Ia4vvvhCYmNjpV27diYY1a5dWz7++GOXXAMAruVns9lsLn4NAAAAj0SPEAAA8FkEIQAA4LMIQgAAwGcRhAAAgM8iCAEAAJ9FEAIAAD6LIAQAAHwWQQgAAPgsghAAAPBZBCEAAOCzCEIAAMBnEYQAAID4qv8HzYNusut+DJYAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "plt.plot(best_y_per_iteration.values)\n", + "plt.xlabel(\"Iteration\")\n", + "plt.ylabel(\"Best y so far\")\n", + "plt.title(\"Best y per iteration\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "96" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(best_y_per_iteration)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "96" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(strategy.experiments)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
710.3999520.8877421.00.5504820.00.00.00.00.00.01.00.0-3.110493True
940.3980060.8851970.00.5726160.00.01.00.00.00.00.01.0-3.013107True
900.3977890.8849290.00.5733320.00.01.00.00.00.00.00.0-3.013040True
730.3958580.8817980.00.5564070.00.00.00.01.00.01.01.0-3.002077True
670.4018520.8889790.00.5496050.00.00.01.00.00.01.01.0-2.996064True
.............................................
691.0000000.6017400.00.0000000.00.01.00.00.01.00.00.0-0.000671True
440.0000000.0000000.00.0000001.00.01.00.00.00.00.00.0-0.000255True
891.0000001.0000001.00.0000001.00.00.00.00.00.01.00.0-0.000232True
681.0000000.4827590.01.0000000.01.01.00.00.00.00.00.0-0.000157True
581.0000000.0000000.01.0000000.00.00.00.00.00.01.00.0-0.000017True
\n", + "

96 rows × 14 columns

\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 x_spurious_1 \\\n", + "71 0.399952 0.887742 1.0 0.550482 0.0 0.0 0.0 0.0 \n", + "94 0.398006 0.885197 0.0 0.572616 0.0 0.0 1.0 0.0 \n", + "90 0.397789 0.884929 0.0 0.573332 0.0 0.0 1.0 0.0 \n", + "73 0.395858 0.881798 0.0 0.556407 0.0 0.0 0.0 0.0 \n", + "67 0.401852 0.888979 0.0 0.549605 0.0 0.0 0.0 1.0 \n", + ".. ... ... ... ... ... ... ... ... \n", + "69 1.000000 0.601740 0.0 0.000000 0.0 0.0 1.0 0.0 \n", + "44 0.000000 0.000000 0.0 0.000000 1.0 0.0 1.0 0.0 \n", + "89 1.000000 1.000000 1.0 0.000000 1.0 0.0 0.0 0.0 \n", + "68 1.000000 0.482759 0.0 1.000000 0.0 1.0 1.0 0.0 \n", + "58 1.000000 0.000000 0.0 1.000000 0.0 0.0 0.0 0.0 \n", + "\n", + " x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 y valid_y \n", + "71 0.0 0.0 1.0 0.0 -3.110493 True \n", + "94 0.0 0.0 0.0 1.0 -3.013107 True \n", + "90 0.0 0.0 0.0 0.0 -3.013040 True \n", + "73 1.0 0.0 1.0 1.0 -3.002077 True \n", + "67 0.0 0.0 1.0 1.0 -2.996064 True \n", + ".. ... ... ... ... ... ... \n", + "69 0.0 1.0 0.0 0.0 -0.000671 True \n", + "44 0.0 0.0 0.0 0.0 -0.000255 True \n", + "89 0.0 0.0 1.0 0.0 -0.000232 True \n", + "68 0.0 0.0 0.0 0.0 -0.000157 True \n", + "58 0.0 0.0 1.0 0.0 -0.000017 True \n", + "\n", + "[96 rows x 14 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.sort_values(by=\"y\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "48", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.40246971, 0.49047062, 0.83129865, 0. , 0. ,\n", + " 0. , 0.25661052, 0.92627053, 0. , 0.97370585,\n", + " 0. , 0. , 0. , 0. , 0. ,\n", + " 0. , 0. , 0. , 0.29692436, 0.30276136,\n", + " 0.31188112, 0. , 0.31113767, 0.31041889, 0.33260343,\n", + " 0.33121978, 0.33797472, 0.34426884, 0. , 0.42236647,\n", + " 0.38995489, 0.40404801, 0.40527138, 0.40636354, 0.37343461,\n", + " 0. , 0.46571731, 0.42500977, 0.81058484, 0.44997919,\n", + " 0. , 0.4405271 , 0.43008012, 0. , 0. ,\n", + " 0.4443915 , 0. , 0. , 0. , 0.43516446,\n", + " 0.43388985, 0.44600393, 0. , 0. , 0. ,\n", + " 0.42146618, 0.48546936, 0.40232409, 1. , 0.39738112,\n", + " 1. , 0.41107153, 0.40042298, 0.39872447, 0. ,\n", + " 0.39733587, 0.39918283, 0.40185191, 1. , 1. ,\n", + " 0.66228013, 0.39995177, 1. , 0.39585761, 1. ,\n", + " 0. , 0. , 0. , 0.45403372, 0. ,\n", + " 1. , 0.44424961, 0. , 0. , 0. ,\n", + " 0.4010512 , 0. , 0.38435384, 0.41094965, 1. ,\n", + " 0.39778853, 0. , 0. , 0. , 0.39800625,\n", + " 0. ])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments.x_0.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.3920, 0.0251, 5.8860, 1.1981, 7.8354, 0.9001, 15.2773, 13.5841,\n", + " 13.1932, 14.4943, 13.7058, 15.2569]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.4024700.0000000.0000000.0000000.0000000.0000000.0000000.6374530.1399190.4218240.6713710.331502-0.005683True
10.4904710.2944280.0000000.0000000.0000000.9252190.1794820.7729880.4809430.0000000.0000000.000000-0.096735True
20.8312990.0000000.1528290.2337560.0000000.2376550.2170320.0000000.5361540.0000000.0000000.000000-0.041713True
30.0000000.8293810.0000000.2879120.0000000.0000000.7741490.0000000.0000000.5073140.7311190.425413-0.082572True
40.0000000.8670440.1348860.0000000.2504430.2510050.7662670.0000000.0214130.0000000.0000000.000000-0.045721True
50.0000000.3077290.8427450.1164260.0000000.2726150.0000000.0000000.0000000.7787310.0000000.155062-0.097132True
60.2566110.4944050.0000000.6031580.0000000.0000000.8366570.0000000.0000000.8080900.0000000.134337-0.619587True
70.9262710.0000000.0000000.3188560.0000000.0000000.6645890.0000000.4928020.0531630.0000000.259605-0.003226True
80.0000000.2191180.6250430.7884450.0000000.0000000.0000000.4703430.0000000.0000000.3855200.006767-0.006505True
90.9737060.0000000.0000000.0623920.0000000.6765590.0000000.8051040.0000000.2609770.4023370.000000-0.054280True
100.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.9186150.0000000.058531-0.005089True
110.0000000.0000000.0000000.6817350.0000000.0000000.0000000.0000000.0000000.8729310.0000000.110071-0.002782True
120.0000000.5359470.0000000.0000000.0000000.0000001.0000000.0000000.0000000.0000000.0000000.157608-0.005837True
130.0000000.0000000.0000000.0000000.0000000.0000000.9638820.0000000.0000000.8681810.0000000.102231-0.005089True
140.0000000.5218990.0000000.6834860.0000000.0010560.0000000.0000000.0000001.0000000.0000000.144456-0.059981True
150.0000000.0000000.0000000.6328840.0000000.0000000.9731650.0000000.0000001.0000000.0000000.141852-0.003873True
160.0000000.4853940.0190510.0000000.0000000.0000000.8809520.0000000.0000000.9239620.0000000.111996-0.005705True
170.0000000.0000000.0000000.7615450.0000000.0000000.8266870.0000000.0000000.1617930.0000000.000000-0.001490True
180.2969240.5858560.0000000.7151310.0000000.0000000.9460410.0000000.0000000.0000000.0000000.000000-1.003502True
190.3027610.0000000.0000000.0000000.3969290.0000000.0000000.0000000.0000000.0000000.0000000.000000-0.024736True
200.3118810.0000000.0000000.8194800.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000-0.003597True
210.0000000.6611510.0000000.8140610.0000000.0000001.0000000.0000000.0000000.0000000.0000000.000000-0.071057True
220.3111380.0000000.0000000.0000000.0000000.0000001.0000000.0000000.0000000.0000000.0000000.000000-0.006050True
230.3104190.6607250.0000000.0000000.0000000.0000000.9965630.0000000.0000000.0000000.0000000.000000-0.067044True
240.3326030.0000000.0000000.8009070.0000000.0000001.0000000.0000000.0000000.0000000.0000000.000000-0.004195True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.402470 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "1 0.490471 0.294428 0.000000 0.000000 0.000000 0.925219 0.179482 \n", + "2 0.831299 0.000000 0.152829 0.233756 0.000000 0.237655 0.217032 \n", + "3 0.000000 0.829381 0.000000 0.287912 0.000000 0.000000 0.774149 \n", + "4 0.000000 0.867044 0.134886 0.000000 0.250443 0.251005 0.766267 \n", + "5 0.000000 0.307729 0.842745 0.116426 0.000000 0.272615 0.000000 \n", + "6 0.256611 0.494405 0.000000 0.603158 0.000000 0.000000 0.836657 \n", + "7 0.926271 0.000000 0.000000 0.318856 0.000000 0.000000 0.664589 \n", + "8 0.000000 0.219118 0.625043 0.788445 0.000000 0.000000 0.000000 \n", + "9 0.973706 0.000000 0.000000 0.062392 0.000000 0.676559 0.000000 \n", + "10 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "11 0.000000 0.000000 0.000000 0.681735 0.000000 0.000000 0.000000 \n", + "12 0.000000 0.535947 0.000000 0.000000 0.000000 0.000000 1.000000 \n", + "13 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.963882 \n", + "14 0.000000 0.521899 0.000000 0.683486 0.000000 0.001056 0.000000 \n", + "15 0.000000 0.000000 0.000000 0.632884 0.000000 0.000000 0.973165 \n", + "16 0.000000 0.485394 0.019051 0.000000 0.000000 0.000000 0.880952 \n", + "17 0.000000 0.000000 0.000000 0.761545 0.000000 0.000000 0.826687 \n", + "18 0.296924 0.585856 0.000000 0.715131 0.000000 0.000000 0.946041 \n", + "19 0.302761 0.000000 0.000000 0.000000 0.396929 0.000000 0.000000 \n", + "20 0.311881 0.000000 0.000000 0.819480 0.000000 0.000000 0.000000 \n", + "21 0.000000 0.661151 0.000000 0.814061 0.000000 0.000000 1.000000 \n", + "22 0.311138 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 \n", + "23 0.310419 0.660725 0.000000 0.000000 0.000000 0.000000 0.996563 \n", + "24 0.332603 0.000000 0.000000 0.800907 0.000000 0.000000 1.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.637453 0.139919 0.421824 0.671371 0.331502 \n", + "1 0.772988 0.480943 0.000000 0.000000 0.000000 \n", + "2 0.000000 0.536154 0.000000 0.000000 0.000000 \n", + "3 0.000000 0.000000 0.507314 0.731119 0.425413 \n", + "4 0.000000 0.021413 0.000000 0.000000 0.000000 \n", + "5 0.000000 0.000000 0.778731 0.000000 0.155062 \n", + "6 0.000000 0.000000 0.808090 0.000000 0.134337 \n", + "7 0.000000 0.492802 0.053163 0.000000 0.259605 \n", + "8 0.470343 0.000000 0.000000 0.385520 0.006767 \n", + "9 0.805104 0.000000 0.260977 0.402337 0.000000 \n", + "10 0.000000 0.000000 0.918615 0.000000 0.058531 \n", + "11 0.000000 0.000000 0.872931 0.000000 0.110071 \n", + "12 0.000000 0.000000 0.000000 0.000000 0.157608 \n", + "13 0.000000 0.000000 0.868181 0.000000 0.102231 \n", + "14 0.000000 0.000000 1.000000 0.000000 0.144456 \n", + "15 0.000000 0.000000 1.000000 0.000000 0.141852 \n", + "16 0.000000 0.000000 0.923962 0.000000 0.111996 \n", + "17 0.000000 0.000000 0.161793 0.000000 0.000000 \n", + "18 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "19 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "20 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "21 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "22 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "23 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "24 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "\n", + " y valid_y \n", + "0 -0.005683 True \n", + "1 -0.096735 True \n", + "2 -0.041713 True \n", + "3 -0.082572 True \n", + "4 -0.045721 True \n", + "5 -0.097132 True \n", + "6 -0.619587 True \n", + "7 -0.003226 True \n", + "8 -0.006505 True \n", + "9 -0.054280 True \n", + "10 -0.005089 True \n", + "11 -0.002782 True \n", + "12 -0.005837 True \n", + "13 -0.005089 True \n", + "14 -0.059981 True \n", + "15 -0.003873 True \n", + "16 -0.005705 True \n", + "17 -0.001490 True \n", + "18 -1.003502 True \n", + "19 -0.024736 True \n", + "20 -0.003597 True \n", + "21 -0.071057 True \n", + "22 -0.006050 True \n", + "23 -0.067044 True \n", + "24 -0.004195 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.2561, 0.3035, 1.2614, 1.3028, 1.5076, 1.4546, 1.6381, 1.6320, 1.7198,\n", + " 1.9160, 1.4628, 0.4608]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.7022250.1195760.0000000.0000000.5141660.0000000.3982430.0000000.0000000.9116470.0000000.506541-0.007801True
10.0000000.1354060.8432900.0000000.0000000.0000000.0795280.2151900.7617940.0000000.7555720.000000-0.008543True
20.0000000.6351610.0000000.7574720.0403000.0000000.0585810.0000000.0000000.0000000.5398340.529786-0.082457True
30.0000000.0000000.0000000.5976710.1564120.0000000.9289830.9367160.0000000.6583720.0000000.631700-0.015096True
40.8185240.0000000.0000000.0000000.9194440.6829780.0000000.0000000.6966650.6427770.0000000.947762-0.000610True
50.0000000.0000000.5590090.4596630.0000000.0000000.0647950.0000000.6381050.9473500.3983820.000000-0.017627True
60.0000000.0068260.0000000.3473980.6099630.0000000.2067430.9459320.0000000.8443830.0000000.000000-0.011483True
70.4624120.7461500.0438960.6079460.0000000.0000000.0000000.0000000.0000000.0000000.3490180.801749-2.436279True
80.0000000.0000000.0000000.7446480.0000000.0000000.2045490.7418980.5717230.4318280.0000000.892975-0.001715True
90.0000000.7577750.0000000.0000000.0000000.9877370.0000000.0000000.0826880.0030900.4656500.160617-0.021754True
100.4143310.8977280.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.4302510.000000-0.111822True
111.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000001.0000000.000000-0.001019True
120.5108750.0000000.0000000.6945000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.000000-0.006296True
130.4373670.0000000.0000000.6891500.0000000.0000000.0000000.0000000.0000000.0000000.4106880.910016-0.007494True
141.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000000.0000001.0000001.000000-0.001019True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.702225 0.119576 0.000000 0.000000 0.514166 0.000000 0.398243 \n", + "1 0.000000 0.135406 0.843290 0.000000 0.000000 0.000000 0.079528 \n", + "2 0.000000 0.635161 0.000000 0.757472 0.040300 0.000000 0.058581 \n", + "3 0.000000 0.000000 0.000000 0.597671 0.156412 0.000000 0.928983 \n", + "4 0.818524 0.000000 0.000000 0.000000 0.919444 0.682978 0.000000 \n", + "5 0.000000 0.000000 0.559009 0.459663 0.000000 0.000000 0.064795 \n", + "6 0.000000 0.006826 0.000000 0.347398 0.609963 0.000000 0.206743 \n", + "7 0.462412 0.746150 0.043896 0.607946 0.000000 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.744648 0.000000 0.000000 0.204549 \n", + "9 0.000000 0.757775 0.000000 0.000000 0.000000 0.987737 0.000000 \n", + "10 0.414331 0.897728 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "11 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "12 0.510875 0.000000 0.000000 0.694500 0.000000 0.000000 0.000000 \n", + "13 0.437367 0.000000 0.000000 0.689150 0.000000 0.000000 0.000000 \n", + "14 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.000000 0.911647 0.000000 0.506541 \n", + "1 0.215190 0.761794 0.000000 0.755572 0.000000 \n", + "2 0.000000 0.000000 0.000000 0.539834 0.529786 \n", + "3 0.936716 0.000000 0.658372 0.000000 0.631700 \n", + "4 0.000000 0.696665 0.642777 0.000000 0.947762 \n", + "5 0.000000 0.638105 0.947350 0.398382 0.000000 \n", + "6 0.945932 0.000000 0.844383 0.000000 0.000000 \n", + "7 0.000000 0.000000 0.000000 0.349018 0.801749 \n", + "8 0.741898 0.571723 0.431828 0.000000 0.892975 \n", + "9 0.000000 0.082688 0.003090 0.465650 0.160617 \n", + "10 0.000000 0.000000 0.000000 0.430251 0.000000 \n", + "11 0.000000 0.000000 0.000000 1.000000 0.000000 \n", + "12 0.000000 0.000000 0.000000 0.000000 0.000000 \n", + "13 0.000000 0.000000 0.000000 0.410688 0.910016 \n", + "14 0.000000 0.000000 0.000000 1.000000 1.000000 \n", + "\n", + " y valid_y \n", + "0 -0.007801 True \n", + "1 -0.008543 True \n", + "2 -0.082457 True \n", + "3 -0.015096 True \n", + "4 -0.000610 True \n", + "5 -0.017627 True \n", + "6 -0.011483 True \n", + "7 -2.436279 True \n", + "8 -0.001715 True \n", + "9 -0.021754 True \n", + "10 -0.111822 True \n", + "11 -0.001019 True \n", + "12 -0.006296 True \n", + "13 -0.007494 True \n", + "14 -0.001019 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_2x_spurious_3x_spurious_4x_spurious_5yvalid_y
00.7022250.1195760.0000000.0000000.5141660.0000000.3982430.0000000.0000000.9116470.0000000.506541-0.007801True
10.0000000.1354060.8432900.0000000.0000000.0000000.0795280.2151900.7617940.0000000.7555720.000000-0.008543True
20.0000000.6351610.0000000.7574720.0403000.0000000.0585810.0000000.0000000.0000000.5398340.529786-0.082457True
30.0000000.0000000.0000000.5976710.1564120.0000000.9289830.9367160.0000000.6583720.0000000.631700-0.015096True
40.8185240.0000000.0000000.0000000.9194440.6829780.0000000.0000000.6966650.6427770.0000000.947762-0.000610True
50.0000000.0000000.5590090.4596630.0000000.0000000.0647950.0000000.6381050.9473500.3983820.000000-0.017627True
60.0000000.0068260.0000000.3473980.6099630.0000000.2067430.9459320.0000000.8443830.0000000.000000-0.011483True
70.4624120.7461500.0438960.6079460.0000000.0000000.0000000.0000000.0000000.0000000.3490180.801749-2.436279True
80.0000000.0000000.0000000.7446480.0000000.0000000.2045490.7418980.5717230.4318280.0000000.892975-0.001715True
90.0000000.7577750.0000000.0000000.0000000.9877370.0000000.0000000.0826880.0030900.4656500.160617-0.021754True
\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.702225 0.119576 0.000000 0.000000 0.514166 0.000000 0.398243 \n", + "1 0.000000 0.135406 0.843290 0.000000 0.000000 0.000000 0.079528 \n", + "2 0.000000 0.635161 0.000000 0.757472 0.040300 0.000000 0.058581 \n", + "3 0.000000 0.000000 0.000000 0.597671 0.156412 0.000000 0.928983 \n", + "4 0.818524 0.000000 0.000000 0.000000 0.919444 0.682978 0.000000 \n", + "5 0.000000 0.000000 0.559009 0.459663 0.000000 0.000000 0.064795 \n", + "6 0.000000 0.006826 0.000000 0.347398 0.609963 0.000000 0.206743 \n", + "7 0.462412 0.746150 0.043896 0.607946 0.000000 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.744648 0.000000 0.000000 0.204549 \n", + "9 0.000000 0.757775 0.000000 0.000000 0.000000 0.987737 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_2 x_spurious_3 x_spurious_4 x_spurious_5 \\\n", + "0 0.000000 0.000000 0.911647 0.000000 0.506541 \n", + "1 0.215190 0.761794 0.000000 0.755572 0.000000 \n", + "2 0.000000 0.000000 0.000000 0.539834 0.529786 \n", + "3 0.936716 0.000000 0.658372 0.000000 0.631700 \n", + "4 0.000000 0.696665 0.642777 0.000000 0.947762 \n", + "5 0.000000 0.638105 0.947350 0.398382 0.000000 \n", + "6 0.945932 0.000000 0.844383 0.000000 0.000000 \n", + "7 0.000000 0.000000 0.000000 0.349018 0.801749 \n", + "8 0.741898 0.571723 0.431828 0.000000 0.892975 \n", + "9 0.000000 0.082688 0.003090 0.465650 0.160617 \n", + "\n", + " y valid_y \n", + "0 -0.007801 True \n", + "1 -0.008543 True \n", + "2 -0.082457 True \n", + "3 -0.015096 True \n", + "4 -0.000610 True \n", + "5 -0.017627 True \n", + "6 -0.011483 True \n", + "7 -2.436279 True \n", + "8 -0.001715 True \n", + "9 -0.021754 True " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((5, 35, 42, 47, 48), {}, -4.6231695559517085)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mcts = MCTS(\n", + " groups=groups,\n", + " # Dummy reward function — we only use MCTS for tree traversal\n", + " # to sample valid NChooseK combinations.\n", + " reward_fn=reward_fn,\n", + " use_cache=False,\n", + " rollout_mode=\"uniform_subset\",\n", + ")\n", + "mcts.run(n_iterations=2000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
x_0x_1x_2x_3x_4x_5x_spurious_0x_spurious_1x_spurious_10x_spurious_11...x_spurious_41x_spurious_42x_spurious_43x_spurious_5x_spurious_6x_spurious_7x_spurious_8x_spurious_9yvalid_y
00.0000000.0000000.0290300.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.1170840.00.000000-0.005313True
10.3088960.0000000.0000000.5516740.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.011710True
20.0000000.0000000.8619450.0000000.00.8948000.0000000.00.0000000.0...0.0000000.00.000000.00.00.1033970.00.000000-0.279134True
30.0000000.7871880.0000000.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.007694True
40.0000000.0000000.8330570.9394680.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.000441True
50.0000000.4647900.0000000.1868110.00.0000000.0000000.00.0000000.0...0.8330060.00.000000.00.00.0000000.00.693337-0.018105True
60.0000000.0000000.0000000.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.005089True
70.0000000.0000000.0000000.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.005089True
80.0000000.0000000.0000000.0000000.00.0000000.0000000.00.0000000.0...0.0000000.00.903890.00.00.0000000.00.000000-0.005089True
90.0000000.0000000.0000000.0000000.00.0000000.9407920.00.7665960.0...0.0000000.00.000000.00.00.0000000.00.000000-0.005089True
100.0000000.0000000.0000000.0000000.01.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.070364True
110.0000000.0000000.0000000.0000000.01.0000000.0000000.00.0000000.0...0.0000000.00.000000.00.00.0000000.00.000000-0.070364True
120.0000000.0000000.0000000.0000000.00.5946860.0000000.00.0000001.0...0.0000000.00.000000.00.00.0000000.00.000000-0.166469True
130.0000000.0000000.0000000.0000000.00.5779770.0000000.00.0000000.0...1.0000000.00.000000.00.01.0000001.00.000000-0.163030True
140.0000000.0000000.0000000.0000000.00.5978800.0000000.00.0000000.0...0.0000000.00.000000.01.00.0000000.00.000000-0.167050True
\n", + "

15 rows × 52 columns

\n", + "
" + ], + "text/plain": [ + " x_0 x_1 x_2 x_3 x_4 x_5 x_spurious_0 \\\n", + "0 0.000000 0.000000 0.029030 0.000000 0.0 0.000000 0.000000 \n", + "1 0.308896 0.000000 0.000000 0.551674 0.0 0.000000 0.000000 \n", + "2 0.000000 0.000000 0.861945 0.000000 0.0 0.894800 0.000000 \n", + "3 0.000000 0.787188 0.000000 0.000000 0.0 0.000000 0.000000 \n", + "4 0.000000 0.000000 0.833057 0.939468 0.0 0.000000 0.000000 \n", + "5 0.000000 0.464790 0.000000 0.186811 0.0 0.000000 0.000000 \n", + "6 0.000000 0.000000 0.000000 0.000000 0.0 0.000000 0.000000 \n", + "7 0.000000 0.000000 0.000000 0.000000 0.0 0.000000 0.000000 \n", + "8 0.000000 0.000000 0.000000 0.000000 0.0 0.000000 0.000000 \n", + "9 0.000000 0.000000 0.000000 0.000000 0.0 0.000000 0.940792 \n", + "10 0.000000 0.000000 0.000000 0.000000 0.0 1.000000 0.000000 \n", + "11 0.000000 0.000000 0.000000 0.000000 0.0 1.000000 0.000000 \n", + "12 0.000000 0.000000 0.000000 0.000000 0.0 0.594686 0.000000 \n", + "13 0.000000 0.000000 0.000000 0.000000 0.0 0.577977 0.000000 \n", + "14 0.000000 0.000000 0.000000 0.000000 0.0 0.597880 0.000000 \n", + "\n", + " x_spurious_1 x_spurious_10 x_spurious_11 ... x_spurious_41 \\\n", + "0 0.0 0.000000 0.0 ... 0.000000 \n", + "1 0.0 0.000000 0.0 ... 0.000000 \n", + "2 0.0 0.000000 0.0 ... 0.000000 \n", + "3 0.0 0.000000 0.0 ... 0.000000 \n", + "4 0.0 0.000000 0.0 ... 0.000000 \n", + "5 0.0 0.000000 0.0 ... 0.833006 \n", + "6 0.0 0.000000 0.0 ... 0.000000 \n", + "7 0.0 0.000000 0.0 ... 0.000000 \n", + "8 0.0 0.000000 0.0 ... 0.000000 \n", + "9 0.0 0.766596 0.0 ... 0.000000 \n", + "10 0.0 0.000000 0.0 ... 0.000000 \n", + "11 0.0 0.000000 0.0 ... 0.000000 \n", + "12 0.0 0.000000 1.0 ... 0.000000 \n", + "13 0.0 0.000000 0.0 ... 1.000000 \n", + "14 0.0 0.000000 0.0 ... 0.000000 \n", + "\n", + " x_spurious_42 x_spurious_43 x_spurious_5 x_spurious_6 x_spurious_7 \\\n", + "0 0.0 0.00000 0.0 0.0 0.117084 \n", + "1 0.0 0.00000 0.0 0.0 0.000000 \n", + "2 0.0 0.00000 0.0 0.0 0.103397 \n", + "3 0.0 0.00000 0.0 0.0 0.000000 \n", + "4 0.0 0.00000 0.0 0.0 0.000000 \n", + "5 0.0 0.00000 0.0 0.0 0.000000 \n", + "6 0.0 0.00000 0.0 0.0 0.000000 \n", + "7 0.0 0.00000 0.0 0.0 0.000000 \n", + "8 0.0 0.90389 0.0 0.0 0.000000 \n", + "9 0.0 0.00000 0.0 0.0 0.000000 \n", + "10 0.0 0.00000 0.0 0.0 0.000000 \n", + "11 0.0 0.00000 0.0 0.0 0.000000 \n", + "12 0.0 0.00000 0.0 0.0 0.000000 \n", + "13 0.0 0.00000 0.0 0.0 1.000000 \n", + "14 0.0 0.00000 0.0 1.0 0.000000 \n", + "\n", + " x_spurious_8 x_spurious_9 y valid_y \n", + "0 0.0 0.000000 -0.005313 True \n", + "1 0.0 0.000000 -0.011710 True \n", + "2 0.0 0.000000 -0.279134 True \n", + "3 0.0 0.000000 -0.007694 True \n", + "4 0.0 0.000000 -0.000441 True \n", + "5 0.0 0.693337 -0.018105 True \n", + "6 0.0 0.000000 -0.005089 True \n", + "7 0.0 0.000000 -0.005089 True \n", + "8 0.0 0.000000 -0.005089 True \n", + "9 0.0 0.000000 -0.005089 True \n", + "10 0.0 0.000000 -0.070364 True \n", + "11 0.0 0.000000 -0.070364 True \n", + "12 0.0 0.000000 -0.166469 True \n", + "13 1.0 0.000000 -0.163030 True \n", + "14 0.0 0.000000 -0.167050 True \n", + "\n", + "[15 rows x 52 columns]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.experiments" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "2510" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import math\n", + "\n", + "\n", + "sum([math.comb(12, i) for i in range(7)])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57", + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'strategy' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mNameError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mstrategy\u001b[49m\n", + "\u001b[31mNameError\u001b[39m: name 'strategy' is not defined" + ] + } + ], + "source": [ + "strategy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58", + "metadata": {}, + "outputs": [], + "source": [ + "strategy.acqf_optimizer._candidates_tensor_to_dataframe(candidates, strategy.domain)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "59", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(5, 37, 48)\n", + "-4.677423986870725\n", + "-4.664086891505056\n" + ] + } + ], + "source": [ + "leaf, path = mcts._select_and_expand()\n", + "selected_features, cat_selections = mcts._get_selection(leaf)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))\n", + "print(reward_fn2(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-39.6933963563105\n", + "-10.756249958648718\n" + ] + } + ], + "source": [ + "print(\n", + " reward_fn(\n", + " (\n", + " 0,\n", + " 1,\n", + " 2,\n", + " 3,\n", + " 4,\n", + " ),\n", + " cat_selections={},\n", + " )\n", + ")\n", + "print(\n", + " reward_fn2(\n", + " (\n", + " 0,\n", + " 1,\n", + " 2,\n", + " 3,\n", + " 4,\n", + " ),\n", + " cat_selections={},\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 8.7714, 3.3755, 0.2549, 0.0688, 2.2844, 1.9798, 13.7240, 6.4729,\n", + " 3.0276, 0.1636, 12.4373, 12.4482]], dtype=torch.float64,\n", + " grad_fn=)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "strategy.model.covar_module.lengthscale" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "np.float64(-0.2791335862611258)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments.y.min()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(18,)\n", + "-32.97498755616551\n" + ] + } + ], + "source": [ + "selected_features, _, _ = mcts._rollout(mcts.root)\n", + "print(selected_features)\n", + "print(reward_fn(selected_features, cat_selections={}))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "np.float64(-0.4820260292929051)" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments.y.min()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-8.419034388799723" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reward_fn(\n", + " [\n", + " 31,\n", + " 49,\n", + " ],\n", + " cat_selections={},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x_3 0.0\n", + "dtype: float64" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiments[[\"x_3\"]].sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'x_3'" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "benchmark.domain.inputs.get_keys()[3]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 50])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bounds.shape" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "feature-mcts", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/mcts-report/unique_evals.png b/mcts-report/unique_evals.png new file mode 100644 index 000000000..545cab830 Binary files /dev/null and b/mcts-report/unique_evals.png differ diff --git a/mcts-report/unique_evals_nig.png b/mcts-report/unique_evals_nig.png new file mode 100644 index 000000000..ccf2551f3 Binary files /dev/null and b/mcts-report/unique_evals_nig.png differ diff --git a/mcts-report/unique_evals_nig_adaptive.png b/mcts-report/unique_evals_nig_adaptive.png new file mode 100644 index 000000000..ff390faa4 Binary files /dev/null and b/mcts-report/unique_evals_nig_adaptive.png differ diff --git a/mcts-report/unique_evals_nig_adaptive_n0.png b/mcts-report/unique_evals_nig_adaptive_n0.png new file mode 100644 index 000000000..bbc41fb9f Binary files /dev/null and b/mcts-report/unique_evals_nig_adaptive_n0.png differ diff --git a/tests/bofire/benchmarks/test_benchmark.py b/tests/bofire/benchmarks/test_benchmark.py index b5cd96c02..1d0695bd7 100644 --- a/tests/bofire/benchmarks/test_benchmark.py +++ b/tests/bofire/benchmarks/test_benchmark.py @@ -13,8 +13,11 @@ SyntheticBoTorch, ) from bofire.benchmarks.multi import ZDT1 -from bofire.benchmarks.single import Himmelblau -from bofire.data_models.constraints.api import LinearInequalityConstraint +from bofire.benchmarks.single import Hartmann, Himmelblau +from bofire.data_models.constraints.api import ( + LinearInequalityConstraint, + NChooseKConstraint, +) from bofire.data_models.features.api import ContinuousDescriptorInput from bofire.data_models.objectives.api import MinimizeObjective from bofire.data_models.strategies.api import RandomStrategy @@ -182,3 +185,51 @@ def test_SpuriousFeaturesWrapper(): candidates[benchmark.domain.inputs.get_keys()], return_complete=False ), ) + + +def test_SpuriousFeaturesWrapper_with_max_count(): + # Use Hartmann (bounds [0,1]) since NChooseK requires lower bound >= 0 + benchmark = Hartmann(dim=6) + wrapped = SpuriousFeaturesWrapper( + benchmark=benchmark, n_spurious_features=4, max_count=3 + ) + # 6 original + 4 spurious = 10 inputs + assert len(wrapped.domain.inputs) == 10 + # Should have an NChooseKConstraint + nchoosek_constraints = [ + c + for c in wrapped.domain.constraints.constraints + if isinstance(c, NChooseKConstraint) + ] + assert len(nchoosek_constraints) == 1 + nchoosek = nchoosek_constraints[0] + assert nchoosek.max_count == 3 + assert nchoosek.min_count == 0 + assert nchoosek.none_also_valid is True + assert len(nchoosek.features) == 10 # all features included + + # Without max_count, no NChooseK constraint + wrapped_no_nchoosek = SpuriousFeaturesWrapper( + benchmark=benchmark, n_spurious_features=4 + ) + nchoosek_constraints = [ + c + for c in wrapped_no_nchoosek.domain.constraints.constraints + if isinstance(c, NChooseKConstraint) + ] + assert len(nchoosek_constraints) == 0 + + # Evaluation still works (ignores spurious features) + candidates = pd.DataFrame( + { + **{f"x_{i}": [0.5, 0.3] for i in range(6)}, + **{f"x_spurious_{i}": [0.0, 0.0] for i in range(4)}, + } + ) + evaled = wrapped.f(candidates, return_complete=False) + assert_frame_equal( + evaled, + benchmark.f( + candidates[benchmark.domain.inputs.get_keys()], return_complete=False + ), + ) diff --git a/tests/bofire/strategies/test_optimize_mcts.py b/tests/bofire/strategies/test_optimize_mcts.py new file mode 100644 index 000000000..b48fa222d --- /dev/null +++ b/tests/bofire/strategies/test_optimize_mcts.py @@ -0,0 +1,2203 @@ +"""Tests for optimize_mcts module (NIG Thompson Sampling version).""" + +import math +import random as stdlib_random + +import pytest +import torch + +import bofire.strategies.predictives.optimize_mcts as optimize_mcts_mod +from bofire.strategies.predictives.optimize_mcts import ( + MCTS, + STOP, + ActionStats, + Categorical, + Groups, + NChooseK, + Node, + TrajectoryStep, + optimize_acqf_mcts, +) + + +# ============================================================================= +# NChooseK tests +# ============================================================================= + + +class TestNChooseK: + def test_basic_construction(self): + g = NChooseK(features=[0, 1, 2], min_count=1, max_count=2) + assert g.features == [0, 1, 2] + assert g.min_count == 1 + assert g.max_count == 2 + + def test_n_options(self): + g = NChooseK(features=[0, 1, 2, 3], min_count=1, max_count=3) + assert g.n_options == 4 + + def test_n_features(self): + g = NChooseK(features=[0, 2, 4], min_count=0, max_count=2) + assert g.n_features == 3 + + def test_non_contiguous_features(self): + g = NChooseK(features=[0, 5, 10], min_count=1, max_count=2) + assert g.n_features == 3 + assert g.n_options == 3 + + # --- Validation --- + + def test_min_count_negative_raises(self): + with pytest.raises(ValueError, match="Invalid NChooseK constraint"): + NChooseK(features=[0, 1, 2], min_count=-1, max_count=2) + + def test_min_count_greater_than_max_count_raises(self): + with pytest.raises(ValueError, match="Invalid NChooseK constraint"): + NChooseK(features=[0, 1, 2], min_count=3, max_count=2) + + def test_max_count_greater_than_n_raises(self): + with pytest.raises(ValueError, match="Invalid NChooseK constraint"): + NChooseK(features=[0, 1], min_count=1, max_count=3) + + def test_min_count_equals_max_count_valid(self): + g = NChooseK(features=[0, 1, 2], min_count=2, max_count=2) + assert g.min_count == 2 + assert g.max_count == 2 + + def test_zero_min_zero_max_valid(self): + g = NChooseK(features=[0, 1], min_count=0, max_count=0) + assert g.min_count == 0 + assert g.max_count == 0 + + def test_empty_features_zero_counts(self): + g = NChooseK(features=[], min_count=0, max_count=0) + assert g.n_features == 0 + + def test_empty_features_nonzero_counts_raises(self): + with pytest.raises(ValueError, match="Invalid NChooseK constraint"): + NChooseK(features=[], min_count=1, max_count=1) + + # --- legal_actions --- + + def test_legal_actions_empty_partial(self): + g = NChooseK(features=[0, 1, 2], min_count=1, max_count=2) + actions = g.legal_actions(partial=(), stopped=False) + # Can pick indices 0, 1, 2; cannot STOP yet (0 < min_count=1) + assert actions == [0, 1, 2] + + def test_legal_actions_stop_available_after_min_count(self): + g = NChooseK(features=[0, 1, 2], min_count=1, max_count=3) + actions = g.legal_actions(partial=(0,), stopped=False) + # Picked 1 item, min_count=1 met; can pick more (1, 2) or STOP + assert 1 in actions + assert 2 in actions + assert STOP in actions + + def test_legal_actions_stop_not_available_below_min_count(self): + g = NChooseK(features=[0, 1, 2, 3], min_count=2, max_count=3) + actions = g.legal_actions(partial=(0,), stopped=False) + # Only 1 picked, min_count=2 not met; STOP should NOT be available + assert STOP not in actions + + def test_legal_actions_enforces_increasing_order(self): + g = NChooseK(features=[0, 1, 2, 3, 4], min_count=1, max_count=5) + actions = g.legal_actions(partial=(2,), stopped=False) + # After picking index 2, only indices > 2 are legal (plus STOP) + non_stop = [a for a in actions if a != STOP] + assert all(a > 2 for a in non_stop) + + def test_legal_actions_when_stopped(self): + g = NChooseK(features=[0, 1, 2], min_count=0, max_count=2) + actions = g.legal_actions(partial=(0,), stopped=True) + assert actions == [] + + def test_legal_actions_at_max_count(self): + g = NChooseK(features=[0, 1, 2], min_count=1, max_count=2) + actions = g.legal_actions(partial=(0, 1), stopped=False) + assert actions == [] + + def test_legal_actions_min_count_constrains_upper_bound(self): + """When min_count requires more picks, high indices become unavailable + because there wouldn't be enough remaining indices to satisfy min_count.""" + g = NChooseK(features=[0, 1, 2, 3], min_count=3, max_count=4) + actions = g.legal_actions(partial=(), stopped=False) + # Need at least 3 picks. After picking index i with 0 picks so far, + # need 2 more from indices > i. So need n - (i+1) >= 2, i.e. i <= 1. + # Legal: 0, 1 (not 2, 3 because not enough room for 3 total) + assert actions == [0, 1] + + def test_legal_actions_min_equals_max(self): + """When min_count == max_count, STOP is never available until max is reached.""" + g = NChooseK(features=[0, 1, 2], min_count=2, max_count=2) + actions = g.legal_actions(partial=(), stopped=False) + # Must pick exactly 2; first pick must leave room for 1 more + assert STOP not in actions + actions2 = g.legal_actions(partial=(0,), stopped=False) + assert STOP not in actions2 + # After picking 2 items, max is reached + actions3 = g.legal_actions(partial=(0, 1), stopped=False) + assert actions3 == [] + + def test_legal_actions_zero_min_count_stop_immediate(self): + g = NChooseK(features=[0, 1, 2], min_count=0, max_count=2) + actions = g.legal_actions(partial=(), stopped=False) + # min_count=0, so STOP is immediately available + assert STOP in actions + assert 0 in actions + + def test_legal_actions_zero_max_count(self): + g = NChooseK(features=[0, 1], min_count=0, max_count=0) + actions = g.legal_actions(partial=(), stopped=False) + # max_count=0 and len(partial)==0 means already at max; no actions available + assert actions == [] + + # --- is_complete --- + + def test_is_complete_not_yet(self): + g = NChooseK(features=[0, 1, 2], min_count=1, max_count=3) + assert g.is_complete(partial=(0,), stopped=False) is False + + def test_is_complete_when_stopped(self): + g = NChooseK(features=[0, 1, 2], min_count=1, max_count=3) + assert g.is_complete(partial=(0,), stopped=True) is True + + def test_is_complete_at_max_count(self): + g = NChooseK(features=[0, 1, 2], min_count=1, max_count=2) + assert g.is_complete(partial=(0, 1), stopped=False) is True + + def test_is_complete_empty_partial_stopped(self): + g = NChooseK(features=[0, 1, 2], min_count=0, max_count=2) + assert g.is_complete(partial=(), stopped=True) is True + + def test_is_complete_zero_max(self): + g = NChooseK(features=[0, 1], min_count=0, max_count=0) + assert g.is_complete(partial=(), stopped=False) is True + + # --- frozen dataclass --- + + def test_frozen(self): + g = NChooseK(features=[0, 1], min_count=0, max_count=1) + with pytest.raises(AttributeError): + g.min_count = 5 # type: ignore[misc] + + +# ============================================================================= +# Categorical tests +# ============================================================================= + + +class TestCategorical: + def test_basic_construction(self): + c = Categorical(dim=3, values=[0.0, 1.0, 2.0]) + assert c.dim == 3 + assert c.values == [0.0, 1.0, 2.0] + + def test_n_options(self): + c = Categorical(dim=0, values=[10.0, 20.0, 30.0, 40.0]) + assert c.n_options == 4 + + # --- Validation --- + + def test_single_value_raises(self): + with pytest.raises(ValueError, match="at least two values"): + Categorical(dim=0, values=[1.0]) + + def test_empty_values_raises(self): + with pytest.raises(ValueError, match="at least two values"): + Categorical(dim=0, values=[]) + + def test_two_values_valid(self): + c = Categorical(dim=0, values=[0.0, 1.0]) + assert c.n_options == 2 + + # --- legal_actions --- + + def test_legal_actions_empty_partial(self): + c = Categorical(dim=0, values=[0.0, 1.0, 2.0]) + actions = c.legal_actions(partial=(), stopped=False) + assert actions == [0, 1, 2] + + def test_legal_actions_after_selection(self): + c = Categorical(dim=0, values=[0.0, 1.0, 2.0]) + actions = c.legal_actions(partial=(1,), stopped=False) + assert actions == [] + + def test_legal_actions_ignores_stopped(self): + """Categorical doesn't use the stopped flag, but should handle it.""" + c = Categorical(dim=0, values=[0.0, 1.0]) + actions = c.legal_actions(partial=(), stopped=True) + # Still returns all options since stopped is not used for Categorical + assert actions == [0, 1] + + # --- is_complete --- + + def test_is_complete_empty(self): + c = Categorical(dim=0, values=[0.0, 1.0]) + assert c.is_complete(partial=(), stopped=False) is False + + def test_is_complete_after_selection(self): + c = Categorical(dim=0, values=[0.0, 1.0]) + assert c.is_complete(partial=(0,), stopped=False) is True + + # --- frozen dataclass --- + + def test_frozen(self): + c = Categorical(dim=0, values=[0.0, 1.0]) + with pytest.raises(AttributeError): + c.dim = 5 # type: ignore[misc] + + +# ============================================================================= +# Groups tests +# ============================================================================= + + +class TestGroups: + def _make_mixed_groups(self): + nck1 = NChooseK(features=[0, 1, 2], min_count=1, max_count=2) + nck2 = NChooseK(features=[5, 6], min_count=1, max_count=2) + cat1 = Categorical(dim=3, values=[0.0, 1.0]) + cat2 = Categorical(dim=4, values=[10.0, 20.0, 30.0]) + return Groups(groups=[nck1, nck2, cat1, cat2]) + + def test_len(self): + gs = self._make_mixed_groups() + assert len(gs) == 4 + + def test_len_empty(self): + gs = Groups(groups=[]) + assert len(gs) == 0 + + def test_nchooseks_property(self): + gs = self._make_mixed_groups() + ncks = gs.nchooseks + assert len(ncks) == 2 + assert all(isinstance(g, NChooseK) for g in ncks) + + def test_categoricals_property(self): + gs = self._make_mixed_groups() + cats = gs.categoricals + assert len(cats) == 2 + assert all(isinstance(g, Categorical) for g in cats) + + def test_all_nchoosek_features(self): + gs = self._make_mixed_groups() + feats = gs.all_nchoosek_features + assert sorted(feats) == [0, 1, 2, 5, 6] + + def test_all_categorical_dims(self): + gs = self._make_mixed_groups() + dims = gs.all_categorical_dims + assert dims == [3, 4] + + def test_only_nchoosek(self): + nck = NChooseK(features=[0, 1], min_count=1, max_count=1) + gs = Groups(groups=[nck]) + assert len(gs.nchooseks) == 1 + assert len(gs.categoricals) == 0 + assert gs.all_nchoosek_features == [0, 1] + assert gs.all_categorical_dims == [] + + def test_only_categorical(self): + cat = Categorical(dim=7, values=[1.0, 2.0]) + gs = Groups(groups=[cat]) + assert len(gs.nchooseks) == 0 + assert len(gs.categoricals) == 1 + assert gs.all_nchoosek_features == [] + assert gs.all_categorical_dims == [7] + + def test_empty_groups(self): + gs = Groups(groups=[]) + assert gs.nchooseks == [] + assert gs.categoricals == [] + assert gs.all_nchoosek_features == [] + assert gs.all_categorical_dims == [] + + def test_feature_ordering_preserved(self): + """all_nchoosek_features preserves order from groups list.""" + nck1 = NChooseK(features=[5, 3], min_count=1, max_count=1) + nck2 = NChooseK(features=[1, 0], min_count=1, max_count=1) + gs = Groups(groups=[nck1, nck2]) + assert gs.all_nchoosek_features == [5, 3, 1, 0] + + +# ============================================================================= +# Node tests +# ============================================================================= + + +class TestNode: + """Tests for the MCTS Node dataclass.""" + + @staticmethod + def _two_group_groups() -> Groups: + """Helper: two groups (1 NChooseK + 1 Categorical).""" + nck = NChooseK(features=[0, 1, 2], min_count=1, max_count=2) + cat = Categorical(dim=3, values=[0.0, 1.0]) + return Groups(groups=[nck, cat]) + + # --- Construction & defaults --- + + def test_default_values(self): + node = Node( + partial_by_group=((), ()), + stopped_by_group=(False, False), + group_idx=0, + ) + assert node.n_visits == 0 + assert node.n_obs == 0 + assert node.sum_rewards == 0.0 + assert node.sum_sq_rewards == 0.0 + assert node.children == {} + + def test_explicit_values(self): + node = Node( + partial_by_group=((0, 1), (0,)), + stopped_by_group=(True, False), + group_idx=2, + n_obs=5, + sum_rewards=10.0, + sum_sq_rewards=30.0, + n_visits=10, + ) + assert node.partial_by_group == ((0, 1), (0,)) + assert node.stopped_by_group == (True, False) + assert node.group_idx == 2 + assert node.n_obs == 5 + assert node.sum_rewards == 10.0 + assert node.sum_sq_rewards == 30.0 + assert node.n_visits == 10 + + # --- is_terminal --- + + def test_is_terminal_false_at_start(self): + gs = self._two_group_groups() + node = Node( + partial_by_group=((), ()), + stopped_by_group=(False, False), + group_idx=0, + ) + assert node.is_terminal(gs) is False + + def test_is_terminal_false_mid_group(self): + gs = self._two_group_groups() + node = Node( + partial_by_group=((0,), ()), + stopped_by_group=(False, False), + group_idx=1, + ) + assert node.is_terminal(gs) is False + + def test_is_terminal_true_past_last_group(self): + gs = self._two_group_groups() + node = Node( + partial_by_group=((0,), (0,)), + stopped_by_group=(True, False), + group_idx=2, + ) + assert node.is_terminal(gs) is True + + def test_is_terminal_true_when_group_idx_exceeds_len(self): + gs = self._two_group_groups() + node = Node( + partial_by_group=((), ()), + stopped_by_group=(False, False), + group_idx=5, + ) + assert node.is_terminal(gs) is True + + def test_is_terminal_empty_groups(self): + gs = Groups(groups=[]) + node = Node( + partial_by_group=(), + stopped_by_group=(), + group_idx=0, + ) + assert node.is_terminal(gs) is True + + def test_is_terminal_single_group(self): + gs = Groups(groups=[NChooseK(features=[0, 1], min_count=1, max_count=1)]) + node_before = Node( + partial_by_group=((),), + stopped_by_group=(False,), + group_idx=0, + ) + node_after = Node( + partial_by_group=((0,),), + stopped_by_group=(False,), + group_idx=1, + ) + assert node_before.is_terminal(gs) is False + assert node_after.is_terminal(gs) is True + + # --- children dict --- + + def test_children_independent_per_node(self): + """Each node gets its own children dict (not shared via mutable default).""" + node_a = Node( + partial_by_group=((),), + stopped_by_group=(False,), + group_idx=0, + ) + node_b = Node( + partial_by_group=((),), + stopped_by_group=(False,), + group_idx=0, + ) + child = Node( + partial_by_group=((0,),), + stopped_by_group=(False,), + group_idx=1, + ) + node_a.children[0] = child + assert 0 in node_a.children + assert 0 not in node_b.children + + def test_children_keyed_by_action(self): + parent = Node( + partial_by_group=((), ()), + stopped_by_group=(False, False), + group_idx=0, + ) + child_0 = Node( + partial_by_group=((0,), ()), + stopped_by_group=(False, False), + group_idx=0, + ) + child_stop = Node( + partial_by_group=((), ()), + stopped_by_group=(True, False), + group_idx=1, + ) + parent.children[0] = child_0 + parent.children[STOP] = child_stop + assert parent.children[0] is child_0 + assert parent.children[STOP] is child_stop + assert len(parent.children) == 2 + + # --- mutability --- + + def test_mutable_nig_stats(self): + """Node is a regular (non-frozen) dataclass, so stats are mutable.""" + node = Node( + partial_by_group=((),), + stopped_by_group=(False,), + group_idx=0, + ) + node.n_visits += 1 + node.n_obs += 1 + node.sum_rewards += 3.5 + node.sum_sq_rewards += 3.5 * 3.5 + assert node.n_visits == 1 + assert node.n_obs == 1 + assert node.sum_rewards == 3.5 + assert node.sum_sq_rewards == pytest.approx(12.25) + + +# ============================================================================= +# MCTS tests +# ============================================================================= + + +class TestMCTS: + """Tests for the MCTS class.""" + + # ---- helpers ---- + + @staticmethod + def _nck_only_groups() -> Groups: + """Single NChooseK group: pick 1-2 from [0,1,2].""" + return Groups(groups=[NChooseK(features=[0, 1, 2], min_count=1, max_count=2)]) + + @staticmethod + def _cat_only_groups() -> Groups: + """Single Categorical group: dim 0, values [10, 20, 30].""" + return Groups(groups=[Categorical(dim=0, values=[10.0, 20.0, 30.0])]) + + @staticmethod + def _mixed_groups() -> Groups: + """NChooseK([0,1,2], 1, 2) + Categorical(dim=3, [0.0, 1.0]).""" + nck = NChooseK(features=[0, 1, 2], min_count=1, max_count=2) + cat = Categorical(dim=3, values=[0.0, 1.0]) + return Groups(groups=[nck, cat]) + + @staticmethod + def _constant_reward(value: float = 1.0): + """Reward function that always returns a constant.""" + return lambda _feats, _cats: value + + # ---- __init__ ---- + + def test_init_root_node(self): + gs = self._mixed_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=0) + assert mcts.root.group_idx == 0 + assert mcts.root.partial_by_group == ((), ()) + assert mcts.root.stopped_by_group == (False, False) + assert mcts.root.n_visits == 0 + + def test_init_defaults(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + assert mcts.nig_alpha0 == 1.0 + assert mcts.ts_prior_var == 1.0 + assert mcts.adaptive_prior_var is True + assert mcts.cache_hit_mode == "variance_inflation" + assert mcts.rollout_mode == "ts_group_action" + assert mcts.p_stop_rollout == 0.35 + assert mcts.pw_k0 == 2.0 + assert mcts.pw_alpha == 0.6 + assert mcts.best_value == float("-inf") + assert mcts.best_selection is None + + def test_init_custom_params(self): + gs = self._nck_only_groups() + mcts = MCTS( + groups=gs, + reward_fn=self._constant_reward(), + nig_alpha0=2.0, + ts_prior_var=0.5, + cache_hit_mode="pessimistic", + rollout_mode="uniform", + p_stop_rollout=0.5, + pw_k0=3.0, + pw_alpha=0.7, + seed=42, + ) + assert mcts.nig_alpha0 == 2.0 + assert mcts.ts_prior_var == 0.5 + assert mcts.cache_hit_mode == "pessimistic" + assert mcts.rollout_mode == "uniform" + assert mcts.p_stop_rollout == 0.5 + assert mcts.pw_k0 == 3.0 + assert mcts.pw_alpha == 0.7 + + # ---- _make_cache_key ---- + + def test_cache_key_deterministic(self): + gs = self._mixed_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + key1 = mcts._make_cache_key((0, 1), {3: 1.0}) + key2 = mcts._make_cache_key((0, 1), {3: 1.0}) + assert key1 == key2 + + def test_cache_key_different_for_different_selections(self): + gs = self._mixed_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + key_a = mcts._make_cache_key((0,), {3: 1.0}) + key_b = mcts._make_cache_key((1,), {3: 1.0}) + key_c = mcts._make_cache_key((0,), {3: 0.0}) + assert key_a != key_b + assert key_a != key_c + + def test_cache_key_order_independent_for_cat_dict(self): + """frozenset of dict items makes key independent of insertion order.""" + gs = Groups( + groups=[ + Categorical(dim=0, values=[0.0, 1.0]), + Categorical(dim=1, values=[2.0, 3.0]), + ] + ) + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + cat_a = {0: 1.0, 1: 2.0} + cat_b = {1: 2.0, 0: 1.0} + assert mcts._make_cache_key((), cat_a) == mcts._make_cache_key((), cat_b) + + # ---- _cached_reward ---- + + def test_cached_reward_calls_reward_fn_once(self): + call_count = 0 + + def counting_reward(_feats, _cats): + nonlocal call_count + call_count += 1 + return 5.0 + + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=counting_reward) + + val1 = mcts._cached_reward((0,), {}) + val2 = mcts._cached_reward((0,), {}) + assert val1 == 5.0 + assert val2 == 5.0 + assert call_count == 1 + + def test_cached_reward_different_keys_call_separately(self): + call_count = 0 + + def counting_reward(_feats, _cats): + nonlocal call_count + call_count += 1 + return float(call_count) + + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=counting_reward) + + mcts._cached_reward((0,), {}) + mcts._cached_reward((1,), {}) + assert call_count == 2 + + def test_cache_stats_after_cached_reward(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + mcts._cached_reward((0,), {}) + mcts._cached_reward((0,), {}) # hit + mcts._cached_reward((1,), {}) # miss + stats = mcts.cache_stats() + assert stats == {"hits": 1, "misses": 2, "size": 2} + + # ---- _child_limit (progressive widening) ---- + + def test_child_limit_zero_visits(self): + gs = self._nck_only_groups() + mcts = MCTS( + groups=gs, reward_fn=self._constant_reward(), pw_k0=2.0, pw_alpha=0.6 + ) + node = Node( + partial_by_group=((),), stopped_by_group=(False,), group_idx=0, n_visits=0 + ) + # max(1, int(2.0 * max(1, 0)**0.6)) = max(1, int(2.0 * 1)) = 2 + assert mcts._child_limit(node) == 2 + + def test_child_limit_increases_with_visits(self): + gs = self._nck_only_groups() + mcts = MCTS( + groups=gs, reward_fn=self._constant_reward(), pw_k0=2.0, pw_alpha=0.6 + ) + limits = [] + for v in [1, 10, 100]: + node = Node( + partial_by_group=((),), + stopped_by_group=(False,), + group_idx=0, + n_visits=v, + ) + limits.append(mcts._child_limit(node)) + # Should be monotonically non-decreasing + assert limits[0] <= limits[1] <= limits[2] + # With more visits, limit should grow + assert limits[2] > limits[0] + + # ---- _legal_actions ---- + + def test_legal_actions_delegates_to_group(self): + gs = self._nck_only_groups() # features [0,1,2], min=1, max=2 + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + node = Node(partial_by_group=((),), stopped_by_group=(False,), group_idx=0) + actions = mcts._legal_actions(node) + assert actions == [0, 1, 2] + + def test_legal_actions_terminal_returns_empty(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + node = Node(partial_by_group=((0,),), stopped_by_group=(False,), group_idx=1) + assert mcts._legal_actions(node) == [] + + # ---- _apply_action ---- + + def test_apply_action_regular(self): + gs = self._nck_only_groups() # features [0,1,2], min=1, max=2 + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + root = mcts.root # group_idx=0, partial=((),) + child = mcts._apply_action(root, 0) + assert child.partial_by_group[0] == (0,) + assert child.stopped_by_group[0] is False + # min=1 met but max=2 not reached; group not complete, stays at group 0 + assert child.group_idx == 0 + + def test_apply_action_completes_group(self): + gs = self._nck_only_groups() # max=2 + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + # After picking 0, pick 1 => 2 picked == max_count, group complete + node = Node(partial_by_group=((0,),), stopped_by_group=(False,), group_idx=0) + child = mcts._apply_action(node, 1) + assert child.partial_by_group[0] == (0, 1) + assert child.group_idx == 1 # advanced past group 0 + + def test_apply_action_stop(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + node = Node(partial_by_group=((0,),), stopped_by_group=(False,), group_idx=0) + child = mcts._apply_action(node, STOP) + assert child.stopped_by_group[0] is True + assert child.partial_by_group[0] == (0,) # unchanged + assert child.group_idx == 1 # advanced + + def test_apply_action_categorical(self): + gs = self._cat_only_groups() # dim=0, values=[10,20,30] + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + root = mcts.root + child = mcts._apply_action(root, 2) # pick index 2 (value 30) + assert child.partial_by_group[0] == (2,) + # Categorical is complete after 1 pick + assert child.group_idx == 1 + + def test_apply_action_does_not_mutate_parent(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + root = mcts.root + original_partial = root.partial_by_group + original_stopped = root.stopped_by_group + mcts._apply_action(root, 0) + assert root.partial_by_group == original_partial + assert root.stopped_by_group == original_stopped + + def test_apply_action_mixed_groups_advances_correctly(self): + gs = self._mixed_groups() # NChooseK + Categorical + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + # Complete NChooseK by stopping after 1 pick + node = Node( + partial_by_group=((0,), ()), + stopped_by_group=(False, False), + group_idx=0, + ) + after_stop = mcts._apply_action(node, STOP) + assert after_stop.group_idx == 1 # now on categorical group + # Pick categorical value + after_cat = mcts._apply_action(after_stop, 0) + assert after_cat.group_idx == 2 # past all groups (terminal) + assert after_cat.partial_by_group == ((0,), (0,)) + + # ---- _get_selection ---- + + def test_get_selection_nchoosek_only(self): + gs = self._nck_only_groups() # features=[0,1,2] + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + # Picked local indices 0 and 2 => features 0 and 2 + node = Node(partial_by_group=((0, 2),), stopped_by_group=(False,), group_idx=1) + feats, cats = mcts._get_selection(node) + assert feats == (0, 2) + assert cats == {} + + def test_get_selection_cat_only(self): + gs = self._cat_only_groups() # dim=0, values=[10,20,30] + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + node = Node(partial_by_group=((1,),), stopped_by_group=(False,), group_idx=1) + feats, cats = mcts._get_selection(node) + assert feats == () + assert cats == {0: 20.0} + + def test_get_selection_mixed(self): + gs = self._mixed_groups() # NChooseK features=[0,1,2], Cat dim=3 vals=[0,1] + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + # NChooseK picked local 1 => feature 1; Cat picked local 1 => value 1.0 + node = Node( + partial_by_group=((1,), (1,)), + stopped_by_group=(True, False), + group_idx=2, + ) + feats, cats = mcts._get_selection(node) + assert feats == (1,) + assert cats == {3: 1.0} + + def test_get_selection_sorts_features(self): + nck = NChooseK(features=[5, 1, 3], min_count=2, max_count=2) + gs = Groups(groups=[nck]) + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + # Picked local indices 0, 2 => features 5, 3 => sorted (3, 5) + node = Node(partial_by_group=((0, 2),), stopped_by_group=(False,), group_idx=1) + feats, _cats = mcts._get_selection(node) + assert feats == (3, 5) + + def test_get_selection_empty_partial(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + node = Node(partial_by_group=((),), stopped_by_group=(True,), group_idx=1) + feats, cats = mcts._get_selection(node) + assert feats == () + assert cats == {} + + # ---- _rollout ---- + + def test_rollout_reaches_terminal(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=42) + feats, cats, _traj = mcts._rollout(mcts.root) + # Should return a valid selection with at least min_count=1 features + assert len(feats) >= 1 + assert all(f in [0, 1, 2] for f in feats) + assert cats == {} + + def test_rollout_categorical_selects_one(self): + gs = self._cat_only_groups() # dim=0, values=[10,20,30] + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=7) + feats, cats, _traj = mcts._rollout(mcts.root) + assert feats == () + assert 0 in cats + assert cats[0] in [10.0, 20.0, 30.0] + + def test_rollout_mixed(self): + gs = self._mixed_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=99) + feats, cats, _traj = mcts._rollout(mcts.root) + assert len(feats) >= 1 + assert all(f in [0, 1, 2] for f in feats) + assert 3 in cats + assert cats[3] in [0.0, 1.0] + + def test_rollout_deterministic_with_seed(self): + gs = self._mixed_groups() + results = [] + for _ in range(2): + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=123) + results.append(mcts._rollout(mcts.root)) + assert results[0] == results[1] + + def test_rollout_does_not_mutate_input_node(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=0) + root = mcts.root + original_partial = root.partial_by_group + original_stopped = root.stopped_by_group + original_idx = root.group_idx + mcts._rollout(root) + assert root.partial_by_group == original_partial + assert root.stopped_by_group == original_stopped + assert root.group_idx == original_idx + + # ---- _backpropagate ---- + + def test_backpropagate_novel_updates_all_stats(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + root = mcts.root + child = Node(partial_by_group=((0,),), stopped_by_group=(False,), group_idx=0) + path = [root, child] + mcts._backpropagate(path, reward=3.0, is_novel=True) + assert root.n_visits == 1 + assert root.n_obs == 1 + assert root.sum_rewards == 3.0 + assert root.sum_sq_rewards == pytest.approx(9.0) + assert child.n_visits == 1 + assert child.n_obs == 1 + + def test_backpropagate_novel_accumulates(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + root = mcts.root + path = [root] + mcts._backpropagate(path, reward=2.0, is_novel=True) + mcts._backpropagate(path, reward=5.0, is_novel=True) + assert root.n_visits == 2 + assert root.n_obs == 2 + assert root.sum_rewards == pytest.approx(7.0) + assert root.sum_sq_rewards == pytest.approx(4.0 + 25.0) + + def test_backpropagate_cache_hit_no_update(self): + """cache_hit_mode='no_update' only increments n_visits.""" + gs = self._nck_only_groups() + mcts = MCTS( + groups=gs, + reward_fn=self._constant_reward(), + cache_hit_mode="no_update", + ) + root = mcts.root + path = [root] + # First a novel observation + mcts._backpropagate(path, reward=5.0, is_novel=True) + assert root.n_obs == 1 + assert root.n_visits == 1 + # Cache hit + mcts._backpropagate(path, reward=5.0, is_novel=False) + assert root.n_obs == 1 # unchanged + assert root.n_visits == 2 # incremented + + # ---- _select_and_expand ---- + + def test_select_and_expand_first_call_expands_root(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=0) + leaf, path = mcts._select_and_expand() + # Should have expanded one child from root + assert len(path) == 2 + assert path[0] is mcts.root + assert len(mcts.root.children) == 1 + + def test_select_and_expand_grows_tree(self): + gs = self._cat_only_groups() # 3 options, single step per branch + mcts = MCTS( + groups=gs, + reward_fn=self._constant_reward(), + seed=0, + pw_k0=10.0, # large widening to allow many children + pw_alpha=0.0, # constant limit = max(1, int(10 * 1)) = 10 + ) + # Expand several times, backpropagating to allow further expansion + for _ in range(3): + leaf, path = mcts._select_and_expand() + mcts._backpropagate(path, reward=1.0, is_novel=True) + # Root should have up to 3 children (one per expansion) + assert len(mcts.root.children) >= 2 + + # ---- run (integration) ---- + + def test_run_returns_valid_selection_nchoosek(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(1.0), seed=42) + feats, cats, val = mcts.run(n_iterations=20) + assert len(feats) >= 1 + assert len(feats) <= 2 + assert all(f in [0, 1, 2] for f in feats) + assert cats == {} + assert val == 1.0 + + def test_run_returns_valid_selection_categorical(self): + gs = self._cat_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(2.0), seed=42) + feats, cats, val = mcts.run(n_iterations=20) + assert feats == () + assert 0 in cats + assert cats[0] in [10.0, 20.0, 30.0] + assert val == 2.0 + + def test_run_returns_valid_selection_mixed(self): + gs = self._mixed_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(3.0), seed=42) + feats, cats, val = mcts.run(n_iterations=30) + assert 1 <= len(feats) <= 2 + assert all(f in [0, 1, 2] for f in feats) + assert 3 in cats + assert cats[3] in [0.0, 1.0] + assert val == 3.0 + + def test_run_finds_best_reward(self): + """MCTS should track the best reward across iterations.""" + gs = self._nck_only_groups() + + def reward_by_count(feats, _cats): + return float(len(feats)) + + mcts = MCTS(groups=gs, reward_fn=reward_by_count, seed=42) + feats, _cats, val = mcts.run(n_iterations=50) + # Best possible: pick 2 features (max_count=2) => reward 2.0 + assert val == 2.0 + assert len(feats) == 2 + + def test_run_deterministic_with_seed(self): + gs = self._mixed_groups() + results = [] + for _ in range(2): + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=77) + results.append(mcts.run(n_iterations=30)) + assert results[0] == results[1] + + def test_run_updates_root_visits(self): + gs = self._nck_only_groups() + n_iter = 25 + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=0) + mcts.run(n_iterations=n_iter) + assert 0 < mcts.root.n_visits <= n_iter + + def test_run_populates_cache(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=0) + mcts.run(n_iterations=30) + stats = mcts.cache_stats() + assert stats["size"] > 0 + assert stats["misses"] > 0 + # With repeated selections, should have some hits + assert stats["hits"] + stats["misses"] == 30 + + def test_run_zero_iterations(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward(), seed=0) + feats, cats, val = mcts.run(n_iterations=0) + assert feats == () + assert cats == {} + assert val == float("-inf") + + def test_run_prefers_higher_reward(self): + """Given a reward function that varies by selection, MCTS should prefer better ones.""" + nck = NChooseK(features=[0, 1, 2], min_count=1, max_count=1) + gs = Groups(groups=[nck]) + + def reward_fn(feats, _cats): + # Feature 2 gives the best reward + if 2 in feats: + return 10.0 + if 1 in feats: + return 5.0 + return 1.0 + + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=42) + feats, _cats, val = mcts.run(n_iterations=100) + # With enough iterations, MCTS should find the best + assert val == 10.0 + assert 2 in feats + + # ---- cache_stats ---- + + def test_cache_stats_initial(self): + gs = self._nck_only_groups() + mcts = MCTS(groups=gs, reward_fn=self._constant_reward()) + assert mcts.cache_stats() == {"hits": 0, "misses": 0, "size": 0} + + +# ============================================================================= +# NIG Sampling tests +# ============================================================================= + + +class TestNIGSampling: + """Tests for the NIG Thompson Sampling methods.""" + + @staticmethod + def _simple_groups() -> Groups: + return Groups(groups=[NChooseK(features=[0, 1, 2], min_count=1, max_count=2)]) + + def test_student_t_sample_returns_float(self): + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=42) + sample = mcts._student_t_sample(df=4.0, loc=5.0, scale=1.0) + assert isinstance(sample, float) + + def test_student_t_sample_deterministic_seed(self): + gs = self._simple_groups() + samples = [] + for _ in range(2): + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=42) + samples.append(mcts._student_t_sample(df=4.0, loc=5.0, scale=1.0)) + assert samples[0] == samples[1] + + def test_student_t_sample_centered(self): + """Student-t samples should be centered around loc.""" + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + loc = 10.0 + samples = [ + mcts._student_t_sample(df=10.0, loc=loc, scale=1.0) for _ in range(5000) + ] + mean = sum(samples) / len(samples) + assert abs(mean - loc) < 0.5 # should be close to loc + + def test_nig_sample_prior_only(self): + """With no observations, _nig_sample returns from the prior.""" + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=42) + sample = mcts._nig_sample(n_obs=0, sum_rewards=0.0, sum_sq_rewards=0.0) + assert isinstance(sample, float) + + def test_nig_sample_with_data_concentrates(self): + """With many consistent observations, samples should cluster near mean.""" + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + # Simulate 100 observations of reward=5.0 + n = 100 + mean_reward = 5.0 + sum_r = n * mean_reward + sum_sq = n * mean_reward * mean_reward + samples = [mcts._nig_sample(n, sum_r, sum_sq) for _ in range(1000)] + sample_mean = sum(samples) / len(samples) + assert abs(sample_mean - mean_reward) < 0.5 + + def test_nig_sample_wide_with_few_observations(self): + """With few observations, samples should be more spread out.""" + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + # 2 observations with different values + samples_few = [mcts._nig_sample(2, 10.0, 52.0) for _ in range(1000)] + # 100 observations tightly clustered around 5.0 + samples_many = [mcts._nig_sample(100, 500.0, 2500.0) for _ in range(1000)] + var_few = sum((s - sum(samples_few) / 1000) ** 2 for s in samples_few) / 1000 + var_many = sum((s - sum(samples_many) / 1000) ** 2 for s in samples_many) / 1000 + assert var_few > var_many + + def test_global_mean_no_data(self): + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + assert mcts._global_mean() == 0.0 + + def test_global_mean_with_data(self): + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + mcts._novel_reward_sum = 30.0 + mcts._novel_reward_count = 3 + assert mcts._global_mean() == pytest.approx(10.0) + + def test_prior_var_fixed(self): + gs = self._simple_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + adaptive_prior_var=False, + ts_prior_var=2.5, + seed=0, + ) + assert mcts._prior_var() == 2.5 + + def test_prior_var_adaptive_insufficient_data(self): + gs = self._simple_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + adaptive_prior_var=True, + ts_prior_var=2.5, + seed=0, + ) + # Only 1 observation => not enough for empirical variance + mcts._novel_reward_count = 1 + mcts._novel_reward_sum = 5.0 + mcts._novel_reward_sq_sum = 25.0 + assert mcts._prior_var() == 2.5 + + def test_prior_var_adaptive_with_data(self): + gs = self._simple_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + adaptive_prior_var=True, + ts_prior_var=2.5, + seed=0, + ) + # 4 observations: [2, 4, 6, 8] => mean=5, var=5 + mcts._novel_reward_count = 4 + mcts._novel_reward_sum = 20.0 + mcts._novel_reward_sq_sum = 120.0 # 4+16+36+64 + # empirical_var = 120/4 - 5^2 = 30 - 25 = 5 + assert mcts._prior_var() == pytest.approx(5.0) + + def test_compute_n0_fixed(self): + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, adaptive_n0=False, seed=0) + assert mcts._compute_n0(10) == 1.0 + + def test_compute_n0_adaptive(self): + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, adaptive_n0=True, seed=0) + n0 = mcts._compute_n0(10) + assert n0 == pytest.approx(1.0 + math.log(10)) + + def test_pessimistic_value_no_data(self): + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, ts_prior_var=4.0, seed=0) + # With no data: mean=0, std=sqrt(4)=2 => pessimistic=-2 + assert mcts._pessimistic_value() == pytest.approx(-2.0) + + def test_pessimistic_value_with_data(self): + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + # Observations: mean=10, empirical_var=4, std=2 + mcts._novel_reward_count = 10 + mcts._novel_reward_sum = 100.0 + mcts._novel_reward_sq_sum = 1040.0 # 10*10^2 + 10*4 = 1040 => var=104-100=4 + pess = mcts._pessimistic_value() + assert pess == pytest.approx(10.0 - 2.0) + + +# ============================================================================= +# NIG Backpropagation tests +# ============================================================================= + + +class TestNIGBackpropagation: + """Tests for NIG-aware backpropagation with various cache-hit modes.""" + + @staticmethod + def _simple_groups() -> Groups: + return Groups(groups=[NChooseK(features=[0, 1, 2], min_count=1, max_count=2)]) + + def test_novel_updates_all_fields(self): + gs = self._simple_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + node = Node(partial_by_group=((),), stopped_by_group=(False,), group_idx=0) + path = [node] + mcts._backpropagate(path, reward=3.0, is_novel=True) + assert node.n_obs == 1 + assert node.sum_rewards == 3.0 + assert node.sum_sq_rewards == pytest.approx(9.0) + assert node.n_visits == 1 + + def test_cache_hit_variance_inflation(self): + gs = self._simple_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + cache_hit_mode="variance_inflation", + variance_decay=0.5, + seed=0, + ) + node = Node( + partial_by_group=((),), + stopped_by_group=(False,), + group_idx=0, + n_obs=10, + sum_rewards=50.0, + sum_sq_rewards=300.0, + n_visits=10, + ) + path = [node] + mcts._backpropagate(path, reward=5.0, is_novel=False) + # n_visits incremented + assert node.n_visits == 11 + # n_obs decayed: int(10 * 0.5) = 5 + assert node.n_obs == 5 + # sum_rewards rescaled: mean=5.0, new_sum=5.0*5=25.0 + assert node.sum_rewards == pytest.approx(25.0) + + def test_cache_hit_pessimistic(self): + gs = self._simple_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + cache_hit_mode="pessimistic", + ts_prior_var=1.0, + seed=0, + ) + # Set some novel reward stats so pessimistic_value is defined + mcts._novel_reward_count = 10 + mcts._novel_reward_sum = 100.0 + mcts._novel_reward_sq_sum = 1010.0 # var=1, std=1 + pess = mcts._pessimistic_value() # 10 - 1 = 9 + + node = Node( + partial_by_group=((),), + stopped_by_group=(False,), + group_idx=0, + n_obs=5, + sum_rewards=50.0, + sum_sq_rewards=510.0, + n_visits=5, + ) + path = [node] + mcts._backpropagate(path, reward=10.0, is_novel=False) + assert node.n_visits == 6 + assert node.n_obs == 6 + assert node.sum_rewards == pytest.approx(50.0 + pess) + assert node.sum_sq_rewards == pytest.approx(510.0 + pess * pess) + + def test_cache_hit_no_update_leaves_stats(self): + gs = self._simple_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + cache_hit_mode="no_update", + seed=0, + ) + node = Node( + partial_by_group=((),), + stopped_by_group=(False,), + group_idx=0, + n_obs=5, + sum_rewards=50.0, + sum_sq_rewards=300.0, + n_visits=5, + ) + path = [node] + mcts._backpropagate(path, reward=10.0, is_novel=False) + assert node.n_visits == 6 + assert node.n_obs == 5 # unchanged + assert node.sum_rewards == 50.0 # unchanged + + def test_cache_hit_combined(self): + """Combined mode does both variance inflation and pessimistic.""" + gs = self._simple_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + cache_hit_mode="combined", + variance_decay=0.5, + ts_prior_var=1.0, + seed=0, + ) + mcts._novel_reward_count = 10 + mcts._novel_reward_sum = 100.0 + mcts._novel_reward_sq_sum = 1010.0 # var=1, pessimistic_value = 10 - 1 = 9 + + node = Node( + partial_by_group=((),), + stopped_by_group=(False,), + group_idx=0, + n_obs=10, + sum_rewards=100.0, + sum_sq_rewards=1010.0, + n_visits=10, + ) + path = [node] + mcts._backpropagate(path, reward=10.0, is_novel=False) + assert node.n_visits == 11 + # Variance inflation: n_obs 10 -> 5, then pessimistic adds 1 -> 6 + assert node.n_obs == 6 + + +# ============================================================================= +# NIG Rollout tests +# ============================================================================= + + +class TestNIGRollout: + """Tests for NIG Thompson Sampling rollout.""" + + @staticmethod + def _nck_groups() -> Groups: + return Groups( + groups=[NChooseK(features=[0, 1, 2, 3, 4], min_count=1, max_count=3)] + ) + + @staticmethod + def _mixed_groups() -> Groups: + nck = NChooseK(features=[0, 1, 2], min_count=1, max_count=2) + cat = Categorical(dim=3, values=[0.0, 1.0]) + return Groups(groups=[nck, cat]) + + def test_ts_rollout_returns_legal_action(self): + gs = self._nck_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + rollout_mode="ts_group_action", + seed=42, + ) + legal = [0, 1, 2, STOP] + for _ in range(20): + action = mcts._ts_sample_rollout_action(0, legal) + assert action in legal + + def test_ts_rollout_biases_toward_high_reward(self): + """After seeding TS stats, rollout should favor high-reward actions.""" + gs = Groups(groups=[NChooseK(features=[0, 1, 2], min_count=1, max_count=1)]) + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + rollout_mode="ts_group_action", + seed=None, + ) + # Make action 2 strongly preferred (many observations with high reward) + mcts.rollout_ts_stats[(0, 2)] = ActionStats(100, 1000.0, 10000.0) # mean=10 + mcts.rollout_ts_stats[(0, 0)] = ActionStats(100, 100.0, 100.0) # mean=1 + mcts.rollout_ts_stats[(0, 1)] = ActionStats(100, 100.0, 100.0) # mean=1 + # Set global stats so NIG prior is reasonable + mcts._novel_reward_count = 300 + mcts._novel_reward_sum = 1200.0 + mcts._novel_reward_sq_sum = 10200.0 + + counts = {0: 0, 1: 0, 2: 0} + for _ in range(500): + feats, _cats, _traj = mcts._rollout(mcts.root) + for f in feats: + counts[f] += 1 + assert counts[2] > counts[0] + assert counts[2] > counts[1] + + def test_update_rollout_ts_stats(self): + gs = self._nck_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + trajectory = [TrajectoryStep(0, 1), TrajectoryStep(0, 2)] + mcts._update_rollout_ts_stats(trajectory, 5.0) + assert mcts.rollout_ts_stats[(0, 1)] == ActionStats(1, 5.0, 25.0) + assert mcts.rollout_ts_stats[(0, 2)] == ActionStats(1, 5.0, 25.0) + + def test_update_rollout_ts_stats_accumulates(self): + gs = self._nck_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + mcts._update_rollout_ts_stats([TrajectoryStep(0, 1)], 3.0) + mcts._update_rollout_ts_stats([TrajectoryStep(0, 1)], 7.0) + stats = mcts.rollout_ts_stats[(0, 1)] + assert stats.n_obs == 2 + assert stats.sum_rewards == 10.0 + assert stats.sum_sq_rewards == pytest.approx(9.0 + 49.0) + + def test_update_rollout_ts_stats_empty_trajectory(self): + gs = self._nck_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + mcts._update_rollout_ts_stats([], 5.0) + assert mcts.rollout_ts_stats == {} + + def test_rollout_returns_three_tuple(self): + gs = self._mixed_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=42) + result = mcts._rollout(mcts.root) + assert len(result) == 3 + feats, cats, trajectory = result + assert isinstance(feats, tuple) + assert isinstance(cats, dict) + assert isinstance(trajectory, list) + + def test_rollout_trajectory_records_actions(self): + gs = self._nck_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + rollout_mode="ts_group_action", + seed=42, + ) + _feats, _cats, trajectory = mcts._rollout(mcts.root) + assert len(trajectory) > 0 + for g, a in trajectory: + assert isinstance(g, int) + assert isinstance(a, int) + + def test_uniform_rollout_records_trajectory(self): + """Trajectory collected in uniform rollout mode too.""" + gs = self._nck_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + rollout_mode="uniform", + seed=42, + ) + _feats, _cats, trajectory = mcts._rollout(mcts.root) + assert len(trajectory) > 0 + + def test_rollout_ts_stats_populated_after_run(self): + """rollout_ts_stats is non-empty after run() with ts_group_action mode.""" + gs = self._nck_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: float(len(f)), + rollout_mode="ts_group_action", + seed=0, + ) + mcts.run(n_iterations=50) + assert len(mcts.rollout_ts_stats) > 0 + + def test_run_with_uniform_rollout(self): + """Verify MCTS works with uniform rollout mode.""" + gs = Groups(groups=[NChooseK(features=[0, 1, 2], min_count=1, max_count=2)]) + + def reward_fn(feats, _cats): + return float(len(feats)) * 10.0 + + mcts = MCTS( + groups=gs, + reward_fn=reward_fn, + rollout_mode="uniform", + seed=42, + ) + _f, _c, val = mcts.run(n_iterations=50) + assert val == 20.0 + + def test_run_with_ts_rollout_converges(self): + """Convergence on needle problem with TS rollout.""" + g = NChooseK(features=list(range(10)), min_count=2, max_count=3) + gs = Groups(groups=[g]) + target = {3, 7} + + def reward_fn(feats, _cats): + feat_set = set(feats) + if feat_set == target: + return 100.0 + overlap = len(feat_set & target) + extras = len(feat_set - target) + return overlap * 20.0 - extras * 5.0 + + mcts = MCTS( + groups=gs, + reward_fn=reward_fn, + rollout_mode="ts_group_action", + seed=42, + ) + _feats, _cats, val = mcts.run(n_iterations=300) + assert val == 100.0 + + def test_rollout_mode_default_is_ts_group_action(self): + gs = self._nck_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=0) + assert mcts.rollout_mode == "ts_group_action" + + # ---- use_cache=False (no-cache mode) ---- + + def test_no_cache_default_is_true(self): + """Default use_cache=True preserves existing behavior.""" + gs = self._nck_groups() + mcts = MCTS(groups=gs, reward_fn=lambda f, c: 0.0, seed=42) + assert mcts.use_cache is True + + def test_no_cache_always_novel(self): + """With use_cache=False, every evaluation is novel (no cache hits).""" + call_count = 0 + + def counting_reward(feats, cats): + nonlocal call_count + call_count += 1 + return sum(feats) * 1.0 + + gs = self._nck_groups() + mcts = MCTS(groups=gs, reward_fn=counting_reward, use_cache=False, seed=42) + mcts.run(n_iterations=20) + stats = mcts.cache_stats() + assert stats["hits"] == 0 # no cache hits ever + assert stats["size"] == 0 # cache stays empty + assert stats["misses"] == call_count # every call tracked as miss + assert call_count == 20 # reward_fn called every iteration + + def test_no_cache_noisy_rewards(self): + """No-cache mode accumulates noisy observations correctly.""" + rng = stdlib_random.Random(42) + + def noisy_reward(feats, _cats): + base = sum(feats) * 10.0 + return base + rng.gauss(0, 1.0) + + gs = self._nck_groups() + mcts = MCTS(groups=gs, reward_fn=noisy_reward, use_cache=False, seed=42) + mcts.run(n_iterations=50) + # Should still find reasonable best value + assert mcts.best_value > 0 + # Novel count should equal total iterations + assert mcts._novel_reward_count == 50 + + def test_cache_true_has_hits(self): + """With use_cache=True (default), the cache is populated and hit.""" + gs = self._nck_groups() + mcts = MCTS( + groups=gs, + reward_fn=lambda f, c: sum(f) * 1.0, + seed=42, + ) + mcts.run(n_iterations=30) + stats = mcts.cache_stats() + assert stats["size"] > 0 # cache populated + # With 5-choose-1-to-3 and 30 iterations, should have hits + assert stats["hits"] > 0 + + def test_no_cache_reward_fn_called_every_time(self): + """With use_cache=False, reward_fn is called on every iteration even for + the same selection.""" + calls = [] + + def tracking_reward(feats, cats): + calls.append((feats, cats)) + return 1.0 + + # Deterministic group with only one possible selection + gs = Groups(groups=[NChooseK(features=[0], min_count=1, max_count=1)]) + mcts = MCTS(groups=gs, reward_fn=tracking_reward, use_cache=False, seed=0) + mcts.run(n_iterations=10) + # Every iteration should call reward_fn even though same selection + assert len(calls) == 10 + + +# ============================================================================= +# MCTS convergence / integration tests +# ============================================================================= + + +class TestMCTSConvergence: + """Integration tests verifying that MCTS converges to known optima.""" + + def test_nchoosek_single_optimum(self): + """5-choose-2 problem where exactly one pair is optimal.""" + nck = NChooseK(features=[0, 1, 2, 3, 4], min_count=2, max_count=2) + gs = Groups(groups=[nck]) + + def reward_fn(feats, _cats): + return 100.0 if set(feats) == {1, 3} else 0.0 + + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=0) + feats, cats, val = mcts.run(n_iterations=200) + assert val == 100.0 + assert set(feats) == {1, 3} + assert cats == {} + + def test_categorical_single_optimum(self): + """Two categorical dimensions; only one combination is optimal.""" + cat1 = Categorical(dim=0, values=[0.0, 1.0, 2.0]) + cat2 = Categorical(dim=1, values=[10.0, 20.0]) + gs = Groups(groups=[cat1, cat2]) + + def reward_fn(_feats, cats): + if cats.get(0) == 2.0 and cats.get(1) == 10.0: + return 50.0 + return 1.0 + + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=42) + feats, cats, val = mcts.run(n_iterations=100) + assert val == 50.0 + assert cats == {0: 2.0, 1: 10.0} + assert feats == () + + def test_mixed_nchoosek_and_categorical_optimum(self): + """Joint NChooseK + Categorical problem with a single optimum.""" + nck = NChooseK(features=[0, 1, 2, 3], min_count=1, max_count=2) + cat = Categorical(dim=5, values=[0.0, 1.0, 2.0]) + gs = Groups(groups=[nck, cat]) + + def reward_fn(feats, cats): + if set(feats) == {0, 3} and cats.get(5) == 2.0: + return 100.0 + return 1.0 + + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=7) + feats, cats, val = mcts.run(n_iterations=300) + assert val == 100.0 + assert set(feats) == {0, 3} + assert cats == {5: 2.0} + + def test_best_value_improves_over_iterations(self): + """Run MCTS in stages and verify the best value never decreases.""" + nck = NChooseK(features=[0, 1, 2, 3, 4], min_count=1, max_count=3) + gs = Groups(groups=[nck]) + + def reward_fn(feats, _cats): + if set(feats) == {1, 2, 4}: + return 100.0 + return float(len(feats)) + + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=42) + + best_values = [] + for _ in range(5): + mcts.run(n_iterations=40) + best_values.append(mcts.best_value) + + for i in range(1, len(best_values)): + assert best_values[i] >= best_values[i - 1] + assert best_values[-1] == 100.0 + + def test_variable_count_optimum(self): + """Optimum requires a specific count of features (not max).""" + nck = NChooseK(features=[0, 1, 2, 3, 4, 5], min_count=1, max_count=4) + gs = Groups(groups=[nck]) + + def reward_fn(feats, _cats): + if set(feats) == {2, 5}: + return 80.0 + return float(len(feats)) + + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=42) + feats, _cats, val = mcts.run(n_iterations=300) + assert val == 80.0 + assert set(feats) == {2, 5} + + def test_large_categorical_space(self): + """Three categoricals with 4 values each (64 combinations).""" + cat1 = Categorical(dim=0, values=[0.0, 1.0, 2.0, 3.0]) + cat2 = Categorical(dim=1, values=[0.0, 1.0, 2.0, 3.0]) + cat3 = Categorical(dim=2, values=[0.0, 1.0, 2.0, 3.0]) + gs = Groups(groups=[cat1, cat2, cat3]) + + target = {0: 3.0, 1: 1.0, 2: 2.0} + + def reward_fn(_feats, cats): + if cats == target: + return 100.0 + return sum(10.0 for d in target if cats.get(d) == target[d]) + + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=99) + feats, cats, val = mcts.run(n_iterations=200) + assert val == 100.0 + assert cats == target + assert feats == () + + def test_multiple_nchoosek_groups(self): + """Two independent NChooseK groups, each with its own optimum.""" + nck1 = NChooseK(features=[0, 1, 2], min_count=1, max_count=1) + nck2 = NChooseK(features=[10, 11, 12], min_count=1, max_count=1) + gs = Groups(groups=[nck1, nck2]) + + def reward_fn(feats, _cats): + if set(feats) == {2, 11}: + return 100.0 + score = 0.0 + if 2 in feats: + score += 40.0 + if 11 in feats: + score += 40.0 + return score + + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=0) + feats, _cats, val = mcts.run(n_iterations=200) + assert val == 100.0 + assert set(feats) == {2, 11} + + def test_convergence_with_noisy_reward(self): + """Reward has a noise floor but the optimum still dominates.""" + rng = stdlib_random.Random(12345) + nck = NChooseK(features=[0, 1, 2], min_count=1, max_count=1) + gs = Groups(groups=[nck]) + + def reward_fn(feats, _cats): + noise = rng.uniform(-1.0, 1.0) + if 2 in feats: + return 50.0 + noise + if 1 in feats: + return 20.0 + noise + return 5.0 + noise + + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=42) + feats, _cats, val = mcts.run(n_iterations=150) + assert 2 in feats + assert val >= 49.0 + + def test_mcts_outperforms_random_rollouts(self): + """MCTS with NIG-TS should achieve higher average best reward + than pure random rollouts on a structured problem.""" + nck = NChooseK(features=list(range(10)), min_count=1, max_count=4) + gs = Groups(groups=[nck]) + optimal = {2, 7} + budget = 60 + + def reward_fn(feats, _cats): + feat_set = set(feats) + if feat_set == optimal: + return 100.0 + overlap = len(feat_set & optimal) + extras = len(feat_set - optimal) + return float(overlap * 30 - extras * 10) + + def random_search(seed: int) -> float: + mcts_tmp = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + rollout_mode="uniform", + seed=seed, + ) + rng = stdlib_random.Random(seed) + best = float("-inf") + for _ in range(budget): + mcts_tmp.rng = rng + feats, cats, _traj = mcts_tmp._rollout(mcts_tmp.root) + val = reward_fn(feats, cats) + if val > best: + best = val + return best + + n_trials = 20 + mcts_best_vals = [] + random_best_vals = [] + + for trial_seed in range(n_trials): + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=trial_seed) + _feats, _cats, mcts_val = mcts.run(n_iterations=budget) + mcts_best_vals.append(mcts_val) + + random_val = random_search(seed=trial_seed + 1000) + random_best_vals.append(random_val) + + mcts_mean = sum(mcts_best_vals) / n_trials + random_mean = sum(random_best_vals) / n_trials + + assert mcts_mean > random_mean, ( + f"MCTS mean best {mcts_mean:.1f} should beat " + f"random mean best {random_mean:.1f}" + ) + + def test_mcts_outperforms_random_mixed_problem(self): + """Mixed NChooseK + Categorical problem where MCTS should beat random.""" + nck1 = NChooseK(features=[0, 1, 2, 3], min_count=1, max_count=2) + nck2 = NChooseK(features=[10, 11, 12, 13], min_count=1, max_count=2) + cat1 = Categorical(dim=20, values=[0.0, 1.0, 2.0, 3.0]) + cat2 = Categorical(dim=21, values=[0.0, 1.0, 2.0]) + gs = Groups(groups=[nck1, nck2, cat1, cat2]) + + target_feats = {1, 3, 11} + target_cats = {20: 2.0, 21: 1.0} + budget = 60 + + def reward_fn(feats, cats): + feat_set = set(feats) + score = sum(10.0 for f in target_feats if f in feat_set) + score += sum(15.0 for d, v in target_cats.items() if cats.get(d) == v) + score -= 5.0 * len(feat_set - target_feats) + if feat_set == target_feats and cats == target_cats: + score = 100.0 + return score + + n_trials = 15 + mcts_best_vals = [] + random_best_vals = [] + + for trial_seed in range(n_trials): + mcts = MCTS(groups=gs, reward_fn=reward_fn, seed=trial_seed) + _f, _c, mcts_val = mcts.run(n_iterations=budget) + mcts_best_vals.append(mcts_val) + + mcts_tmp = MCTS( + groups=gs, + reward_fn=lambda f, c: 0.0, + rollout_mode="uniform", + seed=trial_seed + 500, + ) + rng = stdlib_random.Random(trial_seed + 500) + best_rand = float("-inf") + for _ in range(budget): + mcts_tmp.rng = rng + feats, cats, _traj = mcts_tmp._rollout(mcts_tmp.root) + val = reward_fn(feats, cats) + if val > best_rand: + best_rand = val + random_best_vals.append(best_rand) + + mcts_mean = sum(mcts_best_vals) / n_trials + random_mean = sum(random_best_vals) / n_trials + + assert mcts_mean > random_mean, ( + f"MCTS mean best {mcts_mean:.1f} should beat " + f"random mean best {random_mean:.1f}" + ) + + +# ============================================================================= +# optimize_acqf_mcts tests +# ============================================================================= + + +def _make_mock_optimize_acqf(d: int, q: int = 1): + """Create a mock for botorch.optim.optimize_acqf.""" + call_log: list[dict] = [] + + def mock_fn(**kwargs): + call_log.append(kwargs) + fixed = kwargs.get("fixed_features") or {} + cand = torch.rand(q, d) + for dim, val in fixed.items(): + cand[:, dim] = val + acq_val = cand.sum() + return cand, acq_val + + return mock_fn, call_log + + +class TestOptimizeAcqfMcts: + """Tests for the top-level optimize_acqf_mcts function.""" + + # ---- output shape and type ---- + + def test_returns_correct_shape_nchoosek(self, monkeypatch): + d = 5 + mock_fn, _ = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + + candidates, acq_val = optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[([0, 1, 2], 1, 2)], + num_iterations=10, + seed=0, + ) + assert candidates.shape == (1, d) + assert isinstance(acq_val, float) + + def test_returns_correct_shape_q_greater_than_one(self, monkeypatch): + d, q = 4, 3 + mock_fn, _ = _make_mock_optimize_acqf(d, q=q) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + + candidates, acq_val = optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[([0, 1, 2], 1, 2)], + q=q, + num_iterations=10, + seed=0, + ) + assert candidates.shape == (q, d) + + # ---- zero iterations fallback ---- + + def test_zero_iterations_returns_zeros(self, monkeypatch): + d = 4 + call_log: list[dict] = [] + + def mock_fn(**kwargs): + call_log.append(kwargs) + return torch.zeros(1, d), torch.tensor(0.0) + + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + + candidates, acq_val = optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[([0, 1, 2], 1, 2)], + num_iterations=0, + seed=0, + ) + assert len(call_log) == 0 + assert candidates.shape == (1, d) + assert torch.all(candidates == 0) + assert acq_val == float("-inf") + + # ---- NChooseK: inactive features fixed to zero ---- + + def test_inactive_features_fixed_to_zero(self, monkeypatch): + d = 5 + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + nchoosek_features = [0, 1, 2, 3] + + optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[(nchoosek_features, 1, 2)], + num_iterations=15, + seed=0, + ) + + assert len(call_log) > 0 + for call_kwargs in call_log: + fixed = call_kwargs["fixed_features"] + for f in nchoosek_features: + if f in fixed: + assert fixed[f] == 0.0 + + def test_active_features_not_fixed(self, monkeypatch): + d = 4 + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + nchoosek_features = [0, 1, 2, 3] + + optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[(nchoosek_features, 1, 3)], + num_iterations=20, + seed=0, + ) + + for call_kwargs in call_log: + fixed = call_kwargs["fixed_features"] + n_fixed_nck = sum(1 for f in nchoosek_features if f in fixed) + n_free = len(nchoosek_features) - n_fixed_nck + assert n_free >= 1 + + # ---- Categorical: dims fixed to selected value ---- + + def test_categorical_dims_fixed_to_allowed_value(self, monkeypatch): + d = 5 + cat_dims = {3: [0.0, 1.0, 2.0], 4: [10.0, 20.0]} + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + + optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + cat_dims=cat_dims, + num_iterations=15, + seed=0, + ) + + assert len(call_log) > 0 + for call_kwargs in call_log: + fixed = call_kwargs["fixed_features"] + for dim, allowed in cat_dims.items(): + assert dim in fixed, f"Categorical dim {dim} not in fixed_features" + assert fixed[dim] in allowed, ( + f"Categorical dim {dim} fixed to {fixed[dim]}, " + f"expected one of {allowed}" + ) + + # ---- Mixed NChooseK + Categorical ---- + + def test_mixed_nchoosek_and_categorical(self, monkeypatch): + d = 6 + nchoosek_features = [0, 1, 2] + cat_dims = {4: [0.0, 1.0], 5: [10.0, 20.0, 30.0]} + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + + candidates, acq_val = optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[(nchoosek_features, 1, 2)], + cat_dims=cat_dims, + num_iterations=15, + seed=0, + ) + + assert candidates.shape == (1, d) + assert len(call_log) > 0 + for call_kwargs in call_log: + fixed = call_kwargs["fixed_features"] + assert 4 in fixed + assert 5 in fixed + for f in nchoosek_features: + if f in fixed: + assert fixed[f] == 0.0 + + # ---- User-provided fixed_features are forwarded ---- + + def test_user_fixed_features_forwarded(self, monkeypatch): + d = 5 + user_fixed = {4: 99.0} + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + + optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[([0, 1, 2], 1, 2)], + fixed_features=user_fixed, + num_iterations=10, + seed=0, + ) + + for call_kwargs in call_log: + fixed = call_kwargs["fixed_features"] + assert fixed[4] == 99.0 + + # ---- Constraints are forwarded ---- + + def test_constraints_forwarded(self, monkeypatch): + d = 4 + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + ineq = [(torch.tensor([0]), torch.tensor([1.0]), 0.5)] + eq = [(torch.tensor([1]), torch.tensor([1.0]), 1.0)] + + optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[([0, 1, 2], 1, 2)], + inequality_constraints=ineq, + equality_constraints=eq, + num_iterations=5, + seed=0, + ) + + assert len(call_log) > 0 + for call_kwargs in call_log: + assert call_kwargs["inequality_constraints"] is ineq + assert call_kwargs["equality_constraints"] is eq + + # ---- BoTorch params forwarded ---- + + def test_botorch_params_forwarded(self, monkeypatch): + d = 4 + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + + optimize_acqf_mcts( + acq_function="sentinel_acqf", + bounds=bounds, + nchooseks=[([0, 1], 1, 1)], + q=2, + raw_samples=512, + num_restarts=8, + num_iterations=5, + seed=0, + ) + + for call_kwargs in call_log: + assert call_kwargs["acq_function"] == "sentinel_acqf" + assert call_kwargs["q"] == 2 + assert call_kwargs["raw_samples"] == 512 + assert call_kwargs["num_restarts"] == 8 + + # ---- Bounds forwarded correctly ---- + + def test_bounds_forwarded(self, monkeypatch): + d = 3 + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.tensor([[0.0, -1.0, 0.0], [1.0, 1.0, 5.0]]) + + optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[([0, 1], 1, 1)], + num_iterations=5, + seed=0, + ) + + for call_kwargs in call_log: + assert torch.equal(call_kwargs["bounds"], bounds) + + # ---- Multiple NChooseK groups ---- + + def test_multiple_nchoosek_groups(self, monkeypatch): + d = 6 + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + nchooseks = [ + ([0, 1, 2], 1, 2), + ([3, 4, 5], 1, 1), + ] + + candidates, acq_val = optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=nchooseks, + num_iterations=15, + seed=0, + ) + + assert candidates.shape == (1, d) + all_nck_features = {0, 1, 2, 3, 4, 5} + for call_kwargs in call_log: + fixed = call_kwargs["fixed_features"] + for f in all_nck_features: + if f in fixed: + assert fixed[f] == 0.0 + n_free_g2 = sum(1 for f in [3, 4, 5] if f not in fixed) + assert n_free_g2 == 1 + + # ---- No constraints at all (only bounds) ---- + + def test_no_nchoosek_no_categoricals(self, monkeypatch): + d = 3 + mock_fn, call_log = _make_mock_optimize_acqf(d) + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + + candidates, acq_val = optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + num_iterations=5, + seed=0, + ) + + assert candidates.shape == (1, d) + assert len(call_log) > 0 + for call_kwargs in call_log: + fixed = call_kwargs["fixed_features"] + assert fixed == {} + + # ---- Best candidate is from the call with highest acq value ---- + + def test_best_candidate_tracks_highest_acq_value(self, monkeypatch): + d = 3 + call_count = 0 + best_cand = torch.tensor([[0.1, 0.2, 0.3]]) + + def mock_fn(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 3: + return best_cand, torch.tensor(999.0) + return torch.rand(1, d), torch.tensor(1.0) + + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", mock_fn) + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + + candidates, acq_val = optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[([0, 1, 2], 1, 2)], + num_iterations=10, + seed=0, + ) + + assert acq_val == 999.0 + assert torch.equal(candidates, best_cand) + + # ---- dtype preserved ---- + + def test_dtype_preserved_on_zero_iterations(self, monkeypatch): + d = 3 + + def noop(**kwargs): + return torch.zeros(1, d), torch.tensor(0.0) + + monkeypatch.setattr(optimize_mcts_mod, "optimize_acqf", noop) + bounds = torch.stack( + [torch.zeros(d, dtype=torch.float64), torch.ones(d, dtype=torch.float64)] + ) + + candidates, _ = optimize_acqf_mcts( + acq_function=None, + bounds=bounds, + nchooseks=[([0, 1], 1, 1)], + num_iterations=0, + ) + assert candidates.dtype == torch.float64 diff --git a/tests/bofire/strategies/test_random.py b/tests/bofire/strategies/test_random.py index 1a02f5911..aa555afc7 100644 --- a/tests/bofire/strategies/test_random.py +++ b/tests/bofire/strategies/test_random.py @@ -213,8 +213,8 @@ def test_nchoosek(): If7 = ContinuousInput(bounds=(1, 1), key="If7") c2 = LinearInequalityConstraint.from_greater_equal( - features=["if1", "if2"], - coefficients=[1.0, 1.0], + features=["if1", "if2", "if3"], + coefficients=[1.0, 1.0, 1.0], rhs=0.2, ) @@ -225,8 +225,8 @@ def test_nchoosek(): none_also_valid=False, ) c7 = LinearEqualityConstraint( - features=["if1", "if2"], - coefficients=[1.0, 1.0], + features=["if1", "if2", "if3"], + coefficients=[1.0, 1.0, 1.0], rhs=1.0, ) domain = Domain.from_lists( @@ -265,6 +265,47 @@ def test_sample_from_polytope(): assert_frame_equal(samples2, samples) +def test_allow_zero_without_nchoosek(): + """Test random sampling with allow_zero features but no NChooseK constraint.""" + if1 = ContinuousInput(bounds=(0.1, 1), key="if1", allow_zero=True) + if2 = ContinuousInput(bounds=(0.1, 1), key="if2", allow_zero=True) + if3 = ContinuousInput(bounds=(0.1, 1), key="if3") + domain = Domain.from_lists(inputs=[if1, if2, if3]) + data_model = data_models.RandomStrategy(domain=domain) + sampler = strategies.RandomStrategy(data_model=data_model) + samples = sampler.ask(50) + assert len(samples) == 50 + # if3 should never be zero (not allow_zero) + assert (samples["if3"] != 0.0).all() + # if1 and if2 should have some zeros (allow_zero) + assert (samples["if1"] == 0.0).any() or (samples["if2"] == 0.0).any() + + +def test_allow_zero_with_nchoosek(): + """Test that allow_zero features already in NChooseK don't get duplicate groups.""" + if1 = ContinuousInput(bounds=(0, 1), key="if1") + if2 = ContinuousInput(bounds=(0, 1), key="if2") + if3 = ContinuousInput(bounds=(0, 1), key="if3") + if4 = ContinuousInput(bounds=(0.1, 1), key="if4", allow_zero=True) + c = NChooseKConstraint( + features=["if1", "if2", "if3"], + min_count=1, + max_count=2, + none_also_valid=False, + ) + domain = Domain.from_lists(inputs=[if1, if2, if3, if4], constraints=[c]) + data_model = data_models.RandomStrategy(domain=domain) + sampler = strategies.RandomStrategy(data_model=data_model) + samples = sampler.ask(50) + assert len(samples) == 50 + # At most 2 features should be non-zero per sample (from NChooseK) + nonzero_counts = (samples[["if1", "if2", "if3"]] != 0.0).sum(axis=1) + assert (nonzero_counts >= 1).all() + assert (nonzero_counts <= 2).all() + # if4 (allow_zero, not in NChooseK) should have some zeros + assert (samples["if4"] == 0.0).any() + + @pytest.mark.parametrize( "method,kwargs,n_samples", [