diff --git a/.sampo/changesets/propagate-missing-params-module-wrappers.md b/.sampo/changesets/propagate-missing-params-module-wrappers.md new file mode 100644 index 00000000..1112c6a9 --- /dev/null +++ b/.sampo/changesets/propagate-missing-params-module-wrappers.md @@ -0,0 +1,5 @@ +--- +pypi/posthog: patch +--- + +fix: propagate missing params in module-level wrapper functions (`distinct_id` for `group_identify`, `flag_keys_to_evaluate` for `get_all_flags`/`get_all_flags_and_payloads`) diff --git a/posthog/__init__.py b/posthog/__init__.py index 18995d7d..b8489185 100644 --- a/posthog/__init__.py +++ b/posthog/__init__.py @@ -391,6 +391,7 @@ def group_identify( timestamp=None, # type: Optional[datetime.datetime] uuid=None, # type: Optional[str] disable_geoip=None, # type: Optional[bool] + distinct_id=None, # type: Optional[str] ): # type: (...) -> Optional[str] """ @@ -403,6 +404,7 @@ def group_identify( timestamp: Optional timestamp for the event uuid: Optional UUID for the event disable_geoip: Whether to disable GeoIP lookup + distinct_id: Optional distinct ID of the user performing the action Examples: ```python @@ -425,6 +427,7 @@ def group_identify( timestamp=timestamp, uuid=uuid, disable_geoip=disable_geoip, + distinct_id=distinct_id, ) @@ -611,6 +614,7 @@ def get_all_flags( only_evaluate_locally=False, # type: bool disable_geoip=None, # type: Optional[bool] device_id=None, # type: Optional[str] + flag_keys_to_evaluate=None, # type: Optional[list[str]] ) -> Optional[dict[str, FeatureFlag]]: """ Get all flags for a given user. @@ -622,6 +626,7 @@ def get_all_flags( group_properties: Group properties only_evaluate_locally: Whether to evaluate only locally disable_geoip: Whether to disable GeoIP lookup + flag_keys_to_evaluate: Optional list of flag keys to evaluate (evaluates all if None) Details: Flags are key-value pairs where the key is the flag key and the value is the flag variant, or True, or False. @@ -644,6 +649,7 @@ def get_all_flags( only_evaluate_locally=only_evaluate_locally, disable_geoip=disable_geoip, device_id=device_id, + flag_keys_to_evaluate=flag_keys_to_evaluate, ) @@ -747,6 +753,7 @@ def get_all_flags_and_payloads( only_evaluate_locally=False, disable_geoip=None, # type: Optional[bool] device_id=None, # type: Optional[str] + flag_keys_to_evaluate=None, # type: Optional[list[str]] ) -> FlagsAndPayloads: return _proxy( "get_all_flags_and_payloads", @@ -757,6 +764,7 @@ def get_all_flags_and_payloads( only_evaluate_locally=only_evaluate_locally, disable_geoip=disable_geoip, device_id=device_id, + flag_keys_to_evaluate=flag_keys_to_evaluate, ) diff --git a/posthog/test/test_module.py b/posthog/test/test_module.py index 79ed66c9..03bce00b 100644 --- a/posthog/test/test_module.py +++ b/posthog/test/test_module.py @@ -1,5 +1,9 @@ import unittest +from unittest import mock +from parameterized import parameterized + +import posthog from posthog import Posthog @@ -30,3 +34,62 @@ def test_alias(self): def test_flush(self): self.posthog.flush() + + +class TestModuleLevelWrappers(unittest.TestCase): + """Test that module-level wrapper functions in posthog/__init__.py + correctly propagate all parameters to the Client methods.""" + + def setUp(self): + self.mock_client = mock.MagicMock() + self._original_client = posthog.default_client + posthog.default_client = self.mock_client + + def tearDown(self): + posthog.default_client = self._original_client + + def test_group_identify_propagates_distinct_id(self): + posthog.group_identify( + "company", + "company_123", + {"name": "Awesome Inc."}, + distinct_id="user_456", + ) + self.mock_client.group_identify.assert_called_once_with( + group_type="company", + group_key="company_123", + properties={"name": "Awesome Inc."}, + timestamp=None, + uuid=None, + disable_geoip=None, + distinct_id="user_456", + ) + + def test_group_identify_distinct_id_defaults_to_none(self): + posthog.group_identify("company", "company_123") + call_kwargs = self.mock_client.group_identify.call_args[1] + self.assertIsNone(call_kwargs["distinct_id"]) + + @parameterized.expand( + [ + ("get_all_flags", "get_all_flags"), + ("get_all_flags_and_payloads", "get_all_flags_and_payloads"), + ] + ) + def test_flag_keys_to_evaluate_propagated(self, _name, method_name): + fn = getattr(posthog, method_name) + fn("user_123", flag_keys_to_evaluate=["flag-1", "flag-2"]) + call_kwargs = getattr(self.mock_client, method_name).call_args[1] + self.assertEqual(call_kwargs["flag_keys_to_evaluate"], ["flag-1", "flag-2"]) + + @parameterized.expand( + [ + ("get_all_flags", "get_all_flags"), + ("get_all_flags_and_payloads", "get_all_flags_and_payloads"), + ] + ) + def test_flag_keys_to_evaluate_defaults_to_none(self, _name, method_name): + fn = getattr(posthog, method_name) + fn("user_123") + call_kwargs = getattr(self.mock_client, method_name).call_args[1] + self.assertIsNone(call_kwargs["flag_keys_to_evaluate"])