@@ -553,6 +553,14 @@ def free_lora_adapter():
553553
554554 self ._sampler = None
555555
556+ # Cache model architecture flags to avoid repeated FFI calls
557+ self ._is_recurrent_model = llama_cpp .llama_model_is_recurrent (
558+ self ._model .model
559+ ) or llama_cpp .llama_model_is_hybrid (self ._model .model )
560+ self ._has_swa_model = llama_cpp .llama_model_n_swa (
561+ self ._model .model
562+ ) > 0
563+
556564 @property
557565 def ctx (self ) -> llama_cpp .llama_context_p :
558566 return self ._ctx .ctx
@@ -638,13 +646,12 @@ def reset(self):
638646 """Reset the model state."""
639647 self .n_tokens = 0
640648
641- def eval (self , tokens : Sequence [int ]):
642- """Evaluate a list of tokens.
649+ mem = llama_cpp .llama_get_memory (self ._ctx .ctx )
650+ if mem is not None :
651+ llama_cpp .llama_memory_clear (mem , True )
643652
644- Args:
645- tokens: The list of tokens to evaluate.
646- """
647- self ._ctx .kv_cache_seq_rm (- 1 , self .n_tokens , - 1 )
653+ def eval (self , tokens : Sequence [int ]):
654+ """Evaluate a list of tokens."""
648655 for i in range (0 , len (tokens ), self .n_batch ):
649656 batch = tokens [i : min (len (tokens ), i + self .n_batch )]
650657 n_past = self .n_tokens
@@ -653,26 +660,12 @@ def eval(self, tokens: Sequence[int]):
653660 batch = batch , n_past = n_past , logits_all = self ._logits_all
654661 )
655662 self ._ctx .decode (self ._batch )
656- # Save tokens
657663 self .input_ids [n_past : n_past + n_tokens ] = batch
658- # Save logits
659664 if self ._logits_all :
660- rows = n_tokens
661- cols = self ._n_vocab
662665 logits = np .ctypeslib .as_array (
663- self ._ctx .get_logits (), shape = (rows * cols , )
666+ self ._ctx .get_logits (), shape = (n_tokens , self . _n_vocab )
664667 )
665- self .scores [n_past : n_past + n_tokens , :].reshape (- 1 )[::] = logits
666- else :
667- # rows = 1
668- # cols = self._n_vocab
669- # logits = np.ctypeslib.as_array(
670- # self._ctx.get_logits(), shape=(rows * cols,)
671- # )
672- # self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
673- # NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
674- pass
675- # Update n_tokens
668+ self .scores [n_past : n_past + n_tokens , :] = logits
676669 self .n_tokens += n_tokens
677670
678671 def _init_sampler (
@@ -888,22 +881,20 @@ def generate(
888881 # Check for kv cache prefix match
889882 if reset and self .n_tokens > 0 :
890883 longest_prefix = 0
891- for a , b in zip (self ._input_ids , tokens [: - 1 ] ):
884+ for a , b in zip (self ._input_ids , tokens ):
892885 if a == b :
893886 longest_prefix += 1
894887 else :
895888 break
889+
890+ if (self ._is_recurrent_model or self ._has_swa_model ) and longest_prefix < self .n_tokens :
891+ longest_prefix = 0
892+
896893 if longest_prefix > 0 :
897894 if self ._ctx .kv_cache_seq_rm (- 1 , longest_prefix , - 1 ):
898895 reset = False
899896 tokens = tokens [longest_prefix :]
900897 self .n_tokens = longest_prefix
901- if self .verbose :
902- print (
903- f"Llama.generate: { longest_prefix } prefix-match hit, "
904- f"remaining { len (tokens )} prompt tokens to eval" ,
905- file = sys .stderr ,
906- )
907898 elif self .verbose :
908899 print (
909900 f"Llama.generate: { longest_prefix } prefix-match found "
@@ -1267,12 +1258,9 @@ def logit_bias_processor(
12671258 input_ids : npt .NDArray [np .intc ],
12681259 scores : npt .NDArray [np .single ],
12691260 ) -> npt .NDArray [np .single ]:
1270- new_scores = np .copy (
1271- scores
1272- ) # Does it make sense to copy the whole array or can we just overwrite the original one?
12731261 for input_id , score in logit_bias_map .items ():
1274- new_scores [input_id ] = score + scores [ input_id ]
1275- return new_scores
1262+ scores [input_id ] + = score
1263+ return scores
12761264
12771265 _logit_bias_processor = LogitsProcessorList ([logit_bias_processor ])
12781266 if logits_processor is None :
@@ -1333,6 +1321,7 @@ def logit_bias_processor(
13331321
13341322 finish_reason = "length"
13351323 multibyte_fix = 0
1324+ accumulated_text = b""
13361325 for token in self .generate (
13371326 prompt_tokens ,
13381327 top_k = top_k ,
@@ -1352,16 +1341,17 @@ def logit_bias_processor(
13521341 grammar = grammar ,
13531342 ):
13541343 if llama_cpp .llama_vocab_is_eog (self ._model .vocab , token ):
1355- text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
1344+ text = accumulated_text
13561345 finish_reason = "stop"
13571346 break
13581347
13591348 completion_tokens .append (token )
13601349
1361- all_text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
1350+ new_text = self ._model .token_to_piece (token )
1351+ accumulated_text += new_text
13621352
13631353 # Contains multi-byte UTF8
1364- for k , char in enumerate (all_text [- 3 :]):
1354+ for k , char in enumerate (accumulated_text [- 3 :]):
13651355 k = 3 - k
13661356 for num , pattern in [(2 , 192 ), (3 , 224 ), (4 , 240 )]:
13671357 # Bitwise AND check
@@ -1373,19 +1363,16 @@ def logit_bias_processor(
13731363 multibyte_fix -= 1
13741364 continue
13751365
1376- any_stop = [s for s in stop_sequences if s in all_text ]
1366+ any_stop = [s for s in stop_sequences if s in accumulated_text ]
13771367 if len (any_stop ) > 0 :
13781368 first_stop = any_stop [0 ]
1379- text = all_text [: all_text .index (first_stop )]
1369+ text = accumulated_text [: accumulated_text .index (first_stop )]
13801370 finish_reason = "stop"
13811371 break
13821372
13831373 if stream :
13841374 remaining_tokens = completion_tokens [returned_tokens :]
1385- remaining_text = self .detokenize (
1386- remaining_tokens ,
1387- prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1388- )
1375+ remaining_text = self ._model .token_to_piece (token )
13891376 remaining_length = len (remaining_text )
13901377
13911378 # We want to avoid yielding any characters from
@@ -1522,24 +1509,23 @@ def logit_bias_processor(
15221509 }
15231510
15241511 if len (completion_tokens ) >= max_tokens :
1525- text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
1512+ text = accumulated_text
15261513 finish_reason = "length"
15271514 break
15281515
15291516 if stopping_criteria is not None and stopping_criteria (
15301517 self ._input_ids , self ._scores [- 1 , :]
15311518 ):
1532- text = self . detokenize ( completion_tokens , prev_tokens = prompt_tokens )
1519+ text = accumulated_text
15331520 finish_reason = "stop"
15341521
15351522 if self .verbose :
15361523 self ._ctx .print_timings ()
15371524
15381525 if stream :
15391526 remaining_tokens = completion_tokens [returned_tokens :]
1540- remaining_text = self .detokenize (
1541- remaining_tokens ,
1542- prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1527+ remaining_text = b"" .join (
1528+ self ._model .token_to_piece (t ) for t in remaining_tokens
15431529 )
15441530 any_stop = [s for s in stop_sequences if s in remaining_text ]
15451531 if len (any_stop ) > 0 :
@@ -1549,12 +1535,8 @@ def logit_bias_processor(
15491535
15501536 token_end_position = 0
15511537 for token in remaining_tokens :
1552- token_end_position += len (
1553- self .detokenize (
1554- [token ],
1555- prev_tokens = prompt_tokens + completion_tokens [:returned_tokens ],
1556- )
1557- )
1538+ token_piece = self ._model .token_to_piece (token )
1539+ token_end_position += len (token_piece )
15581540
15591541 logprobs_or_none : Optional [CompletionLogprobs ] = None
15601542 if logprobs is not None :
@@ -1594,7 +1576,7 @@ def logit_bias_processor(
15941576 }
15951577
15961578 if token_end_position >= end :
1597- last_text = self . detokenize ([ token ])
1579+ last_text = token_piece
15981580 if token_end_position == end - 1 :
15991581 break
16001582 returned_tokens += 1
@@ -1707,17 +1689,16 @@ def logit_bias_processor(
17071689 )
17081690 )
17091691 tokens .append (token_str )
1710- sorted_logprobs = list (
1711- sorted (
1712- zip (logprobs_token , range (len (logprobs_token ))), reverse = True
1713- )
1714- )
1692+ top_k_indices = np .argpartition (logprobs_token , - logprobs )[- logprobs :]
1693+ top_k_indices = top_k_indices [
1694+ np .argsort (logprobs_token [top_k_indices ])
1695+ ][::- 1 ]
17151696 token_logprobs .append (logprobs_token [int (token )])
17161697 top_logprob : Optional [Dict [str , float ]] = {
1717- self .detokenize ([i ], prev_tokens = all_tokens [:idx ]).decode (
1698+ self .detokenize ([int ( i ) ], prev_tokens = all_tokens [:idx ]).decode (
17181699 "utf-8" , errors = "ignore"
1719- ): logprob
1720- for logprob , i in sorted_logprobs [: logprobs ]
1700+ ): logprobs_token [ int ( i )]
1701+ for i in top_k_indices
17211702 }
17221703 top_logprob .update ({token_str : logprobs_token [int (token )]})
17231704 top_logprobs .append (top_logprob )
0 commit comments