Skip to content

Commit 9609c82

Browse files
author
Ralf Waldukat
committed
fix: prevent KV cache corruption on SWA/ISWA models (e.g. Gemma-4)
SWA/ISWA KV caches maintain global position maps (g_iswa_pos_max/min) that are only cleared by llama_memory_clear(), not by kv_cache_seq_rm(). When generate() finds a prefix match (e.g. shared BOS token), it calls kv_cache_seq_rm which returns True for ISWA, skipping the full reset. But the stale position maps cause batch allocator inconsistency and llama_decode returned -1 on subsequent prompts. Changes: - Add _has_swa property via llama_model_n_swa() > 0 - reset() now calls llama_memory_clear() unconditionally - generate() bypasses prefix-match optimization for SWA models, forcing full state reset (same path as recurrent models)
1 parent 1cb8b9f commit 9609c82

File tree

3 files changed

+110
-2
lines changed

3 files changed

+110
-2
lines changed

llama_cpp/llama.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,14 @@ def free_lora_adapter():
553553

554554
self._sampler = None
555555

556+
# Cache model architecture flags to avoid repeated FFI calls
557+
self._is_recurrent_model = llama_cpp.llama_model_is_recurrent(
558+
self._model.model
559+
) or llama_cpp.llama_model_is_hybrid(self._model.model)
560+
self._has_swa_model = llama_cpp.llama_model_n_swa(
561+
self._model.model
562+
) > 0
563+
556564
@property
557565
def ctx(self) -> llama_cpp.llama_context_p:
558566
return self._ctx.ctx
@@ -638,6 +646,10 @@ def reset(self):
638646
"""Reset the model state."""
639647
self.n_tokens = 0
640648

649+
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
650+
if mem is not None:
651+
llama_cpp.llama_memory_clear(mem, True)
652+
641653
def eval(self, tokens: Sequence[int]):
642654
"""Evaluate a list of tokens.
643655
@@ -889,11 +901,29 @@ def generate(
889901
# Check for kv cache prefix match
890902
if reset and self.n_tokens > 0:
891903
longest_prefix = 0
892-
for a, b in zip(self._input_ids, tokens[:-1]):
904+
for a, b in zip(self._input_ids, tokens):
893905
if a == b:
894906
longest_prefix += 1
895907
else:
896908
break
909+
910+
# Recurrent models cannot rewind state; reset if needed
911+
if self._is_recurrent_model and longest_prefix < self.n_tokens:
912+
longest_prefix = 0
913+
reset = True
914+
if self.verbose:
915+
print(
916+
"Llama.generate: recurrent model requires full state reset",
917+
file=sys.stderr,
918+
)
919+
920+
# SWA/ISWA models (e.g. Gemma-4) have split KV caches whose
921+
# position-tracking maps are only cleared by a full reset.
922+
# Partial seq_rm leaves stale positions and causes decode failure.
923+
if self._has_swa_model and longest_prefix < self.n_tokens:
924+
longest_prefix = 0
925+
reset = True
926+
897927
if longest_prefix > 0:
898928
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
899929
reset = False
@@ -1259,6 +1289,8 @@ def _create_completion(
12591289
RuntimeWarning,
12601290
)
12611291

1292+
# NOTE: This likely doesn't work correctly for the first token in the prompt
1293+
# because of the extra space added to the start of the prompt_tokens
12621294
if logit_bias is not None:
12631295
logit_bias_map = {int(k): float(v) for k, v in logit_bias.items()}
12641296

@@ -1682,6 +1714,7 @@ def logit_bias_processor(
16821714
for i, token in enumerate(all_tokens)
16831715
]
16841716
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
1717+
# TODO: may be able to change this loop to use np.take_along_dim
16851718
for idx, (token, token_str, logprobs_token) in enumerate(
16861719
zip(all_tokens, all_token_strs, all_logprobs)
16871720
):

test_gemma4_iswa.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Test Gemma-4 ISWA model with sequential chat prompts.
2+
3+
Tests:
4+
1. ISWA fix: no 'llama_decode returned -1' on sequential prompts
5+
2. Output quality: coherent text with proper chat template
6+
"""
7+
8+
import sys
9+
import time
10+
11+
from llama_cpp import Llama
12+
13+
MODEL_PATH = "/Users/avion/Documents.nosync/projects/llama-cpp-python/vendor/llama.cpp/build/bin/../../../Downloads/models/supergemma4-26b-uncensored-fast-v2-Q4_K_M.gguf"
14+
15+
PROMPTS = [
16+
[{"role": "user", "content": "What is 2+2? Answer briefly."}],
17+
[{"role": "user", "content": "Write a Python hello world in one line."}],
18+
[{"role": "user", "content": "Explain recursion in one sentence."}],
19+
]
20+
21+
def main():
22+
print(f"Loading model: {MODEL_PATH}")
23+
t0 = time.time()
24+
llm = Llama(
25+
model_path="/Users/avion/Downloads/models/supergemma4-26b-uncensored-fast-v2-Q4_K_M.gguf",
26+
n_gpu_layers=-1,
27+
n_ctx=4096,
28+
verbose=False,
29+
)
30+
print(f"Model loaded in {time.time() - t0:.1f}s")
31+
print(f" _has_swa: {llm._has_swa}")
32+
print(f" _is_recurrent: {llm._is_recurrent}")
33+
print(f" n_ctx: {llm.n_ctx()}")
34+
print()
35+
36+
results = []
37+
for i, messages in enumerate(PROMPTS):
38+
prompt = messages[0]["content"]
39+
print(f"--- Chat {i+1}: {prompt!r} ---")
40+
t1 = time.time()
41+
try:
42+
resp = llm.create_chat_completion(
43+
messages=messages,
44+
max_tokens=128,
45+
temperature=0.6,
46+
top_p=0.95,
47+
repeat_penalty=1.1,
48+
)
49+
elapsed = time.time() - t1
50+
text = resp["choices"][0]["message"]["content"]
51+
print(f"[OK] {elapsed:.1f}s, {len(text)} chars:")
52+
print(text[:300])
53+
print()
54+
results.append(("OK", prompt, text))
55+
except RuntimeError as e:
56+
elapsed = time.time() - t1
57+
print(f"[FAIL] {elapsed:.1f}s: {e}")
58+
print()
59+
results.append(("FAIL", prompt, str(e)))
60+
61+
print("=" * 60)
62+
print("SUMMARY")
63+
print("=" * 60)
64+
for i, (status, prompt, detail) in enumerate(results):
65+
print(f" [{status}] Chat {i+1}: {prompt!r}")
66+
67+
ok = sum(1 for s, _, _ in results if s == "OK")
68+
fail = sum(1 for s, _, _ in results if s == "FAIL")
69+
print(f"\n Passed: {ok}/{len(results)}, Failed: {fail}/{len(results)}")
70+
71+
if fail > 0:
72+
sys.exit(1)
73+
74+
if __name__ == "__main__":
75+
main()

vendor/llama.cpp

0 commit comments

Comments
 (0)