diff --git a/dotflow/abc/storage.py b/dotflow/abc/storage.py index 51ac1f3e..d7c17d94 100644 --- a/dotflow/abc/storage.py +++ b/dotflow/abc/storage.py @@ -1,7 +1,10 @@ """Storage ABC""" +from __future__ import annotations + from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Iterable +from typing import Any from dotflow.core.context import Context @@ -13,21 +16,62 @@ def __init__(self, *args, **kwargs): pass @abstractmethod - def post(self, key: str, context: Context) -> None: - """Post context somewhere""" + def post( + self, + key: str, + context: Context, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> None: + """Persist context under key.""" @abstractmethod def get(self, key: str) -> Context: - """Get context somewhere""" + """Return stored context or empty Context().""" @abstractmethod - def key(self, task: Callable): - """Function that returns a key to get and post storage""" + def key(self, task: Callable) -> str: + """Storage key for task.""" - @abstractmethod def clear(self, workflow_id: str) -> None: - """Remove every persisted entry under ``workflow_id``. + """Remove every entry under workflow_id. + + Default implementation delegates to ``delete_prefix``. External + subclasses may override directly. + """ + self.delete_prefix(f"{workflow_id}-") + + def delete(self, key: str) -> bool: + """Remove key. Returns True when present. + + Optional in 1.x. Becomes abstract in 2.0. + """ + raise NotImplementedError + + def delete_prefix(self, prefix: str) -> int: + """Remove keys starting with prefix. Returns count. + + Optional in 1.x. Becomes abstract in 2.0. + """ + raise NotImplementedError + + def list_keys(self, prefix: str) -> Iterable[str]: + """Iterate keys starting with prefix. + + Optional in 1.x. Becomes abstract in 2.0. + """ + raise NotImplementedError + + def atomic_swap( + self, + key: str, + expected: Any, + new: Any, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> bool: + """Replace value when current equals expected. - Used by the input-fingerprint reset path when - ``on_input_change='reset'``. + Optional in 1.x. Becomes abstract in 2.0. """ + raise NotImplementedError diff --git a/dotflow/cloud/aws/services/s3.py b/dotflow/cloud/aws/services/s3.py index 0b412415..9ca8c89f 100644 --- a/dotflow/cloud/aws/services/s3.py +++ b/dotflow/cloud/aws/services/s3.py @@ -7,6 +7,11 @@ from dotflow.cloud.core import ObjectStorage from dotflow.core.exception import ModuleNotFound +_PRECONDITION_CODES = { + "PreconditionFailed", + "ConditionalRequestConflict", +} + class S3(ObjectStorage): """Amazon S3 object storage.""" @@ -46,6 +51,79 @@ def write(self, key: str, data: list) -> None: ContentType="application/json", ) + def delete(self, key: str) -> bool: + """Delete a single object.""" + from botocore.exceptions import ClientError + + full_key = f"{self.prefix}{key}" + + try: + self._s3.head_object(Bucket=self.bucket, Key=full_key) + except ClientError as error: + code = error.response.get("Error", {}).get("Code") + + if code in ("404", "NoSuchKey", "NotFound"): + return False + + raise + + self._s3.delete_object(Bucket=self.bucket, Key=full_key) + + return True + + def read_with_etag(self, key: str) -> tuple[list, str | None]: + """Return (data, etag). Etag is None when key is missing.""" + try: + response = self._s3.get_object( + Bucket=self.bucket, + Key=f"{self.prefix}{key}", + ) + data = response["Body"].read().decode("utf-8") + + return loads(data), response.get("ETag") + except self._s3.exceptions.NoSuchKey: + return [], None + + def write_if_match(self, key: str, data: list, etag: str | None) -> bool: + """Conditional PutObject. Returns False on precondition failure.""" + from botocore.exceptions import ClientError + + kwargs = { + "Bucket": self.bucket, + "Key": f"{self.prefix}{key}", + "Body": dumps(data), + "ContentType": "application/json", + } + + if etag is None: + kwargs["IfNoneMatch"] = "*" + else: + kwargs["IfMatch"] = etag + + try: + self._s3.put_object(**kwargs) + return True + except ClientError as error: + code = error.response.get("Error", {}).get("Code") + + if code in _PRECONDITION_CODES: + return False + + raise + + def list_keys(self, sub_prefix: str) -> list[str]: + """Return keys starting with sub_prefix.""" + full_prefix = f"{self.prefix}{sub_prefix}" + paginator = self._s3.get_paginator("list_objects_v2") + names = [] + offset = len(self.prefix) + + for page in paginator.paginate(Bucket=self.bucket, Prefix=full_prefix): + for item in page.get("Contents", []): + names.append(item["Key"][offset:]) + + return names + def delete_prefix(self, sub_prefix: str) -> None: """Delete every object whose key starts with prefix + sub_prefix. diff --git a/dotflow/cloud/gcp/services/gcs.py b/dotflow/cloud/gcp/services/gcs.py index 36c40f18..f92c5217 100644 --- a/dotflow/cloud/gcp/services/gcs.py +++ b/dotflow/cloud/gcp/services/gcs.py @@ -44,6 +44,57 @@ def write(self, key: str, data: list) -> None: content_type="application/json", ) + def delete(self, key: str) -> bool: + """Delete a single blob.""" + blob = self._bucket.blob(f"{self.prefix}{key}") + + try: + blob.delete() + return True + except self._not_found: + return False + + def read_with_generation(self, key: str) -> tuple[list, int | None]: + """Return (data, generation). Generation is None when missing.""" + blob = self._bucket.blob(f"{self.prefix}{key}") + + try: + data = blob.download_as_text() + return loads(data), blob.generation + except self._not_found: + return [], None + + def write_if_generation_match( + self, key: str, data: list, generation: int | None + ) -> bool: + """Conditional upload. Returns False on precondition failure.""" + from google.api_core.exceptions import PreconditionFailed + + blob = self._bucket.blob(f"{self.prefix}{key}") + precondition = generation if generation is not None else 0 + + try: + blob.upload_from_string( + dumps(data), + content_type="application/json", + if_generation_match=precondition, + ) + return True + except PreconditionFailed: + return False + + def list_keys(self, sub_prefix: str) -> list[str]: + """Return blob names starting with sub_prefix.""" + full_prefix = f"{self.prefix}{sub_prefix}" + offset = len(self.prefix) + + return [ + blob.name[offset:] + for blob in self._client.list_blobs( + self._bucket, prefix=full_prefix + ) + ] + def delete_prefix(self, sub_prefix: str) -> None: """Delete every blob whose name starts with prefix + sub_prefix. diff --git a/dotflow/providers/storage_default.py b/dotflow/providers/storage_default.py index efcbb259..2c086738 100644 --- a/dotflow/providers/storage_default.py +++ b/dotflow/providers/storage_default.py @@ -1,29 +1,118 @@ """Storage Default""" -from collections.abc import Callable +from __future__ import annotations + +import threading +import time +from collections.abc import Callable, Iterable +from typing import Any from dotflow.abc.storage import Storage from dotflow.core.context import Context class StorageDefault(Storage): - """In-memory storage using a dictionary.""" + """In-memory storage.""" def __init__(self): self._store: dict[str, Context] = {} + self._fingerprints: dict[str, str] = {} + self._expirations: dict[str, float] = {} + self._lock = threading.RLock() + + def post( + self, + key: str, + context: Context, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> None: + with self._lock: + self._store[key] = context + + if fingerprint is not None: + self._fingerprints[key] = fingerprint - def post(self, key: str, context: Context) -> None: - self._store[key] = context + if ttl is not None: + self._expirations[key] = time.monotonic() + ttl + else: + self._expirations.pop(key, None) def get(self, key: str) -> Context: - return self._store.get(key, Context()) + with self._lock: + self._evict_if_expired(key) + + return self._store.get(key, Context()) + + def delete(self, key: str) -> bool: + with self._lock: + existed = key in self._store + self._store.pop(key, None) + self._fingerprints.pop(key, None) + self._expirations.pop(key, None) + + return existed + + def delete_prefix(self, prefix: str) -> int: + with self._lock: + stale = [k for k in self._store if k.startswith(prefix)] + + for key in stale: + self._store.pop(key, None) + self._fingerprints.pop(key, None) + self._expirations.pop(key, None) + + return len(stale) + + def list_keys(self, prefix: str) -> Iterable[str]: + with self._lock: + for key in list(self._store): + self._evict_if_expired(key) + + return [k for k in self._store if k.startswith(prefix)] + + def atomic_swap( + self, + key: str, + expected: Any, + new: Any, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> bool: + with self._lock: + current = self._store.get(key) + current_value = ( + current.storage if isinstance(current, Context) else current + ) + + if current_value != expected: + return False + + payload = new if isinstance(new, Context) else Context(storage=new) + self._store[key] = payload + self._fingerprints.pop(key, None) + self._expirations.pop(key, None) + + if fingerprint is not None: + self._fingerprints[key] = fingerprint + + if ttl is not None: + self._expirations[key] = time.monotonic() + ttl + + return True def key(self, task: Callable) -> str: return f"{task.workflow_id}-{task.task_id}" - def clear(self, workflow_id: str) -> None: - prefix = f"{workflow_id}-" - stale = [k for k in self._store if k.startswith(prefix)] + def _evict_if_expired(self, key: str) -> None: + expiry = self._expirations.get(key) + + if expiry is None: + return + + if time.monotonic() < expiry: + return - for key in stale: - del self._store[key] + self._store.pop(key, None) + self._fingerprints.pop(key, None) + self._expirations.pop(key, None) diff --git a/dotflow/providers/storage_file.py b/dotflow/providers/storage_file.py index b5eec108..344e926b 100644 --- a/dotflow/providers/storage_file.py +++ b/dotflow/providers/storage_file.py @@ -1,6 +1,10 @@ """Storage File""" -from collections.abc import Callable +from __future__ import annotations + +import threading +import time +from collections.abc import Callable, Iterable from json import dumps, loads from pathlib import Path from typing import Any @@ -12,59 +16,182 @@ class StorageFile(Storage): - """Storage""" + """File-backed storage.""" def __init__(self, *args, path: str = settings.START_PATH, **kwargs): self.path = Path(path, "tasks") self.path.mkdir(parents=True, exist_ok=True) + self._lock = threading.RLock() + + def post( + self, + key: str, + context: Context, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> None: + with self._lock: + task_context = [] + + if Path(self.path, key).exists(): + data = read_file(path=Path(self.path, key)) + + if isinstance(data, list): + task_context = data - def post(self, key: str, context: Context) -> None: - task_context = [] + if isinstance(context.storage, list): + for item in context.storage: + if isinstance(item, Context): + task_context.append(self._dumps(storage=item.storage)) + else: + task_context.append(self._dumps(storage=context.storage)) - if Path(self.path, key).exists(): - data = read_file(path=Path(self.path, key)) - if isinstance(data, list): - task_context = data + write_file(path=Path(self.path, key), content=task_context) - if isinstance(context.storage, list): - for item in context.storage: - if isinstance(item, Context): - task_context.append(self._dumps(storage=item.storage)) - else: - task_context.append(self._dumps(storage=context.storage)) + meta = {} - write_file(path=Path(self.path, key), content=task_context) - return None + if fingerprint is not None: + meta["fingerprint"] = fingerprint + + if ttl is not None: + meta["expires_at"] = time.time() + ttl + + if meta: + self._write_meta(key=key, meta=meta) def get(self, key: str) -> Context: - task_context = [] + with self._lock: + if self._is_expired(key): + self.delete(key) - if Path(self.path, key).exists(): - data = read_file(path=Path(self.path, key)) - if isinstance(data, list): - task_context = data + return Context() - if not task_context: - return Context() + task_context = [] - if len(task_context) == 1: - return self._loads(storage=task_context[0]) + if Path(self.path, key).exists(): + data = read_file(path=Path(self.path, key)) - contexts = Context(storage=[]) - for context in task_context: - contexts.storage.append(self._loads(storage=context)) + if isinstance(data, list): + task_context = data - return contexts + if not task_context: + return Context() - def key(self, task: Callable): - return f"{task.workflow_id}-{task.task_id}.json" + if len(task_context) == 1: + return self._loads(storage=task_context[0]) + + contexts = Context(storage=[]) + + for context in task_context: + contexts.storage.append(self._loads(storage=context)) + + return contexts + + def delete(self, key: str) -> bool: + with self._lock: + target = Path(self.path, key) + existed = target.exists() + + target.unlink(missing_ok=True) + Path(self.path, f"{key}.meta").unlink(missing_ok=True) + + return existed - def clear(self, workflow_id: str) -> None: - prefix = f"{workflow_id}-" + def delete_prefix(self, prefix: str) -> int: + with self._lock: + removed = 0 + + for entry in self.path.iterdir(): + if not entry.is_file(): + continue + + if entry.name.endswith(".meta"): + continue + + if not entry.name.startswith(prefix): + continue - for entry in self.path.iterdir(): - if entry.is_file() and entry.name.startswith(prefix): entry.unlink(missing_ok=True) + Path(self.path, f"{entry.name}.meta").unlink(missing_ok=True) + removed += 1 + + return removed + + def list_keys(self, prefix: str) -> Iterable[str]: + with self._lock: + names = [] + + for entry in self.path.iterdir(): + if not entry.is_file(): + continue + + if entry.name.endswith(".meta"): + continue + + if not entry.name.startswith(prefix): + continue + + if self._is_expired(entry.name): + self.delete(entry.name) + continue + + names.append(entry.name) + + return names + + def atomic_swap( + self, + key: str, + expected: Any, + new: Any, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> bool: + with self._lock: + current = self.get(key) + current_value = ( + current.storage if isinstance(current, Context) else None + ) + + if current_value != expected: + return False + + self.delete(key) + payload = new if isinstance(new, Context) else Context(storage=new) + self.post( + key=key, + context=payload, + ttl=ttl, + fingerprint=fingerprint, + ) + + return True + + def key(self, task: Callable) -> str: + return f"{task.workflow_id}-{task.task_id}.json" + + def _write_meta(self, key: str, meta: dict) -> None: + Path(self.path, f"{key}.meta").write_text(dumps(meta)) + + def _read_meta(self, key: str) -> dict: + meta_path = Path(self.path, f"{key}.meta") + + if not meta_path.exists(): + return {} + + try: + return loads(meta_path.read_text()) + except Exception: + return {} + + def _is_expired(self, key: str) -> bool: + meta = self._read_meta(key=key) + expires_at = meta.get("expires_at") + + if expires_at is None: + return False + + return time.time() >= expires_at def _loads(self, storage: Any) -> Context: try: diff --git a/dotflow/providers/storage_gcs.py b/dotflow/providers/storage_gcs.py index e26b0541..c1c71078 100644 --- a/dotflow/providers/storage_gcs.py +++ b/dotflow/providers/storage_gcs.py @@ -1,6 +1,8 @@ """Storage GCS""" -from collections.abc import Callable +from __future__ import annotations + +from collections.abc import Callable, Iterable from json import dumps, loads from typing import Any @@ -48,7 +50,13 @@ def __init__( ): self._gcs = GCS(bucket=bucket, prefix=prefix, project=project) - def post(self, key: str, context: Context) -> None: + def post( + self, + key: str, + context: Context, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> None: task_context = [] if isinstance(context.storage, list): @@ -60,6 +68,9 @@ def post(self, key: str, context: Context) -> None: self._gcs.write(key=key, data=task_context) + if fingerprint is not None: + self._gcs.write(key=f"{key}.fingerprint", data=[fingerprint]) + def get(self, key: str) -> Context: task_context = self._gcs.read(key) @@ -70,16 +81,71 @@ def get(self, key: str) -> Context: return self._loads(storage=task_context[0]) contexts = Context(storage=[]) + for context in task_context: contexts.storage.append(self._loads(storage=context)) return contexts - def key(self, task: Callable): - return f"{task.workflow_id}-{task.task_id}" + def delete(self, key: str) -> bool: + existed = self._gcs.delete(key=key) + self._gcs.delete(key=f"{key}.fingerprint") + + return existed + + def delete_prefix(self, prefix: str) -> int: + names = self._gcs.list_keys(prefix) + + if not names: + return 0 + + self._gcs.delete_prefix(prefix) + + return sum(1 for n in names if not n.endswith(".fingerprint")) - def clear(self, workflow_id: str) -> None: - self._gcs.delete_prefix(f"{workflow_id}-") + def list_keys(self, prefix: str) -> Iterable[str]: + return [ + n + for n in self._gcs.list_keys(prefix) + if not n.endswith(".fingerprint") + ] + + def atomic_swap( + self, + key: str, + expected: Any, + new: Any, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> bool: + data, generation = self._gcs.read_with_generation(key) + + if not data: + current_value = None + elif len(data) == 1: + current_value = self._loads(storage=data[0]).storage + else: + current_value = [self._loads(storage=d).storage for d in data] + + if current_value != expected: + return False + + payload = new if isinstance(new, Context) else Context(storage=new) + new_data = [self._dumps(storage=payload.storage)] + + swapped = self._gcs.write_if_generation_match( + key=key, + data=new_data, + generation=generation, + ) + + if swapped and fingerprint is not None: + self._gcs.write(key=f"{key}.fingerprint", data=[fingerprint]) + + return swapped + + def key(self, task: Callable) -> str: + return f"{task.workflow_id}-{task.task_id}" def _loads(self, storage: Any) -> Context: try: diff --git a/dotflow/providers/storage_s3.py b/dotflow/providers/storage_s3.py index e4516727..45d98d38 100644 --- a/dotflow/providers/storage_s3.py +++ b/dotflow/providers/storage_s3.py @@ -1,6 +1,8 @@ """Storage S3""" -from collections.abc import Callable +from __future__ import annotations + +from collections.abc import Callable, Iterable from json import dumps, loads from typing import Any @@ -48,7 +50,13 @@ def __init__( ): self._s3 = S3(bucket=bucket, prefix=prefix, region=region) - def post(self, key: str, context: Context) -> None: + def post( + self, + key: str, + context: Context, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> None: task_context = [] if isinstance(context.storage, list): @@ -60,6 +68,9 @@ def post(self, key: str, context: Context) -> None: self._s3.write(key=key, data=task_context) + if fingerprint is not None: + self._s3.write(key=f"{key}.fingerprint", data=[fingerprint]) + def get(self, key: str) -> Context: task_context = self._s3.read(key) @@ -70,16 +81,67 @@ def get(self, key: str) -> Context: return self._loads(storage=task_context[0]) contexts = Context(storage=[]) + for context in task_context: contexts.storage.append(self._loads(storage=context)) return contexts - def key(self, task: Callable): - return f"{task.workflow_id}-{task.task_id}" + def delete(self, key: str) -> bool: + existed = self._s3.delete(key=key) + self._s3.delete(key=f"{key}.fingerprint") + + return existed + + def delete_prefix(self, prefix: str) -> int: + names = self._s3.list_keys(prefix) + + if not names: + return 0 + + self._s3.delete_prefix(prefix) + + return sum(1 for n in names if not n.endswith(".fingerprint")) - def clear(self, workflow_id: str) -> None: - self._s3.delete_prefix(f"{workflow_id}-") + def list_keys(self, prefix: str) -> Iterable[str]: + return [ + n + for n in self._s3.list_keys(prefix) + if not n.endswith(".fingerprint") + ] + + def atomic_swap( + self, + key: str, + expected: Any, + new: Any, + ttl: int | None = None, + fingerprint: str | None = None, + ) -> bool: + data, etag = self._s3.read_with_etag(key) + + if not data: + current_value = None + elif len(data) == 1: + current_value = self._loads(storage=data[0]).storage + else: + current_value = [self._loads(storage=d).storage for d in data] + + if current_value != expected: + return False + + payload = new if isinstance(new, Context) else Context(storage=new) + new_data = [self._dumps(storage=payload.storage)] + + swapped = self._s3.write_if_match(key=key, data=new_data, etag=etag) + + if swapped and fingerprint is not None: + self._s3.write(key=f"{key}.fingerprint", data=[fingerprint]) + + return swapped + + def key(self, task: Callable) -> str: + return f"{task.workflow_id}-{task.task_id}" def _loads(self, storage: Any) -> Context: try: diff --git a/dotflow/testing/__init__.py b/dotflow/testing/__init__.py new file mode 100644 index 00000000..e534293b --- /dev/null +++ b/dotflow/testing/__init__.py @@ -0,0 +1,5 @@ +"""Public test helpers for dotflow consumers.""" + +from dotflow.testing.storage_contract import StorageContract + +__all__ = ["StorageContract"] diff --git a/dotflow/testing/storage_contract.py b/dotflow/testing/storage_contract.py new file mode 100644 index 00000000..328edd77 --- /dev/null +++ b/dotflow/testing/storage_contract.py @@ -0,0 +1,135 @@ +"""Storage contract suite.""" + +from __future__ import annotations + +import time + +from dotflow.abc.storage import Storage +from dotflow.core.context import Context + + +class StorageContract: + """Mix into a unittest.TestCase and override make_storage.""" + + supports_ttl: bool = True + + def make_storage(self) -> Storage: + raise NotImplementedError + + def setUp(self) -> None: + self.storage = self.make_storage() + + def test_post_get_roundtrip(self): + self.storage.post(key="k1", context=Context(storage={"v": 1})) + + result = self.storage.get(key="k1") + + self.assertEqual(result.storage, {"v": 1}) + + def test_get_missing_key_returns_empty_context(self): + result = self.storage.get(key="missing") + + self.assertIsNone(result.storage) + + def test_delete_returns_true_when_present(self): + self.storage.post(key="k1", context=Context(storage="a")) + + self.assertTrue(self.storage.delete(key="k1")) + self.assertIsNone(self.storage.get(key="k1").storage) + + def test_delete_returns_false_when_absent(self): + self.assertFalse(self.storage.delete(key="missing")) + + def test_delete_prefix_returns_count(self): + self.storage.post(key="wf-1-a", context=Context(storage="a")) + self.storage.post(key="wf-1-b", context=Context(storage="b")) + self.storage.post(key="wf-2-c", context=Context(storage="c")) + + removed = self.storage.delete_prefix("wf-1-") + + self.assertEqual(removed, 2) + self.assertEqual(self.storage.get(key="wf-2-c").storage, "c") + + def test_list_keys_filters_by_prefix(self): + self.storage.post(key="wf-1-a", context=Context(storage="a")) + self.storage.post(key="wf-1-b", context=Context(storage="b")) + self.storage.post(key="wf-2-c", context=Context(storage="c")) + + keys = sorted(self.storage.list_keys("wf-1-")) + + self.assertEqual(keys, ["wf-1-a", "wf-1-b"]) + + def test_atomic_swap_succeeds_when_expected_matches(self): + self.storage.post(key="k", context=Context(storage="old")) + + ok = self.storage.atomic_swap(key="k", expected="old", new="new") + + self.assertTrue(ok) + self.assertEqual(self.storage.get(key="k").storage, "new") + + def test_atomic_swap_fails_when_expected_does_not_match(self): + self.storage.post(key="k", context=Context(storage="old")) + + ok = self.storage.atomic_swap( + key="k", + expected="other", + new="new", + ) + + self.assertFalse(ok) + self.assertEqual(self.storage.get(key="k").storage, "old") + + def test_atomic_swap_clears_inherited_ttl(self): + if not self.supports_ttl: + self.skipTest("driver does not support TTL") + + self.storage.post(key="k", context=Context(storage="old"), ttl=1) + + ok = self.storage.atomic_swap(key="k", expected="old", new="new") + + self.assertTrue(ok) + + time.sleep(1.2) + + self.assertEqual(self.storage.get(key="k").storage, "new") + + def test_atomic_swap_applies_new_ttl(self): + if not self.supports_ttl: + self.skipTest("driver does not support TTL") + + self.storage.post(key="k", context=Context(storage="old")) + + ok = self.storage.atomic_swap( + key="k", + expected="old", + new="new", + ttl=1, + ) + + self.assertTrue(ok) + self.assertEqual(self.storage.get(key="k").storage, "new") + + time.sleep(1.2) + + self.assertIsNone(self.storage.get(key="k").storage) + + def test_ttl_expiration(self): + if not self.supports_ttl: + self.skipTest("driver does not support TTL") + + self.storage.post(key="k", context=Context(storage="x"), ttl=1) + + self.assertEqual(self.storage.get(key="k").storage, "x") + + time.sleep(1.1) + + self.assertIsNone(self.storage.get(key="k").storage) + + def test_clear_delegates_to_delete_prefix(self): + self.storage.post(key="wf-A-task-1", context=Context(storage="a")) + self.storage.post(key="wf-B-task-1", context=Context(storage="c")) + + self.storage.clear(workflow_id="wf-A") + + self.assertIsNone(self.storage.get(key="wf-A-task-1").storage) + self.assertEqual(self.storage.get(key="wf-B-task-1").storage, "c") diff --git a/tests/providers/test_storage_default_contract.py b/tests/providers/test_storage_default_contract.py new file mode 100644 index 00000000..d065adcf --- /dev/null +++ b/tests/providers/test_storage_default_contract.py @@ -0,0 +1,11 @@ +"""StorageDefault contract suite.""" + +import unittest + +from dotflow.providers.storage_default import StorageDefault +from dotflow.testing.storage_contract import StorageContract + + +class TestStorageDefaultContract(StorageContract, unittest.TestCase): + def make_storage(self): + return StorageDefault() diff --git a/tests/providers/test_storage_file_contract.py b/tests/providers/test_storage_file_contract.py new file mode 100644 index 00000000..e9ac33a4 --- /dev/null +++ b/tests/providers/test_storage_file_contract.py @@ -0,0 +1,19 @@ +"""StorageFile contract suite.""" + +import tempfile +import unittest +from pathlib import Path +from shutil import rmtree + +from dotflow.providers.storage_file import StorageFile +from dotflow.testing.storage_contract import StorageContract + + +class TestStorageFileContract(StorageContract, unittest.TestCase): + def make_storage(self): + self._tmp = Path(tempfile.mkdtemp(prefix="dotflow-storage-")) + + return StorageFile(path=self._tmp) + + def tearDown(self): + rmtree(self._tmp, ignore_errors=True)