Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions nanoproof/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def generate(self, tokens, num_samples=1, max_tokens=None, min_tokens=None, temp

# Create attention masks if padding is needed
decode_mask = None
decode_mask_len = max_prompt_len # tracks current valid length in pre-allocated decode_mask
prefill_attn_mask = None
if any(length != max_prompt_len for length in prompt_lengths):
# prompt_mask[b, t] = True if position t is a real token (not padding) for prompt b
Expand All @@ -165,8 +166,11 @@ def generate(self, tokens, num_samples=1, max_tokens=None, min_tokens=None, temp
# prefill_attn_mask combines prompt_mask and causal_mask: attend only to non-padding keys before the query position
# shape: (num_prompts, 1, max_prompt_len, max_prompt_len) - the 1 broadcasts across heads
prefill_attn_mask = (causal_mask.unsqueeze(0) & prompt_mask.unsqueeze(1)).unsqueeze(1)
# decode_mask tracks which positions are valid for each row during generation (will be updated after each step)
decode_mask = prompt_mask.repeat_interleave(num_samples, dim=0)
# decode_mask tracks which positions are valid for each row during generation
# Pre-allocate to max size to avoid repeated concatenation in the loop
decode_max_len = max_prompt_len + (max_tokens if max_tokens is not None else 1024)
decode_mask = torch.zeros((total_rows, decode_max_len), dtype=torch.bool, device=device)
decode_mask[:, :max_prompt_len] = prompt_mask.repeat_interleave(num_samples, dim=0)

# 2) Run batched prefill
m = self.model.config
Expand All @@ -189,12 +193,11 @@ def generate(self, tokens, num_samples=1, max_tokens=None, min_tokens=None, temp
**kv_model_kwargs,
)
# Initialize the decode cache from prefill cache, replicating for each sample
# Use repeat_interleave for efficient GPU-side replication (single kernel vs nested loop)
dtype, dev = kv_cache_prefill.kv_cache.dtype, kv_cache_prefill.kv_cache.device
kv_cache_decode.kv_cache = torch.empty(kv_cache_decode.kv_shape, dtype=dtype, device=dev)
for i in range(num_prompts):
src = kv_cache_prefill.kv_cache[:, :, i:i + 1, :, :max_prompt_len, :]
for j in range(num_samples):
kv_cache_decode.kv_cache[:, :, i * num_samples + j:i * num_samples + j + 1, :, :max_prompt_len, :] = src
kv_cache_decode.kv_cache[:, :, :, :, :max_prompt_len, :] = \
kv_cache_prefill.kv_cache[:, :, :, :, :max_prompt_len, :].repeat_interleave(num_samples, dim=2)
kv_cache_decode.pos = max_prompt_len
del kv_cache_prefill # no need to keep this memory around

Expand Down Expand Up @@ -245,16 +248,14 @@ def generate(self, tokens, num_samples=1, max_tokens=None, min_tokens=None, temp
# Prepare logits for next iteration
ids = torch.tensor(token_column, dtype=torch.long, device=device).unsqueeze(1)


if decode_mask is not None:
# Extend decode_mask with True for the new tokens
decode_mask = torch.cat(
[decode_mask, torch.ones((total_rows, 1), dtype=torch.bool, device=device)], dim=1
)
# Mark the new token position as valid (pre-allocated, no concatenation needed)
decode_mask[:, decode_mask_len] = True
decode_mask_len += 1
logits = self.model.forward(
ids,
kv_cache=kv_cache_decode,
attention_mask=decode_mask.unsqueeze(1).unsqueeze(1), # (B, 1, 1, T)
attention_mask=decode_mask[:, :decode_mask_len].unsqueeze(1).unsqueeze(1), # (B, 1, 1, T)
)
else:
logits = self.model.forward(ids, kv_cache=kv_cache_decode)
Expand Down
10 changes: 10 additions & 0 deletions nanoproof/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@

# Create the tactic model with batching support for parallel actors
inner_tactic_model = TacticModel.create(num_samples=num_sampled_tactics)

# Compile the model for faster forward passes (10-30% speedup on modern GPUs)
# Using "reduce-overhead" mode which is optimized for variable batch sizes
if device_type == "cuda":
compiled_network = torch.compile(inner_tactic_model.network, mode="reduce-overhead")
# Update all references to use the compiled model
inner_tactic_model.network = compiled_network
inner_tactic_model.engine.model = compiled_network
log("Model compiled with torch.compile (reduce-overhead mode)", component="Config")

tactic_model = BatchedTacticModel(
inner_model=inner_tactic_model,
batch_size=num_actors,
Expand Down
8 changes: 5 additions & 3 deletions nanoproof/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,26 @@ def progressive_sample(node: Node, config: Config) -> bool:

def select_child(config: Config, node: Node) -> tuple[Action, Node]:
"""Selects the child with the highest UCB score."""
# Cache prior_sum once for all children (avoids O(children^2) recomputation)
prior_sum = node.prior_sum()
_, action, child = max(
(ucb_score(config, node, child), action, child)
(ucb_score(config, node, child, prior_sum), action, child)
for action, child in node.children.items()
)
return action, child


# The score for a node is based on its value, plus an exploration bonus based on
# the prior.
def ucb_score(config: Config, parent: Node, child: Node) -> float:
def ucb_score(config: Config, parent: Node, child: Node, prior_sum: float) -> float:
pb_c = (
math.log((parent.visit_count + config.pb_c_base + 1) / config.pb_c_base)
+ config.pb_c_init
)
pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)

# Due to progressive sampling, we normalise priors here.
prior_score = pb_c * child.prior / parent.prior_sum()
prior_score = pb_c * child.prior / prior_sum
if child.visit_count > 0:
value = child.reward + child.value()
value_score = config.value_discount ** (- 1 - value)
Expand Down