Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions custom_worker_tuner/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Custom Worker Tuner

A `CustomSlotSupplier` is a sample that lets you gate slot grants on whatever you want.
This sample gates on a fake DB pool: the worker only polls for a new
activity when the pool has a free connection.

**Note:** This sample is illustrative only. It shouldn't be used for production grade use-cases.

## What this sample is
db_pool.py - A fixed-capacity fake pool backed by a `BoundedSemaphore`. Two methods: `acquire(blocking=True)` (claim a slot, returns False if full when non-blocking), `release()` (return a slot)
supplier.py - The custom slot supplier. `reserve_slot` blocks on `connection_pool.acquire()` until a slot is free; `try_reserve_slot` does the same non-blocking. `release_slot` calls `connection_pool.release()`
shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps
worker.py - Wires `FakeDatabaseConnectionPool` + `PoolSlotSupplier` into a WorkerTuner
starter.py - Drives load

The flow:

When the pool is at capacity, `reserve_slot` blocks until a
connection frees up. The excess work piles up on the Temporal server, not
inside the worker.

## Run

In three terminals from `samples-python/`:

```bash
temporal server start-dev # terminal 1
uv run custom_worker_tuner/worker.py # terminal 2
uv run custom_worker_tuner/starter.py # terminal 3
```

## What you'll see

The worker prints one line per slot lifecycle event:

```
TIME EVENT COUNT QUEUE DETAIL
(COUNT shows before→after / capacity; QUEUE = tasks parked waiting)
─────────────────────────────────────────────────────────────────
12:30:32.591 reserve 0→ 1/10 0 ready to poll
12:30:32.591 reserve 1→ 2/10 0 ready to poll
12:30:32.592 reserve 2→ 3/10 0 ready to poll
12:30:32.592 reserve 3→ 4/10 0 ready to poll
12:30:32.592 reserve 4→ 5/10 0 ready to poll
12:30:32.592 reserve 5→ 6/10 0 ready to poll
12:30:40.501 reserve 6→ 7/10 0 eager dispatch
12:30:40.502 reserve 7→ 8/10 0 eager dispatch
12:30:40.502 reserve 8→ 9/10 0 eager dispatch
12:30:40.505 release 9→ 8/10 0 no task arrived
12:30:40.506 release 8→ 7/10 0 no task arrived
12:30:40.506 release 7→ 6/10 0 no task arrived
12:30:40.510 used 6→ 6/10 0 activity running
12:30:40.510 reserve 6→ 7/10 0 eager dispatch
12:30:40.511 reserve 7→ 8/10 0 eager dispatch
12:30:40.511 reserve 8→ 9/10 0 eager dispatch
12:30:40.514 reserve 9→10/10 0 ready to poll
12:30:40.520 release 10→ 9/10 0 no task arrived
12:30:40.520 release 9→ 8/10 0 no task arrived
12:30:40.520 release 8→ 7/10 0 no task arrived
12:30:40.520 used 7→ 7/10 0 activity running
12:30:40.520 reserve 7→ 8/10 0 eager dispatch
12:30:40.520 reserve 8→ 9/10 0 eager dispatch
12:30:40.520 reserve 9→10/10 0 eager dispatch
12:30:40.525 release 10→10/10 0 no task arrived
12:30:40.525 release 10→ 9/10 0 no task arrived
12:30:40.525 release 9→ 8/10 0 no task arrived
12:30:40.528 reserve 7→ 8/10 0 ready to poll
12:30:40.530 used 8→ 8/10 0 activity running
12:30:40.535 reserve 8→ 9/10 0 eager dispatch
12:30:40.537 reserve 9→10/10 0 eager dispatch
12:30:40.539 used 10→10/10 1 activity running
12:30:40.540 used 10→10/10 1 activity running
12:30:40.541 used 10→10/10 1 activity running
```

Under load, with more activities than capacity, COUNT pins at
10/10 — that's the supplier refusing to poll past the gate.
we chose 10 because default there are 5 pollers for python sdk

## Knobs

worker.py:

CAPACITY — pool capacity (the gate)

starter.py:

WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load
Empty file.
51 changes: 51 additions & 0 deletions custom_worker_tuner/db_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

import asyncio
import logging

logger = logging.getLogger(__name__)


class FakeDatabaseConnectionPool:
"""Pretend connection pool with a fixed capacity, backed by an asyncio.Semaphore."""

def __init__(self, allowed_connections: int, name: str = "db") -> None:
self.allowed_connections = allowed_connections
self.name = name
self._connection_pool = asyncio.Semaphore(allowed_connections)
logger.info(
"FakeDatabaseConnectionPool ready: name=%s allowed_connections=%d",
name,
allowed_connections,
)

async def acquire(self) -> None:
"""Claim a connection, awaiting until one is free."""
await self._connection_pool.acquire()

def try_acquire(self) -> bool:
"""Non-blocking claim, try_reserve_slot will call this
if the pool is full - it will return false
if it is not full - total pool connections - 1 and slot granted to activity
"""
if self._connection_pool.locked():
return False
self._connection_pool._value -= 1
Comment on lines +31 to +33
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't appreciate async Semaphore doesn't have a try_acquire.

This is a bit hacky reaching into the internals like this, but, I'm not sure I have a great alternative

return True

def release(self) -> None:
"""Return a connection to the pool."""
self._connection_pool.release()

@property
def in_use(self) -> int:
"""Derived from the semaphore — single source of truth."""
return self.allowed_connections - self._connection_pool._value

@property
def queued(self) -> int:
"""How many tasks are parked waiting for a free slot."""
waiters = self._connection_pool._waiters
if not waiters:
return 0
return sum(1 for w in waiters if not w.done())
39 changes: 39 additions & 0 deletions custom_worker_tuner/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from datetime import timedelta

