@@ -651,13 +651,7 @@ def reset(self):
651651 llama_cpp .llama_memory_clear (mem , True )
652652
653653 def eval (self , tokens : Sequence [int ]):
654- """Evaluate a list of tokens.
655-
656- Args:
657- tokens: The list of tokens to evaluate.
658- """
659- if len (tokens ) < self .n_tokens :
660- self ._ctx .kv_cache_seq_rm (- 1 , len (tokens ), - 1 )
654+ """Evaluate a list of tokens."""
661655 for i in range (0 , len (tokens ), self .n_batch ):
662656 batch = tokens [i : min (len (tokens ), i + self .n_batch )]
663657 n_past = self .n_tokens
@@ -666,26 +660,12 @@ def eval(self, tokens: Sequence[int]):
666660 batch = batch , n_past = n_past , logits_all = self ._logits_all
667661 )
668662 self ._ctx .decode (self ._batch )
669- # Save tokens
670663 self .input_ids [n_past : n_past + n_tokens ] = batch
671- # Save logits
672664 if self ._logits_all :
673- rows = n_tokens
674- cols = self ._n_vocab
675665 logits = np .ctypeslib .as_array (
676- self ._ctx .get_logits (), shape = (rows * cols , )
666+ self ._ctx .get_logits (), shape = (n_tokens , self . _n_vocab )
677667 )
678- self .scores [n_past : n_past + n_tokens , :].reshape (- 1 )[::] = logits
679- else :
680- # rows = 1
681- # cols = self._n_vocab
682- # logits = np.ctypeslib.as_array(
683- # self._ctx.get_logits(), shape=(rows * cols,)
684- # )
685- # self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
686- # NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
687- pass
688- # Update n_tokens
668+ self .scores [n_past : n_past + n_tokens , :] = logits
689669 self .n_tokens += n_tokens
690670
691671 def _init_sampler (
@@ -907,34 +887,14 @@ def generate(
907887 else :
908888 break
909889
910- # Recurrent models cannot rewind state; reset if needed
911- if self ._is_recurrent_model and longest_prefix < self .n_tokens :
912- longest_prefix = 0
913- reset = True
914- if self .verbose :
915- print (
916- "Llama.generate: recurrent model requires full state reset" ,
917- file = sys .stderr ,
918- )
919-
920- # SWA/ISWA models (e.g. Gemma-4) have split KV caches whose
921- # position-tracking maps are only cleared by a full reset.
922- # Partial seq_rm leaves stale positions and causes decode failure.
923- if self ._has_swa_model and longest_prefix < self .n_tokens :
890+ if (self ._is_recurrent_model or self ._has_swa_model ) and longest_prefix < self .n_tokens :
924891 longest_prefix = 0
925- reset = True
926892
927893 if longest_prefix > 0 :
928894 if self ._ctx .kv_cache_seq_rm (- 1 , longest_prefix , - 1 ):
929895 reset = False
930896 tokens = tokens [longest_prefix :]
931897 self .n_tokens = longest_prefix
932- if self .verbose :
933- print (
934- f"Llama.generate: { longest_prefix } prefix-match hit, "
935- f"remaining { len (tokens )} prompt tokens to eval" ,
936- file = sys .stderr ,
937- )
938898 elif self .verbose :
939899 print (
940900 f"Llama.generate: { longest_prefix } prefix-match found "
0 commit comments