-
Notifications
You must be signed in to change notification settings - Fork 268
cuda.bindings latency benchmarks - part 2 #1856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
danielfrg
wants to merge
12
commits into
main
Choose a base branch
from
cuda-bindings-bench-more
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,360
−43
Open
Changes from 8 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
a843f75
Add bench_ctx_device and fix JSON output
danielfrg 780b435
Remove prefix so we can compare benchmarks
danielfrg 90b5e0b
Add bench_event and bench_stream and compare script for a summary table
danielfrg 8126ab7
Add bench_event and bench_stream and compare script for a summary table
danielfrg a3f0678
Add Launch benchmarks
danielfrg e4762ed
Lint
danielfrg 170578c
Add motivation to readme
danielfrg c64dada
Add cuStreamSyncrhonize
danielfrg 6b821c3
Apply suggestion from @mdboom
rparolin 5364daa
Simplify kernel params
danielfrg 49c2651
Linting
danielfrg 5727e60
Merge branch 'main' into cuda-bindings-bench-more
danielfrg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import time | ||
|
|
||
| from runner.runtime import ensure_context | ||
|
|
||
| from cuda.bindings import driver as cuda | ||
|
|
||
| CTX = ensure_context() | ||
|
|
||
| _, DEVICE = cuda.cuDeviceGet(0) | ||
| ATTRIBUTE = cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR | ||
|
|
||
|
|
||
| def bench_ctx_get_current(loops: int) -> float: | ||
| _cuCtxGetCurrent = cuda.cuCtxGetCurrent | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuCtxGetCurrent() | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_ctx_set_current(loops: int) -> float: | ||
| _cuCtxSetCurrent = cuda.cuCtxSetCurrent | ||
| _ctx = CTX | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuCtxSetCurrent(_ctx) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_ctx_get_device(loops: int) -> float: | ||
| _cuCtxGetDevice = cuda.cuCtxGetDevice | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuCtxGetDevice() | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_device_get(loops: int) -> float: | ||
| _cuDeviceGet = cuda.cuDeviceGet | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuDeviceGet(0) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_device_get_attribute(loops: int) -> float: | ||
| _cuDeviceGetAttribute = cuda.cuDeviceGetAttribute | ||
| _attr = ATTRIBUTE | ||
| _dev = DEVICE | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuDeviceGetAttribute(_attr, _dev) | ||
| return time.perf_counter() - t0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import time | ||
|
|
||
| from runner.runtime import ensure_context | ||
|
|
||
| from cuda.bindings import driver as cuda | ||
|
|
||
| ensure_context() | ||
|
|
||
| _err, STREAM = cuda.cuStreamCreate(cuda.CUstream_flags.CU_STREAM_NON_BLOCKING.value) | ||
| _err, EVENT = cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING.value) | ||
|
|
||
| cuda.cuEventRecord(EVENT, STREAM) | ||
| cuda.cuStreamSynchronize(STREAM) | ||
|
|
||
| EVENT_FLAGS = cuda.CUevent_flags.CU_EVENT_DISABLE_TIMING.value | ||
|
|
||
|
|
||
| def bench_event_create_destroy(loops: int) -> float: | ||
| _cuEventCreate = cuda.cuEventCreate | ||
| _cuEventDestroy = cuda.cuEventDestroy | ||
| _flags = EVENT_FLAGS | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _, e = _cuEventCreate(_flags) | ||
| _cuEventDestroy(e) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_event_record(loops: int) -> float: | ||
| _cuEventRecord = cuda.cuEventRecord | ||
| _event = EVENT | ||
| _stream = STREAM | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuEventRecord(_event, _stream) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_event_query(loops: int) -> float: | ||
| _cuEventQuery = cuda.cuEventQuery | ||
| _event = EVENT | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuEventQuery(_event) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_event_synchronize(loops: int) -> float: | ||
| _cuEventSynchronize = cuda.cuEventSynchronize | ||
| _event = EVENT | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuEventSynchronize(_event) | ||
| return time.perf_counter() - t0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import ctypes | ||
| import time | ||
|
|
||
| from runner.runtime import alloc_persistent, compile_and_load, ensure_context | ||
|
|
||
| from cuda.bindings import driver as cuda | ||
|
|
||
| ensure_context() | ||
|
|
||
| # Compile kernels | ||
| KERNEL_SOURCE = """\ | ||
| extern "C" __global__ void empty_kernel() { return; } | ||
| extern "C" __global__ void small_kernel(float *f) { *f = 0.0f; } | ||
|
|
||
| #define ITEM_PARAM(x, T) T x | ||
| #define REP1(x, T) , ITEM_PARAM(x, T) | ||
| #define REP2(x, T) REP1(x##0, T) REP1(x##1, T) | ||
| #define REP4(x, T) REP2(x##0, T) REP2(x##1, T) | ||
| #define REP8(x, T) REP4(x##0, T) REP4(x##1, T) | ||
| #define REP16(x, T) REP8(x##0, T) REP8(x##1, T) | ||
|
|
||
| extern "C" __global__ | ||
| void small_kernel_16_args( | ||
| ITEM_PARAM(F, int*) | ||
| REP1(A, int*) | ||
| REP2(A, int*) | ||
| REP4(A, int*) | ||
| REP8(A, int*)) | ||
| { *F = 0; } | ||
| """ | ||
|
|
||
| MODULE = compile_and_load(KERNEL_SOURCE) | ||
|
|
||
| # Get kernel handles | ||
| _err, EMPTY_KERNEL = cuda.cuModuleGetFunction(MODULE, b"empty_kernel") | ||
| _err, SMALL_KERNEL = cuda.cuModuleGetFunction(MODULE, b"small_kernel") | ||
| _err, KERNEL_16_ARGS = cuda.cuModuleGetFunction(MODULE, b"small_kernel_16_args") | ||
|
|
||
| # Create a non-blocking stream for launches | ||
| _err, STREAM = cuda.cuStreamCreate(cuda.CUstream_flags.CU_STREAM_NON_BLOCKING.value) | ||
|
|
||
| # Allocate device memory for kernel arguments | ||
| FLOAT_PTR = alloc_persistent(ctypes.sizeof(ctypes.c_float)) | ||
| INT_PTRS = [alloc_persistent(ctypes.sizeof(ctypes.c_int)) for _ in range(16)] | ||
|
|
||
| # Pre-pack ctypes params for the pre-packed benchmark | ||
| _val_ps = [ctypes.c_void_p(int(p)) for p in INT_PTRS] | ||
| PACKED_16 = (ctypes.c_void_p * 16)() | ||
| for _i in range(16): | ||
| PACKED_16[_i] = ctypes.addressof(_val_ps[_i]) | ||
|
|
||
|
|
||
| def bench_launch_empty_kernel(loops: int) -> float: | ||
| _cuLaunchKernel = cuda.cuLaunchKernel | ||
| _kernel = EMPTY_KERNEL | ||
| _stream = STREAM | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuLaunchKernel(_kernel, 1, 1, 1, 1, 1, 1, 0, _stream, 0, 0) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_launch_small_kernel(loops: int) -> float: | ||
| _cuLaunchKernel = cuda.cuLaunchKernel | ||
| _kernel = SMALL_KERNEL | ||
| _stream = STREAM | ||
| _args = (FLOAT_PTR,) | ||
| _arg_types = (None,) | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuLaunchKernel(_kernel, 1, 1, 1, 1, 1, 1, 0, _stream, (_args, _arg_types), 0) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_launch_16_args(loops: int) -> float: | ||
| _cuLaunchKernel = cuda.cuLaunchKernel | ||
| _kernel = KERNEL_16_ARGS | ||
| _stream = STREAM | ||
| _args = tuple(INT_PTRS) | ||
| _arg_types = tuple([None] * 16) | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuLaunchKernel(_kernel, 1, 1, 1, 1, 1, 1, 0, _stream, (_args, _arg_types), 0) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_launch_16_args_pre_packed(loops: int) -> float: | ||
| _cuLaunchKernel = cuda.cuLaunchKernel | ||
| _kernel = KERNEL_16_ARGS | ||
| _stream = STREAM | ||
| _packed = PACKED_16 | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuLaunchKernel(_kernel, 1, 1, 1, 1, 1, 1, 0, _stream, _packed, 0) | ||
| return time.perf_counter() - t0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| import time | ||
|
|
||
| from runner.runtime import ensure_context | ||
|
|
||
| from cuda.bindings import driver as cuda | ||
|
|
||
| ensure_context() | ||
|
|
||
| _err, STREAM = cuda.cuStreamCreate(cuda.CUstream_flags.CU_STREAM_NON_BLOCKING.value) | ||
|
|
||
|
|
||
| def bench_stream_create_destroy(loops: int) -> float: | ||
| _cuStreamCreate = cuda.cuStreamCreate | ||
| _cuStreamDestroy = cuda.cuStreamDestroy | ||
| _flags = cuda.CUstream_flags.CU_STREAM_NON_BLOCKING.value | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _, s = _cuStreamCreate(_flags) | ||
| _cuStreamDestroy(s) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_stream_query(loops: int) -> float: | ||
| _cuStreamQuery = cuda.cuStreamQuery | ||
| _stream = STREAM | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuStreamQuery(_stream) | ||
| return time.perf_counter() - t0 | ||
|
|
||
|
|
||
| def bench_stream_synchronize(loops: int) -> float: | ||
| _cuStreamSynchronize = cuda.cuStreamSynchronize | ||
| _stream = STREAM | ||
|
|
||
| t0 = time.perf_counter() | ||
| for _ in range(loops): | ||
| _cuStreamSynchronize(_stream) | ||
| return time.perf_counter() - t0 |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.