Skip to content

Commit c9bbd6d

Browse files
author
Ralf Waldukat
committed
fix: remove unconditional kv_cache_seq_rm from eval() and consolidate recurrent/SWA guards
1 parent 69f4e42 commit c9bbd6d

File tree

1 file changed

+4
-44
lines changed

1 file changed

+4
-44
lines changed

llama_cpp/llama.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -651,13 +651,7 @@ def reset(self):
651651
llama_cpp.llama_memory_clear(mem, True)
652652

653653
def eval(self, tokens: Sequence[int]):
654-
"""Evaluate a list of tokens.
655-
656-
Args:
657-
tokens: The list of tokens to evaluate.
658-
"""
659-
if len(tokens) < self.n_tokens:
660-
self._ctx.kv_cache_seq_rm(-1, len(tokens), -1)
654+
"""Evaluate a list of tokens."""
661655
for i in range(0, len(tokens), self.n_batch):
662656
batch = tokens[i : min(len(tokens), i + self.n_batch)]
663657
n_past = self.n_tokens
@@ -666,26 +660,12 @@ def eval(self, tokens: Sequence[int]):
666660
batch=batch, n_past=n_past, logits_all=self._logits_all
667661
)
668662
self._ctx.decode(self._batch)
669-
# Save tokens
670663
self.input_ids[n_past : n_past + n_tokens] = batch
671-
# Save logits
672664
if self._logits_all:
673-
rows = n_tokens
674-
cols = self._n_vocab
675665
logits = np.ctypeslib.as_array(
676-
self._ctx.get_logits(), shape=(rows * cols,)
666+
self._ctx.get_logits(), shape=(n_tokens, self._n_vocab)
677667
)
678-
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
679-
else:
680-
# rows = 1
681-
# cols = self._n_vocab
682-
# logits = np.ctypeslib.as_array(
683-
# self._ctx.get_logits(), shape=(rows * cols,)
684-
# )
685-
# self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
686-
# NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
687-
pass
688-
# Update n_tokens
668+
self.scores[n_past : n_past + n_tokens, :] = logits
689669
self.n_tokens += n_tokens
690670

691671
def _init_sampler(
@@ -907,34 +887,14 @@ def generate(
907887
else:
908888
break
909889

910-
# Recurrent models cannot rewind state; reset if needed
911-
if self._is_recurrent_model and longest_prefix < self.n_tokens:
912-
longest_prefix = 0
913-
reset = True
914-
if self.verbose:
915-
print(
916-
"Llama.generate: recurrent model requires full state reset",
917-
file=sys.stderr,
918-
)
919-
920-
# SWA/ISWA models (e.g. Gemma-4) have split KV caches whose
921-
# position-tracking maps are only cleared by a full reset.
922-
# Partial seq_rm leaves stale positions and causes decode failure.
923-
if self._has_swa_model and longest_prefix < self.n_tokens:
890+
if (self._is_recurrent_model or self._has_swa_model) and longest_prefix < self.n_tokens:
924891
longest_prefix = 0
925-
reset = True
926892

927893
if longest_prefix > 0:
928894
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
929895
reset = False
930896
tokens = tokens[longest_prefix:]
931897
self.n_tokens = longest_prefix
932-
if self.verbose:
933-
print(
934-
f"Llama.generate: {longest_prefix} prefix-match hit, "
935-
f"remaining {len(tokens)} prompt tokens to eval",
936-
file=sys.stderr,
937-
)
938898
elif self.verbose:
939899
print(
940900
f"Llama.generate: {longest_prefix} prefix-match found "

0 commit comments

Comments
 (0)