MCTS based ACQF Optimization for Handling of NChooseKs.#693
MCTS based ACQF Optimization for Handling of NChooseKs.#693
Conversation
…onstraints Add full implementation of Monte Carlo Tree Search for optimizing acquisition functions with NChooseK constraints. The MCTS selects which features are active, then runs BoTorch optimization with inactive features fixed to zero. Features: - UCT selection with RAVE action value estimation - Progressive widening for controlled tree growth - Canonical ordering to ensure feasibility (no dead-end states) - Value caching for terminal evaluations - Support for non-contiguous feature indices in NChooseK constraints Co-Authored-By: Claude Opus 4.5 <[email protected]>
Extend the MCTS implementation to handle both NChooseK constraints and categorical dimensions. The MCTS now explores: 1. Which features are active (NChooseK) - with min/max count flexibility 2. Which categorical values to use - exactly one per categorical dimension Key changes: - Add Group abstract base class for unified MCTS group handling - Add CategoricalGroup for categorical dimensions (select exactly one value) - Update Constraints to hold both nchooseks and categoricals - Extend MCTS to process NChooseK groups first, then categorical groups - cat_dims signature matches botorch.optim.optimize_acqf_mixed_alternating - Categorical values are fixed in BoTorch optimization alongside inactive features Co-Authored-By: Claude Opus 4.5 <[email protected]>
|
@TobyBoyne: I adjusted the random sampler for also using the MCTS Tree just using random rollouts. In principle the conditional input features should also work there now. But I need to test it. I will also write tmr a bit more about my thoughts, because, I think we can really do super crazy stuff with this new formulation. |
…eeds proper review
Learn per-group stop probability from cardinality-reward statistics instead of using a single fixed p_stop_rollout for all groups. Tracks (group_idx, cardinality) -> (visits, total_reward), computes sigmoid on normalized (E_stop - E_continue), and blends with the fixed prior during a warmup period. Enabled by default (adaptive_p_stop=True). Benchmark shows "no RAVE + adaptive p_stop" achieves 100% optimum rate on needle_in_haystack (up from 97%) and best results on large_sparse (93.0 mean vs 83.8). Co-Authored-By: Claude Opus 4.6 <[email protected]>
Benchmark script tests 13 MCTS configurations across 5 combinatorial NChooseK problems. Report documents findings: RAVE is harmful, adaptive p_stop + no RAVE is the best overall configuration (100% optimum rate on needle_in_haystack, best on large_sparse). Figures are not tracked — regenerate with: python mcts-report/benchmark.py Co-Authored-By: Claude Opus 4.6 <[email protected]>
Implement min-max reward normalization to [0,1] before backpropagation, making c_uct scale-independent across problems with different reward ranges. Combined with no-RAVE + adaptive p_stop, this is the new best config: 63% optimum rate on mixed (+20pp), 40% on large_sparse (+13pp). - Add normalize_rewards parameter (default False) and _normalize_reward() - Move reward range tracking from _update_cardinality_stats to run() - Add TestRewardNormalization with 9 unit tests - Add benchmark configs with c_uct=0.01 for normalized rewards - Update report with Section 4.6 analysis and new recommendations Co-Authored-By: Claude Opus 4.6 <[email protected]>
Replace uniform-random rollouts with a learned softmax policy that blends per-(group, action) statistics with uniform exploration. STOP is treated as a regular scored action, unifying the rollout decision-making. Default rollout_policy=True with ε=0.3, τ=1.0. Benchmark results show +14pp on mixed_nchoosek_categorical (63%→77%) and +10pp on large_sparse (40%→50%) over the previous best config, with no regressions on other problems. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Replace global RAVE's context-independent action value estimates with context-dependent statistics keyed by (group_idx, cardinality, action). This allows RAVE to learn that a feature's value depends on how many features are already selected, fixing the fundamental issue that made global RAVE harmful on combinatorial problems. Benchmarks show context RAVE k=300 achieves 80% optimum rate on mixed problems (vs 77% for the rollout policy baseline), and k=100 matches the baseline on all other problems while using fewer evaluations. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Add the simplest possible NChooseK problem: 12 features pick 1-4 where each feature contributes a fixed positive value with no interactions. Confirms MCTS handles this basic case reliably — best configs achieve 83% optimum rate (vs 0% for random). Co-Authored-By: Claude Opus 4.6 <[email protected]>
Change defaults from old settings (c_uct=1.0, k_rave=300, no adaptive p_stop, no normalization, no rollout policy) to the benchmarked best: c_uct=0.01, k_rave=0, adaptive_p_stop=True, normalize_rewards=True, rollout_policy=True. Also forward all MCTS parameters through both optimize_acqf_mcts() and the _OptimizeAcqfMctsInput Pydantic model. Co-Authored-By: Claude Opus 4.6 <[email protected]>
_rollout now returns 3 values (feats, cats, trajectory) after the context RAVE change. Update the caller in RandomStrategy to unpack all three. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Explicitly disable rollout_policy, adaptive_p_stop, and normalize_rewards when constructing the MCTS instance used for random combination sampling. These features are irrelevant for pure random rollouts with a dummy reward function. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Co-Authored-By: Claude Opus 4.6 <[email protected]>
jduerholt
left a comment
There was a problem hiding this comment.
Ok, this is a first sweep.
- Remove large comment block of improvement ideas from optimize_mcts.py - Add data structure example comment for best_selection - Clarify adaptive p_stop warmup docstring (blends, not falls back) - Document normalization edge case for first rollouts (returns 0.5) - Add UCB exploration rationale to _score_rollout_actions docstring - Replace manual softmax with torch.softmax and manual sampling with rng.choices in _sample_rollout_action - Simplify _apply_action partials append to one-liner - Fix lambda formatting in random.py - Fix allow_zero deduplication: skip features already in NChooseK constraints (in both random.py and acqf_optimization.py) - Fix allow_zero features getting Pydantic validation error when bounds are set to [0, 0] by clearing allow_zero flag first - Add tests for allow_zero with and without NChooseK constraints - Add Hartmann acqf test script for MCTS vs exhaustive enumeration Co-Authored-By: Claude Opus 4.6 <[email protected]>
Benchmarking showed RAVE (both global and context-aware) adds complexity without improving performance — best configs all use k_rave=0. Strip all RAVE code from the production module and simplify trajectories from 3-tuples to 2-tuples. A full copy with RAVE intact is kept in mcts-report/ for paper reproducibility. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Introduce ActionStats, TrajectoryStep, and Selection NamedTuples to replace raw tuple types for stats dicts, rollout trajectories, and terminal selections. This improves readability (stats.visits instead of stats[0]) while preserving full backward compatibility since NamedTuples support unpacking and equality like regular tuples. Co-Authored-By: Claude Opus 4.6 <[email protected]>
… burn-in Document three proposed improvements to the MCTS algorithm: - Section 8.3: Thompson Sampling for tree selection, eliminating c_uct and reward normalization via Bayesian posteriors - Section 8.4: Thompson Sampling for rollouts, replacing the softmax policy and subsuming dead adaptive p_stop code (3 fewer hyperparams) - Section 8.5: Two-phase burn-in with cheap random-sample evaluations during early iterations, enabled by TS handling heteroscedastic observations naturally Combined, these proposals would reduce 9 tunable hyperparameters to 0. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Implement MCTS_TS with Normal-Normal conjugate posteriors for tree selection and rollout, replacing UCT and softmax policies. Benchmark 6 TS variants against best UCT configs on 6 problems (30 trials each). Key finding: TS+TS(g,a)+var_infl doubles UCT on multigroup_interaction (47% vs 23%) but UCT remains superior on large search spaces (50% vs 20% on large_sparse). Variance inflation is essential for TS viability. Report updated with full results tables, convergence analysis, prediction validation against §8.3, and 7 proposed improvements including pessimistic pseudo-observations, adaptive prior variance, and two-phase burn-in. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Implement adaptive prior variance for MCTS_TS (auto-calibrates σ₀² from running empirical variance), benchmark 3 new configs, and document results (+7pp simple_additive, +13pp large_sparse, +4pp needle). Add warm-starting trees for batch candidate generation as future direction in §11.10.8. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Implement two new cache_hit_mode options for MCTS_TS: - "pessimistic": injects pessimistic pseudo-observations (mean - std) on cache hits for asymmetric downward pressure away from exhausted subtrees - "combined": applies variance inflation then pessimistic, preserving posterior width while adding directional pressure Benchmark results (30 trials × 6 problems): - combined: 97% on graduated (best of any config), robust across all problems - combined+apv: 37% on large_sparse (best TS result) - pessimistic: 93% graduated, 90% needle; best for small smooth landscapes Document findings in REPORT.md §11.10 (pessimistic) and §11.11 (combined). Update recommendations: TS + TS(g,a) + comb is now the default TS config. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Add three new future directions to REPORT.md §11.12: - §11.12.6: Normal-Inverse-Gamma posterior — proper conjugate for unknown mean and variance, samples from Student-t instead of Normal, fixes premature commitment at low observation counts - §11.12.7: Adaptive pseudo-count n₀ from branching factor — automatic confidence calibration, keeps posteriors wider at high-branching nodes - §11.12.8: Adaptive pessimistic strength from local exhaustion rate — scales pessimistic offset by (1 - n_obs/n_visits) per node Restructure prioritized improvements list into three tiers: implemented, highest-priority single-fidelity, and structural/production changes. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Replace Normal-Normal conjugate with Normal-Inverse-Gamma posterior, whose marginal for the mean is a Student-t distribution with heavier tails at low observation counts. This fixes premature commitment at n=1 where sample variance collapses to near zero. Best NIG config (NIG + TS(g,a) + vi + apv) matches or exceeds UCT on 5 of 6 problems: 80% on multigroup (vs UCT's 23%), 100% on needle and mixed (vs 100% and 77%), and 47% on large_sparse (vs 50%, within noise). Co-Authored-By: Claude Opus 4.6 <[email protected]>
Scale pessimistic pseudo-observations by node exhaustion rate (1 - n_obs/n_visits) so fresh subtrees get mild pessimism and exhausted subtrees get full pessimism. Two new cache-hit modes: adaptive_pessimistic and adaptive_combined. Key finding: no-APV adaptive modes achieve 53% on large_sparse, first NIG configs to surpass UCT's 50%. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Implements n₀ = 1 + log(branching_factor) to make posteriors more conservative at high-branching nodes. Benchmarked across 5 configs on all 6 problems: adaptive n₀ hurts on hard problems (multigroup -23pp, large_sparse -6pp) because the NIG Student-t already handles low-n and higher n₀ over-corrects within the available budget. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Integrate the Normal-Inverse-Gamma posterior from the prototype (mcts-report/optimize_mcts_nig.py) into the production MCTS at bofire/strategies/predictives/optimize_mcts.py. The NIG model treats both mean and variance as unknown, yielding a Student-t marginal with heavier tails at low observation counts that naturally prevents premature commitment. Key changes: - ActionStats: 2 fields -> 3 (n_obs, sum_rewards, sum_sq_rewards) - Node: replace w_total/mean_value with NIG sufficient statistics - MCTS.__init__: remove UCT params (c_uct, normalize_rewards, rollout_policy, rollout_epsilon, rollout_tau, rollout_novelty_weight); add NIG params (nig_alpha0, ts_prior_var, adaptive_prior_var, cache_hit_mode, variance_decay, rollout_mode, adaptive_n0) - Tree selection: UCT score replaced by _nig_sample() Student-t draws - Rollout: only uniform and ts_group_action modes (softmax removed) - Backpropagation: novel/cache-hit split with 6 cache-hit modes - Unified _nig_sample() shared between tree selection and rollout - Defaults: NIG + TS(g,a) + variance_inflation + adaptive_prior_var Verified exact match with prototype across all benchmark problems (multigroup_interaction, needle_in_haystack, mixed, simple_additive). Co-Authored-By: Claude Opus 4.6 <[email protected]>
In ts_group_action mode (the default), STOP is scored like any other action via NIG Thompson Sampling on per-(group, action) stats. The NIG posterior naturally learns when stopping is beneficial, making the separate adaptive p_stop machinery redundant. The only consumer was uniform rollout mode, which is only used in random.py with a dummy reward function where adaptive p_stop can't learn anything useful. Uniform mode now uses fixed p_stop_rollout. Removes: adaptive_p_stop, p_stop_warmup, p_stop_temperature params; _compute_adaptive_p_stop(), _update_cardinality_stats() methods; cardinality_stats, group_rollout_counts, reward_min/max state; per-iteration stats update in run(); TestAdaptivePStop (14 tests). Co-Authored-By: Claude Opus 4.6 <[email protected]>
Empirical investigation comparing polytope-sampled acqf values against optimize_acqf on Hartmann(6, k=4) and FormulationWrapper(Hartmann(), max_count=4). Key finding: Spearman rho > 0.97 at just 64 samples/subset (26x cheaper), validating that cheap samples can replace expensive optimization for MCTS tree burn-in. Documents results in REPORT.md §11.16. Co-Authored-By: Claude Opus 4.6 <[email protected]>
TobyBoyne
left a comment
There was a problem hiding this comment.
Nice work on all of this @jduerholt! Finally got a chance to read through this, and have written up some thoughts. My main area of concern is to focus on what we are actually maximizing downstream.
The MCTS is used to maximize the acqf value. At each node, the reward is the solution to the purely continuous acqf problem, conditioned on the categorical features being fixed for that node. That means:
- The reward at each node is stochastic, since we may not find the global optimum. The report often implies that the rewards are deterministic, which isn't the case.
- We can't evaluate
optimize_acqfparticularly cheaply. So the budgets of several hundred may be too high - we might be limited to a budget of only 100, or less. - I don't currently have much intuition of what the loss landscape looks like. It would be good for us to get some mixed benchmarks, sample
$N$ datapoints, train a GP, and get the actual acqf loss landscape.
I hope my comments are useful, happy to discuss it all further :)
Best,
Toby
|
|
||
| #### multigroup_interaction (search space ~4.25M, optimum = 150.0) | ||
|
|
||
| | Config | Mean Best | ±Std | Opt Rate | Unique Evals | |
There was a problem hiding this comment.
Nit: I can't actually tell what the bolded lines mean (eg., why are MCTS (no RAVE+adpt) and MCTS (+rpol ε=0.1) not highlighted?)
| | MCTS (+crave k=300) | 103.5 | 17.3 | 7% | 370 | | ||
| | MCTS (+crave k=500) | 100.2 | 19.5 | 7% | 310 | | ||
|
|
||
| #### needle_in_haystack (search space ~4,928, optimum = 100.0) |
There was a problem hiding this comment.
I think this benchmark is probably too easy, given most of the algorithms discussed later in the report reach close enough to 100%. Either that, or the iteration budget for this benchmark should be reduced.
Claude focused far too much on this benchmark, which I think ended up with somewhat meaningless distinctions being made (eg. "A is better than B because 97% success is better than 100% success).
| #### Recommended usage | ||
|
|
||
| Normalization should be enabled together with `c_uct=0.01` (or more generally, `c_uct ≈ 1/typical_reward_range`). This combination produces the best overall results and removes the need to tune `c_uct` per problem. |
There was a problem hiding this comment.
We should be standardizing rewards, not normalizing. Think logEI, which can get very negative values that would skew the normalization bounds.
| 4. Blend with uniform: `p[a] = (1 - ε) * p_policy[a] + ε / |legal_actions|` | ||
| 5. STOP is treated as a regular action with its own learned statistics — no special handling needed | ||
|
|
||
| #### `MCTS (no RAVE+adpt+norm)` vs `MCTS (+rpol)` — the key comparison |
There was a problem hiding this comment.
It's unclear to me if the second option includes RAVE, adapt, and norm.
|
|
||
| --- | ||
|
|
||
| ## 1. Experimental Setup |
There was a problem hiding this comment.
Each benchmark reports Opt Rate, which is the percentage of runs for which abs(mcts.best_value - problem.optimal_value) < 1e-6. To me, this threshold is too low - an optimizer that gets within 5% of the optimal value is good enough for acqf optimization in my opinion, since your model is going to be wrong anyway. We should be selecting for a consistently good algorithm, not a sometimes perfect algorithm
|
|
||
| #### 11.10.1 Motivation | ||
|
|
||
| Variance inflation (§11.5.1) widens posteriors symmetrically on cache hits — the posterior could sample higher *or* lower — so ~50% of samples from an inflated exhausted node still select it. Pessimistic pseudo-observations provide *asymmetric* downward pressure: on each cache hit, we inject a pseudo-observation at `global_mean - global_std` into every node along the backpropagation path. This shifts the posterior mean downward, actively pushing the algorithm away from exhausted subtrees, analogous to UCT's virtual loss. |
There was a problem hiding this comment.
Why global_mean - global_std? To me, it seems that is too punishing, and would push the algorithm away from promising subtrees? Shouldn't it be leaf_mean - leaf_std?
|
|
||
| #### 11.11.5 Analysis | ||
|
|
||
| **Combined mode achieves the highest exploration efficiency of any config.** On every problem, `comb` evaluates more unique selections than either `var_infl` or `pess` alone. On large_sparse, `comb` reaches 671 unique evaluations vs 575 for `var_infl` and 639 for `pess`. The dual mechanism — decay followed by pessimistic injection — creates stronger pressure to leave exhausted subtrees than either mechanism alone. |
There was a problem hiding this comment.
Again, I'm a bit confused about this idea of "exploration efficiency". This report is making it out to be a key important metric, but I think you can make deterministic changes to the algorithm to enforce it to 100% evaluate a novel leaf. Why don't we?
|
|
||
| #### 11.12.3 Two-Phase Burn-in with Cheap Evaluations | ||
|
|
||
| **Problem**: TS's exploration efficiency gap (575 vs 750 unique evals on large_sparse) exists because each evaluation is expensive (`optimize_acqf` with multi-start L-BFGS), so wasted iterations on cache hits are costly. The cache itself exists because evaluations are expensive and deterministic. |
There was a problem hiding this comment.
TS's exploration efficiency gap (575 vs 750 unique evals on large_sparse) exists because each evaluation is expensive
I don't understand this statement.
|
|
||
| The benchmark data suggests the burn-in length should scale with search space size: ~50 iterations for small spaces (graduated_landscape, 375 combinations), ~200 for large (large_sparse, 960M combinations). | ||
|
|
||
| #### 11.12.4 Depth-Dependent Cache-Hit Handling |
There was a problem hiding this comment.
This section introduces a new heuristic to address another heuristic. I'm a bit sceptical.
| **Original problem statement**: The current TS implementation uses a Normal-Normal conjugate update that treats the reward variance σ² as a known plug-in estimate (`s² = max(sum_sq/n - x̄², 1e-8)`). This is reasonable when n is moderate, but it breaks down at low observation counts — exactly the regime that matters most for exploration: | ||
|
|
||
| - With n=1: `s² = max(x²/1 - x², 1e-8) = 1e-8` — variance collapses to the floor. The posterior becomes absurdly tight around a single observation. | ||
| - With n=2: sample variance is based on just 2 points — unreliable. | ||
| - With n=0: we fall back to the prior, which is a Normal distribution. |
There was a problem hiding this comment.
I might be missing something here, but I don't think the posterior variance is correct here (or in the code). If we are being Bayesian, then when we have one datapoint shouldn't the posterior variance be some function of the prior variance (posterior = prior * likelihood)? I will think about this a bit deeper.
There was a problem hiding this comment.
Let
Let the reward
If we have a single observation, ie. our dataset
So I think the variance should be
|
Some other thoughts I had: In The RAVE/CRAVE didn't seem particularly effective. However, I do think there is a setting where some context on categorical values is useful: |
Adds uniform_subset rollout that completes NChooseK groups with uniform random subsets (bypassing step-by-step action selection), and two-phase Sobol screening infrastructure for cheaper MCTS evaluation. Co-Authored-By: Claude Opus 4.6 <[email protected]>
|
@TobyBoyne : I added some stuff, especially have a look at the two notebooks |
Motivation
This PR implements an MCTS algorithm for optimitzing over conditional search spaces comprised of NChooseK constraints and/or conditional condinuous input features (
allow_zero).It uses the following features:
This is still WIP and most of the code is (as a test) generated by a wild combo of LLMs.
@TobyBoyne
Have you read the Contributing Guidelines on pull requests?
Yes.
Test Plan
Unit tests, not yet done.