Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions posthog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def group_identify(
timestamp=None, # type: Optional[datetime.datetime]
uuid=None, # type: Optional[str]
disable_geoip=None, # type: Optional[bool]
distinct_id=None, # type: Optional[str]
):
# type: (...) -> Optional[str]
"""
Expand All @@ -403,6 +404,7 @@ def group_identify(
timestamp: Optional timestamp for the event
uuid: Optional UUID for the event
disable_geoip: Whether to disable GeoIP lookup
distinct_id: Optional distinct ID of the user performing the action

Examples:
```python
Expand All @@ -425,6 +427,7 @@ def group_identify(
timestamp=timestamp,
uuid=uuid,
disable_geoip=disable_geoip,
distinct_id=distinct_id,
)


Expand Down Expand Up @@ -611,6 +614,7 @@ def get_all_flags(
only_evaluate_locally=False, # type: bool
disable_geoip=None, # type: Optional[bool]
device_id=None, # type: Optional[str]
flag_keys_to_evaluate=None, # type: Optional[list[str]]
) -> Optional[dict[str, FeatureFlag]]:
"""
Get all flags for a given user.
Expand All @@ -622,6 +626,7 @@ def get_all_flags(
group_properties: Group properties
only_evaluate_locally: Whether to evaluate only locally
disable_geoip: Whether to disable GeoIP lookup
flag_keys_to_evaluate: Optional list of flag keys to evaluate (evaluates all if None)

Details:
Flags are key-value pairs where the key is the flag key and the value is the flag variant, or True, or False.
Expand All @@ -644,6 +649,7 @@ def get_all_flags(
only_evaluate_locally=only_evaluate_locally,
disable_geoip=disable_geoip,
device_id=device_id,
flag_keys_to_evaluate=flag_keys_to_evaluate,
)


Expand Down Expand Up @@ -747,6 +753,7 @@ def get_all_flags_and_payloads(
only_evaluate_locally=False,
disable_geoip=None, # type: Optional[bool]
device_id=None, # type: Optional[str]
flag_keys_to_evaluate=None, # type: Optional[list[str]]
) -> FlagsAndPayloads:
return _proxy(
"get_all_flags_and_payloads",
Expand All @@ -757,6 +764,7 @@ def get_all_flags_and_payloads(
only_evaluate_locally=only_evaluate_locally,
disable_geoip=disable_geoip,
device_id=device_id,
flag_keys_to_evaluate=flag_keys_to_evaluate,
)


Expand Down
63 changes: 63 additions & 0 deletions posthog/test/test_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import unittest
from unittest import mock

from parameterized import parameterized

import posthog
from posthog import Posthog


Expand Down Expand Up @@ -30,3 +34,62 @@ def test_alias(self):

def test_flush(self):
self.posthog.flush()


class TestModuleLevelWrappers(unittest.TestCase):
"""Test that module-level wrapper functions in posthog/__init__.py
correctly propagate all parameters to the Client methods."""

def setUp(self):
self.mock_client = mock.MagicMock()
self._original_client = posthog.default_client
posthog.default_client = self.mock_client

def tearDown(self):
posthog.default_client = self._original_client

def test_group_identify_propagates_distinct_id(self):
posthog.group_identify(
"company",
"company_123",
{"name": "Awesome Inc."},
distinct_id="user_456",
)
self.mock_client.group_identify.assert_called_once_with(
group_type="company",
group_key="company_123",
properties={"name": "Awesome Inc."},
timestamp=None,
uuid=None,
disable_geoip=None,
distinct_id="user_456",
)

def test_group_identify_distinct_id_defaults_to_none(self):
posthog.group_identify("company", "company_123")
call_kwargs = self.mock_client.group_identify.call_args[1]
self.assertIsNone(call_kwargs["distinct_id"])

@parameterized.expand(
[
("get_all_flags", "get_all_flags"),
("get_all_flags_and_payloads", "get_all_flags_and_payloads"),
]
)
def test_flag_keys_to_evaluate_propagated(self, _name, method_name):
fn = getattr(posthog, method_name)
fn("user_123", flag_keys_to_evaluate=["flag-1", "flag-2"])
call_kwargs = getattr(self.mock_client, method_name).call_args[1]
self.assertEqual(call_kwargs["flag_keys_to_evaluate"], ["flag-1", "flag-2"])

@parameterized.expand(
[
("get_all_flags", "get_all_flags"),
("get_all_flags_and_payloads", "get_all_flags_and_payloads"),
]
)
def test_flag_keys_to_evaluate_defaults_to_none(self, _name, method_name):
fn = getattr(posthog, method_name)
fn("user_123")
call_kwargs = getattr(self.mock_client, method_name).call_args[1]
self.assertIsNone(call_kwargs["flag_keys_to_evaluate"])
Loading