Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .sampo/changesets/propagate-missing-params-module-wrappers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
pypi/posthog: patch
---

fix: propagate missing params in module-level wrapper functions (`distinct_id` for `group_identify`, `flag_keys_to_evaluate` for `get_all_flags`/`get_all_flags_and_payloads`)
8 changes: 8 additions & 0 deletions posthog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def group_identify(
timestamp=None, # type: Optional[datetime.datetime]
uuid=None, # type: Optional[str]
disable_geoip=None, # type: Optional[bool]
distinct_id=None, # type: Optional[str]
):
# type: (...) -> Optional[str]
"""
Expand All @@ -403,6 +404,7 @@ def group_identify(
timestamp: Optional timestamp for the event
uuid: Optional UUID for the event
disable_geoip: Whether to disable GeoIP lookup
distinct_id: Optional distinct ID of the user performing the action

Examples:
```python
Expand All @@ -425,6 +427,7 @@ def group_identify(
timestamp=timestamp,
uuid=uuid,
disable_geoip=disable_geoip,
distinct_id=distinct_id,
)


Expand Down Expand Up @@ -611,6 +614,7 @@ def get_all_flags(
only_evaluate_locally=False, # type: bool
disable_geoip=None, # type: Optional[bool]
device_id=None, # type: Optional[str]
flag_keys_to_evaluate=None, # type: Optional[list[str]]
) -> Optional[dict[str, FeatureFlag]]:
"""
Get all flags for a given user.
Expand All @@ -622,6 +626,7 @@ def get_all_flags(
group_properties: Group properties
only_evaluate_locally: Whether to evaluate only locally
disable_geoip: Whether to disable GeoIP lookup
flag_keys_to_evaluate: Optional list of flag keys to evaluate (evaluates all if None)

Details:
Flags are key-value pairs where the key is the flag key and the value is the flag variant, or True, or False.
Expand All @@ -644,6 +649,7 @@ def get_all_flags(
only_evaluate_locally=only_evaluate_locally,
disable_geoip=disable_geoip,
device_id=device_id,
flag_keys_to_evaluate=flag_keys_to_evaluate,
)


Expand Down Expand Up @@ -747,6 +753,7 @@ def get_all_flags_and_payloads(
only_evaluate_locally=False,
disable_geoip=None, # type: Optional[bool]
device_id=None, # type: Optional[str]
flag_keys_to_evaluate=None, # type: Optional[list[str]]
) -> FlagsAndPayloads:
return _proxy(
"get_all_flags_and_payloads",
Expand All @@ -757,6 +764,7 @@ def get_all_flags_and_payloads(
only_evaluate_locally=only_evaluate_locally,
disable_geoip=disable_geoip,
device_id=device_id,
flag_keys_to_evaluate=flag_keys_to_evaluate,
)


Expand Down
63 changes: 63 additions & 0 deletions posthog/test/test_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import unittest
from unittest import mock

from parameterized import parameterized

import posthog
from posthog import Posthog


Expand Down Expand Up @@ -30,3 +34,62 @@ def test_alias(self):

def test_flush(self):
self.posthog.flush()


class TestModuleLevelWrappers(unittest.TestCase):
"""Test that module-level wrapper functions in posthog/__init__.py
correctly propagate all parameters to the Client methods."""

def setUp(self):
self.mock_client = mock.MagicMock()
self._original_client = posthog.default_client
posthog.default_client = self.mock_client

def tearDown(self):
posthog.default_client = self._original_client

def test_group_identify_propagates_distinct_id(self):
posthog.group_identify(
"company",
"company_123",
{"name": "Awesome Inc."},
distinct_id="user_456",
)
self.mock_client.group_identify.assert_called_once_with(
group_type="company",
group_key="company_123",
properties={"name": "Awesome Inc."},
timestamp=None,
uuid=None,
disable_geoip=None,
distinct_id="user_456",
)

def test_group_identify_distinct_id_defaults_to_none(self):
posthog.group_identify("company", "company_123")
call_kwargs = self.mock_client.group_identify.call_args[1]
self.assertIsNone(call_kwargs["distinct_id"])

@parameterized.expand(
[
("get_all_flags", "get_all_flags"),
("get_all_flags_and_payloads", "get_all_flags_and_payloads"),
]
)
def test_flag_keys_to_evaluate_propagated(self, _name, method_name):
fn = getattr(posthog, method_name)
fn("user_123", flag_keys_to_evaluate=["flag-1", "flag-2"])
call_kwargs = getattr(self.mock_client, method_name).call_args[1]
self.assertEqual(call_kwargs["flag_keys_to_evaluate"], ["flag-1", "flag-2"])

@parameterized.expand(
[
("get_all_flags", "get_all_flags"),
("get_all_flags_and_payloads", "get_all_flags_and_payloads"),
]
)
def test_flag_keys_to_evaluate_defaults_to_none(self, _name, method_name):
fn = getattr(posthog, method_name)
fn("user_123")
call_kwargs = getattr(self.mock_client, method_name).call_args[1]
self.assertIsNone(call_kwargs["flag_keys_to_evaluate"])
Loading