diff --git a/setup.py b/setup.py index fa61b72e7..58234c683 100644 --- a/setup.py +++ b/setup.py @@ -532,11 +532,17 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int # CUDA group norm supports from SM70 arch_flags = [] - # FIXME: this needs to be done more cleanly - for arch in [70, 75, 80, 86, 90, 100, 120]: - arch_flag = f"-gencode=arch=compute_{arch},code=sm_{arch}" - arch_flags.append(arch_flag) - arch_flags.append(arch_flag) + # Gate architectures by available CUDA version to avoid unsupported targets on older toolkits + allowed_arches = [70, 75, 80] + if bare_metal_version >= Version("11.1"): + allowed_arches.append(86) + if bare_metal_version >= Version("11.8"): + allowed_arches.append(90) + if bare_metal_version >= Version("12.8"): + allowed_arches.extend([100, 120]) + + for arch in allowed_arches: + arch_flags.append(f"-gencode=arch=compute_{arch},code=sm_{arch}") ext_modules.append( CUDAExtension(