Skip to content

Commit e54fe59

Browse files
author
Ralf Waldukat
committed
perf: vectorize hot-path ops, reduce Python overhead, fix SWA/ISWA KV cache corruption
- set_batch(): numpy bulk writes replace per-token Python loop - _create_completion: incremental token_to_piece() accumulation replaces O(n²) re-detokenization per generated token - _create_completion: in-place logit_bias instead of full vocab copy - _create_completion: np.argpartition for top-k logprobs (O(V) vs O(V log V)) - reset(): call llama_memory_clear() for proper KV cache state reset - generate(): bypass prefix-match for recurrent/SWA models - generate(): fix tokens[:-1] off-by-one in prefix matching - eval(): remove unconditional kv_cache_seq_rm, simplify logits assignment - token_to_piece(): return correct byte length via actual write count
1 parent 1b1a320 commit e54fe59

2 files changed

Lines changed: 59 additions & 71 deletions

File tree

llama_cpp/_internals.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,12 @@ def tokenize(self, text: bytes, add_bos: bool, special: bool):
182182
return list(tokens[:n_tokens])
183183

184184
def token_to_piece(self, token: int, special: bool = False) -> bytes:
185-
buf = ctypes.create_string_buffer(32)
186-
llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special)
187-
return bytes(buf)
185+
size = 32
186+
buffer = (ctypes.c_char * size)()
187+
n = llama_cpp.llama_token_to_piece(
188+
self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special
189+
)
190+
return bytes(buffer[:n])
188191

189192
def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
190193
output = b""
@@ -503,13 +506,17 @@ def reset(self):
503506
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
504507
n_tokens = len(batch)
505508
self.batch.n_tokens = n_tokens
509+
token_arr = np.ctypeslib.as_array(self.batch.token, shape=(n_tokens,))
510+
token_arr[:] = batch
511+
pos_arr = np.ctypeslib.as_array(self.batch.pos, shape=(n_tokens,))
512+
pos_arr[:] = np.arange(n_past, n_past + n_tokens, dtype=pos_arr.dtype)
513+
n_seq_id_arr = np.ctypeslib.as_array(self.batch.n_seq_id, shape=(n_tokens,))
514+
n_seq_id_arr[:] = 1
515+
logits_arr = np.ctypeslib.as_array(self.batch.logits, shape=(n_tokens,))
516+
logits_arr[:] = logits_all
517+
logits_arr[n_tokens - 1] = True
506518
for i in range(n_tokens):
507-
self.batch.token[i] = batch[i]
508-
self.batch.pos[i] = n_past + i
509519
self.batch.seq_id[i][0] = 0
510-
self.batch.n_seq_id[i] = 1
511-
self.batch.logits[i] = logits_all
512-
self.batch.logits[n_tokens - 1] = True
513520

514521
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
515522
n_tokens = len(batch)

llama_cpp/llama.py

Lines changed: 44 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,14 @@ def free_lora_adapter():
553553

554554
self._sampler = None
555555

556+
# Cache model architecture flags to avoid repeated FFI calls
557+
self._is_recurrent_model = llama_cpp.llama_model_is_recurrent(
558+
self._model.model
559+
) or llama_cpp.llama_model_is_hybrid(self._model.model)
560+
self._has_swa_model = llama_cpp.llama_model_n_swa(
561+
self._model.model
562+
) > 0
563+
556564
@property
557565
def ctx(self) -> llama_cpp.llama_context_p:
558566
return self._ctx.ctx
@@ -638,13 +646,12 @@ def reset(self):
638646
"""Reset the model state."""
639647
self.n_tokens = 0
640648

