Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions rag/prompts/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,18 @@ def count():

ll = num_tokens_from_string(msg_[0]["content"])
ll2 = num_tokens_from_string(msg_[-1]["content"])
if ll / (ll + ll2) > 0.8:
total = ll + ll2
if total <= 0:
return 0, msg
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if ll / total > 0.8:
m = msg_[0]["content"]
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
msg[0]["content"] = m
return max_length, msg

m = msg_[-1]["content"]
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
m = encoder.decode(encoder.encode(m)[: max_length - ll])
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
msg[-1]["content"] = m
return max_length, msg

Expand Down
115 changes: 115 additions & 0 deletions test/unit_test/rag/prompts/test_generator_message_fit_in.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace

import pytest


class _CharEncoder:
@staticmethod
def encode(text):
return list(text)

@staticmethod
def decode(tokens):
return "".join(tokens)


def _load_generator_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]

json_repair = ModuleType("json_repair")
json_repair.repair_json = lambda text, **_kwargs: text
monkeypatch.setitem(sys.modules, "json_repair", json_repair)

common_pkg = ModuleType("common")
common_pkg.__path__ = [str(repo_root / "common")]
monkeypatch.setitem(sys.modules, "common", common_pkg)

misc_utils = ModuleType("common.misc_utils")
misc_utils.hash_str2int = lambda value, _mod=500: 0
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils)

constants = ModuleType("common.constants")
constants.TAG_FLD = "tag"
monkeypatch.setitem(sys.modules, "common.constants", constants)

token_utils = ModuleType("common.token_utils")
token_utils.encoder = _CharEncoder()
token_utils.num_tokens_from_string = lambda text: len(text)
monkeypatch.setitem(sys.modules, "common.token_utils", token_utils)

rag_pkg = ModuleType("rag")
rag_pkg.__path__ = [str(repo_root / "rag")]
monkeypatch.setitem(sys.modules, "rag", rag_pkg)

rag_nlp = ModuleType("rag.nlp")
rag_nlp.rag_tokenizer = SimpleNamespace(tokenize=lambda text: text.split())
monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp)

rag_prompts_pkg = ModuleType("rag.prompts")
rag_prompts_pkg.__path__ = [str(repo_root / "rag" / "prompts")]
monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg)

template_mod = ModuleType("rag.prompts.template")
template_mod.load_prompt = lambda *_args, **_kwargs: ""
monkeypatch.setitem(sys.modules, "rag.prompts.template", template_mod)

spec = importlib.util.spec_from_file_location(
"rag.prompts.generator", repo_root / "rag" / "prompts" / "generator.py"
)
module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "rag.prompts.generator", module)
spec.loader.exec_module(module)
return module


@pytest.mark.p1
def test_message_fit_in_truncates_user_message_by_system_token_budget(monkeypatch):
generator = _load_generator_module(monkeypatch)
monkeypatch.setattr(generator, "num_tokens_from_string", lambda text: len(text))
monkeypatch.setattr(generator, "encoder", _CharEncoder())

messages = [
{"role": "system", "content": "1234"},
{"role": "user", "content": "abcdefghij"},
]

used_tokens, trimmed = generator.message_fit_in(messages, max_length=8)

assert used_tokens == 8
assert trimmed[0]["content"] == "1234"
assert trimmed[-1]["content"] == "abcd"


@pytest.mark.p1
def test_message_fit_in_handles_zero_token_messages(monkeypatch):
generator = _load_generator_module(monkeypatch)
monkeypatch.setattr(generator, "num_tokens_from_string", lambda _text: 0)
monkeypatch.setattr(generator, "encoder", _CharEncoder())

messages = [
{"role": "system", "content": ""},
{"role": "user", "content": ""},
]

used_tokens, trimmed = generator.message_fit_in(messages, max_length=0)

assert used_tokens == 0
assert trimmed == messages