from temporalio import activity, workflow

TASK_QUEUE = "custom-worker-tuner"


@dataclass
class BatchInput:
activities: int
seconds: float


@activity.defn
async def do_work(seconds: float) -> None:
"""Sleep, simulating an I/O-bound activity."""
await asyncio.sleep(seconds)


@workflow.defn
class RunBatch:
"""Runs N do_work activities in parallel."""

@workflow.run
async def run(self, inp: BatchInput) -> None:
await asyncio.gather(
*(
workflow.execute_activity(
do_work,
inp.seconds,
start_to_close_timeout=timedelta(minutes=2),
)
for _ in range(inp.activities)
)
)
49 changes: 49 additions & 0 deletions custom_worker_tuner/starter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import asyncio
import time
import uuid

from temporalio.client import Client
from temporalio.envconfig import ClientConfig

from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch

# Tweak these to push more or less load.
WORKFLOWS = 10
ACTIVITIES_PER_WORKFLOW = 20
SECONDS_PER_ACTIVITY = 2.0


async def main() -> None:
config = ClientConfig.load_client_connect_config()
config.setdefault("target_host", "localhost:7233")
client = await Client.connect(**config)
run_id = uuid.uuid4().hex[:8]
inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY)
total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW

print(
f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s"
)
t0 = time.perf_counter()

handles = await asyncio.gather(
*(
client.start_workflow(
RunBatch.run,
inp,
id=f"batch-{run_id}-{i}",
task_queue=TASK_QUEUE,
)
for i in range(WORKFLOWS)
)
)
await asyncio.gather(*(h.result() for h in handles))

wall = time.perf_counter() - t0
print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)")


if __name__ == "__main__":
asyncio.run(main())
70 changes: 70 additions & 0 deletions custom_worker_tuner/supplier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

import itertools
import logging

from temporalio.worker import (
CustomSlotSupplier,
SlotMarkUsedContext,
SlotPermit,
SlotReleaseContext,
SlotReserveContext,
)

from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool

logger = logging.getLogger(__name__)

_slot_id_gen = itertools.count(1)


class _Permit(SlotPermit):
"""SlotPermit subclass that just carries a sequential id for logs."""

def __init__(self, slot_id: int) -> None:
super().__init__()
self.slot_id = slot_id


class PoolSlotSupplier(CustomSlotSupplier):
"""Hands out slots only when the backing pool has a free connection."""

def __init__(self, connection_pool: FakeDatabaseConnectionPool) -> None:
self.connection_pool = connection_pool
logger.info("PoolSlotSupplier ready: connection_pool=%s", connection_pool.name)

async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
"""Block until the pool has capacity, then grant a slot."""
await self.connection_pool.acquire()
after = self.connection_pool.in_use
slot_id = next(_slot_id_gen)
self._log("reserve", slot_id, "ready to poll", after - 1, after)
return _Permit(slot_id)

def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None:
"""Eager path: try to claim a slot without blocking."""
if self.connection_pool.try_acquire():
after = self.connection_pool.in_use
slot_id = next(_slot_id_gen)
self._log("reserve", slot_id, "eager dispatch", after - 1, after)
return _Permit(slot_id)
return None

def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None:
slot_id = getattr(ctx.permit, "slot_id", "?")
in_use = self.connection_pool.in_use
self._log("used", slot_id, "activity running", in_use, in_use)

def release_slot(self, ctx: SlotReleaseContext) -> None:
slot_id = getattr(ctx.permit, "slot_id", "?")
detail = "no task arrived" if ctx.slot_info is None else "activity done"
before = self.connection_pool.in_use
self.connection_pool.release()
after = self.connection_pool.in_use
self._log("release", slot_id, detail, before, after)

def _log(self, event: str, slot_id, note: str, before: int, after: int) -> None:
cap = self.connection_pool.allowed_connections
count = f"{before:>2}→{after:>2}/{cap}"
queued = self.connection_pool.queued
logger.info(f"{event:<8} {count} {queued:>5} {note}")
57 changes: 57 additions & 0 deletions custom_worker_tuner/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import asyncio
import logging

from temporalio.client import Client
from temporalio.envconfig import ClientConfig
from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner

from custom_worker_tuner.db_pool import FakeDatabaseConnectionPool
from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work
from custom_worker_tuner.supplier import PoolSlotSupplier

CAPACITY = 10 # number of pool connections (and concurrent activities)
LOG_LEVEL = "INFO"


async def main() -> None:
logging.basicConfig(
level=getattr(logging, LOG_LEVEL.upper(), logging.INFO),
format="%(asctime)s.%(msecs)03d %(message)s",
datefmt="%H:%M:%S",
)

config = ClientConfig.load_client_connect_config()
config.setdefault("target_host", "localhost:7233")
client = await Client.connect(**config)

pool = FakeDatabaseConnectionPool(allowed_connections=CAPACITY, name="db")
supplier = PoolSlotSupplier(pool)
tuner = WorkerTuner.create_composite(
workflow_supplier=FixedSizeSlotSupplier(100),
activity_supplier=supplier,
local_activity_supplier=FixedSizeSlotSupplier(100),
nexus_supplier=FixedSizeSlotSupplier(100),
)

worker = Worker(
client,
task_queue=TASK_QUEUE,
workflows=[RunBatch],
activities=[do_work],
tuner=tuner,
)

print(f"\nworker started — capacity={CAPACITY}\n")
print("TIME EVENT COUNT QUEUE DETAIL")
print("(COUNT shows before→after / capacity; QUEUE = tasks parked waiting)")
print("─" * 65)
await worker.run()


if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass
Loading