Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 92 additions & 25 deletions src/chat_sdk/adapters/teams/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
import re
from collections.abc import Awaitable, Callable
from datetime import datetime, timezone
from typing import Any, Literal, NoReturn
from typing import Any, Literal, NoReturn, cast
from urllib.parse import urlparse

from chat_sdk.adapters.teams.cards import card_to_adaptive_card
from chat_sdk.adapters.teams.format_converter import TeamsFormatConverter
from chat_sdk.adapters.teams.types import (
TeamsAdapterConfig,
TeamsChannelContext,
TeamsDmContext,
TeamsGraphContext,
TeamsThreadId,
)
from chat_sdk.emoji import convert_emoji_placeholders
Expand Down Expand Up @@ -311,9 +313,25 @@ async def _cache_user_context(self, activity: dict[str, Any]) -> None:
context: TeamsChannelContext = {
"team_id": team_aad_group_id,
"channel_id": channel_data["channel"]["id"],
"type": "channel",
}
await state.set(f"teams:channelContext:{base_channel_id}", json.dumps(context), ttl)

# Cache DM context for Microsoft Graph chat ID resolution
# (vercel/chat#403). Bot Framework hands out opaque DM conversation
# IDs that Graph's ``/chats/{chat-id}/messages`` endpoint rejects;
# the canonical Graph chat ID for a 1:1 DM is
# ``19:{userAadId}_{botId}@unq.gbl.spaces``. ``aadObjectId`` is
# only present for real Teams users (not bots), and DM conversation
# IDs do not start with ``19:`` (channel/group chats do).
aad_object_id = from_user.get("aadObjectId")
if aad_object_id and self._app_id and not base_channel_id.startswith("19:") and state:
dm_context: TeamsDmContext = {
"graph_chat_id": f"19:{aad_object_id}_{self._app_id}@unq.gbl.spaces",
"type": "dm",
}
await state.set(f"teams:channelContext:{base_channel_id}", json.dumps(dm_context), ttl)

