diff --git a/flytekit/testing/__init__.py b/flytekit/testing/__init__.py index 76f89c2545..320d7c1317 100644 --- a/flytekit/testing/__init__.py +++ b/flytekit/testing/__init__.py @@ -1,6 +1,10 @@ """ -This module provides functionality related to testing +This module provides functionality related to testing. + +Provides utilities for mocking tasks and workflows, pytest fixtures for common +test setup patterns, and helpers for local workflow execution. """ from flytekit.core.context_manager import SecretsManager from flytekit.core.testing import patch, task_mock +from flytekit.testing.fixtures import flyte_cache, flyte_context, flyte_tmp_dir, workflow_dry_run diff --git a/flytekit/testing/conftest.py b/flytekit/testing/conftest.py new file mode 100644 index 0000000000..55d5def6e1 --- /dev/null +++ b/flytekit/testing/conftest.py @@ -0,0 +1,10 @@ +"""Pytest plugin that auto-registers flytekit testing fixtures. + +When ``flytekit`` is installed, these fixtures are automatically available in any +pytest session without needing to import them explicitly. This works via the +``pytest11`` entry point registered in ``pyproject.toml``. +""" + +from flytekit.testing.fixtures import flyte_cache, flyte_context, flyte_tmp_dir + +__all__ = ["flyte_cache", "flyte_context", "flyte_tmp_dir"] diff --git a/flytekit/testing/fixtures.py b/flytekit/testing/fixtures.py new file mode 100644 index 0000000000..25c7fe8f73 --- /dev/null +++ b/flytekit/testing/fixtures.py @@ -0,0 +1,96 @@ +import tempfile +import typing +from contextlib import contextmanager +from pathlib import Path + +import pytest + +from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.local_cache import LocalTaskCache + + +@pytest.fixture +def flyte_context() -> FlyteContext: + """Provide the current FlyteContext for testing. + + This eliminates the need to manually call ``FlyteContextManager.current_context()`` + in every test that needs a context for type transformations, file access, or other + context-dependent operations. + + Usage:: + + def test_type_transform(flyte_context): + from flytekit.core.type_engine import TypeEngine + lt = TypeEngine.to_literal_type(int) + lv = TypeEngine.to_literal(flyte_context, 42, int, lt) + assert lv.scalar.primitive.integer == 42 + """ + return FlyteContextManager.current_context() + + +@pytest.fixture +def flyte_cache(): + """Initialize and clear the local task cache before and after each test. + + Prevents stale cached results from prior test runs from leaking into the current test. + This addresses a common pain point where ``cache=True`` on tasks causes flaky tests + because the on-disk cache (``~/.flyte/local-cache``) persists between test runs. + + See https://github.com/flyteorg/flyte/issues/5657 + + Usage:: + + def test_cached_task(flyte_cache): + @task(cache=True, cache_version="v1") + def add(a: int, b: int) -> int: + return a + b + + assert add(a=1, b=2) == 3 + # Cache is automatically cleared after the test + """ + LocalTaskCache.initialize() + LocalTaskCache.clear() + yield + LocalTaskCache.clear() + + +@pytest.fixture +def flyte_tmp_dir() -> typing.Generator[Path, None, None]: + """Provide a temporary directory that is cleaned up after the test. + + Useful for tests involving ``FlyteFile``, ``FlyteDirectory``, or any operation + that needs to write files to disk. + + Usage:: + + def test_file_output(flyte_tmp_dir): + output_path = flyte_tmp_dir / "result.txt" + output_path.write_text("hello") + assert output_path.read_text() == "hello" + """ + with tempfile.TemporaryDirectory() as td: + yield Path(td) + + +@contextmanager +def workflow_dry_run() -> typing.Generator[None, None, None]: + """Context manager that sets up a clean local execution environment. + + Initializes and clears the local cache, then cleans up after the block completes. + Useful for running a workflow locally in tests without worrying about cached state. + + Usage:: + + from flytekit.testing.fixtures import workflow_dry_run + + def test_my_workflow(): + with workflow_dry_run(): + result = my_workflow(x=1, y=2) + assert result == 3 + """ + LocalTaskCache.initialize() + LocalTaskCache.clear() + try: + yield + finally: + LocalTaskCache.clear() diff --git a/pyproject.toml b/pyproject.toml index 82b8c6c054..806573738f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,9 @@ pyflyte-map-execute = "flytekit.bin.entrypoint:map_execute_task_cmd" pyflyte = "flytekit.clis.sdk_in_container.pyflyte:main" flyte-cli = "flytekit.clis.flyte_cli.main:_flyte_cli" +[project.entry-points.pytest11] +flytekit = "flytekit.testing.conftest" + [tool.setuptools_scm] write_to = "flytekit/_version.py" diff --git a/tests/flytekit/unit/testing/__init__.py b/tests/flytekit/unit/testing/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/testing/test_fixtures.py b/tests/flytekit/unit/testing/test_fixtures.py new file mode 100644 index 0000000000..0742317b0e --- /dev/null +++ b/tests/flytekit/unit/testing/test_fixtures.py @@ -0,0 +1,113 @@ +import pytest + +from flytekit import task, workflow +from flytekit.core.context_manager import FlyteContext +from flytekit.core.local_cache import LocalTaskCache +from flytekit.core.type_engine import TypeEngine +from flytekit.testing.fixtures import flyte_cache, flyte_context, flyte_tmp_dir, workflow_dry_run + + +class TestFlyteContextFixture: + def test_returns_flyte_context(self, flyte_context): + assert isinstance(flyte_context, FlyteContext) + + def test_context_has_file_access(self, flyte_context): + assert flyte_context.file_access is not None + + def test_type_transform_with_context(self, flyte_context): + lt = TypeEngine.to_literal_type(int) + lv = TypeEngine.to_literal(flyte_context, 42, int, lt) + assert lv.scalar.primitive.integer == 42 + + +class TestFlyteCacheFixture: + def test_cache_is_cleared(self, flyte_cache): + assert LocalTaskCache._initialized is True + + def test_cached_task_works(self, flyte_cache): + call_count = 0 + + @task(cache=True, cache_version="test-v1") + def add(a: int, b: int) -> int: + nonlocal call_count + call_count += 1 + return a + b + + result1 = add(a=1, b=2) + result2 = add(a=1, b=2) + assert result1 == 3 + assert result2 == 3 + assert call_count == 1 # second call should hit cache + + def test_cache_isolated_between_tests_a(self, flyte_cache): + """First test in a pair that verifies cache isolation.""" + + @task(cache=True, cache_version="isolation-v1") + def multiply(a: int, b: int) -> int: + return a * b + + assert multiply(a=3, b=4) == 12 + + def test_cache_isolated_between_tests_b(self, flyte_cache): + """Second test verifying the cache was cleared between tests.""" + call_count = 0 + + @task(cache=True, cache_version="isolation-v1") + def multiply(a: int, b: int) -> int: + nonlocal call_count + call_count += 1 + return a * b + + multiply(a=3, b=4) + assert call_count == 1 # should NOT hit cache from previous test + + +class TestFlyteTmpDirFixture: + def test_provides_path(self, flyte_tmp_dir): + from pathlib import Path + + assert isinstance(flyte_tmp_dir, Path) + assert flyte_tmp_dir.exists() + assert flyte_tmp_dir.is_dir() + + def test_can_write_files(self, flyte_tmp_dir): + test_file = flyte_tmp_dir / "test.txt" + test_file.write_text("hello flytekit") + assert test_file.read_text() == "hello flytekit" + + def test_can_create_subdirectories(self, flyte_tmp_dir): + sub = flyte_tmp_dir / "subdir" + sub.mkdir() + assert sub.exists() + + +class TestWorkflowDryRun: + def test_basic_workflow(self): + @task + def add_one(x: int) -> int: + return x + 1 + + @workflow + def simple_wf(x: int) -> int: + return add_one(x=x) + + with workflow_dry_run(): + result = simple_wf(x=5) + assert result == 6 + + def test_cached_workflow(self): + call_count = 0 + + @task(cache=True, cache_version="dry-run-v1") + def square(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * x + + @workflow + def square_wf(x: int) -> int: + return square(x=x) + + with workflow_dry_run(): + result = square_wf(x=4) + assert result == 16