Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions custom_worker_tuner/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Custom Worker Tuner

A `CustomSlotSupplier` is a sample that lets you gate slot grants on whatever you want.
This sample gates on a fake DB pool: the worker only polls for a new
activity when the pool has a free connection.

## What this sample is
downstream.py - A static-capacity counter. Pretends to be a DB pool. Two methods: increment() (claim a slot, returns False if full), decrement() (release)
supplier.py - The custom slot supplier. On reserve_slot it polls downstream.increment() until it succeeds. On release_slot it calls downstream.decrement()
shared.py - A RunBatch workflow that runs N do_work activities in parallel. The activity just sleeps
worker.py - Wires Downstream + DownstreamAwareSupplier into a WorkerTuner
starter.py - Drives load

The flow:

When the downstream is at capacity, `reserve_slot` blocks until a
slot frees up. The excess work piles up on the Temporal server, not
inside the worker.

## Run

In three terminals from `samples-python/`:

```bash
temporal server start-dev # terminal 1
uv run custom_worker_tuner/worker.py # terminal 2
uv run custom_worker_tuner/starter.py # terminal 3
```

## What you'll see

The worker prints one line per slot lifecycle event:

```

TIME EVENT SLOT COUNT DETAIL
────────────────────────────────────────────────────────────
10:31:49.842 reserve #1 1/10 ready to poll
10:31:49.842 reserve #2 2/10 ready to poll
10:31:49.843 reserve #3 3/10 ready to poll
10:31:49.843 reserve #4 4/10 ready to poll
10:31:49.843 reserve #5 5/10 ready to poll
10:31:49.843 reserve #6 6/10 ready to poll
10:31:56.763 reserve #7 7/10 eager dispatch
10:31:56.763 reserve #8 8/10 eager dispatch
10:31:56.764 reserve #9 9/10 eager dispatch
10:31:56.766 reserve #10 10/10 eager dispatch
10:31:56.767 release #7 9/10 no task arrived
10:31:56.768 release #8 8/10 no task arrived
10:31:56.768 release #9 7/10 no task arrived
10:31:56.768 reserve #11 8/10 eager dispatch
10:31:56.768 reserve #12 9/10 eager dispatch
10:31:56.768 reserve #13 10/10 eager dispatch
10:31:56.771 used #1 10/10 activity running
10:31:56.771 release #10 9/10 no task arrived
```

Under load, with more activities than capacity, COUNT pins at
10/10 — that's the supplier refusing to poll past the gate.
we chose 10 because default there are 5 pollers for python sdk

## Knobs

worker.py:

CAPACITY — downstream capacity (the gate)
POLL_INTERVAL_MS — how often the supplier rechecks when full

starter.py:

WORKFLOWS, ACTIVITIES_PER_WORKFLOW, SECONDS_PER_ACTIVITY — amount and duration of load
Empty file.
34 changes: 34 additions & 0 deletions custom_worker_tuner/downstream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

import logging
import threading

logger = logging.getLogger(__name__)


class Downstream:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Downstream is an odd name for something that is essentially a Semaphore.

Maybe FakeDatabaseConnectionPool would be more descriptive

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to FakeDatabaseConnectionPool (and moved to db_pool.py)

"""A counter with a fixed capacity. Thread-safe."""

def __init__(self, allowed_connections: int, name: str = "downstream") -> None:
self.allowed_connections = allowed_connections
self.name = name
self.currently_connected = 0
self.connection_pool = threading.Lock()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - the class is now just a thin wrapper over threading.BoundedSemaphore. The counter and lock are gone

in_use is derived from the semaphore (allowed_connections - self._sem._value)

logger.info(
"Downstream ready: name=%s allowed_connections=%d",
name,
allowed_connections,
)

def increment(self) -> bool:
"""allow one connection. Returns False if at capacity."""
with self.connection_pool:
if self.currently_connected >= self.allowed_connections:
return False
self.currently_connected += 1
return True

def decrement(self) -> None:
"""Release one slot. Floored at 0 so a buggy caller can't go negative."""
with self.connection_pool:
self.currently_connected = max(0, self.currently_connected - 1)
39 changes: 39 additions & 0 deletions custom_worker_tuner/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from __future__ import annotations

import asyncio
from dataclasses import dataclass
from datetime import timedelta

from temporalio import activity, workflow

TASK_QUEUE = "custom-worker-tuner"


@dataclass
class BatchInput:
activities: int
seconds: float


@activity.defn
async def do_work(seconds: float) -> None:
"""Sleep, simulating an I/O-bound activity."""
await asyncio.sleep(seconds)


@workflow.defn
class RunBatch:
"""Runs N do_work activities in parallel."""

@workflow.run
async def run(self, inp: BatchInput) -> None:
await asyncio.gather(
*(
workflow.execute_activity(
do_work,
inp.seconds,
start_to_close_timeout=timedelta(minutes=2),
)
for _ in range(inp.activities)
)
)
49 changes: 49 additions & 0 deletions custom_worker_tuner/starter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

import asyncio
import time
import uuid

from temporalio.client import Client
from temporalio.envconfig import ClientConfig

from custom_worker_tuner.shared import TASK_QUEUE, BatchInput, RunBatch

# Tweak these to push more or less load.
WORKFLOWS = 10
ACTIVITIES_PER_WORKFLOW = 20
SECONDS_PER_ACTIVITY = 2.0


