Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.

from abc import ABC
from typing import ClassVar
from typing import Any, ClassVar

from google.genai import Client

from semantic_kernel.connectors.ai.google.google_ai.google_ai_settings import GoogleAISettings
from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.telemetry.user_agent import APP_INFO, prepend_semantic_kernel_to_user_agent


class GoogleAIBase(KernelBaseModel, ABC):
Expand All @@ -17,3 +18,16 @@ class GoogleAIBase(KernelBaseModel, ABC):
service_settings: GoogleAISettings

client: Client | None = None

def _get_http_options(self) -> dict[str, Any] | None:
"""Get the HTTP options for the Google AI client.

Returns:
The HTTP options dictionary, or None if telemetry is disabled.
"""
if not APP_INFO:
return None

headers = dict(APP_INFO)
headers = prepend_semantic_kernel_to_user_agent(headers)
return {"headers": headers}
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,14 @@ async def _generate_content(client: Client) -> GenerateContentResponse:
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
response: GenerateContentResponse = await _generate_content(client) # type: ignore[no-redef]
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
response: GenerateContentResponse = await _generate_content(client) # type: ignore[no-redef]

return [self._create_chat_message_content(response, candidate) for candidate in response.candidates] # type: ignore
Expand Down Expand Up @@ -216,14 +220,18 @@ async def _generate_content_stream(client: Client) -> AsyncGenerator[GenerateCon
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
async for chunk in _generate_content_stream(client):
yield [
self._create_streaming_chat_message_content(chunk, candidate, function_invoke_attempt)
for candidate in chunk.candidates # type: ignore
]
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
async for chunk in _generate_content_stream(client):
yield [
self._create_streaming_chat_message_content(chunk, candidate, function_invoke_attempt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,14 @@ async def _generate_content(client: Client) -> GenerateContentResponse:
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
response: GenerateContentResponse = await _generate_content(client) # type: ignore[no-redef]
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
response: GenerateContentResponse = await _generate_content(client) # type: ignore[no-redef]

return [self._create_text_content(response, candidate) for candidate in response.candidates] # type: ignore
Expand Down Expand Up @@ -173,11 +177,15 @@ async def _generate_content_stream(client: Client) -> AsyncGenerator[GenerateCon
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
async for chunk in _generate_content_stream(client):
yield [self._create_streaming_text_content(chunk, candidate) for candidate in chunk.candidates] # type: ignore
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
async for chunk in _generate_content_stream(client):
yield [self._create_streaming_text_content(chunk, candidate) for candidate in chunk.candidates] # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,14 @@ async def _embed_content(client: Client) -> EmbedContentResponse:
vertexai=True,
project=self.service_settings.cloud_project_id,
location=self.service_settings.cloud_region,
http_options=self._get_http_options(),
) as client:
response: EmbedContentResponse = await _embed_content(client) # type: ignore[no-redef]
else:
with Client(api_key=self.service_settings.api_key.get_secret_value()) as client: # type: ignore[union-attr]
with Client(
api_key=self.service_settings.api_key.get_secret_value(),
http_options=self._get_http_options(),
) as client: # type: ignore[union-attr]
response: EmbedContentResponse = await _embed_content(client) # type: ignore[no-redef]

return [embedding.values for embedding in response.embeddings] # type: ignore
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft. All rights reserved.

from unittest.mock import AsyncMock, patch

import pytest

from semantic_kernel.connectors.ai.google.google_ai.google_ai_prompt_execution_settings import (
GoogleAIChatPromptExecutionSettings,
)
from semantic_kernel.connectors.ai.google.google_ai.services.google_ai_chat_completion import GoogleAIChatCompletion
from semantic_kernel.const import USER_AGENT
from semantic_kernel.contents.chat_history import ChatHistory


@pytest.mark.asyncio
async def test_google_ai_chat_completion_user_agent(google_ai_unit_test_env):
"""Test that GoogleAIChatCompletion sends the User-Agent header."""
chat_history = ChatHistory()
chat_history.add_user_message("hi")

with patch(
"semantic_kernel.connectors.ai.google.google_ai.services.google_ai_chat_completion.Client"
) as mock_client:
mock_instance = mock_client.return_value.__enter__.return_value
mock_instance.aio.models.generate_content = AsyncMock()

service = GoogleAIChatCompletion(gemini_model_id="gemini-3-flash-preview", api_key="AIza-test-key")

await service.get_chat_message_contents(
chat_history=chat_history, settings=GoogleAIChatPromptExecutionSettings()
)

_, kwargs = mock_client.call_args
assert "http_options" in kwargs
assert kwargs["http_options"] is not None
assert "headers" in kwargs["http_options"]
assert USER_AGENT in kwargs["http_options"]["headers"]
assert "semantic-kernel-python" in kwargs["http_options"]["headers"][USER_AGENT]


@pytest.mark.asyncio
async def test_google_ai_chat_completion_no_telemetry(google_ai_unit_test_env):
"""Test that GoogleAIChatCompletion does not send the User-Agent header when telemetry is disabled."""
chat_history = ChatHistory()
chat_history.add_user_message("hi")

with (
patch("semantic_kernel.connectors.ai.google.google_ai.services.google_ai_base.APP_INFO", None),
patch(
"semantic_kernel.connectors.ai.google.google_ai.services.google_ai_chat_completion.Client"
) as mock_client,
):
mock_instance = mock_client.return_value.__enter__.return_value
mock_instance.aio.models.generate_content = AsyncMock()

service = GoogleAIChatCompletion(gemini_model_id="gemini-3-flash-preview", api_key="AIza-test-key")

await service.get_chat_message_contents(
chat_history=chat_history, settings=GoogleAIChatPromptExecutionSettings()
)

_, kwargs = mock_client.call_args
assert "http_options" in kwargs
assert kwargs["http_options"] is None
Loading