Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
chat_function = kernel.add_function(
prompt_template_config=PromptTemplateConfig(
template="""{{system_message}}{{#each chat_history}}
{{#message role=role}}{{~content~}}{{/message}} {{/each}}""",
{{message_to_prompt}} {{/each}}""",
template_format="handlebars",
allow_dangerously_set_content=True,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@

chat_function = kernel.add_function(
prompt_template_config=PromptTemplateConfig(
template="""{{system_message}}{% for item in chat_history %}{{ message(item) }}{% endfor %}""",
template="""{{system_message}}{% for item in chat_history %}{{ message_to_prompt(item) }}{% endfor %}""",
template_format="jinja2",
allow_dangerously_set_content=True,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from collections.abc import Callable
from enum import Enum
from xml.etree.ElementTree import Element, tostring # nosec B405

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -28,21 +29,20 @@ def _message_to_prompt(this, *args, **kwargs):
def _message(this, options, *args, **kwargs):
from semantic_kernel.contents.const import CHAT_MESSAGE_CONTENT_TAG

# everything in kwargs, goes to <ROOT_KEY_MESSAGE kwargs_key="kwargs_value">
# everything in options, goes in between <ROOT_KEY_MESSAGE>options</ROOT_KEY_MESSAGE>
start = f"<{CHAT_MESSAGE_CONTENT_TAG}"
# Everything in kwargs becomes an attribute, and the block output is treated as message text.
message = Element(CHAT_MESSAGE_CONTENT_TAG)
for key, value in kwargs.items():
if isinstance(value, Enum):
value = value.value
if value is not None:
start += f' {key}="{value}"'
start += ">"
end = f"</{CHAT_MESSAGE_CONTENT_TAG}>"
message.set(key, str(value))
try:
content = options["fn"](this)
content = str(options["fn"](this))
except Exception: # pragma: no cover
content = ""
return f"{start}{content}{end}"
if content:
message.text = content
return tostring(message, encoding="unicode", short_empty_elements=False)


def _set(this, *args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
from collections.abc import Callable
from enum import Enum
from xml.etree.ElementTree import Element, tostring # nosec B405

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -27,15 +28,14 @@ def _message_to_prompt(context):
def _message(item):
from semantic_kernel.contents.const import CHAT_MESSAGE_CONTENT_TAG

start = f"<{CHAT_MESSAGE_CONTENT_TAG}"
role = item.role
content = item.content
if isinstance(role, Enum):
role = role.value
start += f' role="{role}"'
start += ">"
end = f"</{CHAT_MESSAGE_CONTENT_TAG}>"
return f"{start}{content}{end}"
message = Element(CHAT_MESSAGE_CONTENT_TAG)
message.set("role", str(role))
if item.content:
message.text = item.content
return tostring(message, encoding="unicode", short_empty_elements=False)


# Wrap the _get function to safely handle calls without arguments
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,26 @@ async def test_helpers_message(kernel: Kernel):
assert "Assistant message" in rendered


async def test_helpers_message_escapes_xml_metacharacters(kernel: Kernel):
template = """
{{#each chat_history}}
{{#message role=role}}
{{~content~}}
{{/message}}
{{/each}}
"""
target = create_handlebars_prompt_template(template, allow_dangerously_set_content=True)
chat_history = ChatHistory()
chat_history.add_user_message('What does a < b & "c" mean?')

rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))

assert "&lt;" in rendered
assert "&amp;" in rendered
assert '"c"' in rendered
assert ChatHistory.from_rendered_prompt(rendered) == chat_history


async def test_helpers_message_to_prompt(kernel: Kernel):
template = """{{#each chat_history}}{{message_to_prompt}} {{/each}}"""
target = create_handlebars_prompt_template(template, allow_dangerously_set_content=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from semantic_kernel import Kernel
from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.functions import kernel_function
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.prompt_template.handlebars_prompt_template import HandlebarsPromptTemplate
Expand Down Expand Up @@ -100,3 +101,39 @@ async def test_chat_history_round_trip(self, kernel: Kernel):
)
chat_history2 = ChatHistory.from_rendered_prompt(rendered)
assert chat_history2 == chat_history

async def test_chat_history_round_trip_with_xml_metacharacters(self, kernel: Kernel):
# Arrange
template = """{{#each chat_history}}{{#message role=role}}{{~content~}}{{/message}} {{/each}}"""
target = create_handlebars_prompt_template(template)
chat_history = ChatHistory()
chat_history.add_user_message("What does a < b mean in Python?")
chat_history.add_assistant_message('Use "&" carefully in XML and HTML.')

rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))

assert "&lt;" in rendered
assert "&amp;" in rendered
assert '"&amp;"' in rendered
assert ChatHistory.from_rendered_prompt(rendered) == chat_history

async def test_message_helper_preserves_system_role_with_xml_metacharacters(self, kernel: Kernel):
# Arrange
template = (
"""{{system_message}}{{#each chat_history}}{{#message role=role}}{{~content~}}{{/message}} {{/each}}"""
)
target = create_handlebars_prompt_template(template)
system_message = "You are a helpful assistant."
chat_history = ChatHistory()
chat_history.add_user_message("What does a < b mean in Python?")

rendered = await target.render(
kernel,
KernelArguments(system_message=system_message, chat_history=chat_history),
)

parsed = ChatHistory.from_rendered_prompt(rendered)
assert parsed.messages[0].role == AuthorRole.SYSTEM
assert parsed.messages[0].content == system_message
assert parsed.messages[1].role == AuthorRole.USER
assert parsed.messages[1].content == "What does a < b mean in Python?"
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,20 @@ async def test_helpers_message(kernel: Kernel):
assert "Assistant message" in rendered


async def test_helpers_message_escapes_xml_metacharacters(kernel: Kernel):
template = """{% for item in chat_history %}{{ message(item) }}{% endfor %}"""
target = create_jinja2_prompt_template(template, allow_dangerously_set_content=True)
chat_history = ChatHistory()
chat_history.add_user_message('What does a < b & "c" mean?')

rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))

assert "&lt;" in rendered
assert "&amp;" in rendered
assert '"c"' in rendered
assert ChatHistory.from_rendered_prompt(rendered) == chat_history


async def test_helpers_message_to_prompt(kernel: Kernel):
template = """
{% for chat in chat_history %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


from semantic_kernel.contents.chat_history import ChatHistory
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.functions import kernel_function
from semantic_kernel.functions.kernel_arguments import KernelArguments
from semantic_kernel.kernel import Kernel
Expand Down Expand Up @@ -104,3 +105,37 @@ async def test_chat_history_round_trip(kernel: Kernel):
)
chat_history2 = ChatHistory.from_rendered_prompt(rendered)
assert chat_history2 == chat_history


async def test_chat_history_round_trip_with_xml_metacharacters(kernel: Kernel):
template = """{% for item in chat_history %}{{ message(item) }}{% endfor %}"""
target = create_jinja2_prompt_template(template)
chat_history = ChatHistory()
chat_history.add_user_message("What does a < b mean in Python?")
chat_history.add_assistant_message('Use "&" carefully in XML and HTML.')

rendered = await target.render(kernel, KernelArguments(chat_history=chat_history))

assert "&lt;" in rendered
assert "&amp;" in rendered
assert '"&amp;"' in rendered
assert ChatHistory.from_rendered_prompt(rendered) == chat_history


async def test_message_helper_preserves_system_role_with_xml_metacharacters(kernel: Kernel):
template = """{{system_message}}{% for item in chat_history %}{{ message(item) }}{% endfor %}"""
target = create_jinja2_prompt_template(template)
system_message = "You are a helpful assistant."
chat_history = ChatHistory()
chat_history.add_user_message("What does a < b mean in Python?")

rendered = await target.render(
kernel,
KernelArguments(system_message=system_message, chat_history=chat_history),
)

parsed = ChatHistory.from_rendered_prompt(rendered)
assert parsed.messages[0].role == AuthorRole.SYSTEM
assert parsed.messages[0].content == system_message
assert parsed.messages[1].role == AuthorRole.USER
assert parsed.messages[1].content == "What does a < b mean in Python?"
Loading