From 66e931d852616b7d7f5f2c6995eeb235f139f742 Mon Sep 17 00:00:00 2001 From: Egor Konovalov Date: Mon, 5 Jan 2026 14:45:50 +0300 Subject: [PATCH] some perf optimizations --- nanoproof/engine.py | 25 +++++++++++++------------ nanoproof/rl.py | 10 ++++++++++ nanoproof/search.py | 8 +++++--- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/nanoproof/engine.py b/nanoproof/engine.py index fdfa947..e641ad8 100644 --- a/nanoproof/engine.py +++ b/nanoproof/engine.py @@ -154,6 +154,7 @@ def generate(self, tokens, num_samples=1, max_tokens=None, min_tokens=None, temp # Create attention masks if padding is needed decode_mask = None + decode_mask_len = max_prompt_len # tracks current valid length in pre-allocated decode_mask prefill_attn_mask = None if any(length != max_prompt_len for length in prompt_lengths): # prompt_mask[b, t] = True if position t is a real token (not padding) for prompt b @@ -165,8 +166,11 @@ def generate(self, tokens, num_samples=1, max_tokens=None, min_tokens=None, temp # prefill_attn_mask combines prompt_mask and causal_mask: attend only to non-padding keys before the query position # shape: (num_prompts, 1, max_prompt_len, max_prompt_len) - the 1 broadcasts across heads prefill_attn_mask = (causal_mask.unsqueeze(0) & prompt_mask.unsqueeze(1)).unsqueeze(1) - # decode_mask tracks which positions are valid for each row during generation (will be updated after each step) - decode_mask = prompt_mask.repeat_interleave(num_samples, dim=0) + # decode_mask tracks which positions are valid for each row during generation + # Pre-allocate to max size to avoid repeated concatenation in the loop + decode_max_len = max_prompt_len + (max_tokens if max_tokens is not None else 1024) + decode_mask = torch.zeros((total_rows, decode_max_len), dtype=torch.bool, device=device) + decode_mask[:, :max_prompt_len] = prompt_mask.repeat_interleave(num_samples, dim=0) # 2) Run batched prefill m = self.model.config @@ -189,12 +193,11 @@ def generate(self, tokens, num_samples=1, max_tokens=None, min_tokens=None, temp **kv_model_kwargs, ) # Initialize the decode cache from prefill cache, replicating for each sample + # Use repeat_interleave for efficient GPU-side replication (single kernel vs nested loop) dtype, dev = kv_cache_prefill.kv_cache.dtype, kv_cache_prefill.kv_cache.device kv_cache_decode.kv_cache = torch.empty(kv_cache_decode.kv_shape, dtype=dtype, device=dev) - for i in range(num_prompts): - src = kv_cache_prefill.kv_cache[:, :, i:i + 1, :, :max_prompt_len, :] - for j in range(num_samples): - kv_cache_decode.kv_cache[:, :, i * num_samples + j:i * num_samples + j + 1, :, :max_prompt_len, :] = src + kv_cache_decode.kv_cache[:, :, :, :, :max_prompt_len, :] = \ + kv_cache_prefill.kv_cache[:, :, :, :, :max_prompt_len, :].repeat_interleave(num_samples, dim=2) kv_cache_decode.pos = max_prompt_len del kv_cache_prefill # no need to keep this memory around @@ -245,16 +248,14 @@ def generate(self, tokens, num_samples=1, max_tokens=None, min_tokens=None, temp # Prepare logits for next iteration ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1) - if decode_mask is not None: - # Extend decode_mask with True for the new tokens - decode_mask = torch.cat( - [decode_mask, torch.ones((total_rows, 1), dtype=torch.bool, device=device)], dim=1 - ) + # Mark the new token position as valid (pre-allocated, no concatenation needed) + decode_mask[:, decode_mask_len] = True + decode_mask_len += 1 logits = self.model.forward( ids, kv_cache=kv_cache_decode, - attention_mask=decode_mask.unsqueeze(1).unsqueeze(1), # (B, 1, 1, T) + attention_mask=decode_mask[:, :decode_mask_len].unsqueeze(1).unsqueeze(1), # (B, 1, 1, T) ) else: logits = self.model.forward(ids, kv_cache=kv_cache_decode) diff --git a/nanoproof/rl.py b/nanoproof/rl.py index 63ceda0..7ac7c8a 100644 --- a/nanoproof/rl.py +++ b/nanoproof/rl.py @@ -113,6 +113,16 @@ # Create the tactic model with batching support for parallel actors inner_tactic_model = TacticModel.create(num_samples=num_sampled_tactics) + +# Compile the model for faster forward passes (10-30% speedup on modern GPUs) +# Using "reduce-overhead" mode which is optimized for variable batch sizes +if device_type == "cuda": + compiled_network = torch.compile(inner_tactic_model.network, mode="reduce-overhead") + # Update all references to use the compiled model + inner_tactic_model.network = compiled_network + inner_tactic_model.engine.model = compiled_network + log("Model compiled with torch.compile (reduce-overhead mode)", component="Config") + tactic_model = BatchedTacticModel( inner_model=inner_tactic_model, batch_size=num_actors, diff --git a/nanoproof/search.py b/nanoproof/search.py index f6a970e..91808c4 100644 --- a/nanoproof/search.py +++ b/nanoproof/search.py @@ -473,8 +473,10 @@ def progressive_sample(node: Node, config: Config) -> bool: def select_child(config: Config, node: Node) -> tuple[Action, Node]: """Selects the child with the highest UCB score.""" + # Cache prior_sum once for all children (avoids O(children^2) recomputation) + prior_sum = node.prior_sum() _, action, child = max( - (ucb_score(config, node, child), action, child) + (ucb_score(config, node, child, prior_sum), action, child) for action, child in node.children.items() ) return action, child @@ -482,7 +484,7 @@ def select_child(config: Config, node: Node) -> tuple[Action, Node]: # The score for a node is based on its value, plus an exploration bonus based on # the prior. -def ucb_score(config: Config, parent: Node, child: Node) -> float: +def ucb_score(config: Config, parent: Node, child: Node, prior_sum: float) -> float: pb_c = ( math.log((parent.visit_count + config.pb_c_base + 1) / config.pb_c_base) + config.pb_c_init @@ -490,7 +492,7 @@ def ucb_score(config: Config, parent: Node, child: Node) -> float: pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1) # Due to progressive sampling, we normalise priors here. - prior_score = pb_c * child.prior / parent.prior_sum() + prior_score = pb_c * child.prior / prior_sum if child.visit_count > 0: value = child.reward + child.value() value_score = config.value_discount ** (- 1 - value)