From a19fe0e5090a5d3f57a5e854286e89a2a2729f47 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 28 May 2026 00:15:20 +0900 Subject: [PATCH 1/4] Use dispatcher shims for APEX extensions Convert the Python-facing custom extension surface to private dispatcher-backed shims under apex._extensions. The shims load the compiled libraries with torch.ops.load_library, call torch.ops.apex registrations directly, and normalize scalar and tensor-list arguments where the dispatcher schemas are stricter than the previous pybind entry points. Move the generated Python module names out of the repository root and stop packaging top-level compatibility modules such as amp_C, fused_adam_cuda, and fused_layer_norm_cuda. Internal APEX imports and affected tests now import through apex._extensions instead, which keeps the package root clean and makes the compatibility break explicit. Replace converted C++ extension frontends with TORCH_LIBRARY dispatcher registrations and build them with py_limited_api where possible. The remaining non-stable-ABI surfaces are left out of this conversion because they still need Python object bindings or setup-time helper behavior. Preserve test-observed behavior while changing the binding layer: fp16 clip_grad falls back to PyTorch's clipping semantics, FusedDense initializes parameters like nn.Linear, and bf16 FusedDense uses the PyTorch matmul+bias path to avoid the fused cublasLt bf16 mismatch. Authored with codex gpt-5.5 xhigh Signed-off-by: Masaki Kozuki --- apex/_custom_ops.py | 42 ++ apex/_extensions/__init__.py | 1 + apex/_extensions/amp_C.py | 350 +++++++++++ apex/_extensions/apex_C.py | 18 + apex/_extensions/bnp.py | 263 ++++++++ apex/_extensions/cudnn_gbn_lib.py | 55 ++ apex/_extensions/distributed_adam_cuda.py | 97 +++ apex/_extensions/distributed_lamb_cuda.py | 67 ++ apex/_extensions/fast_bottleneck.py | 178 ++++++ apex/_extensions/fast_layer_norm.py | 28 + apex/_extensions/fast_multihead_attn.py | 411 +++++++++++++ apex/_extensions/fmhalib.py | 48 ++ apex/_extensions/focal_loss_cuda.py | 21 + apex/_extensions/fused_adam_cuda.py | 154 +++++ apex/_extensions/fused_conv_bias_relu.py | 67 ++ apex/_extensions/fused_dense_cuda.py | 11 + apex/_extensions/fused_index_mul_2d.py | 13 + apex/_extensions/fused_lamb_cuda.py | 45 ++ apex/_extensions/fused_layer_norm_cuda.py | 17 + .../fused_rotary_positional_embedding.py | 15 + .../fused_weight_gradient_mlp_cuda.py | 9 + .../generic_scaled_masked_softmax_cuda.py | 14 + apex/_extensions/group_norm_cuda.py | 32 + apex/_extensions/group_norm_v2_cuda.py | 33 + apex/_extensions/mlp_cuda.py | 14 + apex/_extensions/nccl_p2p_cuda.py | 53 ++ apex/_extensions/peer_memory_cuda.py | 93 +++ apex/_extensions/permutation_search_cuda.py | 76 +++ .../_extensions/scaled_masked_softmax_cuda.py | 20 + apex/_extensions/scaled_softmax_cuda.py | 14 + ...scaled_upper_triang_masked_softmax_cuda.py | 16 + apex/_extensions/syncbn.py | 17 + apex/_extensions/transducer_joint_cuda.py | 50 ++ apex/_extensions/transducer_loss_cuda.py | 53 ++ apex/_extensions/xentropy_cuda.py | 16 + apex/contrib/bottleneck/bottleneck.py | 4 +- apex/contrib/bottleneck/halo_exchangers.py | 4 +- apex/contrib/clip_grad/clip_grad.py | 13 +- apex/contrib/conv_bias_relu/conv_bias_relu.py | 2 +- apex/contrib/csrc/bottleneck/bottleneck.cpp | 258 ++++++-- .../csrc/conv_bias_relu/conv_bias_relu.cpp | 51 +- apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp | 42 +- apex/contrib/csrc/cudnn_gbn/norm_sample.cpp | 11 +- apex/contrib/csrc/fmha/fmha_api.cpp | 74 ++- apex/contrib/csrc/fmha/src/fmha_fill.cu | 2 +- .../csrc/focal_loss/focal_loss_cuda.cpp | 26 +- .../csrc/group_norm/group_norm_nhwc_op.cpp | 114 ++-- apex/contrib/csrc/group_norm_v2/gn.cpp | 93 ++- apex/contrib/csrc/groupbn/interface.cpp | 177 +++++- .../csrc/index_mul_2d/index_mul_2d_cuda.cpp | 90 ++- apex/contrib/csrc/layer_norm/ln_api.cpp | 85 ++- .../additive_masked_softmax_dropout_cuda.cu | 20 +- .../encdec_multihead_attn_cuda.cu | 56 +- .../encdec_multihead_attn_norm_add_cuda.cu | 88 +-- .../masked_softmax_dropout_cuda.cu | 20 +- .../multihead_attn_frontend.cpp | 570 ++++++++++++------ ..._multihead_attn_bias_additive_mask_cuda.cu | 46 +- .../self_multihead_attn_bias_cuda.cu | 46 +- .../self_multihead_attn_cuda.cu | 44 +- .../self_multihead_attn_norm_add_cuda.cu | 78 +-- .../multihead_attn/strided_batched_gemm.cuh | 2 + apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp | 60 +- apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu | 16 +- apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh | 3 +- .../csrc/optimizers/fused_adam_cuda.cpp | 115 +++- .../csrc/optimizers/fused_lamb_cuda.cpp | 26 +- .../optimizers/multi_tensor_distopt_adam.cpp | 64 +- .../optimizers/multi_tensor_distopt_lamb.cpp | 45 +- apex/contrib/csrc/peer_memory/peer_memory.cpp | 78 ++- .../csrc/peer_memory/peer_memory_cuda.cu | 14 +- .../csrc/peer_memory/peer_memory_cuda.cuh | 3 +- .../csrc/transducer/transducer_joint.cpp | 76 ++- .../transducer/transducer_joint_kernel.cu | 52 +- .../csrc/transducer/transducer_loss.cpp | 57 +- .../csrc/transducer/transducer_loss_kernel.cu | 27 +- apex/contrib/csrc/xentropy/interface.cpp | 44 +- apex/contrib/cudnn_gbn/batch_norm.py | 4 +- apex/contrib/fmha/fmha.py | 2 +- apex/contrib/focal_loss/__init__.py | 2 +- apex/contrib/focal_loss/focal_loss.py | 2 +- apex/contrib/group_norm/group_norm.py | 4 +- apex/contrib/groupbn/__init__.py | 2 +- apex/contrib/groupbn/batch_norm.py | 2 +- apex/contrib/index_mul_2d/index_mul_2d.py | 2 +- apex/contrib/layer_norm/layer_norm.py | 2 +- .../fast_encdec_multihead_attn_func.py | 2 +- ...ast_encdec_multihead_attn_norm_add_func.py | 2 +- .../fast_self_multihead_attn_func.py | 2 +- .../fast_self_multihead_attn_norm_add_func.py | 2 +- .../mask_softmax_dropout_func.py | 2 +- .../optimizers/distributed_fused_adam.py | 6 +- .../optimizers/distributed_fused_lamb.py | 8 +- apex/contrib/optimizers/fp16_optimizer.py | 2 +- apex/contrib/optimizers/fused_adam.py | 2 +- apex/contrib/optimizers/fused_lamb.py | 4 +- apex/contrib/optimizers/fused_sgd.py | 2 +- .../peer_memory/peer_halo_exchanger_1d.py | 2 +- apex/contrib/peer_memory/peer_memory.py | 2 +- .../permutation_search_kernels.cu | 375 +++++++----- .../permutation_utilities.py | 4 +- .../test/layer_norm/test_fast_layer_norm.py | 2 +- apex/contrib/transducer/transducer.py | 4 +- apex/contrib/xentropy/softmax_xentropy.py | 2 +- apex/fused_dense/fused_dense.py | 23 +- apex/mlp/mlp.py | 2 +- apex/multi_tensor_apply/multi_tensor_apply.py | 2 +- apex/normalization/fused_layer_norm.py | 24 +- apex/optimizers/fused_adagrad.py | 2 +- apex/optimizers/fused_adam.py | 2 +- apex/optimizers/fused_lamb.py | 2 +- apex/optimizers/fused_mixed_precision_lamb.py | 2 +- apex/optimizers/fused_novograd.py | 2 +- apex/optimizers/fused_sgd.py | 2 +- csrc/amp_C_frontend.cpp | 252 ++++++-- csrc/fused_dense.cpp | 33 +- csrc/fused_dense_cuda.cu | 1 - csrc/layer_norm_cuda.cpp | 104 +++- .../fused_rotary_positional_embedding.cpp | 92 +-- .../fused_rotary_positional_embedding.h | 2 +- .../fused_rotary_positional_embedding_cuda.cu | 57 +- csrc/megatron/fused_weight_gradient_dense.cpp | 30 +- ...d_weight_gradient_dense_16bit_prec_cuda.cu | 2 +- .../fused_weight_gradient_dense_cuda.cu | 2 +- .../generic_scaled_masked_softmax.cpp | 34 +- .../generic_scaled_masked_softmax_cuda.cu | 11 +- csrc/megatron/scaled_masked_softmax.cpp | 48 +- csrc/megatron/scaled_masked_softmax_cuda.cu | 11 +- csrc/megatron/scaled_softmax.cpp | 32 +- csrc/megatron/scaled_softmax_cuda.cu | 9 +- .../scaled_upper_triang_masked_softmax.cpp | 35 +- ...scaled_upper_triang_masked_softmax_cuda.cu | 9 +- csrc/mlp.cpp | 29 +- csrc/mlp_cuda.cu | 1 - csrc/syncbn.cpp | 58 +- setup.py | 104 ++-- tests/L0/run_optimizers/test_lamb.py | 2 +- .../synced_batchnorm/single_gpu_unit_test.py | 2 +- .../synced_batchnorm/test_groups.py | 2 +- .../synced_batchnorm/two_gpu_unit_test.py | 2 +- 139 files changed, 5255 insertions(+), 1326 deletions(-) create mode 100644 apex/_custom_ops.py create mode 100644 apex/_extensions/__init__.py create mode 100644 apex/_extensions/amp_C.py create mode 100644 apex/_extensions/apex_C.py create mode 100644 apex/_extensions/bnp.py create mode 100644 apex/_extensions/cudnn_gbn_lib.py create mode 100644 apex/_extensions/distributed_adam_cuda.py create mode 100644 apex/_extensions/distributed_lamb_cuda.py create mode 100644 apex/_extensions/fast_bottleneck.py create mode 100644 apex/_extensions/fast_layer_norm.py create mode 100644 apex/_extensions/fast_multihead_attn.py create mode 100644 apex/_extensions/fmhalib.py create mode 100644 apex/_extensions/focal_loss_cuda.py create mode 100644 apex/_extensions/fused_adam_cuda.py create mode 100644 apex/_extensions/fused_conv_bias_relu.py create mode 100644 apex/_extensions/fused_dense_cuda.py create mode 100644 apex/_extensions/fused_index_mul_2d.py create mode 100644 apex/_extensions/fused_lamb_cuda.py create mode 100644 apex/_extensions/fused_layer_norm_cuda.py create mode 100644 apex/_extensions/fused_rotary_positional_embedding.py create mode 100644 apex/_extensions/fused_weight_gradient_mlp_cuda.py create mode 100644 apex/_extensions/generic_scaled_masked_softmax_cuda.py create mode 100644 apex/_extensions/group_norm_cuda.py create mode 100644 apex/_extensions/group_norm_v2_cuda.py create mode 100644 apex/_extensions/mlp_cuda.py create mode 100644 apex/_extensions/nccl_p2p_cuda.py create mode 100644 apex/_extensions/peer_memory_cuda.py create mode 100644 apex/_extensions/permutation_search_cuda.py create mode 100644 apex/_extensions/scaled_masked_softmax_cuda.py create mode 100644 apex/_extensions/scaled_softmax_cuda.py create mode 100644 apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py create mode 100644 apex/_extensions/syncbn.py create mode 100644 apex/_extensions/transducer_joint_cuda.py create mode 100644 apex/_extensions/transducer_loss_cuda.py create mode 100644 apex/_extensions/xentropy_cuda.py diff --git a/apex/_custom_ops.py b/apex/_custom_ops.py new file mode 100644 index 000000000..6a74ef22f --- /dev/null +++ b/apex/_custom_ops.py @@ -0,0 +1,42 @@ +from pathlib import Path + +import torch + +_loaded_libraries = set() + + +def load_custom_op_library(extension_name, anchor_file): + base_dir = Path(anchor_file).resolve().parent + search_dirs = [base_dir, base_dir.parent, base_dir.parent.parent] + candidates = sorted( + { + candidate + for directory in search_dirs + for candidate in directory.glob(f"{extension_name}*.so") + }, + key=lambda path: (".cpython-" in path.name, path.name), + ) + if not candidates: + raise ImportError(f"Could not find shared library for {extension_name!r} next to {anchor_file}") + + library = str(candidates[0]) + if library not in _loaded_libraries: + torch.ops.load_library(library) + _loaded_libraries.add(library) + return library + + +def scalar_float(value): + if isinstance(value, torch.Tensor): + return float(value.item()) + return float(value) + + +def scalar_int(value): + if isinstance(value, torch.Tensor): + return int(value.item()) + return int(value) + + +def tensor_list_arg(value): + return [list(tensors) for tensors in value] diff --git a/apex/_extensions/__init__.py b/apex/_extensions/__init__.py new file mode 100644 index 000000000..0e9ebbece --- /dev/null +++ b/apex/_extensions/__init__.py @@ -0,0 +1 @@ +"""Private Python shims for APEX dispatcher libraries.""" diff --git a/apex/_extensions/amp_C.py b/apex/_extensions/amp_C.py new file mode 100644 index 000000000..7fb2fce23 --- /dev/null +++ b/apex/_extensions/amp_C.py @@ -0,0 +1,350 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, tensor_list_arg + + +load_custom_op_library("_amp_C", __file__) +_ops = torch.ops.apex + + +def multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale): + return _ops.amp_multi_tensor_scale(chunk_size, noop_flag, tensor_list_arg(tensor_lists), scalar_float(scale)) + + +def multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, +): + return _ops.amp_multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(wd), + scalar_float(momentum), + scalar_float(dampening), + scalar_float(lr), + nesterov, + first_run, + wd_after_momentum, + scalar_float(scale), + ) + + +def multi_tensor_axpby(chunk_size, noop_flag, tensor_lists, a, b, arg_to_check): + return _ops.amp_multi_tensor_axpby( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(a), + scalar_float(b), + arg_to_check, + ) + + +def multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor_python=None): + return _ops.amp_multi_tensor_l2norm(chunk_size, noop_flag, tensor_list_arg(tensor_lists), per_tensor_python) + + +def multi_tensor_l2norm_mp(chunk_size, noop_flag, tensor_lists, per_tensor_python=None): + return _ops.amp_multi_tensor_l2norm_mp(chunk_size, noop_flag, tensor_list_arg(tensor_lists), per_tensor_python) + + +def multi_tensor_l2norm_scale(chunk_size, noop_flag, tensor_lists, scale, per_tensor_python=None): + return _ops.amp_multi_tensor_l2norm_scale( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), scalar_float(scale), per_tensor_python + ) + + +def multi_tensor_unscale_l2norm(chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor_python=None): + return _ops.amp_multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), inv_scale, per_tensor_python + ) + + +def multi_tensor_lamb_stage1_cuda( + chunk_size, + noop_flag, + tensor_lists, + per_tensor_decay, + step, + beta1, + beta2, + epsilon, + global_grad_norm, + max_global_grad_norm, +): + return _ops.amp_multi_tensor_lamb_stage1_cuda( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + per_tensor_decay, + step, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + global_grad_norm, + scalar_float(max_global_grad_norm), + ) + + +def multi_tensor_lamb_stage2_cuda( + chunk_size, + noop_flag, + tensor_lists, + per_tensor_param_norm, + per_tensor_update_norm, + lr, + weight_decay, + use_nvlamb_python=None, +): + return _ops.amp_multi_tensor_lamb_stage2_cuda( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + per_tensor_param_norm, + per_tensor_update_norm, + scalar_float(lr), + scalar_float(weight_decay), + use_nvlamb_python, + ) + + +def multi_tensor_adam( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, +): + return _ops.amp_multi_tensor_adam( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + mode, + bias_correction, + scalar_float(weight_decay), + ) + + +def multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, +): + return _ops.amp_multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + lr, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + mode, + bias_correction, + scalar_float(weight_decay), + inv_scale, + ) + + +def multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, +): + return _ops.amp_multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + lr, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + mode, + bias_correction, + scalar_float(weight_decay), + inv_scale, + ) + + +def multi_tensor_adagrad(chunk_size, noop_flag, tensor_lists, lr, epsilon, mode, weight_decay): + return _ops.amp_multi_tensor_adagrad( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(epsilon), + mode, + scalar_float(weight_decay), + ) + + +def multi_tensor_novograd( + chunk_size, + noop_flag, + tensor_lists, + grad_norms, + lr, + beta1, + beta2, + epsilon, + step, + bias_correction, + weight_decay, + grad_averaging, + mode, + norm_type, +): + return _ops.amp_multi_tensor_novograd( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + grad_norms, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + bias_correction, + scalar_float(weight_decay), + grad_averaging, + mode, + norm_type, + ) + + +def multi_tensor_lamb( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + bias_correction, + weight_decay, + grad_averaging, + mode, + global_grad_norm, + max_grad_norm, + use_nvlamb_python=None, +): + return _ops.amp_multi_tensor_lamb( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + bias_correction, + scalar_float(weight_decay), + grad_averaging, + mode, + global_grad_norm, + scalar_float(max_grad_norm), + use_nvlamb_python, + ) + + +def multi_tensor_lamb_mp( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + bias_correction, + weight_decay, + grad_averaging, + mode, + global_grad_norm, + max_grad_norm, + use_nvlamb_python, + found_inf, + inv_scale, +): + return _ops.amp_multi_tensor_lamb_mp( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + lr, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + bias_correction, + scalar_float(weight_decay), + grad_averaging, + mode, + global_grad_norm, + max_grad_norm, + use_nvlamb_python, + found_inf, + inv_scale, + ) + + +def update_scale_hysteresis( + current_scale, + growth_tracker, + hysteresis_tracker, + found_inf, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, +): + return _ops.amp_update_scale_hysteresis( + current_scale, + growth_tracker, + hysteresis_tracker, + found_inf, + scalar_float(growth_factor), + scalar_float(backoff_factor), + growth_interval, + hysteresis, + ) diff --git a/apex/_extensions/apex_C.py b/apex/_extensions/apex_C.py new file mode 100644 index 000000000..e96db76a1 --- /dev/null +++ b/apex/_extensions/apex_C.py @@ -0,0 +1,18 @@ +import torch + + +def flatten(tensors): + tensors = list(tensors) + if len(tensors) == 0: + return torch.tensor([]) + return torch.cat([tensor.contiguous().view(-1) for tensor in tensors], dim=0) + + +def unflatten(flat, tensors): + outputs = [] + offset = 0 + for tensor in tensors: + numel = tensor.numel() + outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) + offset += numel + return outputs diff --git a/apex/_extensions/bnp.py b/apex/_extensions/bnp.py new file mode 100644 index 000000000..38d1c25b5 --- /dev/null +++ b/apex/_extensions/bnp.py @@ -0,0 +1,263 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_bnp", __file__) +_ops = torch.ops.apex + + +def _optional_ptr(value): + return None if value is None else scalar_int(value) + + +def get_buffer_size(bn_sync_steps): + return _ops.bnp_get_buffer_size(scalar_int(bn_sync_steps)) + + +def get_data_ptr(data): + return _ops.bnp_get_data_ptr(data) + + +def get_remote_data_ptr(handle, offset): + return _ops.bnp_get_remote_data_ptr(handle, scalar_int(offset)) + + +def close_remote_data(handle): + return _ops.bnp_close_remote_data(handle) + + +def bn_fwd_nhwc( + x, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + ret_cta, + momentum, + epsilon, + fuse_relu, + my_data, + pair_data, + pair_data2, + pair_data3, + bn_group, + magic, + occupancy, + grid_dim_x, + coop, +): + return _ops.bnp_bn_fwd_nhwc( + x, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + ret_cta, + scalar_float(momentum), + scalar_float(epsilon), + bool(fuse_relu), + _optional_ptr(my_data), + _optional_ptr(pair_data), + _optional_ptr(pair_data2), + _optional_ptr(pair_data3), + scalar_int(bn_group), + magic, + scalar_int(occupancy), + scalar_int(grid_dim_x), + bool(coop), + ) + + +def bn_fwd_eval_nhwc(x, scale, bias, running_mean, running_inv_var, ret_cta, bn_group, momentum, epsilon, fuse_relu): + return _ops.bnp_bn_fwd_eval_nhwc( + x, + scale, + bias, + running_mean, + running_inv_var, + ret_cta, + scalar_int(bn_group), + scalar_float(momentum), + scalar_float(epsilon), + bool(fuse_relu), + ) + + +def bn_bwd_nhwc( + x, + dy, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + ret_cta, + momentum, + epsilon, + fuse_relu, + my_data, + pair_data, + pair_data2, + pair_data3, + bn_group, + magic, + occupancy, + grid_dim_x, + coop, +): + return _ops.bnp_bn_bwd_nhwc( + x, + dy, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + ret_cta, + scalar_float(momentum), + scalar_float(epsilon), + bool(fuse_relu), + _optional_ptr(my_data), + _optional_ptr(pair_data), + _optional_ptr(pair_data2), + _optional_ptr(pair_data3), + scalar_int(bn_group), + magic, + scalar_int(occupancy), + scalar_int(grid_dim_x), + bool(coop), + ) + + +def bn_fwd_nhwc_occupancy(): + return _ops.bnp_bn_fwd_nhwc_occupancy() + + +def bn_bwd_nhwc_occupancy(): + return _ops.bnp_bn_bwd_nhwc_occupancy() + + +def bn_addrelu_fwd_nhwc( + x, + z, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + bitmask, + ret_cta, + momentum, + epsilon, + my_data, + pair_data, + pair_data2, + pair_data3, + bn_group, + magic, + occupancy, + grid_dim_x, + coop, +): + return _ops.bnp_bn_addrelu_fwd_nhwc( + x, + z, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + bitmask, + ret_cta, + scalar_float(momentum), + scalar_float(epsilon), + _optional_ptr(my_data), + _optional_ptr(pair_data), + _optional_ptr(pair_data2), + _optional_ptr(pair_data3), + scalar_int(bn_group), + magic, + scalar_int(occupancy), + scalar_int(grid_dim_x), + bool(coop), + ) + + +def bn_addrelu_fwd_eval_nhwc(x, z, scale, bias, running_mean, running_inv_var, ret_cta, bn_group, momentum, epsilon): + return _ops.bnp_bn_addrelu_fwd_eval_nhwc( + x, + z, + scale, + bias, + running_mean, + running_inv_var, + ret_cta, + scalar_int(bn_group), + scalar_float(momentum), + scalar_float(epsilon), + ) + + +def bn_addrelu_bwd_nhwc( + x, + dy, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + bitmask, + ret_cta, + momentum, + epsilon, + my_data, + pair_data, + pair_data2, + pair_data3, + bn_group, + magic, + occupancy, + grid_dim_x, + coop, +): + return _ops.bnp_bn_addrelu_bwd_nhwc( + x, + dy, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + bitmask, + ret_cta, + scalar_float(momentum), + scalar_float(epsilon), + _optional_ptr(my_data), + _optional_ptr(pair_data), + _optional_ptr(pair_data2), + _optional_ptr(pair_data3), + scalar_int(bn_group), + magic, + scalar_int(occupancy), + scalar_int(grid_dim_x), + bool(coop), + ) + + +def bn_addrelu_fwd_nhwc_occupancy(): + return _ops.bnp_bn_addrelu_fwd_nhwc_occupancy() + + +def bn_addrelu_bwd_nhwc_occupancy(): + return _ops.bnp_bn_addrelu_bwd_nhwc_occupancy() diff --git a/apex/_extensions/cudnn_gbn_lib.py b/apex/_extensions/cudnn_gbn_lib.py new file mode 100644 index 000000000..96e9c5891 --- /dev/null +++ b/apex/_extensions/cudnn_gbn_lib.py @@ -0,0 +1,55 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_cudnn_gbn_lib", __file__) +_ops = torch.ops.apex + + +def _int_list(values): + return [scalar_int(value) for value in values] + + +def forward( + x, + scale, + bias, + running_mean, + running_var, + minibatch_mean, + minibatch_inv_var, + momentum, + epsilon, + bn_group, + rank_id, + peer_buffers, +): + return _ops.cudnn_gbn_forward( + x, + scale, + bias, + running_mean, + running_var, + minibatch_mean, + minibatch_inv_var, + scalar_float(momentum), + scalar_float(epsilon), + scalar_int(bn_group), + scalar_int(rank_id), + _int_list(peer_buffers), + ) + + +def backward(x, dy, scale, minibatch_mean, minibatch_inv_var, epsilon, bn_group, rank_id, peer_buffers): + return _ops.cudnn_gbn_backward( + x, + dy, + scale, + minibatch_mean, + minibatch_inv_var, + scalar_float(epsilon), + scalar_int(bn_group), + scalar_int(rank_id), + _int_list(peer_buffers), + ) diff --git a/apex/_extensions/distributed_adam_cuda.py b/apex/_extensions/distributed_adam_cuda.py new file mode 100644 index 000000000..26499e541 --- /dev/null +++ b/apex/_extensions/distributed_adam_cuda.py @@ -0,0 +1,97 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int, tensor_list_arg + + +load_custom_op_library("_distributed_adam_cuda", __file__) +_ops = torch.ops.apex + + +def multi_tensor_fused_adam( + chunk_size, + noop_flag, + tensor_lists, + grad_scale, + lr, + beta1, + beta2, + eps, + step, + mode, + bias_correction, + weight_decay, +): + return _ops.distributed_adam_multi_tensor_fused_adam( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + grad_scale, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(weight_decay), + ) + + +def multi_tensor_fused_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + grad_scale, + lr, + beta1, + beta2, + eps, + step, + mode, + bias_correction, + weight_decay, +): + return _ops.distributed_adam_multi_tensor_fused_adam_capturable( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + grad_scale, + lr, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + step, + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(weight_decay), + ) + + +def multi_tensor_fused_adam_with_param_remainders( + chunk_size, + noop_flag, + tensor_lists, + grad_scale, + lr, + beta1, + beta2, + eps, + step, + mode, + bias_correction, + weight_decay, +): + return _ops.distributed_adam_multi_tensor_fused_adam_with_param_remainders( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + grad_scale, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(weight_decay), + ) diff --git a/apex/_extensions/distributed_lamb_cuda.py b/apex/_extensions/distributed_lamb_cuda.py new file mode 100644 index 000000000..1e2948e8b --- /dev/null +++ b/apex/_extensions/distributed_lamb_cuda.py @@ -0,0 +1,67 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int, tensor_list_arg + + +load_custom_op_library("_distributed_lamb_cuda", __file__) +_ops = torch.ops.apex + + +def multi_tensor_lamb_compute_update_term( + chunk_size, + noop_flag, + tensor_lists, + per_tensor_beta1, + per_tensor_beta2, + per_tensor_beta3, + per_tensor_bias_correction, + step, + per_tensor_epsilon, + mode, + per_tensor_decay, + global_scale, + global_grad_norm, + max_grad_norm, +): + return _ops.distributed_lamb_compute_update_term( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + per_tensor_beta1, + per_tensor_beta2, + per_tensor_beta3, + per_tensor_bias_correction, + step, + per_tensor_epsilon, + scalar_int(mode), + per_tensor_decay, + global_scale, + global_grad_norm, + scalar_float(max_grad_norm), + ) + + +def multi_tensor_lamb_update_weights( + chunk_size, + noop_flag, + tensor_lists, + per_tensor_param_norm, + per_tensor_update_norm, + update_norm_offset, + learning_rate, + per_tensor_decay, + global_grad_norm, + use_nvlamb, +): + return _ops.distributed_lamb_update_weights( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + per_tensor_param_norm, + per_tensor_update_norm, + update_norm_offset, + learning_rate, + per_tensor_decay, + global_grad_norm, + use_nvlamb, + ) diff --git a/apex/_extensions/fast_bottleneck.py b/apex/_extensions/fast_bottleneck.py new file mode 100644 index 000000000..99091d2de --- /dev/null +++ b/apex/_extensions/fast_bottleneck.py @@ -0,0 +1,178 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_fast_bottleneck", __file__) +_ops = torch.ops.apex + + +def _tensor_list(values): + return list(values) + + +def forward(explicit_nhwc, stride_1x1, inputs): + return _ops.fast_bottleneck_forward( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs) + ) + + +def backward(explicit_nhwc, stride_1x1, inputs): + return _ops.fast_bottleneck_backward( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs) + ) + + +def forward_init(explicit_nhwc, stride_1x1, inputs): + return _ops.fast_bottleneck_forward_init( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs) + ) + + +def forward_out1(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_forward_out1( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def forward_out2(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_forward_out2( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def forward_out2_mask(explicit_nhwc, stride_1x1, inputs, outputs, thresholdTop, thresholdBottom): + return _ops.fast_bottleneck_forward_out2_mask( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + thresholdTop, + thresholdBottom, + ) + + +def forward_out2_halo(explicit_nhwc, fat_halo_y1, inputs): + return _ops.fast_bottleneck_forward_out2_halo( + bool(explicit_nhwc), fat_halo_y1, _tensor_list(inputs) + ) + + +def forward_out2_halo_corr(explicit_nhwc, slim_halo_y1, inputs, w1by3, out2_part_halo): + return _ops.fast_bottleneck_forward_out2_halo_corr( + bool(explicit_nhwc), slim_halo_y1, _tensor_list(inputs), w1by3, out2_part_halo + ) + + +def forward_out2_pad(explicit_nhwc, stride_1x1, inputs, outputs, out1_pad): + return _ops.fast_bottleneck_forward_out2_pad( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), out1_pad + ) + + +def forward_rest(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_forward_rest( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def backward_init(explicit_nhwc, stride_1x1, inputs): + return _ops.fast_bottleneck_backward_init( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs) + ) + + +def backward_grad_out2(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_backward_grad_out2( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def backward_grad_out1(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2): + return _ops.fast_bottleneck_backward_grad_out1( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), grad_out2 + ) + + +def backward_grad_out1_mask(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2, thresholdTop, thresholdBottom): + return _ops.fast_bottleneck_backward_grad_out1_mask( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2, + thresholdTop, + thresholdBottom, + ) + + +def backward_grad_out1_halo(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2_halo, relu1_halo): + return _ops.fast_bottleneck_backward_grad_out1_halo( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2_halo, + relu1_halo, + ) + + +def backward_grad_out1_halo_corr( + explicit_nhwc, stride_1x1, inputs, w1by3, outputs, grad_out2_halo, relu1_halo, part_grad_out1 +): + return _ops.fast_bottleneck_backward_grad_out1_halo_corr( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + w1by3, + _tensor_list(outputs), + grad_out2_halo, + relu1_halo, + part_grad_out1, + ) + + +def backward_wgrad2_pad(explicit_nhwc, stride_1x1, inputs, outputs, input, grad_out2): + return _ops.fast_bottleneck_backward_wgrad2_pad( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), input, grad_out2 + ) + + +def backward_wgrad2(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2): + return _ops.fast_bottleneck_backward_wgrad2( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), grad_out2 + ) + + +def backward_wgrad2_halo(explicit_nhwc, stride_1x1, inputs, outputs, input, grad_out2_halo): + return _ops.fast_bottleneck_backward_wgrad2_halo( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + input, + grad_out2_halo, + ) + + +def backward_wgrad3(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_backward_wgrad3( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def backward_wgrad1(explicit_nhwc, stride_1x1, inputs, outputs, grad_out1): + return _ops.fast_bottleneck_backward_wgrad1( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), grad_out1 + ) + + +def backward_rest(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2, grad_out1): + return _ops.fast_bottleneck_backward_rest( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2, + grad_out1, + ) diff --git a/apex/_extensions/fast_layer_norm.py b/apex/_extensions/fast_layer_norm.py new file mode 100644 index 000000000..00f84c3c2 --- /dev/null +++ b/apex/_extensions/fast_layer_norm.py @@ -0,0 +1,28 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_fast_layer_norm", __file__) +_ops = torch.ops.apex + + +def ln_fwd(x, gamma, beta, epsilon): + return _ops.fast_layer_norm_ln_fwd( + x, + gamma, + beta, + scalar_float(epsilon), + ) + + +def ln_bwd(dz, x_or_z, mu, rsigma, gamma, beta, memory_efficient): + return _ops.fast_layer_norm_ln_bwd( + dz, + x_or_z, + mu, + rsigma, + gamma, + beta, + memory_efficient, + ) diff --git a/apex/_extensions/fast_multihead_attn.py b/apex/_extensions/fast_multihead_attn.py new file mode 100644 index 000000000..bf863accb --- /dev/null +++ b/apex/_extensions/fast_multihead_attn.py @@ -0,0 +1,411 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_fast_multihead_attn", __file__) +_ops = torch.ops.apex + + +def additive_mask_softmax_dropout_forward(use_mask, is_training, heads, input, pad_mask, dropout_prob): + return _ops.fast_multihead_attn_additive_mask_softmax_dropout_forward( + bool(use_mask), bool(is_training), scalar_int(heads), input, pad_mask, scalar_float(dropout_prob) + ) + + +def additive_mask_softmax_dropout_backward(use_mask, heads, output_grads, softmax_results, dropout_mask, dropout_prob): + return _ops.fast_multihead_attn_additive_mask_softmax_dropout_backward( + bool(use_mask), scalar_int(heads), output_grads, softmax_results, dropout_mask, scalar_float(dropout_prob) + ) + + +def mask_softmax_dropout_forward(use_mask, is_training, heads, input, pad_mask, dropout_prob): + return _ops.fast_multihead_attn_mask_softmax_dropout_forward( + bool(use_mask), bool(is_training), scalar_int(heads), input, pad_mask, scalar_float(dropout_prob) + ) + + +def mask_softmax_dropout_backward(use_mask, heads, output_grads, softmax_results, dropout_mask, padding_mask, dropout_prob): + return _ops.fast_multihead_attn_mask_softmax_dropout_backward( + bool(use_mask), + scalar_int(heads), + output_grads, + softmax_results, + dropout_mask, + padding_mask, + scalar_float(dropout_prob), + ) + + +def encdec_multihead_attn_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs_q, + inputs_kv, + input_weights_q, + input_weights_kv, + output_weights, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_encdec_multihead_attn_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs_q, + inputs_kv, + input_weights_q, + input_weights_kv, + output_weights, + pad_mask, + scalar_float(dropout_prob), + ) + + +def encdec_multihead_attn_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_q_results, + input_lin_kv_results, + inputs_q, + inputs_kv, + input_weights_q, + input_weights_kv, + output_weights, + dropout_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_encdec_multihead_attn_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_q_results, + input_lin_kv_results, + inputs_q, + inputs_kv, + input_weights_q, + input_weights_kv, + output_weights, + dropout_mask, + scalar_float(dropout_prob), + ) + + +def encdec_multihead_attn_norm_add_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs_q, + inputs_kv, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights_q, + input_weights_kv, + output_weights, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_encdec_multihead_attn_norm_add_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs_q, + inputs_kv, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights_q, + input_weights_kv, + output_weights, + pad_mask, + scalar_float(dropout_prob), + ) + + +def encdec_multihead_attn_norm_add_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_q_results, + input_lin_kv_results, + lyr_nrm_results, + lyr_nrm_mean, + lyr_nrm_invvar, + inputs_q, + inputs_kv, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights_q, + input_weights_kv, + output_weights, + dropout_mask, + dropout_add_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_encdec_multihead_attn_norm_add_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_q_results, + input_lin_kv_results, + lyr_nrm_results, + lyr_nrm_mean, + lyr_nrm_invvar, + inputs_q, + inputs_kv, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights_q, + input_weights_kv, + output_weights, + dropout_mask, + dropout_add_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs, + input_weights, + output_weights, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs, + input_weights, + output_weights, + pad_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_bias_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs, + input_weights, + output_weights, + input_biases, + output_biases, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_bias_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs, + input_weights, + output_weights, + input_biases, + output_biases, + pad_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_bias_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_bias_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_bias_additive_mask_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs, + input_weights, + output_weights, + input_biases, + output_biases, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_bias_additive_mask_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs, + input_weights, + output_weights, + input_biases, + output_biases, + pad_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_bias_additive_mask_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + bmm1_results, + pad_mask, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_bias_additive_mask_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + bmm1_results, + pad_mask, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_norm_add_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights, + output_weights, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_norm_add_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights, + output_weights, + pad_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_norm_add_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + lyr_nrm_results, + lyr_nrm_mean, + lyr_nrm_invvar, + inputs, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights, + output_weights, + dropout_mask, + dropout_add_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_norm_add_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + lyr_nrm_results, + lyr_nrm_mean, + lyr_nrm_invvar, + inputs, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights, + output_weights, + dropout_mask, + dropout_add_mask, + scalar_float(dropout_prob), + ) diff --git a/apex/_extensions/fmhalib.py b/apex/_extensions/fmhalib.py new file mode 100644 index 000000000..c4267445b --- /dev/null +++ b/apex/_extensions/fmhalib.py @@ -0,0 +1,48 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_fmhalib", __file__) +_ops = torch.ops.apex + + +def fwd(qkv, cu_seqlens, p_dropout, max_seq_len, is_training, is_nl, zero_tensors, gen): + return _ops.fmha_fwd( + qkv, + cu_seqlens, + scalar_float(p_dropout), + scalar_int(max_seq_len), + bool(is_training), + bool(is_nl), + bool(zero_tensors), + gen, + ) + + +def fwd_nl(qkv, cu_seqlens, p_dropout, max_seq_len, is_training, is_nl, zero_tensors, gen): + return fwd(qkv, cu_seqlens, p_dropout, max_seq_len, is_training, True, zero_tensors, gen) + + +def bwd(dout, qkv, softmax, cu_seqlens, p_dropout, max_seq_len, zero_tensors): + return _ops.fmha_bwd( + dout, + qkv, + softmax, + cu_seqlens, + scalar_float(p_dropout), + scalar_int(max_seq_len), + bool(zero_tensors), + ) + + +def bwd_nl(dout, qkv, softmax, cu_seqlens, p_dropout, max_seq_len, zero_tensors): + return _ops.fmha_bwd_nl( + dout, + qkv, + softmax, + cu_seqlens, + scalar_float(p_dropout), + scalar_int(max_seq_len), + bool(zero_tensors), + ) diff --git a/apex/_extensions/focal_loss_cuda.py b/apex/_extensions/focal_loss_cuda.py new file mode 100644 index 000000000..edf9924a0 --- /dev/null +++ b/apex/_extensions/focal_loss_cuda.py @@ -0,0 +1,21 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_focal_loss_cuda", __file__) + + +def forward(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, alpha, gamma, smoothing_factor): + return torch.ops.apex.focal_loss_forward( + cls_output, + cls_targets_at_level, + num_positives_sum, + num_real_classes, + scalar_float(alpha), + scalar_float(gamma), + scalar_float(smoothing_factor), + ) + + +backward = torch.ops.apex.focal_loss_backward diff --git a/apex/_extensions/fused_adam_cuda.py b/apex/_extensions/fused_adam_cuda.py new file mode 100644 index 000000000..36c4d2d45 --- /dev/null +++ b/apex/_extensions/fused_adam_cuda.py @@ -0,0 +1,154 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int, tensor_list_arg + + +load_custom_op_library("_fused_adam_cuda", __file__) + + +def strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first): + return torch.ops.apex.fused_adam_strided_check_finite( + overflow_flag, p_copy, scalar_int(stride), scalar_int(clear_overflow_first) + ) + + +def adam( + p, + p_copy, + m, + v, + g, + lr, + beta1, + beta2, + eps, + grad_scale, + step, + mode, + bias_correction, + decay, +): + return torch.ops.apex.fused_adam_adam( + p, + p_copy, + m, + v, + g, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_float(grad_scale), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(decay), + ) + + +def reversible_adam( + p, + p_copy, + m, + v, + g, + lr, + beta1, + beta2, + eps, + grad_scale, + step, + mode, + bias_correction, + decay, +): + return torch.ops.apex.fused_adam_reversible_adam( + p, + p_copy, + m, + v, + g, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_float(grad_scale), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(decay), + ) + + +def adam_mt( + chunk_size, + overflow_flag, + tensor_lists, + lr, + beta1, + beta2, + eps, + grad_scale, + step, + mode, + bias_correction, + decay, +): + return torch.ops.apex.fused_adam_adam_mt( + scalar_int(chunk_size), + overflow_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_float(grad_scale), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(decay), + ) + + +def maybe_adam_undo( + overflow_flag, + p, + m, + v, + g, + lr, + beta1, + beta2, + eps, + grad_scale, + step, + mode, + bias_correction, + decay, +): + return torch.ops.apex.fused_adam_maybe_adam_undo( + overflow_flag, + p, + m, + v, + g, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_float(grad_scale), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(decay), + ) + + +def maybe_cast(overflow_flag, p_in, p_out): + return torch.ops.apex.fused_adam_maybe_cast(overflow_flag, p_in, p_out) + + +def maybe_cast_mt(chunk_size, overflow_flag, tensor_lists): + return torch.ops.apex.fused_adam_maybe_cast_mt( + scalar_int(chunk_size), overflow_flag, tensor_list_arg(tensor_lists) + ) diff --git a/apex/_extensions/fused_conv_bias_relu.py b/apex/_extensions/fused_conv_bias_relu.py new file mode 100644 index 000000000..64d58bbf6 --- /dev/null +++ b/apex/_extensions/fused_conv_bias_relu.py @@ -0,0 +1,67 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_fused_conv_bias_relu", __file__) +_ops = torch.ops.apex + + +def _tensor_list(inputs): + return list(inputs) + + +def forward(inputs, padding, stride): + return _ops.fused_conv_bias_relu_forward( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def backward(inputs, padding, stride): + return _ops.fused_conv_bias_relu_backward( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def forward_no_relu(inputs, padding, stride): + return _ops.fused_conv_bias_relu_forward_no_relu( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def backward_no_relu(inputs, padding, stride): + return _ops.fused_conv_bias_relu_backward_no_relu( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def forward_mask(inputs, padding, stride): + return _ops.fused_conv_bias_relu_forward_mask( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def forward_cscale_cbias_relu(inputs, padding, stride): + return _ops.fused_conv_bias_relu_forward_cscale_cbias_relu( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def backward_cscale_cbias_relu(inputs, padding, stride): + return _ops.fused_conv_bias_relu_backward_cscale_cbias_relu( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) diff --git a/apex/_extensions/fused_dense_cuda.py b/apex/_extensions/fused_dense_cuda.py new file mode 100644 index 000000000..c189bf8c9 --- /dev/null +++ b/apex/_extensions/fused_dense_cuda.py @@ -0,0 +1,11 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_dense_cuda", __file__) + +linear_bias_forward = torch.ops.apex.fused_dense_linear_bias_forward +linear_bias_backward = torch.ops.apex.fused_dense_linear_bias_backward +linear_gelu_linear_forward = torch.ops.apex.fused_dense_linear_gelu_linear_forward +linear_gelu_linear_backward = torch.ops.apex.fused_dense_linear_gelu_linear_backward diff --git a/apex/_extensions/fused_index_mul_2d.py b/apex/_extensions/fused_index_mul_2d.py new file mode 100644 index 000000000..6c475b506 --- /dev/null +++ b/apex/_extensions/fused_index_mul_2d.py @@ -0,0 +1,13 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_index_mul_2d", __file__) + +float_forward = torch.ops.apex.index_mul_2d_float_forward +float_backward = torch.ops.apex.index_mul_2d_float_backward +float_backward_backward = torch.ops.apex.index_mul_2d_float_backward_backward +half_forward = torch.ops.apex.index_mul_2d_half_forward +half_backward = torch.ops.apex.index_mul_2d_half_backward +half_backward_backward = torch.ops.apex.index_mul_2d_half_backward_backward diff --git a/apex/_extensions/fused_lamb_cuda.py b/apex/_extensions/fused_lamb_cuda.py new file mode 100644 index 000000000..2f3eeb843 --- /dev/null +++ b/apex/_extensions/fused_lamb_cuda.py @@ -0,0 +1,45 @@ +import torch + +from apex._custom_ops import ( + load_custom_op_library, + scalar_float, + scalar_int, + tensor_list_arg, +) + + +load_custom_op_library("_fused_lamb_cuda", __file__) + + +def lamb( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + bias_correction, + weight_decay, + grad_averaging, + mode, + global_grad_norm, + max_grad_norm, +): + return torch.ops.apex.fused_lamb_lamb( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + scalar_int(step), + scalar_int(bias_correction), + scalar_float(weight_decay), + scalar_int(grad_averaging), + scalar_int(mode), + scalar_float(global_grad_norm), + scalar_float(max_grad_norm), + ) diff --git a/apex/_extensions/fused_layer_norm_cuda.py b/apex/_extensions/fused_layer_norm_cuda.py new file mode 100644 index 000000000..1cc17eb23 --- /dev/null +++ b/apex/_extensions/fused_layer_norm_cuda.py @@ -0,0 +1,17 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_layer_norm_cuda", __file__) + +forward_affine = torch.ops.apex.fused_layer_norm_forward_affine +forward = torch.ops.apex.fused_layer_norm_forward +backward_affine = torch.ops.apex.fused_layer_norm_backward_affine +backward = torch.ops.apex.fused_layer_norm_backward +forward_affine_mixed_dtypes = torch.ops.apex.fused_layer_norm_forward_affine_mixed_dtypes +rms_forward_affine = torch.ops.apex.fused_layer_norm_rms_forward_affine +rms_forward = torch.ops.apex.fused_layer_norm_rms_forward +rms_backward_affine = torch.ops.apex.fused_layer_norm_rms_backward_affine +rms_backward = torch.ops.apex.fused_layer_norm_rms_backward +rms_forward_affine_mixed_dtypes = torch.ops.apex.fused_layer_norm_rms_forward_affine_mixed_dtypes diff --git a/apex/_extensions/fused_rotary_positional_embedding.py b/apex/_extensions/fused_rotary_positional_embedding.py new file mode 100644 index 000000000..ca27b6d3e --- /dev/null +++ b/apex/_extensions/fused_rotary_positional_embedding.py @@ -0,0 +1,15 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_rotary_positional_embedding", __file__) + +forward = torch.ops.apex.fused_rope_forward +backward = torch.ops.apex.fused_rope_backward +forward_cached = torch.ops.apex.fused_rope_forward_cached +backward_cached = torch.ops.apex.fused_rope_backward_cached +forward_thd = torch.ops.apex.fused_rope_forward_thd +backward_thd = torch.ops.apex.fused_rope_backward_thd +forward_2d = torch.ops.apex.fused_rope_forward_2d +backward_2d = torch.ops.apex.fused_rope_backward_2d diff --git a/apex/_extensions/fused_weight_gradient_mlp_cuda.py b/apex/_extensions/fused_weight_gradient_mlp_cuda.py new file mode 100644 index 000000000..fc9c78c43 --- /dev/null +++ b/apex/_extensions/fused_weight_gradient_mlp_cuda.py @@ -0,0 +1,9 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_weight_gradient_mlp_cuda", __file__) + +wgrad_gemm_accum_fp32 = torch.ops.apex.fused_weight_gradient_mlp_wgrad_gemm_accum_fp32 +wgrad_gemm_accum_fp16 = torch.ops.apex.fused_weight_gradient_mlp_wgrad_gemm_accum_fp16 diff --git a/apex/_extensions/generic_scaled_masked_softmax_cuda.py b/apex/_extensions/generic_scaled_masked_softmax_cuda.py new file mode 100644 index 000000000..76c212d48 --- /dev/null +++ b/apex/_extensions/generic_scaled_masked_softmax_cuda.py @@ -0,0 +1,14 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_generic_scaled_masked_softmax_cuda", __file__) + + +def forward(input, mask, scale_factor): + return torch.ops.apex.generic_scaled_masked_softmax_forward(input, mask, scalar_float(scale_factor)) + + +def backward(output_grads, softmax_results, scale_factor): + return torch.ops.apex.generic_scaled_masked_softmax_backward(output_grads, softmax_results, scalar_float(scale_factor)) diff --git a/apex/_extensions/group_norm_cuda.py b/apex/_extensions/group_norm_cuda.py new file mode 100644 index 000000000..8e97f1745 --- /dev/null +++ b/apex/_extensions/group_norm_cuda.py @@ -0,0 +1,32 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_group_norm_cuda", __file__) + + +def forward(input, groups, weight, bias, eps, passes, with_swish=False): + return torch.ops.apex.group_norm_forward( + input, + scalar_int(groups), + weight, + bias, + scalar_float(eps), + scalar_int(passes), + bool(with_swish), + ) + + +def backward(grad_output, sums, input, groups, weight, bias, eps, passes, with_swish=False): + return torch.ops.apex.group_norm_backward( + grad_output, + sums, + input, + scalar_int(groups), + weight, + bias, + scalar_float(eps), + scalar_int(passes), + bool(with_swish), + ) diff --git a/apex/_extensions/group_norm_v2_cuda.py b/apex/_extensions/group_norm_v2_cuda.py new file mode 100644 index 000000000..f6744ff22 --- /dev/null +++ b/apex/_extensions/group_norm_v2_cuda.py @@ -0,0 +1,33 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_group_norm_v2_cuda", __file__) + + +def gn(x, w, b, eps, silu, num_groups, mean_var_out=None, sm_margin=0): + return torch.ops.apex.group_norm_v2_gn( + x, + w, + b, + scalar_float(eps), + bool(silu), + scalar_int(num_groups), + mean_var_out, + scalar_int(sm_margin), + ) + + +def gn_bwd(grad_output, x, w, b, mean_var, eps, silu, num_groups, sm_margin=0): + return torch.ops.apex.group_norm_v2_gn_bwd( + grad_output, + x, + w, + b, + mean_var, + scalar_float(eps), + bool(silu), + scalar_int(num_groups), + scalar_int(sm_margin), + ) diff --git a/apex/_extensions/mlp_cuda.py b/apex/_extensions/mlp_cuda.py new file mode 100644 index 000000000..75a228c1b --- /dev/null +++ b/apex/_extensions/mlp_cuda.py @@ -0,0 +1,14 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_mlp_cuda", __file__) + + +def forward(use_bias, activation, inputs): + return torch.ops.apex.mlp_forward(use_bias, activation, list(inputs)) + + +def backward(use_bias, activation, grad_o, fprop_outputs, inputs): + return torch.ops.apex.mlp_backward(use_bias, activation, grad_o, list(fprop_outputs), list(inputs)) diff --git a/apex/_extensions/nccl_p2p_cuda.py b/apex/_extensions/nccl_p2p_cuda.py new file mode 100644 index 000000000..c0b44c6e7 --- /dev/null +++ b/apex/_extensions/nccl_p2p_cuda.py @@ -0,0 +1,53 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_nccl_p2p_cuda", __file__) +_ops = torch.ops.apex + + +def get_unique_nccl_id(n): + return _ops.nccl_p2p_get_unique_nccl_id(scalar_int(n)) + + +def init_nccl_comm(unique_nccl_id, my_rank, num_ranks): + return _ops.nccl_p2p_init_nccl_comm( + unique_nccl_id, + scalar_int(my_rank), + scalar_int(num_ranks), + ) + + +def left_right_halo_exchange_inplace( + handle, + left_rank, + right_rank, + left_output_halo, + right_output_halo, + left_input_halo, + right_input_halo, +): + return _ops.nccl_p2p_left_right_halo_exchange_inplace( + scalar_int(handle), + scalar_int(left_rank), + scalar_int(right_rank), + left_output_halo, + right_output_halo, + left_input_halo, + right_input_halo, + ) + + +def left_right_halo_exchange(handle, left_rank, right_rank, left_output_halo, right_output_halo): + return _ops.nccl_p2p_left_right_halo_exchange( + scalar_int(handle), + scalar_int(left_rank), + scalar_int(right_rank), + left_output_halo, + right_output_halo, + ) + + +def add_delay(delay): + return _ops.nccl_p2p_add_delay(scalar_int(delay)) diff --git a/apex/_extensions/peer_memory_cuda.py b/apex/_extensions/peer_memory_cuda.py new file mode 100644 index 000000000..c99127905 --- /dev/null +++ b/apex/_extensions/peer_memory_cuda.py @@ -0,0 +1,93 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_peer_memory_cuda", __file__) +_ops = torch.ops.apex + + +def _shape_arg(shape): + return [scalar_int(dim) for dim in shape] + + +def allocate_raw(size): + return _ops.peer_memory_allocate_raw(scalar_int(size)) + + +def free_raw(raw): + return _ops.peer_memory_free_raw(scalar_int(raw)) + + +def zero(raw, size): + return _ops.peer_memory_zero(scalar_int(raw), scalar_int(size)) + + +def get_raw_ipc_address(raw): + return _ops.peer_memory_get_raw_ipc_address(scalar_int(raw)) + + +def get_raw_peers(ipc_addresses, peer_rank, raw): + return _ops.peer_memory_get_raw_peers( + ipc_addresses, + scalar_int(peer_rank), + scalar_int(raw), + ) + + +def blob_view_half(raw, shape, channels_last): + return _ops.peer_memory_blob_view_half( + scalar_int(raw), + _shape_arg(shape), + bool(channels_last), + ) + + +def blob_view_float(raw, shape, channels_last): + return _ops.peer_memory_blob_view_float( + scalar_int(raw), + _shape_arg(shape), + bool(channels_last), + ) + + +def blob_view_int(raw, shape, channels_last): + return _ops.peer_memory_blob_view_int( + scalar_int(raw), + _shape_arg(shape), + bool(channels_last), + ) + + +def push_pull_halos_1d( + diagnostics, + explicit_nhwc, + numSM, + rank, + top_zero, + top_in_halo, + top_in_transfer, + top_out_transfer, + top_out_halo, + btm_zero, + btm_in_halo, + btm_in_transfer, + btm_out_transfer, + btm_out_halo, +): + return _ops.peer_memory_push_pull_halos_1d( + bool(diagnostics), + bool(explicit_nhwc), + scalar_int(numSM), + scalar_int(rank), + bool(top_zero), + top_in_halo, + top_in_transfer, + top_out_transfer, + top_out_halo, + bool(btm_zero), + btm_in_halo, + btm_in_transfer, + btm_out_transfer, + btm_out_halo, + ) diff --git a/apex/_extensions/permutation_search_cuda.py b/apex/_extensions/permutation_search_cuda.py new file mode 100644 index 000000000..aa3c93dcf --- /dev/null +++ b/apex/_extensions/permutation_search_cuda.py @@ -0,0 +1,76 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_permutation_search_cuda", __file__) + + +def _as_tensor(value): + if isinstance(value, torch.Tensor): + return value + return torch.from_numpy(value) + + +def _op(op_name): + return getattr(torch.ops.apex, op_name) + + +def sum_after_2_to_4(matrix, rows, cols, start_col, end_col, blocks, threads, output): + return _op("permutation_search_sum_after_2_to_4")( + _as_tensor(matrix), + scalar_int(rows), + scalar_int(cols), + scalar_int(start_col), + scalar_int(end_col), + scalar_int(blocks), + scalar_int(threads), + _as_tensor(output), + ) + + +def build_permute_map(matrix, rows, cols, stripes, num_groups, group_width, permutations, perm_length, improvements, best_indices): + return _op("permutation_search_build_permute_map")( + _as_tensor(matrix), + scalar_int(rows), + scalar_int(cols), + _as_tensor(stripes), + scalar_int(num_groups), + scalar_int(group_width), + _as_tensor(permutations), + scalar_int(perm_length), + _as_tensor(improvements), + _as_tensor(best_indices), + ) + + +def check_permutations( + matrix, + rows, + cols, + stripe_groups, + group_width, + num_groups, + permutations, + num_permutations, + improvement, + permutation, +): + return _op("permutation_search_check_permutations")( + _as_tensor(matrix), + scalar_int(rows), + scalar_int(cols), + _as_tensor(stripe_groups), + scalar_int(group_width), + scalar_int(num_groups), + _as_tensor(permutations), + scalar_int(num_permutations), + _as_tensor(improvement), + _as_tensor(permutation), + ) + + +def build_swap_map(matrix, rows, cols, stripe_pairs, output): + return _op("permutation_search_build_swap_map")( + _as_tensor(matrix), scalar_int(rows), scalar_int(cols), _as_tensor(stripe_pairs), _as_tensor(output) + ) diff --git a/apex/_extensions/scaled_masked_softmax_cuda.py b/apex/_extensions/scaled_masked_softmax_cuda.py new file mode 100644 index 000000000..f6de1cca3 --- /dev/null +++ b/apex/_extensions/scaled_masked_softmax_cuda.py @@ -0,0 +1,20 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_scaled_masked_softmax_cuda", __file__) + + +def forward(input, mask, scale_factor): + return torch.ops.apex.scaled_masked_softmax_forward(input, mask, scalar_float(scale_factor)) + + +def backward(output_grads, softmax_results, scale_factor): + return torch.ops.apex.scaled_masked_softmax_backward(output_grads, softmax_results, scalar_float(scale_factor)) + + +def get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads): + return torch.ops.apex.scaled_masked_softmax_get_batch_per_block( + scalar_int(query_seq_len), scalar_int(key_seq_len), scalar_int(batches), scalar_int(attn_heads) + ) diff --git a/apex/_extensions/scaled_softmax_cuda.py b/apex/_extensions/scaled_softmax_cuda.py new file mode 100644 index 000000000..f9bc15596 --- /dev/null +++ b/apex/_extensions/scaled_softmax_cuda.py @@ -0,0 +1,14 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_scaled_softmax_cuda", __file__) + + +def forward(input, scale_factor): + return torch.ops.apex.scaled_softmax_forward(input, scalar_float(scale_factor)) + + +def backward(output_grads, softmax_results, scale_factor): + return torch.ops.apex.scaled_softmax_backward(output_grads, softmax_results, scalar_float(scale_factor)) diff --git a/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py b/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py new file mode 100644 index 000000000..9f23bb41f --- /dev/null +++ b/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py @@ -0,0 +1,16 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_scaled_upper_triang_masked_softmax_cuda", __file__) + + +def forward(input, scale_factor): + return torch.ops.apex.scaled_upper_triang_masked_softmax_forward(input, scalar_float(scale_factor)) + + +def backward(output_grads, softmax_results, scale_factor): + return torch.ops.apex.scaled_upper_triang_masked_softmax_backward( + output_grads, softmax_results, scalar_float(scale_factor) + ) diff --git a/apex/_extensions/syncbn.py b/apex/_extensions/syncbn.py new file mode 100644 index 000000000..6dc9d8bf2 --- /dev/null +++ b/apex/_extensions/syncbn.py @@ -0,0 +1,17 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_syncbn", __file__) + +welford_mean_var = torch.ops.apex.syncbn_welford_mean_var +welford_parallel = torch.ops.apex.syncbn_welford_parallel +batchnorm_forward = torch.ops.apex.syncbn_batchnorm_forward +reduce_bn = torch.ops.apex.syncbn_reduce_bn +batchnorm_backward = torch.ops.apex.syncbn_batchnorm_backward +welford_mean_var_c_last = torch.ops.apex.syncbn_welford_mean_var_c_last +batchnorm_forward_c_last = torch.ops.apex.syncbn_batchnorm_forward_c_last +reduce_bn_c_last = torch.ops.apex.syncbn_reduce_bn_c_last +batchnorm_backward_c_last = torch.ops.apex.syncbn_batchnorm_backward_c_last +relu_bw_c_last = torch.ops.apex.syncbn_relu_bw_c_last diff --git a/apex/_extensions/transducer_joint_cuda.py b/apex/_extensions/transducer_joint_cuda.py new file mode 100644 index 000000000..3533e2a5b --- /dev/null +++ b/apex/_extensions/transducer_joint_cuda.py @@ -0,0 +1,50 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_transducer_joint_cuda", __file__) +_ops = torch.ops.apex + + +def forward( + f, + g, + f_len, + g_len, + batch_offset, + packed_batch, + opt, + pack_output, + relu, + dropout, + dropout_prob, + tile_size, +): + return _ops.transducer_joint_forward( + f, + g, + f_len, + g_len, + batch_offset, + scalar_int(packed_batch), + scalar_int(opt), + pack_output, + relu, + dropout, + scalar_float(dropout_prob), + scalar_int(tile_size), + ) + + +def backward(input, f_len, g_len, batch_offset, max_f_len, max_g_len, pack_output, scale): + return _ops.transducer_joint_backward( + list(input), + f_len, + g_len, + batch_offset, + scalar_int(max_f_len), + scalar_int(max_g_len), + pack_output, + scalar_float(scale), + ) diff --git a/apex/_extensions/transducer_loss_cuda.py b/apex/_extensions/transducer_loss_cuda.py new file mode 100644 index 000000000..72e6fe09c --- /dev/null +++ b/apex/_extensions/transducer_loss_cuda.py @@ -0,0 +1,53 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_transducer_loss_cuda", __file__) +_ops = torch.ops.apex + + +def forward(x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, opt, packed_input): + return _ops.transducer_loss_forward( + x, + label, + f_len, + y_len, + batch_offset, + scalar_int(max_f_len), + scalar_int(blank_idx), + scalar_int(opt), + packed_input, + ) + + +def backward( + x, + loss_grad, + alpha, + beta, + f_len, + y_len, + label, + batch_offset, + max_f_len, + blank_idx, + opt, + fuse_softmax_backward, + packed_input, +): + return _ops.transducer_loss_backward( + x, + loss_grad, + alpha, + beta, + f_len, + y_len, + label, + batch_offset, + scalar_int(max_f_len), + scalar_int(blank_idx), + scalar_int(opt), + fuse_softmax_backward, + packed_input, + ) diff --git a/apex/_extensions/xentropy_cuda.py b/apex/_extensions/xentropy_cuda.py new file mode 100644 index 000000000..969fad21a --- /dev/null +++ b/apex/_extensions/xentropy_cuda.py @@ -0,0 +1,16 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_xentropy_cuda", __file__) + +__version__ = torch.ops.apex.xentropy_version() + + +def forward(input, labels, smoothing, half_to_float): + return torch.ops.apex.xentropy_forward(input, labels, scalar_float(smoothing), half_to_float) + + +def backward(grad_loss, logits, max_log_sum_exp, labels, smoothing): + return torch.ops.apex.xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, scalar_float(smoothing)) diff --git a/apex/contrib/bottleneck/bottleneck.py b/apex/contrib/bottleneck/bottleneck.py index e24a9251b..a15557d41 100644 --- a/apex/contrib/bottleneck/bottleneck.py +++ b/apex/contrib/bottleneck/bottleneck.py @@ -4,8 +4,8 @@ from torch import nn from apex import check_cudnn_version_and_warn -import fast_bottleneck -import nccl_p2p_cuda as inc +from apex._extensions import fast_bottleneck +from apex._extensions import nccl_p2p_cuda as inc assert check_cudnn_version_and_warn(__name__, 8400) diff --git a/apex/contrib/bottleneck/halo_exchangers.py b/apex/contrib/bottleneck/halo_exchangers.py index eb0224c5b..8d261bdc4 100644 --- a/apex/contrib/bottleneck/halo_exchangers.py +++ b/apex/contrib/bottleneck/halo_exchangers.py @@ -1,6 +1,6 @@ import torch -import nccl_p2p_cuda as inc -import peer_memory_cuda as pm +from apex._extensions import nccl_p2p_cuda as inc +from apex._extensions import peer_memory_cuda as pm # Communication free halo exchanger. diff --git a/apex/contrib/clip_grad/clip_grad.py b/apex/contrib/clip_grad/clip_grad.py index b34dc43bd..2290a8703 100644 --- a/apex/contrib/clip_grad/clip_grad.py +++ b/apex/contrib/clip_grad/clip_grad.py @@ -4,7 +4,7 @@ _kernel_import_succeeded = False try: - import amp_C + from apex._extensions import amp_C from apex.multi_tensor_apply import multi_tensor_applier _kernel_import_succeeded = True @@ -74,6 +74,17 @@ def clip_grad_norm_( else: grads_misc.append(grad) + if grads_fp16: + # Preserve PyTorch's current fp16 clipping semantics exactly. The fused + # norm path accumulates and returns fp32, which can perturb the fp16 + # scale enough to change rounded gradient values. + return torch.nn.utils.clip_grad_norm_( + parameters, + max_norm, + norm_type=norm_type, + error_if_nonfinite=error_if_nonfinite, + ) + # Compute gradient L2 norms norms = [] dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device) diff --git a/apex/contrib/conv_bias_relu/conv_bias_relu.py b/apex/contrib/conv_bias_relu/conv_bias_relu.py index 533d6421f..faabdbff3 100644 --- a/apex/contrib/conv_bias_relu/conv_bias_relu.py +++ b/apex/contrib/conv_bias_relu/conv_bias_relu.py @@ -1,7 +1,7 @@ import torch from apex import check_cudnn_version_and_warn -import fused_conv_bias_relu +from apex._extensions import fused_conv_bias_relu check_cudnn_version_and_warn(__name__, 8400) diff --git a/apex/contrib/csrc/bottleneck/bottleneck.cpp b/apex/contrib/csrc/bottleneck/bottleneck.cpp index 3c98a39dd..a521ca090 100644 --- a/apex/contrib/csrc/bottleneck/bottleneck.cpp +++ b/apex/contrib/csrc/bottleneck/bottleneck.cpp @@ -1,8 +1,7 @@ #include #include // for getcudnnhandle #include -#include -#include +#include #include #include @@ -438,7 +437,7 @@ void run_conv_scale_bias_add_activation(int64_t* x_dim_padded, int64_t* pad, int int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -583,7 +582,7 @@ void run_conv_scale_bias_add_activation(int64_t* x_dim_padded, int64_t* pad, int void run_conv_scale_bias(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -696,7 +695,7 @@ void run_conv_scale_bias(int64_t* x_dim_padded, int64_t* pad, int64_t* convstrid void run_dconv_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -810,7 +809,7 @@ void run_dconv_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convst void run_dconv(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, cudnnBackendDescriptorType_t mode) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -905,7 +904,7 @@ void run_dconv(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t void run_dconv_add(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrR) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -1821,7 +1820,7 @@ void run_conv_add_scale_bias_activation(int64_t* x_dim_padded, int64_t* pad, int int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -1966,7 +1965,7 @@ void run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, int64_t* pad int64_t* threshold_dim, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, at::Half* devPtrI, int* devPtrT, int* devPtrU, int axis) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -2236,7 +2235,7 @@ void run_dconv_add_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* co int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR, at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -2367,7 +2366,7 @@ void run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* c int64_t* w_dim_padded, int64_t* y_dim_padded, int64_t* threshold_dim, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR, int* devPtrT, int* devPtrU, int axis) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -3554,43 +3553,202 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector()); - m.def("backward", &bottleneck_backward, "Bottleneck block backward", py::call_guard()); - m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init", py::call_guard()); - m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward", py::call_guard()); - m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward", py::call_guard()); - m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward", - py::call_guard()); - m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward", - py::call_guard()); - m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward", - py::call_guard()); - m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward", - py::call_guard()); - m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward", py::call_guard()); - m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init", - py::call_guard()); - m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward", - py::call_guard()); - m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward", - py::call_guard()); - m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward", - py::call_guard()); - m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward", - py::call_guard()); - m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward", - py::call_guard()); - m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward", - py::call_guard()); +namespace { +int as_int(int64_t value) { return static_cast(value); } + +std::vector apex_fast_bottleneck_forward(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs) { + return bottleneck_forward(explicit_nhwc, as_int(stride_1X1), std::move(inputs)); +} + +std::vector apex_fast_bottleneck_backward(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs) { + return bottleneck_backward(explicit_nhwc, as_int(stride_1X1), std::move(inputs)); +} + +std::vector apex_fast_bottleneck_forward_init(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs) { + return bottleneck_forward_init(explicit_nhwc, as_int(stride_1X1), std::move(inputs)); +} + +void apex_fast_bottleneck_forward_out1(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs) { + bottleneck_forward_out1(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +void apex_fast_bottleneck_forward_out2(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs) { + bottleneck_forward_out2(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +void apex_fast_bottleneck_forward_out2_mask(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, std::vector outputs, + at::Tensor thresholdTop, at::Tensor thresholdBottom) { + bottleneck_forward_out2_mask(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), thresholdTop, + thresholdBottom); +} + +at::Tensor apex_fast_bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_y1, + std::vector inputs) { + return bottleneck_forward_out2_halo(explicit_nhwc, fat_halo_y1, std::move(inputs)); +} + +at::Tensor apex_fast_bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, + std::vector inputs, at::Tensor w1by3, + at::Tensor out2_part_halo) { + return bottleneck_forward_out2_halo_corr(explicit_nhwc, slim_halo_y1, std::move(inputs), w1by3, out2_part_halo); +} + +void apex_fast_bottleneck_forward_out2_pad(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, std::vector outputs, + at::Tensor out1_pad) { + bottleneck_forward_out2_pad(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), out1_pad); +} + +void apex_fast_bottleneck_forward_rest(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs) { + bottleneck_forward_rest(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +std::vector apex_fast_bottleneck_backward_init(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs) { + return bottleneck_backward_init(explicit_nhwc, as_int(stride_1X1), std::move(inputs)); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out2(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, + std::vector outputs) { + return bottleneck_backward_grad_out2(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out1(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, + std::vector outputs, at::Tensor grad_out2) { + return bottleneck_backward_grad_out1(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), + grad_out2); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, + std::vector outputs, at::Tensor grad_out2, + at::Tensor thresholdTop, at::Tensor thresholdBottom) { + return bottleneck_backward_grad_out1_mask(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), + grad_out2, thresholdTop, thresholdBottom); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out1_halo_corr( + bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, at::Tensor w1by3, + std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) { + return bottleneck_backward_grad_out1_halo_corr(explicit_nhwc, as_int(stride_1X1), std::move(inputs), w1by3, + std::move(outputs), grad_out2_halo, relu1_halo, part_grad_out1); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, + std::vector outputs, at::Tensor grad_out2_halo, + at::Tensor relu1_halo) { + return bottleneck_backward_grad_out1_halo(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), + grad_out2_halo, relu1_halo); +} + +void apex_fast_bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, std::vector outputs, + at::Tensor input, at::Tensor grad_out2) { + bottleneck_backward_wgrad2_pad(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), input, + grad_out2); +} + +void apex_fast_bottleneck_backward_wgrad2(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out2) { + bottleneck_backward_wgrad2(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out2); +} + +at::Tensor apex_fast_bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, + std::vector outputs, at::Tensor input, + at::Tensor grad_out2_halo) { + return bottleneck_backward_wgrad2_halo(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), + input, grad_out2_halo); +} + +void apex_fast_bottleneck_backward_wgrad3(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs) { + bottleneck_backward_wgrad3(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +void apex_fast_bottleneck_backward_wgrad1(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out1) { + bottleneck_backward_wgrad1(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out1); +} + +void apex_fast_bottleneck_backward_rest(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out2, + at::Tensor grad_out1) { + bottleneck_backward_rest(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out2, + grad_out1); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fast_bottleneck_forward(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); + m.def("fast_bottleneck_backward(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); + m.def("fast_bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); + m.def("fast_bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) -> ()"); + m.def("fast_bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) -> ()"); + m.def("fast_bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor thresholdTop, Tensor thresholdBottom) -> ()"); + m.def("fast_bottleneck_forward_out2_halo(bool explicit_nhwc, Tensor fat_halo_y1, Tensor[] inputs) -> Tensor"); + m.def("fast_bottleneck_forward_out2_halo_corr(bool explicit_nhwc, Tensor slim_halo_y1, Tensor[] inputs, " + "Tensor w1by3, Tensor out2_part_halo) -> Tensor"); + m.def("fast_bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor out1_pad) -> ()"); + m.def("fast_bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) -> ()"); + m.def("fast_bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); + m.def("fast_bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) " + "-> Tensor"); + m.def("fast_bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out2) -> Tensor"); + m.def("fast_bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " + "Tensor[] outputs, Tensor grad_out2, Tensor thresholdTop, Tensor thresholdBottom) -> Tensor"); + m.def("fast_bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " + "Tensor[] outputs, Tensor grad_out2_halo, Tensor relu1_halo) -> Tensor"); + m.def("fast_bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " + "Tensor w1by3, Tensor[] outputs, Tensor grad_out2_halo, Tensor relu1_halo, Tensor part_grad_out1) -> Tensor"); + m.def("fast_bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor input, Tensor grad_out2) -> ()"); + m.def("fast_bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out2) -> ()"); + m.def("fast_bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor input, Tensor grad_out2_halo) -> Tensor"); + m.def("fast_bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) " + "-> ()"); + m.def("fast_bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out1) -> ()"); + m.def("fast_bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out2, Tensor grad_out1) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fast_bottleneck_forward", &apex_fast_bottleneck_forward); + m.impl("fast_bottleneck_backward", &apex_fast_bottleneck_backward); + m.impl("fast_bottleneck_forward_init", &apex_fast_bottleneck_forward_init); + m.impl("fast_bottleneck_forward_out1", &apex_fast_bottleneck_forward_out1); + m.impl("fast_bottleneck_forward_out2", &apex_fast_bottleneck_forward_out2); + m.impl("fast_bottleneck_forward_out2_mask", &apex_fast_bottleneck_forward_out2_mask); + m.impl("fast_bottleneck_forward_out2_halo", &apex_fast_bottleneck_forward_out2_halo); + m.impl("fast_bottleneck_forward_out2_halo_corr", &apex_fast_bottleneck_forward_out2_halo_corr); + m.impl("fast_bottleneck_forward_out2_pad", &apex_fast_bottleneck_forward_out2_pad); + m.impl("fast_bottleneck_forward_rest", &apex_fast_bottleneck_forward_rest); + m.impl("fast_bottleneck_backward_init", &apex_fast_bottleneck_backward_init); + m.impl("fast_bottleneck_backward_grad_out2", &apex_fast_bottleneck_backward_grad_out2); + m.impl("fast_bottleneck_backward_grad_out1", &apex_fast_bottleneck_backward_grad_out1); + m.impl("fast_bottleneck_backward_grad_out1_mask", &apex_fast_bottleneck_backward_grad_out1_mask); + m.impl("fast_bottleneck_backward_grad_out1_halo", &apex_fast_bottleneck_backward_grad_out1_halo); + m.impl("fast_bottleneck_backward_grad_out1_halo_corr", &apex_fast_bottleneck_backward_grad_out1_halo_corr); + m.impl("fast_bottleneck_backward_wgrad2_pad", &apex_fast_bottleneck_backward_wgrad2_pad); + m.impl("fast_bottleneck_backward_wgrad2", &apex_fast_bottleneck_backward_wgrad2); + m.impl("fast_bottleneck_backward_wgrad2_halo", &apex_fast_bottleneck_backward_wgrad2_halo); + m.impl("fast_bottleneck_backward_wgrad3", &apex_fast_bottleneck_backward_wgrad3); + m.impl("fast_bottleneck_backward_wgrad1", &apex_fast_bottleneck_backward_wgrad1); + m.impl("fast_bottleneck_backward_rest", &apex_fast_bottleneck_backward_rest); } diff --git a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp index 9f6924a65..90347d18d 100644 --- a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp +++ b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp @@ -1,8 +1,7 @@ #include #include // for getcudnnhandle #include -#include -#include +#include #include #include @@ -193,7 +192,7 @@ cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, std::strin void run_conv_bias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* convstride, int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, at::Half* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -331,7 +330,7 @@ void run_conv_bias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv void run_conv_bias_mask_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, int8_t* devPtrM, at::Half* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -530,7 +529,7 @@ void run_conv_bias_mask_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int void run_conv_cscale_cbias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrS, at::Half* devPtrB, at::Half* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -732,7 +731,7 @@ void run_conv_cscale_cbias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, void run_conv_bias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, at::Half* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -896,7 +895,7 @@ void run_conv_bias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* void run_drelu_dscale(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR, at::Half* devPtrS, at::Half* devPtrDX) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -1032,7 +1031,7 @@ void run_drelu_dscale(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPt void run_drelu_dbias(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR, at::Half* devPtrDR, float* devPtrDB) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -1159,7 +1158,7 @@ void run_drelu_dbias(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtr void run_dconv_drelu_dbias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* pad, int64_t* convstride, int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrR, at::Half* devPtrRg, float* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -1323,7 +1322,7 @@ void run_dconv_drelu_dbias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64 void run_dconv(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, cudnnBackendDescriptorType_t mode) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -1427,7 +1426,7 @@ void run_dconv(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad } void run_dbias(int64_t* x_dim, cudnnDataType_t dataType, at::Half* devPtrX, float* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -1898,16 +1897,22 @@ std::vector conv_bias_backward(std::vector inputs, int64 return outputs; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward", py::call_guard()); - m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward", - py::call_guard()); - m.def("forward_no_relu", &conv_bias_forward, "Fused Conv-Bias forward", py::call_guard()); - m.def("backward_no_relu", &conv_bias_backward, "Fused Conv-Bias backward", py::call_guard()); - m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward", - py::call_guard()); - m.def("forward_cscale_cbias_relu", &conv_cscale_cbias_relu_forward, "Fused Conv-(const)Scale-(const)Bias-ReLU", - py::call_guard()); - m.def("backward_cscale_cbias_relu", &conv_cscale_cbias_relu_backward, - "Fused Conv-(const)Scale-(const)Bias-ReLU backward", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_conv_bias_relu_forward(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_backward(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_forward_no_relu(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_backward_no_relu(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_forward_mask(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_forward_cscale_cbias_relu(Tensor[] inputs, int padding, int stride) -> Tensor"); + m.def("fused_conv_bias_relu_backward_cscale_cbias_relu(Tensor[] inputs, int padding, int stride) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_conv_bias_relu_forward", &conv_bias_relu_forward); + m.impl("fused_conv_bias_relu_backward", &conv_bias_relu_backward); + m.impl("fused_conv_bias_relu_forward_no_relu", &conv_bias_forward); + m.impl("fused_conv_bias_relu_backward_no_relu", &conv_bias_backward); + m.impl("fused_conv_bias_relu_forward_mask", &conv_bias_mask_relu_forward); + m.impl("fused_conv_bias_relu_forward_cscale_cbias_relu", &conv_cscale_cbias_relu_forward); + m.impl("fused_conv_bias_relu_backward_cscale_cbias_relu", &conv_cscale_cbias_relu_backward); } diff --git a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp index 858d429d0..460046c7f 100644 --- a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp +++ b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp @@ -1,8 +1,8 @@ #include -#include -#include +#include #include +#include #include #include "norm_sample.h" @@ -117,7 +117,39 @@ std::vector gbn_backward(const at::Tensor& x, const at::Tensor& dy, return std::vector{x_grad, scale_grad, bias_grad}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &gbn_forward, "Group batch norm forward", py::call_guard()); - m.def("backward", &gbn_backward, "Group batch backward", py::call_guard()); +namespace { +std::vector to_int_vector(at::IntArrayRef values) { + return std::vector(values.begin(), values.end()); +} + +at::Tensor apex_cudnn_gbn_forward(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_var, + const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, + double momentum, double epsilon, int64_t bn_group, int64_t rank_id, + at::IntArrayRef peer_buffers) { + return gbn_forward(x, scale, bias, running_mean, running_var, minibatch_mean, minibatch_inv_var, + static_cast(momentum), static_cast(epsilon), bn_group, static_cast(rank_id), + to_int_vector(peer_buffers)); +} + +std::vector apex_cudnn_gbn_backward(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, + const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, double epsilon, + int64_t bn_group, int64_t rank_id, at::IntArrayRef peer_buffers) { + return gbn_backward(x, dy, scale, minibatch_mean, minibatch_inv_var, static_cast(epsilon), bn_group, + static_cast(rank_id), to_int_vector(peer_buffers)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("cudnn_gbn_forward(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_var, " + "Tensor minibatch_mean, Tensor minibatch_inv_var, float momentum, float epsilon, int bn_group, int rank_id, " + "int[] peer_buffers) -> Tensor"); + m.def("cudnn_gbn_backward(Tensor x, Tensor dy, Tensor scale, Tensor minibatch_mean, Tensor minibatch_inv_var, " + "float epsilon, int bn_group, int rank_id, int[] peer_buffers) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("cudnn_gbn_forward", &apex_cudnn_gbn_forward); + m.impl("cudnn_gbn_backward", &apex_cudnn_gbn_backward); } diff --git a/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp b/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp index a0da49493..1a398fefb 100644 --- a/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp +++ b/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp @@ -22,10 +22,9 @@ #include "norm_sample.h" +#include #include // for getcudnnhandle #include -#include -#include #include "cudnn_backend.h" @@ -74,7 +73,7 @@ void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudn cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon, int64_t* peerDims, cudnnDataType_t data_type) { // get the cudnn handle - cudnnHandle_t handle = torch::native::getCudnnHandle(); + cudnnHandle_t handle = at::native::getCudnnHandle(); // Creates the necessary tensor descriptors int64_t tensor_stride[4]; @@ -225,7 +224,7 @@ void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void* xDevPt const std::vector& peer_devPtrs, double epsilon_val, double exponential_decay_factor, size_t peer_size, int rank_id) { // get handle - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); // get stream cudaStream_t stream; @@ -275,7 +274,7 @@ void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void* xDevPt cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon, int64_t* peerDims, cudnnDataType_t data_type) { // get cudnn handle - cudnnHandle_t handle = torch::native::getCudnnHandle(); + cudnnHandle_t handle = at::native::getCudnnHandle(); // Creates the necessary tensor descriptors int64_t tensor_stride[4]; @@ -406,7 +405,7 @@ void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, void* xDevP const std::vector& peer_devPtrs, void* dxDevPtr, void* dscaledevPtr, void* dbiasdevPtr, double epsilon_val, size_t peer_size, int rank_id) { // get handle - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); // get stream cudaStream_t stream; diff --git a/apex/contrib/csrc/fmha/fmha_api.cpp b/apex/contrib/csrc/fmha/fmha_api.cpp index ed8b3de5f..e7193e663 100644 --- a/apex/contrib/csrc/fmha/fmha_api.cpp +++ b/apex/contrib/csrc/fmha/fmha_api.cpp @@ -26,7 +26,8 @@ ******************************************************************************/ #include -#include +#include +#include #include "fmha.h" @@ -81,7 +82,6 @@ std::vector mha_fwd( const at::Tensor& cu_seqlens, // b+1 const float p_dropout, const int max_seq_len, const bool is_training, const bool is_nl, const bool zero_tensors, c10::optional gen_) { - using namespace torch::indexing; auto dprops = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) || (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0)); @@ -127,12 +127,12 @@ std::vector mha_fwd( TORCH_CHECK(head_size == 64); auto opts = qkv.options(); - auto ctx = torch::empty({total, num_heads, head_size}, opts); + auto ctx = at::empty({total, num_heads, head_size}, opts); - auto s = torch::empty({batch_size, num_heads, seq_len, seq_len}, opts); + auto s = at::empty({batch_size, num_heads, seq_len, seq_len}, opts); if (zero_tensors) { - mha_fill(ctx, cu_seqlens.index({Slice(-1, None)})); + mha_fill(ctx, cu_seqlens.slice(0, -1)); } auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); @@ -165,7 +165,6 @@ std::vector mha_bwd( const float p_dropout, // probability to drop const int max_seq_len, // max sequence length to choose the kernel const bool zero_tensors) { - using namespace torch::indexing; auto dprops = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) || (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0)); @@ -189,10 +188,10 @@ std::vector mha_bwd( auto stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(qkv.dtype() == torch::kFloat16); - TORCH_CHECK(dout.dtype() == torch::kFloat16); - TORCH_CHECK(softmax.dtype() == torch::kFloat16); - TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); + TORCH_CHECK(qkv.dtype() == at::kHalf); + TORCH_CHECK(dout.dtype() == at::kHalf); + TORCH_CHECK(softmax.dtype() == at::kHalf); + TORCH_CHECK(cu_seqlens.dtype() == at::kInt); TORCH_CHECK(qkv.is_cuda()); TORCH_CHECK(cu_seqlens.is_cuda()); @@ -213,10 +212,10 @@ std::vector mha_bwd( TORCH_CHECK(batch_size > 0); TORCH_CHECK(head_size == 64); - auto dqkv = torch::empty_like(qkv); + auto dqkv = at::empty_like(qkv); if (zero_tensors) { - mha_fill(dqkv, cu_seqlens.index({Slice(-1, None)})); + mha_fill(dqkv, cu_seqlens.slice(0, -1)); } Fused_multihead_attention_fprop_params params; @@ -274,7 +273,7 @@ std::vector mha_bwd_nl( auto opts = qkv.options(); - auto dqkv = torch::empty_like(qkv); + auto dqkv = at::empty_like(qkv); if (zero_tensors) { dqkv.zero_(); @@ -286,7 +285,7 @@ std::vector mha_bwd_nl( } else if (batch_size == 2) { num_chunks = 3; } - auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts); + auto dkv = at::empty({total, num_chunks, 2, num_heads, head_size}, opts); Fused_multihead_attention_fprop_params params; @@ -307,10 +306,8 @@ std::vector mha_bwd_nl( // SPLIT-K reduction of num_chunks dK, dV parts - // The equivalent of the following Pytorch code: - // using namespace torch::indexing; - // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)}); - // torch::sum_out(view_out, dkv, 1); + // The equivalent Python operation would reduce the split-K chunks into the + // K/V gradient slice of dqkv. const int hidden_size = num_heads * head_size; fmha_run_noloop_reduce(dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr(), hidden_size, batch_size, total, @@ -319,9 +316,40 @@ std::vector mha_bwd_nl( return {dqkv, softmax, dkv}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "Fused Multi-head Self-attention for BERT"; - m.def("fwd", &mha_fwd, "Forward pass", py::call_guard()); - m.def("bwd", &mha_bwd, "Backward pass", py::call_guard()); - m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)", py::call_guard()); +namespace { +std::vector apex_fmha_fwd(const at::Tensor& qkv, const at::Tensor& cu_seqlens, double p_dropout, + int64_t max_seq_len, bool is_training, bool is_nl, bool zero_tensors, + c10::optional gen) { + return mha_fwd(qkv, cu_seqlens, static_cast(p_dropout), static_cast(max_seq_len), is_training, is_nl, + zero_tensors, gen); +} + +std::vector apex_fmha_bwd(const at::Tensor& dout, const at::Tensor& qkv, at::Tensor softmax, + const at::Tensor& cu_seqlens, double p_dropout, int64_t max_seq_len, + bool zero_tensors) { + return mha_bwd(dout, qkv, softmax, cu_seqlens, static_cast(p_dropout), static_cast(max_seq_len), + zero_tensors); +} + +std::vector apex_fmha_bwd_nl(const at::Tensor& dout, const at::Tensor& qkv, at::Tensor softmax, + const at::Tensor& cu_seqlens, double p_dropout, int64_t max_seq_len, + bool zero_tensors) { + return mha_bwd_nl(dout, qkv, softmax, cu_seqlens, static_cast(p_dropout), static_cast(max_seq_len), + zero_tensors); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fmha_fwd(Tensor qkv, Tensor cu_seqlens, float p_dropout, int max_seq_len, bool is_training, bool is_nl, " + "bool zero_tensors, Generator? gen) -> Tensor[]"); + m.def("fmha_bwd(Tensor dout, Tensor qkv, Tensor(a!) softmax, Tensor cu_seqlens, float p_dropout, int max_seq_len, " + "bool zero_tensors) -> Tensor[]"); + m.def("fmha_bwd_nl(Tensor dout, Tensor qkv, Tensor(a!) softmax, Tensor cu_seqlens, float p_dropout, " + "int max_seq_len, bool zero_tensors) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fmha_fwd", &apex_fmha_fwd); + m.impl("fmha_bwd", &apex_fmha_bwd); + m.impl("fmha_bwd_nl", &apex_fmha_bwd_nl); } diff --git a/apex/contrib/csrc/fmha/src/fmha_fill.cu b/apex/contrib/csrc/fmha/src/fmha_fill.cu index f2e0d925d..00614ce6f 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fill.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fill.cu @@ -27,7 +27,7 @@ #include #include -#include +#include constexpr int block_size = 512; constexpr int ctas_per_sm = 4; diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp index 4e3555c3a..3d79df36d 100644 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp +++ b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include #include @@ -39,9 +40,22 @@ at::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::Tensor& return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &focal_loss_forward, "Focal loss calculation forward (CUDA)", - py::call_guard()); - m.def("backward", &focal_loss_backward, "Focal loss calculation backward (CUDA)", - py::call_guard()); +std::vector focal_loss_forward_dispatch(const at::Tensor& cls_output, + const at::Tensor& cls_targets_at_level, + const at::Tensor& num_positives_sum, int64_t num_real_classes, + double alpha, double gamma, double smoothing_factor) { + return focal_loss_forward(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, + static_cast(alpha), static_cast(gamma), + static_cast(smoothing_factor)); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("focal_loss_forward(Tensor cls_output, Tensor cls_targets_at_level, Tensor num_positives_sum, " + "int num_real_classes, float alpha, float gamma, float smoothing_factor) -> Tensor[]"); + m.def("focal_loss_backward(Tensor grad_output, Tensor partial_grad, Tensor num_positives_sum) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("focal_loss_forward", &focal_loss_forward_dispatch); + m.impl("focal_loss_backward", &focal_loss_backward); } diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp b/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp index 2a9575d14..01a0553bc 100644 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp @@ -3,13 +3,20 @@ * SPDX-License-Identifier: BSD-3-Clause */ +#include #include -#include +#include #include "group_norm_nhwc.h" #include "group_norm_nhwc_bwd_one_pass.h" #include "group_norm_nhwc_fwd_one_pass.h" +#include +#include +#include +#include +#include + //////////////////////////////////////////////////////////////////////////////////////////////////// #define CHECK_CUDA_STATUS(call) \ @@ -21,15 +28,16 @@ } \ } while (0) -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_CHANNELS_LAST(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be channels last") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) -#define CHECK_NHWC_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CHANNELS_LAST(x) +#define CHECK_TENSOR_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_TENSOR_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_TENSOR_CHANNELS_LAST(x) \ + TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be channels last") +#define CHECK_TENSOR_INPUT(x) \ + CHECK_TENSOR_CUDA(x); \ + CHECK_TENSOR_CONTIGUOUS(x) +#define CHECK_NHWC_TENSOR_INPUT(x) \ + CHECK_TENSOR_CUDA(x); \ + CHECK_TENSOR_CHANNELS_LAST(x) static bool initialized = false; static cudaDeviceProp props; @@ -39,13 +47,13 @@ const std::unordered_set supported_c_values = {128, 256, 320, 384, 448, 2048, 2240, 2560, 2688, 3072, 3136, 3584, 4096}; const std::unordered_set supported_groups_values = {16, 32}; -std::vector group_norm_fwd(torch::Tensor input, int groups, torch::Tensor weight, torch::Tensor bias, - float eps, int passes, bool with_swish = false) { +std::vector group_norm_fwd(at::Tensor input, int groups, at::Tensor weight, at::Tensor bias, float eps, + int passes, bool with_swish = false) { if (!initialized) { CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0)); initialized = true; } - CHECK_NHWC_INPUT(input); + CHECK_NHWC_TENSOR_INPUT(input); auto stream = at::cuda::getCurrentCUDAStream(); // Achieve group norm arguments @@ -83,26 +91,26 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, torch params_fwd.with_swish = with_swish; PrecisionMode mode; - if (input.dtype() == torch::kFloat32) { - if (weight.dtype() == torch::kFloat16) { + if (input.dtype() == at::kFloat) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::FP32IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::FP32IOBF16W; } else { mode = PrecisionMode::FP32IOFP32W; } - } else if (input.dtype() == torch::kBFloat16) { - if (weight.dtype() == torch::kFloat16) { + } else if (input.dtype() == at::kBFloat16) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::BF16IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::BF16IOBF16W; } else { mode = PrecisionMode::BF16IOFP32W; } } else { - if (weight.dtype() == torch::kFloat16) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::FP16IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::FP16IOBF16W; } else { mode = PrecisionMode::FP16IOFP32W; @@ -126,13 +134,13 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, torch } // Allocate on the device. - auto red_buffer = at::empty({red_buffer_elts}, options.dtype(at::kFloat)); + auto red_buffer = at::empty({static_cast(red_buffer_elts)}, options.dtype(at::kFloat)); params_fwd.red_buffer = red_buffer.data_ptr(); // Allocate the buffer if needed. - auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt)); + auto barriers = at::zeros({static_cast(barriers_elts)}, options.dtype(at::kInt)); params_fwd.barriers = barriers.data_ptr(); - auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); + auto zeroed_red_buffer = at::zeros({static_cast(zeroed_red_buffer_elts)}, options.dtype(at::kFloat)); params_fwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr(); if (passes == 1) { @@ -145,14 +153,14 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, torch return {output, sums_d}; } -std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tensor sums, torch::Tensor input, - int groups, torch::Tensor weight, torch::Tensor bias, float eps, int passes, - bool with_swish = false) { +std::vector group_norm_bwd(at::Tensor grad_output, at::Tensor sums, at::Tensor input, int groups, + at::Tensor weight, at::Tensor bias, float eps, int passes, + bool with_swish = false) { if (!initialized) { CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0)); initialized = true; } - CHECK_NHWC_INPUT(grad_output); + CHECK_NHWC_TENSOR_INPUT(grad_output); auto stream = at::cuda::getCurrentCUDAStream(); // Achieve group norm arguments @@ -197,26 +205,26 @@ std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tens params_bwd.with_swish = with_swish; PrecisionMode mode; - if (input.dtype() == torch::kFloat32) { - if (weight.dtype() == torch::kFloat16) { + if (input.dtype() == at::kFloat) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::FP32IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::FP32IOBF16W; } else { mode = PrecisionMode::FP32IOFP32W; } - } else if (input.dtype() == torch::kBFloat16) { - if (weight.dtype() == torch::kFloat16) { + } else if (input.dtype() == at::kBFloat16) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::BF16IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::BF16IOBF16W; } else { mode = PrecisionMode::BF16IOFP32W; } } else { - if (weight.dtype() == torch::kFloat16) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::FP16IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::FP16IOBF16W; } else { mode = PrecisionMode::FP16IOFP32W; @@ -240,13 +248,13 @@ std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tens } // Allocate on the device. - auto red_buffer = at::empty({red_buffer_elts}, options.dtype(at::kFloat)); + auto red_buffer = at::empty({static_cast(red_buffer_elts)}, options.dtype(at::kFloat)); params_bwd.red_buffer = red_buffer.data_ptr(); // Allocate the buffer if needed. - auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt)); + auto barriers = at::zeros({static_cast(barriers_elts)}, options.dtype(at::kInt)); params_bwd.barriers = barriers.data_ptr(); - auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); + auto zeroed_red_buffer = at::zeros({static_cast(zeroed_red_buffer_elts)}, options.dtype(at::kFloat)); params_bwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr(); if (passes == 1) { @@ -259,7 +267,29 @@ std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tens return {grad_input, grad_weight, grad_bias}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &group_norm_fwd, "NHWC group norm forward", py::call_guard()); - m.def("backward", &group_norm_bwd, "NHWC group norm backward", py::call_guard()); +namespace { +std::vector apex_group_norm_fwd(at::Tensor input, int64_t groups, at::Tensor weight, at::Tensor bias, + double eps, int64_t passes, bool with_swish) { + return group_norm_fwd(input, static_cast(groups), weight, bias, static_cast(eps), + static_cast(passes), with_swish); +} + +std::vector apex_group_norm_bwd(at::Tensor grad_output, at::Tensor sums, at::Tensor input, int64_t groups, + at::Tensor weight, at::Tensor bias, double eps, int64_t passes, + bool with_swish) { + return group_norm_bwd(grad_output, sums, input, static_cast(groups), weight, bias, static_cast(eps), + static_cast(passes), with_swish); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("group_norm_forward(Tensor input, int groups, Tensor weight, Tensor bias, float eps, int passes, " + "bool with_swish) -> Tensor[]"); + m.def("group_norm_backward(Tensor grad_output, Tensor sums, Tensor input, int groups, Tensor weight, Tensor bias, " + "float eps, int passes, bool with_swish) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("group_norm_forward", &apex_group_norm_fwd); + m.impl("group_norm_backward", &apex_group_norm_bwd); } diff --git a/apex/contrib/csrc/group_norm_v2/gn.cpp b/apex/contrib/csrc/group_norm_v2/gn.cpp index 23b2a8b18..bbe3e4b5c 100644 --- a/apex/contrib/csrc/group_norm_v2/gn.cpp +++ b/apex/contrib/csrc/group_norm_v2/gn.cpp @@ -1,25 +1,33 @@ #include "gn.hpp" +#include #include -#include +#include +#include +#include + +#include +#include +#include +#include namespace group_norm_v2 { -torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, bool silu, int num_groups, - std::optional mean_var_out, int sm_margin) { - if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != torch::kFloat32)) { +at::Tensor gn(at::Tensor x, at::Tensor w, at::Tensor b, float eps, bool silu, int num_groups, + std::optional mean_var_out, int sm_margin) { + if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != at::kFloat)) { throw std::invalid_argument("gn dtype mismatch"); } - torch::Tensor out = torch::empty_like(x); + at::Tensor out = at::empty_like(x); float* ptr_mean_var_out = mean_var_out.has_value() ? mean_var_out->data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); int device_id = at::cuda::getCurrentCUDAStream().device().index(); group_norm_v2::Meta meta; - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + if (x.dtype() == at::kHalf && w.dtype() == at::kHalf) { group_norm_v2::gn_cuda((half*)out.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, nullptr, nullptr, sm_margin, stream, device_id, &meta, true); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + } else if (x.dtype() == at::kBFloat16 && w.dtype() == at::kBFloat16) { group_norm_v2::gn_cuda((__nv_bfloat16*)out.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, nullptr, nullptr, sm_margin, stream, device_id, @@ -27,18 +35,18 @@ torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, b } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } - torch::Tensor red_buffer = - torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); - thread_local torch::Tensor barrier; + at::Tensor red_buffer = + at::empty({meta.red_buffer_size}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); + thread_local at::Tensor barrier; if (barrier.size(0) < meta.barrier_size) { - barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); + barrier = at::zeros({meta.barrier_size}, at::TensorOptions().dtype(at::ScalarType::UInt32).device(at::kCUDA)); } - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + if (x.dtype() == at::kHalf && w.dtype() == at::kHalf) { group_norm_v2::gn_cuda((half*)out.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + } else if (x.dtype() == at::kBFloat16 && w.dtype() == at::kBFloat16) { group_norm_v2::gn_cuda((__nv_bfloat16*)out.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, red_buffer.data_ptr(), @@ -49,24 +57,24 @@ torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, b return out; } -auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch::Tensor b, torch::Tensor mean_var, - float eps, bool silu, int num_groups, int sm_margin) { - if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != torch::kFloat32) { +auto gn_bwd(at::Tensor grad_output, at::Tensor x, at::Tensor w, at::Tensor b, at::Tensor mean_var, float eps, bool silu, + int num_groups, int sm_margin) { + if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != at::kFloat) { throw std::invalid_argument("gn_bwd dtype mismatch"); } - torch::Tensor grad_input = torch::empty_like(x); - torch::Tensor grad_weight = torch::empty_like(w); - torch::Tensor grad_bias = torch::empty_like(w); + at::Tensor grad_input = at::empty_like(x); + at::Tensor grad_weight = at::empty_like(w); + at::Tensor grad_bias = at::empty_like(w); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); int device_id = at::cuda::getCurrentCUDAStream().device().index(); group_norm_v2::Meta meta; - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + if (x.dtype() == at::kHalf && w.dtype() == at::kHalf) { group_norm_v2::gn_bwd_cuda((half*)grad_input.data_ptr(), (half*)grad_weight.data_ptr(), (half*)grad_bias.data_ptr(), (half*)grad_output.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, nullptr, nullptr, sm_margin, stream, device_id, &meta, true); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + } else if (x.dtype() == at::kBFloat16 && w.dtype() == at::kBFloat16) { group_norm_v2::gn_bwd_cuda((__nv_bfloat16*)grad_input.data_ptr(), (__nv_bfloat16*)grad_weight.data_ptr(), (__nv_bfloat16*)grad_bias.data_ptr(), (__nv_bfloat16*)grad_output.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), @@ -75,19 +83,19 @@ auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch:: } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } - torch::Tensor red_buffer = - torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); - thread_local torch::Tensor barrier; + at::Tensor red_buffer = + at::empty({meta.red_buffer_size}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); + thread_local at::Tensor barrier; if (barrier.size(0) < meta.barrier_size) { - barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); + barrier = at::zeros({meta.barrier_size}, at::TensorOptions().dtype(at::ScalarType::UInt32).device(at::kCUDA)); } - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + if (x.dtype() == at::kHalf && w.dtype() == at::kHalf) { group_norm_v2::gn_bwd_cuda((half*)grad_input.data_ptr(), (half*)grad_weight.data_ptr(), (half*)grad_bias.data_ptr(), (half*)grad_output.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + } else if (x.dtype() == at::kBFloat16 && w.dtype() == at::kBFloat16) { group_norm_v2::gn_bwd_cuda((__nv_bfloat16*)grad_input.data_ptr(), (__nv_bfloat16*)grad_weight.data_ptr(), (__nv_bfloat16*)grad_bias.data_ptr(), (__nv_bfloat16*)grad_output.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), @@ -102,9 +110,30 @@ auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch:: } // namespace group_norm_v2 -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gn", &group_norm_v2::gn, py::arg("x"), py::arg("w"), py::arg("b"), py::arg("eps"), py::arg("silu"), - py::arg("num_groups"), py::arg("mean_var_out") = py::none(), py::arg("sm_margin") = 0, ""); - m.def("gn_bwd", &group_norm_v2::gn_bwd, py::arg("grad_output"), py::arg("x"), py::arg("w"), py::arg("b"), - py::arg("mean_var"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("sm_margin") = 0, ""); +namespace { +at::Tensor apex_group_norm_v2_gn(at::Tensor x, at::Tensor w, at::Tensor b, double eps, bool silu, int64_t num_groups, + const std::optional& mean_var_out, int64_t sm_margin) { + return group_norm_v2::gn(x, w, b, static_cast(eps), silu, static_cast(num_groups), mean_var_out, + static_cast(sm_margin)); +} + +std::vector apex_group_norm_v2_gn_bwd(at::Tensor grad_output, at::Tensor x, at::Tensor w, at::Tensor b, + at::Tensor mean_var, double eps, bool silu, int64_t num_groups, + int64_t sm_margin) { + auto grads = group_norm_v2::gn_bwd(grad_output, x, w, b, mean_var, static_cast(eps), silu, + static_cast(num_groups), static_cast(sm_margin)); + return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads)}; +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("group_norm_v2_gn(Tensor x, Tensor w, Tensor b, float eps, bool silu, int num_groups, " + "Tensor? mean_var_out, int sm_margin) -> Tensor"); + m.def("group_norm_v2_gn_bwd(Tensor grad_output, Tensor x, Tensor w, Tensor b, Tensor mean_var, float eps, " + "bool silu, int num_groups, int sm_margin) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("group_norm_v2_gn", &apex_group_norm_v2_gn); + m.impl("group_norm_v2_gn_bwd", &apex_group_norm_v2_gn_bwd); } diff --git a/apex/contrib/csrc/groupbn/interface.cpp b/apex/contrib/csrc/groupbn/interface.cpp index 27891bce1..b15250e92 100644 --- a/apex/contrib/csrc/groupbn/interface.cpp +++ b/apex/contrib/csrc/groupbn/interface.cpp @@ -1,18 +1,13 @@ #include #include #include -#include -#include -#include -#include +#include #include "ATen/Generator.h" #include "ATen/Scalar.h" #include "ATen/Storage.h" #include "ATen/Tensor.h" -namespace py = pybind11; - int64_t get_buffer_size(const int bn_sync_steps); void* get_data_ptr(const at::Tensor& data); @@ -72,29 +67,149 @@ int nhwc_bn_bwd_occupancy(); int nhwc_bn_addrelu_fwd_occupancy(); int nhwc_bn_addrelu_bwd_occupancy(); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("get_buffer_size", &get_buffer_size, "get_buffer_size", py::call_guard()); - m.def("get_data_ptr", &get_data_ptr, "get_data_ptr", py::call_guard()); - m.def("get_remote_data_ptr", &get_remote_data_ptr, "get_remote_data_ptr", py::call_guard()); - m.def("close_remote_data", &close_remote_data, "close_remote_data", py::call_guard()); - - m.def("bn_fwd_nhwc", &nhwc_bn_fwd_train, "bn_fwd_nhwc", py::call_guard()); - m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc", py::call_guard()); - m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc", py::call_guard()); - - m.def("bn_fwd_nhwc_occupancy", &nhwc_bn_fwd_occupancy, "bn_fwd_nhwc_occupancy", - py::call_guard()); - m.def("bn_bwd_nhwc_occupancy", &nhwc_bn_bwd_occupancy, "bn_bwd_nhwc_occupancy", - py::call_guard()); - - m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc", - py::call_guard()); - m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc", - py::call_guard()); - m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc", py::call_guard()); - - m.def("bn_addrelu_fwd_nhwc_occupancy", &nhwc_bn_addrelu_fwd_occupancy, "bn_addrelu_fwd_nhwc_occupancy", - py::call_guard()); - m.def("bn_addrelu_bwd_nhwc_occupancy", &nhwc_bn_addrelu_bwd_occupancy, "bn_addrelu_bwd_nhwc_occupancy", - py::call_guard()); +namespace { +void* optional_ptr(const c10::optional& ptr) { + return ptr.has_value() ? reinterpret_cast(ptr.value()) : nullptr; +} + +int64_t apex_bnp_get_buffer_size(int64_t bn_sync_steps) { return get_buffer_size(static_cast(bn_sync_steps)); } + +int64_t apex_bnp_get_data_ptr(const at::Tensor& data) { return reinterpret_cast(get_data_ptr(data)); } + +int64_t apex_bnp_get_remote_data_ptr(const at::Tensor& handle, int64_t offset) { + return reinterpret_cast(get_remote_data_ptr(handle, offset)); +} + +int64_t apex_bnp_bn_fwd_nhwc_occupancy() { return nhwc_bn_fwd_occupancy(); } + +int64_t apex_bnp_bn_bwd_nhwc_occupancy() { return nhwc_bn_bwd_occupancy(); } + +int64_t apex_bnp_bn_addrelu_fwd_nhwc_occupancy() { return nhwc_bn_addrelu_fwd_occupancy(); } + +int64_t apex_bnp_bn_addrelu_bwd_nhwc_occupancy() { return nhwc_bn_addrelu_bwd_occupancy(); } + +at::Tensor apex_bnp_bn_fwd_nhwc(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, + const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, + const at::Tensor& ret_cta, double momentum, double epsilon, bool fuse_relu, + c10::optional my_data, c10::optional pair_data, + c10::optional pair_data2, c10::optional pair_data3, + int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, + int64_t grid_dim_x, bool coop) { + return nhwc_bn_fwd_train(x, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, ret_cta, + static_cast(momentum), static_cast(epsilon), fuse_relu, + optional_ptr(my_data), optional_ptr(pair_data), optional_ptr(pair_data2), + optional_ptr(pair_data3), static_cast(bn_group), magic_tensor, + static_cast(occupancy), static_cast(grid_dim_x), coop); +} + +at::Tensor apex_bnp_bn_fwd_eval_nhwc(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, + const at::Tensor& ret_cta, int64_t bn_group, double momentum, double epsilon, + bool fuse_relu) { + return nhwc_bn_fwd_eval(x, scale, bias, running_mean, running_inv_var, ret_cta, static_cast(bn_group), + static_cast(momentum), static_cast(epsilon), fuse_relu); +} + +std::vector apex_bnp_bn_bwd_nhwc( + const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, double momentum, double epsilon, bool fuse_relu, + c10::optional my_data, c10::optional pair_data, c10::optional pair_data2, + c10::optional pair_data3, int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, + int64_t grid_dim_x, bool coop) { + return nhwc_bn_bwd(x, dy, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, ret_cta, + static_cast(momentum), static_cast(epsilon), fuse_relu, optional_ptr(my_data), + optional_ptr(pair_data), optional_ptr(pair_data2), optional_ptr(pair_data3), + static_cast(bn_group), magic_tensor, static_cast(occupancy), + static_cast(grid_dim_x), coop); +} + +at::Tensor apex_bnp_bn_addrelu_fwd_nhwc( + const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, const at::Tensor& ret_cta, double momentum, + double epsilon, c10::optional my_data, c10::optional pair_data, + c10::optional pair_data2, c10::optional pair_data3, int64_t bn_group, + const at::Tensor& magic_tensor, int64_t occupancy, int64_t grid_dim_x, bool coop) { + return nhwc_bn_addrelu_fwd_train(x, z, scale, bias, running_mean, running_inv_var, minibatch_mean, + minibatch_inv_var, bitmask, ret_cta, static_cast(momentum), + static_cast(epsilon), optional_ptr(my_data), optional_ptr(pair_data), + optional_ptr(pair_data2), optional_ptr(pair_data3), static_cast(bn_group), + magic_tensor, static_cast(occupancy), static_cast(grid_dim_x), coop); +} + +at::Tensor apex_bnp_bn_addrelu_fwd_eval_nhwc(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& ret_cta, + int64_t bn_group, double momentum, double epsilon) { + return nhwc_bn_addrelu_fwd_eval(x, z, scale, bias, running_mean, running_inv_var, ret_cta, + static_cast(bn_group), static_cast(momentum), + static_cast(epsilon)); +} + +std::vector apex_bnp_bn_addrelu_bwd_nhwc( + const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, const at::Tensor& ret_cta, double momentum, + double epsilon, c10::optional my_data, c10::optional pair_data, + c10::optional pair_data2, c10::optional pair_data3, int64_t bn_group, + const at::Tensor& magic_tensor, int64_t occupancy, int64_t grid_dim_x, bool coop) { + return nhwc_bn_addrelu_bwd(x, dy, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, + bitmask, ret_cta, static_cast(momentum), static_cast(epsilon), + optional_ptr(my_data), optional_ptr(pair_data), optional_ptr(pair_data2), + optional_ptr(pair_data3), static_cast(bn_group), magic_tensor, + static_cast(occupancy), static_cast(grid_dim_x), coop); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("bnp_get_buffer_size(int bn_sync_steps) -> int"); + m.def("bnp_get_data_ptr(Tensor data) -> int"); + m.def("bnp_get_remote_data_ptr(Tensor handle, int offset) -> int"); + m.def("bnp_close_remote_data(Tensor handle) -> ()"); + m.def("bnp_bn_fwd_nhwc(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_inv_var, " + "Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor ret_cta, float momentum, float epsilon, " + "bool fuse_relu, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, int bn_group, " + "Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor"); + m.def("bnp_bn_fwd_eval_nhwc(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_inv_var, " + "Tensor ret_cta, int bn_group, float momentum, float epsilon, bool fuse_relu) -> Tensor"); + m.def("bnp_bn_bwd_nhwc(Tensor x, Tensor dy, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor ret_cta, float momentum, " + "float epsilon, bool fuse_relu, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " + "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor[]"); + m.def("bnp_bn_fwd_nhwc_occupancy() -> int"); + m.def("bnp_bn_bwd_nhwc_occupancy() -> int"); + m.def("bnp_bn_addrelu_fwd_nhwc(Tensor x, Tensor z, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor bitmask, Tensor ret_cta, " + "float momentum, float epsilon, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " + "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor"); + m.def("bnp_bn_addrelu_fwd_eval_nhwc(Tensor x, Tensor z, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor ret_cta, int bn_group, float momentum, float epsilon) -> Tensor"); + m.def("bnp_bn_addrelu_bwd_nhwc(Tensor x, Tensor dy, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor bitmask, Tensor ret_cta, " + "float momentum, float epsilon, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " + "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor[]"); + m.def("bnp_bn_addrelu_fwd_nhwc_occupancy() -> int"); + m.def("bnp_bn_addrelu_bwd_nhwc_occupancy() -> int"); +} + +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { + m.impl("bnp_get_buffer_size", &apex_bnp_get_buffer_size); + m.impl("bnp_get_data_ptr", &apex_bnp_get_data_ptr); + m.impl("bnp_get_remote_data_ptr", &apex_bnp_get_remote_data_ptr); + m.impl("bnp_close_remote_data", &close_remote_data); + m.impl("bnp_bn_fwd_nhwc_occupancy", &apex_bnp_bn_fwd_nhwc_occupancy); + m.impl("bnp_bn_bwd_nhwc_occupancy", &apex_bnp_bn_bwd_nhwc_occupancy); + m.impl("bnp_bn_addrelu_fwd_nhwc_occupancy", &apex_bnp_bn_addrelu_fwd_nhwc_occupancy); + m.impl("bnp_bn_addrelu_bwd_nhwc_occupancy", &apex_bnp_bn_addrelu_bwd_nhwc_occupancy); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("bnp_bn_fwd_nhwc", &apex_bnp_bn_fwd_nhwc); + m.impl("bnp_bn_fwd_eval_nhwc", &apex_bnp_bn_fwd_eval_nhwc); + m.impl("bnp_bn_bwd_nhwc", &apex_bnp_bn_bwd_nhwc); + m.impl("bnp_bn_addrelu_fwd_nhwc", &apex_bnp_bn_addrelu_fwd_nhwc); + m.impl("bnp_bn_addrelu_fwd_eval_nhwc", &apex_bnp_bn_addrelu_fwd_eval_nhwc); + m.impl("bnp_bn_addrelu_bwd_nhwc", &apex_bnp_bn_addrelu_bwd_nhwc); } diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp index e585bc73f..c2d7d34e5 100644 --- a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp +++ b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include #include @@ -65,17 +66,78 @@ void index_mul_2d_half_backwrad_backward(at::Tensor& grad_grad_out, at::Tensor& grad_grad_in2, in1, in2, idx1); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("float_forward", &index_mul_2d_float_forward, "index mul float calculation forward (CUDA)", - py::call_guard()); - m.def("float_backward", &index_mul_2d_float_backward, "index mul float calculation backward (CUDA)", - py::call_guard()); - m.def("float_backward_backward", &index_mul_2d_float_backwrad_backward, - "index mul float calculation backward backward (CUDA)", py::call_guard()); - m.def("half_forward", &index_mul_2d_half_forward, "index mul half calculation forward (CUDA)", - py::call_guard()); - m.def("half_backward", &index_mul_2d_half_backward, "index mul half calculation backward (CUDA)", - py::call_guard()); - m.def("half_backward_backward", &index_mul_2d_half_backwrad_backward, - "index mul half calculation backward backward (CUDA)", py::call_guard()); +void index_mul_2d_float_forward_dispatch(const at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor out_arg = out; + index_mul_2d_float_forward(out_arg, in1, in2, idx1); +} + +void index_mul_2d_float_backward_dispatch(const at::Tensor& grad_in1, const at::Tensor& grad_in2, + const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor grad_in1_arg = grad_in1; + at::Tensor grad_in2_arg = grad_in2; + index_mul_2d_float_backward(grad_in1_arg, grad_in2_arg, grad_out, in1, in2, idx1); +} + +void index_mul_2d_float_backward_backward_dispatch(const at::Tensor& grad_grad_out, const at::Tensor& grad_in1, + const at::Tensor& grad_in2, const at::Tensor& grad_out, + const at::Tensor& grad_grad_in1, + const at::Tensor& grad_grad_in2, const at::Tensor& in1, + const at::Tensor& in2, const at::Tensor& idx1) { + at::Tensor grad_grad_out_arg = grad_grad_out; + at::Tensor grad_in1_arg = grad_in1; + at::Tensor grad_in2_arg = grad_in2; + index_mul_2d_float_backwrad_backward(grad_grad_out_arg, grad_in1_arg, grad_in2_arg, grad_out, grad_grad_in1, + grad_grad_in2, in1, in2, idx1); +} + +void index_mul_2d_half_forward_dispatch(const at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor out_arg = out; + index_mul_2d_half_forward(out_arg, in1, in2, idx1); +} + +void index_mul_2d_half_backward_dispatch(const at::Tensor& grad_in1, const at::Tensor& grad_in2, + const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor grad_in1_arg = grad_in1; + at::Tensor grad_in2_arg = grad_in2; + index_mul_2d_half_backward(grad_in1_arg, grad_in2_arg, grad_out, in1, in2, idx1); +} + +void index_mul_2d_half_backward_backward_dispatch(const at::Tensor& grad_grad_out, const at::Tensor& grad_in1, + const at::Tensor& grad_in2, const at::Tensor& grad_out, + const at::Tensor& grad_grad_in1, + const at::Tensor& grad_grad_in2, const at::Tensor& in1, + const at::Tensor& in2, const at::Tensor& idx1) { + at::Tensor grad_grad_out_arg = grad_grad_out; + at::Tensor grad_in1_arg = grad_in1; + at::Tensor grad_in2_arg = grad_in2; + index_mul_2d_half_backwrad_backward(grad_grad_out_arg, grad_in1_arg, grad_in2_arg, grad_out, grad_grad_in1, + grad_grad_in2, in1, in2, idx1); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("index_mul_2d_float_forward(Tensor(a!) out, Tensor in1, Tensor in2, Tensor idx1) -> ()"); + m.def("index_mul_2d_float_backward(Tensor(a!) grad_in1, Tensor(b!) grad_in2, Tensor grad_out, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); + m.def("index_mul_2d_float_backward_backward(Tensor(a!) grad_grad_out, Tensor(b!) grad_in1, " + "Tensor(c!) grad_in2, Tensor grad_out, Tensor grad_grad_in1, Tensor grad_grad_in2, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); + m.def("index_mul_2d_half_forward(Tensor(a!) out, Tensor in1, Tensor in2, Tensor idx1) -> ()"); + m.def("index_mul_2d_half_backward(Tensor(a!) grad_in1, Tensor(b!) grad_in2, Tensor grad_out, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); + m.def("index_mul_2d_half_backward_backward(Tensor(a!) grad_grad_out, Tensor(b!) grad_in1, " + "Tensor(c!) grad_in2, Tensor grad_out, Tensor grad_grad_in1, Tensor grad_grad_in2, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("index_mul_2d_float_forward", &index_mul_2d_float_forward_dispatch); + m.impl("index_mul_2d_float_backward", &index_mul_2d_float_backward_dispatch); + m.impl("index_mul_2d_float_backward_backward", &index_mul_2d_float_backward_backward_dispatch); + m.impl("index_mul_2d_half_forward", &index_mul_2d_half_forward_dispatch); + m.impl("index_mul_2d_half_backward", &index_mul_2d_half_backward_dispatch); + m.impl("index_mul_2d_half_backward_backward", &index_mul_2d_half_backward_backward_dispatch); } diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp index 025fd52a8..89af4b675 100644 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ b/apex/contrib/csrc/layer_norm/ln_api.cpp @@ -1,6 +1,7 @@ -#include +#include +#include -#include "ATen/cuda/CUDAContext.h" +#include #include "ln.h" /* @@ -30,12 +31,12 @@ BwdRegistry BWD_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// -uint32_t get_type_id(torch::Dtype dtype) { - if (dtype == torch::kFloat16) { +uint32_t get_type_id(at::ScalarType dtype) { + if (dtype == at::kHalf) { return TypeId::Value; - } else if (dtype == torch::kBFloat16) { + } else if (dtype == at::kBFloat16) { return TypeId::Value; - } else if (dtype == torch::kFloat32) { + } else if (dtype == at::kFloat) { return TypeId::Value; } else { TORCH_CHECK(false, "Type not supported: ", dtype); @@ -44,7 +45,8 @@ uint32_t get_type_id(torch::Dtype dtype) { //////////////////////////////////////////////////////////////////////////////////////////////////// -uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) { +uint64_t get_key(at::ScalarType wtype, at::ScalarType itype, at::ScalarType otype, at::ScalarType ctype, + uint64_t hidden_size) { using namespace layer_norm; uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6); @@ -56,8 +58,8 @@ uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, tor //////////////////////////////////////////////////////////////////////////////////////////////////// -layer_norm::FwdFunction& get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, - torch::Dtype ctype, uint32_t hidden_size) { +layer_norm::FwdFunction& get_fwd_launcher(at::ScalarType wtype, at::ScalarType itype, at::ScalarType otype, + at::ScalarType ctype, uint32_t hidden_size) { auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); if (iter != layer_norm::FWD_FUNCS.end()) { return iter->second; @@ -68,8 +70,8 @@ layer_norm::FwdFunction& get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype //////////////////////////////////////////////////////////////////////////////////////////////////// -layer_norm::BwdFunction& get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, - torch::Dtype ctype, uint32_t hidden_size) { +layer_norm::BwdFunction& get_bwd_launcher(at::ScalarType wtype, at::ScalarType itype, at::ScalarType otype, + at::ScalarType ctype, uint32_t hidden_size) { auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); if (iter != layer_norm::BWD_FUNCS.end()) { return iter->second; @@ -87,7 +89,7 @@ std::vector ln_fwd(const at::Tensor& x, // BxSxhidden_size auto itype = x.scalar_type(); auto wtype = gamma.scalar_type(); auto otype = wtype; - auto ctype = torch::kFloat32; + auto ctype = at::kFloat; TORCH_CHECK(beta.scalar_type() == wtype); @@ -110,10 +112,10 @@ std::vector ln_fwd(const at::Tensor& x, // BxSxhidden_size auto opts = x.options(); - auto z = torch::empty(sizes, opts.dtype(otype)); + auto z = at::empty(sizes, opts.dtype(otype)); - auto mu = torch::empty({rows}, opts.dtype(ctype)); - auto rsigma = torch::empty({rows}, opts.dtype(ctype)); + auto mu = at::empty({rows}, opts.dtype(ctype)); + auto rsigma = at::empty({rows}, opts.dtype(ctype)); layer_norm::LaunchParams launch_params; @@ -142,8 +144,8 @@ std::vector ln_fwd(const at::Tensor& x, // BxSxhidden_size if (launch_params.barrier_size > 0) { auto options = x.options(); - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + barrier = at::zeros(launch_params.barrier_size, options.dtype(at::kInt)); + workspace = at::empty(launch_params.workspace_bytes, options.dtype(at::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } @@ -157,15 +159,15 @@ std::vector ln_fwd(const at::Tensor& x, // BxSxhidden_size //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector ln_bwd(const at::Tensor& dz, // BxSxhidden_size const at::Tensor& x_or_z, // BxSxhidden_size - c10::optional& mu_, // BxS, FP32! + const c10::optional& mu_, // BxS, FP32! const at::Tensor& rsigma, // BxS, FP32! const at::Tensor& gamma, // hidden_size - c10::optional& beta_, // hidden_size + const c10::optional& beta_, // hidden_size bool memory_efficient) { auto itype = x_or_z.scalar_type(); auto wtype = gamma.scalar_type(); auto otype = wtype; - auto ctype = torch::kFloat32; + auto ctype = at::kFloat; TORCH_CHECK(dz.dtype() == otype); TORCH_CHECK(rsigma.dtype() == ctype); @@ -200,9 +202,9 @@ std::vector ln_bwd(const at::Tensor& dz, // BxSxh auto options = x_or_z.options(); - auto dx = torch::empty_like(x_or_z); - auto dgamma = torch::empty_like(gamma); - auto dbeta = torch::empty_like(gamma); + auto dx = at::empty_like(x_or_z); + auto dgamma = at::empty_like(gamma); + auto dbeta = at::empty_like(gamma); layer_norm::LaunchParams launch_params; launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); @@ -212,8 +214,8 @@ std::vector ln_bwd(const at::Tensor& dz, // BxSxh launcher(launch_params, true); - auto dgamma_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); - auto dbeta_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); + auto dgamma_part = at::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); + auto dbeta_part = at::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); at::Tensor workspace, barrier; layer_norm::BwdParams& params = launch_params.params; @@ -237,8 +239,8 @@ std::vector ln_bwd(const at::Tensor& dz, // BxSxh if (launch_params.barrier_size > 0) { // TODO Any way to avoid this? - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + barrier = at::zeros(launch_params.barrier_size, options.dtype(at::kInt)); + workspace = at::empty(launch_params.workspace_bytes, options.dtype(at::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } @@ -250,8 +252,29 @@ std::vector ln_bwd(const at::Tensor& dz, // BxSxh //////////////////////////////////////////////////////////////////////////////////////////////////// -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "CUDA LayerNorm"; - m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel", py::call_guard()); - m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel", py::call_guard()); +namespace { +std::vector apex_fast_layer_norm_ln_fwd(const at::Tensor& x, const at::Tensor& gamma, + const at::Tensor& beta, double epsilon) { + return ln_fwd(x, gamma, beta, static_cast(epsilon)); } + +std::vector apex_fast_layer_norm_ln_bwd(const at::Tensor& dz, const at::Tensor& x_or_z, + const c10::optional& mu, const at::Tensor& rsigma, + const at::Tensor& gamma, + const c10::optional& beta, bool memory_efficient) { + return ln_bwd(dz, x_or_z, mu, rsigma, gamma, beta, memory_efficient); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fast_layer_norm_ln_fwd(Tensor x, Tensor gamma, Tensor beta, float epsilon) -> Tensor[]"); + m.def("fast_layer_norm_ln_bwd(Tensor dz, Tensor x_or_z, Tensor? mu, Tensor rsigma, Tensor gamma, Tensor? beta, " + "bool memory_efficient) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fast_layer_norm_ln_fwd", &apex_fast_layer_norm_ln_fwd); + m.impl("fast_layer_norm_ln_bwd", &apex_fast_layer_norm_ln_bwd); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu index a18cf12fc..e5fa4f413 100644 --- a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif #include #include -#include +#include #include #include @@ -19,7 +21,7 @@ namespace multihead_attn { namespace fused_softmax { namespace additive_mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const half* pad_mask, +std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const half* pad_mask, float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; @@ -36,11 +38,11 @@ std::vector fwd_cuda(bool is_training, int heads, torch::Tensor c // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = input.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* input_ptr = static_cast(input.data_ptr()); @@ -71,8 +73,8 @@ std::vector fwd_cuda(bool is_training, int heads, torch::Tensor c return {dropout_results, dropout_mask, softmax_results}; } -torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, float dropout_prob) { +at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); const int k_seq_len = q_seq_len; @@ -84,7 +86,7 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tens cublasSetStream(handle, stream); // Output Tensor Allocations - // torch::Tensor input_grads = torch::empty_like(output_grads); + // at::Tensor input_grads = at::empty_like(output_grads); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu index 2f93f108a..b92647375 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif #include #include -#include +#include #include #include @@ -18,9 +20,9 @@ namespace multihead_attn { namespace encdec { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); @@ -50,15 +52,15 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs_q.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); - torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs_q, act_options); + at::Tensor input_lin_q_results = at::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); + at::Tensor input_lin_kv_results = at::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor outputs = at::empty_like(inputs_q, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); @@ -141,12 +143,12 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); @@ -174,16 +176,16 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_q_grads = torch::empty_like(inputs_q); - torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); - torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); - torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_q_grads = at::empty_like(inputs_q); + at::Tensor input_kv_grads = at::empty_like(inputs_kv); + at::Tensor input_weight_q_grads = at::empty_like(input_weights_q); + at::Tensor input_weight_kv_grads = at::empty_like(input_weights_kv); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); - at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_q_output_grads = at::empty_like(input_lin_q_results); + at::Tensor input_lin_kv_output_grads = at::empty_like(input_lin_kv_results); auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu index 4376ededb..4cdb3c301 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif #include #include -#include +#include #include #include @@ -19,10 +21,10 @@ namespace multihead_attn { namespace encdec_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); @@ -53,22 +55,22 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs_q.options().requires_grad(false); - auto lyr_nrm_options = act_options.dtype(torch::kFloat32); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor lyr_nrm_mean = torch::empty({batches_q}, lyr_nrm_options); - torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options); - torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options); - - torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); - torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options); - torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options); - torch::Tensor outputs = torch::empty_like(inputs_q, act_options); + auto lyr_nrm_options = act_options.dtype(at::kFloat); + auto mask_options = act_options.dtype(at::kByte); + + at::Tensor lyr_nrm_mean = at::empty({batches_q}, lyr_nrm_options); + at::Tensor lyr_nrm_invvar = at::empty({batches_q}, lyr_nrm_options); + at::Tensor lyr_nrm_results = at::empty_like(inputs_q, act_options); + + at::Tensor input_lin_q_results = at::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); + at::Tensor input_lin_kv_results = at::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor output_lin_results = at::empty_like(inputs_q, act_options); + at::Tensor dropout_add_mask = at::empty_like(inputs_q, mask_options); + at::Tensor outputs = at::empty_like(inputs_q, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); @@ -175,15 +177,15 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he matmul2_results, dropout_add_mask, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, - torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, + at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); @@ -212,20 +214,20 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_q_grads = torch::empty_like(inputs_q); - torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); - torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); - torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); - torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); - torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_q_grads = at::empty_like(inputs_q); + at::Tensor input_kv_grads = at::empty_like(inputs_kv); + at::Tensor lyr_nrm_gamma_grads = at::empty_like(lyr_nrm_gamma_weights); + at::Tensor lyr_nrm_beta_grads = at::empty_like(lyr_nrm_beta_weights); + at::Tensor input_weight_q_grads = at::empty_like(input_weights_q); + at::Tensor input_weight_kv_grads = at::empty_like(input_weights_kv); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor dropout_add_grads = torch::empty_like(output_grads); - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); - at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); - at::Tensor input_lin_q_grads = torch::empty_like(inputs_q); + at::Tensor dropout_add_grads = at::empty_like(output_grads); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_q_output_grads = at::empty_like(input_lin_q_results); + at::Tensor input_lin_kv_output_grads = at::empty_like(input_lin_kv_results); + at::Tensor input_lin_q_grads = at::empty_like(inputs_q); auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); diff --git a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu index bde131da5..fcb626165 100644 --- a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif #include #include -#include +#include #include #include @@ -17,7 +19,7 @@ namespace multihead_attn { namespace fused_softmax { namespace mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const uint8_t* pad_mask, +std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const uint8_t* pad_mask, float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; @@ -34,11 +36,11 @@ std::vector fwd_cuda(bool is_training, int heads, torch::Tensor c // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = input.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* input_ptr = static_cast(input.data_ptr()); @@ -69,8 +71,8 @@ std::vector fwd_cuda(bool is_training, int heads, torch::Tensor c return {dropout_results, dropout_mask, softmax_results}; } -torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob) { +at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); const int k_seq_len = q_seq_len; @@ -82,7 +84,7 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tens cublasSetStream(handle, stream); // Output Tensor Allocations - // torch::Tensor input_grads = torch::empty_like(output_grads); + // at::Tensor input_grads = at::empty_like(output_grads); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad diff --git a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp index e1d3bcc29..1272dde59 100644 --- a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp +++ b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp @@ -1,5 +1,6 @@ #include -#include +#include +#include #include @@ -13,14 +14,14 @@ namespace multihead_attn { namespace fused_softmax { namespace additive_mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const half* pad_mask, +std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const half* pad_mask, float dropout_prob); -torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, float dropout_prob); +at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool is_training, int heads, torch::Tensor const& input, - torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd(bool use_mask, bool is_training, int heads, at::Tensor const& input, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { @@ -32,8 +33,8 @@ std::vector fwd(bool use_mask, bool is_training, int heads, torch dropout_prob); } -torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, float dropout_prob) { +at::Tensor bwd(bool use_mask, int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); @@ -48,14 +49,14 @@ torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, t } // namespace additive_mask_softmax_dropout namespace mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const uint8_t* pad_mask, +std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const uint8_t* pad_mask, float dropout_prob); -torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob); +at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool is_training, int heads, torch::Tensor const& input, - torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd(bool use_mask, bool is_training, int heads, at::Tensor const& input, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); @@ -68,8 +69,8 @@ std::vector fwd(bool use_mask, bool is_training, int heads, torch dropout_prob); } -torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, torch::Tensor const& padding_mask, float dropout_prob) { +at::Tensor bwd(bool use_mask, int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, at::Tensor const& padding_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); @@ -89,22 +90,22 @@ torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, t namespace encdec { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights_q.dim() == 2, "expected 2D tensor"); @@ -126,12 +127,12 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); @@ -170,28 +171,28 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace encdec_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, - torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, + at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); @@ -218,15 +219,15 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, - torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, + at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); @@ -278,19 +279,19 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace self { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob); +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); @@ -308,11 +309,11 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob) { +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -342,23 +343,23 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace self_bias { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& input_biases, torch::Tensor const& output_biases, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, const uint8_t* pad_mask, float dropout_prob); -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - // torch::Tensor const& input_biases, - // torch::Tensor const& output_biases, - torch::Tensor const& dropout_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& input_biases, - torch::Tensor const& output_biases, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + // at::Tensor const& input_biases, + // at::Tensor const& output_biases, + at::Tensor const& dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& input_biases, + at::Tensor const& output_biases, at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); @@ -376,11 +377,11 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob) { +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -410,25 +411,25 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace self_bias_additive_mask { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& input_biases, torch::Tensor const& output_biases, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, const half* pad_mask, float dropout_prob); -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, - // torch::Tensor const& softmax_results, - torch::Tensor const& bmm1_results, torch::Tensor const& pad_mask, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - // torch::Tensor const& input_biases, - // torch::Tensor const& output_biases, - torch::Tensor const& dropout_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& input_biases, - torch::Tensor const& output_biases, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, + // at::Tensor const& softmax_results, + at::Tensor const& bmm1_results, at::Tensor const& pad_mask, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + // at::Tensor const& input_biases, + // at::Tensor const& output_biases, + at::Tensor const& dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& input_biases, + at::Tensor const& output_biases, at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); @@ -447,11 +448,11 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& bmm1_results, - torch::Tensor const& pad_mask, torch::Tensor const& input_lin_results, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& bmm1_results, + at::Tensor const& pad_mask, at::Tensor const& input_lin_results, + at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); @@ -481,24 +482,24 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace self_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); - -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results, - torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, - torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - torch::Tensor const& dropout_add_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); + +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, + at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, + at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); @@ -520,14 +521,14 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results, - torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, - torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - torch::Tensor const& dropout_add_mask, float dropout_prob) { +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, + at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -569,42 +570,251 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor } // end namespace self_norm_add } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("additive_mask_softmax_dropout_forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, - "Self Multihead Attention masked softmax dropout -- Forward.", py::call_guard()); - m.def("additive_mask_softmax_dropout_backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, - "Self Multihead Attention masked softmax dropout -- Backward.", py::call_guard()); - m.def("mask_softmax_dropout_forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, - "Self Multihead Attention masked softmax dropout -- Forward.", py::call_guard()); - m.def("mask_softmax_dropout_backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, - "Self Multihead Attention masked softmax dropout -- Backward.", py::call_guard()); - m.def("encdec_multihead_attn_forward", &multihead_attn::encdec::cublas_gemmex::fwd, - "Encdec Multihead Attention Forward.", py::call_guard()); - m.def("encdec_multihead_attn_backward", &multihead_attn::encdec::cublas_gemmex::bwd, - "Encdec Multihead Attention Backward.", py::call_guard()); - m.def("encdec_multihead_attn_norm_add_forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, - "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.", - py::call_guard()); - m.def("encdec_multihead_attn_norm_add_backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, - "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.", - py::call_guard()); - m.def("self_attn_forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward.", - py::call_guard()); - m.def("self_attn_backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward.", - py::call_guard()); - m.def("self_attn_bias_forward", &multihead_attn::self_bias::cublas_gemmex::fwd, - "Self Multihead Attention with Bias -- Forward.", py::call_guard()); - m.def("self_attn_bias_backward", &multihead_attn::self_bias::cublas_gemmex::bwd, - "Self Multihead Attention with Bias -- Backward.", py::call_guard()); - m.def("self_attn_bias_additive_mask_forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, - "Self Multihead Attention with Bias -- Forward.", py::call_guard()); - m.def("self_attn_bias_additive_mask_backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, - "Self Multihead Attention with Bias -- Backward.", py::call_guard()); - m.def("self_attn_norm_add_forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, - "Self Multihead Attention Plus Layer Norm and Residual Add Forward.", py::call_guard()); - m.def("self_attn_norm_add_backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, - "Self Multihead Attention Plus Layer Norm and Residual Add Backward.", - py::call_guard()); +namespace { +int as_int(int64_t value) { return static_cast(value); } + +float as_float(double value) { return static_cast(value); } + +std::vector apex_fast_multihead_attn_additive_mask_softmax_dropout_forward( + bool use_mask, bool is_training, int64_t heads, at::Tensor const& input, at::Tensor const& pad_mask, + double dropout_prob) { + return multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd( + use_mask, is_training, as_int(heads), input, pad_mask, as_float(dropout_prob)); +} + +at::Tensor apex_fast_multihead_attn_additive_mask_softmax_dropout_backward( + bool use_mask, int64_t heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, double dropout_prob) { + return multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd( + use_mask, as_int(heads), output_grads, softmax_results, dropout_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_mask_softmax_dropout_forward( + bool use_mask, bool is_training, int64_t heads, at::Tensor const& input, at::Tensor const& pad_mask, + double dropout_prob) { + return multihead_attn::fused_softmax::mask_softmax_dropout::fwd( + use_mask, is_training, as_int(heads), input, pad_mask, as_float(dropout_prob)); +} + +at::Tensor apex_fast_multihead_attn_mask_softmax_dropout_backward( + bool use_mask, int64_t heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, at::Tensor const& padding_mask, double dropout_prob) { + return multihead_attn::fused_softmax::mask_softmax_dropout::bwd( + use_mask, as_int(heads), output_grads, softmax_results, dropout_mask, padding_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_encdec_multihead_attn_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::encdec::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs_q, + inputs_kv, input_weights_q, input_weights_kv, output_weights, + pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_encdec_multihead_attn_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, at::Tensor const& dropout_mask, + double dropout_prob) { + return multihead_attn::encdec::cublas_gemmex::bwd( + as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, + input_lin_kv_results, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, dropout_mask, + as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_encdec_multihead_attn_norm_add_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, at::Tensor const& pad_mask, + double dropout_prob) { + return multihead_attn::encdec_norm_add::cublas_gemmex::fwd( + use_mask, use_time_mask, is_training, as_int(heads), inputs_q, inputs_kv, lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_encdec_multihead_attn_norm_add_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, double dropout_prob) { + return multihead_attn::encdec_norm_add::cublas_gemmex::bwd( + as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, + input_lin_kv_results, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q, inputs_kv, + lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, dropout_mask, + dropout_add_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& pad_mask, + double dropout_prob) { + return multihead_attn::self::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs, + input_weights, output_weights, pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, double dropout_prob) { + return multihead_attn::self::cublas_gemmex::bwd(as_int(heads), output_grads, matmul2_results, dropout_results, + softmax_results, input_lin_results, inputs, input_weights, + output_weights, dropout_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_bias_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& input_biases, + at::Tensor const& output_biases, at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::self_bias::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs, + input_weights, output_weights, input_biases, output_biases, + pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_bias_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, double dropout_prob) { + return multihead_attn::self_bias::cublas_gemmex::bwd( + as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, inputs, + input_weights, output_weights, dropout_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_bias_additive_mask_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& input_biases, + at::Tensor const& output_biases, at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd( + use_mask, use_time_mask, is_training, as_int(heads), inputs, input_weights, output_weights, input_biases, + output_biases, pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_bias_additive_mask_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& bmm1_results, at::Tensor const& pad_mask, + at::Tensor const& input_lin_results, at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, double dropout_prob) { + return multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd( + as_int(heads), output_grads, matmul2_results, dropout_results, bmm1_results, pad_mask, input_lin_results, inputs, + input_weights, output_weights, dropout_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_norm_add_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& pad_mask, + double dropout_prob) { + return multihead_attn::self_norm_add::cublas_gemmex::fwd( + use_mask, use_time_mask, is_training, as_int(heads), inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, + input_weights, output_weights, pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_norm_add_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, double dropout_prob) { + return multihead_attn::self_norm_add::cublas_gemmex::bwd( + as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, + lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, + input_weights, output_weights, dropout_mask, dropout_add_mask, as_float(dropout_prob)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fast_multihead_attn_additive_mask_softmax_dropout_forward(bool use_mask, bool is_training, int heads, " + "Tensor input, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_additive_mask_softmax_dropout_backward(bool use_mask, int heads, Tensor output_grads, " + "Tensor softmax_results, Tensor dropout_mask, float dropout_prob) -> Tensor"); + m.def("fast_multihead_attn_mask_softmax_dropout_forward(bool use_mask, bool is_training, int heads, Tensor input, " + "Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_mask_softmax_dropout_backward(bool use_mask, int heads, Tensor output_grads, " + "Tensor softmax_results, Tensor dropout_mask, Tensor padding_mask, float dropout_prob) -> Tensor"); + m.def("fast_multihead_attn_encdec_multihead_attn_forward(bool use_mask, bool use_time_mask, bool is_training, " + "int heads, Tensor inputs_q, Tensor inputs_kv, Tensor input_weights_q, Tensor input_weights_kv, " + "Tensor output_weights, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_encdec_multihead_attn_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_q_results, Tensor input_lin_kv_results, " + "Tensor inputs_q, Tensor inputs_kv, Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, " + "Tensor dropout_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_encdec_multihead_attn_norm_add_forward(bool use_mask, bool use_time_mask, " + "bool is_training, int heads, Tensor inputs_q, Tensor inputs_kv, Tensor lyr_nrm_gamma_weights, " + "Tensor lyr_nrm_beta_weights, Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, " + "Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_encdec_multihead_attn_norm_add_backward(int heads, Tensor output_grads, " + "Tensor matmul2_results, Tensor dropout_results, Tensor softmax_results, Tensor input_lin_q_results, " + "Tensor input_lin_kv_results, Tensor lyr_nrm_results, Tensor lyr_nrm_mean, Tensor lyr_nrm_invvar, " + "Tensor inputs_q, Tensor inputs_kv, Tensor lyr_nrm_gamma_weights, Tensor lyr_nrm_beta_weights, " + "Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, Tensor dropout_mask, " + "Tensor dropout_add_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_self_attn_forward(bool use_mask, bool use_time_mask, bool is_training, int heads, " + "Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor pad_mask, float dropout_prob) " + "-> Tensor[]"); + m.def("fast_multihead_attn_self_attn_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor inputs, " + "Tensor input_weights, Tensor output_weights, Tensor dropout_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_self_attn_bias_forward(bool use_mask, bool use_time_mask, bool is_training, int heads, " + "Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor input_biases, Tensor output_biases, " + "Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_self_attn_bias_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor inputs, " + "Tensor input_weights, Tensor output_weights, Tensor dropout_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_self_attn_bias_additive_mask_forward(bool use_mask, bool use_time_mask, " + "bool is_training, int heads, Tensor inputs, Tensor input_weights, Tensor output_weights, " + "Tensor input_biases, Tensor output_biases, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_self_attn_bias_additive_mask_backward(int heads, Tensor output_grads, " + "Tensor matmul2_results, Tensor dropout_results, Tensor bmm1_results, Tensor pad_mask, " + "Tensor input_lin_results, Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor dropout_mask, " + "float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_self_attn_norm_add_forward(bool use_mask, bool use_time_mask, bool is_training, " + "int heads, Tensor inputs, Tensor lyr_nrm_gamma_weights, Tensor lyr_nrm_beta_weights, " + "Tensor input_weights, Tensor output_weights, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def("fast_multihead_attn_self_attn_norm_add_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor lyr_nrm_results, " + "Tensor lyr_nrm_mean, Tensor lyr_nrm_invvar, Tensor inputs, Tensor lyr_nrm_gamma_weights, " + "Tensor lyr_nrm_beta_weights, Tensor input_weights, Tensor output_weights, Tensor dropout_mask, " + "Tensor dropout_add_mask, float dropout_prob) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fast_multihead_attn_additive_mask_softmax_dropout_forward", + &apex_fast_multihead_attn_additive_mask_softmax_dropout_forward); + m.impl("fast_multihead_attn_additive_mask_softmax_dropout_backward", + &apex_fast_multihead_attn_additive_mask_softmax_dropout_backward); + m.impl("fast_multihead_attn_mask_softmax_dropout_forward", + &apex_fast_multihead_attn_mask_softmax_dropout_forward); + m.impl("fast_multihead_attn_mask_softmax_dropout_backward", + &apex_fast_multihead_attn_mask_softmax_dropout_backward); + m.impl("fast_multihead_attn_encdec_multihead_attn_forward", + &apex_fast_multihead_attn_encdec_multihead_attn_forward); + m.impl("fast_multihead_attn_encdec_multihead_attn_backward", + &apex_fast_multihead_attn_encdec_multihead_attn_backward); + m.impl("fast_multihead_attn_encdec_multihead_attn_norm_add_forward", + &apex_fast_multihead_attn_encdec_multihead_attn_norm_add_forward); + m.impl("fast_multihead_attn_encdec_multihead_attn_norm_add_backward", + &apex_fast_multihead_attn_encdec_multihead_attn_norm_add_backward); + m.impl("fast_multihead_attn_self_attn_forward", &apex_fast_multihead_attn_self_attn_forward); + m.impl("fast_multihead_attn_self_attn_backward", &apex_fast_multihead_attn_self_attn_backward); + m.impl("fast_multihead_attn_self_attn_bias_forward", &apex_fast_multihead_attn_self_attn_bias_forward); + m.impl("fast_multihead_attn_self_attn_bias_backward", &apex_fast_multihead_attn_self_attn_bias_backward); + m.impl("fast_multihead_attn_self_attn_bias_additive_mask_forward", + &apex_fast_multihead_attn_self_attn_bias_additive_mask_forward); + m.impl("fast_multihead_attn_self_attn_bias_additive_mask_backward", + &apex_fast_multihead_attn_self_attn_bias_additive_mask_backward); + m.impl("fast_multihead_attn_self_attn_norm_add_forward", &apex_fast_multihead_attn_self_attn_norm_add_forward); + m.impl("fast_multihead_attn_self_attn_norm_add_backward", &apex_fast_multihead_attn_self_attn_norm_add_backward); } #undef CHECK_CUDA diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu index 7be40eba5..7d08be994 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif #include #include -#include +#include #include #include @@ -18,9 +20,9 @@ namespace multihead_attn { namespace self_bias_additive_mask { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& input_biases, torch::Tensor const& output_biases, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, const half* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); @@ -47,14 +49,14 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); + at::Tensor input_lin_results = at::empty({q_seq_len, sequences, output_lin_dim}, act_options); + at::Tensor bmm1_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor outputs = at::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); @@ -121,11 +123,11 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he return {input_lin_results, bmm1_results, dropout_results, dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& bmm1_results, - torch::Tensor const& pad_mask, torch::Tensor const& input_lin_results, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& bmm1_results, + at::Tensor const& pad_mask, at::Tensor const& input_lin_results, + at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); @@ -149,13 +151,13 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_grads = at::empty_like(inputs); + at::Tensor input_weight_grads = at::empty_like(input_weights); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_output_grads = at::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu index 8e39a28d6..1a2a4942c 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif #include #include -#include +#include #include #include @@ -18,9 +20,9 @@ namespace multihead_attn { namespace self_bias { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& input_biases, torch::Tensor const& output_biases, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); @@ -47,14 +49,14 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); + at::Tensor input_lin_results = at::empty({q_seq_len, sequences, output_lin_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor outputs = at::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); @@ -132,11 +134,11 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -159,13 +161,13 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_grads = at::empty_like(inputs); + at::Tensor input_weight_grads = at::empty_like(input_weights); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_output_grads = at::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu index 929e7e791..cd071d3a6 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif #include #include -#include +#include #include #include @@ -18,8 +20,8 @@ namespace multihead_attn { namespace self { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); @@ -45,14 +47,14 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); + at::Tensor input_lin_results = at::empty({q_seq_len, sequences, output_lin_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor outputs = at::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); @@ -125,11 +127,11 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -152,13 +154,13 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_grads = at::empty_like(inputs); + at::Tensor input_weight_grads = at::empty_like(input_weights); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_output_grads = at::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu index fa481f8c9..d20985a43 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif #include #include -#include +#include #include #include @@ -19,10 +21,10 @@ namespace multihead_attn { namespace self_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -48,21 +50,21 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); - auto lyr_nrm_options = act_options.dtype(torch::kFloat32); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options); - torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options); - torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options); - - torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor output_lin_results = torch::empty_like(inputs, act_options); - torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); + auto lyr_nrm_options = act_options.dtype(at::kFloat); + auto mask_options = act_options.dtype(at::kByte); + + at::Tensor lyr_nrm_mean = at::empty({batches}, lyr_nrm_options); + at::Tensor lyr_nrm_invvar = at::empty({batches}, lyr_nrm_options); + at::Tensor lyr_nrm_results = at::empty_like(inputs, act_options); + + at::Tensor input_lin_results = at::empty({q_seq_len, sequences, output_lin_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor output_lin_results = at::empty_like(inputs, act_options); + at::Tensor dropout_add_mask = at::empty_like(inputs, mask_options); + at::Tensor outputs = at::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); @@ -160,14 +162,14 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he dropout_results, dropout_mask, matmul2_results, dropout_add_mask, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results, - torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, - torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - torch::Tensor const& dropout_add_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, + at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -191,17 +193,17 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); - torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_grads = at::empty_like(inputs); + at::Tensor lyr_nrm_gamma_grads = at::empty_like(lyr_nrm_gamma_weights); + at::Tensor lyr_nrm_beta_grads = at::empty_like(lyr_nrm_beta_weights); + at::Tensor input_weight_grads = at::empty_like(input_weights); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - torch::Tensor dropout_add_grads = torch::empty_like(output_grads); - torch::Tensor output_lin_grads = torch::empty_like(matmul2_results); - torch::Tensor matmul2_grads = torch::empty_like(dropout_results); - torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); - torch::Tensor input_lin_grads = torch::empty_like(inputs); + at::Tensor dropout_add_grads = at::empty_like(output_grads); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_output_grads = at::empty_like(input_lin_results); + at::Tensor input_lin_grads = at::empty_like(inputs); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; diff --git a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh index 8537e81e7..b9184c064 100644 --- a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh +++ b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh @@ -1,7 +1,9 @@ #pragma once #include #include +#if __has_include() #include +#endif #include #include diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp b/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp index 3db57c4e9..6d532ad19 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp @@ -15,15 +15,55 @@ */ #include "nccl_p2p_cuda.cuh" +#include -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id", - py::call_guard()); - m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm", - py::call_guard()); - m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, - "left_right_halo_exchange_inplace", py::call_guard()); - m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange", - py::call_guard()); - m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay", py::call_guard()); +namespace { +at::Tensor apex_nccl_p2p_get_unique_nccl_id(int64_t n) { + return apex::contrib::nccl_p2p::get_unique_nccl_id(static_cast(n)); +} + +int64_t apex_nccl_p2p_init_nccl_comm(at::Tensor unique_nccl_id, int64_t my_rank, int64_t num_ranks) { + return apex::contrib::nccl_p2p::init_nccl_comm(unique_nccl_id, static_cast(my_rank), + static_cast(num_ranks)); +} + +void apex_nccl_p2p_left_right_halo_exchange_inplace(int64_t handle, int64_t left_rank, int64_t right_rank, + at::Tensor left_output_halo, at::Tensor right_output_halo, + at::Tensor left_input_halo, at::Tensor right_input_halo) { + apex::contrib::nccl_p2p::left_right_halo_exchange_inplace( + static_cast(handle), static_cast(left_rank), static_cast(right_rank), left_output_halo, + right_output_halo, left_input_halo, right_input_halo); +} + +std::vector apex_nccl_p2p_left_right_halo_exchange(int64_t handle, int64_t left_rank, int64_t right_rank, + at::Tensor left_output_halo, + at::Tensor right_output_halo) { + return apex::contrib::nccl_p2p::left_right_halo_exchange( + static_cast(handle), static_cast(left_rank), static_cast(right_rank), left_output_halo, + right_output_halo); +} + +void apex_nccl_p2p_add_delay(int64_t delay) { apex::contrib::nccl_p2p::add_delay(static_cast(delay)); } +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("nccl_p2p_get_unique_nccl_id(int n) -> Tensor"); + m.def("nccl_p2p_init_nccl_comm(Tensor unique_nccl_id, int my_rank, int num_ranks) -> int"); + m.def("nccl_p2p_left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, " + "Tensor left_output_halo, Tensor right_output_halo, Tensor(a!) left_input_halo, " + "Tensor(b!) right_input_halo) -> ()"); + m.def("nccl_p2p_left_right_halo_exchange(int handle, int left_rank, int right_rank, Tensor left_output_halo, " + "Tensor right_output_halo) -> Tensor[]"); + m.def("nccl_p2p_add_delay(int delay) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { + m.impl("nccl_p2p_get_unique_nccl_id", &apex_nccl_p2p_get_unique_nccl_id); + m.impl("nccl_p2p_init_nccl_comm", &apex_nccl_p2p_init_nccl_comm); + m.impl("nccl_p2p_add_delay", &apex_nccl_p2p_add_delay); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("nccl_p2p_left_right_halo_exchange_inplace", &apex_nccl_p2p_left_right_halo_exchange_inplace); + m.impl("nccl_p2p_left_right_halo_exchange", &apex_nccl_p2p_left_right_halo_exchange); } diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu index 21091bd1d..01a3a0c43 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu @@ -1,6 +1,6 @@ +#include #include #include -#include #include #include @@ -84,8 +84,8 @@ class NcclCommWrapper { ncclDataType_t ncclType = get_nccl_type(left_output_halo); bool left_zero = (left_rank < 0); bool right_zero = (right_rank < 0); - size_t left_n = torch::numel(left_output_halo); - size_t right_n = torch::numel(right_output_halo); + size_t left_n = left_output_halo.numel(); + size_t right_n = right_output_halo.numel(); assert(left_n > 0 && left_n == right_n); if (left_zero) { left_input_halo.zero_(); @@ -119,8 +119,8 @@ class NcclCommWrapper { // after halo exchange: // left_output_halo of rank+1 ends up in right_input_halo of rank // right_output_halo of rank-1 ends up in left_input_halo of rank - auto right_input_halo = torch::empty_like(left_output_halo); - auto left_input_halo = torch::empty_like(right_output_halo); + auto right_input_halo = at::empty_like(left_output_halo); + auto left_input_halo = at::empty_like(right_output_halo); left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo); return {left_input_halo, right_input_halo}; @@ -161,8 +161,8 @@ namespace nccl_p2p { at::Tensor get_unique_nccl_id(int n) { ncclUniqueId id; ncclGetUniqueId(&id); - auto id_tensor = torch::empty({n, (int)sizeof(ncclUniqueId)}, - torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false)); + auto id_tensor = at::empty({n, (int)sizeof(ncclUniqueId)}, + at::TensorOptions().dtype(at::kByte).device(at::kCPU).requires_grad(false)); auto id_ptr = id_tensor.data_ptr(); size_t offset = 0; for (int i = 0; i < n; ++i) { @@ -200,7 +200,7 @@ std::vector left_right_halo_exchange(int handle, int left_rank, int void add_delay(int delay) { auto stream = at::cuda::getCurrentCUDAStream(); - auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + auto t = at::empty({1}, at::TensorOptions().dtype(at::kInt).device(at::kCUDA)); AddDelay_kernel<<<1, 1, 0, stream>>>(delay, t.data_ptr()); } diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh index a047bedb6..3b91e7619 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include +#include #ifndef _nccl_p2p_h_ #define _nccl_p2p_h_ diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp index 9d69ba01b..3652a39ac 100644 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp @@ -1,4 +1,5 @@ -#include +#include +#include // CUDA forward declaration void fused_strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first); @@ -88,18 +89,102 @@ void maybe_cast(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out) maybe_cast_cuda(overflow_flag, p_in, p_out); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("strided_check_finite", &strided_check_finite, "Strided finite check.", - py::call_guard()); - m.def("adam", &adam, "Adam optimized CUDA implementation.", py::call_guard()); - m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.", - py::call_guard()); - m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.", - py::call_guard()); - m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.", - py::call_guard()); - m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.", - py::call_guard()); - m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.", - py::call_guard()); +void strided_check_finite_dispatch(const at::Tensor& overflow_flag, const at::Tensor& p_copy, int64_t stride, + int64_t clear_overflow_first) { + at::Tensor overflow_flag_arg = overflow_flag; + at::Tensor p_copy_arg = p_copy; + strided_check_finite(overflow_flag_arg, p_copy_arg, static_cast(stride), static_cast(clear_overflow_first)); +} + +void adam_dispatch(const at::Tensor& p, const at::Tensor& p_copy, const at::Tensor& m, const at::Tensor& v, + const at::Tensor& g, double lr, double beta1, double beta2, double eps, double grad_scale, + int64_t step, int64_t mode, int64_t bias_correction, double decay) { + at::Tensor p_arg = p; + at::Tensor p_copy_arg = p_copy; + at::Tensor m_arg = m; + at::Tensor v_arg = v; + at::Tensor g_arg = g; + adam(p_arg, p_copy_arg, m_arg, v_arg, g_arg, static_cast(lr), static_cast(beta1), + static_cast(beta2), static_cast(eps), static_cast(grad_scale), static_cast(step), + static_cast(mode), static_cast(bias_correction), static_cast(decay)); +} + +void reversible_adam_dispatch(const at::Tensor& p, const at::Tensor& p_copy, const at::Tensor& m, const at::Tensor& v, + const at::Tensor& g, double lr, double beta1, double beta2, double eps, + double grad_scale, int64_t step, int64_t mode, int64_t bias_correction, double decay) { + at::Tensor p_arg = p; + at::Tensor p_copy_arg = p_copy; + at::Tensor m_arg = m; + at::Tensor v_arg = v; + at::Tensor g_arg = g; + reversible_adam(p_arg, p_copy_arg, m_arg, v_arg, g_arg, static_cast(lr), static_cast(beta1), + static_cast(beta2), static_cast(eps), static_cast(grad_scale), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(decay)); +} + +void fused_adam_cuda_mt_dispatch(int64_t chunk_size, at::Tensor overflow_flag, + std::vector> tensor_lists, double lr, double beta1, + double beta2, double eps, double grad_scale, int64_t step, int64_t mode, + int64_t bias_correction, double decay) { + fused_adam_cuda_mt(static_cast(chunk_size), overflow_flag, tensor_lists, static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(eps), + static_cast(grad_scale), static_cast(step), static_cast(mode), + static_cast(bias_correction), static_cast(decay)); +} + +void maybe_adam_undo_dispatch(const at::Tensor& overflow_flag, const at::Tensor& p, const at::Tensor& m, + const at::Tensor& v, const at::Tensor& g, double lr, double beta1, double beta2, + double eps, double grad_scale, int64_t step, int64_t mode, int64_t bias_correction, + double decay) { + at::Tensor overflow_flag_arg = overflow_flag; + at::Tensor p_arg = p; + at::Tensor m_arg = m; + at::Tensor v_arg = v; + at::Tensor g_arg = g; + maybe_adam_undo(overflow_flag_arg, p_arg, m_arg, v_arg, g_arg, static_cast(lr), static_cast(beta1), + static_cast(beta2), static_cast(eps), static_cast(grad_scale), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(decay)); +} + +void maybe_cast_dispatch(const at::Tensor& overflow_flag, const at::Tensor& p_in, const at::Tensor& p_out) { + at::Tensor overflow_flag_arg = overflow_flag; + at::Tensor p_in_arg = p_in; + at::Tensor p_out_arg = p_out; + maybe_cast(overflow_flag_arg, p_in_arg, p_out_arg); +} + +void maybe_cast_cuda_mt_dispatch(int64_t chunk_size, at::Tensor overflow_flag, + std::vector> tensor_lists) { + maybe_cast_cuda_mt(static_cast(chunk_size), overflow_flag, tensor_lists); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_adam_strided_check_finite(Tensor(a!) overflow_flag, Tensor p_copy, int stride, " + "int clear_overflow_first) -> ()"); + m.def("fused_adam_adam(Tensor(a!) p, Tensor(b!) p_copy, Tensor(c!) m, Tensor(d!) v, Tensor g, " + "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " + "int bias_correction, float decay) -> ()"); + m.def("fused_adam_reversible_adam(Tensor(a!) p, Tensor(b!) p_copy, Tensor(c!) m, Tensor(d!) v, Tensor g, " + "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " + "int bias_correction, float decay) -> ()"); + m.def("fused_adam_adam_mt(int chunk_size, Tensor overflow_flag, Tensor[][] tensor_lists, float lr, " + "float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, " + "float decay) -> ()"); + m.def("fused_adam_maybe_adam_undo(Tensor overflow_flag, Tensor(a!) p, Tensor(b!) m, Tensor(c!) v, Tensor(d!) g, " + "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " + "int bias_correction, float decay) -> ()"); + m.def("fused_adam_maybe_cast(Tensor overflow_flag, Tensor p_in, Tensor(a!) p_out) -> ()"); + m.def("fused_adam_maybe_cast_mt(int chunk_size, Tensor overflow_flag, Tensor[][] tensor_lists) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_adam_strided_check_finite", &strided_check_finite_dispatch); + m.impl("fused_adam_adam", &adam_dispatch); + m.impl("fused_adam_reversible_adam", &reversible_adam_dispatch); + m.impl("fused_adam_adam_mt", &fused_adam_cuda_mt_dispatch); + m.impl("fused_adam_maybe_adam_undo", &maybe_adam_undo_dispatch); + m.impl("fused_adam_maybe_cast", &maybe_cast_dispatch); + m.impl("fused_adam_maybe_cast_mt", &maybe_cast_cuda_mt_dispatch); } diff --git a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp index b0aa92082..36cf86a69 100644 --- a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp @@ -1,11 +1,29 @@ -#include +#include +#include void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int bias_correction, const float weight_decay, const int grad_averaging, const int mode, const float global_grad_norm, const float max_grad_norm); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer", - py::call_guard()); +void multi_tensor_lamb_dispatch(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double lr, double beta1, + double beta2, double epsilon, int64_t step, int64_t bias_correction, + double weight_decay, int64_t grad_averaging, int64_t mode, double global_grad_norm, + double max_grad_norm) { + multi_tensor_lamb_cuda(static_cast(chunk_size), noop_flag, tensor_lists, static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + static_cast(step), static_cast(bias_correction), static_cast(weight_decay), + static_cast(grad_averaging), static_cast(mode), static_cast(global_grad_norm), + static_cast(max_grad_norm)); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_lamb_lamb(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " + "float beta2, float epsilon, int step, int bias_correction, float weight_decay, int grad_averaging, " + "int mode, float global_grad_norm, float max_grad_norm) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_lamb_lamb", &multi_tensor_lamb_dispatch); } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp index 29592a6af..d8e443152 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp @@ -1,4 +1,4 @@ -#include +#include void multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor grad_scale, float lr, @@ -16,17 +16,53 @@ void multi_tensor_fused_adam_with_param_remainders_cuda(int chunk_size, at::Tens float eps, int step, int mode, int bias_correction, float weight_decay); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda, - "CUDA kernels for multi-tensor Adam, " - "with param copy", - py::call_guard()); - m.def("multi_tensor_fused_adam_capturable", &multi_tensor_fused_adam_capturable_cuda, - "CUDA kernels for multi-tensor Adam, " - "with param copy, capturable for CUDA graph", - py::call_guard()); - m.def("multi_tensor_fused_adam_with_param_remainders", &multi_tensor_fused_adam_with_param_remainders_cuda, - "CUDA kernel for multi-tensor Adam, " - "with stored param remainders and param copy", - py::call_guard()); +void multi_tensor_fused_adam_dispatch(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, + double lr, double beta1, double beta2, double eps, int64_t step, int64_t mode, + int64_t bias_correction, double weight_decay) { + multi_tensor_fused_adam_cuda(static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(eps), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay)); +} + +void multi_tensor_fused_adam_capturable_dispatch(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor grad_scale, at::Tensor lr, double beta1, double beta2, + double eps, at::Tensor step, int64_t mode, int64_t bias_correction, + double weight_decay) { + multi_tensor_fused_adam_capturable_cuda(static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, lr, + static_cast(beta1), static_cast(beta2), + static_cast(eps), step, static_cast(mode), + static_cast(bias_correction), static_cast(weight_decay)); +} + +void multi_tensor_fused_adam_with_param_remainders_dispatch(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor grad_scale, double lr, double beta1, + double beta2, double eps, int64_t step, int64_t mode, + int64_t bias_correction, double weight_decay) { + multi_tensor_fused_adam_with_param_remainders_cuda( + static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(eps), static_cast(step), + static_cast(mode), static_cast(bias_correction), static_cast(weight_decay)); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("distributed_adam_multi_tensor_fused_adam(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, " + "int bias_correction, float weight_decay) -> ()"); + m.def("distributed_adam_multi_tensor_fused_adam_capturable(int chunk_size, Tensor noop_flag, " + "Tensor[][] tensor_lists, Tensor grad_scale, Tensor lr, float beta1, float beta2, float eps, " + "Tensor step, int mode, int bias_correction, float weight_decay) -> ()"); + m.def("distributed_adam_multi_tensor_fused_adam_with_param_remainders(int chunk_size, Tensor noop_flag, " + "Tensor[][] tensor_lists, Tensor grad_scale, float lr, float beta1, float beta2, float eps, " + "int step, int mode, int bias_correction, float weight_decay) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("distributed_adam_multi_tensor_fused_adam", &multi_tensor_fused_adam_dispatch); + m.impl("distributed_adam_multi_tensor_fused_adam_capturable", &multi_tensor_fused_adam_capturable_dispatch); + m.impl("distributed_adam_multi_tensor_fused_adam_with_param_remainders", + &multi_tensor_fused_adam_with_param_remainders_dispatch); } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp index f74bebfc2..b4a71514e 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp @@ -1,4 +1,4 @@ -#include +#include void multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, @@ -14,9 +14,42 @@ void multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag, at::Tensor update_norm_offset, at::Tensor learning_rate, at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda, - "Computes update term for LAMB optimizer", py::call_guard()); - m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda, - "Applies update term for LAMB optimizer", py::call_guard()); +void multi_tensor_lamb_compute_update_term_dispatch(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2, + at::Tensor per_tensor_beta3, + at::Tensor per_tensor_bias_correction, at::Tensor step, + at::Tensor per_tensor_epsilon, int64_t mode, + at::Tensor per_tensor_decay, at::Tensor global_scale, + at::Tensor global_grad_norm, double max_grad_norm) { + multi_tensor_lamb_compute_update_term_cuda(static_cast(chunk_size), noop_flag, tensor_lists, per_tensor_beta1, + per_tensor_beta2, per_tensor_beta3, per_tensor_bias_correction, step, + per_tensor_epsilon, static_cast(mode), per_tensor_decay, + global_scale, global_grad_norm, static_cast(max_grad_norm)); +} + +void multi_tensor_lamb_update_weights_dispatch(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, + at::Tensor update_norm_offset, at::Tensor learning_rate, + at::Tensor per_tensor_decay, at::Tensor global_grad_norm, + bool use_nvlamb) { + multi_tensor_lamb_update_weights_cuda(static_cast(chunk_size), noop_flag, tensor_lists, per_tensor_param_norm, + per_tensor_update_norm, update_norm_offset, learning_rate, per_tensor_decay, + global_grad_norm, use_nvlamb); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("distributed_lamb_compute_update_term(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_beta1, Tensor per_tensor_beta2, Tensor per_tensor_beta3, " + "Tensor per_tensor_bias_correction, Tensor step, Tensor per_tensor_epsilon, int mode, " + "Tensor per_tensor_decay, Tensor global_scale, Tensor global_grad_norm, float max_grad_norm) -> ()"); + m.def("distributed_lamb_update_weights(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_param_norm, Tensor per_tensor_update_norm, Tensor update_norm_offset, " + "Tensor learning_rate, Tensor per_tensor_decay, Tensor global_grad_norm, bool use_nvlamb) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("distributed_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_dispatch); + m.impl("distributed_lamb_update_weights", &multi_tensor_lamb_update_weights_dispatch); } diff --git a/apex/contrib/csrc/peer_memory/peer_memory.cpp b/apex/contrib/csrc/peer_memory/peer_memory.cpp index bc19e6206..52b2f53d4 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory.cpp +++ b/apex/contrib/csrc/peer_memory/peer_memory.cpp @@ -15,22 +15,66 @@ */ #include "peer_memory_cuda.cuh" +#include -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw", - py::call_guard()); - m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw", py::call_guard()); - m.def("zero", &apex::contrib::peer_memory::zero, "zero", py::call_guard()); - m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address", - py::call_guard()); - m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers", - py::call_guard()); - m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half", - py::call_guard()); - m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float", - py::call_guard()); - m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int", - py::call_guard()); - m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d", - py::call_guard()); +namespace { +std::vector apex_peer_memory_get_raw_peers(at::Tensor ipc_addresses, int64_t peer_rank, int64_t raw) { + return apex::contrib::peer_memory::get_raw_peers(ipc_addresses, static_cast(peer_rank), raw); +} + +at::Tensor apex_peer_memory_blob_view_half(int64_t raw, at::IntArrayRef shape, bool channels_last) { + return apex::contrib::peer_memory::blob_view_half(raw, std::vector(shape.begin(), shape.end()), + channels_last); +} + +at::Tensor apex_peer_memory_blob_view_float(int64_t raw, at::IntArrayRef shape, bool channels_last) { + return apex::contrib::peer_memory::blob_view_float(raw, std::vector(shape.begin(), shape.end()), + channels_last); +} + +at::Tensor apex_peer_memory_blob_view_int(int64_t raw, at::IntArrayRef shape, bool channels_last) { + return apex::contrib::peer_memory::blob_view_int(raw, std::vector(shape.begin(), shape.end()), + channels_last); +} + +void apex_peer_memory_push_pull_halos_1d(bool diagnostics, bool explicit_nhwc, int64_t numSM, int64_t rank, + bool top_zero, at::Tensor top_in_halo, at::Tensor top_in_transfer, + at::Tensor top_out_transfer, at::Tensor top_out_halo, bool btm_zero, + at::Tensor btm_in_halo, at::Tensor btm_in_transfer, + at::Tensor btm_out_transfer, at::Tensor btm_out_halo) { + apex::contrib::peer_memory::push_pull_halos_1d(diagnostics, explicit_nhwc, static_cast(numSM), + static_cast(rank), top_zero, top_in_halo, top_in_transfer, + top_out_transfer, top_out_halo, btm_zero, btm_in_halo, + btm_in_transfer, btm_out_transfer, btm_out_halo); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("peer_memory_allocate_raw(int size) -> int"); + m.def("peer_memory_free_raw(int raw) -> ()"); + m.def("peer_memory_zero(int raw, int size) -> ()"); + m.def("peer_memory_get_raw_ipc_address(int raw) -> Tensor"); + m.def("peer_memory_get_raw_peers(Tensor ipc_addresses, int peer_rank, int raw) -> int[]"); + m.def("peer_memory_blob_view_half(int raw, int[] shape, bool channels_last) -> Tensor"); + m.def("peer_memory_blob_view_float(int raw, int[] shape, bool channels_last) -> Tensor"); + m.def("peer_memory_blob_view_int(int raw, int[] shape, bool channels_last) -> Tensor"); + m.def("peer_memory_push_pull_halos_1d(bool diagnostics, bool explicit_nhwc, int numSM, int rank, bool top_zero, " + "Tensor(a!) top_in_halo, Tensor(b!) top_in_transfer, Tensor(c!) top_out_transfer, Tensor(d!) top_out_halo, " + "bool btm_zero, Tensor(e!) btm_in_halo, Tensor(f!) btm_in_transfer, Tensor(g!) btm_out_transfer, " + "Tensor(h!) btm_out_halo) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { + m.impl("peer_memory_allocate_raw", &apex::contrib::peer_memory::allocate_raw); + m.impl("peer_memory_free_raw", &apex::contrib::peer_memory::free_raw); + m.impl("peer_memory_zero", &apex::contrib::peer_memory::zero); + m.impl("peer_memory_get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address); + m.impl("peer_memory_get_raw_peers", &apex_peer_memory_get_raw_peers); + m.impl("peer_memory_blob_view_half", &apex_peer_memory_blob_view_half); + m.impl("peer_memory_blob_view_float", &apex_peer_memory_blob_view_float); + m.impl("peer_memory_blob_view_int", &apex_peer_memory_blob_view_int); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("peer_memory_push_pull_halos_1d", &apex_peer_memory_push_pull_halos_1d); } diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu index 974ed5d64..3a43c860f 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu @@ -1,7 +1,8 @@ +#include #include +#include #include #include -#include #include #include @@ -51,7 +52,7 @@ at::Tensor blob_view(T* raw_ptr, std::vector shape, const at::TensorOpt size *= sizeof(T); // TODO: Implement dynamic reuse of pooled peer memory. // We provide no deleter function because all peer memory allocations are static in this implementation. - return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options); + return at::from_blob((void*)raw_ptr, shape, strides, options); } void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W) { @@ -330,7 +331,7 @@ at::Tensor get_raw_ipc_address(int64_t raw) { cudaIpcMemHandle_t mem_handle; CUDACHECK(cudaIpcGetMemHandle(&mem_handle, (void*)raw)); const int n = sizeof(cudaIpcMemHandle_t); - auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8)); + auto address_tensor = at::empty({n}, at::TensorOptions().dtype(at::kByte)); auto address_tensor_p = address_tensor.data_ptr(); memcpy(address_tensor_p, (uint8_t*)&mem_handle, n); return address_tensor; @@ -354,15 +355,16 @@ std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int6 } at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last) { - return blob_view((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last); + return blob_view((at::Half*)raw, shape, at::TensorOptions().dtype(at::kHalf).device(at::kCUDA), + channels_last); } at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last) { - return blob_view((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last); + return blob_view((float*)raw, shape, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA), channels_last); } at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last) { - return blob_view((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last); + return blob_view((int*)raw, shape, at::TensorOptions().dtype(at::kInt).device(at::kCUDA), channels_last); } void push_pull_halos_1d( diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh index 29cf7a108..92a7fd876 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include +#include #ifndef _peer_memory_h_ #define _peer_memory_h_ diff --git a/apex/contrib/csrc/transducer/transducer_joint.cpp b/apex/contrib/csrc/transducer/transducer_joint.cpp index 1175c1676..3dd0b7d0a 100644 --- a/apex/contrib/csrc/transducer/transducer_joint.cpp +++ b/apex/contrib/csrc/transducer/transducer_joint.cpp @@ -1,5 +1,5 @@ -#include -#include +#include +#include #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") @@ -7,19 +7,19 @@ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) -std::vector transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, - int64_t packedBatch, int opt, bool packOutput, bool relu, - bool dropout, float dropoutProb, int tileSize); - -std::vector transducer_joint_cuda_backward(std::vector in, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, - int maxGLen, bool packOutput, float scale); - -std::vector transducer_joint_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, int64_t packedBatch, - int opt, bool packOutput, bool relu, bool dropout, - float dropoutProb, int tileSize) { +std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, + at::Tensor gLen, at::Tensor batchOffset, + int64_t packedBatch, int opt, bool packOutput, bool relu, + bool dropout, float dropoutProb, int tileSize); + +std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, + at::Tensor gLen, at::Tensor batchOffset, int maxFLen, + int maxGLen, bool packOutput, float scale); + +std::vector transducer_joint_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, + at::Tensor gLen, at::Tensor batchOffset, int64_t packedBatch, + int opt, bool packOutput, bool relu, bool dropout, + float dropoutProb, int tileSize) { CHECK_INPUT(f); CHECK_INPUT(g); CHECK_INPUT(fLen); @@ -29,21 +29,43 @@ std::vector transducer_joint_forward(torch::Tensor f, torch::Tens dropoutProb, tileSize); } -std::vector transducer_joint_backward(std::vector in, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, - int maxGLen, bool packOutput, float scale) { +std::vector transducer_joint_backward(std::vector in, at::Tensor fLen, + at::Tensor gLen, at::Tensor batchOffset, int maxFLen, + int maxGLen, bool packOutput, float scale) { for (auto t : in) { CHECK_INPUT(t); } CHECK_INPUT(fLen); CHECK_INPUT(gLen); if (packOutput) CHECK_INPUT(batchOffset); - return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)", - py::call_guard()); - m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)", - py::call_guard()); -} + return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale); +} + +std::vector transducer_joint_forward_dispatch(at::Tensor f, at::Tensor g, at::Tensor fLen, + at::Tensor gLen, at::Tensor batchOffset, + int64_t packedBatch, int64_t opt, bool packOutput, + bool relu, bool dropout, double dropoutProb, + int64_t tileSize) { + return transducer_joint_forward(f, g, fLen, gLen, batchOffset, packedBatch, static_cast(opt), packOutput, relu, + dropout, static_cast(dropoutProb), static_cast(tileSize)); +} + +std::vector transducer_joint_backward_dispatch(std::vector in, at::Tensor fLen, + at::Tensor gLen, at::Tensor batchOffset, + int64_t maxFLen, int64_t maxGLen, bool packOutput, + double scale) { + return transducer_joint_backward(in, fLen, gLen, batchOffset, static_cast(maxFLen), static_cast(maxGLen), + packOutput, static_cast(scale)); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("transducer_joint_forward(Tensor f, Tensor g, Tensor fLen, Tensor gLen, Tensor batchOffset, int packedBatch, " + "int opt, bool packOutput, bool relu, bool dropout, float dropoutProb, int tileSize) -> Tensor[]"); + m.def("transducer_joint_backward(Tensor[] input, Tensor fLen, Tensor gLen, Tensor batchOffset, int maxFLen, " + "int maxGLen, bool packOutput, float scale) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("transducer_joint_forward", &transducer_joint_forward_dispatch); + m.impl("transducer_joint_backward", &transducer_joint_backward_dispatch); +} diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu index acb9f9e9e..c6df4748c 100644 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu @@ -1,8 +1,8 @@ -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #ifdef OLD_GENERATOR_PATH #include @@ -179,7 +179,7 @@ __global__ void transducer_joint_forward(const scalar_t* f, const scalar_t* g, c } } else if (packOutput == false and t < maxFLen and u < maxGLen) { // Need to write finite data to don't-care region because we instantiate the result tensor -// with torch::empty for performance reasons. Even though it is don't-care region, the +// with at::empty for performance reasons. Even though it is don't-care region, the // contents need to be finite, otherwise could lead to NaN in WGRAD. // In packing mode, this write is no longer necessary as we remove the don't-care region // from the output. @@ -535,10 +535,10 @@ __global__ void transducer_joint_combined_vec_backward(const scalar_t* grad, con } } -std::vector transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, - int64_t packedBatch, int opt, bool packOutput, bool relu, - bool dropout, float dropoutProb, int tileSize) { +std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, + at::Tensor gLen, at::Tensor batchOffset, + int64_t packedBatch, int opt, bool packOutput, bool relu, + bool dropout, float dropoutProb, int tileSize) { auto tensorOpt = f.options(); auto dtype = f.scalar_type(); const auto batchSize = f.size(0); @@ -548,17 +548,17 @@ std::vector transducer_joint_cuda_forward(torch::Tensor f, torch: bool masked = dropout or relu; int64_t* batchOffsetPtr = nullptr; - torch::Tensor sum, mask; - auto maskOpt = tensorOpt.dtype(torch::kUInt8); - if (!packOutput) { - sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); - batchOffsetPtr = nullptr; - if (masked) mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); - } else { - sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); - batchOffsetPtr = batchOffset.data_ptr(); - if (masked) mask = torch::empty({packedBatch, hiddenSize}, maskOpt); - } + at::Tensor sum, mask; + auto maskOpt = tensorOpt.dtype(at::kByte); + if (!packOutput) { + sum = at::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); + batchOffsetPtr = nullptr; + if (masked) mask = at::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); + } else { + sum = at::empty({packedBatch, hiddenSize}, tensorOpt); + batchOffsetPtr = batchOffset.data_ptr(); + if (masked) mask = at::empty({packedBatch, hiddenSize}, maskOpt); + } uint8_t* maskPtr = masked ? mask.data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -649,9 +649,9 @@ std::vector transducer_joint_cuda_forward(torch::Tensor f, torch: return {sum}; } -std::vector transducer_joint_cuda_backward(std::vector in, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, - int maxGLen, bool packOutput, float scale) { +std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, + at::Tensor gLen, at::Tensor batchOffset, int maxFLen, + int maxGLen, bool packOutput, float scale) { auto grad = in[0]; bool masked = (in.size() == 2); uint8_t* maskPtr = masked ? in[1].data_ptr() : nullptr; @@ -664,8 +664,8 @@ std::vector transducer_joint_cuda_backward(std::vectormaxThreadsPerBlock / C10_WARP_SIZE; - torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); - torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); + at::Tensor fGrad = at::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); + at::Tensor gGrad = at::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); int64_t* batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); diff --git a/apex/contrib/csrc/transducer/transducer_loss.cpp b/apex/contrib/csrc/transducer/transducer_loss.cpp index 4c124edac..64de84271 100644 --- a/apex/contrib/csrc/transducer/transducer_loss.cpp +++ b/apex/contrib/csrc/transducer/transducer_loss.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include @@ -8,17 +9,17 @@ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) -std::vector transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen, - torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen, +std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor label, at::Tensor audLen, + at::Tensor txtLen, at::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool packedInput); -torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, - torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen, - torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, +at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, + at::Tensor beta, at::Tensor audLen, at::Tensor txtLen, + at::Tensor label, at::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, bool packedInput); -std::vector transducer_loss_forward(torch::Tensor x, torch::Tensor label, torch::Tensor fLen, - torch::Tensor yLen, torch::Tensor batchOffset, int maxFLen, +std::vector transducer_loss_forward(at::Tensor x, at::Tensor label, at::Tensor fLen, + at::Tensor yLen, at::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); @@ -28,9 +29,9 @@ std::vector transducer_loss_forward(torch::Tensor x, torch::Tenso return transducer_loss_cuda_forward(x, label, fLen, yLen, batchOffset, maxFLen, blankIdx, opt, packedInput); } -torch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, torch::Tensor beta, - torch::Tensor fLen, torch::Tensor yLen, torch::Tensor label, - torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, +at::Tensor transducer_loss_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, + at::Tensor fLen, at::Tensor yLen, at::Tensor label, + at::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); @@ -45,9 +46,33 @@ torch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lossGrad, fuseSoftmaxBackward, packedInput); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)", - py::call_guard()); - m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)", - py::call_guard()); +std::vector transducer_loss_forward_dispatch(at::Tensor x, at::Tensor label, at::Tensor fLen, + at::Tensor yLen, at::Tensor batchOffset, + int64_t maxFLen, int64_t blankIdx, int64_t opt, + bool packedInput) { + return transducer_loss_forward(x, label, fLen, yLen, batchOffset, static_cast(maxFLen), + static_cast(blankIdx), static_cast(opt), packedInput); +} + +at::Tensor transducer_loss_backward_dispatch(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, + at::Tensor beta, at::Tensor fLen, at::Tensor yLen, + at::Tensor label, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, + bool packedInput) { + return transducer_loss_backward(x, lossGrad, alpha, beta, fLen, yLen, label, batchOffset, static_cast(maxFLen), + static_cast(blankIdx), static_cast(opt), fuseSoftmaxBackward, + packedInput); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("transducer_loss_forward(Tensor x, Tensor label, Tensor fLen, Tensor yLen, Tensor batchOffset, int maxFLen, " + "int blankIdx, int opt, bool packedInput) -> Tensor[]"); + m.def("transducer_loss_backward(Tensor x, Tensor lossGrad, Tensor alpha, Tensor beta, Tensor fLen, Tensor yLen, " + "Tensor label, Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, " + "bool packedInput) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("transducer_loss_forward", &transducer_loss_forward_dispatch); + m.impl("transducer_loss_backward", &transducer_loss_backward_dispatch); } diff --git a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu index f6f7e4ca0..b7613f8b9 100644 --- a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu @@ -3,7 +3,6 @@ #include #include #include -#include #include @@ -455,8 +454,8 @@ __global__ void transducer_loss_fused_vec_backward(const scalar_t* x, const scal } } -std::vector transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen, - torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen, +std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor label, at::Tensor audLen, + at::Tensor txtLen, at::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool packedInput) { auto scalarType = x.scalar_type(); auto tensorOpt = x.options(); @@ -470,9 +469,9 @@ std::vector transducer_loss_cuda_forward(torch::Tensor x, torch:: // The data type of alpha and beta will be resolved at dispatch time, // hence defined here and assigned later - torch::Tensor alpha; - torch::Tensor beta; - torch::Tensor loss = torch::empty({batchSize}, tensorOpt); + at::Tensor alpha; + at::Tensor beta; + at::Tensor loss = at::empty({batchSize}, tensorOpt); const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock; @@ -485,8 +484,8 @@ std::vector transducer_loss_cuda_forward(torch::Tensor x, torch:: using acc_t = at::acc_type; auto accType = c10::CppTypeToScalarType::value; auto accTensorOpt = tensorOpt.dtype(accType); - alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); - beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); + alpha = at::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); + beta = at::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); // decide what kernel to launch based on the problem size // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla @@ -514,12 +513,12 @@ std::vector transducer_loss_cuda_forward(torch::Tensor x, torch:: return {alpha, beta, loss}; } -torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, - torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen, - torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, +at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, + at::Tensor beta, at::Tensor audLen, at::Tensor txtLen, + at::Tensor label, at::Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, bool packedInput) { auto dtype = x.scalar_type(); - torch::Tensor xGrad; + at::Tensor xGrad; const int batchSize = label.size(0); const int maxGLen = label.size(1) + 1; const int dictSize = x.size(-1); @@ -532,7 +531,7 @@ torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossG if (fuseSoftmaxBackward) { // alloc empty tensors for performance, hence need to ensure zeros are writtern to // don't-care region in the kernel. - xGrad = torch::empty_like(x); + xGrad = at::empty_like(x); // Would like each thread to work on 4 hidden units const int workPerThread = 4; @@ -566,7 +565,7 @@ torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossG } else { // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize // the tensor with all zeros. - xGrad = torch::zeros_like(x); + xGrad = at::zeros_like(x); // don't launch more threads than needed. const int threads = std::min(maxThreadPerBlock, maxGLen); const dim3 blocks(maxFLen, batchSize); diff --git a/apex/contrib/csrc/xentropy/interface.cpp b/apex/contrib/csrc/xentropy/interface.cpp index 76cd379f5..aaf9c0c6c 100644 --- a/apex/contrib/csrc/xentropy/interface.cpp +++ b/apex/contrib/csrc/xentropy/interface.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include @@ -38,18 +39,37 @@ at::Tensor softmax_xentropy_backward(const at::Tensor& grad_loss, const at::Tens return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", - py::call_guard()); - m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", - py::call_guard()); - // ref: https://pybind11.readthedocs.io/en/stable/basics.html#exporting-variables - py::object version = py::cast( +std::vector softmax_xentropy_forward_dispatch(const at::Tensor& input, const at::Tensor& labels, + double smoothing, bool half_to_float) { + return softmax_xentropy_forward(input, labels, static_cast(smoothing), half_to_float); +} + +at::Tensor softmax_xentropy_backward_dispatch(const at::Tensor& grad_loss, const at::Tensor& logits, + const at::Tensor& max_log_sum_exp, const at::Tensor& labels, + double smoothing) { + return softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, static_cast(smoothing)); +} + +std::string softmax_xentropy_version() { #ifdef XENTROPY_VER - XENTROPY_VER + return XENTROPY_VER; #else - std::string{} + return {}; #endif - ); - m.attr("__version__") = version; +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("xentropy_forward(Tensor input, Tensor labels, float smoothing, bool half_to_float) -> Tensor[]"); + m.def("xentropy_backward(Tensor grad_loss, Tensor logits, Tensor max_log_sum_exp, Tensor labels, " + "float smoothing) -> Tensor"); + m.def("xentropy_version() -> str"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("xentropy_forward", &softmax_xentropy_forward_dispatch); + m.impl("xentropy_backward", &softmax_xentropy_backward_dispatch); +} + +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { + m.impl("xentropy_version", &softmax_xentropy_version); } diff --git a/apex/contrib/cudnn_gbn/batch_norm.py b/apex/contrib/cudnn_gbn/batch_norm.py index 8346b74aa..854cbb4d5 100644 --- a/apex/contrib/cudnn_gbn/batch_norm.py +++ b/apex/contrib/cudnn_gbn/batch_norm.py @@ -2,8 +2,8 @@ from torch.nn.modules.batchnorm import _BatchNorm from torch.nn import functional as F from torch import Tensor -import peer_memory_cuda as pm -import cudnn_gbn_lib +from apex._extensions import peer_memory_cuda as pm +from apex._extensions import cudnn_gbn_lib from torch.cuda.amp import custom_fwd, custom_bwd diff --git a/apex/contrib/fmha/fmha.py b/apex/contrib/fmha/fmha.py index 4ebe97e93..c4ae69a8d 100644 --- a/apex/contrib/fmha/fmha.py +++ b/apex/contrib/fmha/fmha.py @@ -27,7 +27,7 @@ import torch -import fmhalib as mha +from apex._extensions import fmhalib as mha class FMHAFun(torch.autograd.Function): diff --git a/apex/contrib/focal_loss/__init__.py b/apex/contrib/focal_loss/__init__.py index 2a187d029..ce459f28d 100644 --- a/apex/contrib/focal_loss/__init__.py +++ b/apex/contrib/focal_loss/__init__.py @@ -1,6 +1,6 @@ try: import torch - import focal_loss_cuda + from apex._extensions import focal_loss_cuda from .focal_loss import focal_loss del torch diff --git a/apex/contrib/focal_loss/focal_loss.py b/apex/contrib/focal_loss/focal_loss.py index 85c6f620e..295deb477 100644 --- a/apex/contrib/focal_loss/focal_loss.py +++ b/apex/contrib/focal_loss/focal_loss.py @@ -1,6 +1,6 @@ import torch -import focal_loss_cuda +from apex._extensions import focal_loss_cuda class FocalLoss(torch.autograd.Function): diff --git a/apex/contrib/group_norm/group_norm.py b/apex/contrib/group_norm/group_norm.py index 998a220ff..0aa7ab902 100644 --- a/apex/contrib/group_norm/group_norm.py +++ b/apex/contrib/group_norm/group_norm.py @@ -10,8 +10,8 @@ import os import torch import torch.nn.init as init -import group_norm_cuda -import group_norm_v2_cuda +from apex._extensions import group_norm_cuda +from apex._extensions import group_norm_v2_cuda from torch import Tensor from torch.nn.parameter import Parameter diff --git a/apex/contrib/groupbn/__init__.py b/apex/contrib/groupbn/__init__.py index 4af4ac595..436f24ce4 100644 --- a/apex/contrib/groupbn/__init__.py +++ b/apex/contrib/groupbn/__init__.py @@ -1,6 +1,6 @@ try: import torch - import bnp + from apex._extensions import bnp from .batch_norm import BatchNorm2d_NHWC del torch diff --git a/apex/contrib/groupbn/batch_norm.py b/apex/contrib/groupbn/batch_norm.py index d5d71cd0f..6f78c2af2 100644 --- a/apex/contrib/groupbn/batch_norm.py +++ b/apex/contrib/groupbn/batch_norm.py @@ -2,7 +2,7 @@ import numpy as np from torch.nn.modules.batchnorm import _BatchNorm -import bnp +from apex._extensions import bnp class bn_NHWC_impl(torch.autograd.Function): diff --git a/apex/contrib/index_mul_2d/index_mul_2d.py b/apex/contrib/index_mul_2d/index_mul_2d.py index ab628b05d..47fe8c7a7 100644 --- a/apex/contrib/index_mul_2d/index_mul_2d.py +++ b/apex/contrib/index_mul_2d/index_mul_2d.py @@ -1,6 +1,6 @@ import torch -import fused_index_mul_2d +from apex._extensions import fused_index_mul_2d class IndexMul2d_(torch.autograd.Function): diff --git a/apex/contrib/layer_norm/layer_norm.py b/apex/contrib/layer_norm/layer_norm.py index ef134667d..7b962a307 100644 --- a/apex/contrib/layer_norm/layer_norm.py +++ b/apex/contrib/layer_norm/layer_norm.py @@ -2,7 +2,7 @@ from torch.nn import init from apex._autocast_utils import _cast_if_autocast_enabled -import fast_layer_norm +from apex._extensions import fast_layer_norm class FastLayerNormFN(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py index 7f5fe1994..bfd8958fe 100644 --- a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py +++ b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py @@ -1,6 +1,6 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class FastEncdecAttnFunc(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py index da7ef57dc..00c634bc8 100644 --- a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py +++ b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py @@ -7,7 +7,7 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class FastEncdecAttnNormAddFunc(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py b/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py index a03cbc518..bfea175b9 100644 --- a/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py +++ b/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py @@ -1,6 +1,6 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class FastSelfAttnFunc(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py b/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py index 5cbae97ec..64546f86e 100644 --- a/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py +++ b/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py @@ -1,6 +1,6 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class FastSelfAttnNormAddFunc(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/mask_softmax_dropout_func.py b/apex/contrib/multihead_attn/mask_softmax_dropout_func.py index 995e68e51..64d0542c4 100644 --- a/apex/contrib/multihead_attn/mask_softmax_dropout_func.py +++ b/apex/contrib/multihead_attn/mask_softmax_dropout_func.py @@ -1,6 +1,6 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class MaskSoftmaxDropout(torch.autograd.Function): diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index b1f6b58fc..0a5abd140 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -28,8 +28,8 @@ nccl_allocator = None from apex.multi_tensor_apply import multi_tensor_applier -import amp_C -import distributed_adam_cuda +from apex._extensions import amp_C +from apex._extensions import distributed_adam_cuda # Fallback to private functions if using PyTorch <1.13.0 try: @@ -126,7 +126,7 @@ def _coalescing_manager_append_work( # Import optional CUDA kernels _FOUND_DEPRECATED_FUSED_ADAM: bool = False try: - import fused_adam_cuda + from apex._extensions import fused_adam_cuda _FOUND_DEPRECATED_FUSED_ADAM = True except ImportError: diff --git a/apex/contrib/optimizers/distributed_fused_lamb.py b/apex/contrib/optimizers/distributed_fused_lamb.py index 93b964fbb..b8013dffd 100644 --- a/apex/contrib/optimizers/distributed_fused_lamb.py +++ b/apex/contrib/optimizers/distributed_fused_lamb.py @@ -2,7 +2,7 @@ import inspect import torch import importlib -import amp_C +from apex._extensions import amp_C from apex.multi_tensor_apply import multi_tensor_applier import torch.distributed.distributed_c10d as c10d @@ -140,8 +140,8 @@ def __init__( super(DistributedFusedLAMB, self).__init__(params, defaults) global fused_adam_cuda, distributed_lamb_cuda - fused_adam_cuda = importlib.import_module("fused_adam_cuda") - distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda") + fused_adam_cuda = importlib.import_module("apex._extensions.fused_adam_cuda") + distributed_lamb_cuda = importlib.import_module("apex._extensions.distributed_lamb_cuda") self._overflow_buf = torch.cuda.IntTensor([0]) self._has_overflow = False @@ -151,7 +151,7 @@ def __init__( self.multi_tensor_lamb_update_weights = ( distributed_lamb_cuda.multi_tensor_lamb_update_weights ) - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm diff --git a/apex/contrib/optimizers/fp16_optimizer.py b/apex/contrib/optimizers/fp16_optimizer.py index 856a181dc..3f5788a0e 100755 --- a/apex/contrib/optimizers/fp16_optimizer.py +++ b/apex/contrib/optimizers/fp16_optimizer.py @@ -56,7 +56,7 @@ def __init__( param_group["params"] = fp32_group if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.overflow_buf = torch.cuda.IntTensor([0]) self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm diff --git a/apex/contrib/optimizers/fused_adam.py b/apex/contrib/optimizers/fused_adam.py index 37a77b3cf..1ee075d8d 100644 --- a/apex/contrib/optimizers/fused_adam.py +++ b/apex/contrib/optimizers/fused_adam.py @@ -50,7 +50,7 @@ def __init__( amp_scale_adjustment=1.0, ): global fused_adam_cuda - fused_adam_cuda = importlib.import_module("fused_adam_cuda") + fused_adam_cuda = importlib.import_module("apex._extensions.fused_adam_cuda") self._use_multi_tensor = False if use_mt: diff --git a/apex/contrib/optimizers/fused_lamb.py b/apex/contrib/optimizers/fused_lamb.py index d40cfeab7..9e41cacc6 100644 --- a/apex/contrib/optimizers/fused_lamb.py +++ b/apex/contrib/optimizers/fused_lamb.py @@ -87,11 +87,11 @@ def __init__( ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - fused_lamb_cuda = importlib.import_module("fused_lamb_cuda") + fused_lamb_cuda = importlib.import_module("apex._extensions.fused_lamb_cuda") self.multi_tensor_lamb = fused_lamb_cuda.lamb else: raise RuntimeError("apex.contrib.optimizers.FusedLAMB requires cuda extensions") diff --git a/apex/contrib/optimizers/fused_sgd.py b/apex/contrib/optimizers/fused_sgd.py index e2acfcbaa..6fefda7bf 100644 --- a/apex/contrib/optimizers/fused_sgd.py +++ b/apex/contrib/optimizers/fused_sgd.py @@ -96,7 +96,7 @@ def __init__( self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py b/apex/contrib/peer_memory/peer_halo_exchanger_1d.py index 8995e806d..3512ac8c2 100644 --- a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py +++ b/apex/contrib/peer_memory/peer_halo_exchanger_1d.py @@ -1,5 +1,5 @@ import torch -import peer_memory_cuda as pm +from apex._extensions import peer_memory_cuda as pm class PeerHaloExchanger1d: diff --git a/apex/contrib/peer_memory/peer_memory.py b/apex/contrib/peer_memory/peer_memory.py index 72b1a1098..9717b33c3 100644 --- a/apex/contrib/peer_memory/peer_memory.py +++ b/apex/contrib/peer_memory/peer_memory.py @@ -1,6 +1,6 @@ import torch import numpy as np -import peer_memory_cuda as pm +from apex._extensions import peer_memory_cuda as pm class PeerMemoryPool(object): diff --git a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu b/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu index 4d690ee56..caf3aa62e 100644 --- a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu +++ b/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu @@ -1,7 +1,8 @@ -#include -#include +#include +#include +#include +#include #include -namespace py = pybind11; #define gpuErrchk(ans) \ { \ @@ -36,9 +37,27 @@ __device__ float group_2_to_4(float4 vals) { return best_sum; } -inline float* float_ptr_from_numpy(py::array_t& py_float) { return (float*)py_float.data(); } +inline void check_cpu_contiguous(torch::stable::Tensor const& tensor, const char* name) { + STD_TORCH_CHECK(tensor.is_cpu(), name, " must be a CPU tensor"); + STD_TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); +} + +inline float* float_ptr_from_tensor(torch::stable::Tensor const& tensor, const char* name) { + check_cpu_contiguous(tensor, name); + STD_TORCH_CHECK(tensor.scalar_type() == torch::headeronly::ScalarType::Float, name, " must have dtype float32"); + return static_cast(tensor.mutable_data_ptr()); +} + +inline unsigned int* uint_ptr_from_tensor(torch::stable::Tensor const& tensor, const char* name) { + check_cpu_contiguous(tensor, name); + STD_TORCH_CHECK(tensor.element_size() == sizeof(unsigned int), name, " must have 32-bit elements"); + return reinterpret_cast(tensor.mutable_data_ptr()); +} -inline unsigned int* uint_ptr_from_numpy(py::array_t& py_uint) { return (unsigned int*)py_uint.data(); } +inline unsigned int as_uint_arg(int64_t value, const char* name) { + STD_TORCH_CHECK(value >= 0 && value <= std::numeric_limits::max(), name, " is out of range"); + return static_cast(value); +} /********************************************************** * Check for the best permutation for an entire matrix @@ -136,67 +155,6 @@ int set_up_check_permutation_memory(float** dmatrix, unsigned int rows, unsigned return fresh_alloc; } -int run_check_permutations( - py::array_t& py_matrix, unsigned int rows, unsigned int cols, - py::array_t& - py_stripe_groups, // groups of stripes, group_width = stripes per group, num_groups = groups in the array - unsigned int group_width, unsigned int num_groups, - py::array_t& py_permutations, // array of permutations to try, group_width*4 values per each of - // num_permutations permutations - unsigned int num_permutations, - py::array_t& py_improvement, // improvment offered by the best permutation - py::array_t& py_permutation // the best permutation -) { - const unsigned int threads = 32; - static float* d_matrix; - static unsigned int* d_permutations; - static unsigned int* d_stripes; - static float* d_results; - static float* results; - - float* matrix = float_ptr_from_numpy(py_matrix); - unsigned int* stripe_groups = uint_ptr_from_numpy(py_stripe_groups); - unsigned int* permutations = uint_ptr_from_numpy(py_permutations); - float* improvement = float_ptr_from_numpy(py_improvement); - unsigned int* permutation = uint_ptr_from_numpy(py_permutation); - - int fresh_alloc = set_up_check_permutation_memory(&d_matrix, rows, cols, &d_stripes, group_width, num_groups, - &d_permutations, num_permutations, &d_results, &results); - if (fresh_alloc == 1) { - gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * group_width * 4 * sizeof(unsigned int), - cudaMemcpyHostToDevice)); - gpuErrchk( - cudaMemcpy(d_stripes, stripe_groups, group_width * num_groups * sizeof(unsigned int), cudaMemcpyHostToDevice)); - } - - // initialize results, new matrix - gpuErrchk(cudaMemset(d_results, 0, num_permutations * sizeof(float))); - gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); - - // get results for all permutations - permute_and_sum_after_2_to_4<<>>( - d_matrix, rows, cols, d_stripes, group_width, d_permutations, d_results); - gpuErrchk(cudaDeviceSynchronize()); - - gpuErrchk(cudaMemcpy(results, d_results, num_permutations * sizeof(float), cudaMemcpyDeviceToHost)); - - // find the best permutation - could reduce on GPU - unsigned int best_permutation = 0; - float best_improvement = 0.0f; - for (unsigned int p = 1; p < num_permutations; ++p) { - float cur_improvement = results[p] - results[0]; - if (best_improvement < cur_improvement) { - best_permutation = p; - best_improvement = cur_improvement; - } - } - - *improvement = best_improvement; - *permutation = best_permutation; - - return 0; -} - /////////////////////////////////////////////////////////// /********************************************************** @@ -337,28 +295,6 @@ int set_up_sum_after_2_to_4_memory(float** dmatrix, unsigned int rows, unsigned return fresh_allocation; } -int run_subset_sum_after_2_to_4(py::array_t& py_matrix, unsigned int rows, unsigned int cols, - unsigned int start_col, unsigned int end_col, unsigned int blocks, unsigned int threads, - py::array_t& py_output) { - static float* d_matrix; - static float* d_result; - - int fresh_allocation = set_up_sum_after_2_to_4_memory(&d_matrix, rows, cols, &d_result); - - float* matrix = float_ptr_from_numpy(py_matrix); - float* output = float_ptr_from_numpy(py_output); - - gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); - gpuErrchk(cudaMemset(d_result, 0, sizeof(float))); - - subset_sum_after_2_to_4<<>>(d_matrix, rows, cols, start_col, end_col, d_result); - gpuErrchk(cudaDeviceSynchronize()); - - gpuErrchk(cudaMemcpy(output, d_result, sizeof(float), cudaMemcpyDeviceToHost)); - - return 0; -} - void set_up_permute_map_memory(float** dmatrix, unsigned int rows, unsigned int cols, unsigned int** dstripes, unsigned int num_groups, unsigned int group_width, unsigned int** dpermutations, unsigned int num_permutations, unsigned int perm_length, float** doutput, @@ -428,68 +364,6 @@ void set_up_permute_map_memory(float** dmatrix, unsigned int rows, unsigned int setUpPermLength = perm_length; } -int run_build_permute_map(py::array_t& py_matrix, unsigned int rows, unsigned int cols, - py::array_t& py_stripes, unsigned int num_groups, unsigned int group_width, - py::array_t& py_permutations, unsigned int perm_length, - py::array_t& py_improvements, py::array_t& py_best_indices) { - static float* d_matrix = NULL; - static unsigned int* d_stripes = NULL; - static unsigned int* d_permutations = NULL; - static float* d_output = NULL; - static unsigned int* d_indices = NULL; - static float* hresult = NULL; - static unsigned int* hindices = NULL; - - const unsigned int num_permutations = py_permutations.size() / perm_length; - - const unsigned int MAX_GROUPS_PER_LAUNCH = num_permutations <= 5775 ? 1820 : 40; - const unsigned int full_launches = num_groups / MAX_GROUPS_PER_LAUNCH; - const unsigned int final_launch = num_groups % MAX_GROUPS_PER_LAUNCH; - const unsigned int launches = full_launches + (final_launch != 0 ? 1 : 0); - - set_up_permute_map_memory(&d_matrix, rows, cols, &d_stripes, min(num_groups, MAX_GROUPS_PER_LAUNCH), group_width, - &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, &hindices); - - float* matrix = float_ptr_from_numpy(py_matrix); - unsigned int* stripes = uint_ptr_from_numpy(py_stripes); - unsigned int* permutations = uint_ptr_from_numpy(py_permutations); - float* improvements = float_ptr_from_numpy(py_improvements); - unsigned int* best_indices = uint_ptr_from_numpy(py_best_indices); - - gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); - gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * perm_length * sizeof(unsigned int), - cudaMemcpyHostToDevice)); - - unsigned int group_offset = 0; - for (unsigned int l = 0; l < launches; ++l) { - unsigned int groups_this_launch = (l < full_launches) ? MAX_GROUPS_PER_LAUNCH : final_launch; - - gpuErrchk(cudaMemcpy(d_stripes, &stripes[group_offset * group_width], - groups_this_launch * group_width * sizeof(unsigned int), cudaMemcpyHostToDevice)); - gpuErrchk(cudaMemset(d_output, 0, groups_this_launch * num_permutations * sizeof(float))); - gpuErrchk(cudaMemset(d_indices, 0, groups_this_launch * sizeof(unsigned int))); - - unsigned int shmem = 32 * (32) * sizeof(float); - build_permute_map<<>>(d_matrix, rows, cols, d_stripes, group_width, d_permutations, - num_permutations, perm_length, d_output, d_indices); - gpuErrchk(cudaDeviceSynchronize()); - - gpuErrchk( - cudaMemcpy(hresult, d_output, num_permutations * groups_this_launch * sizeof(float), cudaMemcpyDeviceToHost)); - gpuErrchk(cudaMemcpy(hindices, d_indices, groups_this_launch * sizeof(unsigned int), cudaMemcpyDeviceToHost)); - - // thread0 stuck the minimum in the first slot of each group - for (unsigned int g = 0; g < groups_this_launch; ++g) { - improvements[group_offset + g] = hresult[g * num_permutations]; - best_indices[group_offset + g] = hindices[g]; - } - - group_offset += groups_this_launch; - } - - return 0; -} - /********************************************************** * Build the swap map for channel_swaps **********************************************************/ @@ -594,17 +468,178 @@ void set_up_swap_map_memory(float** dmatrix, unsigned int rows, unsigned int col } } -int run_build_swap_map(py::array_t& py_matrix, unsigned int rows, unsigned int cols, - py::array_t& py_stripe_pairs, py::array_t& py_output) { +/////////////////////////////////////////////////////////// + +int64_t apex_permutation_search_check_permutations( + torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, int64_t cols_arg, + torch::stable::Tensor const& stripe_groups_tensor, int64_t group_width_arg, int64_t num_groups_arg, + torch::stable::Tensor const& permutations_tensor, int64_t num_permutations_arg, + torch::stable::Tensor const& improvement_tensor, torch::stable::Tensor const& permutation_tensor) { + static float* d_matrix; + static unsigned int* d_permutations; + static unsigned int* d_stripes; + static float* d_results; + static float* results; + + unsigned int rows = as_uint_arg(rows_arg, "rows"); + unsigned int cols = as_uint_arg(cols_arg, "cols"); + unsigned int group_width = as_uint_arg(group_width_arg, "group_width"); + unsigned int num_groups = as_uint_arg(num_groups_arg, "num_groups"); + unsigned int num_permutations = as_uint_arg(num_permutations_arg, "num_permutations"); + + float* matrix = float_ptr_from_tensor(matrix_tensor, "matrix"); + unsigned int* stripe_groups = uint_ptr_from_tensor(stripe_groups_tensor, "stripe_groups"); + unsigned int* permutations = uint_ptr_from_tensor(permutations_tensor, "permutations"); + float* improvement = float_ptr_from_tensor(improvement_tensor, "improvement"); + unsigned int* permutation = uint_ptr_from_tensor(permutation_tensor, "permutation"); + + int fresh_alloc = set_up_check_permutation_memory(&d_matrix, rows, cols, &d_stripes, group_width, num_groups, + &d_permutations, num_permutations, &d_results, &results); + if (fresh_alloc == 1) { + gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * group_width * 4 * sizeof(unsigned int), + cudaMemcpyHostToDevice)); + gpuErrchk( + cudaMemcpy(d_stripes, stripe_groups, group_width * num_groups * sizeof(unsigned int), cudaMemcpyHostToDevice)); + } + + gpuErrchk(cudaMemset(d_results, 0, num_permutations * sizeof(float))); + gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); + + permute_and_sum_after_2_to_4<<>>( + d_matrix, rows, cols, d_stripes, group_width, d_permutations, d_results); + gpuErrchk(cudaDeviceSynchronize()); + + gpuErrchk(cudaMemcpy(results, d_results, num_permutations * sizeof(float), cudaMemcpyDeviceToHost)); + + unsigned int best_permutation = 0; + float best_improvement = 0.0f; + for (unsigned int p = 1; p < num_permutations; ++p) { + float cur_improvement = results[p] - results[0]; + if (best_improvement < cur_improvement) { + best_permutation = p; + best_improvement = cur_improvement; + } + } + + *improvement = best_improvement; + *permutation = best_permutation; + return 0; +} + +int64_t apex_permutation_search_sum_after_2_to_4(torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, + int64_t cols_arg, int64_t start_col_arg, int64_t end_col_arg, + int64_t blocks_arg, int64_t threads_arg, + torch::stable::Tensor const& output_tensor) { + static float* d_matrix; + static float* d_result; + + unsigned int rows = as_uint_arg(rows_arg, "rows"); + unsigned int cols = as_uint_arg(cols_arg, "cols"); + unsigned int start_col = as_uint_arg(start_col_arg, "start_col"); + unsigned int end_col = as_uint_arg(end_col_arg, "end_col"); + unsigned int blocks = as_uint_arg(blocks_arg, "blocks"); + unsigned int threads = as_uint_arg(threads_arg, "threads"); + + int fresh_allocation = set_up_sum_after_2_to_4_memory(&d_matrix, rows, cols, &d_result); + (void)fresh_allocation; + + float* matrix = float_ptr_from_tensor(matrix_tensor, "matrix"); + float* output = float_ptr_from_tensor(output_tensor, "output"); + + gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); + gpuErrchk(cudaMemset(d_result, 0, sizeof(float))); + + subset_sum_after_2_to_4<<>>(d_matrix, rows, cols, start_col, end_col, d_result); + gpuErrchk(cudaDeviceSynchronize()); + + gpuErrchk(cudaMemcpy(output, d_result, sizeof(float), cudaMemcpyDeviceToHost)); + return 0; +} + +int64_t apex_permutation_search_build_permute_map( + torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, int64_t cols_arg, + torch::stable::Tensor const& stripes_tensor, int64_t num_groups_arg, int64_t group_width_arg, + torch::stable::Tensor const& permutations_tensor, int64_t perm_length_arg, + torch::stable::Tensor const& improvements_tensor, torch::stable::Tensor const& best_indices_tensor) { + static float* d_matrix = NULL; + static unsigned int* d_stripes = NULL; + static unsigned int* d_permutations = NULL; + static float* d_output = NULL; + static unsigned int* d_indices = NULL; + static float* hresult = NULL; + static unsigned int* hindices = NULL; + + unsigned int rows = as_uint_arg(rows_arg, "rows"); + unsigned int cols = as_uint_arg(cols_arg, "cols"); + unsigned int num_groups = as_uint_arg(num_groups_arg, "num_groups"); + unsigned int group_width = as_uint_arg(group_width_arg, "group_width"); + unsigned int perm_length = as_uint_arg(perm_length_arg, "perm_length"); + STD_TORCH_CHECK(perm_length > 0, "perm_length must be positive"); + unsigned int num_permutations = as_uint_arg(permutations_tensor.numel() / perm_length, "num_permutations"); + + const unsigned int MAX_GROUPS_PER_LAUNCH = num_permutations <= 5775 ? 1820 : 40; + const unsigned int full_launches = num_groups / MAX_GROUPS_PER_LAUNCH; + const unsigned int final_launch = num_groups % MAX_GROUPS_PER_LAUNCH; + const unsigned int launches = full_launches + (final_launch != 0 ? 1 : 0); + + set_up_permute_map_memory(&d_matrix, rows, cols, &d_stripes, min(num_groups, MAX_GROUPS_PER_LAUNCH), group_width, + &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, + &hindices); + + float* matrix = float_ptr_from_tensor(matrix_tensor, "matrix"); + unsigned int* stripes = uint_ptr_from_tensor(stripes_tensor, "stripes"); + unsigned int* permutations = uint_ptr_from_tensor(permutations_tensor, "permutations"); + float* improvements = float_ptr_from_tensor(improvements_tensor, "improvements"); + unsigned int* best_indices = uint_ptr_from_tensor(best_indices_tensor, "best_indices"); + + gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); + gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * perm_length * sizeof(unsigned int), + cudaMemcpyHostToDevice)); + + unsigned int group_offset = 0; + for (unsigned int l = 0; l < launches; ++l) { + unsigned int groups_this_launch = (l < full_launches) ? MAX_GROUPS_PER_LAUNCH : final_launch; + + gpuErrchk(cudaMemcpy(d_stripes, &stripes[group_offset * group_width], + groups_this_launch * group_width * sizeof(unsigned int), cudaMemcpyHostToDevice)); + gpuErrchk(cudaMemset(d_output, 0, groups_this_launch * num_permutations * sizeof(float))); + gpuErrchk(cudaMemset(d_indices, 0, groups_this_launch * sizeof(unsigned int))); + + unsigned int shmem = 32 * (32) * sizeof(float); + build_permute_map<<>>(d_matrix, rows, cols, d_stripes, group_width, d_permutations, + num_permutations, perm_length, d_output, d_indices); + gpuErrchk(cudaDeviceSynchronize()); + + gpuErrchk( + cudaMemcpy(hresult, d_output, num_permutations * groups_this_launch * sizeof(float), cudaMemcpyDeviceToHost)); + gpuErrchk(cudaMemcpy(hindices, d_indices, groups_this_launch * sizeof(unsigned int), cudaMemcpyDeviceToHost)); + + for (unsigned int g = 0; g < groups_this_launch; ++g) { + improvements[group_offset + g] = hresult[g * num_permutations]; + best_indices[group_offset + g] = hindices[g]; + } + + group_offset += groups_this_launch; + } + + return 0; +} + +int64_t apex_permutation_search_build_swap_map(torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, + int64_t cols_arg, + torch::stable::Tensor const& stripe_pairs_tensor, + torch::stable::Tensor const& output_tensor) { static float* d_matrix = NULL; static float* d_result = NULL; static unsigned int* d_stripe_pairs = NULL; - float* matrix = float_ptr_from_numpy(py_matrix); //(float*)py_matrix.data(); - unsigned int* stripe_pairs = uint_ptr_from_numpy(py_stripe_pairs); //(unsigned int*)py_stripe_pairs.data(); - float* output = float_ptr_from_numpy(py_output); //(float*)py_output.data(); + unsigned int rows = as_uint_arg(rows_arg, "rows"); + unsigned int cols = as_uint_arg(cols_arg, "cols"); + unsigned int num_pairs = as_uint_arg(stripe_pairs_tensor.numel() / 2, "num_pairs"); - unsigned int num_pairs = py_stripe_pairs.size() / 2; + float* matrix = float_ptr_from_tensor(matrix_tensor, "matrix"); + unsigned int* stripe_pairs = uint_ptr_from_tensor(stripe_pairs_tensor, "stripe_pairs"); + float* output = float_ptr_from_tensor(output_tensor, "output"); set_up_swap_map_memory(&d_matrix, rows, cols, &d_stripe_pairs, num_pairs, &d_result); gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); @@ -616,17 +651,25 @@ int run_build_swap_map(py::array_t& py_matrix, unsigned int rows, unsigne gpuErrchk(cudaDeviceSynchronize()); gpuErrchk(cudaMemcpy(output, d_result, num_pairs * 16 * sizeof(float), cudaMemcpyDeviceToHost)); - return 0; } -/////////////////////////////////////////////////////////// -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("sum_after_2_to_4", &run_subset_sum_after_2_to_4, "matrix sum after applying 2:4 (CUDA)", - py::call_guard()); - m.def("build_permute_map", &run_build_permute_map, "optimize stripe groups (CUDA)", - py::call_guard()); - m.def("check_permutations", &run_check_permutations, "exhaustively check all permutations (CUDA)", - py::call_guard()); - m.def("build_swap_map", &run_build_swap_map, "channel swaps (CUDA)", py::call_guard()); +STABLE_TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("permutation_search_sum_after_2_to_4(Tensor matrix, int rows, int cols, int start_col, int end_col, " + "int blocks, int threads, Tensor(a!) output) -> int"); + m.def("permutation_search_build_permute_map(Tensor matrix, int rows, int cols, Tensor stripes, int num_groups, " + "int group_width, Tensor permutations, int perm_length, Tensor(a!) improvements, Tensor(b!) best_indices) " + "-> int"); + m.def("permutation_search_check_permutations(Tensor matrix, int rows, int cols, Tensor stripe_groups, " + "int group_width, int num_groups, Tensor permutations, int num_permutations, Tensor(a!) improvement, " + "Tensor(b!) permutation) -> int"); + m.def("permutation_search_build_swap_map(Tensor matrix, int rows, int cols, Tensor stripe_pairs, " + "Tensor(a!) output) -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(apex, CPU, m) { + m.impl("permutation_search_sum_after_2_to_4", TORCH_BOX(&apex_permutation_search_sum_after_2_to_4)); + m.impl("permutation_search_build_permute_map", TORCH_BOX(&apex_permutation_search_build_permute_map)); + m.impl("permutation_search_check_permutations", TORCH_BOX(&apex_permutation_search_check_permutations)); + m.impl("permutation_search_build_swap_map", TORCH_BOX(&apex_permutation_search_build_swap_map)); } diff --git a/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py b/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py index 7d253fa0b..04d3a010a 100644 --- a/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py +++ b/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py @@ -6,12 +6,12 @@ gpus_found = 0 kernels_found = True try: - import permutation_search_cuda as permutation_search_cuda_kernels + from apex._extensions import permutation_search_cuda as permutation_search_cuda_kernels print("Found permutation search CUDA kernels") except ImportError: try: - from . import permutation_search_cuda as permutation_search_cuda_kernels + from apex._extensions import permutation_search_cuda as permutation_search_cuda_kernels print("Found permutation search CUDA kernels for standalone testing") diff --git a/apex/contrib/test/layer_norm/test_fast_layer_norm.py b/apex/contrib/test/layer_norm/test_fast_layer_norm.py index c85afa612..4e4dda0a0 100644 --- a/apex/contrib/test/layer_norm/test_fast_layer_norm.py +++ b/apex/contrib/test/layer_norm/test_fast_layer_norm.py @@ -6,7 +6,7 @@ SKIP_TEST = None try: from apex.contrib.layer_norm.layer_norm import FastLayerNorm - import fast_layer_norm as fln + from apex._extensions import fast_layer_norm as fln except ImportError as e: SKIP_TEST = e diff --git a/apex/contrib/transducer/transducer.py b/apex/contrib/transducer/transducer.py index bb53d39a6..c4c4f9103 100755 --- a/apex/contrib/transducer/transducer.py +++ b/apex/contrib/transducer/transducer.py @@ -1,6 +1,6 @@ import torch -import transducer_loss_cuda -import transducer_joint_cuda +from apex._extensions import transducer_loss_cuda +from apex._extensions import transducer_joint_cuda class TransducerJoint(torch.nn.Module): diff --git a/apex/contrib/xentropy/softmax_xentropy.py b/apex/contrib/xentropy/softmax_xentropy.py index 528f743b0..78fb44d10 100644 --- a/apex/contrib/xentropy/softmax_xentropy.py +++ b/apex/contrib/xentropy/softmax_xentropy.py @@ -1,6 +1,6 @@ import torch -import xentropy_cuda +from apex._extensions import xentropy_cuda class SoftmaxCrossEntropyLoss(torch.autograd.Function): diff --git a/apex/fused_dense/fused_dense.py b/apex/fused_dense/fused_dense.py index 239d12727..fffe5cf39 100644 --- a/apex/fused_dense/fused_dense.py +++ b/apex/fused_dense/fused_dense.py @@ -1,6 +1,8 @@ +import math + import torch from torch import nn -import fused_dense_cuda +from apex._extensions import fused_dense_cuda from apex._autocast_utils import _cast_if_autocast_enabled @@ -60,6 +62,8 @@ def backward(ctx, grad_output): def _fused_dense(input, weight, bias): args = _cast_if_autocast_enabled(input, weight, bias) with torch.amp.autocast("cuda", enabled=False): + if args[0].dtype == torch.bfloat16: + return torch.matmul(args[0], args[1].t()) + args[2] return FusedDenseFunc.apply(*args) @@ -86,6 +90,14 @@ def __init__(self, in_features, out_features, bias=True): else: # assert False, "no-bias option not added yet" self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) def forward(self, input): if self.bias is not None: @@ -105,6 +117,15 @@ def __init__(self, in_features, intermediate_features, out_features, bias=True): self.bias1 = nn.Parameter(torch.empty(intermediate_features)) self.weight2 = nn.Parameter(torch.empty(out_features, intermediate_features)) self.bias2 = nn.Parameter(torch.empty(out_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + for weight, bias in ((self.weight1, self.bias1), (self.weight2, self.bias2)): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(bias, -bound, bound) def forward(self, input): return _fused_dense_gelu_dense(input, self.weight1, self.bias1, self.weight2, self.bias2) diff --git a/apex/mlp/mlp.py b/apex/mlp/mlp.py index 4297b0a73..834f3b70b 100644 --- a/apex/mlp/mlp.py +++ b/apex/mlp/mlp.py @@ -5,7 +5,7 @@ from torch import nn from apex._autocast_utils import _cast_if_autocast_enabled -import mlp_cuda +from apex._extensions import mlp_cuda class MlpFunction(torch.autograd.Function): diff --git a/apex/multi_tensor_apply/multi_tensor_apply.py b/apex/multi_tensor_apply/multi_tensor_apply.py index ba2c21ada..4eeb3f6ba 100644 --- a/apex/multi_tensor_apply/multi_tensor_apply.py +++ b/apex/multi_tensor_apply/multi_tensor_apply.py @@ -4,7 +4,7 @@ class MultiTensorApply(object): def __init__(self, chunk_size): try: - import amp_C + from apex._extensions import amp_C MultiTensorApply.available = True self.chunk_size = chunk_size diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index a0f3833bc..5a3d7bc83 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -40,7 +40,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -87,7 +87,7 @@ def fused_layer_norm_affine_fwd( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") input_ = input.contiguous() weight_ = weight.contiguous() @@ -204,7 +204,7 @@ class FusedRMSNormAffineFunction(torch.autograd.Function): def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -247,7 +247,7 @@ def fused_rms_norm_affine_fwd( ) -> Tuple[torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") input_ = input.contiguous() weight_ = weight.contiguous() @@ -358,7 +358,7 @@ class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -380,7 +380,7 @@ class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -401,7 +401,7 @@ class FusedLayerNormFunction(torch.autograd.Function): def forward(ctx, input, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -439,7 +439,7 @@ def fused_layer_norm_fwd( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") input_ = input.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward(input_, normalized_shape, eps) @@ -535,7 +535,7 @@ class FusedRMSNormFunction(torch.autograd.Function): def forward(ctx, input, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -573,7 +573,7 @@ def fused_rms_norm_fwd( ) -> Tuple[torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") input_ = input.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward(input_, normalized_shape, eps) @@ -791,7 +791,7 @@ def __init__( super().__init__() global fused_layer_norm_cuda - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) @@ -908,7 +908,7 @@ def __init__( super().__init__() global fused_layer_norm_cuda - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) diff --git a/apex/optimizers/fused_adagrad.py b/apex/optimizers/fused_adagrad.py index d9dbcd7cc..43c80d0e4 100644 --- a/apex/optimizers/fused_adagrad.py +++ b/apex/optimizers/fused_adagrad.py @@ -56,7 +56,7 @@ def __init__( self.set_grad_none = set_grad_none if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/apex/optimizers/fused_adam.py b/apex/optimizers/fused_adam.py index 45d7e017e..e7b42c13f 100644 --- a/apex/optimizers/fused_adam.py +++ b/apex/optimizers/fused_adam.py @@ -125,7 +125,7 @@ def __init__( self._step_supports_amp_scaling = True if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") diff --git a/apex/optimizers/fused_lamb.py b/apex/optimizers/fused_lamb.py index a5630a03c..ac6d79e85 100644 --- a/apex/optimizers/fused_lamb.py +++ b/apex/optimizers/fused_lamb.py @@ -88,7 +88,7 @@ def __init__( ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm # Skip buffer diff --git a/apex/optimizers/fused_mixed_precision_lamb.py b/apex/optimizers/fused_mixed_precision_lamb.py index 2ecdddfd2..4a759ffaa 100644 --- a/apex/optimizers/fused_mixed_precision_lamb.py +++ b/apex/optimizers/fused_mixed_precision_lamb.py @@ -50,7 +50,7 @@ def __init__( self.param_groups[idx][item] = group[item].to(device=device) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm_mp # Skip buffer diff --git a/apex/optimizers/fused_novograd.py b/apex/optimizers/fused_novograd.py index b72e2e3e6..c044fc59d 100644 --- a/apex/optimizers/fused_novograd.py +++ b/apex/optimizers/fused_novograd.py @@ -93,7 +93,7 @@ def __init__( ) super(FusedNovoGrad, self).__init__(params, defaults) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer # Creating the overflow buffer on the same device as the params tensors. diff --git a/apex/optimizers/fused_sgd.py b/apex/optimizers/fused_sgd.py index f4cec9f57..429a6b5b2 100644 --- a/apex/optimizers/fused_sgd.py +++ b/apex/optimizers/fused_sgd.py @@ -111,7 +111,7 @@ def __init__( self.set_grad_none = set_grad_none if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer self._dummy_overflow_buf = torch.tensor( diff --git a/csrc/amp_C_frontend.cpp b/csrc/amp_C_frontend.cpp index 585e7872a..694aaf760 100644 --- a/csrc/amp_C_frontend.cpp +++ b/csrc/amp_C_frontend.cpp @@ -1,4 +1,4 @@ -#include +#include void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, float scale); @@ -80,44 +80,214 @@ at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, at::Tensor gro const double backoff_factor, const int64_t growth_interval, const int hysteresis); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_scale", &multi_tensor_scale_cuda, "Fused overflow check + scale for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, "Fused SGD optimizer for list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda, "out = a*x + b*y for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, "Computes L2 norm for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_l2norm_mp", &multi_tensor_l2norm_mp_cuda, "Computes L2 norm for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda, - "Computes L2 norm for a list of contiguous tensors and does scaling", py::call_guard()); - m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda, - "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only performed for L2 norm " - "computation, and tensors are not updated)", - py::call_guard()); - m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda, "Computes update part of LAMB optimizer", - py::call_guard()); - m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda, - "Completes application of gradient to parameters for LAMB optimizer", py::call_guard()); - m.def("multi_tensor_adam", &multi_tensor_adam_cuda, - "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); - m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, - "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support and LR scheduling", - py::call_guard()); - m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda, - "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support, LR scheduling and " - "FP32 master weights", - py::call_guard()); - m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda, - "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); - m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda, - "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); - m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer", - py::call_guard()); - m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda, "Computes and apply update for LAMB optimizer", - py::call_guard()); - m.def("update_scale_hysteresis", &update_scale_hysteresis_cuda, "Updates scale while accounting for hysteresis", - py::call_guard()); +namespace { +void apex_multi_tensor_scale(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double scale) { + multi_tensor_scale_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(scale)); +} + +void apex_multi_tensor_sgd(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double wd, double momentum, + double dampening, double lr, bool nesterov, bool first_run, bool wd_after_momentum, + double scale) { + multi_tensor_sgd_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(wd), + static_cast(momentum), static_cast(dampening), static_cast(lr), nesterov, + first_run, wd_after_momentum, static_cast(scale)); +} + +void apex_multi_tensor_axpby(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double a, double b, + int64_t arg_to_check) { + multi_tensor_axpby_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(a), + static_cast(b), static_cast(arg_to_check)); +} + +std::tuple apex_multi_tensor_l2norm( + int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + at::optional per_tensor_python) { + return multi_tensor_l2norm_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), per_tensor_python); +} + +std::tuple apex_multi_tensor_l2norm_mp( + int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + at::optional per_tensor_python) { + return multi_tensor_l2norm_mp_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), + per_tensor_python); +} + +std::tuple apex_multi_tensor_l2norm_scale( + int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, double scale, + at::optional per_tensor_python) { + return multi_tensor_l2norm_scale_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), + static_cast(scale), per_tensor_python); +} + +std::tuple apex_multi_tensor_unscale_l2norm( + int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor inv_scale, + at::optional per_tensor_python) { + return multi_tensor_unscale_l2norm_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), inv_scale, + per_tensor_python); +} + +void apex_multi_tensor_lamb_stage1_cuda(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_decay, int64_t step, double beta1, double beta2, + double epsilon, at::Tensor global_grad_norm, double max_global_grad_norm) { + multi_tensor_lamb_stage1_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), per_tensor_decay, + static_cast(step), static_cast(beta1), static_cast(beta2), + static_cast(epsilon), global_grad_norm, static_cast(max_global_grad_norm)); +} + +void apex_multi_tensor_lamb_stage2_cuda(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, double lr, + double weight_decay, at::optional use_nvlamb_python) { + multi_tensor_lamb_stage2_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), + per_tensor_param_norm, per_tensor_update_norm, static_cast(lr), + static_cast(weight_decay), use_nvlamb_python); +} + +void apex_multi_tensor_adam(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double lr, double beta1, double beta2, + double epsilon, int64_t step, int64_t mode, int64_t bias_correction, double weight_decay) { + multi_tensor_adam_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay)); +} + +void apex_multi_tensor_adam_capturable(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor lr, double beta1, + double beta2, double epsilon, at::Tensor step, int64_t mode, + int64_t bias_correction, double weight_decay, at::Tensor inv_scale) { + multi_tensor_adam_capturable_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), lr, + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + step, static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay), inv_scale); +} + +void apex_multi_tensor_adam_capturable_master(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor lr, + double beta1, double beta2, double epsilon, at::Tensor step, + int64_t mode, int64_t bias_correction, double weight_decay, + at::Tensor inv_scale) { + multi_tensor_adam_capturable_master_cuda( + static_cast(chunk_size), noop_flag, std::move(tensor_lists), lr, static_cast(beta1), + static_cast(beta2), static_cast(epsilon), step, static_cast(mode), + static_cast(bias_correction), static_cast(weight_decay), inv_scale); +} + +void apex_multi_tensor_adagrad(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double lr, double epsilon, + int64_t mode, double weight_decay) { + multi_tensor_adagrad_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(lr), + static_cast(epsilon), static_cast(mode), static_cast(weight_decay)); +} + +void apex_multi_tensor_novograd(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_norms, double lr, + double beta1, double beta2, double epsilon, int64_t step, int64_t bias_correction, + double weight_decay, int64_t grad_averaging, int64_t mode, int64_t norm_type) { + multi_tensor_novograd_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), grad_norms, + static_cast(lr), static_cast(beta1), static_cast(beta2), + static_cast(epsilon), static_cast(step), static_cast(bias_correction), + static_cast(weight_decay), static_cast(grad_averaging), + static_cast(mode), static_cast(norm_type)); +} + +void apex_multi_tensor_lamb(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double lr, double beta1, double beta2, + double epsilon, int64_t step, int64_t bias_correction, double weight_decay, + int64_t grad_averaging, int64_t mode, at::Tensor global_grad_norm, double max_grad_norm, + at::optional use_nvlamb_python) { + multi_tensor_lamb_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + static_cast(step), static_cast(bias_correction), static_cast(weight_decay), + static_cast(grad_averaging), static_cast(mode), global_grad_norm, + static_cast(max_grad_norm), use_nvlamb_python); +} + +void apex_multi_tensor_lamb_mp(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor lr, double beta1, + double beta2, double epsilon, at::Tensor step, int64_t bias_correction, + double weight_decay, int64_t grad_averaging, int64_t mode, + at::Tensor global_grad_norm, at::Tensor max_grad_norm, + at::optional use_nvlamb_python, at::Tensor found_inf, at::Tensor inv_scale) { + multi_tensor_lamb_mp_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), lr, + static_cast(beta1), static_cast(beta2), static_cast(epsilon), step, + static_cast(bias_correction), static_cast(weight_decay), + static_cast(grad_averaging), static_cast(mode), global_grad_norm, max_grad_norm, + use_nvlamb_python, found_inf, inv_scale); +} + +at::Tensor apex_update_scale_hysteresis(at::Tensor current_scale, at::Tensor growth_tracker, + at::Tensor hysteresis_tracker, at::Tensor found_inf, double growth_factor, + double backoff_factor, int64_t growth_interval, int64_t hysteresis) { + return update_scale_hysteresis_cuda(current_scale, growth_tracker, hysteresis_tracker, found_inf, growth_factor, + backoff_factor, growth_interval, static_cast(hysteresis)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("amp_multi_tensor_scale(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float scale) -> ()"); + m.def("amp_multi_tensor_sgd(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float wd, float momentum, " + "float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale) -> ()"); + m.def("amp_multi_tensor_axpby(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float a, float b, " + "int arg_to_check) -> ()"); + m.def("amp_multi_tensor_l2norm(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, bool? per_tensor_python) " + "-> (Tensor, Tensor)"); + m.def("amp_multi_tensor_l2norm_mp(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "bool? per_tensor_python) -> (Tensor, Tensor)"); + m.def("amp_multi_tensor_l2norm_scale(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float scale, " + "bool? per_tensor_python) -> (Tensor, Tensor)"); + m.def("amp_multi_tensor_unscale_l2norm(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor inv_scale, " + "bool? per_tensor_python) -> (Tensor, Tensor)"); + m.def("amp_multi_tensor_lamb_stage1_cuda(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_decay, int step, float beta1, float beta2, float epsilon, Tensor global_grad_norm, " + "float max_global_grad_norm) -> ()"); + m.def("amp_multi_tensor_lamb_stage2_cuda(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_param_norm, Tensor per_tensor_update_norm, float lr, float weight_decay, " + "bool? use_nvlamb_python) -> ()"); + m.def("amp_multi_tensor_adam(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " + "float beta2, float epsilon, int step, int mode, int bias_correction, float weight_decay) -> ()"); + m.def("amp_multi_tensor_adam_capturable(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, " + "float beta1, float beta2, float epsilon, Tensor step, int mode, int bias_correction, float weight_decay, " + "Tensor inv_scale) -> ()"); + m.def("amp_multi_tensor_adam_capturable_master(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, " + "float beta1, float beta2, float epsilon, Tensor step, int mode, int bias_correction, float weight_decay, " + "Tensor inv_scale) -> ()"); + m.def("amp_multi_tensor_adagrad(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float epsilon, " + "int mode, float weight_decay) -> ()"); + m.def("amp_multi_tensor_novograd(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor grad_norms, " + "float lr, float beta1, float beta2, float epsilon, int step, int bias_correction, float weight_decay, " + "int grad_averaging, int mode, int norm_type) -> ()"); + m.def("amp_multi_tensor_lamb(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " + "float beta2, float epsilon, int step, int bias_correction, float weight_decay, int grad_averaging, " + "int mode, Tensor global_grad_norm, float max_grad_norm, bool? use_nvlamb_python) -> ()"); + m.def("amp_multi_tensor_lamb_mp(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, float beta1, " + "float beta2, float epsilon, Tensor step, int bias_correction, float weight_decay, int grad_averaging, " + "int mode, Tensor global_grad_norm, Tensor max_grad_norm, bool? use_nvlamb_python, Tensor found_inf, " + "Tensor inv_scale) -> ()"); + m.def("amp_update_scale_hysteresis(Tensor current_scale, Tensor growth_tracker, Tensor hysteresis_tracker, " + "Tensor found_inf, float growth_factor, float backoff_factor, int growth_interval, int hysteresis) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("amp_multi_tensor_scale", &apex_multi_tensor_scale); + m.impl("amp_multi_tensor_sgd", &apex_multi_tensor_sgd); + m.impl("amp_multi_tensor_axpby", &apex_multi_tensor_axpby); + m.impl("amp_multi_tensor_l2norm", &apex_multi_tensor_l2norm); + m.impl("amp_multi_tensor_l2norm_mp", &apex_multi_tensor_l2norm_mp); + m.impl("amp_multi_tensor_l2norm_scale", &apex_multi_tensor_l2norm_scale); + m.impl("amp_multi_tensor_unscale_l2norm", &apex_multi_tensor_unscale_l2norm); + m.impl("amp_multi_tensor_lamb_stage1_cuda", &apex_multi_tensor_lamb_stage1_cuda); + m.impl("amp_multi_tensor_lamb_stage2_cuda", &apex_multi_tensor_lamb_stage2_cuda); + m.impl("amp_multi_tensor_adam", &apex_multi_tensor_adam); + m.impl("amp_multi_tensor_adam_capturable", &apex_multi_tensor_adam_capturable); + m.impl("amp_multi_tensor_adam_capturable_master", &apex_multi_tensor_adam_capturable_master); + m.impl("amp_multi_tensor_adagrad", &apex_multi_tensor_adagrad); + m.impl("amp_multi_tensor_novograd", &apex_multi_tensor_novograd); + m.impl("amp_multi_tensor_lamb", &apex_multi_tensor_lamb); + m.impl("amp_multi_tensor_lamb_mp", &apex_multi_tensor_lamb_mp); + m.impl("amp_update_scale_hysteresis", &apex_update_scale_hysteresis); } diff --git a/csrc/fused_dense.cpp b/csrc/fused_dense.cpp index 5b09f1e7c..eb04868d4 100644 --- a/csrc/fused_dense.cpp +++ b/csrc/fused_dense.cpp @@ -1,6 +1,7 @@ #include -#include -#include +#include +#include +#include #include @@ -40,7 +41,6 @@ at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor b AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_bias_forward", [&] { scalar_t* w_ptr = weight.data_ptr(); - scalar_t* b_ptr = bias.data_ptr(); [[maybe_unused]] auto result = linear_bias_forward_cuda(input, w_ptr, bias, in_features, batch_size, out_features, out, // out.data_ptr(), @@ -48,7 +48,7 @@ at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor b (void*)(lt_workspace.data_ptr())); }); - return {out}; + return out; } std::vector linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) { @@ -74,7 +74,6 @@ std::vector linear_bias_backward(at::Tensor input, at::Tensor weight AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_bias_backward", [&] { scalar_t* w_ptr = weight.data_ptr(); - scalar_t* d_b_ptr = d_bias.data_ptr(); [[maybe_unused]] auto result = linear_bias_backward_cuda( input.data_ptr(), w_ptr, d_output.data_ptr(), in_features, batch_size, out_features, d_weight.data_ptr(), d_bias.data_ptr(), d_input.data_ptr(), @@ -142,8 +141,6 @@ std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_bias_backward", [&] { - // scalar_t* w_ptr = weight.data_ptr(); - // scalar_t* d_b_ptr = d_bias.data_ptr(); [[maybe_unused]] auto result = linear_gelu_linear_backward_cuda( input.data_ptr(), gelu_in.data_ptr(), output1.data_ptr(), weight1.data_ptr(), weight2.data_ptr(), d_output1.data_ptr(), @@ -157,12 +154,18 @@ std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor return {d_input, d_weight1, d_bias1, d_weight2, d_bias2}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward", py::call_guard()); - m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward", - py::call_guard()); - m.def("linear_gelu_linear_forward", &linear_gelu_linear_forward, "linear gelu linear forward", - py::call_guard()); - m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward", - py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_dense_linear_bias_forward(Tensor input, Tensor weight, Tensor bias) -> Tensor"); + m.def("fused_dense_linear_bias_backward(Tensor input, Tensor weight, Tensor d_output) -> Tensor[]"); + m.def("fused_dense_linear_gelu_linear_forward(Tensor input, Tensor weight1, Tensor bias1, Tensor weight2, " + "Tensor bias2) -> Tensor[]"); + m.def("fused_dense_linear_gelu_linear_backward(Tensor input, Tensor gelu_in, Tensor output1, Tensor weight1, " + "Tensor weight2, Tensor d_output2) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_dense_linear_bias_forward", &linear_bias_forward); + m.impl("fused_dense_linear_bias_backward", &linear_bias_backward); + m.impl("fused_dense_linear_gelu_linear_forward", &linear_gelu_linear_forward); + m.impl("fused_dense_linear_gelu_linear_backward", &linear_gelu_linear_backward); } diff --git a/csrc/fused_dense_cuda.cu b/csrc/fused_dense_cuda.cu index 9cab0dfa8..fd9af6b85 100644 --- a/csrc/fused_dense_cuda.cu +++ b/csrc/fused_dense_cuda.cu @@ -4,7 +4,6 @@ #include #include #include -#include /* Includes, cuda */ #include diff --git a/csrc/layer_norm_cuda.cpp b/csrc/layer_norm_cuda.cpp index 4e5abae44..cbb50ba22 100644 --- a/csrc/layer_norm_cuda.cpp +++ b/csrc/layer_norm_cuda.cpp @@ -1,7 +1,9 @@ -#include +#include +#include #include #include +#include #include namespace { @@ -249,24 +251,84 @@ std::vector rms_norm_gradient_affine(at::Tensor& dout, at::Tensor& i return {grad_input, grad_gamma}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)", py::call_guard()); - m.def("forward", &layer_norm, "LayerNorm forward (CUDA)", py::call_guard()); - m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)", - py::call_guard()); - m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)", py::call_guard()); - - m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, - "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation", - py::call_guard()); - - m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)", py::call_guard()); - m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)", py::call_guard()); - m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)", - py::call_guard()); - m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)", py::call_guard()); - - m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, - "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation", - py::call_guard()); +namespace { +at::Tensor apex_layer_norm_gradient(const at::Tensor& dout, const std::optional& mean, + const at::Tensor& invvar, const at::Tensor& input_or_output, + at::IntArrayRef normalized_shape, double epsilon, bool memory_efficient) { + at::Tensor dout_ = dout; + at::Tensor invvar_ = invvar; + at::Tensor input_or_output_ = input_or_output; + return layer_norm_gradient(dout_, mean, invvar_, input_or_output_, normalized_shape, epsilon, memory_efficient); +} + +std::vector apex_layer_norm_gradient_affine(const at::Tensor& dout, + const std::optional& mean, + const at::Tensor& invvar, + const at::Tensor& input_or_output, + at::IntArrayRef normalized_shape, const at::Tensor& gamma, + const at::Tensor& beta, double epsilon, + bool memory_efficient) { + at::Tensor dout_ = dout; + at::Tensor invvar_ = invvar; + at::Tensor input_or_output_ = input_or_output; + at::Tensor gamma_ = gamma; + at::Tensor beta_ = beta; + return layer_norm_gradient_affine(dout_, mean, invvar_, input_or_output_, normalized_shape, gamma_, beta_, epsilon, + memory_efficient); +} + +at::Tensor apex_rms_norm_gradient(const at::Tensor& dout, const at::Tensor& invvar, + const at::Tensor& input_or_output, at::IntArrayRef normalized_shape, double epsilon, + bool memory_efficient) { + at::Tensor dout_ = dout; + at::Tensor invvar_ = invvar; + at::Tensor input_or_output_ = input_or_output; + return rms_norm_gradient(dout_, invvar_, input_or_output_, normalized_shape, epsilon, memory_efficient); +} + +std::vector apex_rms_norm_gradient_affine(const at::Tensor& dout, const at::Tensor& invvar, + const at::Tensor& input_or_output, + at::IntArrayRef normalized_shape, const at::Tensor& gamma, + double epsilon, bool memory_efficient) { + at::Tensor dout_ = dout; + at::Tensor invvar_ = invvar; + at::Tensor input_or_output_ = input_or_output; + at::Tensor gamma_ = gamma; + return rms_norm_gradient_affine(dout_, invvar_, input_or_output_, normalized_shape, gamma_, epsilon, + memory_efficient); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_layer_norm_forward_affine(Tensor input, int[] normalized_shape, Tensor gamma, Tensor beta, " + "float epsilon) -> Tensor[]"); + m.def("fused_layer_norm_forward(Tensor input, int[] normalized_shape, float epsilon) -> Tensor[]"); + m.def("fused_layer_norm_backward_affine(Tensor dout, Tensor? mean, Tensor invvar, Tensor input_or_output, " + "int[] normalized_shape, Tensor gamma, Tensor beta, float epsilon, bool memory_efficient) -> Tensor[]"); + m.def("fused_layer_norm_backward(Tensor dout, Tensor? mean, Tensor invvar, Tensor input_or_output, " + "int[] normalized_shape, float epsilon, bool memory_efficient) -> Tensor"); + m.def("fused_layer_norm_forward_affine_mixed_dtypes(Tensor input, int[] normalized_shape, Tensor gamma, Tensor beta, " + "float epsilon) -> Tensor[]"); + m.def("fused_layer_norm_rms_forward_affine(Tensor input, int[] normalized_shape, Tensor gamma, float epsilon) " + "-> Tensor[]"); + m.def("fused_layer_norm_rms_forward(Tensor input, int[] normalized_shape, float epsilon) -> Tensor[]"); + m.def("fused_layer_norm_rms_backward_affine(Tensor dout, Tensor invvar, Tensor input_or_output, " + "int[] normalized_shape, Tensor gamma, float epsilon, bool memory_efficient) -> Tensor[]"); + m.def("fused_layer_norm_rms_backward(Tensor dout, Tensor invvar, Tensor input_or_output, int[] normalized_shape, " + "float epsilon, bool memory_efficient) -> Tensor"); + m.def("fused_layer_norm_rms_forward_affine_mixed_dtypes(Tensor input, int[] normalized_shape, Tensor gamma, " + "float epsilon) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_layer_norm_forward_affine", &layer_norm_affine); + m.impl("fused_layer_norm_forward", &layer_norm); + m.impl("fused_layer_norm_backward_affine", &apex_layer_norm_gradient_affine); + m.impl("fused_layer_norm_backward", &apex_layer_norm_gradient); + m.impl("fused_layer_norm_forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes); + m.impl("fused_layer_norm_rms_forward_affine", &rms_norm_affine); + m.impl("fused_layer_norm_rms_forward", &rms_norm); + m.impl("fused_layer_norm_rms_backward_affine", &apex_rms_norm_gradient_affine); + m.impl("fused_layer_norm_rms_backward", &apex_rms_norm_gradient); + m.impl("fused_layer_norm_rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes); } diff --git a/csrc/megatron/fused_rotary_positional_embedding.cpp b/csrc/megatron/fused_rotary_positional_embedding.cpp index c0b2ded82..fb6762e6d 100644 --- a/csrc/megatron/fused_rotary_positional_embedding.cpp +++ b/csrc/megatron/fused_rotary_positional_embedding.cpp @@ -14,32 +14,32 @@ * limitations under the License. */ -#include +#include +#include namespace fused_rope { -torch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, const bool transpose_output); +at::Tensor fwd_cuda(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output); -torch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& freqs, const bool transpose_output); +at::Tensor bwd_cuda(const at::Tensor& output_grads, const at::Tensor& freqs, const bool transpose_output); -torch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& cos, const torch::Tensor& sin, - const bool transpose_output); +at::Tensor fwd_cached_cuda(const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output); -torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos, const torch::Tensor& sin, - const bool transpose_output); +at::Tensor bwd_cached_cuda(const at::Tensor& output_grads, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output); -torch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs); +at::Tensor fwd_thd_cuda(const at::Tensor& input, const at::Tensor& cu_seqlens, const at::Tensor& freqs); -torch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tensor& cu_seqlens, - const torch::Tensor& freqs); +at::Tensor bwd_thd_cuda(const at::Tensor& output_grads, const at::Tensor& cu_seqlens, const at::Tensor& freqs); -torch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w); +at::Tensor fwd_2d_cuda(const at::Tensor& input, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w); -torch::Tensor bwd_2d_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w); +at::Tensor bwd_2d_cuda(const at::Tensor& output_grads, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w); -torch::Tensor fwd(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output) { +at::Tensor fwd(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(input.size(0) == freqs.size(0), "expected input and freqs tensor have the same sequence length"); @@ -53,7 +53,7 @@ torch::Tensor fwd(const at::Tensor& input, const at::Tensor& freqs, const bool t return fwd_cuda(input, freqs, transpose_output); } -torch::Tensor bwd(const torch::Tensor& output_grads, const at::Tensor& freqs, const bool transpose_output) { +at::Tensor bwd(const at::Tensor& output_grads, const at::Tensor& freqs, const bool transpose_output) { TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(output_grads.size(0) == freqs.size(0), @@ -68,8 +68,8 @@ torch::Tensor bwd(const torch::Tensor& output_grads, const at::Tensor& freqs, co return bwd_cuda(output_grads, freqs, transpose_output); } -torch::Tensor fwd_cached(const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, - const bool transpose_output) { +at::Tensor fwd_cached(const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); @@ -86,8 +86,8 @@ torch::Tensor fwd_cached(const at::Tensor& input, const at::Tensor& cos, const a return fwd_cached_cuda(input, cos, sin, transpose_output); } -torch::Tensor bwd_cached(const torch::Tensor& output_grads, const at::Tensor& cos, const at::Tensor& sin, - const bool transpose_output) { +at::Tensor bwd_cached(const at::Tensor& output_grads, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output) { TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); @@ -106,7 +106,7 @@ torch::Tensor bwd_cached(const torch::Tensor& output_grads, const at::Tensor& co return bwd_cached_cuda(output_grads, cos, sin, transpose_output); } -torch::Tensor fwd_thd(const torch::Tensor& input, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs) { +at::Tensor fwd_thd(const at::Tensor& input, const at::Tensor& cu_seqlens, const at::Tensor& freqs) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); @@ -120,7 +120,7 @@ torch::Tensor fwd_thd(const torch::Tensor& input, const torch::Tensor& cu_seqlen return fwd_thd_cuda(input, cu_seqlens, freqs); } -torch::Tensor bwd_thd(const torch::Tensor& output_grads, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs) { +at::Tensor bwd_thd(const at::Tensor& output_grads, const at::Tensor& cu_seqlens, const at::Tensor& freqs) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); @@ -134,8 +134,8 @@ torch::Tensor bwd_thd(const torch::Tensor& output_grads, const torch::Tensor& cu return bwd_thd_cuda(output_grads, cu_seqlens, freqs); } -torch::Tensor fwd_2d(const torch::Tensor& input, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w) { +at::Tensor fwd_2d(const at::Tensor& input, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w) { TORCH_CHECK(input.dim() == 5, "expected input to be 5D tensor"); TORCH_CHECK(cos_h.dim() == 4, "expected cos_h to be 4D tensor"); TORCH_CHECK(sin_h.dim() == 4, "expected sin_h to be 4D tensor"); @@ -151,8 +151,8 @@ torch::Tensor fwd_2d(const torch::Tensor& input, const torch::Tensor& cos_h, con return fwd_2d_cuda(input, cos_h, sin_h, cos_w, sin_w); } -torch::Tensor bwd_2d(const torch::Tensor& output_grads, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w) { +at::Tensor bwd_2d(const at::Tensor& output_grads, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w) { TORCH_CHECK(output_grads.dim() == 5, "expected output_grads to be 5D tensor"); TORCH_CHECK(cos_h.dim() == 4, "expected cos_h to be 4D tensor"); TORCH_CHECK(sin_h.dim() == 4, "expected sin_h to be 4D tensor"); @@ -172,24 +172,24 @@ torch::Tensor bwd_2d(const torch::Tensor& output_grads, const torch::Tensor& cos } // end namespace fused_rope -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &fused_rope::fwd, "Fused Rotary Positional Embedding -- Forward.", - py::call_guard()); - m.def("backward", &fused_rope::bwd, "Fused Rotary Positional Embedding -- Backward.", - py::call_guard()); - // cache sin/cos - m.def("forward_cached", &fused_rope::fwd_cached, "Fused Rotary Positional Embedding Cached -- Forward.", - py::call_guard()); - m.def("backward_cached", &fused_rope::bwd_cached, "Fused Rotary Positional Embedding Cached -- Backward.", - py::call_guard()); - // thd - m.def("forward_thd", &fused_rope::fwd_thd, "Fused Rotary Positional Embedding for thd layout -- Forward.", - py::call_guard()); - m.def("backward_thd", &fused_rope::bwd_thd, "Fused Rotary Positional Embedding for thd layout -- Backward.", - py::call_guard()); - // 2d - m.def("forward_2d", &fused_rope::fwd_2d, "2D Fused Rotary Positional Embedding -- Forward.", - py::call_guard()); - m.def("backward_2d", &fused_rope::bwd_2d, "2D Fused Rotary Positional Embedding -- Backward.", - py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_rope_forward(Tensor input, Tensor freqs, bool transpose_output) -> Tensor"); + m.def("fused_rope_backward(Tensor output_grads, Tensor freqs, bool transpose_output) -> Tensor"); + m.def("fused_rope_forward_cached(Tensor input, Tensor cos, Tensor sin, bool transpose_output) -> Tensor"); + m.def("fused_rope_backward_cached(Tensor output_grads, Tensor cos, Tensor sin, bool transpose_output) -> Tensor"); + m.def("fused_rope_forward_thd(Tensor input, Tensor cu_seqlens, Tensor freqs) -> Tensor"); + m.def("fused_rope_backward_thd(Tensor output_grads, Tensor cu_seqlens, Tensor freqs) -> Tensor"); + m.def("fused_rope_forward_2d(Tensor input, Tensor cos_h, Tensor sin_h, Tensor cos_w, Tensor sin_w) -> Tensor"); + m.def("fused_rope_backward_2d(Tensor output_grads, Tensor cos_h, Tensor sin_h, Tensor cos_w, Tensor sin_w) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_rope_forward", &fused_rope::fwd); + m.impl("fused_rope_backward", &fused_rope::bwd); + m.impl("fused_rope_forward_cached", &fused_rope::fwd_cached); + m.impl("fused_rope_backward_cached", &fused_rope::bwd_cached); + m.impl("fused_rope_forward_thd", &fused_rope::fwd_thd); + m.impl("fused_rope_backward_thd", &fused_rope::bwd_thd); + m.impl("fused_rope_forward_2d", &fused_rope::fwd_2d); + m.impl("fused_rope_backward_2d", &fused_rope::bwd_2d); } diff --git a/csrc/megatron/fused_rotary_positional_embedding.h b/csrc/megatron/fused_rotary_positional_embedding.h index 32a6042da..7af1864a9 100644 --- a/csrc/megatron/fused_rotary_positional_embedding.h +++ b/csrc/megatron/fused_rotary_positional_embedding.h @@ -18,9 +18,9 @@ #include #include +#include #include #include -#include namespace { diff --git a/csrc/megatron/fused_rotary_positional_embedding_cuda.cu b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu index 82fef191d..7d5ec9299 100644 --- a/csrc/megatron/fused_rotary_positional_embedding_cuda.cu +++ b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu @@ -21,7 +21,7 @@ namespace fused_rope { -torch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, const bool transpose_output) { +at::Tensor fwd_cuda(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output) { // input sizes: (s, b, h, d) // s: sequence length // b: batch size @@ -42,11 +42,11 @@ torch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, c // output auto act_options = input.options().requires_grad(false); - torch::Tensor output; + at::Tensor output; if (transpose_output) { - output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + output = at::empty({b, s, h, d}, act_options).transpose(0, 1); } else { - output = torch::empty({s, b, h, d}, act_options); + output = at::empty({s, b, h, d}, act_options); } // output strides const int o_stride_s = output.stride(0); @@ -62,7 +62,7 @@ torch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, c return output; } -torch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& freqs, const bool transpose_output) { +at::Tensor bwd_cuda(const at::Tensor& output_grads, const at::Tensor& freqs, const bool transpose_output) { // output_grads sizes: (s, b, h, d) // s: sequence length // b: batch size @@ -82,11 +82,11 @@ torch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& f const int d2 = freqs.size(3); auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads; + at::Tensor input_grads; if (transpose_output) { - input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + input_grads = at::empty({b, s, h, d}, act_options).transpose(0, 1); } else { - input_grads = torch::empty({s, b, h, d}, act_options); + input_grads = at::empty({s, b, h, d}, act_options); } const int o_stride_s = input_grads.stride(0); const int o_stride_b = input_grads.stride(1); @@ -156,8 +156,8 @@ torch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& f TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), "' with '", toString(TYPE2), "'"); \ } -torch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& cos, const torch::Tensor& sin, - const bool transpose_output) { +at::Tensor fwd_cached_cuda(const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output) { // input sizes: (s, b, h, d) // s: sequence length // b: batch size @@ -178,11 +178,11 @@ torch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& c // output auto act_options = input.options().requires_grad(false); - torch::Tensor output; + at::Tensor output; if (transpose_output) { - output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + output = at::empty({b, s, h, d}, act_options).transpose(0, 1); } else { - output = torch::empty({s, b, h, d}, act_options); + output = at::empty({s, b, h, d}, act_options); } // output strides const int o_stride_s = output.stride(0); @@ -198,8 +198,8 @@ torch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& c return output; } -torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos, const torch::Tensor& sin, - const bool transpose_output) { +at::Tensor bwd_cached_cuda(const at::Tensor& output_grads, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output) { // output_grads sizes: (s, b, h, d) // s: sequence length // b: batch size @@ -219,11 +219,11 @@ torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Te const int d2 = cos.size(3); auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads; + at::Tensor input_grads; if (transpose_output) { - input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + input_grads = at::empty({b, s, h, d}, act_options).transpose(0, 1); } else { - input_grads = torch::empty({s, b, h, d}, act_options); + input_grads = at::empty({s, b, h, d}, act_options); } const int o_stride_s = input_grads.stride(0); const int o_stride_b = input_grads.stride(1); @@ -238,7 +238,7 @@ torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Te return input_grads; } -torch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs) { +at::Tensor fwd_thd_cuda(const at::Tensor& input, const at::Tensor& cu_seqlens, const at::Tensor& freqs) { // input sizes: (t, h, d) // t: cumulative sum of sequence lengths // h: head num @@ -258,7 +258,7 @@ torch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_s // output auto act_options = input.options().requires_grad(false); - auto output = torch::empty({t, h, d}, act_options); + auto output = at::empty({t, h, d}, act_options); // output strides const int o_stride_t = output.stride(0); const int o_stride_h = output.stride(1); @@ -272,8 +272,7 @@ torch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_s return output; } -torch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tensor& cu_seqlens, - const torch::Tensor& freqs) { +at::Tensor bwd_thd_cuda(const at::Tensor& output_grads, const at::Tensor& cu_seqlens, const at::Tensor& freqs) { // output_grads sizes: (t, h, d) // t: cumulative sum of sequence lengths // h: head num @@ -292,7 +291,7 @@ torch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tenso const int d2 = freqs.size(3); auto act_options = output_grads.options().requires_grad(false); - auto input_grads = torch::empty({t, h, d}, act_options); + auto input_grads = at::empty({t, h, d}, act_options); const int o_stride_t = input_grads.stride(0); const int o_stride_h = input_grads.stride(1); const int o_stride_d = input_grads.stride(2); @@ -305,8 +304,8 @@ torch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tenso return input_grads; } -torch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w) { +at::Tensor fwd_2d_cuda(const at::Tensor& input, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w) { // input sizes: (b, ih, iw, h, d) // b: batch size // ih: image height @@ -327,7 +326,7 @@ torch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h // output auto act_options = input.options().requires_grad(false); - auto output = torch::empty({b, ih * iw, h, d}, act_options); + auto output = at::empty({b, ih * iw, h, d}, act_options); // output strides const int o_stride_b = output.stride(0); const int o_stride_s = output.stride(1); @@ -343,8 +342,8 @@ torch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h return output; } -torch::Tensor bwd_2d_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w) { +at::Tensor bwd_2d_cuda(const at::Tensor& output_grads, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w) { // output_grads sizes: (b, ih, iw, h, d) // b: batch size // ih: image height @@ -364,7 +363,7 @@ torch::Tensor bwd_2d_cuda(const torch::Tensor& output_grads, const torch::Tensor const int stride_d = output_grads.stride(4); auto act_options = output_grads.options().requires_grad(false); - auto input_grads = torch::empty({b, ih * iw, h, d}, act_options); + auto input_grads = at::empty({b, ih * iw, h, d}, act_options); const int o_stride_b = input_grads.stride(0); const int o_stride_s = input_grads.stride(1); const int o_stride_h = input_grads.stride(2); diff --git a/csrc/megatron/fused_weight_gradient_dense.cpp b/csrc/megatron/fused_weight_gradient_dense.cpp index 71c826cfe..633ef068b 100644 --- a/csrc/megatron/fused_weight_gradient_dense.cpp +++ b/csrc/megatron/fused_weight_gradient_dense.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include #include @@ -7,9 +8,26 @@ void wgrad_gemm_accum_fp32_cuda_stub(at::Tensor& input_2d, at::Tensor& d_output_ void wgrad_gemm_accum_fp16_cuda_stub(at::Tensor& input_2d, at::Tensor& d_output_2d, at::Tensor& d_weight); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32", - py::call_guard()); - m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16", - py::call_guard()); +void wgrad_gemm_accum_fp32_dispatch(const at::Tensor& input, const at::Tensor& d_output, const at::Tensor& d_weight) { + at::Tensor input_arg = input; + at::Tensor d_output_arg = d_output; + at::Tensor d_weight_arg = d_weight; + wgrad_gemm_accum_fp32_cuda_stub(input_arg, d_output_arg, d_weight_arg); +} + +void wgrad_gemm_accum_fp16_dispatch(const at::Tensor& input, const at::Tensor& d_output, const at::Tensor& d_weight) { + at::Tensor input_arg = input; + at::Tensor d_output_arg = d_output; + at::Tensor d_weight_arg = d_weight; + wgrad_gemm_accum_fp16_cuda_stub(input_arg, d_output_arg, d_weight_arg); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_weight_gradient_mlp_wgrad_gemm_accum_fp32(Tensor input, Tensor d_output, Tensor(a!) d_weight) -> ()"); + m.def("fused_weight_gradient_mlp_wgrad_gemm_accum_fp16(Tensor input, Tensor d_output, Tensor(a!) d_weight) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_weight_gradient_mlp_wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_dispatch); + m.impl("fused_weight_gradient_mlp_wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_dispatch); } diff --git a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu index 09af5c0a9..87810500d 100644 --- a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu +++ b/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include diff --git a/csrc/megatron/fused_weight_gradient_dense_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_cuda.cu index 6adc8c7e2..e05f8e937 100644 --- a/csrc/megatron/fused_weight_gradient_dense_cuda.cu +++ b/csrc/megatron/fused_weight_gradient_dense_cuda.cu @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include diff --git a/csrc/megatron/generic_scaled_masked_softmax.cpp b/csrc/megatron/generic_scaled_masked_softmax.cpp index cb7a78809..a0cdf38f6 100644 --- a/csrc/megatron/generic_scaled_masked_softmax.cpp +++ b/csrc/megatron/generic_scaled_masked_softmax.cpp @@ -15,7 +15,8 @@ */ #include -#include +#include +#include #include @@ -23,11 +24,11 @@ namespace multihead_attn { namespace fused_softmax { namespace generic_scaled_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor); +at::Tensor fwd_cuda(at::Tensor const& input, at::Tensor const& mask, float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); +at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); -torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { +at::Tensor fwd(at::Tensor const& input, at::Tensor const& mask, float scale_factor) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); @@ -36,7 +37,7 @@ torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, float s return fwd_cuda(input, mask, scale_factor); } -torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor) { TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); @@ -50,14 +51,27 @@ torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softma return bwd_cuda(output_grads, softmax_results, scale_factor); } +at::Tensor fwd_dispatch(at::Tensor const& input, at::Tensor const& mask, double scale_factor) { + return fwd(input, mask, static_cast(scale_factor)); +} + +at::Tensor bwd_dispatch(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { + return bwd(output_grads, softmax_results, static_cast(scale_factor)); +} + } // end namespace generic_scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("generic_scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor"); + m.def("generic_scaled_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " + "float scale_factor) -> Tensor"); +} - m.def("backward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("generic_scaled_masked_softmax_forward", + &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd_dispatch); + m.impl("generic_scaled_masked_softmax_backward", + &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd_dispatch); } diff --git a/csrc/megatron/generic_scaled_masked_softmax_cuda.cu b/csrc/megatron/generic_scaled_masked_softmax_cuda.cu index cc1e55297..0ca0837b0 100644 --- a/csrc/megatron/generic_scaled_masked_softmax_cuda.cu +++ b/csrc/megatron/generic_scaled_masked_softmax_cuda.cu @@ -18,9 +18,10 @@ #include #include #include +#if __has_include() #include +#endif #include -#include #include "generic_scaled_masked_softmax.h" #include "type_shim.h" @@ -29,7 +30,7 @@ namespace multihead_attn { namespace fused_softmax { namespace generic_scaled_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { +at::Tensor fwd_cuda(at::Tensor const& input, at::Tensor const& mask, float scale_factor) { // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] const int batches = input.size(0); const int pad_batches = mask.size(0); @@ -43,7 +44,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, fl // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor softmax_results = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -58,7 +59,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, fl return softmax_results; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { +at::Tensor bwd_cuda(at::Tensor const& output_grads_, at::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -69,7 +70,7 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& const int key_seq_len = output_grads.size(3); auto act_options = output_grads.options(); - torch::Tensor input_grad = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor input_grad = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); void* output_grads_ptr = static_cast(output_grads.data_ptr()); diff --git a/csrc/megatron/scaled_masked_softmax.cpp b/csrc/megatron/scaled_masked_softmax.cpp index 82963817f..f358f8609 100644 --- a/csrc/megatron/scaled_masked_softmax.cpp +++ b/csrc/megatron/scaled_masked_softmax.cpp @@ -15,7 +15,8 @@ */ #include -#include +#include +#include #include @@ -23,13 +24,13 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor); +at::Tensor fwd_cuda(at::Tensor const& input, at::Tensor const& mask, float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); +at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads); -torch::Tensor fwd(torch::Tensor& input, torch::Tensor& mask, float scale_factor) { +at::Tensor fwd(at::Tensor& input, at::Tensor& mask, float scale_factor) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); @@ -40,7 +41,7 @@ torch::Tensor fwd(torch::Tensor& input, torch::Tensor& mask, float scale_factor) return fwd_cuda(input, mask, scale_factor); } -torch::Tensor bwd(torch::Tensor& output_grads, torch::Tensor& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor& output_grads, at::Tensor& softmax_results, float scale_factor) { TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); @@ -60,17 +61,40 @@ int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int att return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); } +at::Tensor fwd_dispatch(at::Tensor const& input, at::Tensor const& mask, double scale_factor) { + at::Tensor input_arg = input; + at::Tensor mask_arg = mask; + return fwd(input_arg, mask_arg, static_cast(scale_factor)); +} + +at::Tensor bwd_dispatch(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { + at::Tensor output_grads_arg = output_grads; + at::Tensor softmax_results_arg = softmax_results; + return bwd(output_grads_arg, softmax_results_arg, static_cast(scale_factor)); +} + +int64_t get_batch_per_block_dispatch(int64_t query_seq_len, int64_t key_seq_len, int64_t batches, int64_t attn_heads) { + return get_batch_per_block(static_cast(query_seq_len), static_cast(key_seq_len), static_cast(batches), + static_cast(attn_heads)); +} + } // end namespace scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor"); + m.def("scaled_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, float scale_factor) -> Tensor"); + m.def("scaled_masked_softmax_get_batch_per_block(int query_seq_len, int key_seq_len, int batches, " + "int attn_heads) -> int"); +} - m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("scaled_masked_softmax_forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd_dispatch); + m.impl("scaled_masked_softmax_backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd_dispatch); +} - m.def("get_batch_per_block", &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size.", py::call_guard()); +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { + m.impl("scaled_masked_softmax_get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block_dispatch); } diff --git a/csrc/megatron/scaled_masked_softmax_cuda.cu b/csrc/megatron/scaled_masked_softmax_cuda.cu index 879697819..fecc5cd37 100644 --- a/csrc/megatron/scaled_masked_softmax_cuda.cu +++ b/csrc/megatron/scaled_masked_softmax_cuda.cu @@ -18,9 +18,10 @@ #include #include #include +#if __has_include() #include +#endif #include -#include #include "scaled_masked_softmax.h" #include "type_shim.h" @@ -33,7 +34,7 @@ int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, in return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); } -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { +at::Tensor fwd_cuda(at::Tensor const& input, at::Tensor const& mask, float scale_factor) { // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] const int batches = input.size(0); const int pad_batches = mask.size(0); @@ -49,7 +50,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, fl // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor softmax_results = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -64,7 +65,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, fl return softmax_results; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { +at::Tensor bwd_cuda(at::Tensor const& output_grads_, at::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -75,7 +76,7 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& const int key_seq_len = output_grads.size(3); auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor input_grads = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); void* input_grads_ptr = static_cast(input_grads.data_ptr()); void* output_grads_ptr = static_cast(output_grads.data_ptr()); diff --git a/csrc/megatron/scaled_softmax.cpp b/csrc/megatron/scaled_softmax.cpp index 7e4d44c1b..8538247b4 100644 --- a/csrc/megatron/scaled_softmax.cpp +++ b/csrc/megatron/scaled_softmax.cpp @@ -15,7 +15,8 @@ */ #include -#include +#include +#include #include @@ -23,11 +24,11 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); +at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); +at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { +at::Tensor fwd(at::Tensor const& input, float scale_factor) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); @@ -35,7 +36,7 @@ torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { return fwd_cuda(input, scale_factor); } -torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor) { TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); @@ -49,13 +50,24 @@ torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softma return bwd_cuda(output_grads, softmax_results, scale_factor); } +at::Tensor fwd_dispatch(at::Tensor const& input, double scale_factor) { + return fwd(input, static_cast(scale_factor)); +} + +at::Tensor bwd_dispatch(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { + return bwd(output_grads, softmax_results, static_cast(scale_factor)); +} + } // end namespace scaled_softmax } // end namespace fused_softmax } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::scaled_softmax::fwd, - "Self Multihead Attention scaled, softmax -- Forward.", py::call_guard()); - m.def("backward", &multihead_attn::fused_softmax::scaled_softmax::bwd, - "Self Multihead Attention scaled, softmax -- Backward.", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("scaled_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def("scaled_softmax_backward(Tensor output_grads, Tensor softmax_results, float scale_factor) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("scaled_softmax_forward", &multihead_attn::fused_softmax::scaled_softmax::fwd_dispatch); + m.impl("scaled_softmax_backward", &multihead_attn::fused_softmax::scaled_softmax::bwd_dispatch); } diff --git a/csrc/megatron/scaled_softmax_cuda.cu b/csrc/megatron/scaled_softmax_cuda.cu index 4ec6e4bdf..93e08de31 100644 --- a/csrc/megatron/scaled_softmax_cuda.cu +++ b/csrc/megatron/scaled_softmax_cuda.cu @@ -18,9 +18,10 @@ #include #include #include +#if __has_include() #include +#endif #include -#include #include "scaled_masked_softmax.h" #include "type_shim.h" @@ -29,7 +30,7 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { +at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor) { // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] const int batches = input.size(0); const int attn_heads = input.size(1); @@ -40,7 +41,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor softmax_results = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -54,7 +55,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { return softmax_results; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { +at::Tensor bwd_cuda(at::Tensor const& output_grads_, at::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp index 726ea7dbe..2856197ca 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp +++ b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp @@ -15,7 +15,8 @@ */ #include -#include +#include +#include #include @@ -23,11 +24,11 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); +at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); +at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { +at::Tensor fwd(at::Tensor const& input, float scale_factor) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); @@ -35,7 +36,7 @@ torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { return fwd_cuda(input, scale_factor); } -torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); @@ -49,13 +50,27 @@ torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softma return bwd_cuda(output_grads, softmax_results, scale_factor); } +at::Tensor fwd_dispatch(at::Tensor const& input, double scale_factor) { + return fwd(input, static_cast(scale_factor)); +} + +at::Tensor bwd_dispatch(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { + return bwd(output_grads, softmax_results, static_cast(scale_factor)); +} + } // end namespace scaled_upper_triang_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); - m.def("backward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("scaled_upper_triang_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def("scaled_upper_triang_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " + "float scale_factor) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("scaled_upper_triang_masked_softmax_forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd_dispatch); + m.impl("scaled_upper_triang_masked_softmax_backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd_dispatch); } diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu index 86813c6a3..251209a87 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu +++ b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu @@ -18,9 +18,10 @@ #include #include #include +#if __has_include() #include +#endif #include -#include #include "scaled_upper_triang_masked_softmax.h" #include "type_shim.h" @@ -29,7 +30,7 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { +at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor) { // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] const int attn_batches = input.size(0); const int seq_len = input.size(1); @@ -37,7 +38,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, seq_len, seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -51,7 +52,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { return softmax_results; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { +at::Tensor bwd_cuda(at::Tensor const& output_grads_, at::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); diff --git a/csrc/mlp.cpp b/csrc/mlp.cpp index 5ee75a405..358b97173 100644 --- a/csrc/mlp.cpp +++ b/csrc/mlp.cpp @@ -1,6 +1,7 @@ #include -#include -#include +#include +#include +#include #include @@ -106,7 +107,25 @@ std::vector mlp_backward(int use_bias, int activation, at::Tensor gr return outputs; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &mlp_forward, "MLP forward", py::call_guard()); - m.def("backward", &mlp_backward, "MLP backward", py::call_guard()); +namespace { +std::vector apex_mlp_forward(int64_t use_bias, int64_t activation, std::vector inputs) { + return mlp_forward(static_cast(use_bias), static_cast(activation), std::move(inputs)); +} + +std::vector apex_mlp_backward(int64_t use_bias, int64_t activation, at::Tensor grad_o, + std::vector fprop_outputs, std::vector inputs) { + return mlp_backward(static_cast(use_bias), static_cast(activation), grad_o, std::move(fprop_outputs), + std::move(inputs)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("mlp_forward(int use_bias, int activation, Tensor[] inputs) -> Tensor[]"); + m.def("mlp_backward(int use_bias, int activation, Tensor grad_o, Tensor[] fprop_outputs, Tensor[] inputs) " + "-> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("mlp_forward", &apex_mlp_forward); + m.impl("mlp_backward", &apex_mlp_backward); } diff --git a/csrc/mlp_cuda.cu b/csrc/mlp_cuda.cu index 4a870da4d..912fa096a 100644 --- a/csrc/mlp_cuda.cu +++ b/csrc/mlp_cuda.cu @@ -4,7 +4,6 @@ #include #include #include -#include /* Includes, cuda */ #include diff --git a/csrc/syncbn.cpp b/csrc/syncbn.cpp index d91a6ad29..b18db1f14 100644 --- a/csrc/syncbn.cpp +++ b/csrc/syncbn.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include @@ -68,22 +68,42 @@ at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output, const at::Ten const at::optional z, const at::Tensor mean, const at::Tensor inv_std, const at::optional weight, const at::optional shift); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance", py::call_guard()); - m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance", - py::call_guard()); - m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward", py::call_guard()); - m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad", - py::call_guard()); - m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad", - py::call_guard()); - m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc", - py::call_guard()); - m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc", - py::call_guard()); - m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc", - py::call_guard()); - m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc", - py::call_guard()); - m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last", py::call_guard()); +namespace { +std::vector apex_welford_parallel_CUDA(const at::Tensor mean_feature_nodes, + const at::Tensor var_biased_feature_nodes, const at::Tensor numel, + double eps) { + return welford_parallel_CUDA(mean_feature_nodes, var_biased_feature_nodes, numel, static_cast(eps)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("syncbn_welford_mean_var(Tensor input) -> Tensor[]"); + m.def("syncbn_welford_parallel(Tensor mean_feature_nodes, Tensor var_biased_feature_nodes, Tensor numel, float eps) " + "-> Tensor[]"); + m.def("syncbn_batchnorm_forward(Tensor input, Tensor mean, Tensor inv_std, Tensor? weight, Tensor? shift) -> Tensor"); + m.def("syncbn_reduce_bn(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight) -> Tensor[]"); + m.def("syncbn_batchnorm_backward(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight, " + "Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor"); + m.def("syncbn_welford_mean_var_c_last(Tensor input) -> Tensor[]"); + m.def("syncbn_batchnorm_forward_c_last(Tensor input, Tensor? z, Tensor mean, Tensor inv_std, Tensor? weight, " + "Tensor? shift, bool fuse_relu) -> Tensor"); + m.def("syncbn_reduce_bn_c_last(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight) " + "-> Tensor[]"); + m.def("syncbn_batchnorm_backward_c_last(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, " + "Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor"); + m.def("syncbn_relu_bw_c_last(Tensor grad_output, Tensor input, Tensor? z, Tensor mean, Tensor inv_std, " + "Tensor? weight, Tensor? shift) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("syncbn_welford_mean_var", &welford_mean_var_CUDA); + m.impl("syncbn_welford_parallel", &apex_welford_parallel_CUDA); + m.impl("syncbn_batchnorm_forward", &batchnorm_forward_CUDA); + m.impl("syncbn_reduce_bn", &reduce_bn_CUDA); + m.impl("syncbn_batchnorm_backward", &batchnorm_backward_CUDA); + m.impl("syncbn_welford_mean_var_c_last", &welford_mean_var_c_last_CUDA); + m.impl("syncbn_batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA); + m.impl("syncbn_reduce_bn_c_last", &reduce_bn_c_last_CUDA); + m.impl("syncbn_batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA); + m.impl("syncbn_relu_bw_c_last", &relu_backward_c_last_CUDA); } diff --git a/setup.py b/setup.py index a0e59cf86..b60c00e20 100644 --- a/setup.py +++ b/setup.py @@ -168,10 +168,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int if has_flag("--cpp_ext", "APEX_CPP_EXT"): if "--cpp_ext" in sys.argv: sys.argv.remove("--cpp_ext") - ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"])) + # apex._extensions.apex_C is a pure Python shim to avoid exposing a C++ ABI + # for dense tensor flattening utilities. -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +bare_metal_version = None +if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if has_flag("--distributed_adam", "APEX_DISTRIBUTED_ADAM"): if "--distributed_adam" in sys.argv: @@ -179,12 +182,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--distributed_adam") ext_modules.append( CUDAExtension( - name="distributed_adam_cuda", + name="_distributed_adam_cuda", sources=[ "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp", "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu", ], include_dirs=[os.path.join(this_dir, "csrc")], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"], @@ -198,12 +202,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--distributed_lamb") ext_modules.append( CUDAExtension( - name="distributed_lamb_cuda", + name="_distributed_lamb_cuda", sources=[ "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp", "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu", ], include_dirs=[os.path.join(this_dir, "csrc")], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"], @@ -219,7 +224,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="amp_C", + name="_amp_C", sources=[ "csrc/amp_C_frontend.cpp", "csrc/multi_tensor_sgd_kernel.cu", @@ -237,6 +242,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "csrc/multi_tensor_lamb_mp.cu", "csrc/update_scale_hysteresis.cu", ], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": [ @@ -250,8 +256,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="syncbn", + name="_syncbn", sources=["csrc/syncbn.cpp", "csrc/welford.cu"], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3"], @@ -261,8 +268,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fused_layer_norm_cuda", + name="_fused_layer_norm_cuda", sources=["csrc/layer_norm_cuda.cpp", "csrc/layer_norm_cuda_kernel.cu"], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-maxrregcount=50", "-O3", "--use_fast_math"], @@ -272,8 +280,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="mlp_cuda", + name="_mlp_cuda", sources=["csrc/mlp.cpp", "csrc/mlp_cuda.cu"], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3"], @@ -282,8 +291,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="fused_dense_cuda", + name="_fused_dense_cuda", sources=["csrc/fused_dense.cpp", "csrc/fused_dense_cuda.cu"], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3"], @@ -293,11 +303,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="scaled_upper_triang_masked_softmax_cuda", + name="_scaled_upper_triang_masked_softmax_cuda", sources=[ "csrc/megatron/scaled_upper_triang_masked_softmax.cpp", "csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -314,11 +325,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="generic_scaled_masked_softmax_cuda", + name="_generic_scaled_masked_softmax_cuda", sources=[ "csrc/megatron/generic_scaled_masked_softmax.cpp", "csrc/megatron/generic_scaled_masked_softmax_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -335,11 +347,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="scaled_masked_softmax_cuda", + name="_scaled_masked_softmax_cuda", sources=[ "csrc/megatron/scaled_masked_softmax.cpp", "csrc/megatron/scaled_masked_softmax_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -356,11 +369,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="scaled_softmax_cuda", + name="_scaled_softmax_cuda", sources=[ "csrc/megatron/scaled_softmax.cpp", "csrc/megatron/scaled_softmax_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -377,11 +391,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fused_rotary_positional_embedding", + name="_fused_rotary_positional_embedding", sources=[ "csrc/megatron/fused_rotary_positional_embedding.cpp", "csrc/megatron/fused_rotary_positional_embedding_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -398,13 +413,14 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fused_weight_gradient_mlp_cuda", + name="_fused_weight_gradient_mlp_cuda", include_dirs=[os.path.join(this_dir, "csrc")], sources=[ "csrc/megatron/fused_weight_gradient_dense.cpp", "csrc/megatron/fused_weight_gradient_dense_cuda.cu", "csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu", ], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": [ @@ -431,7 +447,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int cc_flag = ["-Xcompiler", "-fPIC", "-shared"] ext_modules.append( CUDAExtension( - name="permutation_search_cuda", + name="_permutation_search_cuda", sources=[ "apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu" ], @@ -446,6 +462,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ], extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"] + cc_flag}, + py_limited_api=True, ) ) @@ -455,7 +472,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--bnp") ext_modules.append( CUDAExtension( - name="bnp", + name="_bnp", sources=[ "apex/contrib/csrc/groupbn/batch_norm.cu", "apex/contrib/csrc/groupbn/ipc.cu", @@ -463,6 +480,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "apex/contrib/csrc/groupbn/batch_norm_add_relu.cu", ], include_dirs=[os.path.join(this_dir, "csrc")], + py_limited_api=True, extra_compile_args={ "cxx": [], "nvcc": [ @@ -485,7 +503,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int print(f"`--xentropy` setting version of {xentropy_ver}") ext_modules.append( CUDAExtension( - name="xentropy_cuda", + name="_xentropy_cuda", sources=[ "apex/contrib/csrc/xentropy/interface.cpp", "apex/contrib/csrc/xentropy/xentropy_kernel.cu", @@ -495,6 +513,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"] + [f'-DXENTROPY_VER="{xentropy_ver}"'], "nvcc": ["-O3"], }, + py_limited_api=True, ) ) @@ -504,7 +523,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--focal_loss") ext_modules.append( CUDAExtension( - name="focal_loss_cuda", + name="_focal_loss_cuda", sources=[ "apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp", "apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu", @@ -514,6 +533,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math", "--ftz=false"], }, + py_limited_api=True, ) ) @@ -524,7 +544,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="group_norm_cuda", + name="_group_norm_cuda", sources=[ "apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp", ] @@ -538,6 +558,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "--ftz=false", ], }, + py_limited_api=True, ) ) @@ -554,7 +575,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="group_norm_v2_cuda", + name="_group_norm_v2_cuda", sources=[ "apex/contrib/csrc/group_norm_v2/gn.cpp", "apex/contrib/csrc/group_norm_v2/gn_cuda.cu", @@ -574,6 +595,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ] + arch_flags, }, + py_limited_api=True, ) ) @@ -583,7 +605,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--index_mul_2d") ext_modules.append( CUDAExtension( - name="fused_index_mul_2d", + name="_fused_index_mul_2d", sources=[ "apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp", "apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu", @@ -593,6 +615,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math", "--ftz=false"], }, + py_limited_api=True, ) ) @@ -602,7 +625,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--deprecated_fused_adam") ext_modules.append( CUDAExtension( - name="fused_adam_cuda", + name="_fused_adam_cuda", sources=[ "apex/contrib/csrc/optimizers/fused_adam_cuda.cpp", "apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu", @@ -612,6 +635,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"], }, + py_limited_api=True, ) ) @@ -621,7 +645,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--deprecated_fused_lamb") ext_modules.append( CUDAExtension( - name="fused_lamb_cuda", + name="_fused_lamb_cuda", sources=[ "apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp", "apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu", @@ -632,6 +656,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"], }, + py_limited_api=True, ) ) @@ -649,12 +674,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fast_layer_norm", + name="_fast_layer_norm", sources=[ "apex/contrib/csrc/layer_norm/ln_api.cpp", "apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu", "apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu", ], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"] + generator_flag, "nvcc": [ @@ -701,7 +727,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fmhalib", + name="_fmhalib", sources=[ "apex/contrib/csrc/fmha/fmha_api.cpp", "apex/contrib/csrc/fmha/src/fmha_fill.cu", @@ -732,6 +758,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src"), ], + py_limited_api=True, ) ) @@ -752,7 +779,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="fast_multihead_attn", + name="_fast_multihead_attn", sources=[ "apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp", "apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu", @@ -783,6 +810,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "apex/contrib/csrc/multihead_attn/cutlass/tools/util/include", ), ], + py_limited_api=True, ) ) @@ -792,11 +820,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--transducer") ext_modules.append( CUDAExtension( - name="transducer_joint_cuda", + name="_transducer_joint_cuda", sources=[ "apex/contrib/csrc/transducer/transducer_joint.cpp", "apex/contrib/csrc/transducer/transducer_joint_kernel.cu", ], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"] + generator_flag, "nvcc": ["-O3"] + generator_flag, @@ -809,11 +838,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="transducer_loss_cuda", + name="_transducer_loss_cuda", sources=[ "apex/contrib/csrc/transducer/transducer_loss.cpp", "apex/contrib/csrc/transducer/transducer_loss_kernel.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -838,12 +868,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="cudnn_gbn_lib", + name="_cudnn_gbn_lib", sources=[ "apex/contrib/csrc/cudnn_gbn/norm_sample.cpp", "apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp", ], include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], + py_limited_api=True, extra_compile_args={"cxx": ["-O3", "-g"] + generator_flag}, ) ) @@ -854,11 +885,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--peer_memory") ext_modules.append( CUDAExtension( - name="peer_memory_cuda", + name="_peer_memory_cuda", sources=[ "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu", "apex/contrib/csrc/peer_memory/peer_memory.cpp", ], + py_limited_api=True, extra_compile_args={"cxx": ["-O3"] + generator_flag}, ) ) @@ -870,11 +902,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--nccl_p2p") ext_modules.append( CUDAExtension( - name="nccl_p2p_cuda", + name="_nccl_p2p_cuda", sources=[ "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu", "apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp", ], + py_limited_api=True, + libraries=["nccl"], extra_compile_args={"cxx": ["-O3"] + generator_flag}, ) ) @@ -896,9 +930,10 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="fast_bottleneck", + name="_fast_bottleneck", sources=["apex/contrib/csrc/bottleneck/bottleneck.cpp"], include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], + py_limited_api=True, extra_compile_args={"cxx": ["-O3"] + generator_flag}, ) ) @@ -920,9 +955,10 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="fused_conv_bias_relu", + name="_fused_conv_bias_relu", sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"], include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], + py_limited_api=True, extra_compile_args={"cxx": ["-O3"] + generator_flag}, ) ) diff --git a/tests/L0/run_optimizers/test_lamb.py b/tests/L0/run_optimizers/test_lamb.py index 3b208e61b..171cf9827 100644 --- a/tests/L0/run_optimizers/test_lamb.py +++ b/tests/L0/run_optimizers/test_lamb.py @@ -39,7 +39,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0 defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(RefLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm # Skip buffer diff --git a/tests/distributed/synced_batchnorm/single_gpu_unit_test.py b/tests/distributed/synced_batchnorm/single_gpu_unit_test.py index 18ea55dcc..e8b8180f7 100644 --- a/tests/distributed/synced_batchnorm/single_gpu_unit_test.py +++ b/tests/distributed/synced_batchnorm/single_gpu_unit_test.py @@ -4,7 +4,7 @@ if True: print("using setup tools") - import syncbn + from apex._extensions import syncbn else: print("using jit") from torch.utils.cpp_extension import load diff --git a/tests/distributed/synced_batchnorm/test_groups.py b/tests/distributed/synced_batchnorm/test_groups.py index e95aa9984..4be20dbb4 100644 --- a/tests/distributed/synced_batchnorm/test_groups.py +++ b/tests/distributed/synced_batchnorm/test_groups.py @@ -1,7 +1,7 @@ import torch import numpy as np import apex -import syncbn +from apex._extensions import syncbn import os import argparse import torch.optim as optim diff --git a/tests/distributed/synced_batchnorm/two_gpu_unit_test.py b/tests/distributed/synced_batchnorm/two_gpu_unit_test.py index 3c97e9ac6..343d7763b 100644 --- a/tests/distributed/synced_batchnorm/two_gpu_unit_test.py +++ b/tests/distributed/synced_batchnorm/two_gpu_unit_test.py @@ -1,7 +1,7 @@ import torch import numpy as np import apex -import syncbn +from apex._extensions import syncbn import os import argparse import torch.optim as optim From ea797b8ef52eab17655ee13ad4188b920ce31f93 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 28 May 2026 00:34:50 +0900 Subject: [PATCH 2/4] Use dispatcher-native scalar types for transducer ops Register the transducer joint and loss checked entry points directly with the dispatcher instead of routing through _dispatch adapters that immediately narrowed int64_t and double arguments back to int and float. Keep the dispatcher-facing C++ signatures aligned with schema binding semantics, where schema int arrives as int64_t and schema float arrives as double. The only remaining narrowing is at CUDA kernel call sites that still require float values. --- .../csrc/transducer/transducer_joint.cpp | 37 +++++-------------- .../transducer/transducer_joint_kernel.cu | 33 +++++++++-------- .../csrc/transducer/transducer_loss.cpp | 36 +++++------------- .../csrc/transducer/transducer_loss_kernel.cu | 8 ++-- 4 files changed, 41 insertions(+), 73 deletions(-) diff --git a/apex/contrib/csrc/transducer/transducer_joint.cpp b/apex/contrib/csrc/transducer/transducer_joint.cpp index 3dd0b7d0a..9eb012b74 100644 --- a/apex/contrib/csrc/transducer/transducer_joint.cpp +++ b/apex/contrib/csrc/transducer/transducer_joint.cpp @@ -9,17 +9,17 @@ std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, at::Tensor gLen, at::Tensor batchOffset, - int64_t packedBatch, int opt, bool packOutput, bool relu, - bool dropout, float dropoutProb, int tileSize); + int64_t packedBatch, int64_t opt, bool packOutput, bool relu, + bool dropout, double dropoutProb, int64_t tileSize); std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, int maxFLen, - int maxGLen, bool packOutput, float scale); + at::Tensor gLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t maxGLen, bool packOutput, double scale); std::vector transducer_joint_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, at::Tensor gLen, at::Tensor batchOffset, int64_t packedBatch, - int opt, bool packOutput, bool relu, bool dropout, - float dropoutProb, int tileSize) { + int64_t opt, bool packOutput, bool relu, bool dropout, + double dropoutProb, int64_t tileSize) { CHECK_INPUT(f); CHECK_INPUT(g); CHECK_INPUT(fLen); @@ -30,8 +30,8 @@ std::vector transducer_joint_forward(at::Tensor f, at::Tensor g, at: } std::vector transducer_joint_backward(std::vector in, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, int maxFLen, - int maxGLen, bool packOutput, float scale) { + at::Tensor gLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t maxGLen, bool packOutput, double scale) { for (auto t : in) { CHECK_INPUT(t); } @@ -41,23 +41,6 @@ std::vector transducer_joint_backward(std::vector in, at return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale); } -std::vector transducer_joint_forward_dispatch(at::Tensor f, at::Tensor g, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, - int64_t packedBatch, int64_t opt, bool packOutput, - bool relu, bool dropout, double dropoutProb, - int64_t tileSize) { - return transducer_joint_forward(f, g, fLen, gLen, batchOffset, packedBatch, static_cast(opt), packOutput, relu, - dropout, static_cast(dropoutProb), static_cast(tileSize)); -} - -std::vector transducer_joint_backward_dispatch(std::vector in, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, - int64_t maxFLen, int64_t maxGLen, bool packOutput, - double scale) { - return transducer_joint_backward(in, fLen, gLen, batchOffset, static_cast(maxFLen), static_cast(maxGLen), - packOutput, static_cast(scale)); -} - TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("transducer_joint_forward(Tensor f, Tensor g, Tensor fLen, Tensor gLen, Tensor batchOffset, int packedBatch, " "int opt, bool packOutput, bool relu, bool dropout, float dropoutProb, int tileSize) -> Tensor[]"); @@ -66,6 +49,6 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("transducer_joint_forward", &transducer_joint_forward_dispatch); - m.impl("transducer_joint_backward", &transducer_joint_backward_dispatch); + m.impl("transducer_joint_forward", &transducer_joint_forward); + m.impl("transducer_joint_backward", &transducer_joint_backward); } diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu index c6df4748c..402b9b314 100644 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu @@ -537,8 +537,8 @@ __global__ void transducer_joint_combined_vec_backward(const scalar_t* grad, con std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, at::Tensor gLen, at::Tensor batchOffset, - int64_t packedBatch, int opt, bool packOutput, bool relu, - bool dropout, float dropoutProb, int tileSize) { + int64_t packedBatch, int64_t opt, bool packOutput, bool relu, + bool dropout, double dropoutProb, int64_t tileSize) { auto tensorOpt = f.options(); auto dtype = f.scalar_type(); const auto batchSize = f.size(0); @@ -637,8 +637,9 @@ std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g kernel<<>>(f.data_ptr(), g.data_ptr(), fLen.data_ptr(), gLen.data_ptr(), batchOffsetPtr, maxFLen, maxGLen, hiddenSize, - hiddenPerBlock, packOutput, relu, dropout, 1.0f - dropoutProb, - rng_engine_inputs, sum.data_ptr(), maskPtr); + hiddenPerBlock, packOutput, relu, dropout, + static_cast(1.0 - dropoutProb), + rng_engine_inputs, sum.data_ptr(), maskPtr); })); } @@ -650,8 +651,8 @@ std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g } std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, int maxFLen, - int maxGLen, bool packOutput, float scale) { + at::Tensor gLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t maxGLen, bool packOutput, double scale) { auto grad = in[0]; bool masked = (in.size() == 2); uint8_t* maskPtr = masked ? in[1].data_ptr() : nullptr; @@ -703,7 +704,9 @@ std::vector transducer_joint_cuda_backward(std::vector i (reinterpret_cast(fGradPtr) % vecAlignment == 0) and (reinterpret_cast(gGradPtr) % vecAlignment == 0); - if (vectFactor > 1 and hiddenSize % vectFactor == 0 and memAlign) { + const float scale_arg = static_cast(scale); + + if (vectFactor > 1 and hiddenSize % vectFactor == 0 and memAlign) { // If vectorization helps and the alignment requirement is met, use the vectorized // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. const dim3 blocks((hiddenSize + C10_WARP_SIZE * vectFactor - 1) / (C10_WARP_SIZE * vectFactor), @@ -711,26 +714,26 @@ std::vector transducer_joint_cuda_backward(std::vector i if (masked) { transducer_joint_combined_vec_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale, - fGradPtr, gGradPtr); + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, + fGradPtr, gGradPtr); } else { transducer_joint_combined_vec_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale, - fGradPtr, gGradPtr); + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, + fGradPtr, gGradPtr); } } else { const dim3 blocks((hiddenSize + C10_WARP_SIZE - 1) / C10_WARP_SIZE, maxFLen + maxGLen, batchSize); if (masked) { transducer_joint_combined_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale, - fGradPtr, gGradPtr); + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, + fGradPtr, gGradPtr); } else { transducer_joint_combined_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale, - fGradPtr, gGradPtr); + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, + fGradPtr, gGradPtr); } } })); diff --git a/apex/contrib/csrc/transducer/transducer_loss.cpp b/apex/contrib/csrc/transducer/transducer_loss.cpp index 64de84271..c31cb0fb2 100644 --- a/apex/contrib/csrc/transducer/transducer_loss.cpp +++ b/apex/contrib/csrc/transducer/transducer_loss.cpp @@ -10,17 +10,17 @@ CHECK_CONTIGUOUS(x) std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor label, at::Tensor audLen, - at::Tensor txtLen, at::Tensor batchOffset, int maxFLen, - int blankIdx, int opt, bool packedInput); + at::Tensor txtLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool packedInput); at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, at::Tensor audLen, at::Tensor txtLen, - at::Tensor label, at::Tensor batchOffset, int maxFLen, int blankIdx, - int opt, bool fuseSoftmaxBackward, bool packedInput); + at::Tensor label, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, bool packedInput); std::vector transducer_loss_forward(at::Tensor x, at::Tensor label, at::Tensor fLen, - at::Tensor yLen, at::Tensor batchOffset, int maxFLen, - int blankIdx, int opt, bool packedInput) { + at::Tensor yLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); CHECK_INPUT(fLen); @@ -31,7 +31,7 @@ std::vector transducer_loss_forward(at::Tensor x, at::Tensor label, at::Tensor transducer_loss_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, at::Tensor fLen, at::Tensor yLen, at::Tensor label, - at::Tensor batchOffset, int maxFLen, int blankIdx, int opt, + at::Tensor batchOffset, int64_t maxFLen, int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); @@ -46,24 +46,6 @@ at::Tensor transducer_loss_backward(at::Tensor x, at::Tensor lossGrad, at::Tenso fuseSoftmaxBackward, packedInput); } -std::vector transducer_loss_forward_dispatch(at::Tensor x, at::Tensor label, at::Tensor fLen, - at::Tensor yLen, at::Tensor batchOffset, - int64_t maxFLen, int64_t blankIdx, int64_t opt, - bool packedInput) { - return transducer_loss_forward(x, label, fLen, yLen, batchOffset, static_cast(maxFLen), - static_cast(blankIdx), static_cast(opt), packedInput); -} - -at::Tensor transducer_loss_backward_dispatch(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, - at::Tensor beta, at::Tensor fLen, at::Tensor yLen, - at::Tensor label, at::Tensor batchOffset, int64_t maxFLen, - int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, - bool packedInput) { - return transducer_loss_backward(x, lossGrad, alpha, beta, fLen, yLen, label, batchOffset, static_cast(maxFLen), - static_cast(blankIdx), static_cast(opt), fuseSoftmaxBackward, - packedInput); -} - TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("transducer_loss_forward(Tensor x, Tensor label, Tensor fLen, Tensor yLen, Tensor batchOffset, int maxFLen, " "int blankIdx, int opt, bool packedInput) -> Tensor[]"); @@ -73,6 +55,6 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("transducer_loss_forward", &transducer_loss_forward_dispatch); - m.impl("transducer_loss_backward", &transducer_loss_backward_dispatch); + m.impl("transducer_loss_forward", &transducer_loss_forward); + m.impl("transducer_loss_backward", &transducer_loss_backward); } diff --git a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu index b7613f8b9..1a41018ac 100644 --- a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu @@ -455,8 +455,8 @@ __global__ void transducer_loss_fused_vec_backward(const scalar_t* x, const scal } std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor label, at::Tensor audLen, - at::Tensor txtLen, at::Tensor batchOffset, int maxFLen, - int blankIdx, int opt, bool packedInput) { + at::Tensor txtLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool packedInput) { auto scalarType = x.scalar_type(); auto tensorOpt = x.options(); const int batchSize = label.size(0); @@ -515,8 +515,8 @@ std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor la at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, at::Tensor audLen, at::Tensor txtLen, - at::Tensor label, at::Tensor batchOffset, int maxFLen, int blankIdx, - int opt, bool fuseSoftmaxBackward, bool packedInput) { + at::Tensor label, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, bool packedInput) { auto dtype = x.scalar_type(); at::Tensor xGrad; const int batchSize = label.size(0); From 5430932c15e01fdea4c68d691239e90daafce250 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 28 May 2026 01:40:43 +0900 Subject: [PATCH 3/4] Remove simple dispatcher scalar adapters Python dispatcher schemas canonicalize int and float arguments to int64_t and double before invoking C++ kernels. Several APEX custom op frontends still registered small *_dispatch helpers only to narrow those values to int or float and then call the checked C++ entry point. Register the checked entry points directly for xentropy, focal loss, Megatron scaled softmax variants, fused LAMB, and distributed optimizer frontends. Keep the explicit narrowing at the existing CUDA helper boundary so legacy kernels continue to receive the same scalar types, while the dispatcher-facing code uses the native scalar widths. Leave wrappers that adapt const dispatcher tensors to mutable internal tensor references in place; those are a separate legacy API boundary rather than the scalar precision cleanup handled here. --- .../csrc/focal_loss/focal_loss_cuda.cpp | 18 ++--- .../csrc/optimizers/fused_lamb_cuda.cpp | 11 ++- .../optimizers/multi_tensor_distopt_adam.cpp | 33 +++++---- .../optimizers/multi_tensor_distopt_lamb.cpp | 30 ++++---- apex/contrib/csrc/xentropy/interface.cpp | 23 ++----- .../generic_scaled_masked_softmax.cpp | 20 ++---- csrc/megatron/scaled_masked_softmax.cpp | 68 ++++++++----------- csrc/megatron/scaled_softmax.cpp | 20 ++---- .../scaled_upper_triang_masked_softmax.cpp | 20 ++---- 9 files changed, 92 insertions(+), 151 deletions(-) diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp index 3d79df36d..8978d51c6 100644 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp +++ b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp @@ -23,13 +23,14 @@ at::Tensor focal_loss_backward_cuda(const at::Tensor& grad_output, const at::Ten std::vector focal_loss_forward(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level, const at::Tensor& num_positives_sum, const int64_t num_real_classes, - const float alpha, const float gamma, const float smoothing_factor) { + const double alpha, const double gamma, const double smoothing_factor) { CHECK_INPUT(cls_output); CHECK_INPUT(cls_targets_at_level); CHECK_INPUT(num_positives_sum); - return focal_loss_forward_cuda(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, alpha, gamma, - smoothing_factor); + return focal_loss_forward_cuda(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, + static_cast(alpha), static_cast(gamma), + static_cast(smoothing_factor)); } at::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::Tensor& partial_grad, @@ -40,15 +41,6 @@ at::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::Tensor& return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum); } -std::vector focal_loss_forward_dispatch(const at::Tensor& cls_output, - const at::Tensor& cls_targets_at_level, - const at::Tensor& num_positives_sum, int64_t num_real_classes, - double alpha, double gamma, double smoothing_factor) { - return focal_loss_forward(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, - static_cast(alpha), static_cast(gamma), - static_cast(smoothing_factor)); -} - TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("focal_loss_forward(Tensor cls_output, Tensor cls_targets_at_level, Tensor num_positives_sum, " "int num_real_classes, float alpha, float gamma, float smoothing_factor) -> Tensor[]"); @@ -56,6 +48,6 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("focal_loss_forward", &focal_loss_forward_dispatch); + m.impl("focal_loss_forward", &focal_loss_forward); m.impl("focal_loss_backward", &focal_loss_backward); } diff --git a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp index 36cf86a69..a89c33aa6 100644 --- a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp @@ -6,11 +6,10 @@ void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, double lr, double beta1, - double beta2, double epsilon, int64_t step, int64_t bias_correction, - double weight_decay, int64_t grad_averaging, int64_t mode, double global_grad_norm, - double max_grad_norm) { +void multi_tensor_lamb(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + double lr, double beta1, double beta2, double epsilon, int64_t step, int64_t bias_correction, + double weight_decay, int64_t grad_averaging, int64_t mode, double global_grad_norm, + double max_grad_norm) { multi_tensor_lamb_cuda(static_cast(chunk_size), noop_flag, tensor_lists, static_cast(lr), static_cast(beta1), static_cast(beta2), static_cast(epsilon), static_cast(step), static_cast(bias_correction), static_cast(weight_decay), @@ -25,5 +24,5 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("fused_lamb_lamb", &multi_tensor_lamb_dispatch); + m.impl("fused_lamb_lamb", &multi_tensor_lamb); } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp index d8e443152..87f1bdb0e 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp @@ -16,32 +16,31 @@ void multi_tensor_fused_adam_with_param_remainders_cuda(int chunk_size, at::Tens float eps, int step, int mode, int bias_correction, float weight_decay); -void multi_tensor_fused_adam_dispatch(int64_t chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, at::Tensor grad_scale, - double lr, double beta1, double beta2, double eps, int64_t step, int64_t mode, - int64_t bias_correction, double weight_decay) { +void multi_tensor_fused_adam(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, double lr, + double beta1, double beta2, double eps, int64_t step, int64_t mode, + int64_t bias_correction, double weight_decay) { multi_tensor_fused_adam_cuda(static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, static_cast(lr), static_cast(beta1), static_cast(beta2), static_cast(eps), static_cast(step), static_cast(mode), static_cast(bias_correction), static_cast(weight_decay)); } -void multi_tensor_fused_adam_capturable_dispatch(int64_t chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor grad_scale, at::Tensor lr, double beta1, double beta2, - double eps, at::Tensor step, int64_t mode, int64_t bias_correction, - double weight_decay) { +void multi_tensor_fused_adam_capturable(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, + at::Tensor lr, double beta1, double beta2, double eps, at::Tensor step, + int64_t mode, int64_t bias_correction, double weight_decay) { multi_tensor_fused_adam_capturable_cuda(static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, lr, static_cast(beta1), static_cast(beta2), static_cast(eps), step, static_cast(mode), static_cast(bias_correction), static_cast(weight_decay)); } -void multi_tensor_fused_adam_with_param_remainders_dispatch(int64_t chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor grad_scale, double lr, double beta1, - double beta2, double eps, int64_t step, int64_t mode, - int64_t bias_correction, double weight_decay) { +void multi_tensor_fused_adam_with_param_remainders(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor grad_scale, double lr, double beta1, double beta2, + double eps, int64_t step, int64_t mode, int64_t bias_correction, + double weight_decay) { multi_tensor_fused_adam_with_param_remainders_cuda( static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, static_cast(lr), static_cast(beta1), static_cast(beta2), static_cast(eps), static_cast(step), @@ -61,8 +60,8 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("distributed_adam_multi_tensor_fused_adam", &multi_tensor_fused_adam_dispatch); - m.impl("distributed_adam_multi_tensor_fused_adam_capturable", &multi_tensor_fused_adam_capturable_dispatch); + m.impl("distributed_adam_multi_tensor_fused_adam", &multi_tensor_fused_adam); + m.impl("distributed_adam_multi_tensor_fused_adam_capturable", &multi_tensor_fused_adam_capturable); m.impl("distributed_adam_multi_tensor_fused_adam_with_param_remainders", - &multi_tensor_fused_adam_with_param_remainders_dispatch); + &multi_tensor_fused_adam_with_param_remainders); } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp index b4a71514e..6c7ee6fe3 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp @@ -14,26 +14,24 @@ void multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag, at::Tensor update_norm_offset, at::Tensor learning_rate, at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb); -void multi_tensor_lamb_compute_update_term_dispatch(int64_t chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2, - at::Tensor per_tensor_beta3, - at::Tensor per_tensor_bias_correction, at::Tensor step, - at::Tensor per_tensor_epsilon, int64_t mode, - at::Tensor per_tensor_decay, at::Tensor global_scale, - at::Tensor global_grad_norm, double max_grad_norm) { +void multi_tensor_lamb_compute_update_term(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2, + at::Tensor per_tensor_beta3, at::Tensor per_tensor_bias_correction, + at::Tensor step, at::Tensor per_tensor_epsilon, int64_t mode, + at::Tensor per_tensor_decay, at::Tensor global_scale, + at::Tensor global_grad_norm, double max_grad_norm) { multi_tensor_lamb_compute_update_term_cuda(static_cast(chunk_size), noop_flag, tensor_lists, per_tensor_beta1, per_tensor_beta2, per_tensor_beta3, per_tensor_bias_correction, step, per_tensor_epsilon, static_cast(mode), per_tensor_decay, global_scale, global_grad_norm, static_cast(max_grad_norm)); } -void multi_tensor_lamb_update_weights_dispatch(int64_t chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, - at::Tensor update_norm_offset, at::Tensor learning_rate, - at::Tensor per_tensor_decay, at::Tensor global_grad_norm, - bool use_nvlamb) { +void multi_tensor_lamb_update_weights(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, + at::Tensor update_norm_offset, at::Tensor learning_rate, + at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb) { multi_tensor_lamb_update_weights_cuda(static_cast(chunk_size), noop_flag, tensor_lists, per_tensor_param_norm, per_tensor_update_norm, update_norm_offset, learning_rate, per_tensor_decay, global_grad_norm, use_nvlamb); @@ -50,6 +48,6 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("distributed_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_dispatch); - m.impl("distributed_lamb_update_weights", &multi_tensor_lamb_update_weights_dispatch); + m.impl("distributed_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term); + m.impl("distributed_lamb_update_weights", &multi_tensor_lamb_update_weights); } diff --git a/apex/contrib/csrc/xentropy/interface.cpp b/apex/contrib/csrc/xentropy/interface.cpp index aaf9c0c6c..40461798d 100644 --- a/apex/contrib/csrc/xentropy/interface.cpp +++ b/apex/contrib/csrc/xentropy/interface.cpp @@ -21,33 +21,22 @@ at::Tensor softmax_xentropy_backward_cuda(const at::Tensor& grad_loss, const at: CHECK_CONTIGUOUS(x) std::vector softmax_xentropy_forward(const at::Tensor& input, const at::Tensor& labels, - const float smoothing, const bool half_to_float) { + const double smoothing, const bool half_to_float) { CHECK_CUDA(input); CHECK_INPUT(labels); - return softmax_xentropy_cuda(input, labels, smoothing, half_to_float); + return softmax_xentropy_cuda(input, labels, static_cast(smoothing), half_to_float); } at::Tensor softmax_xentropy_backward(const at::Tensor& grad_loss, const at::Tensor& logits, const at::Tensor& max_log_sum_exp, const at::Tensor& labels, - const float smoothing) { + const double smoothing) { CHECK_CUDA(grad_loss); CHECK_CUDA(logits); CHECK_INPUT(max_log_sum_exp); CHECK_INPUT(labels); - return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing); -} - -std::vector softmax_xentropy_forward_dispatch(const at::Tensor& input, const at::Tensor& labels, - double smoothing, bool half_to_float) { - return softmax_xentropy_forward(input, labels, static_cast(smoothing), half_to_float); -} - -at::Tensor softmax_xentropy_backward_dispatch(const at::Tensor& grad_loss, const at::Tensor& logits, - const at::Tensor& max_log_sum_exp, const at::Tensor& labels, - double smoothing) { - return softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, static_cast(smoothing)); + return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, static_cast(smoothing)); } std::string softmax_xentropy_version() { @@ -66,8 +55,8 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("xentropy_forward", &softmax_xentropy_forward_dispatch); - m.impl("xentropy_backward", &softmax_xentropy_backward_dispatch); + m.impl("xentropy_forward", &softmax_xentropy_forward); + m.impl("xentropy_backward", &softmax_xentropy_backward); } TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { diff --git a/csrc/megatron/generic_scaled_masked_softmax.cpp b/csrc/megatron/generic_scaled_masked_softmax.cpp index a0cdf38f6..7d1dc638b 100644 --- a/csrc/megatron/generic_scaled_masked_softmax.cpp +++ b/csrc/megatron/generic_scaled_masked_softmax.cpp @@ -28,16 +28,16 @@ at::Tensor fwd_cuda(at::Tensor const& input, at::Tensor const& mask, float scale at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); -at::Tensor fwd(at::Tensor const& input, at::Tensor const& mask, float scale_factor) { +at::Tensor fwd(at::Tensor const& input, at::Tensor const& mask, double scale_factor) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); TORCH_CHECK(mask.dim() == 4, "expected 4D tensor"); - return fwd_cuda(input, mask, scale_factor); + return fwd_cuda(input, mask, static_cast(scale_factor)); } -at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); @@ -48,15 +48,7 @@ at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -at::Tensor fwd_dispatch(at::Tensor const& input, at::Tensor const& mask, double scale_factor) { - return fwd(input, mask, static_cast(scale_factor)); -} - -at::Tensor bwd_dispatch(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { - return bwd(output_grads, softmax_results, static_cast(scale_factor)); + return bwd_cuda(output_grads, softmax_results, static_cast(scale_factor)); } } // end namespace generic_scaled_masked_softmax @@ -71,7 +63,7 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { TORCH_LIBRARY_IMPL(apex, CUDA, m) { m.impl("generic_scaled_masked_softmax_forward", - &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd_dispatch); + &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd); m.impl("generic_scaled_masked_softmax_backward", - &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd_dispatch); + &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd); } diff --git a/csrc/megatron/scaled_masked_softmax.cpp b/csrc/megatron/scaled_masked_softmax.cpp index f358f8609..c4d2fcff1 100644 --- a/csrc/megatron/scaled_masked_softmax.cpp +++ b/csrc/megatron/scaled_masked_softmax.cpp @@ -30,52 +30,40 @@ at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_re int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads); -at::Tensor fwd(at::Tensor& input, at::Tensor& mask, float scale_factor) { - TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); - TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), +at::Tensor fwd(at::Tensor const& input, at::Tensor const& mask, double scale_factor) { + auto input_arg = input; + auto mask_arg = mask; + TORCH_CHECK(input_arg.dim() == 4, "expected 4D tensor"); + TORCH_CHECK((input_arg.scalar_type() == at::ScalarType::Half) || (input_arg.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - TORCH_CHECK(mask.dim() == 4, "expected 4D tensor"); - if (!input.is_contiguous()) input = input.contiguous(); - if (!mask.is_contiguous()) mask = mask.contiguous(); + TORCH_CHECK(mask_arg.dim() == 4, "expected 4D tensor"); + if (!input_arg.is_contiguous()) input_arg = input_arg.contiguous(); + if (!mask_arg.is_contiguous()) mask_arg = mask_arg.contiguous(); - return fwd_cuda(input, mask, scale_factor); + return fwd_cuda(input_arg, mask_arg, static_cast(scale_factor)); } -at::Tensor bwd(at::Tensor& output_grads, at::Tensor& softmax_results, float scale_factor) { - TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); - TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { + auto output_grads_arg = output_grads; + auto softmax_results_arg = softmax_results; + TORCH_CHECK(output_grads_arg.dim() == 4, "expected 3D tensor"); + TORCH_CHECK(softmax_results_arg.dim() == 4, "expected 3D tensor"); - TORCH_CHECK( - (output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), + TORCH_CHECK((output_grads_arg.scalar_type() == at::ScalarType::Half) || + (output_grads_arg.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - if (!output_grads.is_contiguous()) output_grads = output_grads.contiguous(); - if (!softmax_results.is_contiguous()) softmax_results = softmax_results.contiguous(); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); -} - -at::Tensor fwd_dispatch(at::Tensor const& input, at::Tensor const& mask, double scale_factor) { - at::Tensor input_arg = input; - at::Tensor mask_arg = mask; - return fwd(input_arg, mask_arg, static_cast(scale_factor)); -} + TORCH_CHECK((softmax_results_arg.scalar_type() == at::ScalarType::Half) || + (softmax_results_arg.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + if (!output_grads_arg.is_contiguous()) output_grads_arg = output_grads_arg.contiguous(); + if (!softmax_results_arg.is_contiguous()) softmax_results_arg = softmax_results_arg.contiguous(); -at::Tensor bwd_dispatch(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { - at::Tensor output_grads_arg = output_grads; - at::Tensor softmax_results_arg = softmax_results; - return bwd(output_grads_arg, softmax_results_arg, static_cast(scale_factor)); + return bwd_cuda(output_grads_arg, softmax_results_arg, static_cast(scale_factor)); } -int64_t get_batch_per_block_dispatch(int64_t query_seq_len, int64_t key_seq_len, int64_t batches, int64_t attn_heads) { - return get_batch_per_block(static_cast(query_seq_len), static_cast(key_seq_len), static_cast(batches), - static_cast(attn_heads)); +int64_t get_batch_per_block(int64_t query_seq_len, int64_t key_seq_len, int64_t batches, int64_t attn_heads) { + return get_batch_per_block_cuda(static_cast(query_seq_len), static_cast(key_seq_len), + static_cast(batches), static_cast(attn_heads)); } } // end namespace scaled_masked_softmax @@ -90,11 +78,11 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("scaled_masked_softmax_forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd_dispatch); - m.impl("scaled_masked_softmax_backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd_dispatch); + m.impl("scaled_masked_softmax_forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd); + m.impl("scaled_masked_softmax_backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd); } TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { m.impl("scaled_masked_softmax_get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block_dispatch); + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block); } diff --git a/csrc/megatron/scaled_softmax.cpp b/csrc/megatron/scaled_softmax.cpp index 8538247b4..7bc8a6e17 100644 --- a/csrc/megatron/scaled_softmax.cpp +++ b/csrc/megatron/scaled_softmax.cpp @@ -28,15 +28,15 @@ at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor); at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); -at::Tensor fwd(at::Tensor const& input, float scale_factor) { +at::Tensor fwd(at::Tensor const& input, double scale_factor) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return fwd_cuda(input, scale_factor); + return fwd_cuda(input, static_cast(scale_factor)); } -at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); @@ -47,15 +47,7 @@ at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -at::Tensor fwd_dispatch(at::Tensor const& input, double scale_factor) { - return fwd(input, static_cast(scale_factor)); -} - -at::Tensor bwd_dispatch(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { - return bwd(output_grads, softmax_results, static_cast(scale_factor)); + return bwd_cuda(output_grads, softmax_results, static_cast(scale_factor)); } } // end namespace scaled_softmax @@ -68,6 +60,6 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("scaled_softmax_forward", &multihead_attn::fused_softmax::scaled_softmax::fwd_dispatch); - m.impl("scaled_softmax_backward", &multihead_attn::fused_softmax::scaled_softmax::bwd_dispatch); + m.impl("scaled_softmax_forward", &multihead_attn::fused_softmax::scaled_softmax::fwd); + m.impl("scaled_softmax_backward", &multihead_attn::fused_softmax::scaled_softmax::bwd); } diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp index 2856197ca..eb2a7d9e1 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp +++ b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp @@ -28,15 +28,15 @@ at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor); at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); -at::Tensor fwd(at::Tensor const& input, float scale_factor) { +at::Tensor fwd(at::Tensor const& input, double scale_factor) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return fwd_cuda(input, scale_factor); + return fwd_cuda(input, static_cast(scale_factor)); } -at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); @@ -47,15 +47,7 @@ at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -at::Tensor fwd_dispatch(at::Tensor const& input, double scale_factor) { - return fwd(input, static_cast(scale_factor)); -} - -at::Tensor bwd_dispatch(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { - return bwd(output_grads, softmax_results, static_cast(scale_factor)); + return bwd_cuda(output_grads, softmax_results, static_cast(scale_factor)); } } // end namespace scaled_upper_triang_masked_softmax @@ -70,7 +62,7 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { TORCH_LIBRARY_IMPL(apex, CUDA, m) { m.impl("scaled_upper_triang_masked_softmax_forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd_dispatch); + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd); m.impl("scaled_upper_triang_masked_softmax_backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd_dispatch); + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd); } From f7c9d521c66412bcb77f14979aa141b23e0c7c9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 16:46:38 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apex/_custom_ops.py | 4 +- apex/_extensions/amp_C.py | 16 +- apex/_extensions/bnp.py | 8 +- apex/_extensions/cudnn_gbn_lib.py | 4 +- apex/_extensions/fast_bottleneck.py | 35 +- apex/_extensions/fast_multihead_attn.py | 33 +- apex/_extensions/focal_loss_cuda.py | 10 +- .../generic_scaled_masked_softmax_cuda.py | 8 +- apex/_extensions/mlp_cuda.py | 4 +- apex/_extensions/permutation_search_cuda.py | 19 +- .../_extensions/scaled_masked_softmax_cuda.py | 9 +- apex/_extensions/scaled_softmax_cuda.py | 4 +- ...scaled_upper_triang_masked_softmax_cuda.py | 4 +- apex/_extensions/xentropy_cuda.py | 4 +- apex/contrib/csrc/bottleneck/bottleneck.cpp | 119 ++-- apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp | 18 +- apex/contrib/csrc/fmha/fmha_api.cpp | 17 +- apex/contrib/csrc/fmha/src/fmha_fill.cu | 2 +- .../csrc/focal_loss/focal_loss_cuda.cpp | 5 +- .../csrc/group_norm/group_norm_nhwc_op.cpp | 18 +- apex/contrib/csrc/group_norm_v2/gn.cpp | 16 +- apex/contrib/csrc/groupbn/interface.cpp | 112 ++-- .../csrc/index_mul_2d/index_mul_2d_cuda.cpp | 36 +- apex/contrib/csrc/layer_norm/ln_api.cpp | 11 +- .../additive_masked_softmax_dropout_cuda.cu | 6 +- .../encdec_multihead_attn_cuda.cu | 19 +- .../encdec_multihead_attn_norm_add_cuda.cu | 28 +- .../masked_softmax_dropout_cuda.cu | 6 +- .../multihead_attn_frontend.cpp | 539 +++++++++--------- ..._multihead_attn_bias_additive_mask_cuda.cu | 17 +- .../self_multihead_attn_bias_cuda.cu | 16 +- .../self_multihead_attn_cuda.cu | 14 +- .../self_multihead_attn_norm_add_cuda.cu | 22 +- apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp | 29 +- apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh | 1 + .../csrc/optimizers/fused_adam_cuda.cpp | 37 +- .../csrc/optimizers/fused_lamb_cuda.cpp | 11 +- .../optimizers/multi_tensor_distopt_adam.cpp | 35 +- .../optimizers/multi_tensor_distopt_lamb.cpp | 20 +- apex/contrib/csrc/peer_memory/peer_memory.cpp | 20 +- .../csrc/peer_memory/peer_memory_cuda.cuh | 1 + .../csrc/transducer/transducer_joint.cpp | 82 +-- .../transducer/transducer_joint_kernel.cu | 83 ++- .../csrc/transducer/transducer_loss.cpp | 36 +- .../csrc/transducer/transducer_loss_kernel.cu | 12 +- apex/contrib/csrc/xentropy/interface.cpp | 9 +- .../permutation_search_kernels.cu | 57 +- apex/contrib/transducer/transducer.py | 4 +- apex/normalization/fused_layer_norm.py | 40 +- csrc/amp_C_frontend.cpp | 179 +++--- csrc/fused_dense.cpp | 12 +- csrc/layer_norm_cuda.cpp | 54 +- .../fused_rotary_positional_embedding.cpp | 7 +- .../generic_scaled_masked_softmax.cpp | 13 +- csrc/megatron/scaled_masked_softmax.cpp | 12 +- csrc/megatron/scaled_softmax.cpp | 2 +- .../scaled_upper_triang_masked_softmax.cpp | 7 +- csrc/mlp.cpp | 7 +- csrc/syncbn.cpp | 30 +- 59 files changed, 1098 insertions(+), 885 deletions(-) diff --git a/apex/_custom_ops.py b/apex/_custom_ops.py index 6a74ef22f..4e43f2e3d 100644 --- a/apex/_custom_ops.py +++ b/apex/_custom_ops.py @@ -17,7 +17,9 @@ def load_custom_op_library(extension_name, anchor_file): key=lambda path: (".cpython-" in path.name, path.name), ) if not candidates: - raise ImportError(f"Could not find shared library for {extension_name!r} next to {anchor_file}") + raise ImportError( + f"Could not find shared library for {extension_name!r} next to {anchor_file}" + ) library = str(candidates[0]) if library not in _loaded_libraries: diff --git a/apex/_extensions/amp_C.py b/apex/_extensions/amp_C.py index 7fb2fce23..25426100c 100644 --- a/apex/_extensions/amp_C.py +++ b/apex/_extensions/amp_C.py @@ -8,7 +8,9 @@ def multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale): - return _ops.amp_multi_tensor_scale(chunk_size, noop_flag, tensor_list_arg(tensor_lists), scalar_float(scale)) + return _ops.amp_multi_tensor_scale( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), scalar_float(scale) + ) def multi_tensor_sgd( @@ -51,11 +53,15 @@ def multi_tensor_axpby(chunk_size, noop_flag, tensor_lists, a, b, arg_to_check): def multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor_python=None): - return _ops.amp_multi_tensor_l2norm(chunk_size, noop_flag, tensor_list_arg(tensor_lists), per_tensor_python) + return _ops.amp_multi_tensor_l2norm( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), per_tensor_python + ) def multi_tensor_l2norm_mp(chunk_size, noop_flag, tensor_lists, per_tensor_python=None): - return _ops.amp_multi_tensor_l2norm_mp(chunk_size, noop_flag, tensor_list_arg(tensor_lists), per_tensor_python) + return _ops.amp_multi_tensor_l2norm_mp( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), per_tensor_python + ) def multi_tensor_l2norm_scale(chunk_size, noop_flag, tensor_lists, scale, per_tensor_python=None): @@ -64,7 +70,9 @@ def multi_tensor_l2norm_scale(chunk_size, noop_flag, tensor_lists, scale, per_te ) -def multi_tensor_unscale_l2norm(chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor_python=None): +def multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor_python=None +): return _ops.amp_multi_tensor_unscale_l2norm( chunk_size, noop_flag, tensor_list_arg(tensor_lists), inv_scale, per_tensor_python ) diff --git a/apex/_extensions/bnp.py b/apex/_extensions/bnp.py index 38d1c25b5..2ec7e01b4 100644 --- a/apex/_extensions/bnp.py +++ b/apex/_extensions/bnp.py @@ -73,7 +73,9 @@ def bn_fwd_nhwc( ) -def bn_fwd_eval_nhwc(x, scale, bias, running_mean, running_inv_var, ret_cta, bn_group, momentum, epsilon, fuse_relu): +def bn_fwd_eval_nhwc( + x, scale, bias, running_mean, running_inv_var, ret_cta, bn_group, momentum, epsilon, fuse_relu +): return _ops.bnp_bn_fwd_eval_nhwc( x, scale, @@ -192,7 +194,9 @@ def bn_addrelu_fwd_nhwc( ) -def bn_addrelu_fwd_eval_nhwc(x, z, scale, bias, running_mean, running_inv_var, ret_cta, bn_group, momentum, epsilon): +def bn_addrelu_fwd_eval_nhwc( + x, z, scale, bias, running_mean, running_inv_var, ret_cta, bn_group, momentum, epsilon +): return _ops.bnp_bn_addrelu_fwd_eval_nhwc( x, z, diff --git a/apex/_extensions/cudnn_gbn_lib.py b/apex/_extensions/cudnn_gbn_lib.py index 96e9c5891..e0a302b4d 100644 --- a/apex/_extensions/cudnn_gbn_lib.py +++ b/apex/_extensions/cudnn_gbn_lib.py @@ -41,7 +41,9 @@ def forward( ) -def backward(x, dy, scale, minibatch_mean, minibatch_inv_var, epsilon, bn_group, rank_id, peer_buffers): +def backward( + x, dy, scale, minibatch_mean, minibatch_inv_var, epsilon, bn_group, rank_id, peer_buffers +): return _ops.cudnn_gbn_backward( x, dy, diff --git a/apex/_extensions/fast_bottleneck.py b/apex/_extensions/fast_bottleneck.py index 99091d2de..98cbc5510 100644 --- a/apex/_extensions/fast_bottleneck.py +++ b/apex/_extensions/fast_bottleneck.py @@ -66,7 +66,11 @@ def forward_out2_halo_corr(explicit_nhwc, slim_halo_y1, inputs, w1by3, out2_part def forward_out2_pad(explicit_nhwc, stride_1x1, inputs, outputs, out1_pad): return _ops.fast_bottleneck_forward_out2_pad( - bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), out1_pad + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + out1_pad, ) @@ -90,11 +94,17 @@ def backward_grad_out2(explicit_nhwc, stride_1x1, inputs, outputs): def backward_grad_out1(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2): return _ops.fast_bottleneck_backward_grad_out1( - bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), grad_out2 + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2, ) -def backward_grad_out1_mask(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2, thresholdTop, thresholdBottom): +def backward_grad_out1_mask( + explicit_nhwc, stride_1x1, inputs, outputs, grad_out2, thresholdTop, thresholdBottom +): return _ops.fast_bottleneck_backward_grad_out1_mask( bool(explicit_nhwc), scalar_int(stride_1x1), @@ -134,13 +144,22 @@ def backward_grad_out1_halo_corr( def backward_wgrad2_pad(explicit_nhwc, stride_1x1, inputs, outputs, input, grad_out2): return _ops.fast_bottleneck_backward_wgrad2_pad( - bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), input, grad_out2 + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + input, + grad_out2, ) def backward_wgrad2(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2): return _ops.fast_bottleneck_backward_wgrad2( - bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), grad_out2 + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2, ) @@ -163,7 +182,11 @@ def backward_wgrad3(explicit_nhwc, stride_1x1, inputs, outputs): def backward_wgrad1(explicit_nhwc, stride_1x1, inputs, outputs, grad_out1): return _ops.fast_bottleneck_backward_wgrad1( - bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs), grad_out1 + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out1, ) diff --git a/apex/_extensions/fast_multihead_attn.py b/apex/_extensions/fast_multihead_attn.py index bf863accb..21e2aaf34 100644 --- a/apex/_extensions/fast_multihead_attn.py +++ b/apex/_extensions/fast_multihead_attn.py @@ -7,25 +7,46 @@ _ops = torch.ops.apex -def additive_mask_softmax_dropout_forward(use_mask, is_training, heads, input, pad_mask, dropout_prob): +def additive_mask_softmax_dropout_forward( + use_mask, is_training, heads, input, pad_mask, dropout_prob +): return _ops.fast_multihead_attn_additive_mask_softmax_dropout_forward( - bool(use_mask), bool(is_training), scalar_int(heads), input, pad_mask, scalar_float(dropout_prob) + bool(use_mask), + bool(is_training), + scalar_int(heads), + input, + pad_mask, + scalar_float(dropout_prob), ) -def additive_mask_softmax_dropout_backward(use_mask, heads, output_grads, softmax_results, dropout_mask, dropout_prob): +def additive_mask_softmax_dropout_backward( + use_mask, heads, output_grads, softmax_results, dropout_mask, dropout_prob +): return _ops.fast_multihead_attn_additive_mask_softmax_dropout_backward( - bool(use_mask), scalar_int(heads), output_grads, softmax_results, dropout_mask, scalar_float(dropout_prob) + bool(use_mask), + scalar_int(heads), + output_grads, + softmax_results, + dropout_mask, + scalar_float(dropout_prob), ) def mask_softmax_dropout_forward(use_mask, is_training, heads, input, pad_mask, dropout_prob): return _ops.fast_multihead_attn_mask_softmax_dropout_forward( - bool(use_mask), bool(is_training), scalar_int(heads), input, pad_mask, scalar_float(dropout_prob) + bool(use_mask), + bool(is_training), + scalar_int(heads), + input, + pad_mask, + scalar_float(dropout_prob), ) -def mask_softmax_dropout_backward(use_mask, heads, output_grads, softmax_results, dropout_mask, padding_mask, dropout_prob): +def mask_softmax_dropout_backward( + use_mask, heads, output_grads, softmax_results, dropout_mask, padding_mask, dropout_prob +): return _ops.fast_multihead_attn_mask_softmax_dropout_backward( bool(use_mask), scalar_int(heads), diff --git a/apex/_extensions/focal_loss_cuda.py b/apex/_extensions/focal_loss_cuda.py index edf9924a0..ae941bdd5 100644 --- a/apex/_extensions/focal_loss_cuda.py +++ b/apex/_extensions/focal_loss_cuda.py @@ -6,7 +6,15 @@ load_custom_op_library("_focal_loss_cuda", __file__) -def forward(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, alpha, gamma, smoothing_factor): +def forward( + cls_output, + cls_targets_at_level, + num_positives_sum, + num_real_classes, + alpha, + gamma, + smoothing_factor, +): return torch.ops.apex.focal_loss_forward( cls_output, cls_targets_at_level, diff --git a/apex/_extensions/generic_scaled_masked_softmax_cuda.py b/apex/_extensions/generic_scaled_masked_softmax_cuda.py index 76c212d48..770bd6894 100644 --- a/apex/_extensions/generic_scaled_masked_softmax_cuda.py +++ b/apex/_extensions/generic_scaled_masked_softmax_cuda.py @@ -7,8 +7,12 @@ def forward(input, mask, scale_factor): - return torch.ops.apex.generic_scaled_masked_softmax_forward(input, mask, scalar_float(scale_factor)) + return torch.ops.apex.generic_scaled_masked_softmax_forward( + input, mask, scalar_float(scale_factor) + ) def backward(output_grads, softmax_results, scale_factor): - return torch.ops.apex.generic_scaled_masked_softmax_backward(output_grads, softmax_results, scalar_float(scale_factor)) + return torch.ops.apex.generic_scaled_masked_softmax_backward( + output_grads, softmax_results, scalar_float(scale_factor) + ) diff --git a/apex/_extensions/mlp_cuda.py b/apex/_extensions/mlp_cuda.py index 75a228c1b..23e5b8d4c 100644 --- a/apex/_extensions/mlp_cuda.py +++ b/apex/_extensions/mlp_cuda.py @@ -11,4 +11,6 @@ def forward(use_bias, activation, inputs): def backward(use_bias, activation, grad_o, fprop_outputs, inputs): - return torch.ops.apex.mlp_backward(use_bias, activation, grad_o, list(fprop_outputs), list(inputs)) + return torch.ops.apex.mlp_backward( + use_bias, activation, grad_o, list(fprop_outputs), list(inputs) + ) diff --git a/apex/_extensions/permutation_search_cuda.py b/apex/_extensions/permutation_search_cuda.py index aa3c93dcf..c5caf0a4c 100644 --- a/apex/_extensions/permutation_search_cuda.py +++ b/apex/_extensions/permutation_search_cuda.py @@ -29,7 +29,18 @@ def sum_after_2_to_4(matrix, rows, cols, start_col, end_col, blocks, threads, ou ) -def build_permute_map(matrix, rows, cols, stripes, num_groups, group_width, permutations, perm_length, improvements, best_indices): +def build_permute_map( + matrix, + rows, + cols, + stripes, + num_groups, + group_width, + permutations, + perm_length, + improvements, + best_indices, +): return _op("permutation_search_build_permute_map")( _as_tensor(matrix), scalar_int(rows), @@ -72,5 +83,9 @@ def check_permutations( def build_swap_map(matrix, rows, cols, stripe_pairs, output): return _op("permutation_search_build_swap_map")( - _as_tensor(matrix), scalar_int(rows), scalar_int(cols), _as_tensor(stripe_pairs), _as_tensor(output) + _as_tensor(matrix), + scalar_int(rows), + scalar_int(cols), + _as_tensor(stripe_pairs), + _as_tensor(output), ) diff --git a/apex/_extensions/scaled_masked_softmax_cuda.py b/apex/_extensions/scaled_masked_softmax_cuda.py index f6de1cca3..3c6e17e84 100644 --- a/apex/_extensions/scaled_masked_softmax_cuda.py +++ b/apex/_extensions/scaled_masked_softmax_cuda.py @@ -11,10 +11,15 @@ def forward(input, mask, scale_factor): def backward(output_grads, softmax_results, scale_factor): - return torch.ops.apex.scaled_masked_softmax_backward(output_grads, softmax_results, scalar_float(scale_factor)) + return torch.ops.apex.scaled_masked_softmax_backward( + output_grads, softmax_results, scalar_float(scale_factor) + ) def get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads): return torch.ops.apex.scaled_masked_softmax_get_batch_per_block( - scalar_int(query_seq_len), scalar_int(key_seq_len), scalar_int(batches), scalar_int(attn_heads) + scalar_int(query_seq_len), + scalar_int(key_seq_len), + scalar_int(batches), + scalar_int(attn_heads), ) diff --git a/apex/_extensions/scaled_softmax_cuda.py b/apex/_extensions/scaled_softmax_cuda.py index f9bc15596..5b9ff0b9d 100644 --- a/apex/_extensions/scaled_softmax_cuda.py +++ b/apex/_extensions/scaled_softmax_cuda.py @@ -11,4 +11,6 @@ def forward(input, scale_factor): def backward(output_grads, softmax_results, scale_factor): - return torch.ops.apex.scaled_softmax_backward(output_grads, softmax_results, scalar_float(scale_factor)) + return torch.ops.apex.scaled_softmax_backward( + output_grads, softmax_results, scalar_float(scale_factor) + ) diff --git a/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py b/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py index 9f23bb41f..8feebacba 100644 --- a/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py +++ b/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py @@ -7,7 +7,9 @@ def forward(input, scale_factor): - return torch.ops.apex.scaled_upper_triang_masked_softmax_forward(input, scalar_float(scale_factor)) + return torch.ops.apex.scaled_upper_triang_masked_softmax_forward( + input, scalar_float(scale_factor) + ) def backward(output_grads, softmax_results, scale_factor): diff --git a/apex/_extensions/xentropy_cuda.py b/apex/_extensions/xentropy_cuda.py index 969fad21a..0ae58d8d8 100644 --- a/apex/_extensions/xentropy_cuda.py +++ b/apex/_extensions/xentropy_cuda.py @@ -13,4 +13,6 @@ def forward(input, labels, smoothing, half_to_float): def backward(grad_loss, logits, max_log_sum_exp, labels, smoothing): - return torch.ops.apex.xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, scalar_float(smoothing)) + return torch.ops.apex.xentropy_backward( + grad_loss, logits, max_log_sum_exp, labels, scalar_float(smoothing) + ) diff --git a/apex/contrib/csrc/bottleneck/bottleneck.cpp b/apex/contrib/csrc/bottleneck/bottleneck.cpp index a521ca090..e33201cb9 100644 --- a/apex/contrib/csrc/bottleneck/bottleneck.cpp +++ b/apex/contrib/csrc/bottleneck/bottleneck.cpp @@ -3581,9 +3581,9 @@ void apex_fast_bottleneck_forward_out2(bool explicit_nhwc, int64_t stride_1X1, s bottleneck_forward_out2(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); } -void apex_fast_bottleneck_forward_out2_mask(bool explicit_nhwc, int64_t stride_1X1, - std::vector inputs, std::vector outputs, - at::Tensor thresholdTop, at::Tensor thresholdBottom) { +void apex_fast_bottleneck_forward_out2_mask(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor thresholdTop, + at::Tensor thresholdBottom) { bottleneck_forward_out2_mask(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), thresholdTop, thresholdBottom); } @@ -3599,9 +3599,8 @@ at::Tensor apex_fast_bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::T return bottleneck_forward_out2_halo_corr(explicit_nhwc, slim_halo_y1, std::move(inputs), w1by3, out2_part_halo); } -void apex_fast_bottleneck_forward_out2_pad(bool explicit_nhwc, int64_t stride_1X1, - std::vector inputs, std::vector outputs, - at::Tensor out1_pad) { +void apex_fast_bottleneck_forward_out2_pad(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor out1_pad) { bottleneck_forward_out2_pad(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), out1_pad); } @@ -3616,44 +3615,42 @@ std::vector apex_fast_bottleneck_backward_init(bool explicit_nhwc, i } at::Tensor apex_fast_bottleneck_backward_grad_out2(bool explicit_nhwc, int64_t stride_1X1, - std::vector inputs, - std::vector outputs) { + std::vector inputs, std::vector outputs) { return bottleneck_backward_grad_out2(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); } at::Tensor apex_fast_bottleneck_backward_grad_out1(bool explicit_nhwc, int64_t stride_1X1, - std::vector inputs, - std::vector outputs, at::Tensor grad_out2) { + std::vector inputs, std::vector outputs, + at::Tensor grad_out2) { return bottleneck_backward_grad_out1(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out2); } at::Tensor apex_fast_bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int64_t stride_1X1, - std::vector inputs, - std::vector outputs, at::Tensor grad_out2, - at::Tensor thresholdTop, at::Tensor thresholdBottom) { + std::vector inputs, std::vector outputs, + at::Tensor grad_out2, at::Tensor thresholdTop, + at::Tensor thresholdBottom) { return bottleneck_backward_grad_out1_mask(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out2, thresholdTop, thresholdBottom); } -at::Tensor apex_fast_bottleneck_backward_grad_out1_halo_corr( - bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, at::Tensor w1by3, - std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) { +at::Tensor apex_fast_bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, at::Tensor w1by3, + std::vector outputs, at::Tensor grad_out2_halo, + at::Tensor relu1_halo, at::Tensor part_grad_out1) { return bottleneck_backward_grad_out1_halo_corr(explicit_nhwc, as_int(stride_1X1), std::move(inputs), w1by3, std::move(outputs), grad_out2_halo, relu1_halo, part_grad_out1); } at::Tensor apex_fast_bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int64_t stride_1X1, - std::vector inputs, - std::vector outputs, at::Tensor grad_out2_halo, - at::Tensor relu1_halo) { + std::vector inputs, std::vector outputs, + at::Tensor grad_out2_halo, at::Tensor relu1_halo) { return bottleneck_backward_grad_out1_halo(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out2_halo, relu1_halo); } -void apex_fast_bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int64_t stride_1X1, - std::vector inputs, std::vector outputs, - at::Tensor input, at::Tensor grad_out2) { +void apex_fast_bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor input, at::Tensor grad_out2) { bottleneck_backward_wgrad2_pad(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), input, grad_out2); } @@ -3664,9 +3661,8 @@ void apex_fast_bottleneck_backward_wgrad2(bool explicit_nhwc, int64_t stride_1X1 } at::Tensor apex_fast_bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int64_t stride_1X1, - std::vector inputs, - std::vector outputs, at::Tensor input, - at::Tensor grad_out2_halo) { + std::vector inputs, std::vector outputs, + at::Tensor input, at::Tensor grad_out2_halo) { return bottleneck_backward_wgrad2_halo(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), input, grad_out2_halo); } @@ -3682,8 +3678,7 @@ void apex_fast_bottleneck_backward_wgrad1(bool explicit_nhwc, int64_t stride_1X1 } void apex_fast_bottleneck_backward_rest(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, - std::vector outputs, at::Tensor grad_out2, - at::Tensor grad_out1) { + std::vector outputs, at::Tensor grad_out2, at::Tensor grad_out1) { bottleneck_backward_rest(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out2, grad_out1); } @@ -3695,37 +3690,51 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("fast_bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); m.def("fast_bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) -> ()"); m.def("fast_bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) -> ()"); - m.def("fast_bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " - "Tensor thresholdTop, Tensor thresholdBottom) -> ()"); + m.def( + "fast_bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor thresholdTop, Tensor thresholdBottom) -> ()"); m.def("fast_bottleneck_forward_out2_halo(bool explicit_nhwc, Tensor fat_halo_y1, Tensor[] inputs) -> Tensor"); - m.def("fast_bottleneck_forward_out2_halo_corr(bool explicit_nhwc, Tensor slim_halo_y1, Tensor[] inputs, " - "Tensor w1by3, Tensor out2_part_halo) -> Tensor"); - m.def("fast_bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " - "Tensor out1_pad) -> ()"); + m.def( + "fast_bottleneck_forward_out2_halo_corr(bool explicit_nhwc, Tensor slim_halo_y1, Tensor[] inputs, " + "Tensor w1by3, Tensor out2_part_halo) -> Tensor"); + m.def( + "fast_bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor out1_pad) -> ()"); m.def("fast_bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) -> ()"); m.def("fast_bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); - m.def("fast_bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) " - "-> Tensor"); - m.def("fast_bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " - "Tensor grad_out2) -> Tensor"); - m.def("fast_bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " - "Tensor[] outputs, Tensor grad_out2, Tensor thresholdTop, Tensor thresholdBottom) -> Tensor"); - m.def("fast_bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " - "Tensor[] outputs, Tensor grad_out2_halo, Tensor relu1_halo) -> Tensor"); - m.def("fast_bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " - "Tensor w1by3, Tensor[] outputs, Tensor grad_out2_halo, Tensor relu1_halo, Tensor part_grad_out1) -> Tensor"); - m.def("fast_bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " - "Tensor input, Tensor grad_out2) -> ()"); - m.def("fast_bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " - "Tensor grad_out2) -> ()"); - m.def("fast_bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " - "Tensor input, Tensor grad_out2_halo) -> Tensor"); - m.def("fast_bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) " - "-> ()"); - m.def("fast_bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " - "Tensor grad_out1) -> ()"); - m.def("fast_bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " - "Tensor grad_out2, Tensor grad_out1) -> ()"); + m.def( + "fast_bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) " + "-> Tensor"); + m.def( + "fast_bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out2) -> Tensor"); + m.def( + "fast_bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " + "Tensor[] outputs, Tensor grad_out2, Tensor thresholdTop, Tensor thresholdBottom) -> Tensor"); + m.def( + "fast_bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " + "Tensor[] outputs, Tensor grad_out2_halo, Tensor relu1_halo) -> Tensor"); + m.def( + "fast_bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " + "Tensor w1by3, Tensor[] outputs, Tensor grad_out2_halo, Tensor relu1_halo, Tensor part_grad_out1) -> Tensor"); + m.def( + "fast_bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor input, Tensor grad_out2) -> ()"); + m.def( + "fast_bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out2) -> ()"); + m.def( + "fast_bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor input, Tensor grad_out2_halo) -> Tensor"); + m.def( + "fast_bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) " + "-> ()"); + m.def( + "fast_bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out1) -> ()"); + m.def( + "fast_bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out2, Tensor grad_out1) -> ()"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp index 460046c7f..1c8c2d273 100644 --- a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp +++ b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp @@ -133,20 +133,22 @@ at::Tensor apex_cudnn_gbn_forward(const at::Tensor& x, const at::Tensor& scale, } std::vector apex_cudnn_gbn_backward(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, double epsilon, - int64_t bn_group, int64_t rank_id, at::IntArrayRef peer_buffers) { + const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, + double epsilon, int64_t bn_group, int64_t rank_id, + at::IntArrayRef peer_buffers) { return gbn_backward(x, dy, scale, minibatch_mean, minibatch_inv_var, static_cast(epsilon), bn_group, static_cast(rank_id), to_int_vector(peer_buffers)); } } // namespace TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("cudnn_gbn_forward(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_var, " - "Tensor minibatch_mean, Tensor minibatch_inv_var, float momentum, float epsilon, int bn_group, int rank_id, " - "int[] peer_buffers) -> Tensor"); - m.def("cudnn_gbn_backward(Tensor x, Tensor dy, Tensor scale, Tensor minibatch_mean, Tensor minibatch_inv_var, " - "float epsilon, int bn_group, int rank_id, int[] peer_buffers) -> Tensor[]"); + m.def( + "cudnn_gbn_forward(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_var, " + "Tensor minibatch_mean, Tensor minibatch_inv_var, float momentum, float epsilon, int bn_group, int rank_id, " + "int[] peer_buffers) -> Tensor"); + m.def( + "cudnn_gbn_backward(Tensor x, Tensor dy, Tensor scale, Tensor minibatch_mean, Tensor minibatch_inv_var, " + "float epsilon, int bn_group, int rank_id, int[] peer_buffers) -> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/fmha/fmha_api.cpp b/apex/contrib/csrc/fmha/fmha_api.cpp index e7193e663..c2a6a7a43 100644 --- a/apex/contrib/csrc/fmha/fmha_api.cpp +++ b/apex/contrib/csrc/fmha/fmha_api.cpp @@ -25,8 +25,8 @@ * ******************************************************************************/ -#include #include +#include #include #include "fmha.h" @@ -340,12 +340,15 @@ std::vector apex_fmha_bwd_nl(const at::Tensor& dout, const at::Tenso } // namespace TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("fmha_fwd(Tensor qkv, Tensor cu_seqlens, float p_dropout, int max_seq_len, bool is_training, bool is_nl, " - "bool zero_tensors, Generator? gen) -> Tensor[]"); - m.def("fmha_bwd(Tensor dout, Tensor qkv, Tensor(a!) softmax, Tensor cu_seqlens, float p_dropout, int max_seq_len, " - "bool zero_tensors) -> Tensor[]"); - m.def("fmha_bwd_nl(Tensor dout, Tensor qkv, Tensor(a!) softmax, Tensor cu_seqlens, float p_dropout, " - "int max_seq_len, bool zero_tensors) -> Tensor[]"); + m.def( + "fmha_fwd(Tensor qkv, Tensor cu_seqlens, float p_dropout, int max_seq_len, bool is_training, bool is_nl, " + "bool zero_tensors, Generator? gen) -> Tensor[]"); + m.def( + "fmha_bwd(Tensor dout, Tensor qkv, Tensor(a!) softmax, Tensor cu_seqlens, float p_dropout, int max_seq_len, " + "bool zero_tensors) -> Tensor[]"); + m.def( + "fmha_bwd_nl(Tensor dout, Tensor qkv, Tensor(a!) softmax, Tensor cu_seqlens, float p_dropout, " + "int max_seq_len, bool zero_tensors) -> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/fmha/src/fmha_fill.cu b/apex/contrib/csrc/fmha/src/fmha_fill.cu index 00614ce6f..01f6e6c71 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fill.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fill.cu @@ -25,9 +25,9 @@ * ******************************************************************************/ +#include #include #include -#include constexpr int block_size = 512; constexpr int ctas_per_sm = 4; diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp index 8978d51c6..c02dbd7f5 100644 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp +++ b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp @@ -42,8 +42,9 @@ at::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::Tensor& } TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("focal_loss_forward(Tensor cls_output, Tensor cls_targets_at_level, Tensor num_positives_sum, " - "int num_real_classes, float alpha, float gamma, float smoothing_factor) -> Tensor[]"); + m.def( + "focal_loss_forward(Tensor cls_output, Tensor cls_targets_at_level, Tensor num_positives_sum, " + "int num_real_classes, float alpha, float gamma, float smoothing_factor) -> Tensor[]"); m.def("focal_loss_backward(Tensor grad_output, Tensor partial_grad, Tensor num_positives_sum) -> Tensor"); } diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp b/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp index 01a0553bc..14ef3812b 100644 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp @@ -7,16 +7,16 @@ #include #include -#include "group_norm_nhwc.h" -#include "group_norm_nhwc_bwd_one_pass.h" -#include "group_norm_nhwc_fwd_one_pass.h" - #include #include #include #include #include +#include "group_norm_nhwc.h" +#include "group_norm_nhwc_bwd_one_pass.h" +#include "group_norm_nhwc_fwd_one_pass.h" + //////////////////////////////////////////////////////////////////////////////////////////////////// #define CHECK_CUDA_STATUS(call) \ @@ -283,10 +283,12 @@ std::vector apex_group_norm_bwd(at::Tensor grad_output, at::Tensor s } // namespace TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("group_norm_forward(Tensor input, int groups, Tensor weight, Tensor bias, float eps, int passes, " - "bool with_swish) -> Tensor[]"); - m.def("group_norm_backward(Tensor grad_output, Tensor sums, Tensor input, int groups, Tensor weight, Tensor bias, " - "float eps, int passes, bool with_swish) -> Tensor[]"); + m.def( + "group_norm_forward(Tensor input, int groups, Tensor weight, Tensor bias, float eps, int passes, " + "bool with_swish) -> Tensor[]"); + m.def( + "group_norm_backward(Tensor grad_output, Tensor sums, Tensor input, int groups, Tensor weight, Tensor bias, " + "float eps, int passes, bool with_swish) -> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/group_norm_v2/gn.cpp b/apex/contrib/csrc/group_norm_v2/gn.cpp index bbe3e4b5c..51d062eec 100644 --- a/apex/contrib/csrc/group_norm_v2/gn.cpp +++ b/apex/contrib/csrc/group_norm_v2/gn.cpp @@ -35,8 +35,7 @@ at::Tensor gn(at::Tensor x, at::Tensor w, at::Tensor b, float eps, bool silu, in } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } - at::Tensor red_buffer = - at::empty({meta.red_buffer_size}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); + at::Tensor red_buffer = at::empty({meta.red_buffer_size}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); thread_local at::Tensor barrier; if (barrier.size(0) < meta.barrier_size) { barrier = at::zeros({meta.barrier_size}, at::TensorOptions().dtype(at::ScalarType::UInt32).device(at::kCUDA)); @@ -83,8 +82,7 @@ auto gn_bwd(at::Tensor grad_output, at::Tensor x, at::Tensor w, at::Tensor b, at } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } - at::Tensor red_buffer = - at::empty({meta.red_buffer_size}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); + at::Tensor red_buffer = at::empty({meta.red_buffer_size}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); thread_local at::Tensor barrier; if (barrier.size(0) < meta.barrier_size) { barrier = at::zeros({meta.barrier_size}, at::TensorOptions().dtype(at::ScalarType::UInt32).device(at::kCUDA)); @@ -127,10 +125,12 @@ std::vector apex_group_norm_v2_gn_bwd(at::Tensor grad_output, at::Te } // namespace TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("group_norm_v2_gn(Tensor x, Tensor w, Tensor b, float eps, bool silu, int num_groups, " - "Tensor? mean_var_out, int sm_margin) -> Tensor"); - m.def("group_norm_v2_gn_bwd(Tensor grad_output, Tensor x, Tensor w, Tensor b, Tensor mean_var, float eps, " - "bool silu, int num_groups, int sm_margin) -> Tensor[]"); + m.def( + "group_norm_v2_gn(Tensor x, Tensor w, Tensor b, float eps, bool silu, int num_groups, " + "Tensor? mean_var_out, int sm_margin) -> Tensor"); + m.def( + "group_norm_v2_gn_bwd(Tensor grad_output, Tensor x, Tensor w, Tensor b, Tensor mean_var, float eps, " + "bool silu, int num_groups, int sm_margin) -> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/groupbn/interface.cpp b/apex/contrib/csrc/groupbn/interface.cpp index b15250e92..cd6eeb7ef 100644 --- a/apex/contrib/csrc/groupbn/interface.cpp +++ b/apex/contrib/csrc/groupbn/interface.cpp @@ -93,14 +93,13 @@ at::Tensor apex_bnp_bn_fwd_nhwc(const at::Tensor& x, const at::Tensor& scale, co const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, double momentum, double epsilon, bool fuse_relu, c10::optional my_data, c10::optional pair_data, - c10::optional pair_data2, c10::optional pair_data3, - int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, - int64_t grid_dim_x, bool coop) { + c10::optional pair_data2, c10::optional pair_data3, int64_t bn_group, + const at::Tensor& magic_tensor, int64_t occupancy, int64_t grid_dim_x, bool coop) { return nhwc_bn_fwd_train(x, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, ret_cta, - static_cast(momentum), static_cast(epsilon), fuse_relu, - optional_ptr(my_data), optional_ptr(pair_data), optional_ptr(pair_data2), - optional_ptr(pair_data3), static_cast(bn_group), magic_tensor, - static_cast(occupancy), static_cast(grid_dim_x), coop); + static_cast(momentum), static_cast(epsilon), fuse_relu, optional_ptr(my_data), + optional_ptr(pair_data), optional_ptr(pair_data2), optional_ptr(pair_data3), + static_cast(bn_group), magic_tensor, static_cast(occupancy), + static_cast(grid_dim_x), coop); } at::Tensor apex_bnp_bn_fwd_eval_nhwc(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, @@ -111,13 +110,15 @@ at::Tensor apex_bnp_bn_fwd_eval_nhwc(const at::Tensor& x, const at::Tensor& scal static_cast(momentum), static_cast(epsilon), fuse_relu); } -std::vector apex_bnp_bn_bwd_nhwc( - const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& bias, - const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, double momentum, double epsilon, bool fuse_relu, - c10::optional my_data, c10::optional pair_data, c10::optional pair_data2, - c10::optional pair_data3, int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, - int64_t grid_dim_x, bool coop) { +std::vector apex_bnp_bn_bwd_nhwc(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, + double momentum, double epsilon, bool fuse_relu, + c10::optional my_data, c10::optional pair_data, + c10::optional pair_data2, c10::optional pair_data3, + int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, + int64_t grid_dim_x, bool coop) { return nhwc_bn_bwd(x, dy, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, ret_cta, static_cast(momentum), static_cast(epsilon), fuse_relu, optional_ptr(my_data), optional_ptr(pair_data), optional_ptr(pair_data2), optional_ptr(pair_data3), @@ -125,36 +126,37 @@ std::vector apex_bnp_bn_bwd_nhwc( static_cast(grid_dim_x), coop); } -at::Tensor apex_bnp_bn_addrelu_fwd_nhwc( - const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, const at::Tensor& bias, - const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, const at::Tensor& ret_cta, double momentum, - double epsilon, c10::optional my_data, c10::optional pair_data, - c10::optional pair_data2, c10::optional pair_data3, int64_t bn_group, - const at::Tensor& magic_tensor, int64_t occupancy, int64_t grid_dim_x, bool coop) { - return nhwc_bn_addrelu_fwd_train(x, z, scale, bias, running_mean, running_inv_var, minibatch_mean, - minibatch_inv_var, bitmask, ret_cta, static_cast(momentum), - static_cast(epsilon), optional_ptr(my_data), optional_ptr(pair_data), - optional_ptr(pair_data2), optional_ptr(pair_data3), static_cast(bn_group), - magic_tensor, static_cast(occupancy), static_cast(grid_dim_x), coop); +at::Tensor apex_bnp_bn_addrelu_fwd_nhwc(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, + const at::Tensor& ret_cta, double momentum, double epsilon, + c10::optional my_data, c10::optional pair_data, + c10::optional pair_data2, c10::optional pair_data3, + int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, + int64_t grid_dim_x, bool coop) { + return nhwc_bn_addrelu_fwd_train(x, z, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, + bitmask, ret_cta, static_cast(momentum), static_cast(epsilon), + optional_ptr(my_data), optional_ptr(pair_data), optional_ptr(pair_data2), + optional_ptr(pair_data3), static_cast(bn_group), magic_tensor, + static_cast(occupancy), static_cast(grid_dim_x), coop); } at::Tensor apex_bnp_bn_addrelu_fwd_eval_nhwc(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& ret_cta, int64_t bn_group, double momentum, double epsilon) { - return nhwc_bn_addrelu_fwd_eval(x, z, scale, bias, running_mean, running_inv_var, ret_cta, - static_cast(bn_group), static_cast(momentum), - static_cast(epsilon)); + return nhwc_bn_addrelu_fwd_eval(x, z, scale, bias, running_mean, running_inv_var, ret_cta, static_cast(bn_group), + static_cast(momentum), static_cast(epsilon)); } std::vector apex_bnp_bn_addrelu_bwd_nhwc( const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& bias, const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, const at::Tensor& ret_cta, double momentum, - double epsilon, c10::optional my_data, c10::optional pair_data, - c10::optional pair_data2, c10::optional pair_data3, int64_t bn_group, - const at::Tensor& magic_tensor, int64_t occupancy, int64_t grid_dim_x, bool coop) { + double epsilon, c10::optional my_data, c10::optional pair_data, c10::optional pair_data2, + c10::optional pair_data3, int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, + int64_t grid_dim_x, bool coop) { return nhwc_bn_addrelu_bwd(x, dy, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, bitmask, ret_cta, static_cast(momentum), static_cast(epsilon), optional_ptr(my_data), optional_ptr(pair_data), optional_ptr(pair_data2), @@ -168,28 +170,34 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("bnp_get_data_ptr(Tensor data) -> int"); m.def("bnp_get_remote_data_ptr(Tensor handle, int offset) -> int"); m.def("bnp_close_remote_data(Tensor handle) -> ()"); - m.def("bnp_bn_fwd_nhwc(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_inv_var, " - "Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor ret_cta, float momentum, float epsilon, " - "bool fuse_relu, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, int bn_group, " - "Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor"); - m.def("bnp_bn_fwd_eval_nhwc(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_inv_var, " - "Tensor ret_cta, int bn_group, float momentum, float epsilon, bool fuse_relu) -> Tensor"); - m.def("bnp_bn_bwd_nhwc(Tensor x, Tensor dy, Tensor scale, Tensor bias, Tensor running_mean, " - "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor ret_cta, float momentum, " - "float epsilon, bool fuse_relu, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " - "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor[]"); + m.def( + "bnp_bn_fwd_nhwc(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_inv_var, " + "Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor ret_cta, float momentum, float epsilon, " + "bool fuse_relu, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, int bn_group, " + "Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor"); + m.def( + "bnp_bn_fwd_eval_nhwc(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_inv_var, " + "Tensor ret_cta, int bn_group, float momentum, float epsilon, bool fuse_relu) -> Tensor"); + m.def( + "bnp_bn_bwd_nhwc(Tensor x, Tensor dy, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor ret_cta, float momentum, " + "float epsilon, bool fuse_relu, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " + "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor[]"); m.def("bnp_bn_fwd_nhwc_occupancy() -> int"); m.def("bnp_bn_bwd_nhwc_occupancy() -> int"); - m.def("bnp_bn_addrelu_fwd_nhwc(Tensor x, Tensor z, Tensor scale, Tensor bias, Tensor running_mean, " - "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor bitmask, Tensor ret_cta, " - "float momentum, float epsilon, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " - "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor"); - m.def("bnp_bn_addrelu_fwd_eval_nhwc(Tensor x, Tensor z, Tensor scale, Tensor bias, Tensor running_mean, " - "Tensor running_inv_var, Tensor ret_cta, int bn_group, float momentum, float epsilon) -> Tensor"); - m.def("bnp_bn_addrelu_bwd_nhwc(Tensor x, Tensor dy, Tensor scale, Tensor bias, Tensor running_mean, " - "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor bitmask, Tensor ret_cta, " - "float momentum, float epsilon, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " - "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor[]"); + m.def( + "bnp_bn_addrelu_fwd_nhwc(Tensor x, Tensor z, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor bitmask, Tensor ret_cta, " + "float momentum, float epsilon, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " + "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor"); + m.def( + "bnp_bn_addrelu_fwd_eval_nhwc(Tensor x, Tensor z, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor ret_cta, int bn_group, float momentum, float epsilon) -> Tensor"); + m.def( + "bnp_bn_addrelu_bwd_nhwc(Tensor x, Tensor dy, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor bitmask, Tensor ret_cta, " + "float momentum, float epsilon, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " + "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor[]"); m.def("bnp_bn_addrelu_fwd_nhwc_occupancy() -> int"); m.def("bnp_bn_addrelu_bwd_nhwc_occupancy() -> int"); } diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp index c2d7d34e5..702525cf8 100644 --- a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp +++ b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp @@ -82,9 +82,9 @@ void index_mul_2d_float_backward_dispatch(const at::Tensor& grad_in1, const at:: void index_mul_2d_float_backward_backward_dispatch(const at::Tensor& grad_grad_out, const at::Tensor& grad_in1, const at::Tensor& grad_in2, const at::Tensor& grad_out, - const at::Tensor& grad_grad_in1, - const at::Tensor& grad_grad_in2, const at::Tensor& in1, - const at::Tensor& in2, const at::Tensor& idx1) { + const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, + const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { at::Tensor grad_grad_out_arg = grad_grad_out; at::Tensor grad_in1_arg = grad_in1; at::Tensor grad_in2_arg = grad_in2; @@ -108,9 +108,9 @@ void index_mul_2d_half_backward_dispatch(const at::Tensor& grad_in1, const at::T void index_mul_2d_half_backward_backward_dispatch(const at::Tensor& grad_grad_out, const at::Tensor& grad_in1, const at::Tensor& grad_in2, const at::Tensor& grad_out, - const at::Tensor& grad_grad_in1, - const at::Tensor& grad_grad_in2, const at::Tensor& in1, - const at::Tensor& in2, const at::Tensor& idx1) { + const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, + const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { at::Tensor grad_grad_out_arg = grad_grad_out; at::Tensor grad_in1_arg = grad_in1; at::Tensor grad_in2_arg = grad_in2; @@ -120,17 +120,21 @@ void index_mul_2d_half_backward_backward_dispatch(const at::Tensor& grad_grad_ou TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("index_mul_2d_float_forward(Tensor(a!) out, Tensor in1, Tensor in2, Tensor idx1) -> ()"); - m.def("index_mul_2d_float_backward(Tensor(a!) grad_in1, Tensor(b!) grad_in2, Tensor grad_out, Tensor in1, " - "Tensor in2, Tensor idx1) -> ()"); - m.def("index_mul_2d_float_backward_backward(Tensor(a!) grad_grad_out, Tensor(b!) grad_in1, " - "Tensor(c!) grad_in2, Tensor grad_out, Tensor grad_grad_in1, Tensor grad_grad_in2, Tensor in1, " - "Tensor in2, Tensor idx1) -> ()"); + m.def( + "index_mul_2d_float_backward(Tensor(a!) grad_in1, Tensor(b!) grad_in2, Tensor grad_out, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); + m.def( + "index_mul_2d_float_backward_backward(Tensor(a!) grad_grad_out, Tensor(b!) grad_in1, " + "Tensor(c!) grad_in2, Tensor grad_out, Tensor grad_grad_in1, Tensor grad_grad_in2, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); m.def("index_mul_2d_half_forward(Tensor(a!) out, Tensor in1, Tensor in2, Tensor idx1) -> ()"); - m.def("index_mul_2d_half_backward(Tensor(a!) grad_in1, Tensor(b!) grad_in2, Tensor grad_out, Tensor in1, " - "Tensor in2, Tensor idx1) -> ()"); - m.def("index_mul_2d_half_backward_backward(Tensor(a!) grad_grad_out, Tensor(b!) grad_in1, " - "Tensor(c!) grad_in2, Tensor grad_out, Tensor grad_grad_in1, Tensor grad_grad_in2, Tensor in1, " - "Tensor in2, Tensor idx1) -> ()"); + m.def( + "index_mul_2d_half_backward(Tensor(a!) grad_in1, Tensor(b!) grad_in2, Tensor grad_out, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); + m.def( + "index_mul_2d_half_backward_backward(Tensor(a!) grad_grad_out, Tensor(b!) grad_in1, " + "Tensor(c!) grad_in2, Tensor grad_out, Tensor grad_grad_in1, Tensor grad_grad_in2, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp index 89af4b675..e5c0a0adc 100644 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ b/apex/contrib/csrc/layer_norm/ln_api.cpp @@ -1,7 +1,7 @@ #include +#include #include -#include #include "ln.h" /* @@ -260,16 +260,17 @@ std::vector apex_fast_layer_norm_ln_fwd(const at::Tensor& x, const a std::vector apex_fast_layer_norm_ln_bwd(const at::Tensor& dz, const at::Tensor& x_or_z, const c10::optional& mu, const at::Tensor& rsigma, - const at::Tensor& gamma, - const c10::optional& beta, bool memory_efficient) { + const at::Tensor& gamma, const c10::optional& beta, + bool memory_efficient) { return ln_bwd(dz, x_or_z, mu, rsigma, gamma, beta, memory_efficient); } } // namespace TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("fast_layer_norm_ln_fwd(Tensor x, Tensor gamma, Tensor beta, float epsilon) -> Tensor[]"); - m.def("fast_layer_norm_ln_bwd(Tensor dz, Tensor x_or_z, Tensor? mu, Tensor rsigma, Tensor gamma, Tensor? beta, " - "bool memory_efficient) -> Tensor[]"); + m.def( + "fast_layer_norm_ln_bwd(Tensor dz, Tensor x_or_z, Tensor? mu, Tensor rsigma, Tensor gamma, Tensor? beta, " + "bool memory_efficient) -> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu index e5fa4f413..abf64b052 100644 --- a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu @@ -5,9 +5,9 @@ #if __has_include() #include #endif +#include #include #include -#include #include #include @@ -22,7 +22,7 @@ namespace fused_softmax { namespace additive_mask_softmax_dropout { std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const half* pad_mask, - float dropout_prob) { + float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; const int q_seq_len = input.size(1); @@ -74,7 +74,7 @@ std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& } at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, - at::Tensor const& dropout_mask, float dropout_prob) { + at::Tensor const& dropout_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); const int k_seq_len = q_seq_len; diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu index b92647375..99dd6fdde 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu @@ -5,9 +5,9 @@ #if __has_include() #include #endif +#include #include #include -#include #include #include @@ -21,9 +21,9 @@ namespace encdec { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, - at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, - at::Tensor const& input_weights_kv, at::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob) { + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -144,12 +144,11 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads } std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, - at::Tensor const& inputs_q, at::Tensor const& inputs_kv, - at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, - float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu index 4cdb3c301..2ec6a3e7a 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu @@ -5,9 +5,9 @@ #if __has_include() #include #endif +#include #include #include -#include #include #include @@ -22,10 +22,10 @@ namespace encdec_norm_add { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, - at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, - at::Tensor const& input_weights_kv, at::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob) { + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -178,15 +178,15 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads } std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, - at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, - at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, - at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, - at::Tensor const& input_weights_kv, at::Tensor const& output_weights, - at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, - float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, + at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, + float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); diff --git a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu index fcb626165..0b58c33b8 100644 --- a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu @@ -5,9 +5,9 @@ #if __has_include() #include #endif +#include #include #include -#include #include #include @@ -20,7 +20,7 @@ namespace fused_softmax { namespace mask_softmax_dropout { std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const uint8_t* pad_mask, - float dropout_prob) { + float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; const int q_seq_len = input.size(1); @@ -72,7 +72,7 @@ std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& } at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, - at::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob) { + at::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); const int k_seq_len = q_seq_len; diff --git a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp index 1272dde59..8e66c5f60 100644 --- a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp +++ b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include #include @@ -15,13 +15,13 @@ namespace fused_softmax { namespace additive_mask_softmax_dropout { std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const half* pad_mask, - float dropout_prob); + float dropout_prob); at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, - at::Tensor const& dropout_mask, float dropout_prob); + at::Tensor const& dropout_mask, float dropout_prob); std::vector fwd(bool use_mask, bool is_training, int heads, at::Tensor const& input, - at::Tensor const& pad_mask, float dropout_prob) { + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { @@ -34,7 +34,7 @@ std::vector fwd(bool use_mask, bool is_training, int heads, at::Tens } at::Tensor bwd(bool use_mask, int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, - at::Tensor const& dropout_mask, float dropout_prob) { + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); @@ -50,13 +50,13 @@ at::Tensor bwd(bool use_mask, int heads, at::Tensor const& output_grads, at::Ten namespace mask_softmax_dropout { std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const uint8_t* pad_mask, - float dropout_prob); + float dropout_prob); at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, - at::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob); + at::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob); std::vector fwd(bool use_mask, bool is_training, int heads, at::Tensor const& input, - at::Tensor const& pad_mask, float dropout_prob) { + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); @@ -70,7 +70,7 @@ std::vector fwd(bool use_mask, bool is_training, int heads, at::Tens } at::Tensor bwd(bool use_mask, int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, - at::Tensor const& dropout_mask, at::Tensor const& padding_mask, float dropout_prob) { + at::Tensor const& dropout_mask, at::Tensor const& padding_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); @@ -91,21 +91,20 @@ namespace encdec { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, - at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, - at::Tensor const& input_weights_kv, at::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob); + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, - at::Tensor const& inputs_q, at::Tensor const& inputs_kv, - at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, - float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - at::Tensor const& inputs_q, at::Tensor const& inputs_kv, - at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, - at::Tensor const& output_weights, at::Tensor const& pad_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights_q.dim() == 2, "expected 2D tensor"); @@ -128,12 +127,11 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, } std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, - at::Tensor const& inputs_q, at::Tensor const& inputs_kv, - at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, - float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -172,27 +170,27 @@ namespace encdec_norm_add { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, - at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, - at::Tensor const& input_weights_kv, at::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob); + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, - at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, - at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, - at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, - at::Tensor const& input_weights_kv, at::Tensor const& output_weights, - at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, - float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - at::Tensor const& inputs_q, at::Tensor const& inputs_kv, - at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, - at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, - at::Tensor const& output_weights, at::Tensor const& pad_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, + at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, + float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); @@ -220,15 +218,14 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, } std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, - at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, - at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, - at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, - at::Tensor const& input_weights_kv, at::Tensor const& output_weights, - at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, - float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, + at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -280,18 +277,18 @@ namespace self { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob); + at::Tensor const& input_weights, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - at::Tensor const& dropout_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - at::Tensor const& inputs, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& pad_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); @@ -310,10 +307,10 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, } std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - at::Tensor const& dropout_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -344,22 +341,22 @@ namespace self_bias { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - at::Tensor const& input_biases, at::Tensor const& output_biases, - const uint8_t* pad_mask, float dropout_prob); + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, + const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - // at::Tensor const& input_biases, - // at::Tensor const& output_biases, - at::Tensor const& dropout_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - at::Tensor const& inputs, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& input_biases, - at::Tensor const& output_biases, at::Tensor const& pad_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + // at::Tensor const& input_biases, + // at::Tensor const& output_biases, + at::Tensor const& dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, at::Tensor const& pad_mask, + float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); @@ -378,10 +375,10 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, } std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - at::Tensor const& dropout_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -412,24 +409,24 @@ namespace self_bias_additive_mask { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - at::Tensor const& input_biases, at::Tensor const& output_biases, - const half* pad_mask, float dropout_prob); + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, const half* pad_mask, + float dropout_prob); std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, - // at::Tensor const& softmax_results, - at::Tensor const& bmm1_results, at::Tensor const& pad_mask, - at::Tensor const& input_lin_results, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - // at::Tensor const& input_biases, - // at::Tensor const& output_biases, - at::Tensor const& dropout_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - at::Tensor const& inputs, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& input_biases, - at::Tensor const& output_biases, at::Tensor const& pad_mask, float dropout_prob) { + at::Tensor const& dropout_results, + // at::Tensor const& softmax_results, + at::Tensor const& bmm1_results, at::Tensor const& pad_mask, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + // at::Tensor const& input_biases, + // at::Tensor const& output_biases, + at::Tensor const& dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, at::Tensor const& pad_mask, + float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); @@ -449,11 +446,10 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, } std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& bmm1_results, - at::Tensor const& pad_mask, at::Tensor const& input_lin_results, - at::Tensor const& inputs, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, - float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& bmm1_results, + at::Tensor const& pad_mask, at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -483,23 +479,23 @@ namespace self_norm_add { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, - at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, - at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob); std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, - at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, - at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, - at::Tensor const& dropout_add_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& pad_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, + at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); @@ -522,13 +518,12 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, } std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, - at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, - at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, - at::Tensor const& dropout_add_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -575,30 +570,37 @@ int as_int(int64_t value) { return static_cast(value); } float as_float(double value) { return static_cast(value); } -std::vector apex_fast_multihead_attn_additive_mask_softmax_dropout_forward( - bool use_mask, bool is_training, int64_t heads, at::Tensor const& input, at::Tensor const& pad_mask, - double dropout_prob) { - return multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd( - use_mask, is_training, as_int(heads), input, pad_mask, as_float(dropout_prob)); +std::vector apex_fast_multihead_attn_additive_mask_softmax_dropout_forward(bool use_mask, bool is_training, + int64_t heads, + at::Tensor const& input, + at::Tensor const& pad_mask, + double dropout_prob) { + return multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd(use_mask, is_training, as_int(heads), input, + pad_mask, as_float(dropout_prob)); } -at::Tensor apex_fast_multihead_attn_additive_mask_softmax_dropout_backward( - bool use_mask, int64_t heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, - at::Tensor const& dropout_mask, double dropout_prob) { +at::Tensor apex_fast_multihead_attn_additive_mask_softmax_dropout_backward(bool use_mask, int64_t heads, + at::Tensor const& output_grads, + at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, + double dropout_prob) { return multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd( use_mask, as_int(heads), output_grads, softmax_results, dropout_mask, as_float(dropout_prob)); } -std::vector apex_fast_multihead_attn_mask_softmax_dropout_forward( - bool use_mask, bool is_training, int64_t heads, at::Tensor const& input, at::Tensor const& pad_mask, - double dropout_prob) { - return multihead_attn::fused_softmax::mask_softmax_dropout::fwd( - use_mask, is_training, as_int(heads), input, pad_mask, as_float(dropout_prob)); +std::vector apex_fast_multihead_attn_mask_softmax_dropout_forward(bool use_mask, bool is_training, + int64_t heads, at::Tensor const& input, + at::Tensor const& pad_mask, + double dropout_prob) { + return multihead_attn::fused_softmax::mask_softmax_dropout::fwd(use_mask, is_training, as_int(heads), input, pad_mask, + as_float(dropout_prob)); } -at::Tensor apex_fast_multihead_attn_mask_softmax_dropout_backward( - bool use_mask, int64_t heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, - at::Tensor const& dropout_mask, at::Tensor const& padding_mask, double dropout_prob) { +at::Tensor apex_fast_multihead_attn_mask_softmax_dropout_backward(bool use_mask, int64_t heads, + at::Tensor const& output_grads, + at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, + at::Tensor const& padding_mask, double dropout_prob) { return multihead_attn::fused_softmax::mask_softmax_dropout::bwd( use_mask, as_int(heads), output_grads, softmax_results, dropout_mask, padding_mask, as_float(dropout_prob)); } @@ -613,61 +615,59 @@ std::vector apex_fast_multihead_attn_encdec_multihead_attn_forward( } std::vector apex_fast_multihead_attn_encdec_multihead_attn_backward( - int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, at::Tensor const& inputs_q, at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, at::Tensor const& output_weights, at::Tensor const& dropout_mask, double dropout_prob) { - return multihead_attn::encdec::cublas_gemmex::bwd( - as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, - input_lin_kv_results, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, dropout_mask, - as_float(dropout_prob)); + return multihead_attn::encdec::cublas_gemmex::bwd(as_int(heads), output_grads, matmul2_results, dropout_results, + softmax_results, input_lin_q_results, input_lin_kv_results, + inputs_q, inputs_kv, input_weights_q, input_weights_kv, + output_weights, dropout_mask, as_float(dropout_prob)); } std::vector apex_fast_multihead_attn_encdec_multihead_attn_norm_add_forward( bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs_q, - at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, - at::Tensor const& input_weights_kv, at::Tensor const& output_weights, at::Tensor const& pad_mask, - double dropout_prob) { + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& pad_mask, double dropout_prob) { return multihead_attn::encdec_norm_add::cublas_gemmex::fwd( use_mask, use_time_mask, is_training, as_int(heads), inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, as_float(dropout_prob)); } std::vector apex_fast_multihead_attn_encdec_multihead_attn_norm_add_backward( - int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, - at::Tensor const& input_weights_kv, at::Tensor const& output_weights, at::Tensor const& dropout_mask, - at::Tensor const& dropout_add_mask, double dropout_prob) { + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, + double dropout_prob) { return multihead_attn::encdec_norm_add::cublas_gemmex::bwd( as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, - input_lin_kv_results, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q, inputs_kv, - lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, dropout_mask, - dropout_add_mask, as_float(dropout_prob)); + input_lin_kv_results, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q, inputs_kv, lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, dropout_mask, dropout_add_mask, + as_float(dropout_prob)); } -std::vector apex_fast_multihead_attn_self_attn_forward( - bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& pad_mask, - double dropout_prob) { +std::vector apex_fast_multihead_attn_self_attn_forward(bool use_mask, bool use_time_mask, bool is_training, + int64_t heads, at::Tensor const& inputs, + at::Tensor const& input_weights, + at::Tensor const& output_weights, + at::Tensor const& pad_mask, double dropout_prob) { return multihead_attn::self::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs, - input_weights, output_weights, pad_mask, as_float(dropout_prob)); + input_weights, output_weights, pad_mask, as_float(dropout_prob)); } std::vector apex_fast_multihead_attn_self_attn_backward( - int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& inputs, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, double dropout_prob) { + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& dropout_mask, + double dropout_prob) { return multihead_attn::self::cublas_gemmex::bwd(as_int(heads), output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_results, inputs, input_weights, - output_weights, dropout_mask, as_float(dropout_prob)); + softmax_results, input_lin_results, inputs, input_weights, + output_weights, dropout_mask, as_float(dropout_prob)); } std::vector apex_fast_multihead_attn_self_attn_bias_forward( @@ -675,18 +675,18 @@ std::vector apex_fast_multihead_attn_self_attn_bias_forward( at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& input_biases, at::Tensor const& output_biases, at::Tensor const& pad_mask, double dropout_prob) { return multihead_attn::self_bias::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs, - input_weights, output_weights, input_biases, output_biases, - pad_mask, as_float(dropout_prob)); + input_weights, output_weights, input_biases, output_biases, + pad_mask, as_float(dropout_prob)); } std::vector apex_fast_multihead_attn_self_attn_bias_backward( - int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& inputs, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, double dropout_prob) { - return multihead_attn::self_bias::cublas_gemmex::bwd( - as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, inputs, - input_weights, output_weights, dropout_mask, as_float(dropout_prob)); + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& dropout_mask, + double dropout_prob) { + return multihead_attn::self_bias::cublas_gemmex::bwd(as_int(heads), output_grads, matmul2_results, dropout_results, + softmax_results, input_lin_results, inputs, input_weights, + output_weights, dropout_mask, as_float(dropout_prob)); } std::vector apex_fast_multihead_attn_self_attn_bias_additive_mask_forward( @@ -699,10 +699,10 @@ std::vector apex_fast_multihead_attn_self_attn_bias_additive_mask_fo } std::vector apex_fast_multihead_attn_self_attn_bias_additive_mask_backward( - int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& bmm1_results, at::Tensor const& pad_mask, - at::Tensor const& input_lin_results, at::Tensor const& inputs, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, double dropout_prob) { + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& bmm1_results, at::Tensor const& pad_mask, at::Tensor const& input_lin_results, + at::Tensor const& inputs, at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, double dropout_prob) { return multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd( as_int(heads), output_grads, matmul2_results, dropout_results, bmm1_results, pad_mask, input_lin_results, inputs, input_weights, output_weights, dropout_mask, as_float(dropout_prob)); @@ -710,82 +710,96 @@ std::vector apex_fast_multihead_attn_self_attn_bias_additive_mask_ba std::vector apex_fast_multihead_attn_self_attn_norm_add_forward( bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs, - at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, - at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& pad_mask, - double dropout_prob) { - return multihead_attn::self_norm_add::cublas_gemmex::fwd( - use_mask, use_time_mask, is_training, as_int(heads), inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, - input_weights, output_weights, pad_mask, as_float(dropout_prob)); + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::self_norm_add::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs, + lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, + output_weights, pad_mask, as_float(dropout_prob)); } std::vector apex_fast_multihead_attn_self_attn_norm_add_backward( - int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs, - at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, - at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& dropout_mask, - at::Tensor const& dropout_add_mask, double dropout_prob) { + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, + double dropout_prob) { return multihead_attn::self_norm_add::cublas_gemmex::bwd( as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, - lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, - input_weights, output_weights, dropout_mask, dropout_add_mask, as_float(dropout_prob)); + lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, + output_weights, dropout_mask, dropout_add_mask, as_float(dropout_prob)); } } // namespace TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("fast_multihead_attn_additive_mask_softmax_dropout_forward(bool use_mask, bool is_training, int heads, " - "Tensor input, Tensor pad_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_additive_mask_softmax_dropout_backward(bool use_mask, int heads, Tensor output_grads, " - "Tensor softmax_results, Tensor dropout_mask, float dropout_prob) -> Tensor"); - m.def("fast_multihead_attn_mask_softmax_dropout_forward(bool use_mask, bool is_training, int heads, Tensor input, " - "Tensor pad_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_mask_softmax_dropout_backward(bool use_mask, int heads, Tensor output_grads, " - "Tensor softmax_results, Tensor dropout_mask, Tensor padding_mask, float dropout_prob) -> Tensor"); - m.def("fast_multihead_attn_encdec_multihead_attn_forward(bool use_mask, bool use_time_mask, bool is_training, " - "int heads, Tensor inputs_q, Tensor inputs_kv, Tensor input_weights_q, Tensor input_weights_kv, " - "Tensor output_weights, Tensor pad_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_encdec_multihead_attn_backward(int heads, Tensor output_grads, Tensor matmul2_results, " - "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_q_results, Tensor input_lin_kv_results, " - "Tensor inputs_q, Tensor inputs_kv, Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, " - "Tensor dropout_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_encdec_multihead_attn_norm_add_forward(bool use_mask, bool use_time_mask, " - "bool is_training, int heads, Tensor inputs_q, Tensor inputs_kv, Tensor lyr_nrm_gamma_weights, " - "Tensor lyr_nrm_beta_weights, Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, " - "Tensor pad_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_encdec_multihead_attn_norm_add_backward(int heads, Tensor output_grads, " - "Tensor matmul2_results, Tensor dropout_results, Tensor softmax_results, Tensor input_lin_q_results, " - "Tensor input_lin_kv_results, Tensor lyr_nrm_results, Tensor lyr_nrm_mean, Tensor lyr_nrm_invvar, " - "Tensor inputs_q, Tensor inputs_kv, Tensor lyr_nrm_gamma_weights, Tensor lyr_nrm_beta_weights, " - "Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, Tensor dropout_mask, " - "Tensor dropout_add_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_self_attn_forward(bool use_mask, bool use_time_mask, bool is_training, int heads, " - "Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor pad_mask, float dropout_prob) " - "-> Tensor[]"); - m.def("fast_multihead_attn_self_attn_backward(int heads, Tensor output_grads, Tensor matmul2_results, " - "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor inputs, " - "Tensor input_weights, Tensor output_weights, Tensor dropout_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_self_attn_bias_forward(bool use_mask, bool use_time_mask, bool is_training, int heads, " - "Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor input_biases, Tensor output_biases, " - "Tensor pad_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_self_attn_bias_backward(int heads, Tensor output_grads, Tensor matmul2_results, " - "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor inputs, " - "Tensor input_weights, Tensor output_weights, Tensor dropout_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_self_attn_bias_additive_mask_forward(bool use_mask, bool use_time_mask, " - "bool is_training, int heads, Tensor inputs, Tensor input_weights, Tensor output_weights, " - "Tensor input_biases, Tensor output_biases, Tensor pad_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_self_attn_bias_additive_mask_backward(int heads, Tensor output_grads, " - "Tensor matmul2_results, Tensor dropout_results, Tensor bmm1_results, Tensor pad_mask, " - "Tensor input_lin_results, Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor dropout_mask, " - "float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_self_attn_norm_add_forward(bool use_mask, bool use_time_mask, bool is_training, " - "int heads, Tensor inputs, Tensor lyr_nrm_gamma_weights, Tensor lyr_nrm_beta_weights, " - "Tensor input_weights, Tensor output_weights, Tensor pad_mask, float dropout_prob) -> Tensor[]"); - m.def("fast_multihead_attn_self_attn_norm_add_backward(int heads, Tensor output_grads, Tensor matmul2_results, " - "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor lyr_nrm_results, " - "Tensor lyr_nrm_mean, Tensor lyr_nrm_invvar, Tensor inputs, Tensor lyr_nrm_gamma_weights, " - "Tensor lyr_nrm_beta_weights, Tensor input_weights, Tensor output_weights, Tensor dropout_mask, " - "Tensor dropout_add_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_additive_mask_softmax_dropout_forward(bool use_mask, bool is_training, int heads, " + "Tensor input, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_additive_mask_softmax_dropout_backward(bool use_mask, int heads, Tensor output_grads, " + "Tensor softmax_results, Tensor dropout_mask, float dropout_prob) -> Tensor"); + m.def( + "fast_multihead_attn_mask_softmax_dropout_forward(bool use_mask, bool is_training, int heads, Tensor input, " + "Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_mask_softmax_dropout_backward(bool use_mask, int heads, Tensor output_grads, " + "Tensor softmax_results, Tensor dropout_mask, Tensor padding_mask, float dropout_prob) -> Tensor"); + m.def( + "fast_multihead_attn_encdec_multihead_attn_forward(bool use_mask, bool use_time_mask, bool is_training, " + "int heads, Tensor inputs_q, Tensor inputs_kv, Tensor input_weights_q, Tensor input_weights_kv, " + "Tensor output_weights, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_encdec_multihead_attn_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_q_results, Tensor input_lin_kv_results, " + "Tensor inputs_q, Tensor inputs_kv, Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, " + "Tensor dropout_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_encdec_multihead_attn_norm_add_forward(bool use_mask, bool use_time_mask, " + "bool is_training, int heads, Tensor inputs_q, Tensor inputs_kv, Tensor lyr_nrm_gamma_weights, " + "Tensor lyr_nrm_beta_weights, Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, " + "Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_encdec_multihead_attn_norm_add_backward(int heads, Tensor output_grads, " + "Tensor matmul2_results, Tensor dropout_results, Tensor softmax_results, Tensor input_lin_q_results, " + "Tensor input_lin_kv_results, Tensor lyr_nrm_results, Tensor lyr_nrm_mean, Tensor lyr_nrm_invvar, " + "Tensor inputs_q, Tensor inputs_kv, Tensor lyr_nrm_gamma_weights, Tensor lyr_nrm_beta_weights, " + "Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, Tensor dropout_mask, " + "Tensor dropout_add_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_forward(bool use_mask, bool use_time_mask, bool is_training, int heads, " + "Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor pad_mask, float dropout_prob) " + "-> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor inputs, " + "Tensor input_weights, Tensor output_weights, Tensor dropout_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_bias_forward(bool use_mask, bool use_time_mask, bool is_training, int heads, " + "Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor input_biases, Tensor output_biases, " + "Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_bias_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor inputs, " + "Tensor input_weights, Tensor output_weights, Tensor dropout_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_bias_additive_mask_forward(bool use_mask, bool use_time_mask, " + "bool is_training, int heads, Tensor inputs, Tensor input_weights, Tensor output_weights, " + "Tensor input_biases, Tensor output_biases, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_bias_additive_mask_backward(int heads, Tensor output_grads, " + "Tensor matmul2_results, Tensor dropout_results, Tensor bmm1_results, Tensor pad_mask, " + "Tensor input_lin_results, Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor dropout_mask, " + "float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_norm_add_forward(bool use_mask, bool use_time_mask, bool is_training, " + "int heads, Tensor inputs, Tensor lyr_nrm_gamma_weights, Tensor lyr_nrm_beta_weights, " + "Tensor input_weights, Tensor output_weights, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_norm_add_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor lyr_nrm_results, " + "Tensor lyr_nrm_mean, Tensor lyr_nrm_invvar, Tensor inputs, Tensor lyr_nrm_gamma_weights, " + "Tensor lyr_nrm_beta_weights, Tensor input_weights, Tensor output_weights, Tensor dropout_mask, " + "Tensor dropout_add_mask, float dropout_prob) -> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { @@ -793,12 +807,9 @@ TORCH_LIBRARY_IMPL(apex, CUDA, m) { &apex_fast_multihead_attn_additive_mask_softmax_dropout_forward); m.impl("fast_multihead_attn_additive_mask_softmax_dropout_backward", &apex_fast_multihead_attn_additive_mask_softmax_dropout_backward); - m.impl("fast_multihead_attn_mask_softmax_dropout_forward", - &apex_fast_multihead_attn_mask_softmax_dropout_forward); - m.impl("fast_multihead_attn_mask_softmax_dropout_backward", - &apex_fast_multihead_attn_mask_softmax_dropout_backward); - m.impl("fast_multihead_attn_encdec_multihead_attn_forward", - &apex_fast_multihead_attn_encdec_multihead_attn_forward); + m.impl("fast_multihead_attn_mask_softmax_dropout_forward", &apex_fast_multihead_attn_mask_softmax_dropout_forward); + m.impl("fast_multihead_attn_mask_softmax_dropout_backward", &apex_fast_multihead_attn_mask_softmax_dropout_backward); + m.impl("fast_multihead_attn_encdec_multihead_attn_forward", &apex_fast_multihead_attn_encdec_multihead_attn_forward); m.impl("fast_multihead_attn_encdec_multihead_attn_backward", &apex_fast_multihead_attn_encdec_multihead_attn_backward); m.impl("fast_multihead_attn_encdec_multihead_attn_norm_add_forward", diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu index 7d08be994..e6c229f64 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu @@ -5,9 +5,9 @@ #if __has_include() #include #endif +#include #include #include -#include #include #include @@ -21,9 +21,9 @@ namespace self_bias_additive_mask { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - at::Tensor const& input_biases, at::Tensor const& output_biases, - const half* pad_mask, float dropout_prob) { + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, const half* pad_mask, + float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -124,11 +124,10 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads } std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& bmm1_results, - at::Tensor const& pad_mask, at::Tensor const& input_lin_results, - at::Tensor const& inputs, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, - float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& bmm1_results, + at::Tensor const& pad_mask, at::Tensor const& input_lin_results, + at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu index 1a2a4942c..67f49b5ab 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu @@ -5,9 +5,9 @@ #if __has_include() #include #endif +#include #include #include -#include #include #include @@ -21,9 +21,9 @@ namespace self_bias { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - at::Tensor const& input_biases, at::Tensor const& output_biases, - const uint8_t* pad_mask, float dropout_prob) { + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -135,10 +135,10 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads } std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - at::Tensor const& dropout_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu index cd071d3a6..c8720d0d2 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu @@ -5,9 +5,9 @@ #if __has_include() #include #endif +#include #include #include -#include #include #include @@ -21,8 +21,8 @@ namespace self { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob) { + at::Tensor const& input_weights, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -128,10 +128,10 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads } std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& inputs, - at::Tensor const& input_weights, at::Tensor const& output_weights, - at::Tensor const& dropout_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu index d20985a43..ae0b6f04b 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu @@ -5,9 +5,9 @@ #if __has_include() #include #endif +#include #include #include -#include #include #include @@ -22,9 +22,9 @@ namespace self_norm_add { namespace cublas_gemmex { std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, - at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, - at::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -163,13 +163,13 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads } std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, - at::Tensor const& dropout_results, at::Tensor const& softmax_results, - at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, - at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, - at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, - at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, - at::Tensor const& output_weights, at::Tensor const& dropout_mask, - at::Tensor const& dropout_add_mask, float dropout_prob) { + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, + at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp b/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp index 6d532ad19..6080b93e8 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "nccl_p2p_cuda.cuh" #include +#include "nccl_p2p_cuda.cuh" + namespace { at::Tensor apex_nccl_p2p_get_unique_nccl_id(int64_t n) { return apex::contrib::nccl_p2p::get_unique_nccl_id(static_cast(n)); @@ -24,23 +25,23 @@ at::Tensor apex_nccl_p2p_get_unique_nccl_id(int64_t n) { int64_t apex_nccl_p2p_init_nccl_comm(at::Tensor unique_nccl_id, int64_t my_rank, int64_t num_ranks) { return apex::contrib::nccl_p2p::init_nccl_comm(unique_nccl_id, static_cast(my_rank), - static_cast(num_ranks)); + static_cast(num_ranks)); } void apex_nccl_p2p_left_right_halo_exchange_inplace(int64_t handle, int64_t left_rank, int64_t right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo) { - apex::contrib::nccl_p2p::left_right_halo_exchange_inplace( - static_cast(handle), static_cast(left_rank), static_cast(right_rank), left_output_halo, - right_output_halo, left_input_halo, right_input_halo); + apex::contrib::nccl_p2p::left_right_halo_exchange_inplace(static_cast(handle), static_cast(left_rank), + static_cast(right_rank), left_output_halo, + right_output_halo, left_input_halo, right_input_halo); } std::vector apex_nccl_p2p_left_right_halo_exchange(int64_t handle, int64_t left_rank, int64_t right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo) { - return apex::contrib::nccl_p2p::left_right_halo_exchange( - static_cast(handle), static_cast(left_rank), static_cast(right_rank), left_output_halo, - right_output_halo); + return apex::contrib::nccl_p2p::left_right_halo_exchange(static_cast(handle), static_cast(left_rank), + static_cast(right_rank), left_output_halo, + right_output_halo); } void apex_nccl_p2p_add_delay(int64_t delay) { apex::contrib::nccl_p2p::add_delay(static_cast(delay)); } @@ -49,11 +50,13 @@ void apex_nccl_p2p_add_delay(int64_t delay) { apex::contrib::nccl_p2p::add_delay TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("nccl_p2p_get_unique_nccl_id(int n) -> Tensor"); m.def("nccl_p2p_init_nccl_comm(Tensor unique_nccl_id, int my_rank, int num_ranks) -> int"); - m.def("nccl_p2p_left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, " - "Tensor left_output_halo, Tensor right_output_halo, Tensor(a!) left_input_halo, " - "Tensor(b!) right_input_halo) -> ()"); - m.def("nccl_p2p_left_right_halo_exchange(int handle, int left_rank, int right_rank, Tensor left_output_halo, " - "Tensor right_output_halo) -> Tensor[]"); + m.def( + "nccl_p2p_left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, " + "Tensor left_output_halo, Tensor right_output_halo, Tensor(a!) left_input_halo, " + "Tensor(b!) right_input_halo) -> ()"); + m.def( + "nccl_p2p_left_right_halo_exchange(int handle, int left_rank, int right_rank, Tensor left_output_halo, " + "Tensor right_output_halo) -> Tensor[]"); m.def("nccl_p2p_add_delay(int delay) -> ()"); } diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh index 3b91e7619..dd6b4bd43 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh @@ -16,6 +16,7 @@ #pragma once #include + #include #ifndef _nccl_p2p_h_ #define _nccl_p2p_h_ diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp index 3652a39ac..baee1710c 100644 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp @@ -110,8 +110,8 @@ void adam_dispatch(const at::Tensor& p, const at::Tensor& p_copy, const at::Tens } void reversible_adam_dispatch(const at::Tensor& p, const at::Tensor& p_copy, const at::Tensor& m, const at::Tensor& v, - const at::Tensor& g, double lr, double beta1, double beta2, double eps, - double grad_scale, int64_t step, int64_t mode, int64_t bias_correction, double decay) { + const at::Tensor& g, double lr, double beta1, double beta2, double eps, double grad_scale, + int64_t step, int64_t mode, int64_t bias_correction, double decay) { at::Tensor p_arg = p; at::Tensor p_copy_arg = p_copy; at::Tensor m_arg = m; @@ -161,20 +161,25 @@ void maybe_cast_cuda_mt_dispatch(int64_t chunk_size, at::Tensor overflow_flag, } TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("fused_adam_strided_check_finite(Tensor(a!) overflow_flag, Tensor p_copy, int stride, " - "int clear_overflow_first) -> ()"); - m.def("fused_adam_adam(Tensor(a!) p, Tensor(b!) p_copy, Tensor(c!) m, Tensor(d!) v, Tensor g, " - "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " - "int bias_correction, float decay) -> ()"); - m.def("fused_adam_reversible_adam(Tensor(a!) p, Tensor(b!) p_copy, Tensor(c!) m, Tensor(d!) v, Tensor g, " - "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " - "int bias_correction, float decay) -> ()"); - m.def("fused_adam_adam_mt(int chunk_size, Tensor overflow_flag, Tensor[][] tensor_lists, float lr, " - "float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, " - "float decay) -> ()"); - m.def("fused_adam_maybe_adam_undo(Tensor overflow_flag, Tensor(a!) p, Tensor(b!) m, Tensor(c!) v, Tensor(d!) g, " - "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " - "int bias_correction, float decay) -> ()"); + m.def( + "fused_adam_strided_check_finite(Tensor(a!) overflow_flag, Tensor p_copy, int stride, " + "int clear_overflow_first) -> ()"); + m.def( + "fused_adam_adam(Tensor(a!) p, Tensor(b!) p_copy, Tensor(c!) m, Tensor(d!) v, Tensor g, " + "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " + "int bias_correction, float decay) -> ()"); + m.def( + "fused_adam_reversible_adam(Tensor(a!) p, Tensor(b!) p_copy, Tensor(c!) m, Tensor(d!) v, Tensor g, " + "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " + "int bias_correction, float decay) -> ()"); + m.def( + "fused_adam_adam_mt(int chunk_size, Tensor overflow_flag, Tensor[][] tensor_lists, float lr, " + "float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, " + "float decay) -> ()"); + m.def( + "fused_adam_maybe_adam_undo(Tensor overflow_flag, Tensor(a!) p, Tensor(b!) m, Tensor(c!) v, Tensor(d!) g, " + "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " + "int bias_correction, float decay) -> ()"); m.def("fused_adam_maybe_cast(Tensor overflow_flag, Tensor p_in, Tensor(a!) p_out) -> ()"); m.def("fused_adam_maybe_cast_mt(int chunk_size, Tensor overflow_flag, Tensor[][] tensor_lists) -> ()"); } diff --git a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp index a89c33aa6..200258e0c 100644 --- a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp @@ -18,11 +18,10 @@ void multi_tensor_lamb(int64_t chunk_size, at::Tensor noop_flag, std::vector ()"); + m.def( + "fused_lamb_lamb(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " + "float beta2, float epsilon, int step, int bias_correction, float weight_decay, int grad_averaging, " + "int mode, float global_grad_norm, float max_grad_norm) -> ()"); } -TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("fused_lamb_lamb", &multi_tensor_lamb); -} +TORCH_LIBRARY_IMPL(apex, CUDA, m) { m.impl("fused_lamb_lamb", &multi_tensor_lamb); } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp index 87f1bdb0e..6e88ea604 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp @@ -20,10 +20,10 @@ void multi_tensor_fused_adam(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor grad_scale, double lr, double beta1, double beta2, double eps, int64_t step, int64_t mode, int64_t bias_correction, double weight_decay) { - multi_tensor_fused_adam_cuda(static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, static_cast(lr), - static_cast(beta1), static_cast(beta2), static_cast(eps), - static_cast(step), static_cast(mode), static_cast(bias_correction), - static_cast(weight_decay)); + multi_tensor_fused_adam_cuda(static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, + static_cast(lr), static_cast(beta1), static_cast(beta2), + static_cast(eps), static_cast(step), static_cast(mode), + static_cast(bias_correction), static_cast(weight_decay)); } void multi_tensor_fused_adam_capturable(int64_t chunk_size, at::Tensor noop_flag, @@ -31,9 +31,9 @@ void multi_tensor_fused_adam_capturable(int64_t chunk_size, at::Tensor noop_flag at::Tensor lr, double beta1, double beta2, double eps, at::Tensor step, int64_t mode, int64_t bias_correction, double weight_decay) { multi_tensor_fused_adam_capturable_cuda(static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, lr, - static_cast(beta1), static_cast(beta2), - static_cast(eps), step, static_cast(mode), - static_cast(bias_correction), static_cast(weight_decay)); + static_cast(beta1), static_cast(beta2), static_cast(eps), + step, static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay)); } void multi_tensor_fused_adam_with_param_remainders(int64_t chunk_size, at::Tensor noop_flag, @@ -48,15 +48,18 @@ void multi_tensor_fused_adam_with_param_remainders(int64_t chunk_size, at::Tenso } TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("distributed_adam_multi_tensor_fused_adam(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " - "Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, " - "int bias_correction, float weight_decay) -> ()"); - m.def("distributed_adam_multi_tensor_fused_adam_capturable(int chunk_size, Tensor noop_flag, " - "Tensor[][] tensor_lists, Tensor grad_scale, Tensor lr, float beta1, float beta2, float eps, " - "Tensor step, int mode, int bias_correction, float weight_decay) -> ()"); - m.def("distributed_adam_multi_tensor_fused_adam_with_param_remainders(int chunk_size, Tensor noop_flag, " - "Tensor[][] tensor_lists, Tensor grad_scale, float lr, float beta1, float beta2, float eps, " - "int step, int mode, int bias_correction, float weight_decay) -> ()"); + m.def( + "distributed_adam_multi_tensor_fused_adam(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, " + "int bias_correction, float weight_decay) -> ()"); + m.def( + "distributed_adam_multi_tensor_fused_adam_capturable(int chunk_size, Tensor noop_flag, " + "Tensor[][] tensor_lists, Tensor grad_scale, Tensor lr, float beta1, float beta2, float eps, " + "Tensor step, int mode, int bias_correction, float weight_decay) -> ()"); + m.def( + "distributed_adam_multi_tensor_fused_adam_with_param_remainders(int chunk_size, Tensor noop_flag, " + "Tensor[][] tensor_lists, Tensor grad_scale, float lr, float beta1, float beta2, float eps, " + "int step, int mode, int bias_correction, float weight_decay) -> ()"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp index 6c7ee6fe3..63aa147ff 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp @@ -23,8 +23,8 @@ void multi_tensor_lamb_compute_update_term(int64_t chunk_size, at::Tensor noop_f at::Tensor global_grad_norm, double max_grad_norm) { multi_tensor_lamb_compute_update_term_cuda(static_cast(chunk_size), noop_flag, tensor_lists, per_tensor_beta1, per_tensor_beta2, per_tensor_beta3, per_tensor_bias_correction, step, - per_tensor_epsilon, static_cast(mode), per_tensor_decay, - global_scale, global_grad_norm, static_cast(max_grad_norm)); + per_tensor_epsilon, static_cast(mode), per_tensor_decay, global_scale, + global_grad_norm, static_cast(max_grad_norm)); } void multi_tensor_lamb_update_weights(int64_t chunk_size, at::Tensor noop_flag, @@ -38,13 +38,15 @@ void multi_tensor_lamb_update_weights(int64_t chunk_size, at::Tensor noop_flag, } TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("distributed_lamb_compute_update_term(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " - "Tensor per_tensor_beta1, Tensor per_tensor_beta2, Tensor per_tensor_beta3, " - "Tensor per_tensor_bias_correction, Tensor step, Tensor per_tensor_epsilon, int mode, " - "Tensor per_tensor_decay, Tensor global_scale, Tensor global_grad_norm, float max_grad_norm) -> ()"); - m.def("distributed_lamb_update_weights(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " - "Tensor per_tensor_param_norm, Tensor per_tensor_update_norm, Tensor update_norm_offset, " - "Tensor learning_rate, Tensor per_tensor_decay, Tensor global_grad_norm, bool use_nvlamb) -> ()"); + m.def( + "distributed_lamb_compute_update_term(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_beta1, Tensor per_tensor_beta2, Tensor per_tensor_beta3, " + "Tensor per_tensor_bias_correction, Tensor step, Tensor per_tensor_epsilon, int mode, " + "Tensor per_tensor_decay, Tensor global_scale, Tensor global_grad_norm, float max_grad_norm) -> ()"); + m.def( + "distributed_lamb_update_weights(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_param_norm, Tensor per_tensor_update_norm, Tensor update_norm_offset, " + "Tensor learning_rate, Tensor per_tensor_decay, Tensor global_grad_norm, bool use_nvlamb) -> ()"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/peer_memory/peer_memory.cpp b/apex/contrib/csrc/peer_memory/peer_memory.cpp index 52b2f53d4..f4ae8c80c 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory.cpp +++ b/apex/contrib/csrc/peer_memory/peer_memory.cpp @@ -14,9 +14,10 @@ * limitations under the License. */ -#include "peer_memory_cuda.cuh" #include +#include "peer_memory_cuda.cuh" + namespace { std::vector apex_peer_memory_get_raw_peers(at::Tensor ipc_addresses, int64_t peer_rank, int64_t raw) { return apex::contrib::peer_memory::get_raw_peers(ipc_addresses, static_cast(peer_rank), raw); @@ -44,8 +45,8 @@ void apex_peer_memory_push_pull_halos_1d(bool diagnostics, bool explicit_nhwc, i at::Tensor btm_out_transfer, at::Tensor btm_out_halo) { apex::contrib::peer_memory::push_pull_halos_1d(diagnostics, explicit_nhwc, static_cast(numSM), static_cast(rank), top_zero, top_in_halo, top_in_transfer, - top_out_transfer, top_out_halo, btm_zero, btm_in_halo, - btm_in_transfer, btm_out_transfer, btm_out_halo); + top_out_transfer, top_out_halo, btm_zero, btm_in_halo, btm_in_transfer, + btm_out_transfer, btm_out_halo); } } // namespace @@ -58,10 +59,11 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("peer_memory_blob_view_half(int raw, int[] shape, bool channels_last) -> Tensor"); m.def("peer_memory_blob_view_float(int raw, int[] shape, bool channels_last) -> Tensor"); m.def("peer_memory_blob_view_int(int raw, int[] shape, bool channels_last) -> Tensor"); - m.def("peer_memory_push_pull_halos_1d(bool diagnostics, bool explicit_nhwc, int numSM, int rank, bool top_zero, " - "Tensor(a!) top_in_halo, Tensor(b!) top_in_transfer, Tensor(c!) top_out_transfer, Tensor(d!) top_out_halo, " - "bool btm_zero, Tensor(e!) btm_in_halo, Tensor(f!) btm_in_transfer, Tensor(g!) btm_out_transfer, " - "Tensor(h!) btm_out_halo) -> ()"); + m.def( + "peer_memory_push_pull_halos_1d(bool diagnostics, bool explicit_nhwc, int numSM, int rank, bool top_zero, " + "Tensor(a!) top_in_halo, Tensor(b!) top_in_transfer, Tensor(c!) top_out_transfer, Tensor(d!) top_out_halo, " + "bool btm_zero, Tensor(e!) btm_in_halo, Tensor(f!) btm_in_transfer, Tensor(g!) btm_out_transfer, " + "Tensor(h!) btm_out_halo) -> ()"); } TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { @@ -75,6 +77,4 @@ TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { m.impl("peer_memory_blob_view_int", &apex_peer_memory_blob_view_int); } -TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("peer_memory_push_pull_halos_1d", &apex_peer_memory_push_pull_halos_1d); -} +TORCH_LIBRARY_IMPL(apex, CUDA, m) { m.impl("peer_memory_push_pull_halos_1d", &apex_peer_memory_push_pull_halos_1d); } diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh index 92a7fd876..4357ebaa9 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh @@ -16,6 +16,7 @@ #pragma once #include + #include #ifndef _peer_memory_h_ #define _peer_memory_h_ diff --git a/apex/contrib/csrc/transducer/transducer_joint.cpp b/apex/contrib/csrc/transducer/transducer_joint.cpp index 9eb012b74..4deedce0d 100644 --- a/apex/contrib/csrc/transducer/transducer_joint.cpp +++ b/apex/contrib/csrc/transducer/transducer_joint.cpp @@ -1,51 +1,53 @@ #include #include - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, - int64_t packedBatch, int64_t opt, bool packOutput, bool relu, - bool dropout, double dropoutProb, int64_t tileSize); -std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, int64_t maxFLen, - int64_t maxGLen, bool packOutput, double scale); +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) -std::vector transducer_joint_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, int64_t packedBatch, - int64_t opt, bool packOutput, bool relu, bool dropout, - double dropoutProb, int64_t tileSize) { - CHECK_INPUT(f); - CHECK_INPUT(g); - CHECK_INPUT(fLen); - CHECK_INPUT(gLen); - if (packOutput) CHECK_INPUT(batchOffset); - return transducer_joint_cuda_forward(f, g, fLen, gLen, batchOffset, packedBatch, opt, packOutput, relu, dropout, - dropoutProb, tileSize); -} - -std::vector transducer_joint_backward(std::vector in, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, int64_t maxFLen, - int64_t maxGLen, bool packOutput, double scale) { - for (auto t : in) { - CHECK_INPUT(t); - } - CHECK_INPUT(fLen); - CHECK_INPUT(gLen); - if (packOutput) CHECK_INPUT(batchOffset); +std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t packedBatch, int64_t opt, + bool packOutput, bool relu, bool dropout, double dropoutProb, + int64_t tileSize); + +std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t maxFLen, int64_t maxGLen, + bool packOutput, double scale); + +std::vector transducer_joint_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t packedBatch, int64_t opt, + bool packOutput, bool relu, bool dropout, double dropoutProb, + int64_t tileSize) { + CHECK_INPUT(f); + CHECK_INPUT(g); + CHECK_INPUT(fLen); + CHECK_INPUT(gLen); + if (packOutput) CHECK_INPUT(batchOffset); + return transducer_joint_cuda_forward(f, g, fLen, gLen, batchOffset, packedBatch, opt, packOutput, relu, dropout, + dropoutProb, tileSize); +} + +std::vector transducer_joint_backward(std::vector in, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t maxFLen, int64_t maxGLen, + bool packOutput, double scale) { + for (auto t : in) { + CHECK_INPUT(t); + } + CHECK_INPUT(fLen); + CHECK_INPUT(gLen); + if (packOutput) CHECK_INPUT(batchOffset); return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale); } TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("transducer_joint_forward(Tensor f, Tensor g, Tensor fLen, Tensor gLen, Tensor batchOffset, int packedBatch, " - "int opt, bool packOutput, bool relu, bool dropout, float dropoutProb, int tileSize) -> Tensor[]"); - m.def("transducer_joint_backward(Tensor[] input, Tensor fLen, Tensor gLen, Tensor batchOffset, int maxFLen, " - "int maxGLen, bool packOutput, float scale) -> Tensor[]"); + m.def( + "transducer_joint_forward(Tensor f, Tensor g, Tensor fLen, Tensor gLen, Tensor batchOffset, int packedBatch, " + "int opt, bool packOutput, bool relu, bool dropout, float dropoutProb, int tileSize) -> Tensor[]"); + m.def( + "transducer_joint_backward(Tensor[] input, Tensor fLen, Tensor gLen, Tensor batchOffset, int maxFLen, " + "int maxGLen, bool packOutput, float scale) -> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu index 402b9b314..d8915abe3 100644 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu @@ -1,8 +1,8 @@ -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #ifdef OLD_GENERATOR_PATH #include @@ -179,7 +179,7 @@ __global__ void transducer_joint_forward(const scalar_t* f, const scalar_t* g, c } } else if (packOutput == false and t < maxFLen and u < maxGLen) { // Need to write finite data to don't-care region because we instantiate the result tensor -// with at::empty for performance reasons. Even though it is don't-care region, the +// with at::empty for performance reasons. Even though it is don't-care region, the // contents need to be finite, otherwise could lead to NaN in WGRAD. // In packing mode, this write is no longer necessary as we remove the don't-care region // from the output. @@ -535,10 +535,10 @@ __global__ void transducer_joint_combined_vec_backward(const scalar_t* grad, con } } -std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, - int64_t packedBatch, int64_t opt, bool packOutput, bool relu, - bool dropout, double dropoutProb, int64_t tileSize) { +std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t packedBatch, int64_t opt, + bool packOutput, bool relu, bool dropout, double dropoutProb, + int64_t tileSize) { auto tensorOpt = f.options(); auto dtype = f.scalar_type(); const auto batchSize = f.size(0); @@ -548,17 +548,17 @@ std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g bool masked = dropout or relu; int64_t* batchOffsetPtr = nullptr; - at::Tensor sum, mask; - auto maskOpt = tensorOpt.dtype(at::kByte); - if (!packOutput) { - sum = at::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); - batchOffsetPtr = nullptr; - if (masked) mask = at::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); - } else { - sum = at::empty({packedBatch, hiddenSize}, tensorOpt); - batchOffsetPtr = batchOffset.data_ptr(); - if (masked) mask = at::empty({packedBatch, hiddenSize}, maskOpt); - } + at::Tensor sum, mask; + auto maskOpt = tensorOpt.dtype(at::kByte); + if (!packOutput) { + sum = at::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); + batchOffsetPtr = nullptr; + if (masked) mask = at::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); + } else { + sum = at::empty({packedBatch, hiddenSize}, tensorOpt); + batchOffsetPtr = batchOffset.data_ptr(); + if (masked) mask = at::empty({packedBatch, hiddenSize}, maskOpt); + } uint8_t* maskPtr = masked ? mask.data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -635,11 +635,10 @@ std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g } } - kernel<<>>(f.data_ptr(), g.data_ptr(), fLen.data_ptr(), - gLen.data_ptr(), batchOffsetPtr, maxFLen, maxGLen, hiddenSize, - hiddenPerBlock, packOutput, relu, dropout, - static_cast(1.0 - dropoutProb), - rng_engine_inputs, sum.data_ptr(), maskPtr); + kernel<<>>( + f.data_ptr(), g.data_ptr(), fLen.data_ptr(), gLen.data_ptr(), + batchOffsetPtr, maxFLen, maxGLen, hiddenSize, hiddenPerBlock, packOutput, relu, dropout, + static_cast(1.0 - dropoutProb), rng_engine_inputs, sum.data_ptr(), maskPtr); })); } @@ -650,9 +649,9 @@ std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g return {sum}; } -std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, - at::Tensor gLen, at::Tensor batchOffset, int64_t maxFLen, - int64_t maxGLen, bool packOutput, double scale) { +std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t maxFLen, int64_t maxGLen, + bool packOutput, double scale) { auto grad = in[0]; bool masked = (in.size() == 2); uint8_t* maskPtr = masked ? in[1].data_ptr() : nullptr; @@ -665,8 +664,8 @@ std::vector transducer_joint_cuda_backward(std::vector i const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE; - at::Tensor fGrad = at::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); - at::Tensor gGrad = at::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); + at::Tensor fGrad = at::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); + at::Tensor gGrad = at::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); int64_t* batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); @@ -704,9 +703,9 @@ std::vector transducer_joint_cuda_backward(std::vector i (reinterpret_cast(fGradPtr) % vecAlignment == 0) and (reinterpret_cast(gGradPtr) % vecAlignment == 0); - const float scale_arg = static_cast(scale); - - if (vectFactor > 1 and hiddenSize % vectFactor == 0 and memAlign) { + const float scale_arg = static_cast(scale); + + if (vectFactor > 1 and hiddenSize % vectFactor == 0 and memAlign) { // If vectorization helps and the alignment requirement is met, use the vectorized // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. const dim3 blocks((hiddenSize + C10_WARP_SIZE * vectFactor - 1) / (C10_WARP_SIZE * vectFactor), @@ -714,26 +713,26 @@ std::vector transducer_joint_cuda_backward(std::vector i if (masked) { transducer_joint_combined_vec_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, - fGradPtr, gGradPtr); + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, + fGradPtr, gGradPtr); } else { transducer_joint_combined_vec_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, - fGradPtr, gGradPtr); + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, + fGradPtr, gGradPtr); } } else { const dim3 blocks((hiddenSize + C10_WARP_SIZE - 1) / C10_WARP_SIZE, maxFLen + maxGLen, batchSize); if (masked) { transducer_joint_combined_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, - fGradPtr, gGradPtr); + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, + fGradPtr, gGradPtr); } else { transducer_joint_combined_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, - fGradPtr, gGradPtr); + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, + fGradPtr, gGradPtr); } } })); diff --git a/apex/contrib/csrc/transducer/transducer_loss.cpp b/apex/contrib/csrc/transducer/transducer_loss.cpp index c31cb0fb2..48f20ca03 100644 --- a/apex/contrib/csrc/transducer/transducer_loss.cpp +++ b/apex/contrib/csrc/transducer/transducer_loss.cpp @@ -10,17 +10,17 @@ CHECK_CONTIGUOUS(x) std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor label, at::Tensor audLen, - at::Tensor txtLen, at::Tensor batchOffset, int64_t maxFLen, - int64_t blankIdx, int64_t opt, bool packedInput); + at::Tensor txtLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool packedInput); -at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, - at::Tensor beta, at::Tensor audLen, at::Tensor txtLen, - at::Tensor label, at::Tensor batchOffset, int64_t maxFLen, - int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, bool packedInput); +at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, + at::Tensor audLen, at::Tensor txtLen, at::Tensor label, at::Tensor batchOffset, + int64_t maxFLen, int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, + bool packedInput); -std::vector transducer_loss_forward(at::Tensor x, at::Tensor label, at::Tensor fLen, - at::Tensor yLen, at::Tensor batchOffset, int64_t maxFLen, - int64_t blankIdx, int64_t opt, bool packedInput) { +std::vector transducer_loss_forward(at::Tensor x, at::Tensor label, at::Tensor fLen, at::Tensor yLen, + at::Tensor batchOffset, int64_t maxFLen, int64_t blankIdx, int64_t opt, + bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); CHECK_INPUT(fLen); @@ -30,9 +30,9 @@ std::vector transducer_loss_forward(at::Tensor x, at::Tensor label, } at::Tensor transducer_loss_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, - at::Tensor fLen, at::Tensor yLen, at::Tensor label, - at::Tensor batchOffset, int64_t maxFLen, int64_t blankIdx, int64_t opt, - bool fuseSoftmaxBackward, bool packedInput) { + at::Tensor fLen, at::Tensor yLen, at::Tensor label, at::Tensor batchOffset, + int64_t maxFLen, int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, + bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); CHECK_INPUT(lossGrad); @@ -47,11 +47,13 @@ at::Tensor transducer_loss_backward(at::Tensor x, at::Tensor lossGrad, at::Tenso } TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("transducer_loss_forward(Tensor x, Tensor label, Tensor fLen, Tensor yLen, Tensor batchOffset, int maxFLen, " - "int blankIdx, int opt, bool packedInput) -> Tensor[]"); - m.def("transducer_loss_backward(Tensor x, Tensor lossGrad, Tensor alpha, Tensor beta, Tensor fLen, Tensor yLen, " - "Tensor label, Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, " - "bool packedInput) -> Tensor"); + m.def( + "transducer_loss_forward(Tensor x, Tensor label, Tensor fLen, Tensor yLen, Tensor batchOffset, int maxFLen, " + "int blankIdx, int opt, bool packedInput) -> Tensor[]"); + m.def( + "transducer_loss_backward(Tensor x, Tensor lossGrad, Tensor alpha, Tensor beta, Tensor fLen, Tensor yLen, " + "Tensor label, Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, " + "bool packedInput) -> Tensor"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu index 1a41018ac..437dcbb2a 100644 --- a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu @@ -455,8 +455,8 @@ __global__ void transducer_loss_fused_vec_backward(const scalar_t* x, const scal } std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor label, at::Tensor audLen, - at::Tensor txtLen, at::Tensor batchOffset, int64_t maxFLen, - int64_t blankIdx, int64_t opt, bool packedInput) { + at::Tensor txtLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool packedInput) { auto scalarType = x.scalar_type(); auto tensorOpt = x.options(); const int batchSize = label.size(0); @@ -513,10 +513,10 @@ std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor la return {alpha, beta, loss}; } -at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, - at::Tensor beta, at::Tensor audLen, at::Tensor txtLen, - at::Tensor label, at::Tensor batchOffset, int64_t maxFLen, - int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, bool packedInput) { +at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, + at::Tensor audLen, at::Tensor txtLen, at::Tensor label, at::Tensor batchOffset, + int64_t maxFLen, int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, + bool packedInput) { auto dtype = x.scalar_type(); at::Tensor xGrad; const int batchSize = label.size(0); diff --git a/apex/contrib/csrc/xentropy/interface.cpp b/apex/contrib/csrc/xentropy/interface.cpp index 40461798d..748d85d0b 100644 --- a/apex/contrib/csrc/xentropy/interface.cpp +++ b/apex/contrib/csrc/xentropy/interface.cpp @@ -49,8 +49,9 @@ std::string softmax_xentropy_version() { TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("xentropy_forward(Tensor input, Tensor labels, float smoothing, bool half_to_float) -> Tensor[]"); - m.def("xentropy_backward(Tensor grad_loss, Tensor logits, Tensor max_log_sum_exp, Tensor labels, " - "float smoothing) -> Tensor"); + m.def( + "xentropy_backward(Tensor grad_loss, Tensor logits, Tensor max_log_sum_exp, Tensor labels, " + "float smoothing) -> Tensor"); m.def("xentropy_version() -> str"); } @@ -59,6 +60,4 @@ TORCH_LIBRARY_IMPL(apex, CUDA, m) { m.impl("xentropy_backward", &softmax_xentropy_backward); } -TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { - m.impl("xentropy_version", &softmax_xentropy_version); -} +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { m.impl("xentropy_version", &softmax_xentropy_version); } diff --git a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu b/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu index caf3aa62e..8c8b0c159 100644 --- a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu +++ b/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu @@ -1,8 +1,9 @@ +#include #include #include #include + #include -#include #define gpuErrchk(ans) \ { \ @@ -470,11 +471,13 @@ void set_up_swap_map_memory(float** dmatrix, unsigned int rows, unsigned int col /////////////////////////////////////////////////////////// -int64_t apex_permutation_search_check_permutations( - torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, int64_t cols_arg, - torch::stable::Tensor const& stripe_groups_tensor, int64_t group_width_arg, int64_t num_groups_arg, - torch::stable::Tensor const& permutations_tensor, int64_t num_permutations_arg, - torch::stable::Tensor const& improvement_tensor, torch::stable::Tensor const& permutation_tensor) { +int64_t apex_permutation_search_check_permutations(torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, + int64_t cols_arg, torch::stable::Tensor const& stripe_groups_tensor, + int64_t group_width_arg, int64_t num_groups_arg, + torch::stable::Tensor const& permutations_tensor, + int64_t num_permutations_arg, + torch::stable::Tensor const& improvement_tensor, + torch::stable::Tensor const& permutation_tensor) { static float* d_matrix; static unsigned int* d_permutations; static unsigned int* d_stripes; @@ -556,11 +559,13 @@ int64_t apex_permutation_search_sum_after_2_to_4(torch::stable::Tensor const& ma return 0; } -int64_t apex_permutation_search_build_permute_map( - torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, int64_t cols_arg, - torch::stable::Tensor const& stripes_tensor, int64_t num_groups_arg, int64_t group_width_arg, - torch::stable::Tensor const& permutations_tensor, int64_t perm_length_arg, - torch::stable::Tensor const& improvements_tensor, torch::stable::Tensor const& best_indices_tensor) { +int64_t apex_permutation_search_build_permute_map(torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, + int64_t cols_arg, torch::stable::Tensor const& stripes_tensor, + int64_t num_groups_arg, int64_t group_width_arg, + torch::stable::Tensor const& permutations_tensor, + int64_t perm_length_arg, + torch::stable::Tensor const& improvements_tensor, + torch::stable::Tensor const& best_indices_tensor) { static float* d_matrix = NULL; static unsigned int* d_stripes = NULL; static unsigned int* d_permutations = NULL; @@ -583,8 +588,7 @@ int64_t apex_permutation_search_build_permute_map( const unsigned int launches = full_launches + (final_launch != 0 ? 1 : 0); set_up_permute_map_memory(&d_matrix, rows, cols, &d_stripes, min(num_groups, MAX_GROUPS_PER_LAUNCH), group_width, - &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, - &hindices); + &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, &hindices); float* matrix = float_ptr_from_tensor(matrix_tensor, "matrix"); unsigned int* stripes = uint_ptr_from_tensor(stripes_tensor, "stripes"); @@ -626,8 +630,7 @@ int64_t apex_permutation_search_build_permute_map( } int64_t apex_permutation_search_build_swap_map(torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, - int64_t cols_arg, - torch::stable::Tensor const& stripe_pairs_tensor, + int64_t cols_arg, torch::stable::Tensor const& stripe_pairs_tensor, torch::stable::Tensor const& output_tensor) { static float* d_matrix = NULL; static float* d_result = NULL; @@ -655,16 +658,20 @@ int64_t apex_permutation_search_build_swap_map(torch::stable::Tensor const& matr } STABLE_TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("permutation_search_sum_after_2_to_4(Tensor matrix, int rows, int cols, int start_col, int end_col, " - "int blocks, int threads, Tensor(a!) output) -> int"); - m.def("permutation_search_build_permute_map(Tensor matrix, int rows, int cols, Tensor stripes, int num_groups, " - "int group_width, Tensor permutations, int perm_length, Tensor(a!) improvements, Tensor(b!) best_indices) " - "-> int"); - m.def("permutation_search_check_permutations(Tensor matrix, int rows, int cols, Tensor stripe_groups, " - "int group_width, int num_groups, Tensor permutations, int num_permutations, Tensor(a!) improvement, " - "Tensor(b!) permutation) -> int"); - m.def("permutation_search_build_swap_map(Tensor matrix, int rows, int cols, Tensor stripe_pairs, " - "Tensor(a!) output) -> int"); + m.def( + "permutation_search_sum_after_2_to_4(Tensor matrix, int rows, int cols, int start_col, int end_col, " + "int blocks, int threads, Tensor(a!) output) -> int"); + m.def( + "permutation_search_build_permute_map(Tensor matrix, int rows, int cols, Tensor stripes, int num_groups, " + "int group_width, Tensor permutations, int perm_length, Tensor(a!) improvements, Tensor(b!) best_indices) " + "-> int"); + m.def( + "permutation_search_check_permutations(Tensor matrix, int rows, int cols, Tensor stripe_groups, " + "int group_width, int num_groups, Tensor permutations, int num_permutations, Tensor(a!) improvement, " + "Tensor(b!) permutation) -> int"); + m.def( + "permutation_search_build_swap_map(Tensor matrix, int rows, int cols, Tensor stripe_pairs, " + "Tensor(a!) output) -> int"); } STABLE_TORCH_LIBRARY_IMPL(apex, CPU, m) { diff --git a/apex/contrib/transducer/transducer.py b/apex/contrib/transducer/transducer.py index c4c4f9103..18db7f826 100755 --- a/apex/contrib/transducer/transducer.py +++ b/apex/contrib/transducer/transducer.py @@ -1,6 +1,6 @@ import torch -from apex._extensions import transducer_loss_cuda -from apex._extensions import transducer_joint_cuda +from apex._extensions import transducer_loss_cuda +from apex._extensions import transducer_joint_cuda class TransducerJoint(torch.nn.Module): diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index 5a3d7bc83..cd61dec85 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -40,7 +40,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -87,7 +89,9 @@ def fused_layer_norm_affine_fwd( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) input_ = input.contiguous() weight_ = weight.contiguous() @@ -204,7 +208,9 @@ class FusedRMSNormAffineFunction(torch.autograd.Function): def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -247,7 +253,9 @@ def fused_rms_norm_affine_fwd( ) -> Tuple[torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) input_ = input.contiguous() weight_ = weight.contiguous() @@ -358,7 +366,9 @@ class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -380,7 +390,9 @@ class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -401,7 +413,9 @@ class FusedLayerNormFunction(torch.autograd.Function): def forward(ctx, input, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -439,7 +453,9 @@ def fused_layer_norm_fwd( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) input_ = input.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward(input_, normalized_shape, eps) @@ -535,7 +551,9 @@ class FusedRMSNormFunction(torch.autograd.Function): def forward(ctx, input, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -573,7 +591,9 @@ def fused_rms_norm_fwd( ) -> Tuple[torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) input_ = input.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward(input_, normalized_shape, eps) diff --git a/csrc/amp_C_frontend.cpp b/csrc/amp_C_frontend.cpp index 694aaf760..7ad36f74b 100644 --- a/csrc/amp_C_frontend.cpp +++ b/csrc/amp_C_frontend.cpp @@ -86,10 +86,9 @@ void apex_multi_tensor_scale(int64_t chunk_size, at::Tensor noop_flag, multi_tensor_scale_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(scale)); } -void apex_multi_tensor_sgd(int64_t chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, double wd, double momentum, - double dampening, double lr, bool nesterov, bool first_run, bool wd_after_momentum, - double scale) { +void apex_multi_tensor_sgd(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + double wd, double momentum, double dampening, double lr, bool nesterov, bool first_run, + bool wd_after_momentum, double scale) { multi_tensor_sgd_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(wd), static_cast(momentum), static_cast(dampening), static_cast(lr), nesterov, first_run, wd_after_momentum, static_cast(scale)); @@ -102,54 +101,56 @@ void apex_multi_tensor_axpby(int64_t chunk_size, at::Tensor noop_flag, static_cast(b), static_cast(arg_to_check)); } -std::tuple apex_multi_tensor_l2norm( - int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, - at::optional per_tensor_python) { +std::tuple apex_multi_tensor_l2norm(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python) { return multi_tensor_l2norm_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), per_tensor_python); } -std::tuple apex_multi_tensor_l2norm_mp( - int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, - at::optional per_tensor_python) { +std::tuple apex_multi_tensor_l2norm_mp(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python) { return multi_tensor_l2norm_mp_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), - per_tensor_python); + per_tensor_python); } -std::tuple apex_multi_tensor_l2norm_scale( - int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, double scale, - at::optional per_tensor_python) { +std::tuple apex_multi_tensor_l2norm_scale(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + double scale, at::optional per_tensor_python) { return multi_tensor_l2norm_scale_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), - static_cast(scale), per_tensor_python); + static_cast(scale), per_tensor_python); } -std::tuple apex_multi_tensor_unscale_l2norm( - int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor inv_scale, - at::optional per_tensor_python) { +std::tuple apex_multi_tensor_unscale_l2norm(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor inv_scale, + at::optional per_tensor_python) { return multi_tensor_unscale_l2norm_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), inv_scale, - per_tensor_python); + per_tensor_python); } void apex_multi_tensor_lamb_stage1_cuda(int64_t chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_decay, int64_t step, double beta1, double beta2, - double epsilon, at::Tensor global_grad_norm, double max_global_grad_norm) { + std::vector> tensor_lists, at::Tensor per_tensor_decay, + int64_t step, double beta1, double beta2, double epsilon, + at::Tensor global_grad_norm, double max_global_grad_norm) { multi_tensor_lamb_stage1_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), per_tensor_decay, static_cast(step), static_cast(beta1), static_cast(beta2), - static_cast(epsilon), global_grad_norm, static_cast(max_global_grad_norm)); + static_cast(epsilon), global_grad_norm, + static_cast(max_global_grad_norm)); } void apex_multi_tensor_lamb_stage2_cuda(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, double lr, double weight_decay, at::optional use_nvlamb_python) { - multi_tensor_lamb_stage2_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), - per_tensor_param_norm, per_tensor_update_norm, static_cast(lr), - static_cast(weight_decay), use_nvlamb_python); + multi_tensor_lamb_stage2_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), per_tensor_param_norm, + per_tensor_update_norm, static_cast(lr), static_cast(weight_decay), + use_nvlamb_python); } -void apex_multi_tensor_adam(int64_t chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, double lr, double beta1, double beta2, - double epsilon, int64_t step, int64_t mode, int64_t bias_correction, double weight_decay) { +void apex_multi_tensor_adam(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + double lr, double beta1, double beta2, double epsilon, int64_t step, int64_t mode, + int64_t bias_correction, double weight_decay) { multi_tensor_adam_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(lr), static_cast(beta1), static_cast(beta2), static_cast(epsilon), static_cast(step), static_cast(mode), static_cast(bias_correction), @@ -168,9 +169,8 @@ void apex_multi_tensor_adam_capturable(int64_t chunk_size, at::Tensor noop_flag, void apex_multi_tensor_adam_capturable_master(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor lr, - double beta1, double beta2, double epsilon, at::Tensor step, - int64_t mode, int64_t bias_correction, double weight_decay, - at::Tensor inv_scale) { + double beta1, double beta2, double epsilon, at::Tensor step, int64_t mode, + int64_t bias_correction, double weight_decay, at::Tensor inv_scale) { multi_tensor_adam_capturable_master_cuda( static_cast(chunk_size), noop_flag, std::move(tensor_lists), lr, static_cast(beta1), static_cast(beta2), static_cast(epsilon), step, static_cast(mode), @@ -191,15 +191,14 @@ void apex_multi_tensor_novograd(int64_t chunk_size, at::Tensor noop_flag, multi_tensor_novograd_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), grad_norms, static_cast(lr), static_cast(beta1), static_cast(beta2), static_cast(epsilon), static_cast(step), static_cast(bias_correction), - static_cast(weight_decay), static_cast(grad_averaging), - static_cast(mode), static_cast(norm_type)); + static_cast(weight_decay), static_cast(grad_averaging), static_cast(mode), + static_cast(norm_type)); } -void apex_multi_tensor_lamb(int64_t chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, double lr, double beta1, double beta2, - double epsilon, int64_t step, int64_t bias_correction, double weight_decay, - int64_t grad_averaging, int64_t mode, at::Tensor global_grad_norm, double max_grad_norm, - at::optional use_nvlamb_python) { +void apex_multi_tensor_lamb(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + double lr, double beta1, double beta2, double epsilon, int64_t step, + int64_t bias_correction, double weight_decay, int64_t grad_averaging, int64_t mode, + at::Tensor global_grad_norm, double max_grad_norm, at::optional use_nvlamb_python) { multi_tensor_lamb_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(lr), static_cast(beta1), static_cast(beta2), static_cast(epsilon), static_cast(step), static_cast(bias_correction), static_cast(weight_decay), @@ -210,9 +209,9 @@ void apex_multi_tensor_lamb(int64_t chunk_size, at::Tensor noop_flag, void apex_multi_tensor_lamb_mp(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor lr, double beta1, double beta2, double epsilon, at::Tensor step, int64_t bias_correction, - double weight_decay, int64_t grad_averaging, int64_t mode, - at::Tensor global_grad_norm, at::Tensor max_grad_norm, - at::optional use_nvlamb_python, at::Tensor found_inf, at::Tensor inv_scale) { + double weight_decay, int64_t grad_averaging, int64_t mode, at::Tensor global_grad_norm, + at::Tensor max_grad_norm, at::optional use_nvlamb_python, at::Tensor found_inf, + at::Tensor inv_scale) { multi_tensor_lamb_mp_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), lr, static_cast(beta1), static_cast(beta2), static_cast(epsilon), step, static_cast(bias_correction), static_cast(weight_decay), @@ -230,46 +229,62 @@ at::Tensor apex_update_scale_hysteresis(at::Tensor current_scale, at::Tensor gro TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("amp_multi_tensor_scale(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float scale) -> ()"); - m.def("amp_multi_tensor_sgd(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float wd, float momentum, " - "float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale) -> ()"); - m.def("amp_multi_tensor_axpby(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float a, float b, " - "int arg_to_check) -> ()"); - m.def("amp_multi_tensor_l2norm(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, bool? per_tensor_python) " - "-> (Tensor, Tensor)"); - m.def("amp_multi_tensor_l2norm_mp(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " - "bool? per_tensor_python) -> (Tensor, Tensor)"); - m.def("amp_multi_tensor_l2norm_scale(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float scale, " - "bool? per_tensor_python) -> (Tensor, Tensor)"); - m.def("amp_multi_tensor_unscale_l2norm(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor inv_scale, " - "bool? per_tensor_python) -> (Tensor, Tensor)"); - m.def("amp_multi_tensor_lamb_stage1_cuda(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " - "Tensor per_tensor_decay, int step, float beta1, float beta2, float epsilon, Tensor global_grad_norm, " - "float max_global_grad_norm) -> ()"); - m.def("amp_multi_tensor_lamb_stage2_cuda(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " - "Tensor per_tensor_param_norm, Tensor per_tensor_update_norm, float lr, float weight_decay, " - "bool? use_nvlamb_python) -> ()"); - m.def("amp_multi_tensor_adam(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " - "float beta2, float epsilon, int step, int mode, int bias_correction, float weight_decay) -> ()"); - m.def("amp_multi_tensor_adam_capturable(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, " - "float beta1, float beta2, float epsilon, Tensor step, int mode, int bias_correction, float weight_decay, " - "Tensor inv_scale) -> ()"); - m.def("amp_multi_tensor_adam_capturable_master(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, " - "float beta1, float beta2, float epsilon, Tensor step, int mode, int bias_correction, float weight_decay, " - "Tensor inv_scale) -> ()"); - m.def("amp_multi_tensor_adagrad(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float epsilon, " - "int mode, float weight_decay) -> ()"); - m.def("amp_multi_tensor_novograd(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor grad_norms, " - "float lr, float beta1, float beta2, float epsilon, int step, int bias_correction, float weight_decay, " - "int grad_averaging, int mode, int norm_type) -> ()"); - m.def("amp_multi_tensor_lamb(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " - "float beta2, float epsilon, int step, int bias_correction, float weight_decay, int grad_averaging, " - "int mode, Tensor global_grad_norm, float max_grad_norm, bool? use_nvlamb_python) -> ()"); - m.def("amp_multi_tensor_lamb_mp(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, float beta1, " - "float beta2, float epsilon, Tensor step, int bias_correction, float weight_decay, int grad_averaging, " - "int mode, Tensor global_grad_norm, Tensor max_grad_norm, bool? use_nvlamb_python, Tensor found_inf, " - "Tensor inv_scale) -> ()"); - m.def("amp_update_scale_hysteresis(Tensor current_scale, Tensor growth_tracker, Tensor hysteresis_tracker, " - "Tensor found_inf, float growth_factor, float backoff_factor, int growth_interval, int hysteresis) -> Tensor"); + m.def( + "amp_multi_tensor_sgd(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float wd, float momentum, " + "float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale) -> ()"); + m.def( + "amp_multi_tensor_axpby(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float a, float b, " + "int arg_to_check) -> ()"); + m.def( + "amp_multi_tensor_l2norm(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, bool? per_tensor_python) " + "-> (Tensor, Tensor)"); + m.def( + "amp_multi_tensor_l2norm_mp(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "bool? per_tensor_python) -> (Tensor, Tensor)"); + m.def( + "amp_multi_tensor_l2norm_scale(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float scale, " + "bool? per_tensor_python) -> (Tensor, Tensor)"); + m.def( + "amp_multi_tensor_unscale_l2norm(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor inv_scale, " + "bool? per_tensor_python) -> (Tensor, Tensor)"); + m.def( + "amp_multi_tensor_lamb_stage1_cuda(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_decay, int step, float beta1, float beta2, float epsilon, Tensor global_grad_norm, " + "float max_global_grad_norm) -> ()"); + m.def( + "amp_multi_tensor_lamb_stage2_cuda(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_param_norm, Tensor per_tensor_update_norm, float lr, float weight_decay, " + "bool? use_nvlamb_python) -> ()"); + m.def( + "amp_multi_tensor_adam(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " + "float beta2, float epsilon, int step, int mode, int bias_correction, float weight_decay) -> ()"); + m.def( + "amp_multi_tensor_adam_capturable(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, " + "float beta1, float beta2, float epsilon, Tensor step, int mode, int bias_correction, float weight_decay, " + "Tensor inv_scale) -> ()"); + m.def( + "amp_multi_tensor_adam_capturable_master(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, " + "float beta1, float beta2, float epsilon, Tensor step, int mode, int bias_correction, float weight_decay, " + "Tensor inv_scale) -> ()"); + m.def( + "amp_multi_tensor_adagrad(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float epsilon, " + "int mode, float weight_decay) -> ()"); + m.def( + "amp_multi_tensor_novograd(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor grad_norms, " + "float lr, float beta1, float beta2, float epsilon, int step, int bias_correction, float weight_decay, " + "int grad_averaging, int mode, int norm_type) -> ()"); + m.def( + "amp_multi_tensor_lamb(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " + "float beta2, float epsilon, int step, int bias_correction, float weight_decay, int grad_averaging, " + "int mode, Tensor global_grad_norm, float max_grad_norm, bool? use_nvlamb_python) -> ()"); + m.def( + "amp_multi_tensor_lamb_mp(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, float beta1, " + "float beta2, float epsilon, Tensor step, int bias_correction, float weight_decay, int grad_averaging, " + "int mode, Tensor global_grad_norm, Tensor max_grad_norm, bool? use_nvlamb_python, Tensor found_inf, " + "Tensor inv_scale) -> ()"); + m.def( + "amp_update_scale_hysteresis(Tensor current_scale, Tensor growth_tracker, Tensor hysteresis_tracker, " + "Tensor found_inf, float growth_factor, float backoff_factor, int growth_interval, int hysteresis) -> Tensor"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/csrc/fused_dense.cpp b/csrc/fused_dense.cpp index eb04868d4..4d98eb7ef 100644 --- a/csrc/fused_dense.cpp +++ b/csrc/fused_dense.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include #include #include @@ -157,10 +157,12 @@ std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("fused_dense_linear_bias_forward(Tensor input, Tensor weight, Tensor bias) -> Tensor"); m.def("fused_dense_linear_bias_backward(Tensor input, Tensor weight, Tensor d_output) -> Tensor[]"); - m.def("fused_dense_linear_gelu_linear_forward(Tensor input, Tensor weight1, Tensor bias1, Tensor weight2, " - "Tensor bias2) -> Tensor[]"); - m.def("fused_dense_linear_gelu_linear_backward(Tensor input, Tensor gelu_in, Tensor output1, Tensor weight1, " - "Tensor weight2, Tensor d_output2) -> Tensor[]"); + m.def( + "fused_dense_linear_gelu_linear_forward(Tensor input, Tensor weight1, Tensor bias1, Tensor weight2, " + "Tensor bias2) -> Tensor[]"); + m.def( + "fused_dense_linear_gelu_linear_backward(Tensor input, Tensor gelu_in, Tensor output1, Tensor weight1, " + "Tensor weight2, Tensor d_output2) -> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/csrc/layer_norm_cuda.cpp b/csrc/layer_norm_cuda.cpp index cbb50ba22..3361f22c5 100644 --- a/csrc/layer_norm_cuda.cpp +++ b/csrc/layer_norm_cuda.cpp @@ -261,13 +261,10 @@ at::Tensor apex_layer_norm_gradient(const at::Tensor& dout, const std::optional< return layer_norm_gradient(dout_, mean, invvar_, input_or_output_, normalized_shape, epsilon, memory_efficient); } -std::vector apex_layer_norm_gradient_affine(const at::Tensor& dout, - const std::optional& mean, - const at::Tensor& invvar, - const at::Tensor& input_or_output, +std::vector apex_layer_norm_gradient_affine(const at::Tensor& dout, const std::optional& mean, + const at::Tensor& invvar, const at::Tensor& input_or_output, at::IntArrayRef normalized_shape, const at::Tensor& gamma, - const at::Tensor& beta, double epsilon, - bool memory_efficient) { + const at::Tensor& beta, double epsilon, bool memory_efficient) { at::Tensor dout_ = dout; at::Tensor invvar_ = invvar; at::Tensor input_or_output_ = input_or_output; @@ -277,9 +274,8 @@ std::vector apex_layer_norm_gradient_affine(const at::Tensor& dout, memory_efficient); } -at::Tensor apex_rms_norm_gradient(const at::Tensor& dout, const at::Tensor& invvar, - const at::Tensor& input_or_output, at::IntArrayRef normalized_shape, double epsilon, - bool memory_efficient) { +at::Tensor apex_rms_norm_gradient(const at::Tensor& dout, const at::Tensor& invvar, const at::Tensor& input_or_output, + at::IntArrayRef normalized_shape, double epsilon, bool memory_efficient) { at::Tensor dout_ = dout; at::Tensor invvar_ = invvar; at::Tensor input_or_output_ = input_or_output; @@ -300,24 +296,32 @@ std::vector apex_rms_norm_gradient_affine(const at::Tensor& dout, co } // namespace TORCH_LIBRARY_FRAGMENT(apex, m) { - m.def("fused_layer_norm_forward_affine(Tensor input, int[] normalized_shape, Tensor gamma, Tensor beta, " - "float epsilon) -> Tensor[]"); + m.def( + "fused_layer_norm_forward_affine(Tensor input, int[] normalized_shape, Tensor gamma, Tensor beta, " + "float epsilon) -> Tensor[]"); m.def("fused_layer_norm_forward(Tensor input, int[] normalized_shape, float epsilon) -> Tensor[]"); - m.def("fused_layer_norm_backward_affine(Tensor dout, Tensor? mean, Tensor invvar, Tensor input_or_output, " - "int[] normalized_shape, Tensor gamma, Tensor beta, float epsilon, bool memory_efficient) -> Tensor[]"); - m.def("fused_layer_norm_backward(Tensor dout, Tensor? mean, Tensor invvar, Tensor input_or_output, " - "int[] normalized_shape, float epsilon, bool memory_efficient) -> Tensor"); - m.def("fused_layer_norm_forward_affine_mixed_dtypes(Tensor input, int[] normalized_shape, Tensor gamma, Tensor beta, " - "float epsilon) -> Tensor[]"); - m.def("fused_layer_norm_rms_forward_affine(Tensor input, int[] normalized_shape, Tensor gamma, float epsilon) " - "-> Tensor[]"); + m.def( + "fused_layer_norm_backward_affine(Tensor dout, Tensor? mean, Tensor invvar, Tensor input_or_output, " + "int[] normalized_shape, Tensor gamma, Tensor beta, float epsilon, bool memory_efficient) -> Tensor[]"); + m.def( + "fused_layer_norm_backward(Tensor dout, Tensor? mean, Tensor invvar, Tensor input_or_output, " + "int[] normalized_shape, float epsilon, bool memory_efficient) -> Tensor"); + m.def( + "fused_layer_norm_forward_affine_mixed_dtypes(Tensor input, int[] normalized_shape, Tensor gamma, Tensor beta, " + "float epsilon) -> Tensor[]"); + m.def( + "fused_layer_norm_rms_forward_affine(Tensor input, int[] normalized_shape, Tensor gamma, float epsilon) " + "-> Tensor[]"); m.def("fused_layer_norm_rms_forward(Tensor input, int[] normalized_shape, float epsilon) -> Tensor[]"); - m.def("fused_layer_norm_rms_backward_affine(Tensor dout, Tensor invvar, Tensor input_or_output, " - "int[] normalized_shape, Tensor gamma, float epsilon, bool memory_efficient) -> Tensor[]"); - m.def("fused_layer_norm_rms_backward(Tensor dout, Tensor invvar, Tensor input_or_output, int[] normalized_shape, " - "float epsilon, bool memory_efficient) -> Tensor"); - m.def("fused_layer_norm_rms_forward_affine_mixed_dtypes(Tensor input, int[] normalized_shape, Tensor gamma, " - "float epsilon) -> Tensor[]"); + m.def( + "fused_layer_norm_rms_backward_affine(Tensor dout, Tensor invvar, Tensor input_or_output, " + "int[] normalized_shape, Tensor gamma, float epsilon, bool memory_efficient) -> Tensor[]"); + m.def( + "fused_layer_norm_rms_backward(Tensor dout, Tensor invvar, Tensor input_or_output, int[] normalized_shape, " + "float epsilon, bool memory_efficient) -> Tensor"); + m.def( + "fused_layer_norm_rms_forward_affine_mixed_dtypes(Tensor input, int[] normalized_shape, Tensor gamma, " + "float epsilon) -> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/csrc/megatron/fused_rotary_positional_embedding.cpp b/csrc/megatron/fused_rotary_positional_embedding.cpp index fb6762e6d..4efd711c7 100644 --- a/csrc/megatron/fused_rotary_positional_embedding.cpp +++ b/csrc/megatron/fused_rotary_positional_embedding.cpp @@ -134,8 +134,8 @@ at::Tensor bwd_thd(const at::Tensor& output_grads, const at::Tensor& cu_seqlens, return bwd_thd_cuda(output_grads, cu_seqlens, freqs); } -at::Tensor fwd_2d(const at::Tensor& input, const at::Tensor& cos_h, const at::Tensor& sin_h, - const at::Tensor& cos_w, const at::Tensor& sin_w) { +at::Tensor fwd_2d(const at::Tensor& input, const at::Tensor& cos_h, const at::Tensor& sin_h, const at::Tensor& cos_w, + const at::Tensor& sin_w) { TORCH_CHECK(input.dim() == 5, "expected input to be 5D tensor"); TORCH_CHECK(cos_h.dim() == 4, "expected cos_h to be 4D tensor"); TORCH_CHECK(sin_h.dim() == 4, "expected sin_h to be 4D tensor"); @@ -180,7 +180,8 @@ TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("fused_rope_forward_thd(Tensor input, Tensor cu_seqlens, Tensor freqs) -> Tensor"); m.def("fused_rope_backward_thd(Tensor output_grads, Tensor cu_seqlens, Tensor freqs) -> Tensor"); m.def("fused_rope_forward_2d(Tensor input, Tensor cos_h, Tensor sin_h, Tensor cos_w, Tensor sin_w) -> Tensor"); - m.def("fused_rope_backward_2d(Tensor output_grads, Tensor cos_h, Tensor sin_h, Tensor cos_w, Tensor sin_w) -> Tensor"); + m.def( + "fused_rope_backward_2d(Tensor output_grads, Tensor cos_h, Tensor sin_h, Tensor cos_w, Tensor sin_w) -> Tensor"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/csrc/megatron/generic_scaled_masked_softmax.cpp b/csrc/megatron/generic_scaled_masked_softmax.cpp index 7d1dc638b..770989801 100644 --- a/csrc/megatron/generic_scaled_masked_softmax.cpp +++ b/csrc/megatron/generic_scaled_masked_softmax.cpp @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include +#include #include #include @@ -57,13 +57,12 @@ at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("generic_scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor"); - m.def("generic_scaled_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " - "float scale_factor) -> Tensor"); + m.def( + "generic_scaled_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " + "float scale_factor) -> Tensor"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { - m.impl("generic_scaled_masked_softmax_forward", - &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd); - m.impl("generic_scaled_masked_softmax_backward", - &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd); + m.impl("generic_scaled_masked_softmax_forward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd); + m.impl("generic_scaled_masked_softmax_backward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd); } diff --git a/csrc/megatron/scaled_masked_softmax.cpp b/csrc/megatron/scaled_masked_softmax.cpp index c4d2fcff1..e5bab742f 100644 --- a/csrc/megatron/scaled_masked_softmax.cpp +++ b/csrc/megatron/scaled_masked_softmax.cpp @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include +#include #include #include @@ -34,8 +34,9 @@ at::Tensor fwd(at::Tensor const& input, at::Tensor const& mask, double scale_fac auto input_arg = input; auto mask_arg = mask; TORCH_CHECK(input_arg.dim() == 4, "expected 4D tensor"); - TORCH_CHECK((input_arg.scalar_type() == at::ScalarType::Half) || (input_arg.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); + TORCH_CHECK( + (input_arg.scalar_type() == at::ScalarType::Half) || (input_arg.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); TORCH_CHECK(mask_arg.dim() == 4, "expected 4D tensor"); if (!input_arg.is_contiguous()) input_arg = input_arg.contiguous(); if (!mask_arg.is_contiguous()) mask_arg = mask_arg.contiguous(); @@ -73,8 +74,9 @@ int64_t get_batch_per_block(int64_t query_seq_len, int64_t key_seq_len, int64_t TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor"); m.def("scaled_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, float scale_factor) -> Tensor"); - m.def("scaled_masked_softmax_get_batch_per_block(int query_seq_len, int key_seq_len, int batches, " - "int attn_heads) -> int"); + m.def( + "scaled_masked_softmax_get_batch_per_block(int query_seq_len, int key_seq_len, int batches, " + "int attn_heads) -> int"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/csrc/megatron/scaled_softmax.cpp b/csrc/megatron/scaled_softmax.cpp index 7bc8a6e17..5a43f8ba1 100644 --- a/csrc/megatron/scaled_softmax.cpp +++ b/csrc/megatron/scaled_softmax.cpp @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include +#include #include #include diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp index eb2a7d9e1..6d907d017 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp +++ b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include +#include #include #include @@ -56,8 +56,9 @@ at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("scaled_upper_triang_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor"); - m.def("scaled_upper_triang_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " - "float scale_factor) -> Tensor"); + m.def( + "scaled_upper_triang_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " + "float scale_factor) -> Tensor"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/csrc/mlp.cpp b/csrc/mlp.cpp index 358b97173..be88dc993 100644 --- a/csrc/mlp.cpp +++ b/csrc/mlp.cpp @@ -1,6 +1,6 @@ -#include #include #include +#include #include #include @@ -121,8 +121,9 @@ std::vector apex_mlp_backward(int64_t use_bias, int64_t activation, TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("mlp_forward(int use_bias, int activation, Tensor[] inputs) -> Tensor[]"); - m.def("mlp_backward(int use_bias, int activation, Tensor grad_o, Tensor[] fprop_outputs, Tensor[] inputs) " - "-> Tensor[]"); + m.def( + "mlp_backward(int use_bias, int activation, Tensor grad_o, Tensor[] fprop_outputs, Tensor[] inputs) " + "-> Tensor[]"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) { diff --git a/csrc/syncbn.cpp b/csrc/syncbn.cpp index b18db1f14..b5dc892c2 100644 --- a/csrc/syncbn.cpp +++ b/csrc/syncbn.cpp @@ -78,21 +78,27 @@ std::vector apex_welford_parallel_CUDA(const at::Tensor mean_feature TORCH_LIBRARY_FRAGMENT(apex, m) { m.def("syncbn_welford_mean_var(Tensor input) -> Tensor[]"); - m.def("syncbn_welford_parallel(Tensor mean_feature_nodes, Tensor var_biased_feature_nodes, Tensor numel, float eps) " - "-> Tensor[]"); + m.def( + "syncbn_welford_parallel(Tensor mean_feature_nodes, Tensor var_biased_feature_nodes, Tensor numel, float eps) " + "-> Tensor[]"); m.def("syncbn_batchnorm_forward(Tensor input, Tensor mean, Tensor inv_std, Tensor? weight, Tensor? shift) -> Tensor"); m.def("syncbn_reduce_bn(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight) -> Tensor[]"); - m.def("syncbn_batchnorm_backward(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight, " - "Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor"); + m.def( + "syncbn_batchnorm_backward(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight, " + "Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor"); m.def("syncbn_welford_mean_var_c_last(Tensor input) -> Tensor[]"); - m.def("syncbn_batchnorm_forward_c_last(Tensor input, Tensor? z, Tensor mean, Tensor inv_std, Tensor? weight, " - "Tensor? shift, bool fuse_relu) -> Tensor"); - m.def("syncbn_reduce_bn_c_last(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight) " - "-> Tensor[]"); - m.def("syncbn_batchnorm_backward_c_last(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, " - "Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor"); - m.def("syncbn_relu_bw_c_last(Tensor grad_output, Tensor input, Tensor? z, Tensor mean, Tensor inv_std, " - "Tensor? weight, Tensor? shift) -> Tensor"); + m.def( + "syncbn_batchnorm_forward_c_last(Tensor input, Tensor? z, Tensor mean, Tensor inv_std, Tensor? weight, " + "Tensor? shift, bool fuse_relu) -> Tensor"); + m.def( + "syncbn_reduce_bn_c_last(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight) " + "-> Tensor[]"); + m.def( + "syncbn_batchnorm_backward_c_last(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, " + "Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor"); + m.def( + "syncbn_relu_bw_c_last(Tensor grad_output, Tensor input, Tensor? z, Tensor mean, Tensor inv_std, " + "Tensor? weight, Tensor? shift) -> Tensor"); } TORCH_LIBRARY_IMPL(apex, CUDA, m) {