diff --git a/src/chat_sdk/__init__.py b/src/chat_sdk/__init__.py index 7b121a8..7655b92 100644 --- a/src/chat_sdk/__init__.py +++ b/src/chat_sdk/__init__.py @@ -158,6 +158,8 @@ MessageContext, MessageData, MessageMetadata, + MessageSubject, + MessageSubjectParty, ModalCloseEvent, ModalResponse, ModalSubmitEvent, @@ -368,6 +370,8 @@ "MessageContext", "MessageData", "MessageMetadata", + "MessageSubject", + "MessageSubjectParty", "ModalCloseEvent", "ModalResponse", "ModalSubmitEvent", diff --git a/src/chat_sdk/chat.py b/src/chat_sdk/chat.py index a4024fb..9df5de6 100644 --- a/src/chat_sdk/chat.py +++ b/src/chat_sdk/chat.py @@ -63,6 +63,7 @@ UserInfo, WebhookOptions, _parse_iso, + set_message_adapter, ) # --------------------------------------------------------------------------- @@ -2135,6 +2136,20 @@ async def _dispatch_to_handlers( context: MessageContext | None = None, ) -> None: """Route a message to the correct handler chain.""" + # Register the owning adapter so handlers can lazily resolve + # ``message.subject`` via the adapter's optional ``fetch_subject`` hook. + # Mirrors upstream's ``setMessageAdapter`` call at the dispatch bind + # site (packages/chat/src/chat.ts). Every dispatched message flows + # through here, so this is the single registration point. + set_message_adapter(message, adapter) + # Skipped messages (queue drain / burst collapse) are surfaced to + # handlers via ``context.skipped`` but never themselves dispatched, + # so they also need their adapter bound for ``await msg.subject`` to + # work inside the handler. + if context is not None: + for skipped_msg in context.skipped: + set_message_adapter(skipped_msg, adapter) + # Detect mention message.is_mention = message.is_mention or self._detect_mention(adapter, message) diff --git a/src/chat_sdk/types.py b/src/chat_sdk/types.py index 28cb259..ae6eafc 100644 --- a/src/chat_sdk/types.py +++ b/src/chat_sdk/types.py @@ -5,6 +5,8 @@ from __future__ import annotations +import asyncio +import weakref from collections.abc import AsyncIterable, Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime @@ -387,6 +389,111 @@ class SerializedMessage(_SerializedMessageRequired, total=False): links: list[SerializedLinkPreview] +@dataclass +class MessageSubjectParty: + """A person referenced by a :class:`MessageSubject` (assignee/author). + + Mirrors the inline ``{ id: string; name: string }`` shape used by + upstream's ``MessageSubject.assignee`` / ``MessageSubject.author``. + """ + + id: str + name: str + + +@dataclass +class MessageSubject: + """The external subject a message refers to (e.g. a Linear issue or GitHub PR). + + Python port of the TS ``MessageSubject`` interface + (``packages/chat/src/types.ts``). Resolved lazily via + :attr:`Message.subject`, which delegates to the owning adapter's + optional :meth:`Adapter.fetch_subject` hook. + + Field names are snake_case per the Python port convention; ``raw`` is + the platform-specific escape hatch. + """ + + # ``id`` and ``type`` are the only required fields upstream; everything + # else is optional. ``raw`` is required upstream but defaults to ``None`` + # here so partially-populated subjects (e.g. in tests) construct cleanly. + id: str + type: str + raw: Any = None + assignee: MessageSubjectParty | None = None + author: MessageSubjectParty | None = None + description: str | None = None + labels: list[str] | None = None + status: str | None = None + title: str | None = None + url: str | None = None + + +# -------------------------------------------------------------------------- +# Message -> Adapter registry (powers ``Message.subject``) +# -------------------------------------------------------------------------- +# +# Upstream (``packages/chat/src/message.ts``) uses +# ``const adapterMap = new WeakMap()`` so a dispatched +# message can lazily ask its owning adapter to resolve its subject, without +# the message holding a hard reference to the adapter and without leaking +# messages after they fall out of scope. +# +# Python port hazard — hashability/weakref: +# ``Message`` is a plain ``@dataclass`` (``eq=True``), which makes instances +# *unhashable*. A ``weakref.WeakKeyDictionary[Message, Adapter]`` therefore +# raises ``TypeError: unhashable type: 'Message'``. We deliberately do NOT +# change ``Message`` to ``eq=False``/``frozen=True`` (that would alter its +# public equality contract). Instead we key a plain ``dict`` by +# ``id(message)`` (object identity, matching ``WeakMap`` semantics) and +# register a ``weakref.finalize`` callback per message that pops the entry +# when the message is garbage-collected. ``weakref.ref(message)`` works on a +# plain dataclass even though ``hash()`` does not, so this is safe. The +# finalizer also closes the ``id()`` reuse hole: the entry is removed before +# CPython can recycle the id for a new object. +_message_adapter_map: dict[int, Adapter] = {} + + +def set_message_adapter(message: Message, adapter: Adapter) -> None: + """Register the adapter that owns ``message`` (powers ``message.subject``). + + Called by :class:`~chat_sdk.chat.Chat` at the dispatch bind site so every + message handed to a handler can resolve its subject via the adapter's + optional :meth:`Adapter.fetch_subject` hook. + + Mirrors upstream ``setMessageAdapter`` (``packages/chat/src/message.ts``). + The mapping is keyed by object identity and weakly scoped: when ``message`` + is garbage-collected, its entry is removed automatically. + """ + key = id(message) + already_registered = key in _message_adapter_map + _message_adapter_map[key] = adapter + + # Register the GC finalizer only on first registration for a given message + # identity; re-registering the same live message just overwrites the adapter + # value above. This prevents an accumulation of redundant finalizers when a + # message is registered more than once (re-dispatch, rehydrate, multiple + # handler passes). The ``id()``-reuse hole stays closed: if a prior message + # with the same id was GC'd, its finalizer already popped the entry, so + # ``key not in _message_adapter_map`` is true again and a fresh finalizer is + # registered for the new object. + if not already_registered: + # Drop the entry when the message is GC'd. A zero-arg closure (rather + # than ``weakref.finalize(message, dict.pop, key, None)``) captures + # ``key`` and keeps the finalizer callable's type unambiguous for the + # type-checker. ``pop(key, None)`` is a no-op if the entry was already + # removed. + def _cleanup() -> None: + _message_adapter_map.pop(key, None) + + weakref.finalize(message, _cleanup) + + +def _get_message_adapter(message: Message) -> Adapter | None: + """Return the adapter registered for ``message``, or ``None``.""" + return _message_adapter_map.get(id(message)) + + def _strip_none(d: dict[str, Any]) -> dict[str, Any]: """Remove keys whose value is ``None`` from a dict. @@ -412,6 +519,72 @@ class Message: links: list[LinkPreview] | None = None raw: Any = None + # Cached awaitable for ``subject``. Mirrors upstream's ``_subjectPromise``: + # the first ``await message.subject`` stores the in-flight future here so a + # second access reuses it instead of re-calling ``fetch_subject``. + # ``init=False``/``compare=False``/``repr=False`` keep it out of the + # dataclass ``__init__``, equality, and ``repr`` — it is purely internal + # resolution state, not message data. + _subject_future: Any = field(default=None, init=False, compare=False, repr=False) + + async def _resolve_subject(self) -> MessageSubject | None: + """Resolve the subject via the owning adapter's ``fetch_subject`` hook. + + Returns ``None`` when no adapter is registered, the adapter has no + ``fetch_subject`` hook, the hook returns ``None``, or the hook raises + (failures are swallowed, mirroring upstream's ``.catch(() => null)``). + """ + adapter = _get_message_adapter(self) + if adapter is None: + return None + fetch_subject = getattr(adapter, "fetch_subject", None) + if fetch_subject is None: + return None + try: + return await fetch_subject(self.raw) + except Exception: + return None + + async def _subject(self) -> MessageSubject | None: + """Coroutine backing the :attr:`subject` accessor (caches the result). + + The first await schedules ``_resolve_subject`` once via + ``ensure_future`` and stores the shared future on the instance; every + later/concurrent await reuses it, so ``fetch_subject`` runs at most + once. Mirrors upstream's cached ``_subjectPromise``. + + The cached future is awaited through :func:`asyncio.shield` so that a + caller cancellation (e.g. ``asyncio.wait_for(msg.subject, timeout=...)`` + firing) propagates ``CancelledError`` to the caller but does *not* + cancel the shared inner task. Without shielding, the first cancelled + awaiter would poison the cache and every subsequent ``await + msg.subject`` would raise ``CancelledError``. + """ + if self._subject_future is None: + self._subject_future = asyncio.ensure_future(self._resolve_subject()) + return await asyncio.shield(self._subject_future) + + @property + def subject(self) -> Awaitable[MessageSubject | None]: + """The external subject this message refers to (issue, PR, etc.), or ``None``. + + Lazily resolved via the owning adapter's optional + :meth:`Adapter.fetch_subject` hook. The adapter is registered at + dispatch time by :func:`set_message_adapter`. + + Mirrors upstream ``Message.subject`` (``packages/chat/src/message.ts``): + it is an awaitable, the result is cached after the first access, and a + second ``await message.subject`` does NOT re-call ``fetch_subject``. + Concurrent awaits share a single in-flight resolution. + + Usage:: + + subject = await message.subject + if subject is not None: + ... + """ + return self._subject() + def to_json(self) -> dict[str, Any]: """Serialize to JSON-compatible dict. @@ -1257,6 +1430,17 @@ async def get_user(self, user_id: str) -> UserInfo | None: """ return None + # NOTE: ``fetch_subject`` is intentionally NOT declared here. Upstream's + # ``Adapter.fetchSubject`` is an *optional* member (``fetchSubject?(...)``), + # and in this Python port the established convention for optional adapter + # hooks (``stream``, ``open_dm``, ``rehydrate_attachment``, + # ``get_channel_visibility``, ...) is to declare them on :class:`BaseAdapter` + # only — NOT on this structural ``Protocol`` — so that adapters which don't + # implement them still satisfy ``Adapter`` for type-checking. Declaring it + # on the Protocol would make it a *required* attribute and break every + # adapter that doesn't define it. :attr:`Message.subject` reads the hook via + # ``getattr(adapter, "fetch_subject", None)``, so presence is fully optional. + class BaseAdapter: """Base adapter with default implementations for optional methods. @@ -1415,6 +1599,22 @@ async def get_user(self, user_id: str) -> UserInfo | None: """ raise ChatNotImplementedError(self.name, "getUser") + async def fetch_subject(self, raw: Any) -> MessageSubject | None: + """Resolve the external subject a message refers to (issue, PR, etc.). + + Optional — the default returns ``None`` (no subject). Adapters that + can resolve a backing entity (a Linear issue, a GitHub PR, etc.) from a + message's raw payload should override this. Unlike most optional + :class:`BaseAdapter` hooks it does *not* raise + :class:`~chat_sdk.errors.ChatNotImplementedError`, because + :attr:`Message.subject` is best-effort: "this adapter has no subject + concept" is a normal, non-error outcome that maps to ``None``. + + Mirrors upstream's optional ``Adapter.fetchSubject`` + (``packages/chat/src/types.ts``). + """ + return None + def rehydrate_attachment(self, attachment: Attachment) -> Attachment: """Reconstruct ``fetch_data`` on an attachment after deserialization. diff --git a/tests/test_chat_faithful.py b/tests/test_chat_faithful.py index bb5c491..61e6739 100644 --- a/tests/test_chat_faithful.py +++ b/tests/test_chat_faithful.py @@ -34,6 +34,7 @@ ConcurrencyConfig, EmojiValue, MessageContext, + MessageSubject, ModalSubmitEvent, QueueEntry, ReactionEvent, @@ -3889,3 +3890,98 @@ async def test_should_not_cache_incoming_messages_when_adapter_does_not_set_pers history_keys = [k for k in state.cache if k.startswith("msg-history:")] assert len(history_keys) == 0 + + +class TestSubjectBinding: + """Dispatch registers the owning adapter so handlers can resolve message.subject.""" + + async def test_handler_can_resolve_subject_via_adapter_hook(self): + adapter = create_mock_adapter("slack") + expected = MessageSubject(id="ENG-1", type="issue", title="Fix it", raw={}) + + async def _fetch_subject(raw): # noqa: ANN001, ANN202 + return expected + + adapter.fetch_subject = _fetch_subject # type: ignore[attr-defined] + chat, adapter, state = await _init_chat(adapter=adapter) + + resolved: list[MessageSubject | None] = [] + + @chat.on_subscribed_message + async def handler(thread, message, context=None): + resolved.append(await message.subject) + + await state.subscribe("slack:C123:1234.5678") + msg = create_test_message("msg-1", "Follow up") + await chat.handle_incoming_message(adapter, "slack:C123:1234.5678", msg) + + assert resolved == [expected] + + async def test_subject_is_none_when_adapter_has_no_fetch_subject_hook(self): + chat, adapter, state = await _init_chat() + resolved: list[MessageSubject | None] = [] + + @chat.on_subscribed_message + async def handler(thread, message, context=None): + resolved.append(await message.subject) + + await state.subscribe("slack:C123:1234.5678") + msg = create_test_message("msg-1", "Follow up") + await chat.handle_incoming_message(adapter, "slack:C123:1234.5678", msg) + + assert resolved == [None] + + # Codex P2: skipped messages were dispatched to handlers via + # ``context.skipped`` without ever being bound to their owning adapter, + # so ``await context.skipped[i].subject`` silently returned ``None`` even + # when ``adapter.fetch_subject`` was implemented. Fix binds skipped + # messages alongside the primary message in ``_dispatch_to_handlers``. + async def test_skipped_messages_subject_resolves_via_adapter_hook(self): + """Burst-drained skipped messages must also support ``.subject``. + + Load-bearing: reverting the bind-skipped fix in + ``_dispatch_to_handlers`` makes ``context.skipped[i].subject`` + return ``None`` instead of the adapter-fetched value. + """ + adapter = create_mock_adapter("slack") + # ``_resolve_subject`` invokes ``fetch_subject(self.raw)`` — the raw + # webhook payload, not the Message object — so we key on raw["id"]. + fetched: dict[str, MessageSubject] = { + "msg-burst-1": MessageSubject(id="A", type="issue", title="first", raw={}), + "msg-burst-2": MessageSubject(id="B", type="issue", title="second", raw={}), + "msg-burst-3": MessageSubject(id="C", type="issue", title="third", raw={}), + } + + async def _fetch_subject(raw): # noqa: ANN001, ANN202 + return fetched[raw["id"]] + + adapter.fetch_subject = _fetch_subject # type: ignore[attr-defined] + + chat, _, _ = await _init_chat( + adapter=adapter, + concurrency=ConcurrencyConfig(strategy="burst", debounce_ms=60), + ) + + all_subjects: list[MessageSubject | None] = [] + + @chat.on_mention + async def handler(thread, message, context=None): # noqa: ANN001 + all_subjects.append(await message.subject) + if context is not None: + for skipped in context.skipped: + all_subjects.append(await skipped.subject) + + def _mk(mid: str, text: str): # noqa: ANN202 + return create_test_message(mid, text, raw={"id": mid}) + + msg1 = _mk("msg-burst-1", "Hey @slack-bot first") + task = asyncio.create_task(chat.handle_incoming_message(adapter, "slack:C123:1234.5678", msg1)) + await asyncio.sleep(0.005) + await chat.handle_incoming_message(adapter, "slack:C123:1234.5678", _mk("msg-burst-2", "Hey @slack-bot second")) + await chat.handle_incoming_message(adapter, "slack:C123:1234.5678", _mk("msg-burst-3", "Hey @slack-bot third")) + await task + + # latest (msg-burst-3) dispatched as ``message``, the two earlier + # arrivals folded into ``context.skipped`` in order. All three must + # resolve to their adapter-fetched subjects. + assert all_subjects == [fetched["msg-burst-3"], fetched["msg-burst-1"], fetched["msg-burst-2"]] diff --git a/tests/test_types.py b/tests/test_types.py index a456c8a..9f5f605 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -2,8 +2,13 @@ from __future__ import annotations +import asyncio +import gc +import weakref from datetime import datetime, timezone +import pytest + from chat_sdk.types import ( Attachment, Author, @@ -11,7 +16,12 @@ EmojiValue, Message, MessageMetadata, + MessageSubject, + MessageSubjectParty, RawMessage, + _get_message_adapter, + _message_adapter_map, + set_message_adapter, ) @@ -366,3 +376,259 @@ def test_creation(self): assert raw.id == "raw-001" assert raw.thread_id == "thread-001" assert raw.raw["platform"] == "test" + + +def _make_message(**overrides) -> Message: + """Build a Message for subject tests (mirrors upstream ``makeMessage``).""" + defaults = { + "id": "msg-1", + "thread_id": "slack:C123:1234.5678", + "text": "Hello world", + "formatted": {"type": "root", "children": []}, + "author": Author( + user_id="U123", + user_name="testuser", + full_name="Test User", + is_bot=False, + is_me=False, + ), + "metadata": MessageMetadata( + date_sent=datetime(2024, 1, 15, 10, 30, 0, tzinfo=timezone.utc), + edited=False, + ), + "raw": {"platform": "test"}, + } + defaults.update(overrides) + return Message(**defaults) + + +class _AdapterWithSubject: + """Minimal adapter stub exposing ``fetch_subject`` that records call count.""" + + def __init__(self, result: MessageSubject | None) -> None: + self._result = result + self.calls = 0 + + async def fetch_subject(self, raw): # noqa: ANN001, ANN201 + self.calls += 1 + return self._result + + +class TestMessageSubjectDataclass: + """Tests for the MessageSubject / MessageSubjectParty dataclasses.""" + + def test_minimal_required_fields(self): + subject = MessageSubject(id="ENG-1", type="issue") + assert subject.id == "ENG-1" + assert subject.type == "issue" + assert subject.raw is None + assert subject.assignee is None + assert subject.labels is None + + def test_all_fields(self): + subject = MessageSubject( + id="ENG-123", + type="issue", + raw={"k": "v"}, + assignee=MessageSubjectParty(id="u1", name="Alice"), + author=MessageSubjectParty(id="u2", name="Bob"), + description="A bug", + labels=["bug", "p0"], + status="In Progress", + title="Fix bug", + url="https://linear.app/team/ENG-123", + ) + assert subject.assignee == MessageSubjectParty(id="u1", name="Alice") + assert subject.author.name == "Bob" + assert subject.labels == ["bug", "p0"] + assert subject.status == "In Progress" + assert subject.title == "Fix bug" + assert subject.url == "https://linear.app/team/ENG-123" + assert subject.description == "A bug" + + +class TestMessageSubject: + """Tests for the Message.subject accessor (mirrors upstream message.test.ts).""" + + async def test_returns_none_when_no_adapter_is_set(self): + msg = _make_message() + assert await msg.subject is None + + async def test_returns_none_when_adapter_has_no_fetch_subject(self): + msg = _make_message() + set_message_adapter(msg, object()) + assert await msg.subject is None + + async def test_returns_subject_from_adapter(self): + msg = _make_message() + expected = MessageSubject( + type="issue", + id="ENG-123", + title="Fix bug", + status="In Progress", + url="https://linear.app/team/ENG-123", + raw={}, + ) + set_message_adapter(msg, _AdapterWithSubject(expected)) + result = await msg.subject + assert result == expected + + async def test_should_cache_the_result(self): + msg = _make_message() + adapter = _AdapterWithSubject(MessageSubject(type="issue", id="1", raw={})) + set_message_adapter(msg, adapter) + await msg.subject + await msg.subject + assert adapter.calls == 1 + + async def test_should_cache_null_result(self): + msg = _make_message() + adapter = _AdapterWithSubject(None) + set_message_adapter(msg, adapter) + await msg.subject + await msg.subject + assert adapter.calls == 1 + + async def test_should_handle_concurrent_access(self): + msg = _make_message() + adapter = _AdapterWithSubject(MessageSubject(type="issue", id="1", raw={})) + set_message_adapter(msg, adapter) + a, b = await asyncio.gather(msg.subject, msg.subject) + assert a == b + assert adapter.calls == 1 + + async def test_swallows_fetch_subject_errors(self): + """A raising hook resolves to None (mirrors upstream .catch(() => null)).""" + msg = _make_message() + + class Boom: + async def fetch_subject(self, raw): # noqa: ANN001, ANN201 + raise RuntimeError("boom") + + set_message_adapter(msg, Boom()) + assert await msg.subject is None + + async def test_passes_raw_payload_to_fetch_subject(self): + msg = _make_message(raw={"native": "payload"}) + seen = {} + + class Capturing: + async def fetch_subject(self, raw): # noqa: ANN001, ANN201 + seen["raw"] = raw + return None + + set_message_adapter(msg, Capturing()) + await msg.subject + assert seen["raw"] == {"native": "payload"} + + async def test_subject_survives_caller_cancellation(self): + """Cancellation in one awaiter must not poison the cached future for other awaiters. + + Cancelling ``await msg.subject`` (via wait_for/timeout) used to propagate into + the shared ``_subject_future``, cancelling the inner task and causing every + subsequent ``await msg.subject`` to raise CancelledError. ``asyncio.shield()`` + prevents that. + """ + # Set up: an adapter whose fetch_subject takes longer than a tight timeout. + started = asyncio.Event() + proceed = asyncio.Event() + + class SlowAdapter: + name = "slow" + + async def fetch_subject(self, raw): # noqa: ANN001, ANN201 + started.set() + await proceed.wait() + return MessageSubject(id="s1", type="issue", raw={}, title="Done") + + msg = _make_message() + set_message_adapter(msg, SlowAdapter()) + + # First caller times out — must NOT poison the cache. + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(msg.subject, timeout=0.05) + await started.wait() # confirm the inner task started + + # Now let the inner task finish. A second caller must see the result. + proceed.set() + result = await msg.subject + assert result is not None + assert result.title == "Done" + + +class TestSetMessageAdapterWeakref: + """Tests for the identity-keyed, weakly-scoped adapter registry.""" + + def test_registration_does_not_crash_on_unhashable_message(self): + # Message is a plain dataclass (eq=True) -> unhashable. The registry + # must not rely on hashing the Message itself. + msg = _make_message() + with __import__("pytest").raises(TypeError): + hash(msg) + set_message_adapter(msg, object()) # must not raise + + def test_entry_removed_when_message_is_garbage_collected(self): + msg = _make_message() + set_message_adapter(msg, object()) + key = id(msg) + assert key in _message_adapter_map + del msg + gc.collect() + assert key not in _message_adapter_map + + def test_distinct_messages_get_distinct_adapters(self): + m1 = _make_message() + m2 = _make_message() + a1, a2 = object(), object() + set_message_adapter(m1, a1) + set_message_adapter(m2, a2) + assert _get_message_adapter(m1) is a1 + assert _get_message_adapter(m2) is a2 + + @staticmethod + def _live_finalizer_count(message: Message) -> int: + """Count live ``weakref.finalize`` callbacks attached to ``message``. + + ``weakref.finalize`` keeps a class-level registry whose keys are the + ``finalize`` instances; ``peek()`` returns ``(obj, func, args, kwargs)`` + while the finalizer is still alive. We count entries whose tracked + object is ``message`` to assert exactly one cleanup is registered. + """ + count = 0 + for finalizer in list(weakref.finalize._registry): + peeked = finalizer.peek() + if peeked is not None and peeked[0] is message: + count += 1 + return count + + def test_re_registration_does_not_add_duplicate_finalizer(self): + # Registering the same live message more than once (re-dispatch, + # rehydrate, multiple handler passes) must not accumulate finalizers. + msg = _make_message() + set_message_adapter(msg, object()) + assert self._live_finalizer_count(msg) == 1 + set_message_adapter(msg, object()) + set_message_adapter(msg, object()) + assert self._live_finalizer_count(msg) == 1 + + def test_re_registration_updates_adapter_value(self): + # The adapter VALUE is still overwritten on re-registration even though + # no second finalizer is added. + msg = _make_message() + adapter_a, adapter_b = object(), object() + set_message_adapter(msg, adapter_a) + assert _get_message_adapter(msg) is adapter_a + set_message_adapter(msg, adapter_b) + assert _get_message_adapter(msg) is adapter_b + + def test_re_registered_message_cleans_up_exactly_once(self): + # After re-registration, GC must remove the single entry without a + # double-pop and without leaving a stale finalizer behind. + msg = _make_message() + set_message_adapter(msg, object()) + set_message_adapter(msg, object()) + key = id(msg) + assert key in _message_adapter_map + del msg + gc.collect() + assert key not in _message_adapter_map