-
Notifications
You must be signed in to change notification settings - Fork 108
sample for custom worker tuner #314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # Custom Worker Tuner | ||
|
|
||
| A `CustomSlotSupplier` is a sample that lets you gate slot grants on whatever you want. | ||
| This sample gates on a fake DB pool: the worker only polls for a new | ||
| activity when the pool has a free connection. | ||
|
|
||
| ## What this sample is | ||
| downstream.py - A static-capacity counter. Pretends to be a DB pool. Two methods: increment() (claim a slot, returns False if full), decrement() (release) | ||
| supplier.py - The custom slot supplier. On reserve_slot it polls downstream.increment() until it succeeds. On release_slot it calls downstream.decrement() | ||
| shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps | ||
| worker.py - Wires Downstream + DownstreamAwareSupplier into a WorkerTuner | ||
| starter.py - Drives load | ||
|
|
||
| The flow: | ||
|
|
||
| When the downstream is at capacity, `reserve_slot` blocks until a | ||
| slot frees up. The excess work piles up on the Temporal server, not | ||
| inside the worker. | ||
|
|
||
| ## Run | ||
|
|
||
| In three terminals from `samples-python/`: | ||
|
|
||
| ```bash | ||
| temporal server start-dev # terminal 1 | ||
| uv run custom_worker_tuner/worker.py # terminal 2 | ||
| uv run custom_worker_tuner/starter.py # terminal 3 | ||
| ``` | ||
|
|
||
| ## What you'll see | ||
|
|
||
| The worker prints one line per slot lifecycle event: | ||
|
|
||
| ``` | ||
|
|
||
| TIME EVENT SLOT COUNT DETAIL | ||
| ──────────────────────────────────────────────────────────── | ||
| 10:31:49.842 reserve #1 1/10 ready to poll | ||
| 10:31:49.842 reserve #2 2/10 ready to poll | ||
| 10:31:49.843 reserve #3 3/10 ready to poll | ||
| 10:31:49.843 reserve #4 4/10 ready to poll | ||
| 10:31:49.843 reserve #5 5/10 ready to poll | ||
| 10:31:49.843 reserve #6 6/10 ready to poll | ||
| 10:31:56.763 reserve #7 7/10 eager dispatch | ||
| 10:31:56.763 reserve #8 8/10 eager dispatch | ||
| 10:31:56.764 reserve #9 9/10 eager dispatch | ||
| 10:31:56.766 reserve #10 10/10 eager dispatch | ||
| 10:31:56.767 release #7 9/10 no task arrived | ||
| 10:31:56.768 release #8 8/10 no task arrived | ||
| 10:31:56.768 release #9 7/10 no task arrived | ||
| 10:31:56.768 reserve #11 8/10 eager dispatch | ||
| 10:31:56.768 reserve #12 9/10 eager dispatch | ||
| 10:31:56.768 reserve #13 10/10 eager dispatch | ||
| 10:31:56.771 used #1 10/10 activity running | ||
| 10:31:56.771 release #10 9/10 no task arrived | ||
| ``` | ||
|
|
||
| Under load, with more activities than capacity, COUNT pins at | ||
| 10/10 — that's the supplier refusing to poll past the gate. | ||
| we chose 10 because default there are 5 pollers for python sdk | ||
|
|
||
| ## Knobs | ||
|
|
||
| worker.py: | ||
|
|
||
| CAPACITY — downstream capacity (the gate) | ||
| POLL_INTERVAL_MS — how often the supplier rechecks when full | ||
|
|
||
| starter.py: | ||
|
|
||
| WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import logging | ||
| import threading | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class Downstream: | ||
| """A counter with a fixed capacity. Thread-safe.""" | ||
|
|
||
| def __init__(self, allowed_connections: int, name: str = "downstream") -> None: | ||
| self.allowed_connections = allowed_connections | ||
| self.name = name | ||
| self.currently_connected = 0 | ||
| self.connection_pool = threading.Lock() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could in fact literally be a https://docs.python.org/3/library/threading.html#threading.Semaphore
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done - the class is now just a thin wrapper over threading.BoundedSemaphore. The counter and lock are gone in_use is derived from the semaphore (allowed_connections - self._sem._value) |
||
| logger.info( | ||
| "Downstream ready: name=%s allowed_connections=%d", | ||
| name, | ||
| allowed_connections, | ||
| ) | ||
|
|
||
| def increment(self) -> bool: | ||
| """allow one connection. Returns False if at capacity.""" | ||
| with self.connection_pool: | ||
| if self.currently_connected >= self.allowed_connections: | ||
| return False | ||
| self.currently_connected += 1 | ||
| return True | ||
|
|
||
| def decrement(self) -> None: | ||
| """Release one slot. Floored at 0 so a buggy caller can't go negative.""" | ||
| with self.connection_pool: | ||
| self.currently_connected = max(0, self.currently_connected - 1) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| from dataclasses import dataclass | ||
| from datetime import timedelta | ||
|
|
||
| from temporalio import activity, workflow | ||
|
|
||
| TASK_QUEUE = "custom-worker-tuner" | ||
|
|
||
|
|
||
| @dataclass | ||
| class BatchInput: | ||
| activities: int | ||
| seconds: float | ||
|
|
||
|
|
||
| @activity.defn | ||
| async def do_work(seconds: float) -> None: | ||
| """Sleep, simulating an I/O-bound activity.""" | ||
| await asyncio.sleep(seconds) | ||
|
|
||
|
|
||
| @workflow.defn | ||
| class RunBatch: | ||
| """Runs N do_work activities in parallel.""" | ||
|
|
||
| @workflow.run | ||
| async def run(self, inp: BatchInput) -> None: | ||
| await asyncio.gather( | ||
| *( | ||
| workflow.execute_activity( | ||
| do_work, | ||
| inp.seconds, | ||
| start_to_close_timeout=timedelta(minutes=2), | ||
| ) | ||
| for _ in range(inp.activities) | ||
| ) | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import time | ||
| import uuid | ||
|
|
||
| from temporalio.client import Client | ||
| from temporalio.envconfig import ClientConfig | ||
|
|
||
| from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch | ||
|
|
||
| # Tweak these to push more or less load. | ||
| WORKFLOWS = 10 | ||
| ACTIVITIES_PER_WORKFLOW = 20 | ||
| SECONDS_PER_ACTIVITY = 2.0 | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| config = ClientConfig.load_client_connect_config() | ||
| config.setdefault("target_host", "localhost:7233") | ||
| client = await Client.connect(**config) | ||
| run_id = uuid.uuid4().hex[:8] | ||
| inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY) | ||
| total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW | ||
|
|
||
| print( | ||
| f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s" | ||
| ) | ||
| t0 = time.perf_counter() | ||
|
|
||
| handles = await asyncio.gather( | ||
| *( | ||
| client.start_workflow( | ||
| RunBatch.run, | ||
| inp, | ||
| id=f"batch-{run_id}-{i}", | ||
| task_queue=TASK_QUEUE, | ||
| ) | ||
| for i in range(WORKFLOWS) | ||
| ) | ||
| ) | ||
| await asyncio.gather(*(h.result() for h in handles)) | ||
|
|
||
| wall = time.perf_counter() - t0 | ||
| print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import itertools | ||
| import logging | ||
|
|
||
| from temporalio.worker import ( | ||
| CustomSlotSupplier, | ||
| SlotMarkUsedContext, | ||
| SlotPermit, | ||
| SlotReleaseContext, | ||
| SlotReserveContext, | ||
| ) | ||
|
|
||
| from custom_worker_tuner.downstream import Downstream | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _slot_id_gen = itertools.count(1) | ||
|
|
||
|
|
||
| class _Permit(SlotPermit): | ||
| """SlotPermit subclass that just carries a sequential id for logs.""" | ||
|
|
||
| def __init__(self, slot_id: int) -> None: | ||
| super().__init__() | ||
| self.slot_id = slot_id | ||
|
|
||
|
|
||
| class DownstreamAwareSupplier(CustomSlotSupplier): | ||
| def __init__(self, downstream: Downstream, poll_interval_ms: int = 100) -> None: | ||
| self.downstream = downstream | ||
| self.poll_interval_ms = poll_interval_ms | ||
| logger.info( | ||
| "DownstreamAwareSupplier ready: downstream=%s poll_interval_ms=%d", | ||
| downstream.name, | ||
| poll_interval_ms, | ||
| ) | ||
|
|
||
| async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit: | ||
| """block downstream until it has capacity to get incremented and then grant a slot.""" | ||
| slot_id = next(_slot_id_gen) | ||
| while not self.downstream.increment(): | ||
| await asyncio.sleep(self.poll_interval_ms / 1000.0) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Spin-looping here isn't great. Per my other suggestion if the fake pool just actually uses a semaphore the call itself can be blocking and no need for a sleep
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed - reserve_slot now does await asyncio.to_thread(self.connection_pool.acquire), so it blocks until a slot is free instead of polling. poll_interval_ms is removed. |
||
| self._log("reserve", slot_id, "ready to poll") | ||
| return _Permit(slot_id) | ||
|
|
||
| def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None: | ||
| """Eager path: can i run this activity right now?""" | ||
| if self.downstream.increment(): | ||
| slot_id = next(_slot_id_gen) | ||
| self._log("reserve", slot_id, "eager dispatch") | ||
| return _Permit(slot_id) | ||
| return None | ||
|
|
||
| def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None: | ||
| """A task arrived for a reserved slot""" | ||
| slot_id = getattr(ctx.permit, "slot_id", "?") | ||
| self._log("used", slot_id, "activity running") | ||
|
|
||
| def release_slot(self, ctx: SlotReleaseContext) -> None: | ||
| """Return the slot to the downstream.""" | ||
| slot_id = getattr(ctx.permit, "slot_id", "?") | ||
| detail = "no task arrived" if ctx.slot_info is None else "activity done" | ||
| self.downstream.decrement() | ||
| self._log("release", slot_id, detail) | ||
|
|
||
| def _log(self, event: str, slot_id, note: str) -> None: | ||
| count = f"{self.downstream.currently_connected}/{self.downstream.allowed_connections}" | ||
| logger.info(f"{event:<8} {count:>5} {note}") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import logging | ||
|
|
||
| from temporalio.client import Client | ||
| from temporalio.envconfig import ClientConfig | ||
| from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner | ||
|
|
||
| from custom_worker_tuner.downstream import Downstream | ||
| from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work | ||
| from custom_worker_tuner.supplier import DownstreamAwareSupplier | ||
|
|
||
| CAPACITY = 10 # number of connections allowed at a time | ||
| POLL_INTERVAL_MS = 500 | ||
| LOG_LEVEL = "INFO" # flip to "DEBUG" to see every increment/decrement | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| logging.basicConfig( | ||
| level=getattr(logging, LOG_LEVEL.upper(), logging.INFO), | ||
| format="%(asctime)s.%(msecs)03d %(message)s", | ||
| datefmt="%H:%M:%S", | ||
| ) | ||
|
|
||
| config = ClientConfig.load_client_connect_config() | ||
| config.setdefault("target_host", "localhost:7233") | ||
| client = await Client.connect(**config) | ||
|
|
||
| downstream = Downstream(allowed_connections=CAPACITY, name="db") | ||
| supplier = DownstreamAwareSupplier(downstream, poll_interval_ms=POLL_INTERVAL_MS) | ||
| tuner = WorkerTuner.create_composite( | ||
| workflow_supplier=FixedSizeSlotSupplier(100), | ||
| activity_supplier=supplier, | ||
| local_activity_supplier=FixedSizeSlotSupplier(100), | ||
| nexus_supplier=FixedSizeSlotSupplier(100), | ||
| ) | ||
|
|
||
| worker = Worker( | ||
| client, | ||
| task_queue=TASK_QUEUE, | ||
| workflows=[RunBatch], | ||
| activities=[do_work], | ||
| tuner=tuner, | ||
| ) | ||
|
|
||
| print(f"\nworker started — capacity={CAPACITY}, poll={POLL_INTERVAL_MS}ms\n") | ||
| print("TIME EVENT COUNT DETAIL") | ||
| print("─" * 60) | ||
| await worker.run() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| try: | ||
| asyncio.run(main()) | ||
| except KeyboardInterrupt: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Downstreamis an odd name for something that is essentially a Semaphore.Maybe
FakeDatabaseConnectionPoolwould be more descriptiveThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to FakeDatabaseConnectionPool (and moved to db_pool.py)