forked from NVIDIA/cuda-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
57 lines (46 loc) · 1.98 KB
/
__init__.py
File metadata and controls
57 lines (46 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import functools
import os
from typing import Union
from cuda.core._utils.cuda_utils import handle_return
from cuda.pathfinder import get_cuda_path_or_home
from cuda_python_test_helpers import *
CUDA_PATH = get_cuda_path_or_home()
CUDA_INCLUDE_PATH = None
CCCL_INCLUDE_PATHS = None
if CUDA_PATH is not None:
path = os.path.join(CUDA_PATH, "include")
if os.path.isdir(path):
CUDA_INCLUDE_PATH = path
CCCL_INCLUDE_PATHS = (path,)
path = os.path.join(path, "cccl")
if os.path.isdir(path):
CCCL_INCLUDE_PATHS = (path,) + CCCL_INCLUDE_PATHS
@functools.cache
def supports_ipc_mempool(device_id: Union[int, object]) -> bool:
"""Return True if mempool IPC via POSIX file descriptor is supported.
Uses cuDeviceGetAttribute(CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES)
to check for CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR support. Does not
require an active CUDA context.
"""
if IS_WSL:
return False
try:
# Lazy import to avoid hard dependency when not running GPU tests
try:
from cuda.bindings import driver # type: ignore
except Exception:
from cuda import cuda as driver # type: ignore
# Initialize CUDA
handle_return(driver.cuInit(0))
# Resolve device id from int or Device-like object
dev_id = int(getattr(device_id, "device_id", device_id))
# Query supported mempool handle types bitmask
attr = driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES
mask = handle_return(driver.cuDeviceGetAttribute(attr, dev_id))
# Check POSIX FD handle type support via bitmask
posix_fd = driver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
return (int(mask) & int(posix_fd)) != 0
except Exception:
return False