async def main() -> None:
config = ClientConfig.load_client_connect_config()
config.setdefault("target_host", "localhost:7233")
client = await Client.connect(**config)
run_id = uuid.uuid4().hex[:8]
inp = BatchInput(activities=ACTIVITIES_PER_WORKFLOW, seconds=SECONDS_PER_ACTIVITY)
total = WORKFLOWS * ACTIVITIES_PER_WORKFLOW

print(
f"starting {WORKFLOWS} workflows × {ACTIVITIES_PER_WORKFLOW} activities × {SECONDS_PER_ACTIVITY}s"
)
t0 = time.perf_counter()

handles = await asyncio.gather(
*(
client.start_workflow(
RunBatch.run,
inp,
id=f"batch-{run_id}-{i}",
task_queue=TASK_QUEUE,
)
for i in range(WORKFLOWS)
)
)
await asyncio.gather(*(h.result() for h in handles))

wall = time.perf_counter() - t0
print(f"done in {wall:.1f}s ({total} activities, {total / wall:.0f}/s)")


if __name__ == "__main__":
asyncio.run(main())
70 changes: 70 additions & 0 deletions custom_worker_tuner/supplier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

import asyncio
import itertools
import logging

from temporalio.worker import (
CustomSlotSupplier,
SlotMarkUsedContext,
SlotPermit,
SlotReleaseContext,
SlotReserveContext,
)

from custom_worker_tuner.downstream import Downstream

logger = logging.getLogger(__name__)

_slot_id_gen = itertools.count(1)


class _Permit(SlotPermit):
"""SlotPermit subclass that just carries a sequential id for logs."""

def __init__(self, slot_id: int) -> None:
super().__init__()
self.slot_id = slot_id


class DownstreamAwareSupplier(CustomSlotSupplier):
def __init__(self, downstream: Downstream, poll_interval_ms: int = 100) -> None:
self.downstream = downstream
self.poll_interval_ms = poll_interval_ms
logger.info(
"DownstreamAwareSupplier ready: downstream=%s poll_interval_ms=%d",
downstream.name,
poll_interval_ms,
)

async def reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit:
"""block downstream until it has capacity to get incremented and then grant a slot."""
slot_id = next(_slot_id_gen)
while not self.downstream.increment():
await asyncio.sleep(self.poll_interval_ms / 1000.0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spin-looping here isn't great. Per my other suggestion if the fake pool just actually uses a semaphore the call itself can be blocking and no need for a sleep

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed -

reserve_slot now does await asyncio.to_thread(self.connection_pool.acquire), so it blocks until a slot is free instead of polling.

poll_interval_ms is removed.

self._log("reserve", slot_id, "ready to poll")
return _Permit(slot_id)

def try_reserve_slot(self, ctx: SlotReserveContext) -> SlotPermit | None:
"""Eager path: can i run this activity right now?"""
if self.downstream.increment():
slot_id = next(_slot_id_gen)
self._log("reserve", slot_id, "eager dispatch")
return _Permit(slot_id)
return None

def mark_slot_used(self, ctx: SlotMarkUsedContext) -> None:
"""A task arrived for a reserved slot"""
slot_id = getattr(ctx.permit, "slot_id", "?")
self._log("used", slot_id, "activity running")

def release_slot(self, ctx: SlotReleaseContext) -> None:
"""Return the slot to the downstream."""
slot_id = getattr(ctx.permit, "slot_id", "?")
detail = "no task arrived" if ctx.slot_info is None else "activity done"
self.downstream.decrement()
self._log("release", slot_id, detail)

def _log(self, event: str, slot_id, note: str) -> None:
count = f"{self.downstream.currently_connected}/{self.downstream.allowed_connections}"
logger.info(f"{event:<8} {count:>5} {note}")
57 changes: 57 additions & 0 deletions custom_worker_tuner/worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import asyncio
import logging

from temporalio.client import Client
from temporalio.envconfig import ClientConfig
from temporalio.worker import FixedSizeSlotSupplier, Worker, WorkerTuner

from custom_worker_tuner.downstream import Downstream
from custom_worker_tuner.shared import TASK_QUEUE, RunBatch, do_work
from custom_worker_tuner.supplier import DownstreamAwareSupplier

CAPACITY = 10 # number of connections allowed at a time
POLL_INTERVAL_MS = 500
LOG_LEVEL = "INFO" # flip to "DEBUG" to see every increment/decrement


async def main() -> None:
logging.basicConfig(
level=getattr(logging, LOG_LEVEL.upper(), logging.INFO),
format="%(asctime)s.%(msecs)03d %(message)s",
datefmt="%H:%M:%S",
)

config = ClientConfig.load_client_connect_config()
config.setdefault("target_host", "localhost:7233")
client = await Client.connect(**config)

downstream = Downstream(allowed_connections=CAPACITY, name="db")
supplier = DownstreamAwareSupplier(downstream, poll_interval_ms=POLL_INTERVAL_MS)
tuner = WorkerTuner.create_composite(
workflow_supplier=FixedSizeSlotSupplier(100),
activity_supplier=supplier,
local_activity_supplier=FixedSizeSlotSupplier(100),
nexus_supplier=FixedSizeSlotSupplier(100),
)

worker = Worker(
client,
task_queue=TASK_QUEUE,
workflows=[RunBatch],
activities=[do_work],
tuner=tuner,
)

print(f"\nworker started — capacity={CAPACITY}, poll={POLL_INTERVAL_MS}ms\n")
print("TIME EVENT COUNT DETAIL")
print("─" * 60)
await worker.run()


if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass
Loading