-
Notifications
You must be signed in to change notification settings - Fork 108
sample for custom worker tuner #314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
deepika-awasthi
wants to merge
4
commits into
main
Choose a base branch
from
deepika/custom-worker-tuner
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| # Custom Worker Tuner | ||
|
|
||
| A `CustomSlotSupplier` is a sample that lets you gate slot grants on whatever you want. | ||
| This sample gates on a fake DB pool: the worker only polls for a new | ||
| activity when the pool has a free connection. | ||
|
|
||
| **Note:** This sample is illustrative only. It shouldn't be used for production grade use-cases. | ||
|
|
||
| ## What this sample is | ||
| db_pool.py - A fixed-capacity fake pool backed by a `BoundedSemaphore`. Two methods: `acquire(blocking=True)` (claim a slot, returns False if full when non-blocking), `release()` (return a slot) | ||
| supplier.py - The custom slot supplier. `reserve_slot` blocks on `connection_pool.acquire()` until a slot is free; `try_reserve_slot` does the same non-blocking. `release_slot` calls `connection_pool.release()` | ||
| shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps | ||
| worker.py - Wires `FakeDatabaseConnectionPool` + `PoolSlotSupplier` into a WorkerTuner | ||
| starter.py - Drives load | ||
|
|
||
| The flow: | ||
|
|
||
| When the pool is at capacity, `reserve_slot` blocks until a | ||
| connection frees up. The excess work piles up on the Temporal server, not | ||
| inside the worker. | ||
|
|
||
| ## Run | ||
|
|
||
| In three terminals from `samples-python/`: | ||
|
|
||
| ```bash | ||
| temporal server start-dev # terminal 1 | ||
| uv run custom_worker_tuner/worker.py # terminal 2 | ||
| uv run custom_worker_tuner/starter.py # terminal 3 | ||
| ``` | ||
|
|
||
| ## What you'll see | ||
|
|
||
| The worker prints one line per slot lifecycle event: | ||
|
|
||
| ``` | ||
| TIME EVENT COUNT QUEUE DETAIL | ||
| (COUNT shows before→after / capacity; QUEUE = tasks parked waiting) | ||
| ───────────────────────────────────────────────────────────────── | ||
| 12:30:32.591 reserve 0→ 1/10 0 ready to poll | ||
| 12:30:32.591 reserve 1→ 2/10 0 ready to poll | ||
| 12:30:32.592 reserve 2→ 3/10 0 ready to poll | ||
| 12:30:32.592 reserve 3→ 4/10 0 ready to poll | ||
| 12:30:32.592 reserve 4→ 5/10 0 ready to poll | ||
| 12:30:32.592 reserve 5→ 6/10 0 ready to poll | ||
| 12:30:40.501 reserve 6→ 7/10 0 eager dispatch | ||
| 12:30:40.502 reserve 7→ 8/10 0 eager dispatch | ||
| 12:30:40.502 reserve 8→ 9/10 0 eager dispatch | ||
| 12:30:40.505 release 9→ 8/10 0 no task arrived | ||
| 12:30:40.506 release 8→ 7/10 0 no task arrived | ||
| 12:30:40.506 release 7→ 6/10 0 no task arrived | ||
| 12:30:40.510 used 6→ 6/10 0 activity running | ||
| 12:30:40.510 reserve 6→ 7/10 0 eager dispatch | ||
| 12:30:40.511 reserve 7→ 8/10 0 eager dispatch | ||
| 12:30:40.511 reserve 8→ 9/10 0 eager dispatch | ||
| 12:30:40.514 reserve 9→10/10 0 ready to poll | ||
| 12:30:40.520 release 10→ 9/10 0 no task arrived | ||
| 12:30:40.520 release 9→ 8/10 0 no task arrived | ||
| 12:30:40.520 release 8→ 7/10 0 no task arrived | ||
| 12:30:40.520 used 7→ 7/10 0 activity running | ||
| 12:30:40.520 reserve 7→ 8/10 0 eager dispatch | ||
| 12:30:40.520 reserve 8→ 9/10 0 eager dispatch | ||
| 12:30:40.520 reserve 9→10/10 0 eager dispatch | ||
| 12:30:40.525 release 10→10/10 0 no task arrived | ||
| 12:30:40.525 release 10→ 9/10 0 no task arrived | ||
| 12:30:40.525 release 9→ 8/10 0 no task arrived | ||
| 12:30:40.528 reserve 7→ 8/10 0 ready to poll | ||
| 12:30:40.530 used 8→ 8/10 0 activity running | ||
| 12:30:40.535 reserve 8→ 9/10 0 eager dispatch | ||
| 12:30:40.537 reserve 9→10/10 0 eager dispatch | ||
| 12:30:40.539 used 10→10/10 1 activity running | ||
| 12:30:40.540 used 10→10/10 1 activity running | ||
| 12:30:40.541 used 10→10/10 1 activity running | ||
| ``` | ||
|
|
||
| Under load, with more activities than capacity, COUNT pins at | ||
| 10/10 — that's the supplier refusing to poll past the gate. | ||
| we chose 10 because default there are 5 pollers for python sdk | ||
|
|
||
| ## Knobs | ||
|
|
||
| worker.py: | ||
|
|
||
| CAPACITY — pool capacity (the gate) | ||
|
|
||
| starter.py: | ||
|
|
||
| WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import logging | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class FakeDatabaseConnectionPool: | ||
| """Pretend connection pool with a fixed capacity, backed by an asyncio.Semaphore.""" | ||
|
|
||
| def __init__(self, allowed_connections: int, name: str = "db") -> None: | ||
| self.allowed_connections = allowed_connections | ||
| self.name = name | ||
| self._connection_pool = asyncio.Semaphore(allowed_connections) | ||
| logger.info( | ||
| "FakeDatabaseConnectionPool ready: name=%s allowed_connections=%d", | ||
| name, | ||
| allowed_connections, | ||
| ) | ||
|
|
||
| async def acquire(self) -> None: | ||
| """Claim a connection, awaiting until one is free.""" | ||
| await self._connection_pool.acquire() | ||
|
|
||
| def try_acquire(self) -> bool: | ||
| """Non-blocking claim, try_reserve_slot will call this | ||
| if the pool is full - it will return false | ||
| if it is not full - total pool connections - 1 and slot granted to activity | ||
| """ | ||
| if self._connection_pool.locked(): | ||
| return False | ||
| self._connection_pool._value -= 1 | ||
| return True | ||
|
|
||
| def release(self) -> None: | ||
| """Return a connection to the pool.""" | ||
| self._connection_pool.release() | ||
|
|
||
| @property | ||
| def in_use(self) -> int: | ||
| """Derived from the semaphore — single source of truth.""" | ||
| return self.allowed_connections - self._connection_pool._value | ||
|
|
||
| @property | ||
| def queued(self) -> int: | ||
| """How many tasks are parked waiting for a free slot.""" | ||
| waiters = self._connection_pool._waiters | ||
| if not waiters: | ||
| return 0 | ||
| return sum(1 for w in waiters if not w.done()) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| from dataclasses import dataclass | ||
| from datetime import timedelta | ||
|
|
||
| from temporalio import activity, workflow | ||
|
|
||
| TASK_QUEUE = "custom-worker-tuner" | ||
|
|
||
|
|
||
| @dataclass | ||
| class BatchInput: | ||
| activities: int | ||
| seconds: float | ||
|
|
||
|
|
||
| @activity.defn | ||
| async def do_work(seconds: float) -> None: | ||
| """Sleep, simulating an I/O-bound activity.""" | ||
| await asyncio.sleep(seconds) | ||
|
|
||
|
|
||
| @workflow.defn | ||
| class RunBatch: | ||
| """Runs N do_work activities in parallel.""" | ||
|
|
||
| @workflow.run | ||
| async def run(self, inp: BatchInput) -> None: | ||
| await asyncio.gather( | ||
| *( | ||
| workflow.execute_activity( | ||
| do_work, | ||
| inp.seconds, | ||
| start_to_close_timeout=timedelta(minutes=2), | ||
| ) | ||
| for _ in range(inp.activities) | ||
| ) | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import time | ||
| import uuid | ||
|
|
||
| from temporalio.client import Client | ||
| from temporalio.envconfig import ClientConfig | ||
|
|
||
| from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch | ||
|
|
||
| # Tweak these to push more or less load. | ||
| WORKFLOWS = 10 | ||
| ACTIVITIES_PER_WORKFLOW = 20 | ||
| SECONDS_PER_ACTIVITY = 2.0 | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| config = ClientConfig.load_client_connect_config() | ||
| config.setdefault("target_host", "localhost:7233") | ||
| client = await Client.connect(**config) | ||
| run_id = uuid.uuid4().hex[:8] | ||
| inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY) | ||
| total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW | ||
|
|
||
| print( | ||
| f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s" | ||
| ) | ||
| t0 = time.perf_counter() | ||
|
|
||
| handles = await asyncio.gather( | ||
| *( | ||
| client.start_workflow( | ||
| RunBatch.run, | ||
| inp, | ||
| id=f"batch-{run_id}-{i}", | ||
| task_queue=TASK_QUEUE, | ||
| ) | ||
| for i in range(WORKFLOWS) | ||
| ) | ||
| ) | ||
| await asyncio.gather(*(h.result() for h in handles)) | ||
|
|
||
| wall = time.perf_counter() - t0 | ||
| print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import itertools | ||
| import logging | ||
|
|
||
| from temporalio.worker import ( | ||
| CustomSlotSupplier, | ||
| SlotMarkUsedContext, | ||
| SlotPermit, | ||
| SlotReleaseContext, | ||
| SlotReserveContext, | ||
| ) | ||
|
|
||
| from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _slot_id_gen = itertools.count(1) | ||
|
|
||
|
|
||
| class _Permit(SlotPermit): | ||
| """SlotPermit subclass that just carries a sequential id for logs.""" | ||
|
|
||
| def __init__(self, slot_id: int) -> None: | ||
| super().__init__() | ||
| self.slot_id = slot_id | ||
|
|
||
|
|
||
| class PoolSlotSupplier(CustomSlotSupplier): | ||
| """Hands out slots only when the backing pool has a free connection.""" | ||
|
|
||
| def __init__(self, connection_pool: FakeDatabaseConnectionPool) -> None: | ||
| self.connection_pool = connection_pool | ||
| logger.info("PoolSlotSupplier ready: connection_pool=%s", connection_pool.name) | ||
|
|
||
| async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit: | ||
| """Block until the pool has capacity, then grant a slot.""" | ||
| await self.connection_pool.acquire() | ||
| after = self.connection_pool.in_use | ||
| slot_id = next(_slot_id_gen) | ||
| self._log("reserve", slot_id, "ready to poll", after - 1, after) | ||
| return _Permit(slot_id) | ||
|
|
||
| def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None: | ||
| """Eager path: try to claim a slot without blocking.""" | ||
| if self.connection_pool.try_acquire(): | ||
| after = self.connection_pool.in_use | ||
| slot_id = next(_slot_id_gen) | ||
| self._log("reserve", slot_id, "eager dispatch", after - 1, after) | ||
| return _Permit(slot_id) | ||
| return None | ||
|
|
||
| def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None: | ||
| slot_id = getattr(ctx.permit, "slot_id", "?") | ||
| in_use = self.connection_pool.in_use | ||
| self._log("used", slot_id, "activity running", in_use, in_use) | ||
|
|
||
| def release_slot(self, ctx: SlotReleaseContext) -> None: | ||
| slot_id = getattr(ctx.permit, "slot_id", "?") | ||
| detail = "no task arrived" if ctx.slot_info is None else "activity done" | ||
| before = self.connection_pool.in_use | ||
| self.connection_pool.release() | ||
| after = self.connection_pool.in_use | ||
| self._log("release", slot_id, detail, before, after) | ||
|
|
||
| def _log(self, event: str, slot_id, note: str, before: int, after: int) -> None: | ||
| cap = self.connection_pool.allowed_connections | ||
| count = f"{before:>2}→{after:>2}/{cap}" | ||
| queued = self.connection_pool.queued | ||
| logger.info(f"{event:<8} {count} {queued:>5} {note}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import logging | ||
|
|
||
| from temporalio.client import Client | ||
| from temporalio.envconfig import ClientConfig | ||
| from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner | ||
|
|
||
| from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool | ||
| from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work | ||
| from custom_worker_tuner.supplier import PoolSlotSupplier | ||
|
|
||
| CAPACITY = 10 # number of pool connections (and concurrent activities) | ||
| LOG_LEVEL = "INFO" | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| logging.basicConfig( | ||
| level=getattr(logging, LOG_LEVEL.upper(), logging.INFO), | ||
| format="%(asctime)s.%(msecs)03d %(message)s", | ||
| datefmt="%H:%M:%S", | ||
| ) | ||
|
|
||
| config = ClientConfig.load_client_connect_config() | ||
| config.setdefault("target_host", "localhost:7233") | ||
| client = await Client.connect(**config) | ||
|
|
||
| pool = FakeDatabaseConnectionPool(allowed_connections=CAPACITY, name="db") | ||
| supplier = PoolSlotSupplier(pool) | ||
| tuner = WorkerTuner.create_composite( | ||
| workflow_supplier=FixedSizeSlotSupplier(100), | ||
| activity_supplier=supplier, | ||
| local_activity_supplier=FixedSizeSlotSupplier(100), | ||
| nexus_supplier=FixedSizeSlotSupplier(100), | ||
| ) | ||
|
|
||
| worker = Worker( | ||
| client, | ||
| task_queue=TASK_QUEUE, | ||
| workflows=[RunBatch], | ||
| activities=[do_work], | ||
| tuner=tuner, | ||
| ) | ||
|
|
||
| print(f"\nworker started — capacity={CAPACITY}\n") | ||
| print("TIME EVENT COUNT QUEUE DETAIL") | ||
| print("(COUNT shows before→after / capacity; QUEUE = tasks parked waiting)") | ||
| print("─" * 65) | ||
| await worker.run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| try: | ||
| asyncio.run(main()) | ||
| except KeyboardInterrupt: | ||
| pass |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't appreciate async Semaphore doesn't have a try_acquire.
This is a bit hacky reaching into the internals like this, but, I'm not sure I have a great alternative