diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 47c0b9f2baa..c30f75baa50 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -75,6 +75,10 @@ def count(): total += m["count"] return total + def trim_content(content, limit): + limit = max(0, limit) + return encoder.decode(encoder.encode(content)[:limit]) + c = count() if c < max_length: return c, msg @@ -89,16 +93,34 @@ def count(): ll = num_tokens_from_string(msg_[0]["content"]) ll2 = num_tokens_from_string(msg_[-1]["content"]) - if ll / (ll + ll2) > 0.8: - m = msg_[0]["content"] - m = encoder.decode(encoder.encode(m)[: max_length - ll2]) - msg[0]["content"] = m - return max_length, msg - - m = msg_[-1]["content"] - m = encoder.decode(encoder.encode(m)[: max_length - ll2]) - msg[-1]["content"] = m - return max_length, msg + total = ll + ll2 + if total <= 0: + logging.debug( + "message_fit_in degenerate token counts total=%s max_length=%s ll=%s ll2=%s preserved_roles=%s", + total, + max_length, + ll, + ll2, + [m.get("role") for m in msg], + ) + return 0, msg + + if len(msg) == 1: + msg[0]["content"] = trim_content(msg[0]["content"], max_length) + return count(), msg + + if ll / total > 0.8: + preserved_last = min(ll2, max_length) + msg[-1]["content"] = trim_content(msg_[-1]["content"], preserved_last) + remaining = max(0, max_length - preserved_last) + msg[0]["content"] = trim_content(msg_[0]["content"], remaining) + return count(), msg + + preserved_system = min(ll, max_length) + msg[0]["content"] = trim_content(msg_[0]["content"], preserved_system) + remaining = max(0, max_length - preserved_system) + msg[-1]["content"] = trim_content(msg_[-1]["content"], remaining) + return count(), msg def kb_prompt(kbinfos, max_tokens, hash_id=False): diff --git a/test/unit_test/rag/prompts/test_generator_message_fit_in.py b/test/unit_test/rag/prompts/test_generator_message_fit_in.py new file mode 100644 index 00000000000..925c203e68a --- /dev/null +++ b/test/unit_test/rag/prompts/test_generator_message_fit_in.py @@ -0,0 +1,151 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _CharEncoder: + @staticmethod + def encode(text): + return list(text) + + @staticmethod + def decode(tokens): + return "".join(tokens) + + +def _load_generator_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + json_repair = ModuleType("json_repair") + json_repair.repair_json = lambda text, **_kwargs: text + monkeypatch.setitem(sys.modules, "json_repair", json_repair) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + misc_utils = ModuleType("common.misc_utils") + misc_utils.hash_str2int = lambda value, _mod=500: 0 + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils) + + constants = ModuleType("common.constants") + constants.TAG_FLD = "tag" + monkeypatch.setitem(sys.modules, "common.constants", constants) + + token_utils = ModuleType("common.token_utils") + token_utils.encoder = _CharEncoder() + token_utils.num_tokens_from_string = lambda text: len(text) + monkeypatch.setitem(sys.modules, "common.token_utils", token_utils) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + + rag_nlp = ModuleType("rag.nlp") + rag_nlp.rag_tokenizer = SimpleNamespace(tokenize=lambda text: text.split()) + monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp) + + rag_prompts_pkg = ModuleType("rag.prompts") + rag_prompts_pkg.__path__ = [str(repo_root / "rag" / "prompts")] + monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg) + + template_mod = ModuleType("rag.prompts.template") + template_mod.load_prompt = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "rag.prompts.template", template_mod) + + spec = importlib.util.spec_from_file_location( + "rag.prompts.generator", repo_root / "rag" / "prompts" / "generator.py" + ) + module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, "rag.prompts.generator", module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p1 +def test_message_fit_in_truncates_user_message_by_system_token_budget(monkeypatch): + generator = _load_generator_module(monkeypatch) + monkeypatch.setattr(generator, "num_tokens_from_string", lambda text: len(text)) + monkeypatch.setattr(generator, "encoder", _CharEncoder()) + + messages = [ + {"role": "system", "content": "1234"}, + {"role": "user", "content": "abcdefghij"}, + ] + + used_tokens, trimmed = generator.message_fit_in(messages, max_length=8) + + assert used_tokens == 8 + assert trimmed[0]["content"] == "1234" + assert trimmed[-1]["content"] == "abcd" + + +@pytest.mark.p1 +def test_message_fit_in_handles_zero_token_messages(monkeypatch): + generator = _load_generator_module(monkeypatch) + monkeypatch.setattr(generator, "num_tokens_from_string", lambda _text: 0) + monkeypatch.setattr(generator, "encoder", _CharEncoder()) + + messages = [ + {"role": "system", "content": ""}, + {"role": "user", "content": ""}, + ] + + used_tokens, trimmed = generator.message_fit_in(messages, max_length=0) + + assert used_tokens == 0 + assert trimmed == messages + + +@pytest.mark.p1 +def test_message_fit_in_clamps_negative_slice_lengths(monkeypatch): + generator = _load_generator_module(monkeypatch) + monkeypatch.setattr(generator, "num_tokens_from_string", lambda text: len(text)) + monkeypatch.setattr(generator, "encoder", _CharEncoder()) + + messages = [ + {"role": "system", "content": "1234"}, + {"role": "user", "content": "abcdefghij"}, + ] + + used_tokens, trimmed = generator.message_fit_in(messages, max_length=2) + + assert used_tokens == 2 + assert trimmed[0]["content"] == "12" + assert trimmed[-1]["content"] == "" + + +@pytest.mark.p1 +def test_message_fit_in_clamps_dominant_last_message_to_budget(monkeypatch): + generator = _load_generator_module(monkeypatch) + monkeypatch.setattr(generator, "num_tokens_from_string", lambda text: len(text)) + monkeypatch.setattr(generator, "encoder", _CharEncoder()) + + messages = [ + {"role": "system", "content": "s" * 41}, + {"role": "user", "content": "abcdefghij"}, + ] + + used_tokens, trimmed = generator.message_fit_in(messages, max_length=8) + + assert used_tokens == 8 + assert trimmed[0]["content"] == "" + assert trimmed[-1]["content"] == "abcdefgh"