async def _handle_message_activity(
self,
activity: dict[str, Any],
Expand Down Expand Up @@ -1036,29 +1054,38 @@ async def fetch_messages(
thread_message_id = message_id_match.group(1) if message_id_match else None
base_conversation_id = MESSAGEID_STRIP_PATTERN.sub("", conversation_id)

channel_context = await self._get_channel_context(base_conversation_id) if thread_message_id else None
# vercel/chat#403: always look up the Graph context, not just for
# channel threads — DMs need it to map the opaque Bot Framework
# conversation ID to the canonical Graph chat ID
# (``19:{aadId}_{botId}@unq.gbl.spaces``) that ``/chats/{id}``
# accepts.
graph_context = await self._get_graph_context(base_conversation_id)
context_type: str | None = graph_context.get("type") if graph_context else None

try:
self._logger.debug(
"Teams Graph API: fetching messages",
{
"conversationId": base_conversation_id,
"threadMessageId": thread_message_id,
"hasChannelContext": channel_context is not None,
"contextType": context_type or "none",
"limit": limit,
"cursor": cursor,
"direction": direction,
},
)

if channel_context and thread_message_id:
if graph_context and context_type != "dm" and thread_message_id:
# Narrowed: channel context for a channel thread.
channel_context = cast(TeamsChannelContext, graph_context)
return await self._fetch_channel_thread_messages(
channel_context,
thread_message_id,
thread_id,
options,
)

chat_id = self._chat_id_from_context(graph_context, base_conversation_id)
graph_messages: list[dict[str, Any]]
has_more = False

Expand All @@ -1069,7 +1096,7 @@ async def fetch_messages(
}
if cursor:
params["$filter"] = f"createdDateTime gt {cursor}"
graph_messages = await self._graph_list_chat_messages(base_conversation_id, params)
graph_messages = await self._graph_list_chat_messages(chat_id, params)
has_more = len(graph_messages) >= limit
else:
params = {
Expand All @@ -1078,11 +1105,11 @@ async def fetch_messages(
}
if cursor:
params["$filter"] = f"createdDateTime lt {cursor}"
graph_messages = await self._graph_list_chat_messages(base_conversation_id, params)
graph_messages = await self._graph_list_chat_messages(chat_id, params)
graph_messages.reverse()
has_more = len(graph_messages) >= limit

if thread_message_id and not channel_context:
if thread_message_id and not graph_context:
graph_messages = [msg for msg in graph_messages if msg.get("id") and msg["id"] >= thread_message_id]
self._logger.debug(
"Filtered group chat messages to thread",
Expand Down Expand Up @@ -1134,13 +1161,14 @@ async def fetch_channel_messages(
direction = options.direction or "backward"

try:
channel_context = await self._get_channel_context(base_conversation_id)
graph_context = await self._get_graph_context(base_conversation_id)
context_type = graph_context.get("type") if graph_context else None

self._logger.debug(
"Teams Graph API: fetchChannelMessages",
{
"conversationId": base_conversation_id,
"hasChannelContext": channel_context is not None,
"contextType": context_type or "none",
"limit": limit,
"direction": direction,
},
Expand All @@ -1149,7 +1177,8 @@ async def fetch_channel_messages(
graph_messages: list[dict[str, Any]]
has_more = False

if channel_context:
if graph_context and context_type != "dm":
channel_context = cast(TeamsChannelContext, graph_context)
if direction == "forward":
graph_messages = await self._graph_list_channel_messages(
channel_context["team_id"],
Expand All @@ -1176,19 +1205,24 @@ async def fetch_channel_messages(
)
graph_messages.reverse()
has_more = len(graph_messages) >= limit
elif direction == "forward":
params = {"$top": limit, "$orderby": "createdDateTime asc"}
if options.cursor:
params["$filter"] = f"createdDateTime gt {options.cursor}"
graph_messages = await self._graph_list_chat_messages(base_conversation_id, params)
has_more = len(graph_messages) >= limit
else:
params = {"$top": limit, "$orderby": "createdDateTime desc"}
if options.cursor:
params["$filter"] = f"createdDateTime lt {options.cursor}"
graph_messages = await self._graph_list_chat_messages(base_conversation_id, params)
graph_messages.reverse()
has_more = len(graph_messages) >= limit
# vercel/chat#403: DM contexts substitute the canonical Graph
# chat ID for the opaque Bot Framework conversation ID; no
# context (group chat) falls through to the raw ID.
chat_id = self._chat_id_from_context(graph_context, base_conversation_id)
if direction == "forward":
params = {"$top": limit, "$orderby": "createdDateTime asc"}
if options.cursor:
params["$filter"] = f"createdDateTime gt {options.cursor}"
graph_messages = await self._graph_list_chat_messages(chat_id, params)
has_more = len(graph_messages) >= limit
else:
params = {"$top": limit, "$orderby": "createdDateTime desc"}
if options.cursor:
params["$filter"] = f"createdDateTime lt {options.cursor}"
graph_messages = await self._graph_list_chat_messages(chat_id, params)
graph_messages.reverse()
has_more = len(graph_messages) >= limit
Comment on lines +1212 to +1225
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling forward and backward pagination is duplicated here and in fetch_messages (lines 1092-1110). To improve maintainability and reduce code duplication, you could refactor this block and apply a similar change in fetch_messages.

Here's a suggested refactoring for this block:

Suggested change
chat_id = self._chat_id_from_context(graph_context, base_conversation_id)
if direction == "forward":
params = {"$top": limit, "$orderby": "createdDateTime asc"}
if options.cursor:
params["$filter"] = f"createdDateTime gt {options.cursor}"
graph_messages = await self._graph_list_chat_messages(chat_id, params)
has_more = len(graph_messages) >= limit
else:
params = {"$top": limit, "$orderby": "createdDateTime desc"}
if options.cursor:
params["$filter"] = f"createdDateTime lt {options.cursor}"
graph_messages = await self._graph_list_chat_messages(chat_id, params)
graph_messages.reverse()
has_more = len(graph_messages) >= limit
chat_id = self._chat_id_from_context(graph_context, base_conversation_id)
order_by = "createdDateTime asc"
filter_op = "gt"
if direction != "forward":
order_by = "createdDateTime desc"
filter_op = "lt"
params = {"$top": limit, "$orderby": order_by}
if options.cursor:
params["$filter"] = f"createdDateTime {filter_op} {options.cursor}"
graph_messages = await self._graph_list_chat_messages(chat_id, params)
has_more = len(graph_messages) >= limit
if direction != "forward":
graph_messages.reverse()


messages = [self._map_graph_message(msg, channel_id) for msg in graph_messages if msg.get("id")]

Expand Down Expand Up @@ -1225,7 +1259,14 @@ async def fetch_channel_info(self, channel_id: str) -> ChannelInfo:
base_conversation_id = MESSAGEID_STRIP_PATTERN.sub("", conversation_id)
is_dm = not conversation_id.startswith("19:")

channel_context = await self._get_channel_context(base_conversation_id) if not is_dm else None
# vercel/chat#403: only call into the Graph teams/channels
# endpoint for true channel contexts. A cached DM context (now
# possible when ``aadObjectId`` was present on the activity)
# must not be treated as a channel.
graph_context = await self._get_graph_context(base_conversation_id) if not is_dm else None
channel_context: TeamsChannelContext | None = None
if graph_context and graph_context.get("type") != "dm":
channel_context = cast(TeamsChannelContext, graph_context)

if channel_context:
try:
Expand Down Expand Up @@ -1344,8 +1385,19 @@ async def disconnect(self) -> None:
# Graph API — internal helpers
# =========================================================================

async def _get_channel_context(self, base_conversation_id: str) -> TeamsChannelContext | None:
"""Look up cached channel context (team_id, channel_id) for a conversation."""
async def _get_graph_context(self, base_conversation_id: str) -> TeamsGraphContext | None:
"""Look up cached Microsoft Graph context for a conversation.

Returns either a :class:`TeamsChannelContext` (channel/team
thread) or a :class:`TeamsDmContext` (1:1 DM with a resolved
Graph chat ID). For group chats, no entry is cached — the raw
conversation ID works as-is with Graph's ``/chats`` endpoints.

Backwards compat: cached entries written before vercel/chat#403
omit the ``type`` discriminator and are treated as
``"channel"`` by :meth:`_chat_id_from_context` and the call
sites that branch on context type.
"""
if not self._chat:
return None
state = self._chat.get_state()
Expand All @@ -1359,6 +1411,21 @@ async def _get_channel_context(self, base_conversation_id: str) -> TeamsChannelC
pass
return None

@staticmethod
def _chat_id_from_context(
context: TeamsGraphContext | None,
base_conversation_id: str,
) -> str:
"""Resolve the Microsoft Graph chat ID for a non-channel conversation.

Uses the DM context's ``graph_chat_id`` when present, otherwise
falls back to the raw Bot Framework conversation ID (which works
as-is for group chats and the legacy pre-#403 cache shape).
"""
if context is not None and context.get("type") == "dm":
return context["graph_chat_id"]
return base_conversation_id

async def _graph_list_chat_messages(
self,
chat_id: str,
Expand Down
30 changes: 28 additions & 2 deletions src/chat_sdk/adapters/teams/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,37 @@ class TeamsThreadId:
# =============================================================================


class TeamsChannelContext(TypedDict):
"""Teams channel context extracted from activity.channelData."""
class TeamsChannelContext(TypedDict, total=False):
"""Teams channel context extracted from activity.channelData.

The ``type`` discriminator is optional for backwards-compatibility:
cached entries written before vercel/chat#403 omit it, and downstream
code treats a missing ``type`` as ``"channel"``.
"""

channel_id: str
team_id: str
type: str # Literal["channel"] when present


class TeamsDmContext(TypedDict):
"""Teams DM context with the resolved Microsoft Graph chat ID.

Bot Framework hands out opaque DM conversation IDs (e.g.
``a:1xWhatever``) which are *not* accepted by Graph's
``/chats/{chat-id}/messages`` endpoint. The canonical Graph chat ID
for a 1:1 DM is ``19:{userAadId}_{botId}@unq.gbl.spaces`` — derive
and cache it from the incoming activity's ``from.aadObjectId``.
"""

graph_chat_id: str
type: str # Literal["dm"]


# Discriminated union for Microsoft Graph API resolution context.
# Group chats are not represented — their conversation ID works as-is
# with Graph's chat endpoints.
TeamsGraphContext = TeamsChannelContext | TeamsDmContext


# =============================================================================
Expand Down
52 changes: 49 additions & 3 deletions src/chat_sdk/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,48 @@ def _is_async_iterable(value: Any) -> bool:
return value is not None and hasattr(value, "__aiter__")


def _extract_slack_recipient_team_id(raw: Any) -> str | None:
"""Resolve the Slack workspace ID from a ``currentMessage.raw`` payload.

Slack carries the workspace ID in different shapes depending on the
webhook envelope:

- Message events (``message`` / ``app_mention``): top-level ``team_id``
(string) or ``team`` (string).
- Interactive payloads (``block_actions``, ``view_submission``, etc.):
nested ``team.id`` on a ``team`` object, with ``user.team_id`` as a
final fallback when ``team`` is missing entirely.

Upstream parity: vercel/chat#330. The previous extraction
(``raw.get("team_id") or raw.get("team")``) returned the entire ``team``
dict for block_actions, which then traveled to the Slack adapter as the
``recipient_team_id`` and either crashed downstream API calls or sent
them to the wrong workspace.
"""
if not isinstance(raw, dict):
return None

team_id = raw.get("team_id")
if isinstance(team_id, str) and team_id:
return team_id

team = raw.get("team")
if isinstance(team, str) and team:
return team
if isinstance(team, dict):
nested = team.get("id")
if isinstance(nested, str) and nested:
return nested

user = raw.get("user")
if isinstance(user, dict):
user_team_id = user.get("team_id")
if isinstance(user_team_id, str) and user_team_id:
return user_team_id

return None


def _extract_message_content(
message: AdapterPostableMessage,
) -> tuple[str, FormattedContent, list[Attachment]]:
Expand Down Expand Up @@ -615,9 +657,13 @@ async def _handle_stream(
options = StreamOptions()
if self._current_message is not None:
options.recipient_user_id = self._current_message.author.user_id
raw = self._current_message.raw
if isinstance(raw, dict):
options.recipient_team_id = raw.get("team_id") or raw.get("team")
# recipient_team_id is only consumed by the Slack adapter; other
# adapters ignore it. Slack carries the workspace ID in different
# shapes depending on webhook type — message events use the
# top-level ``team_id`` / ``team`` (string), block_actions / view
# payloads use ``team.id`` (object), with ``user.team_id`` as a
# final fallback. Upstream parity: vercel/chat#330.
options.recipient_team_id = _extract_slack_recipient_team_id(self._current_message.raw)

# Merge caller-supplied StreamingPlan options on top. Explicit fields win.
if extra_options is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_coverage_gaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ async def test_fetch_messages_dm_chat(self):

# Mock the Graph API methods
adapter._get_graph_token = AsyncMock(return_value="fake-token")
adapter._get_channel_context = AsyncMock(return_value=None)
adapter._get_graph_context = AsyncMock(return_value=None)

fake_graph_messages = [
{
Expand Down
Loading
Loading