641-
def eval(self, tokens: Sequence[int]):
642-
"""Evaluate a list of tokens.
649+
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
650+
if mem is not None:
651+
llama_cpp.llama_memory_clear(mem, True)
643652

644-
Args:
645-
tokens: The list of tokens to evaluate.
646-
"""
647-
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
653+
def eval(self, tokens: Sequence[int]):
654+
"""Evaluate a list of tokens."""
648655
for i in range(0, len(tokens), self.n_batch):
649656
batch = tokens[i : min(len(tokens), i + self.n_batch)]
650657
n_past = self.n_tokens
@@ -653,26 +660,12 @@ def eval(self, tokens: Sequence[int]):
653660
batch=batch, n_past=n_past, logits_all=self._logits_all
654661
)
655662
self._ctx.decode(self._batch)
656-
# Save tokens
657663
self.input_ids[n_past : n_past + n_tokens] = batch
658-
# Save logits
659664
if self._logits_all:
660-
rows = n_tokens
661-
cols = self._n_vocab
662665
logits = np.ctypeslib.as_array(
663-
self._ctx.get_logits(), shape=(rows * cols,)
666+
self._ctx.get_logits(), shape=(n_tokens, self._n_vocab)
664667
)
665-
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
666-
else:
667-
# rows = 1
668-
# cols = self._n_vocab
669-
# logits = np.ctypeslib.as_array(
670-
# self._ctx.get_logits(), shape=(rows * cols,)
671-
# )
672-
# self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
673-
# NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
674-
pass
675-
# Update n_tokens
668+
self.scores[n_past : n_past + n_tokens, :] = logits
676669
self.n_tokens += n_tokens
677670

678671
def _init_sampler(
@@ -888,22 +881,20 @@ def generate(
888881
# Check for kv cache prefix match
889882
if reset and self.n_tokens > 0:
890883
longest_prefix = 0
891-
for a, b in zip(self._input_ids, tokens[:-1]):
884+
for a, b in zip(self._input_ids, tokens):
892885
if a == b:
893886
longest_prefix += 1
894887
else:
895888
break
889+
890+
if (self._is_recurrent_model or self._has_swa_model) and longest_prefix < self.n_tokens:
891+
longest_prefix = 0
892+
896893
if longest_prefix > 0:
897894
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
898895
reset = False
899896
tokens = tokens[longest_prefix:]
900897
self.n_tokens = longest_prefix
901-
if self.verbose:
902-
print(
903-
f"Llama.generate: {longest_prefix} prefix-match hit, "
904-
f"remaining {len(tokens)} prompt tokens to eval",
905-
file=sys.stderr,
906-
)
907898
elif self.verbose:
908899
print(
909900
f"Llama.generate: {longest_prefix} prefix-match found "
@@ -1267,12 +1258,9 @@ def logit_bias_processor(
12671258
input_ids: npt.NDArray[np.intc],
12681259
scores: npt.NDArray[np.single],
12691260
) -> npt.NDArray[np.single]:
1270-
new_scores = np.copy(
1271-
scores
1272-
) # Does it make sense to copy the whole array or can we just overwrite the original one?
12731261
for input_id, score in logit_bias_map.items():
1274-
new_scores[input_id] = score + scores[input_id]
1275-
return new_scores
1262+
scores[input_id] += score
1263+
return scores
12761264

12771265
_logit_bias_processor = LogitsProcessorList([logit_bias_processor])
12781266
if logits_processor is None:
@@ -1333,6 +1321,7 @@ def logit_bias_processor(
13331321

13341322
finish_reason = "length"
13351323
multibyte_fix = 0
1324+
accumulated_text = b""
13361325
for token in self.generate(
13371326
prompt_tokens,
13381327
top_k=top_k,
@@ -1352,16 +1341,17 @@ def logit_bias_processor(
13521341
grammar=grammar,
13531342
):
13541343
if llama_cpp.llama_vocab_is_eog(self._model.vocab, token):
1355-
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1344+
text = accumulated_text
13561345
finish_reason = "stop"
13571346
break
13581347

13591348
completion_tokens.append(token)
13601349

1361-
all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1350+
new_text = self._model.token_to_piece(token)
1351+
accumulated_text += new_text
13621352

13631353
# Contains multi-byte UTF8
1364-
for k, char in enumerate(all_text[-3:]):
1354+
for k, char in enumerate(accumulated_text[-3:]):
13651355
k = 3 - k
13661356
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
13671357
# Bitwise AND check
@@ -1373,19 +1363,16 @@ def logit_bias_processor(
13731363
multibyte_fix -= 1
13741364
continue
13751365

1376-
any_stop = [s for s in stop_sequences if s in all_text]
1366+
any_stop = [s for s in stop_sequences if s in accumulated_text]
13771367
if len(any_stop) > 0:
13781368
first_stop = any_stop[0]
1379-
text = all_text[: all_text.index(first_stop)]
1369+
text = accumulated_text[: accumulated_text.index(first_stop)]
13801370
finish_reason = "stop"
13811371
break
13821372

13831373
if stream:
13841374
remaining_tokens = completion_tokens[returned_tokens:]
1385-
remaining_text = self.detokenize(
1386-
remaining_tokens,
1387-
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1388-
)
1375+
remaining_text = self._model.token_to_piece(token)
13891376
remaining_length = len(remaining_text)
13901377

13911378
# We want to avoid yielding any characters from
@@ -1522,24 +1509,23 @@ def logit_bias_processor(
15221509
}
15231510

15241511
if len(completion_tokens) >= max_tokens:
1525-
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1512+
text = accumulated_text
15261513
finish_reason = "length"
15271514
break
15281515

15291516
if stopping_criteria is not None and stopping_criteria(
15301517
self._input_ids, self._scores[-1, :]
15311518
):
1532-
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
1519+
text = accumulated_text
15331520
finish_reason = "stop"
15341521

15351522
if self.verbose:
15361523
self._ctx.print_timings()
15371524

15381525
if stream:
15391526
remaining_tokens = completion_tokens[returned_tokens:]
1540-
remaining_text = self.detokenize(
1541-
remaining_tokens,
1542-
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1527+
remaining_text = b"".join(
1528+
self._model.token_to_piece(t) for t in remaining_tokens
15431529
)
15441530
any_stop = [s for s in stop_sequences if s in remaining_text]
15451531
if len(any_stop) > 0:
@@ -1549,12 +1535,8 @@ def logit_bias_processor(
15491535

15501536
token_end_position = 0
15511537
for token in remaining_tokens:
1552-
token_end_position += len(
1553-
self.detokenize(
1554-
[token],
1555-
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
1556-
)
1557-
)
1538+
token_piece = self._model.token_to_piece(token)
1539+
token_end_position += len(token_piece)
15581540

15591541
logprobs_or_none: Optional[CompletionLogprobs] = None
15601542
if logprobs is not None:
@@ -1594,7 +1576,7 @@ def logit_bias_processor(
15941576
}
15951577

15961578
if token_end_position >= end:
1597-
last_text = self.detokenize([token])
1579+
last_text = token_piece
15981580
if token_end_position == end - 1:
15991581
break
16001582
returned_tokens += 1
@@ -1707,17 +1689,16 @@ def logit_bias_processor(
17071689
)
17081690
)
17091691
tokens.append(token_str)
1710-
sorted_logprobs = list(
1711-
sorted(
1712-
zip(logprobs_token, range(len(logprobs_token))), reverse=True
1713-
)
1714-
)
1692+
top_k_indices = np.argpartition(logprobs_token, -logprobs)[-logprobs:]
1693+
top_k_indices = top_k_indices[
1694+
np.argsort(logprobs_token[top_k_indices])
1695+
][::-1]
17151696
token_logprobs.append(logprobs_token[int(token)])
17161697
top_logprob: Optional[Dict[str, float]] = {
1717-
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode(
1698+
self.detokenize([int(i)], prev_tokens=all_tokens[:idx]).decode(
17181699
"utf-8", errors="ignore"
1719-
): logprob
1720-
for logprob, i in sorted_logprobs[:logprobs]
1700+
): logprobs_token[int(i)]
1701+
for i in top_k_indices
17211702
}
17221703
top_logprob.update({token_str: logprobs_token[int(token)]})
17231704
top_logprobs.append(top_logprob)

0 commit comments

Comments
 (0)