diff --git a/apex/_custom_ops.py b/apex/_custom_ops.py new file mode 100644 index 000000000..4e43f2e3d --- /dev/null +++ b/apex/_custom_ops.py @@ -0,0 +1,44 @@ +from pathlib import Path + +import torch + +_loaded_libraries = set() + + +def load_custom_op_library(extension_name, anchor_file): + base_dir = Path(anchor_file).resolve().parent + search_dirs = [base_dir, base_dir.parent, base_dir.parent.parent] + candidates = sorted( + { + candidate + for directory in search_dirs + for candidate in directory.glob(f"{extension_name}*.so") + }, + key=lambda path: (".cpython-" in path.name, path.name), + ) + if not candidates: + raise ImportError( + f"Could not find shared library for {extension_name!r} next to {anchor_file}" + ) + + library = str(candidates[0]) + if library not in _loaded_libraries: + torch.ops.load_library(library) + _loaded_libraries.add(library) + return library + + +def scalar_float(value): + if isinstance(value, torch.Tensor): + return float(value.item()) + return float(value) + + +def scalar_int(value): + if isinstance(value, torch.Tensor): + return int(value.item()) + return int(value) + + +def tensor_list_arg(value): + return [list(tensors) for tensors in value] diff --git a/apex/_extensions/__init__.py b/apex/_extensions/__init__.py new file mode 100644 index 000000000..0e9ebbece --- /dev/null +++ b/apex/_extensions/__init__.py @@ -0,0 +1 @@ +"""Private Python shims for APEX dispatcher libraries.""" diff --git a/apex/_extensions/amp_C.py b/apex/_extensions/amp_C.py new file mode 100644 index 000000000..25426100c --- /dev/null +++ b/apex/_extensions/amp_C.py @@ -0,0 +1,358 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, tensor_list_arg + + +load_custom_op_library("_amp_C", __file__) +_ops = torch.ops.apex + + +def multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale): + return _ops.amp_multi_tensor_scale( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), scalar_float(scale) + ) + + +def multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, +): + return _ops.amp_multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(wd), + scalar_float(momentum), + scalar_float(dampening), + scalar_float(lr), + nesterov, + first_run, + wd_after_momentum, + scalar_float(scale), + ) + + +def multi_tensor_axpby(chunk_size, noop_flag, tensor_lists, a, b, arg_to_check): + return _ops.amp_multi_tensor_axpby( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(a), + scalar_float(b), + arg_to_check, + ) + + +def multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor_python=None): + return _ops.amp_multi_tensor_l2norm( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), per_tensor_python + ) + + +def multi_tensor_l2norm_mp(chunk_size, noop_flag, tensor_lists, per_tensor_python=None): + return _ops.amp_multi_tensor_l2norm_mp( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), per_tensor_python + ) + + +def multi_tensor_l2norm_scale(chunk_size, noop_flag, tensor_lists, scale, per_tensor_python=None): + return _ops.amp_multi_tensor_l2norm_scale( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), scalar_float(scale), per_tensor_python + ) + + +def multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor_python=None +): + return _ops.amp_multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_list_arg(tensor_lists), inv_scale, per_tensor_python + ) + + +def multi_tensor_lamb_stage1_cuda( + chunk_size, + noop_flag, + tensor_lists, + per_tensor_decay, + step, + beta1, + beta2, + epsilon, + global_grad_norm, + max_global_grad_norm, +): + return _ops.amp_multi_tensor_lamb_stage1_cuda( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + per_tensor_decay, + step, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + global_grad_norm, + scalar_float(max_global_grad_norm), + ) + + +def multi_tensor_lamb_stage2_cuda( + chunk_size, + noop_flag, + tensor_lists, + per_tensor_param_norm, + per_tensor_update_norm, + lr, + weight_decay, + use_nvlamb_python=None, +): + return _ops.amp_multi_tensor_lamb_stage2_cuda( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + per_tensor_param_norm, + per_tensor_update_norm, + scalar_float(lr), + scalar_float(weight_decay), + use_nvlamb_python, + ) + + +def multi_tensor_adam( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, +): + return _ops.amp_multi_tensor_adam( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + mode, + bias_correction, + scalar_float(weight_decay), + ) + + +def multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, +): + return _ops.amp_multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + lr, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + mode, + bias_correction, + scalar_float(weight_decay), + inv_scale, + ) + + +def multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, +): + return _ops.amp_multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + lr, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + mode, + bias_correction, + scalar_float(weight_decay), + inv_scale, + ) + + +def multi_tensor_adagrad(chunk_size, noop_flag, tensor_lists, lr, epsilon, mode, weight_decay): + return _ops.amp_multi_tensor_adagrad( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(epsilon), + mode, + scalar_float(weight_decay), + ) + + +def multi_tensor_novograd( + chunk_size, + noop_flag, + tensor_lists, + grad_norms, + lr, + beta1, + beta2, + epsilon, + step, + bias_correction, + weight_decay, + grad_averaging, + mode, + norm_type, +): + return _ops.amp_multi_tensor_novograd( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + grad_norms, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + bias_correction, + scalar_float(weight_decay), + grad_averaging, + mode, + norm_type, + ) + + +def multi_tensor_lamb( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + bias_correction, + weight_decay, + grad_averaging, + mode, + global_grad_norm, + max_grad_norm, + use_nvlamb_python=None, +): + return _ops.amp_multi_tensor_lamb( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + bias_correction, + scalar_float(weight_decay), + grad_averaging, + mode, + global_grad_norm, + scalar_float(max_grad_norm), + use_nvlamb_python, + ) + + +def multi_tensor_lamb_mp( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + bias_correction, + weight_decay, + grad_averaging, + mode, + global_grad_norm, + max_grad_norm, + use_nvlamb_python, + found_inf, + inv_scale, +): + return _ops.amp_multi_tensor_lamb_mp( + chunk_size, + noop_flag, + tensor_list_arg(tensor_lists), + lr, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + step, + bias_correction, + scalar_float(weight_decay), + grad_averaging, + mode, + global_grad_norm, + max_grad_norm, + use_nvlamb_python, + found_inf, + inv_scale, + ) + + +def update_scale_hysteresis( + current_scale, + growth_tracker, + hysteresis_tracker, + found_inf, + growth_factor, + backoff_factor, + growth_interval, + hysteresis, +): + return _ops.amp_update_scale_hysteresis( + current_scale, + growth_tracker, + hysteresis_tracker, + found_inf, + scalar_float(growth_factor), + scalar_float(backoff_factor), + growth_interval, + hysteresis, + ) diff --git a/apex/_extensions/apex_C.py b/apex/_extensions/apex_C.py new file mode 100644 index 000000000..e96db76a1 --- /dev/null +++ b/apex/_extensions/apex_C.py @@ -0,0 +1,18 @@ +import torch + + +def flatten(tensors): + tensors = list(tensors) + if len(tensors) == 0: + return torch.tensor([]) + return torch.cat([tensor.contiguous().view(-1) for tensor in tensors], dim=0) + + +def unflatten(flat, tensors): + outputs = [] + offset = 0 + for tensor in tensors: + numel = tensor.numel() + outputs.append(flat.narrow(0, offset, numel).view_as(tensor)) + offset += numel + return outputs diff --git a/apex/_extensions/bnp.py b/apex/_extensions/bnp.py new file mode 100644 index 000000000..2ec7e01b4 --- /dev/null +++ b/apex/_extensions/bnp.py @@ -0,0 +1,267 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_bnp", __file__) +_ops = torch.ops.apex + + +def _optional_ptr(value): + return None if value is None else scalar_int(value) + + +def get_buffer_size(bn_sync_steps): + return _ops.bnp_get_buffer_size(scalar_int(bn_sync_steps)) + + +def get_data_ptr(data): + return _ops.bnp_get_data_ptr(data) + + +def get_remote_data_ptr(handle, offset): + return _ops.bnp_get_remote_data_ptr(handle, scalar_int(offset)) + + +def close_remote_data(handle): + return _ops.bnp_close_remote_data(handle) + + +def bn_fwd_nhwc( + x, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + ret_cta, + momentum, + epsilon, + fuse_relu, + my_data, + pair_data, + pair_data2, + pair_data3, + bn_group, + magic, + occupancy, + grid_dim_x, + coop, +): + return _ops.bnp_bn_fwd_nhwc( + x, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + ret_cta, + scalar_float(momentum), + scalar_float(epsilon), + bool(fuse_relu), + _optional_ptr(my_data), + _optional_ptr(pair_data), + _optional_ptr(pair_data2), + _optional_ptr(pair_data3), + scalar_int(bn_group), + magic, + scalar_int(occupancy), + scalar_int(grid_dim_x), + bool(coop), + ) + + +def bn_fwd_eval_nhwc( + x, scale, bias, running_mean, running_inv_var, ret_cta, bn_group, momentum, epsilon, fuse_relu +): + return _ops.bnp_bn_fwd_eval_nhwc( + x, + scale, + bias, + running_mean, + running_inv_var, + ret_cta, + scalar_int(bn_group), + scalar_float(momentum), + scalar_float(epsilon), + bool(fuse_relu), + ) + + +def bn_bwd_nhwc( + x, + dy, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + ret_cta, + momentum, + epsilon, + fuse_relu, + my_data, + pair_data, + pair_data2, + pair_data3, + bn_group, + magic, + occupancy, + grid_dim_x, + coop, +): + return _ops.bnp_bn_bwd_nhwc( + x, + dy, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + ret_cta, + scalar_float(momentum), + scalar_float(epsilon), + bool(fuse_relu), + _optional_ptr(my_data), + _optional_ptr(pair_data), + _optional_ptr(pair_data2), + _optional_ptr(pair_data3), + scalar_int(bn_group), + magic, + scalar_int(occupancy), + scalar_int(grid_dim_x), + bool(coop), + ) + + +def bn_fwd_nhwc_occupancy(): + return _ops.bnp_bn_fwd_nhwc_occupancy() + + +def bn_bwd_nhwc_occupancy(): + return _ops.bnp_bn_bwd_nhwc_occupancy() + + +def bn_addrelu_fwd_nhwc( + x, + z, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + bitmask, + ret_cta, + momentum, + epsilon, + my_data, + pair_data, + pair_data2, + pair_data3, + bn_group, + magic, + occupancy, + grid_dim_x, + coop, +): + return _ops.bnp_bn_addrelu_fwd_nhwc( + x, + z, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + bitmask, + ret_cta, + scalar_float(momentum), + scalar_float(epsilon), + _optional_ptr(my_data), + _optional_ptr(pair_data), + _optional_ptr(pair_data2), + _optional_ptr(pair_data3), + scalar_int(bn_group), + magic, + scalar_int(occupancy), + scalar_int(grid_dim_x), + bool(coop), + ) + + +def bn_addrelu_fwd_eval_nhwc( + x, z, scale, bias, running_mean, running_inv_var, ret_cta, bn_group, momentum, epsilon +): + return _ops.bnp_bn_addrelu_fwd_eval_nhwc( + x, + z, + scale, + bias, + running_mean, + running_inv_var, + ret_cta, + scalar_int(bn_group), + scalar_float(momentum), + scalar_float(epsilon), + ) + + +def bn_addrelu_bwd_nhwc( + x, + dy, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + bitmask, + ret_cta, + momentum, + epsilon, + my_data, + pair_data, + pair_data2, + pair_data3, + bn_group, + magic, + occupancy, + grid_dim_x, + coop, +): + return _ops.bnp_bn_addrelu_bwd_nhwc( + x, + dy, + scale, + bias, + running_mean, + running_inv_var, + minibatch_mean, + minibatch_inv_var, + bitmask, + ret_cta, + scalar_float(momentum), + scalar_float(epsilon), + _optional_ptr(my_data), + _optional_ptr(pair_data), + _optional_ptr(pair_data2), + _optional_ptr(pair_data3), + scalar_int(bn_group), + magic, + scalar_int(occupancy), + scalar_int(grid_dim_x), + bool(coop), + ) + + +def bn_addrelu_fwd_nhwc_occupancy(): + return _ops.bnp_bn_addrelu_fwd_nhwc_occupancy() + + +def bn_addrelu_bwd_nhwc_occupancy(): + return _ops.bnp_bn_addrelu_bwd_nhwc_occupancy() diff --git a/apex/_extensions/cudnn_gbn_lib.py b/apex/_extensions/cudnn_gbn_lib.py new file mode 100644 index 000000000..e0a302b4d --- /dev/null +++ b/apex/_extensions/cudnn_gbn_lib.py @@ -0,0 +1,57 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_cudnn_gbn_lib", __file__) +_ops = torch.ops.apex + + +def _int_list(values): + return [scalar_int(value) for value in values] + + +def forward( + x, + scale, + bias, + running_mean, + running_var, + minibatch_mean, + minibatch_inv_var, + momentum, + epsilon, + bn_group, + rank_id, + peer_buffers, +): + return _ops.cudnn_gbn_forward( + x, + scale, + bias, + running_mean, + running_var, + minibatch_mean, + minibatch_inv_var, + scalar_float(momentum), + scalar_float(epsilon), + scalar_int(bn_group), + scalar_int(rank_id), + _int_list(peer_buffers), + ) + + +def backward( + x, dy, scale, minibatch_mean, minibatch_inv_var, epsilon, bn_group, rank_id, peer_buffers +): + return _ops.cudnn_gbn_backward( + x, + dy, + scale, + minibatch_mean, + minibatch_inv_var, + scalar_float(epsilon), + scalar_int(bn_group), + scalar_int(rank_id), + _int_list(peer_buffers), + ) diff --git a/apex/_extensions/distributed_adam_cuda.py b/apex/_extensions/distributed_adam_cuda.py new file mode 100644 index 000000000..26499e541 --- /dev/null +++ b/apex/_extensions/distributed_adam_cuda.py @@ -0,0 +1,97 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int, tensor_list_arg + + +load_custom_op_library("_distributed_adam_cuda", __file__) +_ops = torch.ops.apex + + +def multi_tensor_fused_adam( + chunk_size, + noop_flag, + tensor_lists, + grad_scale, + lr, + beta1, + beta2, + eps, + step, + mode, + bias_correction, + weight_decay, +): + return _ops.distributed_adam_multi_tensor_fused_adam( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + grad_scale, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(weight_decay), + ) + + +def multi_tensor_fused_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + grad_scale, + lr, + beta1, + beta2, + eps, + step, + mode, + bias_correction, + weight_decay, +): + return _ops.distributed_adam_multi_tensor_fused_adam_capturable( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + grad_scale, + lr, + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + step, + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(weight_decay), + ) + + +def multi_tensor_fused_adam_with_param_remainders( + chunk_size, + noop_flag, + tensor_lists, + grad_scale, + lr, + beta1, + beta2, + eps, + step, + mode, + bias_correction, + weight_decay, +): + return _ops.distributed_adam_multi_tensor_fused_adam_with_param_remainders( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + grad_scale, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(weight_decay), + ) diff --git a/apex/_extensions/distributed_lamb_cuda.py b/apex/_extensions/distributed_lamb_cuda.py new file mode 100644 index 000000000..1e2948e8b --- /dev/null +++ b/apex/_extensions/distributed_lamb_cuda.py @@ -0,0 +1,67 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int, tensor_list_arg + + +load_custom_op_library("_distributed_lamb_cuda", __file__) +_ops = torch.ops.apex + + +def multi_tensor_lamb_compute_update_term( + chunk_size, + noop_flag, + tensor_lists, + per_tensor_beta1, + per_tensor_beta2, + per_tensor_beta3, + per_tensor_bias_correction, + step, + per_tensor_epsilon, + mode, + per_tensor_decay, + global_scale, + global_grad_norm, + max_grad_norm, +): + return _ops.distributed_lamb_compute_update_term( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + per_tensor_beta1, + per_tensor_beta2, + per_tensor_beta3, + per_tensor_bias_correction, + step, + per_tensor_epsilon, + scalar_int(mode), + per_tensor_decay, + global_scale, + global_grad_norm, + scalar_float(max_grad_norm), + ) + + +def multi_tensor_lamb_update_weights( + chunk_size, + noop_flag, + tensor_lists, + per_tensor_param_norm, + per_tensor_update_norm, + update_norm_offset, + learning_rate, + per_tensor_decay, + global_grad_norm, + use_nvlamb, +): + return _ops.distributed_lamb_update_weights( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + per_tensor_param_norm, + per_tensor_update_norm, + update_norm_offset, + learning_rate, + per_tensor_decay, + global_grad_norm, + use_nvlamb, + ) diff --git a/apex/_extensions/fast_bottleneck.py b/apex/_extensions/fast_bottleneck.py new file mode 100644 index 000000000..98cbc5510 --- /dev/null +++ b/apex/_extensions/fast_bottleneck.py @@ -0,0 +1,201 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_fast_bottleneck", __file__) +_ops = torch.ops.apex + + +def _tensor_list(values): + return list(values) + + +def forward(explicit_nhwc, stride_1x1, inputs): + return _ops.fast_bottleneck_forward( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs) + ) + + +def backward(explicit_nhwc, stride_1x1, inputs): + return _ops.fast_bottleneck_backward( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs) + ) + + +def forward_init(explicit_nhwc, stride_1x1, inputs): + return _ops.fast_bottleneck_forward_init( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs) + ) + + +def forward_out1(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_forward_out1( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def forward_out2(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_forward_out2( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def forward_out2_mask(explicit_nhwc, stride_1x1, inputs, outputs, thresholdTop, thresholdBottom): + return _ops.fast_bottleneck_forward_out2_mask( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + thresholdTop, + thresholdBottom, + ) + + +def forward_out2_halo(explicit_nhwc, fat_halo_y1, inputs): + return _ops.fast_bottleneck_forward_out2_halo( + bool(explicit_nhwc), fat_halo_y1, _tensor_list(inputs) + ) + + +def forward_out2_halo_corr(explicit_nhwc, slim_halo_y1, inputs, w1by3, out2_part_halo): + return _ops.fast_bottleneck_forward_out2_halo_corr( + bool(explicit_nhwc), slim_halo_y1, _tensor_list(inputs), w1by3, out2_part_halo + ) + + +def forward_out2_pad(explicit_nhwc, stride_1x1, inputs, outputs, out1_pad): + return _ops.fast_bottleneck_forward_out2_pad( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + out1_pad, + ) + + +def forward_rest(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_forward_rest( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def backward_init(explicit_nhwc, stride_1x1, inputs): + return _ops.fast_bottleneck_backward_init( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs) + ) + + +def backward_grad_out2(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_backward_grad_out2( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def backward_grad_out1(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2): + return _ops.fast_bottleneck_backward_grad_out1( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2, + ) + + +def backward_grad_out1_mask( + explicit_nhwc, stride_1x1, inputs, outputs, grad_out2, thresholdTop, thresholdBottom +): + return _ops.fast_bottleneck_backward_grad_out1_mask( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2, + thresholdTop, + thresholdBottom, + ) + + +def backward_grad_out1_halo(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2_halo, relu1_halo): + return _ops.fast_bottleneck_backward_grad_out1_halo( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2_halo, + relu1_halo, + ) + + +def backward_grad_out1_halo_corr( + explicit_nhwc, stride_1x1, inputs, w1by3, outputs, grad_out2_halo, relu1_halo, part_grad_out1 +): + return _ops.fast_bottleneck_backward_grad_out1_halo_corr( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + w1by3, + _tensor_list(outputs), + grad_out2_halo, + relu1_halo, + part_grad_out1, + ) + + +def backward_wgrad2_pad(explicit_nhwc, stride_1x1, inputs, outputs, input, grad_out2): + return _ops.fast_bottleneck_backward_wgrad2_pad( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + input, + grad_out2, + ) + + +def backward_wgrad2(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2): + return _ops.fast_bottleneck_backward_wgrad2( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2, + ) + + +def backward_wgrad2_halo(explicit_nhwc, stride_1x1, inputs, outputs, input, grad_out2_halo): + return _ops.fast_bottleneck_backward_wgrad2_halo( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + input, + grad_out2_halo, + ) + + +def backward_wgrad3(explicit_nhwc, stride_1x1, inputs, outputs): + return _ops.fast_bottleneck_backward_wgrad3( + bool(explicit_nhwc), scalar_int(stride_1x1), _tensor_list(inputs), _tensor_list(outputs) + ) + + +def backward_wgrad1(explicit_nhwc, stride_1x1, inputs, outputs, grad_out1): + return _ops.fast_bottleneck_backward_wgrad1( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out1, + ) + + +def backward_rest(explicit_nhwc, stride_1x1, inputs, outputs, grad_out2, grad_out1): + return _ops.fast_bottleneck_backward_rest( + bool(explicit_nhwc), + scalar_int(stride_1x1), + _tensor_list(inputs), + _tensor_list(outputs), + grad_out2, + grad_out1, + ) diff --git a/apex/_extensions/fast_layer_norm.py b/apex/_extensions/fast_layer_norm.py new file mode 100644 index 000000000..00f84c3c2 --- /dev/null +++ b/apex/_extensions/fast_layer_norm.py @@ -0,0 +1,28 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_fast_layer_norm", __file__) +_ops = torch.ops.apex + + +def ln_fwd(x, gamma, beta, epsilon): + return _ops.fast_layer_norm_ln_fwd( + x, + gamma, + beta, + scalar_float(epsilon), + ) + + +def ln_bwd(dz, x_or_z, mu, rsigma, gamma, beta, memory_efficient): + return _ops.fast_layer_norm_ln_bwd( + dz, + x_or_z, + mu, + rsigma, + gamma, + beta, + memory_efficient, + ) diff --git a/apex/_extensions/fast_multihead_attn.py b/apex/_extensions/fast_multihead_attn.py new file mode 100644 index 000000000..21e2aaf34 --- /dev/null +++ b/apex/_extensions/fast_multihead_attn.py @@ -0,0 +1,432 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_fast_multihead_attn", __file__) +_ops = torch.ops.apex + + +def additive_mask_softmax_dropout_forward( + use_mask, is_training, heads, input, pad_mask, dropout_prob +): + return _ops.fast_multihead_attn_additive_mask_softmax_dropout_forward( + bool(use_mask), + bool(is_training), + scalar_int(heads), + input, + pad_mask, + scalar_float(dropout_prob), + ) + + +def additive_mask_softmax_dropout_backward( + use_mask, heads, output_grads, softmax_results, dropout_mask, dropout_prob +): + return _ops.fast_multihead_attn_additive_mask_softmax_dropout_backward( + bool(use_mask), + scalar_int(heads), + output_grads, + softmax_results, + dropout_mask, + scalar_float(dropout_prob), + ) + + +def mask_softmax_dropout_forward(use_mask, is_training, heads, input, pad_mask, dropout_prob): + return _ops.fast_multihead_attn_mask_softmax_dropout_forward( + bool(use_mask), + bool(is_training), + scalar_int(heads), + input, + pad_mask, + scalar_float(dropout_prob), + ) + + +def mask_softmax_dropout_backward( + use_mask, heads, output_grads, softmax_results, dropout_mask, padding_mask, dropout_prob +): + return _ops.fast_multihead_attn_mask_softmax_dropout_backward( + bool(use_mask), + scalar_int(heads), + output_grads, + softmax_results, + dropout_mask, + padding_mask, + scalar_float(dropout_prob), + ) + + +def encdec_multihead_attn_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs_q, + inputs_kv, + input_weights_q, + input_weights_kv, + output_weights, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_encdec_multihead_attn_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs_q, + inputs_kv, + input_weights_q, + input_weights_kv, + output_weights, + pad_mask, + scalar_float(dropout_prob), + ) + + +def encdec_multihead_attn_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_q_results, + input_lin_kv_results, + inputs_q, + inputs_kv, + input_weights_q, + input_weights_kv, + output_weights, + dropout_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_encdec_multihead_attn_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_q_results, + input_lin_kv_results, + inputs_q, + inputs_kv, + input_weights_q, + input_weights_kv, + output_weights, + dropout_mask, + scalar_float(dropout_prob), + ) + + +def encdec_multihead_attn_norm_add_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs_q, + inputs_kv, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights_q, + input_weights_kv, + output_weights, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_encdec_multihead_attn_norm_add_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs_q, + inputs_kv, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights_q, + input_weights_kv, + output_weights, + pad_mask, + scalar_float(dropout_prob), + ) + + +def encdec_multihead_attn_norm_add_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_q_results, + input_lin_kv_results, + lyr_nrm_results, + lyr_nrm_mean, + lyr_nrm_invvar, + inputs_q, + inputs_kv, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights_q, + input_weights_kv, + output_weights, + dropout_mask, + dropout_add_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_encdec_multihead_attn_norm_add_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_q_results, + input_lin_kv_results, + lyr_nrm_results, + lyr_nrm_mean, + lyr_nrm_invvar, + inputs_q, + inputs_kv, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights_q, + input_weights_kv, + output_weights, + dropout_mask, + dropout_add_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs, + input_weights, + output_weights, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs, + input_weights, + output_weights, + pad_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_bias_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs, + input_weights, + output_weights, + input_biases, + output_biases, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_bias_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs, + input_weights, + output_weights, + input_biases, + output_biases, + pad_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_bias_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_bias_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_bias_additive_mask_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs, + input_weights, + output_weights, + input_biases, + output_biases, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_bias_additive_mask_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs, + input_weights, + output_weights, + input_biases, + output_biases, + pad_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_bias_additive_mask_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + bmm1_results, + pad_mask, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_bias_additive_mask_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + bmm1_results, + pad_mask, + input_lin_results, + inputs, + input_weights, + output_weights, + dropout_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_norm_add_forward( + use_mask, + use_time_mask, + is_training, + heads, + inputs, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights, + output_weights, + pad_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_norm_add_forward( + bool(use_mask), + bool(use_time_mask), + bool(is_training), + scalar_int(heads), + inputs, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights, + output_weights, + pad_mask, + scalar_float(dropout_prob), + ) + + +def self_attn_norm_add_backward( + heads, + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + lyr_nrm_results, + lyr_nrm_mean, + lyr_nrm_invvar, + inputs, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights, + output_weights, + dropout_mask, + dropout_add_mask, + dropout_prob, +): + return _ops.fast_multihead_attn_self_attn_norm_add_backward( + scalar_int(heads), + output_grads, + matmul2_results, + dropout_results, + softmax_results, + input_lin_results, + lyr_nrm_results, + lyr_nrm_mean, + lyr_nrm_invvar, + inputs, + lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, + input_weights, + output_weights, + dropout_mask, + dropout_add_mask, + scalar_float(dropout_prob), + ) diff --git a/apex/_extensions/fmhalib.py b/apex/_extensions/fmhalib.py new file mode 100644 index 000000000..c4267445b --- /dev/null +++ b/apex/_extensions/fmhalib.py @@ -0,0 +1,48 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_fmhalib", __file__) +_ops = torch.ops.apex + + +def fwd(qkv, cu_seqlens, p_dropout, max_seq_len, is_training, is_nl, zero_tensors, gen): + return _ops.fmha_fwd( + qkv, + cu_seqlens, + scalar_float(p_dropout), + scalar_int(max_seq_len), + bool(is_training), + bool(is_nl), + bool(zero_tensors), + gen, + ) + + +def fwd_nl(qkv, cu_seqlens, p_dropout, max_seq_len, is_training, is_nl, zero_tensors, gen): + return fwd(qkv, cu_seqlens, p_dropout, max_seq_len, is_training, True, zero_tensors, gen) + + +def bwd(dout, qkv, softmax, cu_seqlens, p_dropout, max_seq_len, zero_tensors): + return _ops.fmha_bwd( + dout, + qkv, + softmax, + cu_seqlens, + scalar_float(p_dropout), + scalar_int(max_seq_len), + bool(zero_tensors), + ) + + +def bwd_nl(dout, qkv, softmax, cu_seqlens, p_dropout, max_seq_len, zero_tensors): + return _ops.fmha_bwd_nl( + dout, + qkv, + softmax, + cu_seqlens, + scalar_float(p_dropout), + scalar_int(max_seq_len), + bool(zero_tensors), + ) diff --git a/apex/_extensions/focal_loss_cuda.py b/apex/_extensions/focal_loss_cuda.py new file mode 100644 index 000000000..ae941bdd5 --- /dev/null +++ b/apex/_extensions/focal_loss_cuda.py @@ -0,0 +1,29 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_focal_loss_cuda", __file__) + + +def forward( + cls_output, + cls_targets_at_level, + num_positives_sum, + num_real_classes, + alpha, + gamma, + smoothing_factor, +): + return torch.ops.apex.focal_loss_forward( + cls_output, + cls_targets_at_level, + num_positives_sum, + num_real_classes, + scalar_float(alpha), + scalar_float(gamma), + scalar_float(smoothing_factor), + ) + + +backward = torch.ops.apex.focal_loss_backward diff --git a/apex/_extensions/fused_adam_cuda.py b/apex/_extensions/fused_adam_cuda.py new file mode 100644 index 000000000..36c4d2d45 --- /dev/null +++ b/apex/_extensions/fused_adam_cuda.py @@ -0,0 +1,154 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int, tensor_list_arg + + +load_custom_op_library("_fused_adam_cuda", __file__) + + +def strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first): + return torch.ops.apex.fused_adam_strided_check_finite( + overflow_flag, p_copy, scalar_int(stride), scalar_int(clear_overflow_first) + ) + + +def adam( + p, + p_copy, + m, + v, + g, + lr, + beta1, + beta2, + eps, + grad_scale, + step, + mode, + bias_correction, + decay, +): + return torch.ops.apex.fused_adam_adam( + p, + p_copy, + m, + v, + g, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_float(grad_scale), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(decay), + ) + + +def reversible_adam( + p, + p_copy, + m, + v, + g, + lr, + beta1, + beta2, + eps, + grad_scale, + step, + mode, + bias_correction, + decay, +): + return torch.ops.apex.fused_adam_reversible_adam( + p, + p_copy, + m, + v, + g, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_float(grad_scale), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(decay), + ) + + +def adam_mt( + chunk_size, + overflow_flag, + tensor_lists, + lr, + beta1, + beta2, + eps, + grad_scale, + step, + mode, + bias_correction, + decay, +): + return torch.ops.apex.fused_adam_adam_mt( + scalar_int(chunk_size), + overflow_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_float(grad_scale), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(decay), + ) + + +def maybe_adam_undo( + overflow_flag, + p, + m, + v, + g, + lr, + beta1, + beta2, + eps, + grad_scale, + step, + mode, + bias_correction, + decay, +): + return torch.ops.apex.fused_adam_maybe_adam_undo( + overflow_flag, + p, + m, + v, + g, + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(eps), + scalar_float(grad_scale), + scalar_int(step), + scalar_int(mode), + scalar_int(bias_correction), + scalar_float(decay), + ) + + +def maybe_cast(overflow_flag, p_in, p_out): + return torch.ops.apex.fused_adam_maybe_cast(overflow_flag, p_in, p_out) + + +def maybe_cast_mt(chunk_size, overflow_flag, tensor_lists): + return torch.ops.apex.fused_adam_maybe_cast_mt( + scalar_int(chunk_size), overflow_flag, tensor_list_arg(tensor_lists) + ) diff --git a/apex/_extensions/fused_conv_bias_relu.py b/apex/_extensions/fused_conv_bias_relu.py new file mode 100644 index 000000000..64d58bbf6 --- /dev/null +++ b/apex/_extensions/fused_conv_bias_relu.py @@ -0,0 +1,67 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_fused_conv_bias_relu", __file__) +_ops = torch.ops.apex + + +def _tensor_list(inputs): + return list(inputs) + + +def forward(inputs, padding, stride): + return _ops.fused_conv_bias_relu_forward( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def backward(inputs, padding, stride): + return _ops.fused_conv_bias_relu_backward( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def forward_no_relu(inputs, padding, stride): + return _ops.fused_conv_bias_relu_forward_no_relu( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def backward_no_relu(inputs, padding, stride): + return _ops.fused_conv_bias_relu_backward_no_relu( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def forward_mask(inputs, padding, stride): + return _ops.fused_conv_bias_relu_forward_mask( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def forward_cscale_cbias_relu(inputs, padding, stride): + return _ops.fused_conv_bias_relu_forward_cscale_cbias_relu( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) + + +def backward_cscale_cbias_relu(inputs, padding, stride): + return _ops.fused_conv_bias_relu_backward_cscale_cbias_relu( + _tensor_list(inputs), + scalar_int(padding), + scalar_int(stride), + ) diff --git a/apex/_extensions/fused_dense_cuda.py b/apex/_extensions/fused_dense_cuda.py new file mode 100644 index 000000000..c189bf8c9 --- /dev/null +++ b/apex/_extensions/fused_dense_cuda.py @@ -0,0 +1,11 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_dense_cuda", __file__) + +linear_bias_forward = torch.ops.apex.fused_dense_linear_bias_forward +linear_bias_backward = torch.ops.apex.fused_dense_linear_bias_backward +linear_gelu_linear_forward = torch.ops.apex.fused_dense_linear_gelu_linear_forward +linear_gelu_linear_backward = torch.ops.apex.fused_dense_linear_gelu_linear_backward diff --git a/apex/_extensions/fused_index_mul_2d.py b/apex/_extensions/fused_index_mul_2d.py new file mode 100644 index 000000000..6c475b506 --- /dev/null +++ b/apex/_extensions/fused_index_mul_2d.py @@ -0,0 +1,13 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_index_mul_2d", __file__) + +float_forward = torch.ops.apex.index_mul_2d_float_forward +float_backward = torch.ops.apex.index_mul_2d_float_backward +float_backward_backward = torch.ops.apex.index_mul_2d_float_backward_backward +half_forward = torch.ops.apex.index_mul_2d_half_forward +half_backward = torch.ops.apex.index_mul_2d_half_backward +half_backward_backward = torch.ops.apex.index_mul_2d_half_backward_backward diff --git a/apex/_extensions/fused_lamb_cuda.py b/apex/_extensions/fused_lamb_cuda.py new file mode 100644 index 000000000..2f3eeb843 --- /dev/null +++ b/apex/_extensions/fused_lamb_cuda.py @@ -0,0 +1,45 @@ +import torch + +from apex._custom_ops import ( + load_custom_op_library, + scalar_float, + scalar_int, + tensor_list_arg, +) + + +load_custom_op_library("_fused_lamb_cuda", __file__) + + +def lamb( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + bias_correction, + weight_decay, + grad_averaging, + mode, + global_grad_norm, + max_grad_norm, +): + return torch.ops.apex.fused_lamb_lamb( + scalar_int(chunk_size), + noop_flag, + tensor_list_arg(tensor_lists), + scalar_float(lr), + scalar_float(beta1), + scalar_float(beta2), + scalar_float(epsilon), + scalar_int(step), + scalar_int(bias_correction), + scalar_float(weight_decay), + scalar_int(grad_averaging), + scalar_int(mode), + scalar_float(global_grad_norm), + scalar_float(max_grad_norm), + ) diff --git a/apex/_extensions/fused_layer_norm_cuda.py b/apex/_extensions/fused_layer_norm_cuda.py new file mode 100644 index 000000000..1cc17eb23 --- /dev/null +++ b/apex/_extensions/fused_layer_norm_cuda.py @@ -0,0 +1,17 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_layer_norm_cuda", __file__) + +forward_affine = torch.ops.apex.fused_layer_norm_forward_affine +forward = torch.ops.apex.fused_layer_norm_forward +backward_affine = torch.ops.apex.fused_layer_norm_backward_affine +backward = torch.ops.apex.fused_layer_norm_backward +forward_affine_mixed_dtypes = torch.ops.apex.fused_layer_norm_forward_affine_mixed_dtypes +rms_forward_affine = torch.ops.apex.fused_layer_norm_rms_forward_affine +rms_forward = torch.ops.apex.fused_layer_norm_rms_forward +rms_backward_affine = torch.ops.apex.fused_layer_norm_rms_backward_affine +rms_backward = torch.ops.apex.fused_layer_norm_rms_backward +rms_forward_affine_mixed_dtypes = torch.ops.apex.fused_layer_norm_rms_forward_affine_mixed_dtypes diff --git a/apex/_extensions/fused_rotary_positional_embedding.py b/apex/_extensions/fused_rotary_positional_embedding.py new file mode 100644 index 000000000..ca27b6d3e --- /dev/null +++ b/apex/_extensions/fused_rotary_positional_embedding.py @@ -0,0 +1,15 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_rotary_positional_embedding", __file__) + +forward = torch.ops.apex.fused_rope_forward +backward = torch.ops.apex.fused_rope_backward +forward_cached = torch.ops.apex.fused_rope_forward_cached +backward_cached = torch.ops.apex.fused_rope_backward_cached +forward_thd = torch.ops.apex.fused_rope_forward_thd +backward_thd = torch.ops.apex.fused_rope_backward_thd +forward_2d = torch.ops.apex.fused_rope_forward_2d +backward_2d = torch.ops.apex.fused_rope_backward_2d diff --git a/apex/_extensions/fused_weight_gradient_mlp_cuda.py b/apex/_extensions/fused_weight_gradient_mlp_cuda.py new file mode 100644 index 000000000..fc9c78c43 --- /dev/null +++ b/apex/_extensions/fused_weight_gradient_mlp_cuda.py @@ -0,0 +1,9 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_fused_weight_gradient_mlp_cuda", __file__) + +wgrad_gemm_accum_fp32 = torch.ops.apex.fused_weight_gradient_mlp_wgrad_gemm_accum_fp32 +wgrad_gemm_accum_fp16 = torch.ops.apex.fused_weight_gradient_mlp_wgrad_gemm_accum_fp16 diff --git a/apex/_extensions/generic_scaled_masked_softmax_cuda.py b/apex/_extensions/generic_scaled_masked_softmax_cuda.py new file mode 100644 index 000000000..770bd6894 --- /dev/null +++ b/apex/_extensions/generic_scaled_masked_softmax_cuda.py @@ -0,0 +1,18 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_generic_scaled_masked_softmax_cuda", __file__) + + +def forward(input, mask, scale_factor): + return torch.ops.apex.generic_scaled_masked_softmax_forward( + input, mask, scalar_float(scale_factor) + ) + + +def backward(output_grads, softmax_results, scale_factor): + return torch.ops.apex.generic_scaled_masked_softmax_backward( + output_grads, softmax_results, scalar_float(scale_factor) + ) diff --git a/apex/_extensions/group_norm_cuda.py b/apex/_extensions/group_norm_cuda.py new file mode 100644 index 000000000..8e97f1745 --- /dev/null +++ b/apex/_extensions/group_norm_cuda.py @@ -0,0 +1,32 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_group_norm_cuda", __file__) + + +def forward(input, groups, weight, bias, eps, passes, with_swish=False): + return torch.ops.apex.group_norm_forward( + input, + scalar_int(groups), + weight, + bias, + scalar_float(eps), + scalar_int(passes), + bool(with_swish), + ) + + +def backward(grad_output, sums, input, groups, weight, bias, eps, passes, with_swish=False): + return torch.ops.apex.group_norm_backward( + grad_output, + sums, + input, + scalar_int(groups), + weight, + bias, + scalar_float(eps), + scalar_int(passes), + bool(with_swish), + ) diff --git a/apex/_extensions/group_norm_v2_cuda.py b/apex/_extensions/group_norm_v2_cuda.py new file mode 100644 index 000000000..f6744ff22 --- /dev/null +++ b/apex/_extensions/group_norm_v2_cuda.py @@ -0,0 +1,33 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_group_norm_v2_cuda", __file__) + + +def gn(x, w, b, eps, silu, num_groups, mean_var_out=None, sm_margin=0): + return torch.ops.apex.group_norm_v2_gn( + x, + w, + b, + scalar_float(eps), + bool(silu), + scalar_int(num_groups), + mean_var_out, + scalar_int(sm_margin), + ) + + +def gn_bwd(grad_output, x, w, b, mean_var, eps, silu, num_groups, sm_margin=0): + return torch.ops.apex.group_norm_v2_gn_bwd( + grad_output, + x, + w, + b, + mean_var, + scalar_float(eps), + bool(silu), + scalar_int(num_groups), + scalar_int(sm_margin), + ) diff --git a/apex/_extensions/mlp_cuda.py b/apex/_extensions/mlp_cuda.py new file mode 100644 index 000000000..23e5b8d4c --- /dev/null +++ b/apex/_extensions/mlp_cuda.py @@ -0,0 +1,16 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_mlp_cuda", __file__) + + +def forward(use_bias, activation, inputs): + return torch.ops.apex.mlp_forward(use_bias, activation, list(inputs)) + + +def backward(use_bias, activation, grad_o, fprop_outputs, inputs): + return torch.ops.apex.mlp_backward( + use_bias, activation, grad_o, list(fprop_outputs), list(inputs) + ) diff --git a/apex/_extensions/nccl_p2p_cuda.py b/apex/_extensions/nccl_p2p_cuda.py new file mode 100644 index 000000000..c0b44c6e7 --- /dev/null +++ b/apex/_extensions/nccl_p2p_cuda.py @@ -0,0 +1,53 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_nccl_p2p_cuda", __file__) +_ops = torch.ops.apex + + +def get_unique_nccl_id(n): + return _ops.nccl_p2p_get_unique_nccl_id(scalar_int(n)) + + +def init_nccl_comm(unique_nccl_id, my_rank, num_ranks): + return _ops.nccl_p2p_init_nccl_comm( + unique_nccl_id, + scalar_int(my_rank), + scalar_int(num_ranks), + ) + + +def left_right_halo_exchange_inplace( + handle, + left_rank, + right_rank, + left_output_halo, + right_output_halo, + left_input_halo, + right_input_halo, +): + return _ops.nccl_p2p_left_right_halo_exchange_inplace( + scalar_int(handle), + scalar_int(left_rank), + scalar_int(right_rank), + left_output_halo, + right_output_halo, + left_input_halo, + right_input_halo, + ) + + +def left_right_halo_exchange(handle, left_rank, right_rank, left_output_halo, right_output_halo): + return _ops.nccl_p2p_left_right_halo_exchange( + scalar_int(handle), + scalar_int(left_rank), + scalar_int(right_rank), + left_output_halo, + right_output_halo, + ) + + +def add_delay(delay): + return _ops.nccl_p2p_add_delay(scalar_int(delay)) diff --git a/apex/_extensions/peer_memory_cuda.py b/apex/_extensions/peer_memory_cuda.py new file mode 100644 index 000000000..c99127905 --- /dev/null +++ b/apex/_extensions/peer_memory_cuda.py @@ -0,0 +1,93 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_peer_memory_cuda", __file__) +_ops = torch.ops.apex + + +def _shape_arg(shape): + return [scalar_int(dim) for dim in shape] + + +def allocate_raw(size): + return _ops.peer_memory_allocate_raw(scalar_int(size)) + + +def free_raw(raw): + return _ops.peer_memory_free_raw(scalar_int(raw)) + + +def zero(raw, size): + return _ops.peer_memory_zero(scalar_int(raw), scalar_int(size)) + + +def get_raw_ipc_address(raw): + return _ops.peer_memory_get_raw_ipc_address(scalar_int(raw)) + + +def get_raw_peers(ipc_addresses, peer_rank, raw): + return _ops.peer_memory_get_raw_peers( + ipc_addresses, + scalar_int(peer_rank), + scalar_int(raw), + ) + + +def blob_view_half(raw, shape, channels_last): + return _ops.peer_memory_blob_view_half( + scalar_int(raw), + _shape_arg(shape), + bool(channels_last), + ) + + +def blob_view_float(raw, shape, channels_last): + return _ops.peer_memory_blob_view_float( + scalar_int(raw), + _shape_arg(shape), + bool(channels_last), + ) + + +def blob_view_int(raw, shape, channels_last): + return _ops.peer_memory_blob_view_int( + scalar_int(raw), + _shape_arg(shape), + bool(channels_last), + ) + + +def push_pull_halos_1d( + diagnostics, + explicit_nhwc, + numSM, + rank, + top_zero, + top_in_halo, + top_in_transfer, + top_out_transfer, + top_out_halo, + btm_zero, + btm_in_halo, + btm_in_transfer, + btm_out_transfer, + btm_out_halo, +): + return _ops.peer_memory_push_pull_halos_1d( + bool(diagnostics), + bool(explicit_nhwc), + scalar_int(numSM), + scalar_int(rank), + bool(top_zero), + top_in_halo, + top_in_transfer, + top_out_transfer, + top_out_halo, + bool(btm_zero), + btm_in_halo, + btm_in_transfer, + btm_out_transfer, + btm_out_halo, + ) diff --git a/apex/_extensions/permutation_search_cuda.py b/apex/_extensions/permutation_search_cuda.py new file mode 100644 index 000000000..c5caf0a4c --- /dev/null +++ b/apex/_extensions/permutation_search_cuda.py @@ -0,0 +1,91 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_permutation_search_cuda", __file__) + + +def _as_tensor(value): + if isinstance(value, torch.Tensor): + return value + return torch.from_numpy(value) + + +def _op(op_name): + return getattr(torch.ops.apex, op_name) + + +def sum_after_2_to_4(matrix, rows, cols, start_col, end_col, blocks, threads, output): + return _op("permutation_search_sum_after_2_to_4")( + _as_tensor(matrix), + scalar_int(rows), + scalar_int(cols), + scalar_int(start_col), + scalar_int(end_col), + scalar_int(blocks), + scalar_int(threads), + _as_tensor(output), + ) + + +def build_permute_map( + matrix, + rows, + cols, + stripes, + num_groups, + group_width, + permutations, + perm_length, + improvements, + best_indices, +): + return _op("permutation_search_build_permute_map")( + _as_tensor(matrix), + scalar_int(rows), + scalar_int(cols), + _as_tensor(stripes), + scalar_int(num_groups), + scalar_int(group_width), + _as_tensor(permutations), + scalar_int(perm_length), + _as_tensor(improvements), + _as_tensor(best_indices), + ) + + +def check_permutations( + matrix, + rows, + cols, + stripe_groups, + group_width, + num_groups, + permutations, + num_permutations, + improvement, + permutation, +): + return _op("permutation_search_check_permutations")( + _as_tensor(matrix), + scalar_int(rows), + scalar_int(cols), + _as_tensor(stripe_groups), + scalar_int(group_width), + scalar_int(num_groups), + _as_tensor(permutations), + scalar_int(num_permutations), + _as_tensor(improvement), + _as_tensor(permutation), + ) + + +def build_swap_map(matrix, rows, cols, stripe_pairs, output): + return _op("permutation_search_build_swap_map")( + _as_tensor(matrix), + scalar_int(rows), + scalar_int(cols), + _as_tensor(stripe_pairs), + _as_tensor(output), + ) diff --git a/apex/_extensions/scaled_masked_softmax_cuda.py b/apex/_extensions/scaled_masked_softmax_cuda.py new file mode 100644 index 000000000..3c6e17e84 --- /dev/null +++ b/apex/_extensions/scaled_masked_softmax_cuda.py @@ -0,0 +1,25 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_scaled_masked_softmax_cuda", __file__) + + +def forward(input, mask, scale_factor): + return torch.ops.apex.scaled_masked_softmax_forward(input, mask, scalar_float(scale_factor)) + + +def backward(output_grads, softmax_results, scale_factor): + return torch.ops.apex.scaled_masked_softmax_backward( + output_grads, softmax_results, scalar_float(scale_factor) + ) + + +def get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads): + return torch.ops.apex.scaled_masked_softmax_get_batch_per_block( + scalar_int(query_seq_len), + scalar_int(key_seq_len), + scalar_int(batches), + scalar_int(attn_heads), + ) diff --git a/apex/_extensions/scaled_softmax_cuda.py b/apex/_extensions/scaled_softmax_cuda.py new file mode 100644 index 000000000..5b9ff0b9d --- /dev/null +++ b/apex/_extensions/scaled_softmax_cuda.py @@ -0,0 +1,16 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_scaled_softmax_cuda", __file__) + + +def forward(input, scale_factor): + return torch.ops.apex.scaled_softmax_forward(input, scalar_float(scale_factor)) + + +def backward(output_grads, softmax_results, scale_factor): + return torch.ops.apex.scaled_softmax_backward( + output_grads, softmax_results, scalar_float(scale_factor) + ) diff --git a/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py b/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py new file mode 100644 index 000000000..8feebacba --- /dev/null +++ b/apex/_extensions/scaled_upper_triang_masked_softmax_cuda.py @@ -0,0 +1,18 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_scaled_upper_triang_masked_softmax_cuda", __file__) + + +def forward(input, scale_factor): + return torch.ops.apex.scaled_upper_triang_masked_softmax_forward( + input, scalar_float(scale_factor) + ) + + +def backward(output_grads, softmax_results, scale_factor): + return torch.ops.apex.scaled_upper_triang_masked_softmax_backward( + output_grads, softmax_results, scalar_float(scale_factor) + ) diff --git a/apex/_extensions/syncbn.py b/apex/_extensions/syncbn.py new file mode 100644 index 000000000..6dc9d8bf2 --- /dev/null +++ b/apex/_extensions/syncbn.py @@ -0,0 +1,17 @@ +import torch + +from apex._custom_ops import load_custom_op_library + + +load_custom_op_library("_syncbn", __file__) + +welford_mean_var = torch.ops.apex.syncbn_welford_mean_var +welford_parallel = torch.ops.apex.syncbn_welford_parallel +batchnorm_forward = torch.ops.apex.syncbn_batchnorm_forward +reduce_bn = torch.ops.apex.syncbn_reduce_bn +batchnorm_backward = torch.ops.apex.syncbn_batchnorm_backward +welford_mean_var_c_last = torch.ops.apex.syncbn_welford_mean_var_c_last +batchnorm_forward_c_last = torch.ops.apex.syncbn_batchnorm_forward_c_last +reduce_bn_c_last = torch.ops.apex.syncbn_reduce_bn_c_last +batchnorm_backward_c_last = torch.ops.apex.syncbn_batchnorm_backward_c_last +relu_bw_c_last = torch.ops.apex.syncbn_relu_bw_c_last diff --git a/apex/_extensions/transducer_joint_cuda.py b/apex/_extensions/transducer_joint_cuda.py new file mode 100644 index 000000000..3533e2a5b --- /dev/null +++ b/apex/_extensions/transducer_joint_cuda.py @@ -0,0 +1,50 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float, scalar_int + + +load_custom_op_library("_transducer_joint_cuda", __file__) +_ops = torch.ops.apex + + +def forward( + f, + g, + f_len, + g_len, + batch_offset, + packed_batch, + opt, + pack_output, + relu, + dropout, + dropout_prob, + tile_size, +): + return _ops.transducer_joint_forward( + f, + g, + f_len, + g_len, + batch_offset, + scalar_int(packed_batch), + scalar_int(opt), + pack_output, + relu, + dropout, + scalar_float(dropout_prob), + scalar_int(tile_size), + ) + + +def backward(input, f_len, g_len, batch_offset, max_f_len, max_g_len, pack_output, scale): + return _ops.transducer_joint_backward( + list(input), + f_len, + g_len, + batch_offset, + scalar_int(max_f_len), + scalar_int(max_g_len), + pack_output, + scalar_float(scale), + ) diff --git a/apex/_extensions/transducer_loss_cuda.py b/apex/_extensions/transducer_loss_cuda.py new file mode 100644 index 000000000..72e6fe09c --- /dev/null +++ b/apex/_extensions/transducer_loss_cuda.py @@ -0,0 +1,53 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_int + + +load_custom_op_library("_transducer_loss_cuda", __file__) +_ops = torch.ops.apex + + +def forward(x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, opt, packed_input): + return _ops.transducer_loss_forward( + x, + label, + f_len, + y_len, + batch_offset, + scalar_int(max_f_len), + scalar_int(blank_idx), + scalar_int(opt), + packed_input, + ) + + +def backward( + x, + loss_grad, + alpha, + beta, + f_len, + y_len, + label, + batch_offset, + max_f_len, + blank_idx, + opt, + fuse_softmax_backward, + packed_input, +): + return _ops.transducer_loss_backward( + x, + loss_grad, + alpha, + beta, + f_len, + y_len, + label, + batch_offset, + scalar_int(max_f_len), + scalar_int(blank_idx), + scalar_int(opt), + fuse_softmax_backward, + packed_input, + ) diff --git a/apex/_extensions/xentropy_cuda.py b/apex/_extensions/xentropy_cuda.py new file mode 100644 index 000000000..0ae58d8d8 --- /dev/null +++ b/apex/_extensions/xentropy_cuda.py @@ -0,0 +1,18 @@ +import torch + +from apex._custom_ops import load_custom_op_library, scalar_float + + +load_custom_op_library("_xentropy_cuda", __file__) + +__version__ = torch.ops.apex.xentropy_version() + + +def forward(input, labels, smoothing, half_to_float): + return torch.ops.apex.xentropy_forward(input, labels, scalar_float(smoothing), half_to_float) + + +def backward(grad_loss, logits, max_log_sum_exp, labels, smoothing): + return torch.ops.apex.xentropy_backward( + grad_loss, logits, max_log_sum_exp, labels, scalar_float(smoothing) + ) diff --git a/apex/contrib/bottleneck/bottleneck.py b/apex/contrib/bottleneck/bottleneck.py index e24a9251b..a15557d41 100644 --- a/apex/contrib/bottleneck/bottleneck.py +++ b/apex/contrib/bottleneck/bottleneck.py @@ -4,8 +4,8 @@ from torch import nn from apex import check_cudnn_version_and_warn -import fast_bottleneck -import nccl_p2p_cuda as inc +from apex._extensions import fast_bottleneck +from apex._extensions import nccl_p2p_cuda as inc assert check_cudnn_version_and_warn(__name__, 8400) diff --git a/apex/contrib/bottleneck/halo_exchangers.py b/apex/contrib/bottleneck/halo_exchangers.py index eb0224c5b..8d261bdc4 100644 --- a/apex/contrib/bottleneck/halo_exchangers.py +++ b/apex/contrib/bottleneck/halo_exchangers.py @@ -1,6 +1,6 @@ import torch -import nccl_p2p_cuda as inc -import peer_memory_cuda as pm +from apex._extensions import nccl_p2p_cuda as inc +from apex._extensions import peer_memory_cuda as pm # Communication free halo exchanger. diff --git a/apex/contrib/clip_grad/clip_grad.py b/apex/contrib/clip_grad/clip_grad.py index b34dc43bd..2290a8703 100644 --- a/apex/contrib/clip_grad/clip_grad.py +++ b/apex/contrib/clip_grad/clip_grad.py @@ -4,7 +4,7 @@ _kernel_import_succeeded = False try: - import amp_C + from apex._extensions import amp_C from apex.multi_tensor_apply import multi_tensor_applier _kernel_import_succeeded = True @@ -74,6 +74,17 @@ def clip_grad_norm_( else: grads_misc.append(grad) + if grads_fp16: + # Preserve PyTorch's current fp16 clipping semantics exactly. The fused + # norm path accumulates and returns fp32, which can perturb the fp16 + # scale enough to change rounded gradient values. + return torch.nn.utils.clip_grad_norm_( + parameters, + max_norm, + norm_type=norm_type, + error_if_nonfinite=error_if_nonfinite, + ) + # Compute gradient L2 norms norms = [] dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device) diff --git a/apex/contrib/conv_bias_relu/conv_bias_relu.py b/apex/contrib/conv_bias_relu/conv_bias_relu.py index 533d6421f..faabdbff3 100644 --- a/apex/contrib/conv_bias_relu/conv_bias_relu.py +++ b/apex/contrib/conv_bias_relu/conv_bias_relu.py @@ -1,7 +1,7 @@ import torch from apex import check_cudnn_version_and_warn -import fused_conv_bias_relu +from apex._extensions import fused_conv_bias_relu check_cudnn_version_and_warn(__name__, 8400) diff --git a/apex/contrib/csrc/bottleneck/bottleneck.cpp b/apex/contrib/csrc/bottleneck/bottleneck.cpp index 3c98a39dd..e33201cb9 100644 --- a/apex/contrib/csrc/bottleneck/bottleneck.cpp +++ b/apex/contrib/csrc/bottleneck/bottleneck.cpp @@ -1,8 +1,7 @@ #include #include // for getcudnnhandle #include -#include -#include +#include #include #include @@ -438,7 +437,7 @@ void run_conv_scale_bias_add_activation(int64_t* x_dim_padded, int64_t* pad, int int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -583,7 +582,7 @@ void run_conv_scale_bias_add_activation(int64_t* x_dim_padded, int64_t* pad, int void run_conv_scale_bias(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -696,7 +695,7 @@ void run_conv_scale_bias(int64_t* x_dim_padded, int64_t* pad, int64_t* convstrid void run_dconv_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -810,7 +809,7 @@ void run_dconv_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* convst void run_dconv(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, cudnnBackendDescriptorType_t mode) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -905,7 +904,7 @@ void run_dconv(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t void run_dconv_add(int64_t* x_dim_padded, int64_t* pad, int64_t* convstride, int64_t* dilation, int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrR) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -1821,7 +1820,7 @@ void run_conv_add_scale_bias_activation(int64_t* x_dim_padded, int64_t* pad, int int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -1966,7 +1965,7 @@ void run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, int64_t* pad int64_t* threshold_dim, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrB, at::Half* devPtrI, int* devPtrT, int* devPtrU, int axis) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -2236,7 +2235,7 @@ void run_dconv_add_drelu_dscale(int64_t* x_dim_padded, int64_t* pad, int64_t* co int64_t* w_dim_padded, int64_t* y_dim_padded, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR, at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -2367,7 +2366,7 @@ void run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, int64_t* pad, int64_t* c int64_t* w_dim_padded, int64_t* y_dim_padded, int64_t* threshold_dim, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, at::Half* devPtrZ, at::Half* devPtrR, int* devPtrT, int* devPtrU, int axis) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -3554,43 +3553,211 @@ void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector()); - m.def("backward", &bottleneck_backward, "Bottleneck block backward", py::call_guard()); - m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init", py::call_guard()); - m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward", py::call_guard()); - m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward", py::call_guard()); - m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward", - py::call_guard()); - m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward", - py::call_guard()); - m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward", - py::call_guard()); - m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward", - py::call_guard()); - m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward", py::call_guard()); - m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init", - py::call_guard()); - m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward", - py::call_guard()); - m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward", - py::call_guard()); - m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward", - py::call_guard()); - m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward", - py::call_guard()); - m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward", - py::call_guard()); - m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward", - py::call_guard()); - m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward", - py::call_guard()); +namespace { +int as_int(int64_t value) { return static_cast(value); } + +std::vector apex_fast_bottleneck_forward(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs) { + return bottleneck_forward(explicit_nhwc, as_int(stride_1X1), std::move(inputs)); +} + +std::vector apex_fast_bottleneck_backward(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs) { + return bottleneck_backward(explicit_nhwc, as_int(stride_1X1), std::move(inputs)); +} + +std::vector apex_fast_bottleneck_forward_init(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs) { + return bottleneck_forward_init(explicit_nhwc, as_int(stride_1X1), std::move(inputs)); +} + +void apex_fast_bottleneck_forward_out1(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs) { + bottleneck_forward_out1(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +void apex_fast_bottleneck_forward_out2(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs) { + bottleneck_forward_out2(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +void apex_fast_bottleneck_forward_out2_mask(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor thresholdTop, + at::Tensor thresholdBottom) { + bottleneck_forward_out2_mask(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), thresholdTop, + thresholdBottom); +} + +at::Tensor apex_fast_bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_y1, + std::vector inputs) { + return bottleneck_forward_out2_halo(explicit_nhwc, fat_halo_y1, std::move(inputs)); +} + +at::Tensor apex_fast_bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, + std::vector inputs, at::Tensor w1by3, + at::Tensor out2_part_halo) { + return bottleneck_forward_out2_halo_corr(explicit_nhwc, slim_halo_y1, std::move(inputs), w1by3, out2_part_halo); +} + +void apex_fast_bottleneck_forward_out2_pad(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor out1_pad) { + bottleneck_forward_out2_pad(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), out1_pad); +} + +void apex_fast_bottleneck_forward_rest(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs) { + bottleneck_forward_rest(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +std::vector apex_fast_bottleneck_backward_init(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs) { + return bottleneck_backward_init(explicit_nhwc, as_int(stride_1X1), std::move(inputs)); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out2(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, std::vector outputs) { + return bottleneck_backward_grad_out2(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out1(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, std::vector outputs, + at::Tensor grad_out2) { + return bottleneck_backward_grad_out1(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), + grad_out2); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, std::vector outputs, + at::Tensor grad_out2, at::Tensor thresholdTop, + at::Tensor thresholdBottom) { + return bottleneck_backward_grad_out1_mask(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), + grad_out2, thresholdTop, thresholdBottom); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, at::Tensor w1by3, + std::vector outputs, at::Tensor grad_out2_halo, + at::Tensor relu1_halo, at::Tensor part_grad_out1) { + return bottleneck_backward_grad_out1_halo_corr(explicit_nhwc, as_int(stride_1X1), std::move(inputs), w1by3, + std::move(outputs), grad_out2_halo, relu1_halo, part_grad_out1); +} + +at::Tensor apex_fast_bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, std::vector outputs, + at::Tensor grad_out2_halo, at::Tensor relu1_halo) { + return bottleneck_backward_grad_out1_halo(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), + grad_out2_halo, relu1_halo); +} + +void apex_fast_bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor input, at::Tensor grad_out2) { + bottleneck_backward_wgrad2_pad(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), input, + grad_out2); +} + +void apex_fast_bottleneck_backward_wgrad2(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out2) { + bottleneck_backward_wgrad2(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out2); +} + +at::Tensor apex_fast_bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int64_t stride_1X1, + std::vector inputs, std::vector outputs, + at::Tensor input, at::Tensor grad_out2_halo) { + return bottleneck_backward_wgrad2_halo(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), + input, grad_out2_halo); +} + +void apex_fast_bottleneck_backward_wgrad3(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs) { + bottleneck_backward_wgrad3(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs)); +} + +void apex_fast_bottleneck_backward_wgrad1(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out1) { + bottleneck_backward_wgrad1(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out1); +} + +void apex_fast_bottleneck_backward_rest(bool explicit_nhwc, int64_t stride_1X1, std::vector inputs, + std::vector outputs, at::Tensor grad_out2, at::Tensor grad_out1) { + bottleneck_backward_rest(explicit_nhwc, as_int(stride_1X1), std::move(inputs), std::move(outputs), grad_out2, + grad_out1); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fast_bottleneck_forward(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); + m.def("fast_bottleneck_backward(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); + m.def("fast_bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); + m.def("fast_bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) -> ()"); + m.def("fast_bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) -> ()"); + m.def( + "fast_bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor thresholdTop, Tensor thresholdBottom) -> ()"); + m.def("fast_bottleneck_forward_out2_halo(bool explicit_nhwc, Tensor fat_halo_y1, Tensor[] inputs) -> Tensor"); + m.def( + "fast_bottleneck_forward_out2_halo_corr(bool explicit_nhwc, Tensor slim_halo_y1, Tensor[] inputs, " + "Tensor w1by3, Tensor out2_part_halo) -> Tensor"); + m.def( + "fast_bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor out1_pad) -> ()"); + m.def("fast_bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) -> ()"); + m.def("fast_bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, Tensor[] inputs) -> Tensor[]"); + m.def( + "fast_bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) " + "-> Tensor"); + m.def( + "fast_bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out2) -> Tensor"); + m.def( + "fast_bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " + "Tensor[] outputs, Tensor grad_out2, Tensor thresholdTop, Tensor thresholdBottom) -> Tensor"); + m.def( + "fast_bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " + "Tensor[] outputs, Tensor grad_out2_halo, Tensor relu1_halo) -> Tensor"); + m.def( + "fast_bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, " + "Tensor w1by3, Tensor[] outputs, Tensor grad_out2_halo, Tensor relu1_halo, Tensor part_grad_out1) -> Tensor"); + m.def( + "fast_bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor input, Tensor grad_out2) -> ()"); + m.def( + "fast_bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out2) -> ()"); + m.def( + "fast_bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor input, Tensor grad_out2_halo) -> Tensor"); + m.def( + "fast_bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs) " + "-> ()"); + m.def( + "fast_bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out1) -> ()"); + m.def( + "fast_bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, Tensor[] inputs, Tensor[] outputs, " + "Tensor grad_out2, Tensor grad_out1) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fast_bottleneck_forward", &apex_fast_bottleneck_forward); + m.impl("fast_bottleneck_backward", &apex_fast_bottleneck_backward); + m.impl("fast_bottleneck_forward_init", &apex_fast_bottleneck_forward_init); + m.impl("fast_bottleneck_forward_out1", &apex_fast_bottleneck_forward_out1); + m.impl("fast_bottleneck_forward_out2", &apex_fast_bottleneck_forward_out2); + m.impl("fast_bottleneck_forward_out2_mask", &apex_fast_bottleneck_forward_out2_mask); + m.impl("fast_bottleneck_forward_out2_halo", &apex_fast_bottleneck_forward_out2_halo); + m.impl("fast_bottleneck_forward_out2_halo_corr", &apex_fast_bottleneck_forward_out2_halo_corr); + m.impl("fast_bottleneck_forward_out2_pad", &apex_fast_bottleneck_forward_out2_pad); + m.impl("fast_bottleneck_forward_rest", &apex_fast_bottleneck_forward_rest); + m.impl("fast_bottleneck_backward_init", &apex_fast_bottleneck_backward_init); + m.impl("fast_bottleneck_backward_grad_out2", &apex_fast_bottleneck_backward_grad_out2); + m.impl("fast_bottleneck_backward_grad_out1", &apex_fast_bottleneck_backward_grad_out1); + m.impl("fast_bottleneck_backward_grad_out1_mask", &apex_fast_bottleneck_backward_grad_out1_mask); + m.impl("fast_bottleneck_backward_grad_out1_halo", &apex_fast_bottleneck_backward_grad_out1_halo); + m.impl("fast_bottleneck_backward_grad_out1_halo_corr", &apex_fast_bottleneck_backward_grad_out1_halo_corr); + m.impl("fast_bottleneck_backward_wgrad2_pad", &apex_fast_bottleneck_backward_wgrad2_pad); + m.impl("fast_bottleneck_backward_wgrad2", &apex_fast_bottleneck_backward_wgrad2); + m.impl("fast_bottleneck_backward_wgrad2_halo", &apex_fast_bottleneck_backward_wgrad2_halo); + m.impl("fast_bottleneck_backward_wgrad3", &apex_fast_bottleneck_backward_wgrad3); + m.impl("fast_bottleneck_backward_wgrad1", &apex_fast_bottleneck_backward_wgrad1); + m.impl("fast_bottleneck_backward_rest", &apex_fast_bottleneck_backward_rest); } diff --git a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp index 9f6924a65..90347d18d 100644 --- a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp +++ b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp @@ -1,8 +1,7 @@ #include #include // for getcudnnhandle #include -#include -#include +#include #include #include @@ -193,7 +192,7 @@ cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, std::strin void run_conv_bias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* convstride, int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, at::Half* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -331,7 +330,7 @@ void run_conv_bias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv void run_conv_bias_mask_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, int8_t* devPtrM, at::Half* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -530,7 +529,7 @@ void run_conv_bias_mask_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int void run_conv_cscale_cbias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrS, at::Half* devPtrB, at::Half* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -732,7 +731,7 @@ void run_conv_cscale_cbias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, void run_conv_bias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrB, at::Half* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -896,7 +895,7 @@ void run_conv_bias_relu(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* void run_drelu_dscale(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR, at::Half* devPtrS, at::Half* devPtrDX) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -1032,7 +1031,7 @@ void run_drelu_dscale(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPt void run_drelu_dbias(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtrDY, at::Half* devPtrR, at::Half* devPtrDR, float* devPtrDB) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -1159,7 +1158,7 @@ void run_drelu_dbias(int64_t* dy_dim, cudnnDataType_t dataType, at::Half* devPtr void run_dconv_drelu_dbias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* pad, int64_t* convstride, int64_t* dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrR, at::Half* devPtrRg, float* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -1323,7 +1322,7 @@ void run_dconv_drelu_dbias(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64 void run_dconv(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad, int64_t* conv_stride, int64_t* conv_dilation, cudnnDataType_t dataType, at::Half* devPtrX, at::Half* devPtrW, at::Half* devPtrY, cudnnBackendDescriptorType_t mode) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { @@ -1427,7 +1426,7 @@ void run_dconv(int64_t* x_dim, int64_t* w_dim, int64_t* y_dim, int64_t* conv_pad } void run_dbias(int64_t* x_dim, cudnnDataType_t dataType, at::Half* devPtrX, float* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); std::stringstream log_buf; try { int convDim = 2; @@ -1898,16 +1897,22 @@ std::vector conv_bias_backward(std::vector inputs, int64 return outputs; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward", py::call_guard()); - m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward", - py::call_guard()); - m.def("forward_no_relu", &conv_bias_forward, "Fused Conv-Bias forward", py::call_guard()); - m.def("backward_no_relu", &conv_bias_backward, "Fused Conv-Bias backward", py::call_guard()); - m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward", - py::call_guard()); - m.def("forward_cscale_cbias_relu", &conv_cscale_cbias_relu_forward, "Fused Conv-(const)Scale-(const)Bias-ReLU", - py::call_guard()); - m.def("backward_cscale_cbias_relu", &conv_cscale_cbias_relu_backward, - "Fused Conv-(const)Scale-(const)Bias-ReLU backward", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_conv_bias_relu_forward(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_backward(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_forward_no_relu(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_backward_no_relu(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_forward_mask(Tensor[] inputs, int padding, int stride) -> Tensor[]"); + m.def("fused_conv_bias_relu_forward_cscale_cbias_relu(Tensor[] inputs, int padding, int stride) -> Tensor"); + m.def("fused_conv_bias_relu_backward_cscale_cbias_relu(Tensor[] inputs, int padding, int stride) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_conv_bias_relu_forward", &conv_bias_relu_forward); + m.impl("fused_conv_bias_relu_backward", &conv_bias_relu_backward); + m.impl("fused_conv_bias_relu_forward_no_relu", &conv_bias_forward); + m.impl("fused_conv_bias_relu_backward_no_relu", &conv_bias_backward); + m.impl("fused_conv_bias_relu_forward_mask", &conv_bias_mask_relu_forward); + m.impl("fused_conv_bias_relu_forward_cscale_cbias_relu", &conv_cscale_cbias_relu_forward); + m.impl("fused_conv_bias_relu_backward_cscale_cbias_relu", &conv_cscale_cbias_relu_backward); } diff --git a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp index 858d429d0..1c8c2d273 100644 --- a/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp +++ b/apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp @@ -1,8 +1,8 @@ #include -#include -#include +#include #include +#include #include #include "norm_sample.h" @@ -117,7 +117,41 @@ std::vector gbn_backward(const at::Tensor& x, const at::Tensor& dy, return std::vector{x_grad, scale_grad, bias_grad}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &gbn_forward, "Group batch norm forward", py::call_guard()); - m.def("backward", &gbn_backward, "Group batch backward", py::call_guard()); +namespace { +std::vector to_int_vector(at::IntArrayRef values) { + return std::vector(values.begin(), values.end()); +} + +at::Tensor apex_cudnn_gbn_forward(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_var, + const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, + double momentum, double epsilon, int64_t bn_group, int64_t rank_id, + at::IntArrayRef peer_buffers) { + return gbn_forward(x, scale, bias, running_mean, running_var, minibatch_mean, minibatch_inv_var, + static_cast(momentum), static_cast(epsilon), bn_group, static_cast(rank_id), + to_int_vector(peer_buffers)); +} + +std::vector apex_cudnn_gbn_backward(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, + const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, + double epsilon, int64_t bn_group, int64_t rank_id, + at::IntArrayRef peer_buffers) { + return gbn_backward(x, dy, scale, minibatch_mean, minibatch_inv_var, static_cast(epsilon), bn_group, + static_cast(rank_id), to_int_vector(peer_buffers)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "cudnn_gbn_forward(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_var, " + "Tensor minibatch_mean, Tensor minibatch_inv_var, float momentum, float epsilon, int bn_group, int rank_id, " + "int[] peer_buffers) -> Tensor"); + m.def( + "cudnn_gbn_backward(Tensor x, Tensor dy, Tensor scale, Tensor minibatch_mean, Tensor minibatch_inv_var, " + "float epsilon, int bn_group, int rank_id, int[] peer_buffers) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("cudnn_gbn_forward", &apex_cudnn_gbn_forward); + m.impl("cudnn_gbn_backward", &apex_cudnn_gbn_backward); } diff --git a/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp b/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp index a0da49493..1a398fefb 100644 --- a/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp +++ b/apex/contrib/csrc/cudnn_gbn/norm_sample.cpp @@ -22,10 +22,9 @@ #include "norm_sample.h" +#include #include // for getcudnnhandle #include -#include -#include #include "cudnn_backend.h" @@ -74,7 +73,7 @@ void generateStrides(const int64_t* dimA, int64_t* strideA, int64_t nbDims, cudn cudnn_frontend::ExecutionPlan run_batch_norm_forward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon, int64_t* peerDims, cudnnDataType_t data_type) { // get the cudnn handle - cudnnHandle_t handle = torch::native::getCudnnHandle(); + cudnnHandle_t handle = at::native::getCudnnHandle(); // Creates the necessary tensor descriptors int64_t tensor_stride[4]; @@ -225,7 +224,7 @@ void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void* xDevPt const std::vector& peer_devPtrs, double epsilon_val, double exponential_decay_factor, size_t peer_size, int rank_id) { // get handle - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); // get stream cudaStream_t stream; @@ -275,7 +274,7 @@ void execute_batch_norm_forward(cudnn_frontend::ExecutionPlan plan, void* xDevPt cudnn_frontend::ExecutionPlan run_batch_norm_backward(int64_t* tensorDims, int64_t* perChannelSum, int64_t* epsilon, int64_t* peerDims, cudnnDataType_t data_type) { // get cudnn handle - cudnnHandle_t handle = torch::native::getCudnnHandle(); + cudnnHandle_t handle = at::native::getCudnnHandle(); // Creates the necessary tensor descriptors int64_t tensor_stride[4]; @@ -406,7 +405,7 @@ void execute_batch_norm_backward(cudnn_frontend::ExecutionPlan plan, void* xDevP const std::vector& peer_devPtrs, void* dxDevPtr, void* dscaledevPtr, void* dbiasdevPtr, double epsilon_val, size_t peer_size, int rank_id) { // get handle - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); + cudnnHandle_t handle_ = at::native::getCudnnHandle(); // get stream cudaStream_t stream; diff --git a/apex/contrib/csrc/fmha/fmha_api.cpp b/apex/contrib/csrc/fmha/fmha_api.cpp index ed8b3de5f..c2a6a7a43 100644 --- a/apex/contrib/csrc/fmha/fmha_api.cpp +++ b/apex/contrib/csrc/fmha/fmha_api.cpp @@ -25,8 +25,9 @@ * ******************************************************************************/ +#include #include -#include +#include #include "fmha.h" @@ -81,7 +82,6 @@ std::vector mha_fwd( const at::Tensor& cu_seqlens, // b+1 const float p_dropout, const int max_seq_len, const bool is_training, const bool is_nl, const bool zero_tensors, c10::optional gen_) { - using namespace torch::indexing; auto dprops = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) || (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0)); @@ -127,12 +127,12 @@ std::vector mha_fwd( TORCH_CHECK(head_size == 64); auto opts = qkv.options(); - auto ctx = torch::empty({total, num_heads, head_size}, opts); + auto ctx = at::empty({total, num_heads, head_size}, opts); - auto s = torch::empty({batch_size, num_heads, seq_len, seq_len}, opts); + auto s = at::empty({batch_size, num_heads, seq_len, seq_len}, opts); if (zero_tensors) { - mha_fill(ctx, cu_seqlens.index({Slice(-1, None)})); + mha_fill(ctx, cu_seqlens.slice(0, -1)); } auto gen = at::get_generator_or_default(gen_, at::cuda::detail::getDefaultCUDAGenerator()); @@ -165,7 +165,6 @@ std::vector mha_bwd( const float p_dropout, // probability to drop const int max_seq_len, // max sequence length to choose the kernel const bool zero_tensors) { - using namespace torch::indexing; auto dprops = at::cuda::getCurrentDeviceProperties(); TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) || (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0)); @@ -189,10 +188,10 @@ std::vector mha_bwd( auto stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK(qkv.dtype() == torch::kFloat16); - TORCH_CHECK(dout.dtype() == torch::kFloat16); - TORCH_CHECK(softmax.dtype() == torch::kFloat16); - TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); + TORCH_CHECK(qkv.dtype() == at::kHalf); + TORCH_CHECK(dout.dtype() == at::kHalf); + TORCH_CHECK(softmax.dtype() == at::kHalf); + TORCH_CHECK(cu_seqlens.dtype() == at::kInt); TORCH_CHECK(qkv.is_cuda()); TORCH_CHECK(cu_seqlens.is_cuda()); @@ -213,10 +212,10 @@ std::vector mha_bwd( TORCH_CHECK(batch_size > 0); TORCH_CHECK(head_size == 64); - auto dqkv = torch::empty_like(qkv); + auto dqkv = at::empty_like(qkv); if (zero_tensors) { - mha_fill(dqkv, cu_seqlens.index({Slice(-1, None)})); + mha_fill(dqkv, cu_seqlens.slice(0, -1)); } Fused_multihead_attention_fprop_params params; @@ -274,7 +273,7 @@ std::vector mha_bwd_nl( auto opts = qkv.options(); - auto dqkv = torch::empty_like(qkv); + auto dqkv = at::empty_like(qkv); if (zero_tensors) { dqkv.zero_(); @@ -286,7 +285,7 @@ std::vector mha_bwd_nl( } else if (batch_size == 2) { num_chunks = 3; } - auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts); + auto dkv = at::empty({total, num_chunks, 2, num_heads, head_size}, opts); Fused_multihead_attention_fprop_params params; @@ -307,10 +306,8 @@ std::vector mha_bwd_nl( // SPLIT-K reduction of num_chunks dK, dV parts - // The equivalent of the following Pytorch code: - // using namespace torch::indexing; - // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)}); - // torch::sum_out(view_out, dkv, 1); + // The equivalent Python operation would reduce the split-K chunks into the + // K/V gradient slice of dqkv. const int hidden_size = num_heads * head_size; fmha_run_noloop_reduce(dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr(), hidden_size, batch_size, total, @@ -319,9 +316,43 @@ std::vector mha_bwd_nl( return {dqkv, softmax, dkv}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "Fused Multi-head Self-attention for BERT"; - m.def("fwd", &mha_fwd, "Forward pass", py::call_guard()); - m.def("bwd", &mha_bwd, "Backward pass", py::call_guard()); - m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)", py::call_guard()); +namespace { +std::vector apex_fmha_fwd(const at::Tensor& qkv, const at::Tensor& cu_seqlens, double p_dropout, + int64_t max_seq_len, bool is_training, bool is_nl, bool zero_tensors, + c10::optional gen) { + return mha_fwd(qkv, cu_seqlens, static_cast(p_dropout), static_cast(max_seq_len), is_training, is_nl, + zero_tensors, gen); +} + +std::vector apex_fmha_bwd(const at::Tensor& dout, const at::Tensor& qkv, at::Tensor softmax, + const at::Tensor& cu_seqlens, double p_dropout, int64_t max_seq_len, + bool zero_tensors) { + return mha_bwd(dout, qkv, softmax, cu_seqlens, static_cast(p_dropout), static_cast(max_seq_len), + zero_tensors); +} + +std::vector apex_fmha_bwd_nl(const at::Tensor& dout, const at::Tensor& qkv, at::Tensor softmax, + const at::Tensor& cu_seqlens, double p_dropout, int64_t max_seq_len, + bool zero_tensors) { + return mha_bwd_nl(dout, qkv, softmax, cu_seqlens, static_cast(p_dropout), static_cast(max_seq_len), + zero_tensors); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "fmha_fwd(Tensor qkv, Tensor cu_seqlens, float p_dropout, int max_seq_len, bool is_training, bool is_nl, " + "bool zero_tensors, Generator? gen) -> Tensor[]"); + m.def( + "fmha_bwd(Tensor dout, Tensor qkv, Tensor(a!) softmax, Tensor cu_seqlens, float p_dropout, int max_seq_len, " + "bool zero_tensors) -> Tensor[]"); + m.def( + "fmha_bwd_nl(Tensor dout, Tensor qkv, Tensor(a!) softmax, Tensor cu_seqlens, float p_dropout, " + "int max_seq_len, bool zero_tensors) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fmha_fwd", &apex_fmha_fwd); + m.impl("fmha_bwd", &apex_fmha_bwd); + m.impl("fmha_bwd_nl", &apex_fmha_bwd_nl); } diff --git a/apex/contrib/csrc/fmha/src/fmha_fill.cu b/apex/contrib/csrc/fmha/src/fmha_fill.cu index f2e0d925d..01f6e6c71 100644 --- a/apex/contrib/csrc/fmha/src/fmha_fill.cu +++ b/apex/contrib/csrc/fmha/src/fmha_fill.cu @@ -25,9 +25,9 @@ * ******************************************************************************/ +#include #include #include -#include constexpr int block_size = 512; constexpr int ctas_per_sm = 4; diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp index 4e3555c3a..c02dbd7f5 100644 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp +++ b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include #include @@ -22,13 +23,14 @@ at::Tensor focal_loss_backward_cuda(const at::Tensor& grad_output, const at::Ten std::vector focal_loss_forward(const at::Tensor& cls_output, const at::Tensor& cls_targets_at_level, const at::Tensor& num_positives_sum, const int64_t num_real_classes, - const float alpha, const float gamma, const float smoothing_factor) { + const double alpha, const double gamma, const double smoothing_factor) { CHECK_INPUT(cls_output); CHECK_INPUT(cls_targets_at_level); CHECK_INPUT(num_positives_sum); - return focal_loss_forward_cuda(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, alpha, gamma, - smoothing_factor); + return focal_loss_forward_cuda(cls_output, cls_targets_at_level, num_positives_sum, num_real_classes, + static_cast(alpha), static_cast(gamma), + static_cast(smoothing_factor)); } at::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::Tensor& partial_grad, @@ -39,9 +41,14 @@ at::Tensor focal_loss_backward(const at::Tensor& grad_output, const at::Tensor& return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &focal_loss_forward, "Focal loss calculation forward (CUDA)", - py::call_guard()); - m.def("backward", &focal_loss_backward, "Focal loss calculation backward (CUDA)", - py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "focal_loss_forward(Tensor cls_output, Tensor cls_targets_at_level, Tensor num_positives_sum, " + "int num_real_classes, float alpha, float gamma, float smoothing_factor) -> Tensor[]"); + m.def("focal_loss_backward(Tensor grad_output, Tensor partial_grad, Tensor num_positives_sum) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("focal_loss_forward", &focal_loss_forward); + m.impl("focal_loss_backward", &focal_loss_backward); } diff --git a/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp b/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp index 2a9575d14..14ef3812b 100644 --- a/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp +++ b/apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp @@ -3,8 +3,15 @@ * SPDX-License-Identifier: BSD-3-Clause */ +#include #include -#include +#include + +#include +#include +#include +#include +#include #include "group_norm_nhwc.h" #include "group_norm_nhwc_bwd_one_pass.h" @@ -21,15 +28,16 @@ } \ } while (0) -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_CHANNELS_LAST(x) TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be channels last") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) -#define CHECK_NHWC_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CHANNELS_LAST(x) +#define CHECK_TENSOR_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_TENSOR_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_TENSOR_CHANNELS_LAST(x) \ + TORCH_CHECK(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be channels last") +#define CHECK_TENSOR_INPUT(x) \ + CHECK_TENSOR_CUDA(x); \ + CHECK_TENSOR_CONTIGUOUS(x) +#define CHECK_NHWC_TENSOR_INPUT(x) \ + CHECK_TENSOR_CUDA(x); \ + CHECK_TENSOR_CHANNELS_LAST(x) static bool initialized = false; static cudaDeviceProp props; @@ -39,13 +47,13 @@ const std::unordered_set supported_c_values = {128, 256, 320, 384, 448, 2048, 2240, 2560, 2688, 3072, 3136, 3584, 4096}; const std::unordered_set supported_groups_values = {16, 32}; -std::vector group_norm_fwd(torch::Tensor input, int groups, torch::Tensor weight, torch::Tensor bias, - float eps, int passes, bool with_swish = false) { +std::vector group_norm_fwd(at::Tensor input, int groups, at::Tensor weight, at::Tensor bias, float eps, + int passes, bool with_swish = false) { if (!initialized) { CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0)); initialized = true; } - CHECK_NHWC_INPUT(input); + CHECK_NHWC_TENSOR_INPUT(input); auto stream = at::cuda::getCurrentCUDAStream(); // Achieve group norm arguments @@ -83,26 +91,26 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, torch params_fwd.with_swish = with_swish; PrecisionMode mode; - if (input.dtype() == torch::kFloat32) { - if (weight.dtype() == torch::kFloat16) { + if (input.dtype() == at::kFloat) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::FP32IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::FP32IOBF16W; } else { mode = PrecisionMode::FP32IOFP32W; } - } else if (input.dtype() == torch::kBFloat16) { - if (weight.dtype() == torch::kFloat16) { + } else if (input.dtype() == at::kBFloat16) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::BF16IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::BF16IOBF16W; } else { mode = PrecisionMode::BF16IOFP32W; } } else { - if (weight.dtype() == torch::kFloat16) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::FP16IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::FP16IOBF16W; } else { mode = PrecisionMode::FP16IOFP32W; @@ -126,13 +134,13 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, torch } // Allocate on the device. - auto red_buffer = at::empty({red_buffer_elts}, options.dtype(at::kFloat)); + auto red_buffer = at::empty({static_cast(red_buffer_elts)}, options.dtype(at::kFloat)); params_fwd.red_buffer = red_buffer.data_ptr(); // Allocate the buffer if needed. - auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt)); + auto barriers = at::zeros({static_cast(barriers_elts)}, options.dtype(at::kInt)); params_fwd.barriers = barriers.data_ptr(); - auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); + auto zeroed_red_buffer = at::zeros({static_cast(zeroed_red_buffer_elts)}, options.dtype(at::kFloat)); params_fwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr(); if (passes == 1) { @@ -145,14 +153,14 @@ std::vector group_norm_fwd(torch::Tensor input, int groups, torch return {output, sums_d}; } -std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tensor sums, torch::Tensor input, - int groups, torch::Tensor weight, torch::Tensor bias, float eps, int passes, - bool with_swish = false) { +std::vector group_norm_bwd(at::Tensor grad_output, at::Tensor sums, at::Tensor input, int groups, + at::Tensor weight, at::Tensor bias, float eps, int passes, + bool with_swish = false) { if (!initialized) { CHECK_CUDA_STATUS(cudaGetDeviceProperties(&props, 0)); initialized = true; } - CHECK_NHWC_INPUT(grad_output); + CHECK_NHWC_TENSOR_INPUT(grad_output); auto stream = at::cuda::getCurrentCUDAStream(); // Achieve group norm arguments @@ -197,26 +205,26 @@ std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tens params_bwd.with_swish = with_swish; PrecisionMode mode; - if (input.dtype() == torch::kFloat32) { - if (weight.dtype() == torch::kFloat16) { + if (input.dtype() == at::kFloat) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::FP32IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::FP32IOBF16W; } else { mode = PrecisionMode::FP32IOFP32W; } - } else if (input.dtype() == torch::kBFloat16) { - if (weight.dtype() == torch::kFloat16) { + } else if (input.dtype() == at::kBFloat16) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::BF16IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::BF16IOBF16W; } else { mode = PrecisionMode::BF16IOFP32W; } } else { - if (weight.dtype() == torch::kFloat16) { + if (weight.dtype() == at::kHalf) { mode = PrecisionMode::FP16IOFP16W; - } else if (weight.dtype() == torch::kBFloat16) { + } else if (weight.dtype() == at::kBFloat16) { mode = PrecisionMode::FP16IOBF16W; } else { mode = PrecisionMode::FP16IOFP32W; @@ -240,13 +248,13 @@ std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tens } // Allocate on the device. - auto red_buffer = at::empty({red_buffer_elts}, options.dtype(at::kFloat)); + auto red_buffer = at::empty({static_cast(red_buffer_elts)}, options.dtype(at::kFloat)); params_bwd.red_buffer = red_buffer.data_ptr(); // Allocate the buffer if needed. - auto barriers = at::zeros({barriers_elts}, options.dtype(at::kInt)); + auto barriers = at::zeros({static_cast(barriers_elts)}, options.dtype(at::kInt)); params_bwd.barriers = barriers.data_ptr(); - auto zeroed_red_buffer = at::zeros({zeroed_red_buffer_elts}, options.dtype(at::kFloat)); + auto zeroed_red_buffer = at::zeros({static_cast(zeroed_red_buffer_elts)}, options.dtype(at::kFloat)); params_bwd.zeroed_red_buffer = zeroed_red_buffer.data_ptr(); if (passes == 1) { @@ -259,7 +267,31 @@ std::vector group_norm_bwd(torch::Tensor grad_output, torch::Tens return {grad_input, grad_weight, grad_bias}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &group_norm_fwd, "NHWC group norm forward", py::call_guard()); - m.def("backward", &group_norm_bwd, "NHWC group norm backward", py::call_guard()); +namespace { +std::vector apex_group_norm_fwd(at::Tensor input, int64_t groups, at::Tensor weight, at::Tensor bias, + double eps, int64_t passes, bool with_swish) { + return group_norm_fwd(input, static_cast(groups), weight, bias, static_cast(eps), + static_cast(passes), with_swish); +} + +std::vector apex_group_norm_bwd(at::Tensor grad_output, at::Tensor sums, at::Tensor input, int64_t groups, + at::Tensor weight, at::Tensor bias, double eps, int64_t passes, + bool with_swish) { + return group_norm_bwd(grad_output, sums, input, static_cast(groups), weight, bias, static_cast(eps), + static_cast(passes), with_swish); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "group_norm_forward(Tensor input, int groups, Tensor weight, Tensor bias, float eps, int passes, " + "bool with_swish) -> Tensor[]"); + m.def( + "group_norm_backward(Tensor grad_output, Tensor sums, Tensor input, int groups, Tensor weight, Tensor bias, " + "float eps, int passes, bool with_swish) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("group_norm_forward", &apex_group_norm_fwd); + m.impl("group_norm_backward", &apex_group_norm_bwd); } diff --git a/apex/contrib/csrc/group_norm_v2/gn.cpp b/apex/contrib/csrc/group_norm_v2/gn.cpp index 23b2a8b18..51d062eec 100644 --- a/apex/contrib/csrc/group_norm_v2/gn.cpp +++ b/apex/contrib/csrc/group_norm_v2/gn.cpp @@ -1,25 +1,33 @@ #include "gn.hpp" +#include #include -#include +#include +#include +#include + +#include +#include +#include +#include namespace group_norm_v2 { -torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, bool silu, int num_groups, - std::optional mean_var_out, int sm_margin) { - if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != torch::kFloat32)) { +at::Tensor gn(at::Tensor x, at::Tensor w, at::Tensor b, float eps, bool silu, int num_groups, + std::optional mean_var_out, int sm_margin) { + if (w.dtype() != b.dtype() || (mean_var_out.has_value() && mean_var_out->dtype() != at::kFloat)) { throw std::invalid_argument("gn dtype mismatch"); } - torch::Tensor out = torch::empty_like(x); + at::Tensor out = at::empty_like(x); float* ptr_mean_var_out = mean_var_out.has_value() ? mean_var_out->data_ptr() : nullptr; cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); int device_id = at::cuda::getCurrentCUDAStream().device().index(); group_norm_v2::Meta meta; - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + if (x.dtype() == at::kHalf && w.dtype() == at::kHalf) { group_norm_v2::gn_cuda((half*)out.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, nullptr, nullptr, sm_margin, stream, device_id, &meta, true); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + } else if (x.dtype() == at::kBFloat16 && w.dtype() == at::kBFloat16) { group_norm_v2::gn_cuda((__nv_bfloat16*)out.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, nullptr, nullptr, sm_margin, stream, device_id, @@ -27,18 +35,17 @@ torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, b } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } - torch::Tensor red_buffer = - torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); - thread_local torch::Tensor barrier; + at::Tensor red_buffer = at::empty({meta.red_buffer_size}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); + thread_local at::Tensor barrier; if (barrier.size(0) < meta.barrier_size) { - barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); + barrier = at::zeros({meta.barrier_size}, at::TensorOptions().dtype(at::ScalarType::UInt32).device(at::kCUDA)); } - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + if (x.dtype() == at::kHalf && w.dtype() == at::kHalf) { group_norm_v2::gn_cuda((half*)out.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + } else if (x.dtype() == at::kBFloat16 && w.dtype() == at::kBFloat16) { group_norm_v2::gn_cuda((__nv_bfloat16*)out.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, ptr_mean_var_out, red_buffer.data_ptr(), @@ -49,24 +56,24 @@ torch::Tensor gn(torch::Tensor x, torch::Tensor w, torch::Tensor b, float eps, b return out; } -auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch::Tensor b, torch::Tensor mean_var, - float eps, bool silu, int num_groups, int sm_margin) { - if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != torch::kFloat32) { +auto gn_bwd(at::Tensor grad_output, at::Tensor x, at::Tensor w, at::Tensor b, at::Tensor mean_var, float eps, bool silu, + int num_groups, int sm_margin) { + if (w.dtype() != b.dtype() || x.dtype() != grad_output.dtype() || mean_var.dtype() != at::kFloat) { throw std::invalid_argument("gn_bwd dtype mismatch"); } - torch::Tensor grad_input = torch::empty_like(x); - torch::Tensor grad_weight = torch::empty_like(w); - torch::Tensor grad_bias = torch::empty_like(w); + at::Tensor grad_input = at::empty_like(x); + at::Tensor grad_weight = at::empty_like(w); + at::Tensor grad_bias = at::empty_like(w); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); int device_id = at::cuda::getCurrentCUDAStream().device().index(); group_norm_v2::Meta meta; - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + if (x.dtype() == at::kHalf && w.dtype() == at::kHalf) { group_norm_v2::gn_bwd_cuda((half*)grad_input.data_ptr(), (half*)grad_weight.data_ptr(), (half*)grad_bias.data_ptr(), (half*)grad_output.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, nullptr, nullptr, sm_margin, stream, device_id, &meta, true); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + } else if (x.dtype() == at::kBFloat16 && w.dtype() == at::kBFloat16) { group_norm_v2::gn_bwd_cuda((__nv_bfloat16*)grad_input.data_ptr(), (__nv_bfloat16*)grad_weight.data_ptr(), (__nv_bfloat16*)grad_bias.data_ptr(), (__nv_bfloat16*)grad_output.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), @@ -75,19 +82,18 @@ auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch:: } else { throw std::invalid_argument("gn only supports half or bfloat16 input and weight"); } - torch::Tensor red_buffer = - torch::empty({meta.red_buffer_size}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); - thread_local torch::Tensor barrier; + at::Tensor red_buffer = at::empty({meta.red_buffer_size}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); + thread_local at::Tensor barrier; if (barrier.size(0) < meta.barrier_size) { - barrier = torch::zeros({meta.barrier_size}, torch::TensorOptions().dtype(torch::kUInt32).device(torch::kCUDA)); + barrier = at::zeros({meta.barrier_size}, at::TensorOptions().dtype(at::ScalarType::UInt32).device(at::kCUDA)); } - if (x.dtype() == torch::kHalf && w.dtype() == torch::kHalf) { + if (x.dtype() == at::kHalf && w.dtype() == at::kHalf) { group_norm_v2::gn_bwd_cuda((half*)grad_input.data_ptr(), (half*)grad_weight.data_ptr(), (half*)grad_bias.data_ptr(), (half*)grad_output.data_ptr(), (half*)x.data_ptr(), (half*)w.data_ptr(), (half*)b.data_ptr(), mean_var.data_ptr(), eps, silu, x.size(0), x.size(2) * x.size(3), num_groups, x.size(1) / num_groups, red_buffer.data_ptr(), barrier.data_ptr(), sm_margin, stream, device_id, nullptr, false); - } else if (x.dtype() == torch::kBFloat16 && w.dtype() == torch::kBFloat16) { + } else if (x.dtype() == at::kBFloat16 && w.dtype() == at::kBFloat16) { group_norm_v2::gn_bwd_cuda((__nv_bfloat16*)grad_input.data_ptr(), (__nv_bfloat16*)grad_weight.data_ptr(), (__nv_bfloat16*)grad_bias.data_ptr(), (__nv_bfloat16*)grad_output.data_ptr(), (__nv_bfloat16*)x.data_ptr(), (__nv_bfloat16*)w.data_ptr(), (__nv_bfloat16*)b.data_ptr(), @@ -102,9 +108,32 @@ auto gn_bwd(torch::Tensor grad_output, torch::Tensor x, torch::Tensor w, torch:: } // namespace group_norm_v2 -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("gn", &group_norm_v2::gn, py::arg("x"), py::arg("w"), py::arg("b"), py::arg("eps"), py::arg("silu"), - py::arg("num_groups"), py::arg("mean_var_out") = py::none(), py::arg("sm_margin") = 0, ""); - m.def("gn_bwd", &group_norm_v2::gn_bwd, py::arg("grad_output"), py::arg("x"), py::arg("w"), py::arg("b"), - py::arg("mean_var"), py::arg("eps"), py::arg("silu"), py::arg("num_groups"), py::arg("sm_margin") = 0, ""); +namespace { +at::Tensor apex_group_norm_v2_gn(at::Tensor x, at::Tensor w, at::Tensor b, double eps, bool silu, int64_t num_groups, + const std::optional& mean_var_out, int64_t sm_margin) { + return group_norm_v2::gn(x, w, b, static_cast(eps), silu, static_cast(num_groups), mean_var_out, + static_cast(sm_margin)); +} + +std::vector apex_group_norm_v2_gn_bwd(at::Tensor grad_output, at::Tensor x, at::Tensor w, at::Tensor b, + at::Tensor mean_var, double eps, bool silu, int64_t num_groups, + int64_t sm_margin) { + auto grads = group_norm_v2::gn_bwd(grad_output, x, w, b, mean_var, static_cast(eps), silu, + static_cast(num_groups), static_cast(sm_margin)); + return {std::get<0>(grads), std::get<1>(grads), std::get<2>(grads)}; +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "group_norm_v2_gn(Tensor x, Tensor w, Tensor b, float eps, bool silu, int num_groups, " + "Tensor? mean_var_out, int sm_margin) -> Tensor"); + m.def( + "group_norm_v2_gn_bwd(Tensor grad_output, Tensor x, Tensor w, Tensor b, Tensor mean_var, float eps, " + "bool silu, int num_groups, int sm_margin) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("group_norm_v2_gn", &apex_group_norm_v2_gn); + m.impl("group_norm_v2_gn_bwd", &apex_group_norm_v2_gn_bwd); } diff --git a/apex/contrib/csrc/groupbn/interface.cpp b/apex/contrib/csrc/groupbn/interface.cpp index 27891bce1..cd6eeb7ef 100644 --- a/apex/contrib/csrc/groupbn/interface.cpp +++ b/apex/contrib/csrc/groupbn/interface.cpp @@ -1,18 +1,13 @@ #include #include #include -#include -#include -#include -#include +#include #include "ATen/Generator.h" #include "ATen/Scalar.h" #include "ATen/Storage.h" #include "ATen/Tensor.h" -namespace py = pybind11; - int64_t get_buffer_size(const int bn_sync_steps); void* get_data_ptr(const at::Tensor& data); @@ -72,29 +67,157 @@ int nhwc_bn_bwd_occupancy(); int nhwc_bn_addrelu_fwd_occupancy(); int nhwc_bn_addrelu_bwd_occupancy(); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("get_buffer_size", &get_buffer_size, "get_buffer_size", py::call_guard()); - m.def("get_data_ptr", &get_data_ptr, "get_data_ptr", py::call_guard()); - m.def("get_remote_data_ptr", &get_remote_data_ptr, "get_remote_data_ptr", py::call_guard()); - m.def("close_remote_data", &close_remote_data, "close_remote_data", py::call_guard()); - - m.def("bn_fwd_nhwc", &nhwc_bn_fwd_train, "bn_fwd_nhwc", py::call_guard()); - m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc", py::call_guard()); - m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc", py::call_guard()); - - m.def("bn_fwd_nhwc_occupancy", &nhwc_bn_fwd_occupancy, "bn_fwd_nhwc_occupancy", - py::call_guard()); - m.def("bn_bwd_nhwc_occupancy", &nhwc_bn_bwd_occupancy, "bn_bwd_nhwc_occupancy", - py::call_guard()); - - m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc", - py::call_guard()); - m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc", - py::call_guard()); - m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc", py::call_guard()); - - m.def("bn_addrelu_fwd_nhwc_occupancy", &nhwc_bn_addrelu_fwd_occupancy, "bn_addrelu_fwd_nhwc_occupancy", - py::call_guard()); - m.def("bn_addrelu_bwd_nhwc_occupancy", &nhwc_bn_addrelu_bwd_occupancy, "bn_addrelu_bwd_nhwc_occupancy", - py::call_guard()); +namespace { +void* optional_ptr(const c10::optional& ptr) { + return ptr.has_value() ? reinterpret_cast(ptr.value()) : nullptr; +} + +int64_t apex_bnp_get_buffer_size(int64_t bn_sync_steps) { return get_buffer_size(static_cast(bn_sync_steps)); } + +int64_t apex_bnp_get_data_ptr(const at::Tensor& data) { return reinterpret_cast(get_data_ptr(data)); } + +int64_t apex_bnp_get_remote_data_ptr(const at::Tensor& handle, int64_t offset) { + return reinterpret_cast(get_remote_data_ptr(handle, offset)); +} + +int64_t apex_bnp_bn_fwd_nhwc_occupancy() { return nhwc_bn_fwd_occupancy(); } + +int64_t apex_bnp_bn_bwd_nhwc_occupancy() { return nhwc_bn_bwd_occupancy(); } + +int64_t apex_bnp_bn_addrelu_fwd_nhwc_occupancy() { return nhwc_bn_addrelu_fwd_occupancy(); } + +int64_t apex_bnp_bn_addrelu_bwd_nhwc_occupancy() { return nhwc_bn_addrelu_bwd_occupancy(); } + +at::Tensor apex_bnp_bn_fwd_nhwc(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, + const at::Tensor& minibatch_mean, const at::Tensor& minibatch_inv_var, + const at::Tensor& ret_cta, double momentum, double epsilon, bool fuse_relu, + c10::optional my_data, c10::optional pair_data, + c10::optional pair_data2, c10::optional pair_data3, int64_t bn_group, + const at::Tensor& magic_tensor, int64_t occupancy, int64_t grid_dim_x, bool coop) { + return nhwc_bn_fwd_train(x, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, ret_cta, + static_cast(momentum), static_cast(epsilon), fuse_relu, optional_ptr(my_data), + optional_ptr(pair_data), optional_ptr(pair_data2), optional_ptr(pair_data3), + static_cast(bn_group), magic_tensor, static_cast(occupancy), + static_cast(grid_dim_x), coop); +} + +at::Tensor apex_bnp_bn_fwd_eval_nhwc(const at::Tensor& x, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, + const at::Tensor& ret_cta, int64_t bn_group, double momentum, double epsilon, + bool fuse_relu) { + return nhwc_bn_fwd_eval(x, scale, bias, running_mean, running_inv_var, ret_cta, static_cast(bn_group), + static_cast(momentum), static_cast(epsilon), fuse_relu); +} + +std::vector apex_bnp_bn_bwd_nhwc(const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& ret_cta, + double momentum, double epsilon, bool fuse_relu, + c10::optional my_data, c10::optional pair_data, + c10::optional pair_data2, c10::optional pair_data3, + int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, + int64_t grid_dim_x, bool coop) { + return nhwc_bn_bwd(x, dy, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, ret_cta, + static_cast(momentum), static_cast(epsilon), fuse_relu, optional_ptr(my_data), + optional_ptr(pair_data), optional_ptr(pair_data2), optional_ptr(pair_data3), + static_cast(bn_group), magic_tensor, static_cast(occupancy), + static_cast(grid_dim_x), coop); +} + +at::Tensor apex_bnp_bn_addrelu_fwd_nhwc(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, + const at::Tensor& ret_cta, double momentum, double epsilon, + c10::optional my_data, c10::optional pair_data, + c10::optional pair_data2, c10::optional pair_data3, + int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, + int64_t grid_dim_x, bool coop) { + return nhwc_bn_addrelu_fwd_train(x, z, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, + bitmask, ret_cta, static_cast(momentum), static_cast(epsilon), + optional_ptr(my_data), optional_ptr(pair_data), optional_ptr(pair_data2), + optional_ptr(pair_data3), static_cast(bn_group), magic_tensor, + static_cast(occupancy), static_cast(grid_dim_x), coop); +} + +at::Tensor apex_bnp_bn_addrelu_fwd_eval_nhwc(const at::Tensor& x, const at::Tensor& z, const at::Tensor& scale, + const at::Tensor& bias, const at::Tensor& running_mean, + const at::Tensor& running_inv_var, const at::Tensor& ret_cta, + int64_t bn_group, double momentum, double epsilon) { + return nhwc_bn_addrelu_fwd_eval(x, z, scale, bias, running_mean, running_inv_var, ret_cta, static_cast(bn_group), + static_cast(momentum), static_cast(epsilon)); +} + +std::vector apex_bnp_bn_addrelu_bwd_nhwc( + const at::Tensor& x, const at::Tensor& dy, const at::Tensor& scale, const at::Tensor& bias, + const at::Tensor& running_mean, const at::Tensor& running_inv_var, const at::Tensor& minibatch_mean, + const at::Tensor& minibatch_inv_var, const at::Tensor& bitmask, const at::Tensor& ret_cta, double momentum, + double epsilon, c10::optional my_data, c10::optional pair_data, c10::optional pair_data2, + c10::optional pair_data3, int64_t bn_group, const at::Tensor& magic_tensor, int64_t occupancy, + int64_t grid_dim_x, bool coop) { + return nhwc_bn_addrelu_bwd(x, dy, scale, bias, running_mean, running_inv_var, minibatch_mean, minibatch_inv_var, + bitmask, ret_cta, static_cast(momentum), static_cast(epsilon), + optional_ptr(my_data), optional_ptr(pair_data), optional_ptr(pair_data2), + optional_ptr(pair_data3), static_cast(bn_group), magic_tensor, + static_cast(occupancy), static_cast(grid_dim_x), coop); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("bnp_get_buffer_size(int bn_sync_steps) -> int"); + m.def("bnp_get_data_ptr(Tensor data) -> int"); + m.def("bnp_get_remote_data_ptr(Tensor handle, int offset) -> int"); + m.def("bnp_close_remote_data(Tensor handle) -> ()"); + m.def( + "bnp_bn_fwd_nhwc(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_inv_var, " + "Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor ret_cta, float momentum, float epsilon, " + "bool fuse_relu, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, int bn_group, " + "Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor"); + m.def( + "bnp_bn_fwd_eval_nhwc(Tensor x, Tensor scale, Tensor bias, Tensor running_mean, Tensor running_inv_var, " + "Tensor ret_cta, int bn_group, float momentum, float epsilon, bool fuse_relu) -> Tensor"); + m.def( + "bnp_bn_bwd_nhwc(Tensor x, Tensor dy, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor ret_cta, float momentum, " + "float epsilon, bool fuse_relu, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " + "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor[]"); + m.def("bnp_bn_fwd_nhwc_occupancy() -> int"); + m.def("bnp_bn_bwd_nhwc_occupancy() -> int"); + m.def( + "bnp_bn_addrelu_fwd_nhwc(Tensor x, Tensor z, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor bitmask, Tensor ret_cta, " + "float momentum, float epsilon, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " + "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor"); + m.def( + "bnp_bn_addrelu_fwd_eval_nhwc(Tensor x, Tensor z, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor ret_cta, int bn_group, float momentum, float epsilon) -> Tensor"); + m.def( + "bnp_bn_addrelu_bwd_nhwc(Tensor x, Tensor dy, Tensor scale, Tensor bias, Tensor running_mean, " + "Tensor running_inv_var, Tensor minibatch_mean, Tensor minibatch_inv_var, Tensor bitmask, Tensor ret_cta, " + "float momentum, float epsilon, int? my_data, int? pair_data, int? pair_data2, int? pair_data3, " + "int bn_group, Tensor magic_tensor, int occupancy, int grid_dim_x, bool coop) -> Tensor[]"); + m.def("bnp_bn_addrelu_fwd_nhwc_occupancy() -> int"); + m.def("bnp_bn_addrelu_bwd_nhwc_occupancy() -> int"); +} + +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { + m.impl("bnp_get_buffer_size", &apex_bnp_get_buffer_size); + m.impl("bnp_get_data_ptr", &apex_bnp_get_data_ptr); + m.impl("bnp_get_remote_data_ptr", &apex_bnp_get_remote_data_ptr); + m.impl("bnp_close_remote_data", &close_remote_data); + m.impl("bnp_bn_fwd_nhwc_occupancy", &apex_bnp_bn_fwd_nhwc_occupancy); + m.impl("bnp_bn_bwd_nhwc_occupancy", &apex_bnp_bn_bwd_nhwc_occupancy); + m.impl("bnp_bn_addrelu_fwd_nhwc_occupancy", &apex_bnp_bn_addrelu_fwd_nhwc_occupancy); + m.impl("bnp_bn_addrelu_bwd_nhwc_occupancy", &apex_bnp_bn_addrelu_bwd_nhwc_occupancy); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("bnp_bn_fwd_nhwc", &apex_bnp_bn_fwd_nhwc); + m.impl("bnp_bn_fwd_eval_nhwc", &apex_bnp_bn_fwd_eval_nhwc); + m.impl("bnp_bn_bwd_nhwc", &apex_bnp_bn_bwd_nhwc); + m.impl("bnp_bn_addrelu_fwd_nhwc", &apex_bnp_bn_addrelu_fwd_nhwc); + m.impl("bnp_bn_addrelu_fwd_eval_nhwc", &apex_bnp_bn_addrelu_fwd_eval_nhwc); + m.impl("bnp_bn_addrelu_bwd_nhwc", &apex_bnp_bn_addrelu_bwd_nhwc); } diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp index e585bc73f..702525cf8 100644 --- a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp +++ b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include #include @@ -65,17 +66,82 @@ void index_mul_2d_half_backwrad_backward(at::Tensor& grad_grad_out, at::Tensor& grad_grad_in2, in1, in2, idx1); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("float_forward", &index_mul_2d_float_forward, "index mul float calculation forward (CUDA)", - py::call_guard()); - m.def("float_backward", &index_mul_2d_float_backward, "index mul float calculation backward (CUDA)", - py::call_guard()); - m.def("float_backward_backward", &index_mul_2d_float_backwrad_backward, - "index mul float calculation backward backward (CUDA)", py::call_guard()); - m.def("half_forward", &index_mul_2d_half_forward, "index mul half calculation forward (CUDA)", - py::call_guard()); - m.def("half_backward", &index_mul_2d_half_backward, "index mul half calculation backward (CUDA)", - py::call_guard()); - m.def("half_backward_backward", &index_mul_2d_half_backwrad_backward, - "index mul half calculation backward backward (CUDA)", py::call_guard()); +void index_mul_2d_float_forward_dispatch(const at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor out_arg = out; + index_mul_2d_float_forward(out_arg, in1, in2, idx1); +} + +void index_mul_2d_float_backward_dispatch(const at::Tensor& grad_in1, const at::Tensor& grad_in2, + const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor grad_in1_arg = grad_in1; + at::Tensor grad_in2_arg = grad_in2; + index_mul_2d_float_backward(grad_in1_arg, grad_in2_arg, grad_out, in1, in2, idx1); +} + +void index_mul_2d_float_backward_backward_dispatch(const at::Tensor& grad_grad_out, const at::Tensor& grad_in1, + const at::Tensor& grad_in2, const at::Tensor& grad_out, + const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, + const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor grad_grad_out_arg = grad_grad_out; + at::Tensor grad_in1_arg = grad_in1; + at::Tensor grad_in2_arg = grad_in2; + index_mul_2d_float_backwrad_backward(grad_grad_out_arg, grad_in1_arg, grad_in2_arg, grad_out, grad_grad_in1, + grad_grad_in2, in1, in2, idx1); +} + +void index_mul_2d_half_forward_dispatch(const at::Tensor& out, const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor out_arg = out; + index_mul_2d_half_forward(out_arg, in1, in2, idx1); +} + +void index_mul_2d_half_backward_dispatch(const at::Tensor& grad_in1, const at::Tensor& grad_in2, + const at::Tensor& grad_out, const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor grad_in1_arg = grad_in1; + at::Tensor grad_in2_arg = grad_in2; + index_mul_2d_half_backward(grad_in1_arg, grad_in2_arg, grad_out, in1, in2, idx1); +} + +void index_mul_2d_half_backward_backward_dispatch(const at::Tensor& grad_grad_out, const at::Tensor& grad_in1, + const at::Tensor& grad_in2, const at::Tensor& grad_out, + const at::Tensor& grad_grad_in1, const at::Tensor& grad_grad_in2, + const at::Tensor& in1, const at::Tensor& in2, + const at::Tensor& idx1) { + at::Tensor grad_grad_out_arg = grad_grad_out; + at::Tensor grad_in1_arg = grad_in1; + at::Tensor grad_in2_arg = grad_in2; + index_mul_2d_half_backwrad_backward(grad_grad_out_arg, grad_in1_arg, grad_in2_arg, grad_out, grad_grad_in1, + grad_grad_in2, in1, in2, idx1); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("index_mul_2d_float_forward(Tensor(a!) out, Tensor in1, Tensor in2, Tensor idx1) -> ()"); + m.def( + "index_mul_2d_float_backward(Tensor(a!) grad_in1, Tensor(b!) grad_in2, Tensor grad_out, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); + m.def( + "index_mul_2d_float_backward_backward(Tensor(a!) grad_grad_out, Tensor(b!) grad_in1, " + "Tensor(c!) grad_in2, Tensor grad_out, Tensor grad_grad_in1, Tensor grad_grad_in2, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); + m.def("index_mul_2d_half_forward(Tensor(a!) out, Tensor in1, Tensor in2, Tensor idx1) -> ()"); + m.def( + "index_mul_2d_half_backward(Tensor(a!) grad_in1, Tensor(b!) grad_in2, Tensor grad_out, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); + m.def( + "index_mul_2d_half_backward_backward(Tensor(a!) grad_grad_out, Tensor(b!) grad_in1, " + "Tensor(c!) grad_in2, Tensor grad_out, Tensor grad_grad_in1, Tensor grad_grad_in2, Tensor in1, " + "Tensor in2, Tensor idx1) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("index_mul_2d_float_forward", &index_mul_2d_float_forward_dispatch); + m.impl("index_mul_2d_float_backward", &index_mul_2d_float_backward_dispatch); + m.impl("index_mul_2d_float_backward_backward", &index_mul_2d_float_backward_backward_dispatch); + m.impl("index_mul_2d_half_forward", &index_mul_2d_half_forward_dispatch); + m.impl("index_mul_2d_half_backward", &index_mul_2d_half_backward_dispatch); + m.impl("index_mul_2d_half_backward_backward", &index_mul_2d_half_backward_backward_dispatch); } diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp index 025fd52a8..e5c0a0adc 100644 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ b/apex/contrib/csrc/layer_norm/ln_api.cpp @@ -1,6 +1,7 @@ -#include +#include +#include +#include -#include "ATen/cuda/CUDAContext.h" #include "ln.h" /* @@ -30,12 +31,12 @@ BwdRegistry BWD_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// -uint32_t get_type_id(torch::Dtype dtype) { - if (dtype == torch::kFloat16) { +uint32_t get_type_id(at::ScalarType dtype) { + if (dtype == at::kHalf) { return TypeId::Value; - } else if (dtype == torch::kBFloat16) { + } else if (dtype == at::kBFloat16) { return TypeId::Value; - } else if (dtype == torch::kFloat32) { + } else if (dtype == at::kFloat) { return TypeId::Value; } else { TORCH_CHECK(false, "Type not supported: ", dtype); @@ -44,7 +45,8 @@ uint32_t get_type_id(torch::Dtype dtype) { //////////////////////////////////////////////////////////////////////////////////////////////////// -uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) { +uint64_t get_key(at::ScalarType wtype, at::ScalarType itype, at::ScalarType otype, at::ScalarType ctype, + uint64_t hidden_size) { using namespace layer_norm; uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6); @@ -56,8 +58,8 @@ uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, tor //////////////////////////////////////////////////////////////////////////////////////////////////// -layer_norm::FwdFunction& get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, - torch::Dtype ctype, uint32_t hidden_size) { +layer_norm::FwdFunction& get_fwd_launcher(at::ScalarType wtype, at::ScalarType itype, at::ScalarType otype, + at::ScalarType ctype, uint32_t hidden_size) { auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); if (iter != layer_norm::FWD_FUNCS.end()) { return iter->second; @@ -68,8 +70,8 @@ layer_norm::FwdFunction& get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype //////////////////////////////////////////////////////////////////////////////////////////////////// -layer_norm::BwdFunction& get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, - torch::Dtype ctype, uint32_t hidden_size) { +layer_norm::BwdFunction& get_bwd_launcher(at::ScalarType wtype, at::ScalarType itype, at::ScalarType otype, + at::ScalarType ctype, uint32_t hidden_size) { auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); if (iter != layer_norm::BWD_FUNCS.end()) { return iter->second; @@ -87,7 +89,7 @@ std::vector ln_fwd(const at::Tensor& x, // BxSxhidden_size auto itype = x.scalar_type(); auto wtype = gamma.scalar_type(); auto otype = wtype; - auto ctype = torch::kFloat32; + auto ctype = at::kFloat; TORCH_CHECK(beta.scalar_type() == wtype); @@ -110,10 +112,10 @@ std::vector ln_fwd(const at::Tensor& x, // BxSxhidden_size auto opts = x.options(); - auto z = torch::empty(sizes, opts.dtype(otype)); + auto z = at::empty(sizes, opts.dtype(otype)); - auto mu = torch::empty({rows}, opts.dtype(ctype)); - auto rsigma = torch::empty({rows}, opts.dtype(ctype)); + auto mu = at::empty({rows}, opts.dtype(ctype)); + auto rsigma = at::empty({rows}, opts.dtype(ctype)); layer_norm::LaunchParams launch_params; @@ -142,8 +144,8 @@ std::vector ln_fwd(const at::Tensor& x, // BxSxhidden_size if (launch_params.barrier_size > 0) { auto options = x.options(); - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + barrier = at::zeros(launch_params.barrier_size, options.dtype(at::kInt)); + workspace = at::empty(launch_params.workspace_bytes, options.dtype(at::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } @@ -157,15 +159,15 @@ std::vector ln_fwd(const at::Tensor& x, // BxSxhidden_size //////////////////////////////////////////////////////////////////////////////////////////////////// std::vector ln_bwd(const at::Tensor& dz, // BxSxhidden_size const at::Tensor& x_or_z, // BxSxhidden_size - c10::optional& mu_, // BxS, FP32! + const c10::optional& mu_, // BxS, FP32! const at::Tensor& rsigma, // BxS, FP32! const at::Tensor& gamma, // hidden_size - c10::optional& beta_, // hidden_size + const c10::optional& beta_, // hidden_size bool memory_efficient) { auto itype = x_or_z.scalar_type(); auto wtype = gamma.scalar_type(); auto otype = wtype; - auto ctype = torch::kFloat32; + auto ctype = at::kFloat; TORCH_CHECK(dz.dtype() == otype); TORCH_CHECK(rsigma.dtype() == ctype); @@ -200,9 +202,9 @@ std::vector ln_bwd(const at::Tensor& dz, // BxSxh auto options = x_or_z.options(); - auto dx = torch::empty_like(x_or_z); - auto dgamma = torch::empty_like(gamma); - auto dbeta = torch::empty_like(gamma); + auto dx = at::empty_like(x_or_z); + auto dgamma = at::empty_like(gamma); + auto dbeta = at::empty_like(gamma); layer_norm::LaunchParams launch_params; launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); @@ -212,8 +214,8 @@ std::vector ln_bwd(const at::Tensor& dz, // BxSxh launcher(launch_params, true); - auto dgamma_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); - auto dbeta_part = torch::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); + auto dgamma_part = at::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); + auto dbeta_part = at::empty({launch_params.params.ctas_per_col, hidden_size}, options.dtype(ctype)); at::Tensor workspace, barrier; layer_norm::BwdParams& params = launch_params.params; @@ -237,8 +239,8 @@ std::vector ln_bwd(const at::Tensor& dz, // BxSxh if (launch_params.barrier_size > 0) { // TODO Any way to avoid this? - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); + barrier = at::zeros(launch_params.barrier_size, options.dtype(at::kInt)); + workspace = at::empty(launch_params.workspace_bytes, options.dtype(at::kChar)); params.workspace = workspace.data_ptr(); params.barrier = barrier.data_ptr(); } @@ -250,8 +252,30 @@ std::vector ln_bwd(const at::Tensor& dz, // BxSxh //////////////////////////////////////////////////////////////////////////////////////////////////// -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "CUDA LayerNorm"; - m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel", py::call_guard()); - m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel", py::call_guard()); +namespace { +std::vector apex_fast_layer_norm_ln_fwd(const at::Tensor& x, const at::Tensor& gamma, + const at::Tensor& beta, double epsilon) { + return ln_fwd(x, gamma, beta, static_cast(epsilon)); } + +std::vector apex_fast_layer_norm_ln_bwd(const at::Tensor& dz, const at::Tensor& x_or_z, + const c10::optional& mu, const at::Tensor& rsigma, + const at::Tensor& gamma, const c10::optional& beta, + bool memory_efficient) { + return ln_bwd(dz, x_or_z, mu, rsigma, gamma, beta, memory_efficient); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fast_layer_norm_ln_fwd(Tensor x, Tensor gamma, Tensor beta, float epsilon) -> Tensor[]"); + m.def( + "fast_layer_norm_ln_bwd(Tensor dz, Tensor x_or_z, Tensor? mu, Tensor rsigma, Tensor gamma, Tensor? beta, " + "bool memory_efficient) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fast_layer_norm_ln_fwd", &apex_fast_layer_norm_ln_fwd); + m.impl("fast_layer_norm_ln_bwd", &apex_fast_layer_norm_ln_bwd); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu index a18cf12fc..abf64b052 100644 --- a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif +#include #include #include -#include #include #include @@ -19,8 +21,8 @@ namespace multihead_attn { namespace fused_softmax { namespace additive_mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const half* pad_mask, - float dropout_prob) { +std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const half* pad_mask, + float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; const int q_seq_len = input.size(1); @@ -36,11 +38,11 @@ std::vector fwd_cuda(bool is_training, int heads, torch::Tensor c // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = input.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* input_ptr = static_cast(input.data_ptr()); @@ -71,8 +73,8 @@ std::vector fwd_cuda(bool is_training, int heads, torch::Tensor c return {dropout_results, dropout_mask, softmax_results}; } -torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, float dropout_prob) { +at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); const int k_seq_len = q_seq_len; @@ -84,7 +86,7 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tens cublasSetStream(handle, stream); // Output Tensor Allocations - // torch::Tensor input_grads = torch::empty_like(output_grads); + // at::Tensor input_grads = at::empty_like(output_grads); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu index 2f93f108a..99dd6fdde 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif +#include #include #include -#include #include #include @@ -18,10 +20,10 @@ namespace multihead_attn { namespace encdec { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -50,15 +52,15 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs_q.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); - torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs_q, act_options); + at::Tensor input_lin_q_results = at::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); + at::Tensor input_lin_kv_results = at::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor outputs = at::empty_like(inputs_q, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); @@ -141,13 +143,12 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -174,16 +175,16 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_q_grads = torch::empty_like(inputs_q); - torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); - torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); - torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_q_grads = at::empty_like(inputs_q); + at::Tensor input_kv_grads = at::empty_like(inputs_kv); + at::Tensor input_weight_q_grads = at::empty_like(input_weights_q); + at::Tensor input_weight_kv_grads = at::empty_like(input_weights_kv); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); - at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_q_output_grads = at::empty_like(input_lin_q_results); + at::Tensor input_lin_kv_output_grads = at::empty_like(input_lin_kv_results); auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu index 4376ededb..2ec6a3e7a 100644 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif +#include #include #include -#include #include #include @@ -19,11 +21,11 @@ namespace multihead_attn { namespace encdec_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -53,22 +55,22 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs_q.options().requires_grad(false); - auto lyr_nrm_options = act_options.dtype(torch::kFloat32); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor lyr_nrm_mean = torch::empty({batches_q}, lyr_nrm_options); - torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options); - torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options); - - torch::Tensor input_lin_q_results = torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); - torch::Tensor input_lin_kv_results = torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options); - torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options); - torch::Tensor outputs = torch::empty_like(inputs_q, act_options); + auto lyr_nrm_options = act_options.dtype(at::kFloat); + auto mask_options = act_options.dtype(at::kByte); + + at::Tensor lyr_nrm_mean = at::empty({batches_q}, lyr_nrm_options); + at::Tensor lyr_nrm_invvar = at::empty({batches_q}, lyr_nrm_options); + at::Tensor lyr_nrm_results = at::empty_like(inputs_q, act_options); + + at::Tensor input_lin_q_results = at::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); + at::Tensor input_lin_kv_results = at::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor output_lin_results = at::empty_like(inputs_q, act_options); + at::Tensor dropout_add_mask = at::empty_like(inputs_q, mask_options); + at::Tensor outputs = at::empty_like(inputs_q, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); @@ -175,16 +177,16 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he matmul2_results, dropout_add_mask, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, - torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, - float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, + at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, + float dropout_prob) { const int embed_dim = inputs_q.size(2); const int sequences = inputs_q.size(1); const int q_seq_len = inputs_q.size(0); @@ -212,20 +214,20 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_q_grads = torch::empty_like(inputs_q); - torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); - torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); - torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); - torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); - torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_q_grads = at::empty_like(inputs_q); + at::Tensor input_kv_grads = at::empty_like(inputs_kv); + at::Tensor lyr_nrm_gamma_grads = at::empty_like(lyr_nrm_gamma_weights); + at::Tensor lyr_nrm_beta_grads = at::empty_like(lyr_nrm_beta_weights); + at::Tensor input_weight_q_grads = at::empty_like(input_weights_q); + at::Tensor input_weight_kv_grads = at::empty_like(input_weights_kv); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor dropout_add_grads = torch::empty_like(output_grads); - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); - at::Tensor input_lin_kv_output_grads = torch::empty_like(input_lin_kv_results); - at::Tensor input_lin_q_grads = torch::empty_like(inputs_q); + at::Tensor dropout_add_grads = at::empty_like(output_grads); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_q_output_grads = at::empty_like(input_lin_q_results); + at::Tensor input_lin_kv_output_grads = at::empty_like(input_lin_kv_results); + at::Tensor input_lin_q_grads = at::empty_like(inputs_q); auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); diff --git a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu index bde131da5..0b58c33b8 100644 --- a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif +#include #include #include -#include #include #include @@ -17,8 +19,8 @@ namespace multihead_attn { namespace fused_softmax { namespace mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const uint8_t* pad_mask, - float dropout_prob) { +std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const uint8_t* pad_mask, + float dropout_prob) { const int attn_batches = input.size(0); const int sequences = attn_batches / heads; const int q_seq_len = input.size(1); @@ -34,11 +36,11 @@ std::vector fwd_cuda(bool is_training, int heads, torch::Tensor c // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = input.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) void* input_ptr = static_cast(input.data_ptr()); @@ -69,8 +71,8 @@ std::vector fwd_cuda(bool is_training, int heads, torch::Tensor c return {dropout_results, dropout_mask, softmax_results}; } -torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob) { +at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob) { const int attn_batches = output_grads.size(0); const int q_seq_len = output_grads.size(1); const int k_seq_len = q_seq_len; @@ -82,7 +84,7 @@ torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tens cublasSetStream(handle, stream); // Output Tensor Allocations - // torch::Tensor input_grads = torch::empty_like(output_grads); + // at::Tensor input_grads = at::empty_like(output_grads); // Apply Dropout Mask and Scale by Dropout Probability // Softmax Grad diff --git a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp index e1d3bcc29..8e66c5f60 100644 --- a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp +++ b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp @@ -1,5 +1,6 @@ +#include #include -#include +#include #include @@ -13,14 +14,14 @@ namespace multihead_attn { namespace fused_softmax { namespace additive_mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const half* pad_mask, - float dropout_prob); +std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const half* pad_mask, + float dropout_prob); -torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, float dropout_prob); +at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool is_training, int heads, torch::Tensor const& input, - torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd(bool use_mask, bool is_training, int heads, at::Tensor const& input, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); if (use_mask) { @@ -32,8 +33,8 @@ std::vector fwd(bool use_mask, bool is_training, int heads, torch dropout_prob); } -torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, float dropout_prob) { +at::Tensor bwd(bool use_mask, int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); @@ -48,14 +49,14 @@ torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, t } // namespace additive_mask_softmax_dropout namespace mask_softmax_dropout { -std::vector fwd_cuda(bool is_training, int heads, torch::Tensor const& input, const uint8_t* pad_mask, - float dropout_prob); +std::vector fwd_cuda(bool is_training, int heads, at::Tensor const& input, const uint8_t* pad_mask, + float dropout_prob); -torch::Tensor bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob); +at::Tensor bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, const uint8_t* padding_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool is_training, int heads, torch::Tensor const& input, - torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd(bool use_mask, bool is_training, int heads, at::Tensor const& input, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.scalar_type() == at::ScalarType::Half, "Only HALF is supported"); @@ -68,8 +69,8 @@ std::vector fwd(bool use_mask, bool is_training, int heads, torch dropout_prob); } -torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, torch::Tensor const& softmax_results, - torch::Tensor const& dropout_mask, torch::Tensor const& padding_mask, float dropout_prob) { +at::Tensor bwd(bool use_mask, int heads, at::Tensor const& output_grads, at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, at::Tensor const& padding_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_mask.dim() == 3, "expected 3D tensor"); @@ -89,22 +90,21 @@ torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const& output_grads, t namespace encdec { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob); -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob); +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights_q.dim() == 2, "expected 2D tensor"); @@ -126,13 +126,12 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - float dropout_prob) { +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -170,28 +169,28 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace encdec_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob); - -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, - torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, - float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs_q, torch::Tensor const& inputs_kv, - torch::Tensor const& lyr_nrm_gamma_weights, torch::Tensor const& lyr_nrm_beta_weights, - torch::Tensor const& input_weights_q, torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob); + +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, + at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, + float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs_q.dim() == 3, "expected 3D tensor"); TORCH_CHECK(inputs_kv.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); @@ -218,16 +217,15 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_q_results, torch::Tensor const& input_lin_kv_results, - torch::Tensor const& lyr_nrm_results, torch::Tensor const& lyr_nrm_mean, - torch::Tensor const& lyr_nrm_invvar, torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, torch::Tensor const& dropout_add_mask, - float dropout_prob) { +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, + at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs_q, at::Tensor const& inputs_kv, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -278,19 +276,19 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace self { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob); +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob); -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob); +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob); -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); @@ -308,11 +306,11 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob) { +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -342,23 +340,23 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace self_bias { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& input_biases, torch::Tensor const& output_biases, - const uint8_t* pad_mask, float dropout_prob); - -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - // torch::Tensor const& input_biases, - // torch::Tensor const& output_biases, - torch::Tensor const& dropout_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& input_biases, - torch::Tensor const& output_biases, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, + const uint8_t* pad_mask, float dropout_prob); + +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + // at::Tensor const& input_biases, + // at::Tensor const& output_biases, + at::Tensor const& dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, at::Tensor const& pad_mask, + float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); @@ -376,11 +374,11 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob) { +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -410,25 +408,25 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace self_bias_additive_mask { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& input_biases, torch::Tensor const& output_biases, - const half* pad_mask, float dropout_prob); - -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, - // torch::Tensor const& softmax_results, - torch::Tensor const& bmm1_results, torch::Tensor const& pad_mask, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - // torch::Tensor const& input_biases, - // torch::Tensor const& output_biases, - torch::Tensor const& dropout_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& input_biases, - torch::Tensor const& output_biases, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, const half* pad_mask, + float dropout_prob); + +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, + // at::Tensor const& softmax_results, + at::Tensor const& bmm1_results, at::Tensor const& pad_mask, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + // at::Tensor const& input_biases, + // at::Tensor const& output_biases, + at::Tensor const& dropout_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, at::Tensor const& pad_mask, + float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input_weights.dim() == 2, "expected 2D tensor"); TORCH_CHECK(output_weights.dim() == 2, "expected 2D tensor"); @@ -447,12 +445,11 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& bmm1_results, - torch::Tensor const& pad_mask, torch::Tensor const& input_lin_results, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - float dropout_prob) { +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& bmm1_results, + at::Tensor const& pad_mask, at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -481,24 +478,24 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor namespace self_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob); - -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results, - torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, - torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - torch::Tensor const& dropout_add_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob); + +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, + at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, float dropout_prob); + +std::vector fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& pad_mask, float dropout_prob) { TORCH_CHECK(inputs.dim() == 3, "expected 3D tensor"); TORCH_CHECK(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); TORCH_CHECK(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); @@ -520,14 +517,13 @@ std::vector fwd(bool use_mask, bool use_time_mask, bool is_traini output_weights, use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, dropout_prob); } -std::vector bwd(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results, - torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, - torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - torch::Tensor const& dropout_add_mask, float dropout_prob) { +std::vector bwd(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, float dropout_prob) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(matmul2_results.dim() == 3, "expected 3D tensor"); TORCH_CHECK(dropout_results.dim() == 3, "expected 3D tensor"); @@ -569,42 +565,267 @@ std::vector bwd(int heads, torch::Tensor const& output_grads, tor } // end namespace self_norm_add } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("additive_mask_softmax_dropout_forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, - "Self Multihead Attention masked softmax dropout -- Forward.", py::call_guard()); - m.def("additive_mask_softmax_dropout_backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, - "Self Multihead Attention masked softmax dropout -- Backward.", py::call_guard()); - m.def("mask_softmax_dropout_forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, - "Self Multihead Attention masked softmax dropout -- Forward.", py::call_guard()); - m.def("mask_softmax_dropout_backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, - "Self Multihead Attention masked softmax dropout -- Backward.", py::call_guard()); - m.def("encdec_multihead_attn_forward", &multihead_attn::encdec::cublas_gemmex::fwd, - "Encdec Multihead Attention Forward.", py::call_guard()); - m.def("encdec_multihead_attn_backward", &multihead_attn::encdec::cublas_gemmex::bwd, - "Encdec Multihead Attention Backward.", py::call_guard()); - m.def("encdec_multihead_attn_norm_add_forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, - "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.", - py::call_guard()); - m.def("encdec_multihead_attn_norm_add_backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, - "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.", - py::call_guard()); - m.def("self_attn_forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward.", - py::call_guard()); - m.def("self_attn_backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward.", - py::call_guard()); - m.def("self_attn_bias_forward", &multihead_attn::self_bias::cublas_gemmex::fwd, - "Self Multihead Attention with Bias -- Forward.", py::call_guard()); - m.def("self_attn_bias_backward", &multihead_attn::self_bias::cublas_gemmex::bwd, - "Self Multihead Attention with Bias -- Backward.", py::call_guard()); - m.def("self_attn_bias_additive_mask_forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, - "Self Multihead Attention with Bias -- Forward.", py::call_guard()); - m.def("self_attn_bias_additive_mask_backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, - "Self Multihead Attention with Bias -- Backward.", py::call_guard()); - m.def("self_attn_norm_add_forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, - "Self Multihead Attention Plus Layer Norm and Residual Add Forward.", py::call_guard()); - m.def("self_attn_norm_add_backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, - "Self Multihead Attention Plus Layer Norm and Residual Add Backward.", - py::call_guard()); +namespace { +int as_int(int64_t value) { return static_cast(value); } + +float as_float(double value) { return static_cast(value); } + +std::vector apex_fast_multihead_attn_additive_mask_softmax_dropout_forward(bool use_mask, bool is_training, + int64_t heads, + at::Tensor const& input, + at::Tensor const& pad_mask, + double dropout_prob) { + return multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd(use_mask, is_training, as_int(heads), input, + pad_mask, as_float(dropout_prob)); +} + +at::Tensor apex_fast_multihead_attn_additive_mask_softmax_dropout_backward(bool use_mask, int64_t heads, + at::Tensor const& output_grads, + at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, + double dropout_prob) { + return multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd( + use_mask, as_int(heads), output_grads, softmax_results, dropout_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_mask_softmax_dropout_forward(bool use_mask, bool is_training, + int64_t heads, at::Tensor const& input, + at::Tensor const& pad_mask, + double dropout_prob) { + return multihead_attn::fused_softmax::mask_softmax_dropout::fwd(use_mask, is_training, as_int(heads), input, pad_mask, + as_float(dropout_prob)); +} + +at::Tensor apex_fast_multihead_attn_mask_softmax_dropout_backward(bool use_mask, int64_t heads, + at::Tensor const& output_grads, + at::Tensor const& softmax_results, + at::Tensor const& dropout_mask, + at::Tensor const& padding_mask, double dropout_prob) { + return multihead_attn::fused_softmax::mask_softmax_dropout::bwd( + use_mask, as_int(heads), output_grads, softmax_results, dropout_mask, padding_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_encdec_multihead_attn_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::encdec::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs_q, + inputs_kv, input_weights_q, input_weights_kv, output_weights, + pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_encdec_multihead_attn_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, at::Tensor const& input_weights_q, + at::Tensor const& input_weights_kv, at::Tensor const& output_weights, at::Tensor const& dropout_mask, + double dropout_prob) { + return multihead_attn::encdec::cublas_gemmex::bwd(as_int(heads), output_grads, matmul2_results, dropout_results, + softmax_results, input_lin_q_results, input_lin_kv_results, + inputs_q, inputs_kv, input_weights_q, input_weights_kv, + output_weights, dropout_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_encdec_multihead_attn_norm_add_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs_q, + at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, at::Tensor const& output_weights, + at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::encdec_norm_add::cublas_gemmex::fwd( + use_mask, use_time_mask, is_training, as_int(heads), inputs_q, inputs_kv, lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_encdec_multihead_attn_norm_add_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_q_results, at::Tensor const& input_lin_kv_results, + at::Tensor const& lyr_nrm_results, at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, + at::Tensor const& inputs_q, at::Tensor const& inputs_kv, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights_q, at::Tensor const& input_weights_kv, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, + double dropout_prob) { + return multihead_attn::encdec_norm_add::cublas_gemmex::bwd( + as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_q_results, + input_lin_kv_results, lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q, inputs_kv, lyr_nrm_gamma_weights, + lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, dropout_mask, dropout_add_mask, + as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_forward(bool use_mask, bool use_time_mask, bool is_training, + int64_t heads, at::Tensor const& inputs, + at::Tensor const& input_weights, + at::Tensor const& output_weights, + at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::self::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs, + input_weights, output_weights, pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& dropout_mask, + double dropout_prob) { + return multihead_attn::self::cublas_gemmex::bwd(as_int(heads), output_grads, matmul2_results, dropout_results, + softmax_results, input_lin_results, inputs, input_weights, + output_weights, dropout_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_bias_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& input_biases, + at::Tensor const& output_biases, at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::self_bias::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs, + input_weights, output_weights, input_biases, output_biases, + pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_bias_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& dropout_mask, + double dropout_prob) { + return multihead_attn::self_bias::cublas_gemmex::bwd(as_int(heads), output_grads, matmul2_results, dropout_results, + softmax_results, input_lin_results, inputs, input_weights, + output_weights, dropout_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_bias_additive_mask_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, at::Tensor const& input_biases, + at::Tensor const& output_biases, at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd( + use_mask, use_time_mask, is_training, as_int(heads), inputs, input_weights, output_weights, input_biases, + output_biases, pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_bias_additive_mask_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& bmm1_results, at::Tensor const& pad_mask, at::Tensor const& input_lin_results, + at::Tensor const& inputs, at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, double dropout_prob) { + return multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd( + as_int(heads), output_grads, matmul2_results, dropout_results, bmm1_results, pad_mask, input_lin_results, inputs, + input_weights, output_weights, dropout_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_norm_add_forward( + bool use_mask, bool use_time_mask, bool is_training, int64_t heads, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& pad_mask, double dropout_prob) { + return multihead_attn::self_norm_add::cublas_gemmex::fwd(use_mask, use_time_mask, is_training, as_int(heads), inputs, + lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, + output_weights, pad_mask, as_float(dropout_prob)); +} + +std::vector apex_fast_multihead_attn_self_attn_norm_add_backward( + int64_t heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, at::Tensor const& dropout_results, + at::Tensor const& softmax_results, at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, at::Tensor const& dropout_add_mask, + double dropout_prob) { + return multihead_attn::self_norm_add::cublas_gemmex::bwd( + as_int(heads), output_grads, matmul2_results, dropout_results, softmax_results, input_lin_results, + lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, + output_weights, dropout_mask, dropout_add_mask, as_float(dropout_prob)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "fast_multihead_attn_additive_mask_softmax_dropout_forward(bool use_mask, bool is_training, int heads, " + "Tensor input, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_additive_mask_softmax_dropout_backward(bool use_mask, int heads, Tensor output_grads, " + "Tensor softmax_results, Tensor dropout_mask, float dropout_prob) -> Tensor"); + m.def( + "fast_multihead_attn_mask_softmax_dropout_forward(bool use_mask, bool is_training, int heads, Tensor input, " + "Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_mask_softmax_dropout_backward(bool use_mask, int heads, Tensor output_grads, " + "Tensor softmax_results, Tensor dropout_mask, Tensor padding_mask, float dropout_prob) -> Tensor"); + m.def( + "fast_multihead_attn_encdec_multihead_attn_forward(bool use_mask, bool use_time_mask, bool is_training, " + "int heads, Tensor inputs_q, Tensor inputs_kv, Tensor input_weights_q, Tensor input_weights_kv, " + "Tensor output_weights, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_encdec_multihead_attn_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_q_results, Tensor input_lin_kv_results, " + "Tensor inputs_q, Tensor inputs_kv, Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, " + "Tensor dropout_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_encdec_multihead_attn_norm_add_forward(bool use_mask, bool use_time_mask, " + "bool is_training, int heads, Tensor inputs_q, Tensor inputs_kv, Tensor lyr_nrm_gamma_weights, " + "Tensor lyr_nrm_beta_weights, Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, " + "Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_encdec_multihead_attn_norm_add_backward(int heads, Tensor output_grads, " + "Tensor matmul2_results, Tensor dropout_results, Tensor softmax_results, Tensor input_lin_q_results, " + "Tensor input_lin_kv_results, Tensor lyr_nrm_results, Tensor lyr_nrm_mean, Tensor lyr_nrm_invvar, " + "Tensor inputs_q, Tensor inputs_kv, Tensor lyr_nrm_gamma_weights, Tensor lyr_nrm_beta_weights, " + "Tensor input_weights_q, Tensor input_weights_kv, Tensor output_weights, Tensor dropout_mask, " + "Tensor dropout_add_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_forward(bool use_mask, bool use_time_mask, bool is_training, int heads, " + "Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor pad_mask, float dropout_prob) " + "-> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor inputs, " + "Tensor input_weights, Tensor output_weights, Tensor dropout_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_bias_forward(bool use_mask, bool use_time_mask, bool is_training, int heads, " + "Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor input_biases, Tensor output_biases, " + "Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_bias_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor inputs, " + "Tensor input_weights, Tensor output_weights, Tensor dropout_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_bias_additive_mask_forward(bool use_mask, bool use_time_mask, " + "bool is_training, int heads, Tensor inputs, Tensor input_weights, Tensor output_weights, " + "Tensor input_biases, Tensor output_biases, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_bias_additive_mask_backward(int heads, Tensor output_grads, " + "Tensor matmul2_results, Tensor dropout_results, Tensor bmm1_results, Tensor pad_mask, " + "Tensor input_lin_results, Tensor inputs, Tensor input_weights, Tensor output_weights, Tensor dropout_mask, " + "float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_norm_add_forward(bool use_mask, bool use_time_mask, bool is_training, " + "int heads, Tensor inputs, Tensor lyr_nrm_gamma_weights, Tensor lyr_nrm_beta_weights, " + "Tensor input_weights, Tensor output_weights, Tensor pad_mask, float dropout_prob) -> Tensor[]"); + m.def( + "fast_multihead_attn_self_attn_norm_add_backward(int heads, Tensor output_grads, Tensor matmul2_results, " + "Tensor dropout_results, Tensor softmax_results, Tensor input_lin_results, Tensor lyr_nrm_results, " + "Tensor lyr_nrm_mean, Tensor lyr_nrm_invvar, Tensor inputs, Tensor lyr_nrm_gamma_weights, " + "Tensor lyr_nrm_beta_weights, Tensor input_weights, Tensor output_weights, Tensor dropout_mask, " + "Tensor dropout_add_mask, float dropout_prob) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fast_multihead_attn_additive_mask_softmax_dropout_forward", + &apex_fast_multihead_attn_additive_mask_softmax_dropout_forward); + m.impl("fast_multihead_attn_additive_mask_softmax_dropout_backward", + &apex_fast_multihead_attn_additive_mask_softmax_dropout_backward); + m.impl("fast_multihead_attn_mask_softmax_dropout_forward", &apex_fast_multihead_attn_mask_softmax_dropout_forward); + m.impl("fast_multihead_attn_mask_softmax_dropout_backward", &apex_fast_multihead_attn_mask_softmax_dropout_backward); + m.impl("fast_multihead_attn_encdec_multihead_attn_forward", &apex_fast_multihead_attn_encdec_multihead_attn_forward); + m.impl("fast_multihead_attn_encdec_multihead_attn_backward", + &apex_fast_multihead_attn_encdec_multihead_attn_backward); + m.impl("fast_multihead_attn_encdec_multihead_attn_norm_add_forward", + &apex_fast_multihead_attn_encdec_multihead_attn_norm_add_forward); + m.impl("fast_multihead_attn_encdec_multihead_attn_norm_add_backward", + &apex_fast_multihead_attn_encdec_multihead_attn_norm_add_backward); + m.impl("fast_multihead_attn_self_attn_forward", &apex_fast_multihead_attn_self_attn_forward); + m.impl("fast_multihead_attn_self_attn_backward", &apex_fast_multihead_attn_self_attn_backward); + m.impl("fast_multihead_attn_self_attn_bias_forward", &apex_fast_multihead_attn_self_attn_bias_forward); + m.impl("fast_multihead_attn_self_attn_bias_backward", &apex_fast_multihead_attn_self_attn_bias_backward); + m.impl("fast_multihead_attn_self_attn_bias_additive_mask_forward", + &apex_fast_multihead_attn_self_attn_bias_additive_mask_forward); + m.impl("fast_multihead_attn_self_attn_bias_additive_mask_backward", + &apex_fast_multihead_attn_self_attn_bias_additive_mask_backward); + m.impl("fast_multihead_attn_self_attn_norm_add_forward", &apex_fast_multihead_attn_self_attn_norm_add_forward); + m.impl("fast_multihead_attn_self_attn_norm_add_backward", &apex_fast_multihead_attn_self_attn_norm_add_backward); } #undef CHECK_CUDA diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu index 7be40eba5..e6c229f64 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif +#include #include #include -#include #include #include @@ -18,10 +20,10 @@ namespace multihead_attn { namespace self_bias_additive_mask { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& input_biases, torch::Tensor const& output_biases, - const half* pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, const half* pad_mask, + float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -47,14 +49,14 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); + at::Tensor input_lin_results = at::empty({q_seq_len, sequences, output_lin_dim}, act_options); + at::Tensor bmm1_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor outputs = at::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); @@ -121,12 +123,11 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he return {input_lin_results, bmm1_results, dropout_results, dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& bmm1_results, - torch::Tensor const& pad_mask, torch::Tensor const& input_lin_results, - torch::Tensor const& inputs, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& bmm1_results, + at::Tensor const& pad_mask, at::Tensor const& input_lin_results, + at::Tensor const& inputs, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -149,13 +150,13 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_grads = at::empty_like(inputs); + at::Tensor input_weight_grads = at::empty_like(input_weights); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_output_grads = at::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu index 8e39a28d6..67f49b5ab 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif +#include #include #include -#include #include #include @@ -18,10 +20,10 @@ namespace multihead_attn { namespace self_bias { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& input_biases, torch::Tensor const& output_biases, - const uint8_t* pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& input_biases, at::Tensor const& output_biases, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -47,14 +49,14 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); + at::Tensor input_lin_results = at::empty({q_seq_len, sequences, output_lin_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor outputs = at::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); @@ -132,11 +134,11 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -159,13 +161,13 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_grads = at::empty_like(inputs); + at::Tensor input_weight_grads = at::empty_like(input_weights); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_output_grads = at::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu index 929e7e791..c8720d0d2 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif +#include #include #include -#include #include #include @@ -18,9 +20,9 @@ namespace multihead_attn { namespace self { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - const uint8_t* pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -45,14 +47,14 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); + auto mask_options = act_options.dtype(at::kByte); - torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); + at::Tensor input_lin_results = at::empty({q_seq_len, sequences, output_lin_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor outputs = at::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); @@ -125,11 +127,11 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he return {input_lin_results, softmax_results, dropout_results, dropout_mask, matmul2_results, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& inputs, - torch::Tensor const& input_weights, torch::Tensor const& output_weights, - torch::Tensor const& dropout_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& inputs, + at::Tensor const& input_weights, at::Tensor const& output_weights, + at::Tensor const& dropout_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -152,13 +154,13 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_grads = at::empty_like(inputs); + at::Tensor input_weight_grads = at::empty_like(input_weights); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_output_grads = at::empty_like(input_lin_results); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu index fa481f8c9..ae0b6f04b 100644 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu +++ b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu @@ -2,10 +2,12 @@ #include #include #include +#if __has_include() #include +#endif +#include #include #include -#include #include #include @@ -19,10 +21,10 @@ namespace multihead_attn { namespace self_norm_add { namespace cublas_gemmex { -std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, torch::Tensor const& inputs, - torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, const uint8_t* pad_mask, float dropout_prob) { +std::vector fwd_cuda(bool use_time_mask, bool is_training, int heads, at::Tensor const& inputs, + at::Tensor const& lyr_nrm_gamma_weights, at::Tensor const& lyr_nrm_beta_weights, + at::Tensor const& input_weights, at::Tensor const& output_weights, + const uint8_t* pad_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -48,21 +50,21 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he // 3 Intermediate Results + Output (Note: dropout intermediates are generated // by ATen library code) auto act_options = inputs.options().requires_grad(false); - auto lyr_nrm_options = act_options.dtype(torch::kFloat32); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options); - torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options); - torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options); - - torch::Tensor input_lin_results = torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor output_lin_results = torch::empty_like(inputs, act_options); - torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); + auto lyr_nrm_options = act_options.dtype(at::kFloat); + auto mask_options = act_options.dtype(at::kByte); + + at::Tensor lyr_nrm_mean = at::empty({batches}, lyr_nrm_options); + at::Tensor lyr_nrm_invvar = at::empty({batches}, lyr_nrm_options); + at::Tensor lyr_nrm_results = at::empty_like(inputs, act_options); + + at::Tensor input_lin_results = at::empty({q_seq_len, sequences, output_lin_dim}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_results = at::empty({attn_batches, q_seq_len, k_seq_len}, act_options); + at::Tensor dropout_mask = at::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); + at::Tensor matmul2_results = at::empty({q_seq_len, attn_batches, head_dim}, act_options); + at::Tensor output_lin_results = at::empty_like(inputs, act_options); + at::Tensor dropout_add_mask = at::empty_like(inputs, mask_options); + at::Tensor outputs = at::empty_like(inputs, act_options); // Input Linear Results Pointers to Q, K, and V of interviewed activations void* q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); @@ -160,14 +162,14 @@ std::vector fwd_cuda(bool use_time_mask, bool is_training, int he dropout_results, dropout_mask, matmul2_results, dropout_add_mask, outputs}; } -std::vector bwd_cuda(int heads, torch::Tensor const& output_grads, torch::Tensor const& matmul2_results, - torch::Tensor const& dropout_results, torch::Tensor const& softmax_results, - torch::Tensor const& input_lin_results, torch::Tensor const& lyr_nrm_results, - torch::Tensor const& lyr_nrm_mean, torch::Tensor const& lyr_nrm_invvar, - torch::Tensor const& inputs, torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, torch::Tensor const& input_weights, - torch::Tensor const& output_weights, torch::Tensor const& dropout_mask, - torch::Tensor const& dropout_add_mask, float dropout_prob) { +std::vector bwd_cuda(int heads, at::Tensor const& output_grads, at::Tensor const& matmul2_results, + at::Tensor const& dropout_results, at::Tensor const& softmax_results, + at::Tensor const& input_lin_results, at::Tensor const& lyr_nrm_results, + at::Tensor const& lyr_nrm_mean, at::Tensor const& lyr_nrm_invvar, + at::Tensor const& inputs, at::Tensor const& lyr_nrm_gamma_weights, + at::Tensor const& lyr_nrm_beta_weights, at::Tensor const& input_weights, + at::Tensor const& output_weights, at::Tensor const& dropout_mask, + at::Tensor const& dropout_add_mask, float dropout_prob) { const int embed_dim = inputs.size(2); const int sequences = inputs.size(1); const int q_seq_len = inputs.size(0); @@ -191,17 +193,17 @@ std::vector bwd_cuda(int heads, torch::Tensor const& output_grads cublasSetStream(handle, stream); // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); - torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); + at::Tensor input_grads = at::empty_like(inputs); + at::Tensor lyr_nrm_gamma_grads = at::empty_like(lyr_nrm_gamma_weights); + at::Tensor lyr_nrm_beta_grads = at::empty_like(lyr_nrm_beta_weights); + at::Tensor input_weight_grads = at::empty_like(input_weights); + at::Tensor output_weight_grads = at::empty_like(output_weights); // Intermediate Tensor Allocations - torch::Tensor dropout_add_grads = torch::empty_like(output_grads); - torch::Tensor output_lin_grads = torch::empty_like(matmul2_results); - torch::Tensor matmul2_grads = torch::empty_like(dropout_results); - torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); - torch::Tensor input_lin_grads = torch::empty_like(inputs); + at::Tensor dropout_add_grads = at::empty_like(output_grads); + at::Tensor output_lin_grads = at::empty_like(matmul2_results); + at::Tensor matmul2_grads = at::empty_like(dropout_results); + at::Tensor input_lin_output_grads = at::empty_like(input_lin_results); + at::Tensor input_lin_grads = at::empty_like(inputs); auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); auto k_lin_results_ptr = static_cast(input_lin_results.data_ptr()) + head_dim; diff --git a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh index 8537e81e7..b9184c064 100644 --- a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh +++ b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh @@ -1,7 +1,9 @@ #pragma once #include #include +#if __has_include() #include +#endif #include #include diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp b/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp index 3db57c4e9..6080b93e8 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp @@ -14,16 +14,59 @@ * limitations under the License. */ +#include + #include "nccl_p2p_cuda.cuh" -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id", - py::call_guard()); - m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm", - py::call_guard()); - m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, - "left_right_halo_exchange_inplace", py::call_guard()); - m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange", - py::call_guard()); - m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay", py::call_guard()); +namespace { +at::Tensor apex_nccl_p2p_get_unique_nccl_id(int64_t n) { + return apex::contrib::nccl_p2p::get_unique_nccl_id(static_cast(n)); +} + +int64_t apex_nccl_p2p_init_nccl_comm(at::Tensor unique_nccl_id, int64_t my_rank, int64_t num_ranks) { + return apex::contrib::nccl_p2p::init_nccl_comm(unique_nccl_id, static_cast(my_rank), + static_cast(num_ranks)); +} + +void apex_nccl_p2p_left_right_halo_exchange_inplace(int64_t handle, int64_t left_rank, int64_t right_rank, + at::Tensor left_output_halo, at::Tensor right_output_halo, + at::Tensor left_input_halo, at::Tensor right_input_halo) { + apex::contrib::nccl_p2p::left_right_halo_exchange_inplace(static_cast(handle), static_cast(left_rank), + static_cast(right_rank), left_output_halo, + right_output_halo, left_input_halo, right_input_halo); +} + +std::vector apex_nccl_p2p_left_right_halo_exchange(int64_t handle, int64_t left_rank, int64_t right_rank, + at::Tensor left_output_halo, + at::Tensor right_output_halo) { + return apex::contrib::nccl_p2p::left_right_halo_exchange(static_cast(handle), static_cast(left_rank), + static_cast(right_rank), left_output_halo, + right_output_halo); +} + +void apex_nccl_p2p_add_delay(int64_t delay) { apex::contrib::nccl_p2p::add_delay(static_cast(delay)); } +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("nccl_p2p_get_unique_nccl_id(int n) -> Tensor"); + m.def("nccl_p2p_init_nccl_comm(Tensor unique_nccl_id, int my_rank, int num_ranks) -> int"); + m.def( + "nccl_p2p_left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, " + "Tensor left_output_halo, Tensor right_output_halo, Tensor(a!) left_input_halo, " + "Tensor(b!) right_input_halo) -> ()"); + m.def( + "nccl_p2p_left_right_halo_exchange(int handle, int left_rank, int right_rank, Tensor left_output_halo, " + "Tensor right_output_halo) -> Tensor[]"); + m.def("nccl_p2p_add_delay(int delay) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { + m.impl("nccl_p2p_get_unique_nccl_id", &apex_nccl_p2p_get_unique_nccl_id); + m.impl("nccl_p2p_init_nccl_comm", &apex_nccl_p2p_init_nccl_comm); + m.impl("nccl_p2p_add_delay", &apex_nccl_p2p_add_delay); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("nccl_p2p_left_right_halo_exchange_inplace", &apex_nccl_p2p_left_right_halo_exchange_inplace); + m.impl("nccl_p2p_left_right_halo_exchange", &apex_nccl_p2p_left_right_halo_exchange); } diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu index 21091bd1d..01a3a0c43 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu @@ -1,6 +1,6 @@ +#include #include #include -#include #include #include @@ -84,8 +84,8 @@ class NcclCommWrapper { ncclDataType_t ncclType = get_nccl_type(left_output_halo); bool left_zero = (left_rank < 0); bool right_zero = (right_rank < 0); - size_t left_n = torch::numel(left_output_halo); - size_t right_n = torch::numel(right_output_halo); + size_t left_n = left_output_halo.numel(); + size_t right_n = right_output_halo.numel(); assert(left_n > 0 && left_n == right_n); if (left_zero) { left_input_halo.zero_(); @@ -119,8 +119,8 @@ class NcclCommWrapper { // after halo exchange: // left_output_halo of rank+1 ends up in right_input_halo of rank // right_output_halo of rank-1 ends up in left_input_halo of rank - auto right_input_halo = torch::empty_like(left_output_halo); - auto left_input_halo = torch::empty_like(right_output_halo); + auto right_input_halo = at::empty_like(left_output_halo); + auto left_input_halo = at::empty_like(right_output_halo); left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo); return {left_input_halo, right_input_halo}; @@ -161,8 +161,8 @@ namespace nccl_p2p { at::Tensor get_unique_nccl_id(int n) { ncclUniqueId id; ncclGetUniqueId(&id); - auto id_tensor = torch::empty({n, (int)sizeof(ncclUniqueId)}, - torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false)); + auto id_tensor = at::empty({n, (int)sizeof(ncclUniqueId)}, + at::TensorOptions().dtype(at::kByte).device(at::kCPU).requires_grad(false)); auto id_ptr = id_tensor.data_ptr(); size_t offset = 0; for (int i = 0; i < n; ++i) { @@ -200,7 +200,7 @@ std::vector left_right_halo_exchange(int handle, int left_rank, int void add_delay(int delay) { auto stream = at::cuda::getCurrentCUDAStream(); - auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + auto t = at::empty({1}, at::TensorOptions().dtype(at::kInt).device(at::kCUDA)); AddDelay_kernel<<<1, 1, 0, stream>>>(delay, t.data_ptr()); } diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh index a047bedb6..dd6b4bd43 100644 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh +++ b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh @@ -15,7 +15,9 @@ */ #pragma once -#include +#include + +#include #ifndef _nccl_p2p_h_ #define _nccl_p2p_h_ diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp index 9d69ba01b..baee1710c 100644 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp @@ -1,4 +1,5 @@ -#include +#include +#include // CUDA forward declaration void fused_strided_check_finite(at::Tensor& overflow_flag, at::Tensor& p_copy, int stride, int clear_overflow_first); @@ -88,18 +89,107 @@ void maybe_cast(at::Tensor& overflow_flag, at::Tensor& p_in, at::Tensor& p_out) maybe_cast_cuda(overflow_flag, p_in, p_out); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("strided_check_finite", &strided_check_finite, "Strided finite check.", - py::call_guard()); - m.def("adam", &adam, "Adam optimized CUDA implementation.", py::call_guard()); - m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.", - py::call_guard()); - m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.", - py::call_guard()); - m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.", - py::call_guard()); - m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.", - py::call_guard()); - m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.", - py::call_guard()); +void strided_check_finite_dispatch(const at::Tensor& overflow_flag, const at::Tensor& p_copy, int64_t stride, + int64_t clear_overflow_first) { + at::Tensor overflow_flag_arg = overflow_flag; + at::Tensor p_copy_arg = p_copy; + strided_check_finite(overflow_flag_arg, p_copy_arg, static_cast(stride), static_cast(clear_overflow_first)); +} + +void adam_dispatch(const at::Tensor& p, const at::Tensor& p_copy, const at::Tensor& m, const at::Tensor& v, + const at::Tensor& g, double lr, double beta1, double beta2, double eps, double grad_scale, + int64_t step, int64_t mode, int64_t bias_correction, double decay) { + at::Tensor p_arg = p; + at::Tensor p_copy_arg = p_copy; + at::Tensor m_arg = m; + at::Tensor v_arg = v; + at::Tensor g_arg = g; + adam(p_arg, p_copy_arg, m_arg, v_arg, g_arg, static_cast(lr), static_cast(beta1), + static_cast(beta2), static_cast(eps), static_cast(grad_scale), static_cast(step), + static_cast(mode), static_cast(bias_correction), static_cast(decay)); +} + +void reversible_adam_dispatch(const at::Tensor& p, const at::Tensor& p_copy, const at::Tensor& m, const at::Tensor& v, + const at::Tensor& g, double lr, double beta1, double beta2, double eps, double grad_scale, + int64_t step, int64_t mode, int64_t bias_correction, double decay) { + at::Tensor p_arg = p; + at::Tensor p_copy_arg = p_copy; + at::Tensor m_arg = m; + at::Tensor v_arg = v; + at::Tensor g_arg = g; + reversible_adam(p_arg, p_copy_arg, m_arg, v_arg, g_arg, static_cast(lr), static_cast(beta1), + static_cast(beta2), static_cast(eps), static_cast(grad_scale), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(decay)); +} + +void fused_adam_cuda_mt_dispatch(int64_t chunk_size, at::Tensor overflow_flag, + std::vector> tensor_lists, double lr, double beta1, + double beta2, double eps, double grad_scale, int64_t step, int64_t mode, + int64_t bias_correction, double decay) { + fused_adam_cuda_mt(static_cast(chunk_size), overflow_flag, tensor_lists, static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(eps), + static_cast(grad_scale), static_cast(step), static_cast(mode), + static_cast(bias_correction), static_cast(decay)); +} + +void maybe_adam_undo_dispatch(const at::Tensor& overflow_flag, const at::Tensor& p, const at::Tensor& m, + const at::Tensor& v, const at::Tensor& g, double lr, double beta1, double beta2, + double eps, double grad_scale, int64_t step, int64_t mode, int64_t bias_correction, + double decay) { + at::Tensor overflow_flag_arg = overflow_flag; + at::Tensor p_arg = p; + at::Tensor m_arg = m; + at::Tensor v_arg = v; + at::Tensor g_arg = g; + maybe_adam_undo(overflow_flag_arg, p_arg, m_arg, v_arg, g_arg, static_cast(lr), static_cast(beta1), + static_cast(beta2), static_cast(eps), static_cast(grad_scale), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(decay)); +} + +void maybe_cast_dispatch(const at::Tensor& overflow_flag, const at::Tensor& p_in, const at::Tensor& p_out) { + at::Tensor overflow_flag_arg = overflow_flag; + at::Tensor p_in_arg = p_in; + at::Tensor p_out_arg = p_out; + maybe_cast(overflow_flag_arg, p_in_arg, p_out_arg); +} + +void maybe_cast_cuda_mt_dispatch(int64_t chunk_size, at::Tensor overflow_flag, + std::vector> tensor_lists) { + maybe_cast_cuda_mt(static_cast(chunk_size), overflow_flag, tensor_lists); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "fused_adam_strided_check_finite(Tensor(a!) overflow_flag, Tensor p_copy, int stride, " + "int clear_overflow_first) -> ()"); + m.def( + "fused_adam_adam(Tensor(a!) p, Tensor(b!) p_copy, Tensor(c!) m, Tensor(d!) v, Tensor g, " + "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " + "int bias_correction, float decay) -> ()"); + m.def( + "fused_adam_reversible_adam(Tensor(a!) p, Tensor(b!) p_copy, Tensor(c!) m, Tensor(d!) v, Tensor g, " + "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " + "int bias_correction, float decay) -> ()"); + m.def( + "fused_adam_adam_mt(int chunk_size, Tensor overflow_flag, Tensor[][] tensor_lists, float lr, " + "float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, " + "float decay) -> ()"); + m.def( + "fused_adam_maybe_adam_undo(Tensor overflow_flag, Tensor(a!) p, Tensor(b!) m, Tensor(c!) v, Tensor(d!) g, " + "float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, " + "int bias_correction, float decay) -> ()"); + m.def("fused_adam_maybe_cast(Tensor overflow_flag, Tensor p_in, Tensor(a!) p_out) -> ()"); + m.def("fused_adam_maybe_cast_mt(int chunk_size, Tensor overflow_flag, Tensor[][] tensor_lists) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_adam_strided_check_finite", &strided_check_finite_dispatch); + m.impl("fused_adam_adam", &adam_dispatch); + m.impl("fused_adam_reversible_adam", &reversible_adam_dispatch); + m.impl("fused_adam_adam_mt", &fused_adam_cuda_mt_dispatch); + m.impl("fused_adam_maybe_adam_undo", &maybe_adam_undo_dispatch); + m.impl("fused_adam_maybe_cast", &maybe_cast_dispatch); + m.impl("fused_adam_maybe_cast_mt", &maybe_cast_cuda_mt_dispatch); } diff --git a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp index b0aa92082..200258e0c 100644 --- a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp +++ b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp @@ -1,11 +1,27 @@ -#include +#include +#include void multi_tensor_lamb_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int bias_correction, const float weight_decay, const int grad_averaging, const int mode, const float global_grad_norm, const float max_grad_norm); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer", - py::call_guard()); +void multi_tensor_lamb(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + double lr, double beta1, double beta2, double epsilon, int64_t step, int64_t bias_correction, + double weight_decay, int64_t grad_averaging, int64_t mode, double global_grad_norm, + double max_grad_norm) { + multi_tensor_lamb_cuda(static_cast(chunk_size), noop_flag, tensor_lists, static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + static_cast(step), static_cast(bias_correction), static_cast(weight_decay), + static_cast(grad_averaging), static_cast(mode), static_cast(global_grad_norm), + static_cast(max_grad_norm)); } + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "fused_lamb_lamb(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " + "float beta2, float epsilon, int step, int bias_correction, float weight_decay, int grad_averaging, " + "int mode, float global_grad_norm, float max_grad_norm) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { m.impl("fused_lamb_lamb", &multi_tensor_lamb); } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp index 29592a6af..6e88ea604 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp @@ -1,4 +1,4 @@ -#include +#include void multi_tensor_fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, at::Tensor grad_scale, float lr, @@ -16,17 +16,55 @@ void multi_tensor_fused_adam_with_param_remainders_cuda(int chunk_size, at::Tens float eps, int step, int mode, int bias_correction, float weight_decay); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda, - "CUDA kernels for multi-tensor Adam, " - "with param copy", - py::call_guard()); - m.def("multi_tensor_fused_adam_capturable", &multi_tensor_fused_adam_capturable_cuda, - "CUDA kernels for multi-tensor Adam, " - "with param copy, capturable for CUDA graph", - py::call_guard()); - m.def("multi_tensor_fused_adam_with_param_remainders", &multi_tensor_fused_adam_with_param_remainders_cuda, - "CUDA kernel for multi-tensor Adam, " - "with stored param remainders and param copy", - py::call_guard()); +void multi_tensor_fused_adam(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, double lr, + double beta1, double beta2, double eps, int64_t step, int64_t mode, + int64_t bias_correction, double weight_decay) { + multi_tensor_fused_adam_cuda(static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, + static_cast(lr), static_cast(beta1), static_cast(beta2), + static_cast(eps), static_cast(step), static_cast(mode), + static_cast(bias_correction), static_cast(weight_decay)); +} + +void multi_tensor_fused_adam_capturable(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_scale, + at::Tensor lr, double beta1, double beta2, double eps, at::Tensor step, + int64_t mode, int64_t bias_correction, double weight_decay) { + multi_tensor_fused_adam_capturable_cuda(static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, lr, + static_cast(beta1), static_cast(beta2), static_cast(eps), + step, static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay)); +} + +void multi_tensor_fused_adam_with_param_remainders(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor grad_scale, double lr, double beta1, double beta2, + double eps, int64_t step, int64_t mode, int64_t bias_correction, + double weight_decay) { + multi_tensor_fused_adam_with_param_remainders_cuda( + static_cast(chunk_size), noop_flag, tensor_lists, grad_scale, static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(eps), static_cast(step), + static_cast(mode), static_cast(bias_correction), static_cast(weight_decay)); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "distributed_adam_multi_tensor_fused_adam(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor grad_scale, float lr, float beta1, float beta2, float eps, int step, int mode, " + "int bias_correction, float weight_decay) -> ()"); + m.def( + "distributed_adam_multi_tensor_fused_adam_capturable(int chunk_size, Tensor noop_flag, " + "Tensor[][] tensor_lists, Tensor grad_scale, Tensor lr, float beta1, float beta2, float eps, " + "Tensor step, int mode, int bias_correction, float weight_decay) -> ()"); + m.def( + "distributed_adam_multi_tensor_fused_adam_with_param_remainders(int chunk_size, Tensor noop_flag, " + "Tensor[][] tensor_lists, Tensor grad_scale, float lr, float beta1, float beta2, float eps, " + "int step, int mode, int bias_correction, float weight_decay) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("distributed_adam_multi_tensor_fused_adam", &multi_tensor_fused_adam); + m.impl("distributed_adam_multi_tensor_fused_adam_capturable", &multi_tensor_fused_adam_capturable); + m.impl("distributed_adam_multi_tensor_fused_adam_with_param_remainders", + &multi_tensor_fused_adam_with_param_remainders); } diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp index f74bebfc2..63aa147ff 100644 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp +++ b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp @@ -1,4 +1,4 @@ -#include +#include void multi_tensor_lamb_compute_update_term_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, @@ -14,9 +14,42 @@ void multi_tensor_lamb_update_weights_cuda(int chunk_size, at::Tensor noop_flag, at::Tensor update_norm_offset, at::Tensor learning_rate, at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda, - "Computes update term for LAMB optimizer", py::call_guard()); - m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda, - "Applies update term for LAMB optimizer", py::call_guard()); +void multi_tensor_lamb_compute_update_term(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_beta1, at::Tensor per_tensor_beta2, + at::Tensor per_tensor_beta3, at::Tensor per_tensor_bias_correction, + at::Tensor step, at::Tensor per_tensor_epsilon, int64_t mode, + at::Tensor per_tensor_decay, at::Tensor global_scale, + at::Tensor global_grad_norm, double max_grad_norm) { + multi_tensor_lamb_compute_update_term_cuda(static_cast(chunk_size), noop_flag, tensor_lists, per_tensor_beta1, + per_tensor_beta2, per_tensor_beta3, per_tensor_bias_correction, step, + per_tensor_epsilon, static_cast(mode), per_tensor_decay, global_scale, + global_grad_norm, static_cast(max_grad_norm)); +} + +void multi_tensor_lamb_update_weights(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, + at::Tensor update_norm_offset, at::Tensor learning_rate, + at::Tensor per_tensor_decay, at::Tensor global_grad_norm, bool use_nvlamb) { + multi_tensor_lamb_update_weights_cuda(static_cast(chunk_size), noop_flag, tensor_lists, per_tensor_param_norm, + per_tensor_update_norm, update_norm_offset, learning_rate, per_tensor_decay, + global_grad_norm, use_nvlamb); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "distributed_lamb_compute_update_term(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_beta1, Tensor per_tensor_beta2, Tensor per_tensor_beta3, " + "Tensor per_tensor_bias_correction, Tensor step, Tensor per_tensor_epsilon, int mode, " + "Tensor per_tensor_decay, Tensor global_scale, Tensor global_grad_norm, float max_grad_norm) -> ()"); + m.def( + "distributed_lamb_update_weights(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_param_norm, Tensor per_tensor_update_norm, Tensor update_norm_offset, " + "Tensor learning_rate, Tensor per_tensor_decay, Tensor global_grad_norm, bool use_nvlamb) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("distributed_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term); + m.impl("distributed_lamb_update_weights", &multi_tensor_lamb_update_weights); } diff --git a/apex/contrib/csrc/peer_memory/peer_memory.cpp b/apex/contrib/csrc/peer_memory/peer_memory.cpp index bc19e6206..f4ae8c80c 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory.cpp +++ b/apex/contrib/csrc/peer_memory/peer_memory.cpp @@ -14,23 +14,67 @@ * limitations under the License. */ +#include + #include "peer_memory_cuda.cuh" -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw", - py::call_guard()); - m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw", py::call_guard()); - m.def("zero", &apex::contrib::peer_memory::zero, "zero", py::call_guard()); - m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address", - py::call_guard()); - m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers", - py::call_guard()); - m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half", - py::call_guard()); - m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float", - py::call_guard()); - m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int", - py::call_guard()); - m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d", - py::call_guard()); +namespace { +std::vector apex_peer_memory_get_raw_peers(at::Tensor ipc_addresses, int64_t peer_rank, int64_t raw) { + return apex::contrib::peer_memory::get_raw_peers(ipc_addresses, static_cast(peer_rank), raw); +} + +at::Tensor apex_peer_memory_blob_view_half(int64_t raw, at::IntArrayRef shape, bool channels_last) { + return apex::contrib::peer_memory::blob_view_half(raw, std::vector(shape.begin(), shape.end()), + channels_last); +} + +at::Tensor apex_peer_memory_blob_view_float(int64_t raw, at::IntArrayRef shape, bool channels_last) { + return apex::contrib::peer_memory::blob_view_float(raw, std::vector(shape.begin(), shape.end()), + channels_last); +} + +at::Tensor apex_peer_memory_blob_view_int(int64_t raw, at::IntArrayRef shape, bool channels_last) { + return apex::contrib::peer_memory::blob_view_int(raw, std::vector(shape.begin(), shape.end()), + channels_last); } + +void apex_peer_memory_push_pull_halos_1d(bool diagnostics, bool explicit_nhwc, int64_t numSM, int64_t rank, + bool top_zero, at::Tensor top_in_halo, at::Tensor top_in_transfer, + at::Tensor top_out_transfer, at::Tensor top_out_halo, bool btm_zero, + at::Tensor btm_in_halo, at::Tensor btm_in_transfer, + at::Tensor btm_out_transfer, at::Tensor btm_out_halo) { + apex::contrib::peer_memory::push_pull_halos_1d(diagnostics, explicit_nhwc, static_cast(numSM), + static_cast(rank), top_zero, top_in_halo, top_in_transfer, + top_out_transfer, top_out_halo, btm_zero, btm_in_halo, btm_in_transfer, + btm_out_transfer, btm_out_halo); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("peer_memory_allocate_raw(int size) -> int"); + m.def("peer_memory_free_raw(int raw) -> ()"); + m.def("peer_memory_zero(int raw, int size) -> ()"); + m.def("peer_memory_get_raw_ipc_address(int raw) -> Tensor"); + m.def("peer_memory_get_raw_peers(Tensor ipc_addresses, int peer_rank, int raw) -> int[]"); + m.def("peer_memory_blob_view_half(int raw, int[] shape, bool channels_last) -> Tensor"); + m.def("peer_memory_blob_view_float(int raw, int[] shape, bool channels_last) -> Tensor"); + m.def("peer_memory_blob_view_int(int raw, int[] shape, bool channels_last) -> Tensor"); + m.def( + "peer_memory_push_pull_halos_1d(bool diagnostics, bool explicit_nhwc, int numSM, int rank, bool top_zero, " + "Tensor(a!) top_in_halo, Tensor(b!) top_in_transfer, Tensor(c!) top_out_transfer, Tensor(d!) top_out_halo, " + "bool btm_zero, Tensor(e!) btm_in_halo, Tensor(f!) btm_in_transfer, Tensor(g!) btm_out_transfer, " + "Tensor(h!) btm_out_halo) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { + m.impl("peer_memory_allocate_raw", &apex::contrib::peer_memory::allocate_raw); + m.impl("peer_memory_free_raw", &apex::contrib::peer_memory::free_raw); + m.impl("peer_memory_zero", &apex::contrib::peer_memory::zero); + m.impl("peer_memory_get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address); + m.impl("peer_memory_get_raw_peers", &apex_peer_memory_get_raw_peers); + m.impl("peer_memory_blob_view_half", &apex_peer_memory_blob_view_half); + m.impl("peer_memory_blob_view_float", &apex_peer_memory_blob_view_float); + m.impl("peer_memory_blob_view_int", &apex_peer_memory_blob_view_int); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { m.impl("peer_memory_push_pull_halos_1d", &apex_peer_memory_push_pull_halos_1d); } diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu index 974ed5d64..3a43c860f 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu @@ -1,7 +1,8 @@ +#include #include +#include #include #include -#include #include #include @@ -51,7 +52,7 @@ at::Tensor blob_view(T* raw_ptr, std::vector shape, const at::TensorOpt size *= sizeof(T); // TODO: Implement dynamic reuse of pooled peer memory. // We provide no deleter function because all peer memory allocations are static in this implementation. - return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options); + return at::from_blob((void*)raw_ptr, shape, strides, options); } void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W) { @@ -330,7 +331,7 @@ at::Tensor get_raw_ipc_address(int64_t raw) { cudaIpcMemHandle_t mem_handle; CUDACHECK(cudaIpcGetMemHandle(&mem_handle, (void*)raw)); const int n = sizeof(cudaIpcMemHandle_t); - auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8)); + auto address_tensor = at::empty({n}, at::TensorOptions().dtype(at::kByte)); auto address_tensor_p = address_tensor.data_ptr(); memcpy(address_tensor_p, (uint8_t*)&mem_handle, n); return address_tensor; @@ -354,15 +355,16 @@ std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int6 } at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last) { - return blob_view((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last); + return blob_view((at::Half*)raw, shape, at::TensorOptions().dtype(at::kHalf).device(at::kCUDA), + channels_last); } at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last) { - return blob_view((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last); + return blob_view((float*)raw, shape, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA), channels_last); } at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last) { - return blob_view((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last); + return blob_view((int*)raw, shape, at::TensorOptions().dtype(at::kInt).device(at::kCUDA), channels_last); } void push_pull_halos_1d( diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh index 29cf7a108..4357ebaa9 100644 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh +++ b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh @@ -15,7 +15,9 @@ */ #pragma once -#include +#include + +#include #ifndef _peer_memory_h_ #define _peer_memory_h_ diff --git a/apex/contrib/csrc/transducer/transducer_joint.cpp b/apex/contrib/csrc/transducer/transducer_joint.cpp index 1175c1676..4deedce0d 100644 --- a/apex/contrib/csrc/transducer/transducer_joint.cpp +++ b/apex/contrib/csrc/transducer/transducer_joint.cpp @@ -1,49 +1,56 @@ -#include -#include - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -std::vector transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, - int64_t packedBatch, int opt, bool packOutput, bool relu, - bool dropout, float dropoutProb, int tileSize); - -std::vector transducer_joint_cuda_backward(std::vector in, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, - int maxGLen, bool packOutput, float scale); - -std::vector transducer_joint_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, int64_t packedBatch, - int opt, bool packOutput, bool relu, bool dropout, - float dropoutProb, int tileSize) { - CHECK_INPUT(f); - CHECK_INPUT(g); - CHECK_INPUT(fLen); - CHECK_INPUT(gLen); - if (packOutput) CHECK_INPUT(batchOffset); - return transducer_joint_cuda_forward(f, g, fLen, gLen, batchOffset, packedBatch, opt, packOutput, relu, dropout, - dropoutProb, tileSize); -} - -std::vector transducer_joint_backward(std::vector in, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, - int maxGLen, bool packOutput, float scale) { - for (auto t : in) { - CHECK_INPUT(t); - } - CHECK_INPUT(fLen); - CHECK_INPUT(gLen); - if (packOutput) CHECK_INPUT(batchOffset); - return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)", - py::call_guard()); - m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)", - py::call_guard()); -} +#include +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t packedBatch, int64_t opt, + bool packOutput, bool relu, bool dropout, double dropoutProb, + int64_t tileSize); + +std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t maxFLen, int64_t maxGLen, + bool packOutput, double scale); + +std::vector transducer_joint_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t packedBatch, int64_t opt, + bool packOutput, bool relu, bool dropout, double dropoutProb, + int64_t tileSize) { + CHECK_INPUT(f); + CHECK_INPUT(g); + CHECK_INPUT(fLen); + CHECK_INPUT(gLen); + if (packOutput) CHECK_INPUT(batchOffset); + return transducer_joint_cuda_forward(f, g, fLen, gLen, batchOffset, packedBatch, opt, packOutput, relu, dropout, + dropoutProb, tileSize); +} + +std::vector transducer_joint_backward(std::vector in, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t maxFLen, int64_t maxGLen, + bool packOutput, double scale) { + for (auto t : in) { + CHECK_INPUT(t); + } + CHECK_INPUT(fLen); + CHECK_INPUT(gLen); + if (packOutput) CHECK_INPUT(batchOffset); + return transducer_joint_cuda_backward(in, fLen, gLen, batchOffset, maxFLen, maxGLen, packOutput, scale); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "transducer_joint_forward(Tensor f, Tensor g, Tensor fLen, Tensor gLen, Tensor batchOffset, int packedBatch, " + "int opt, bool packOutput, bool relu, bool dropout, float dropoutProb, int tileSize) -> Tensor[]"); + m.def( + "transducer_joint_backward(Tensor[] input, Tensor fLen, Tensor gLen, Tensor batchOffset, int maxFLen, " + "int maxGLen, bool packOutput, float scale) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("transducer_joint_forward", &transducer_joint_forward); + m.impl("transducer_joint_backward", &transducer_joint_backward); +} diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu index acb9f9e9e..d8915abe3 100644 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu @@ -1,8 +1,8 @@ +#include #include #include #include #include -#include #ifdef OLD_GENERATOR_PATH #include @@ -179,7 +179,7 @@ __global__ void transducer_joint_forward(const scalar_t* f, const scalar_t* g, c } } else if (packOutput == false and t < maxFLen and u < maxGLen) { // Need to write finite data to don't-care region because we instantiate the result tensor -// with torch::empty for performance reasons. Even though it is don't-care region, the +// with at::empty for performance reasons. Even though it is don't-care region, the // contents need to be finite, otherwise could lead to NaN in WGRAD. // In packing mode, this write is no longer necessary as we remove the don't-care region // from the output. @@ -535,10 +535,10 @@ __global__ void transducer_joint_combined_vec_backward(const scalar_t* grad, con } } -std::vector transducer_joint_cuda_forward(torch::Tensor f, torch::Tensor g, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, - int64_t packedBatch, int opt, bool packOutput, bool relu, - bool dropout, float dropoutProb, int tileSize) { +std::vector transducer_joint_cuda_forward(at::Tensor f, at::Tensor g, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t packedBatch, int64_t opt, + bool packOutput, bool relu, bool dropout, double dropoutProb, + int64_t tileSize) { auto tensorOpt = f.options(); auto dtype = f.scalar_type(); const auto batchSize = f.size(0); @@ -548,16 +548,16 @@ std::vector transducer_joint_cuda_forward(torch::Tensor f, torch: bool masked = dropout or relu; int64_t* batchOffsetPtr = nullptr; - torch::Tensor sum, mask; - auto maskOpt = tensorOpt.dtype(torch::kUInt8); + at::Tensor sum, mask; + auto maskOpt = tensorOpt.dtype(at::kByte); if (!packOutput) { - sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); + sum = at::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); batchOffsetPtr = nullptr; - if (masked) mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); + if (masked) mask = at::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); } else { - sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); + sum = at::empty({packedBatch, hiddenSize}, tensorOpt); batchOffsetPtr = batchOffset.data_ptr(); - if (masked) mask = torch::empty({packedBatch, hiddenSize}, maskOpt); + if (masked) mask = at::empty({packedBatch, hiddenSize}, maskOpt); } uint8_t* maskPtr = masked ? mask.data_ptr() : nullptr; @@ -635,10 +635,10 @@ std::vector transducer_joint_cuda_forward(torch::Tensor f, torch: } } - kernel<<>>(f.data_ptr(), g.data_ptr(), fLen.data_ptr(), - gLen.data_ptr(), batchOffsetPtr, maxFLen, maxGLen, hiddenSize, - hiddenPerBlock, packOutput, relu, dropout, 1.0f - dropoutProb, - rng_engine_inputs, sum.data_ptr(), maskPtr); + kernel<<>>( + f.data_ptr(), g.data_ptr(), fLen.data_ptr(), gLen.data_ptr(), + batchOffsetPtr, maxFLen, maxGLen, hiddenSize, hiddenPerBlock, packOutput, relu, dropout, + static_cast(1.0 - dropoutProb), rng_engine_inputs, sum.data_ptr(), maskPtr); })); } @@ -649,9 +649,9 @@ std::vector transducer_joint_cuda_forward(torch::Tensor f, torch: return {sum}; } -std::vector transducer_joint_cuda_backward(std::vector in, torch::Tensor fLen, - torch::Tensor gLen, torch::Tensor batchOffset, int maxFLen, - int maxGLen, bool packOutput, float scale) { +std::vector transducer_joint_cuda_backward(std::vector in, at::Tensor fLen, at::Tensor gLen, + at::Tensor batchOffset, int64_t maxFLen, int64_t maxGLen, + bool packOutput, double scale) { auto grad = in[0]; bool masked = (in.size() == 2); uint8_t* maskPtr = masked ? in[1].data_ptr() : nullptr; @@ -664,8 +664,8 @@ std::vector transducer_joint_cuda_backward(std::vectormaxThreadsPerBlock / C10_WARP_SIZE; - torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); - torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); + at::Tensor fGrad = at::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); + at::Tensor gGrad = at::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); int64_t* batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); @@ -703,6 +703,8 @@ std::vector transducer_joint_cuda_backward(std::vector(fGradPtr) % vecAlignment == 0) and (reinterpret_cast(gGradPtr) % vecAlignment == 0); + const float scale_arg = static_cast(scale); + if (vectFactor > 1 and hiddenSize % vectFactor == 0 and memAlign) { // If vectorization helps and the alignment requirement is met, use the vectorized // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. @@ -711,12 +713,12 @@ std::vector transducer_joint_cuda_backward(std::vector <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale, + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, fGradPtr, gGradPtr); } else { transducer_joint_combined_vec_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale, + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, fGradPtr, gGradPtr); } } else { @@ -724,12 +726,12 @@ std::vector transducer_joint_cuda_backward(std::vector <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale, + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, fGradPtr, gGradPtr); } else { transducer_joint_combined_backward <<>>(gradPtr, maskPtr, fLenPtr, gLenPtr, batchOffsetPtr, - maxFLen, maxGLen, hiddenSize, packOutput, scale, + maxFLen, maxGLen, hiddenSize, packOutput, scale_arg, fGradPtr, gGradPtr); } } diff --git a/apex/contrib/csrc/transducer/transducer_loss.cpp b/apex/contrib/csrc/transducer/transducer_loss.cpp index 4c124edac..48f20ca03 100644 --- a/apex/contrib/csrc/transducer/transducer_loss.cpp +++ b/apex/contrib/csrc/transducer/transducer_loss.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include @@ -8,18 +9,18 @@ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) -std::vector transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen, - torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen, - int blankIdx, int opt, bool packedInput); +std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor label, at::Tensor audLen, + at::Tensor txtLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool packedInput); -torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, - torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen, - torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, - int opt, bool fuseSoftmaxBackward, bool packedInput); +at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, + at::Tensor audLen, at::Tensor txtLen, at::Tensor label, at::Tensor batchOffset, + int64_t maxFLen, int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, + bool packedInput); -std::vector transducer_loss_forward(torch::Tensor x, torch::Tensor label, torch::Tensor fLen, - torch::Tensor yLen, torch::Tensor batchOffset, int maxFLen, - int blankIdx, int opt, bool packedInput) { +std::vector transducer_loss_forward(at::Tensor x, at::Tensor label, at::Tensor fLen, at::Tensor yLen, + at::Tensor batchOffset, int64_t maxFLen, int64_t blankIdx, int64_t opt, + bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); CHECK_INPUT(fLen); @@ -28,10 +29,10 @@ std::vector transducer_loss_forward(torch::Tensor x, torch::Tenso return transducer_loss_cuda_forward(x, label, fLen, yLen, batchOffset, maxFLen, blankIdx, opt, packedInput); } -torch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, torch::Tensor beta, - torch::Tensor fLen, torch::Tensor yLen, torch::Tensor label, - torch::Tensor batchOffset, int maxFLen, int blankIdx, int opt, - bool fuseSoftmaxBackward, bool packedInput) { +at::Tensor transducer_loss_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, + at::Tensor fLen, at::Tensor yLen, at::Tensor label, at::Tensor batchOffset, + int64_t maxFLen, int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, + bool packedInput) { CHECK_INPUT(x); CHECK_INPUT(label); CHECK_INPUT(lossGrad); @@ -45,9 +46,17 @@ torch::Tensor transducer_loss_backward(torch::Tensor x, torch::Tensor lossGrad, fuseSoftmaxBackward, packedInput); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)", - py::call_guard()); - m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)", - py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "transducer_loss_forward(Tensor x, Tensor label, Tensor fLen, Tensor yLen, Tensor batchOffset, int maxFLen, " + "int blankIdx, int opt, bool packedInput) -> Tensor[]"); + m.def( + "transducer_loss_backward(Tensor x, Tensor lossGrad, Tensor alpha, Tensor beta, Tensor fLen, Tensor yLen, " + "Tensor label, Tensor batchOffset, int maxFLen, int blankIdx, int opt, bool fuseSoftmaxBackward, " + "bool packedInput) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("transducer_loss_forward", &transducer_loss_forward); + m.impl("transducer_loss_backward", &transducer_loss_backward); } diff --git a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu index f6f7e4ca0..437dcbb2a 100644 --- a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu +++ b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu @@ -3,7 +3,6 @@ #include #include #include -#include #include @@ -455,9 +454,9 @@ __global__ void transducer_loss_fused_vec_backward(const scalar_t* x, const scal } } -std::vector transducer_loss_cuda_forward(torch::Tensor x, torch::Tensor label, torch::Tensor audLen, - torch::Tensor txtLen, torch::Tensor batchOffset, int maxFLen, - int blankIdx, int opt, bool packedInput) { +std::vector transducer_loss_cuda_forward(at::Tensor x, at::Tensor label, at::Tensor audLen, + at::Tensor txtLen, at::Tensor batchOffset, int64_t maxFLen, + int64_t blankIdx, int64_t opt, bool packedInput) { auto scalarType = x.scalar_type(); auto tensorOpt = x.options(); const int batchSize = label.size(0); @@ -470,9 +469,9 @@ std::vector transducer_loss_cuda_forward(torch::Tensor x, torch:: // The data type of alpha and beta will be resolved at dispatch time, // hence defined here and assigned later - torch::Tensor alpha; - torch::Tensor beta; - torch::Tensor loss = torch::empty({batchSize}, tensorOpt); + at::Tensor alpha; + at::Tensor beta; + at::Tensor loss = at::empty({batchSize}, tensorOpt); const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock; @@ -485,8 +484,8 @@ std::vector transducer_loss_cuda_forward(torch::Tensor x, torch:: using acc_t = at::acc_type; auto accType = c10::CppTypeToScalarType::value; auto accTensorOpt = tensorOpt.dtype(accType); - alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); - beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); + alpha = at::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); + beta = at::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); // decide what kernel to launch based on the problem size // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla @@ -514,12 +513,12 @@ std::vector transducer_loss_cuda_forward(torch::Tensor x, torch:: return {alpha, beta, loss}; } -torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossGrad, torch::Tensor alpha, - torch::Tensor beta, torch::Tensor audLen, torch::Tensor txtLen, - torch::Tensor label, torch::Tensor batchOffset, int maxFLen, int blankIdx, - int opt, bool fuseSoftmaxBackward, bool packedInput) { +at::Tensor transducer_loss_cuda_backward(at::Tensor x, at::Tensor lossGrad, at::Tensor alpha, at::Tensor beta, + at::Tensor audLen, at::Tensor txtLen, at::Tensor label, at::Tensor batchOffset, + int64_t maxFLen, int64_t blankIdx, int64_t opt, bool fuseSoftmaxBackward, + bool packedInput) { auto dtype = x.scalar_type(); - torch::Tensor xGrad; + at::Tensor xGrad; const int batchSize = label.size(0); const int maxGLen = label.size(1) + 1; const int dictSize = x.size(-1); @@ -532,7 +531,7 @@ torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossG if (fuseSoftmaxBackward) { // alloc empty tensors for performance, hence need to ensure zeros are writtern to // don't-care region in the kernel. - xGrad = torch::empty_like(x); + xGrad = at::empty_like(x); // Would like each thread to work on 4 hidden units const int workPerThread = 4; @@ -566,7 +565,7 @@ torch::Tensor transducer_loss_cuda_backward(torch::Tensor x, torch::Tensor lossG } else { // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize // the tensor with all zeros. - xGrad = torch::zeros_like(x); + xGrad = at::zeros_like(x); // don't launch more threads than needed. const int threads = std::min(maxThreadPerBlock, maxGLen); const dim3 blocks(maxFLen, batchSize); diff --git a/apex/contrib/csrc/xentropy/interface.cpp b/apex/contrib/csrc/xentropy/interface.cpp index 76cd379f5..748d85d0b 100644 --- a/apex/contrib/csrc/xentropy/interface.cpp +++ b/apex/contrib/csrc/xentropy/interface.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include @@ -20,36 +21,43 @@ at::Tensor softmax_xentropy_backward_cuda(const at::Tensor& grad_loss, const at: CHECK_CONTIGUOUS(x) std::vector softmax_xentropy_forward(const at::Tensor& input, const at::Tensor& labels, - const float smoothing, const bool half_to_float) { + const double smoothing, const bool half_to_float) { CHECK_CUDA(input); CHECK_INPUT(labels); - return softmax_xentropy_cuda(input, labels, smoothing, half_to_float); + return softmax_xentropy_cuda(input, labels, static_cast(smoothing), half_to_float); } at::Tensor softmax_xentropy_backward(const at::Tensor& grad_loss, const at::Tensor& logits, const at::Tensor& max_log_sum_exp, const at::Tensor& labels, - const float smoothing) { + const double smoothing) { CHECK_CUDA(grad_loss); CHECK_CUDA(logits); CHECK_INPUT(max_log_sum_exp); CHECK_INPUT(labels); - return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing); + return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, static_cast(smoothing)); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", - py::call_guard()); - m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", - py::call_guard()); - // ref: https://pybind11.readthedocs.io/en/stable/basics.html#exporting-variables - py::object version = py::cast( +std::string softmax_xentropy_version() { #ifdef XENTROPY_VER - XENTROPY_VER + return XENTROPY_VER; #else - std::string{} + return {}; #endif - ); - m.attr("__version__") = version; } + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("xentropy_forward(Tensor input, Tensor labels, float smoothing, bool half_to_float) -> Tensor[]"); + m.def( + "xentropy_backward(Tensor grad_loss, Tensor logits, Tensor max_log_sum_exp, Tensor labels, " + "float smoothing) -> Tensor"); + m.def("xentropy_version() -> str"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("xentropy_forward", &softmax_xentropy_forward); + m.impl("xentropy_backward", &softmax_xentropy_backward); +} + +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { m.impl("xentropy_version", &softmax_xentropy_version); } diff --git a/apex/contrib/cudnn_gbn/batch_norm.py b/apex/contrib/cudnn_gbn/batch_norm.py index 8346b74aa..854cbb4d5 100644 --- a/apex/contrib/cudnn_gbn/batch_norm.py +++ b/apex/contrib/cudnn_gbn/batch_norm.py @@ -2,8 +2,8 @@ from torch.nn.modules.batchnorm import _BatchNorm from torch.nn import functional as F from torch import Tensor -import peer_memory_cuda as pm -import cudnn_gbn_lib +from apex._extensions import peer_memory_cuda as pm +from apex._extensions import cudnn_gbn_lib from torch.cuda.amp import custom_fwd, custom_bwd diff --git a/apex/contrib/fmha/fmha.py b/apex/contrib/fmha/fmha.py index 4ebe97e93..c4ae69a8d 100644 --- a/apex/contrib/fmha/fmha.py +++ b/apex/contrib/fmha/fmha.py @@ -27,7 +27,7 @@ import torch -import fmhalib as mha +from apex._extensions import fmhalib as mha class FMHAFun(torch.autograd.Function): diff --git a/apex/contrib/focal_loss/__init__.py b/apex/contrib/focal_loss/__init__.py index 2a187d029..ce459f28d 100644 --- a/apex/contrib/focal_loss/__init__.py +++ b/apex/contrib/focal_loss/__init__.py @@ -1,6 +1,6 @@ try: import torch - import focal_loss_cuda + from apex._extensions import focal_loss_cuda from .focal_loss import focal_loss del torch diff --git a/apex/contrib/focal_loss/focal_loss.py b/apex/contrib/focal_loss/focal_loss.py index 85c6f620e..295deb477 100644 --- a/apex/contrib/focal_loss/focal_loss.py +++ b/apex/contrib/focal_loss/focal_loss.py @@ -1,6 +1,6 @@ import torch -import focal_loss_cuda +from apex._extensions import focal_loss_cuda class FocalLoss(torch.autograd.Function): diff --git a/apex/contrib/group_norm/group_norm.py b/apex/contrib/group_norm/group_norm.py index 998a220ff..0aa7ab902 100644 --- a/apex/contrib/group_norm/group_norm.py +++ b/apex/contrib/group_norm/group_norm.py @@ -10,8 +10,8 @@ import os import torch import torch.nn.init as init -import group_norm_cuda -import group_norm_v2_cuda +from apex._extensions import group_norm_cuda +from apex._extensions import group_norm_v2_cuda from torch import Tensor from torch.nn.parameter import Parameter diff --git a/apex/contrib/groupbn/__init__.py b/apex/contrib/groupbn/__init__.py index 4af4ac595..436f24ce4 100644 --- a/apex/contrib/groupbn/__init__.py +++ b/apex/contrib/groupbn/__init__.py @@ -1,6 +1,6 @@ try: import torch - import bnp + from apex._extensions import bnp from .batch_norm import BatchNorm2d_NHWC del torch diff --git a/apex/contrib/groupbn/batch_norm.py b/apex/contrib/groupbn/batch_norm.py index d5d71cd0f..6f78c2af2 100644 --- a/apex/contrib/groupbn/batch_norm.py +++ b/apex/contrib/groupbn/batch_norm.py @@ -2,7 +2,7 @@ import numpy as np from torch.nn.modules.batchnorm import _BatchNorm -import bnp +from apex._extensions import bnp class bn_NHWC_impl(torch.autograd.Function): diff --git a/apex/contrib/index_mul_2d/index_mul_2d.py b/apex/contrib/index_mul_2d/index_mul_2d.py index ab628b05d..47fe8c7a7 100644 --- a/apex/contrib/index_mul_2d/index_mul_2d.py +++ b/apex/contrib/index_mul_2d/index_mul_2d.py @@ -1,6 +1,6 @@ import torch -import fused_index_mul_2d +from apex._extensions import fused_index_mul_2d class IndexMul2d_(torch.autograd.Function): diff --git a/apex/contrib/layer_norm/layer_norm.py b/apex/contrib/layer_norm/layer_norm.py index ef134667d..7b962a307 100644 --- a/apex/contrib/layer_norm/layer_norm.py +++ b/apex/contrib/layer_norm/layer_norm.py @@ -2,7 +2,7 @@ from torch.nn import init from apex._autocast_utils import _cast_if_autocast_enabled -import fast_layer_norm +from apex._extensions import fast_layer_norm class FastLayerNormFN(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py index 7f5fe1994..bfd8958fe 100644 --- a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py +++ b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py @@ -1,6 +1,6 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class FastEncdecAttnFunc(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py index da7ef57dc..00c634bc8 100644 --- a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py +++ b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py @@ -7,7 +7,7 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class FastEncdecAttnNormAddFunc(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py b/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py index a03cbc518..bfea175b9 100644 --- a/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py +++ b/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py @@ -1,6 +1,6 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class FastSelfAttnFunc(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py b/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py index 5cbae97ec..64546f86e 100644 --- a/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py +++ b/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py @@ -1,6 +1,6 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class FastSelfAttnNormAddFunc(torch.autograd.Function): diff --git a/apex/contrib/multihead_attn/mask_softmax_dropout_func.py b/apex/contrib/multihead_attn/mask_softmax_dropout_func.py index 995e68e51..64d0542c4 100644 --- a/apex/contrib/multihead_attn/mask_softmax_dropout_func.py +++ b/apex/contrib/multihead_attn/mask_softmax_dropout_func.py @@ -1,6 +1,6 @@ import torch -import fast_multihead_attn +from apex._extensions import fast_multihead_attn class MaskSoftmaxDropout(torch.autograd.Function): diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index b1f6b58fc..0a5abd140 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -28,8 +28,8 @@ nccl_allocator = None from apex.multi_tensor_apply import multi_tensor_applier -import amp_C -import distributed_adam_cuda +from apex._extensions import amp_C +from apex._extensions import distributed_adam_cuda # Fallback to private functions if using PyTorch <1.13.0 try: @@ -126,7 +126,7 @@ def _coalescing_manager_append_work( # Import optional CUDA kernels _FOUND_DEPRECATED_FUSED_ADAM: bool = False try: - import fused_adam_cuda + from apex._extensions import fused_adam_cuda _FOUND_DEPRECATED_FUSED_ADAM = True except ImportError: diff --git a/apex/contrib/optimizers/distributed_fused_lamb.py b/apex/contrib/optimizers/distributed_fused_lamb.py index 93b964fbb..b8013dffd 100644 --- a/apex/contrib/optimizers/distributed_fused_lamb.py +++ b/apex/contrib/optimizers/distributed_fused_lamb.py @@ -2,7 +2,7 @@ import inspect import torch import importlib -import amp_C +from apex._extensions import amp_C from apex.multi_tensor_apply import multi_tensor_applier import torch.distributed.distributed_c10d as c10d @@ -140,8 +140,8 @@ def __init__( super(DistributedFusedLAMB, self).__init__(params, defaults) global fused_adam_cuda, distributed_lamb_cuda - fused_adam_cuda = importlib.import_module("fused_adam_cuda") - distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda") + fused_adam_cuda = importlib.import_module("apex._extensions.fused_adam_cuda") + distributed_lamb_cuda = importlib.import_module("apex._extensions.distributed_lamb_cuda") self._overflow_buf = torch.cuda.IntTensor([0]) self._has_overflow = False @@ -151,7 +151,7 @@ def __init__( self.multi_tensor_lamb_update_weights = ( distributed_lamb_cuda.multi_tensor_lamb_update_weights ) - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm diff --git a/apex/contrib/optimizers/fp16_optimizer.py b/apex/contrib/optimizers/fp16_optimizer.py index 856a181dc..3f5788a0e 100755 --- a/apex/contrib/optimizers/fp16_optimizer.py +++ b/apex/contrib/optimizers/fp16_optimizer.py @@ -56,7 +56,7 @@ def __init__( param_group["params"] = fp32_group if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.overflow_buf = torch.cuda.IntTensor([0]) self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm diff --git a/apex/contrib/optimizers/fused_adam.py b/apex/contrib/optimizers/fused_adam.py index 37a77b3cf..1ee075d8d 100644 --- a/apex/contrib/optimizers/fused_adam.py +++ b/apex/contrib/optimizers/fused_adam.py @@ -50,7 +50,7 @@ def __init__( amp_scale_adjustment=1.0, ): global fused_adam_cuda - fused_adam_cuda = importlib.import_module("fused_adam_cuda") + fused_adam_cuda = importlib.import_module("apex._extensions.fused_adam_cuda") self._use_multi_tensor = False if use_mt: diff --git a/apex/contrib/optimizers/fused_lamb.py b/apex/contrib/optimizers/fused_lamb.py index d40cfeab7..9e41cacc6 100644 --- a/apex/contrib/optimizers/fused_lamb.py +++ b/apex/contrib/optimizers/fused_lamb.py @@ -87,11 +87,11 @@ def __init__( ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - fused_lamb_cuda = importlib.import_module("fused_lamb_cuda") + fused_lamb_cuda = importlib.import_module("apex._extensions.fused_lamb_cuda") self.multi_tensor_lamb = fused_lamb_cuda.lamb else: raise RuntimeError("apex.contrib.optimizers.FusedLAMB requires cuda extensions") diff --git a/apex/contrib/optimizers/fused_sgd.py b/apex/contrib/optimizers/fused_sgd.py index e2acfcbaa..6fefda7bf 100644 --- a/apex/contrib/optimizers/fused_sgd.py +++ b/apex/contrib/optimizers/fused_sgd.py @@ -96,7 +96,7 @@ def __init__( self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py b/apex/contrib/peer_memory/peer_halo_exchanger_1d.py index 8995e806d..3512ac8c2 100644 --- a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py +++ b/apex/contrib/peer_memory/peer_halo_exchanger_1d.py @@ -1,5 +1,5 @@ import torch -import peer_memory_cuda as pm +from apex._extensions import peer_memory_cuda as pm class PeerHaloExchanger1d: diff --git a/apex/contrib/peer_memory/peer_memory.py b/apex/contrib/peer_memory/peer_memory.py index 72b1a1098..9717b33c3 100644 --- a/apex/contrib/peer_memory/peer_memory.py +++ b/apex/contrib/peer_memory/peer_memory.py @@ -1,6 +1,6 @@ import torch import numpy as np -import peer_memory_cuda as pm +from apex._extensions import peer_memory_cuda as pm class PeerMemoryPool(object): diff --git a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu b/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu index 4d690ee56..8c8b0c159 100644 --- a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu +++ b/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu @@ -1,7 +1,9 @@ -#include -#include #include -namespace py = pybind11; +#include +#include +#include + +#include #define gpuErrchk(ans) \ { \ @@ -36,9 +38,27 @@ __device__ float group_2_to_4(float4 vals) { return best_sum; } -inline float* float_ptr_from_numpy(py::array_t& py_float) { return (float*)py_float.data(); } +inline void check_cpu_contiguous(torch::stable::Tensor const& tensor, const char* name) { + STD_TORCH_CHECK(tensor.is_cpu(), name, " must be a CPU tensor"); + STD_TORCH_CHECK(tensor.is_contiguous(), name, " must be contiguous"); +} + +inline float* float_ptr_from_tensor(torch::stable::Tensor const& tensor, const char* name) { + check_cpu_contiguous(tensor, name); + STD_TORCH_CHECK(tensor.scalar_type() == torch::headeronly::ScalarType::Float, name, " must have dtype float32"); + return static_cast(tensor.mutable_data_ptr()); +} + +inline unsigned int* uint_ptr_from_tensor(torch::stable::Tensor const& tensor, const char* name) { + check_cpu_contiguous(tensor, name); + STD_TORCH_CHECK(tensor.element_size() == sizeof(unsigned int), name, " must have 32-bit elements"); + return reinterpret_cast(tensor.mutable_data_ptr()); +} -inline unsigned int* uint_ptr_from_numpy(py::array_t& py_uint) { return (unsigned int*)py_uint.data(); } +inline unsigned int as_uint_arg(int64_t value, const char* name) { + STD_TORCH_CHECK(value >= 0 && value <= std::numeric_limits::max(), name, " is out of range"); + return static_cast(value); +} /********************************************************** * Check for the best permutation for an entire matrix @@ -136,67 +156,6 @@ int set_up_check_permutation_memory(float** dmatrix, unsigned int rows, unsigned return fresh_alloc; } -int run_check_permutations( - py::array_t& py_matrix, unsigned int rows, unsigned int cols, - py::array_t& - py_stripe_groups, // groups of stripes, group_width = stripes per group, num_groups = groups in the array - unsigned int group_width, unsigned int num_groups, - py::array_t& py_permutations, // array of permutations to try, group_width*4 values per each of - // num_permutations permutations - unsigned int num_permutations, - py::array_t& py_improvement, // improvment offered by the best permutation - py::array_t& py_permutation // the best permutation -) { - const unsigned int threads = 32; - static float* d_matrix; - static unsigned int* d_permutations; - static unsigned int* d_stripes; - static float* d_results; - static float* results; - - float* matrix = float_ptr_from_numpy(py_matrix); - unsigned int* stripe_groups = uint_ptr_from_numpy(py_stripe_groups); - unsigned int* permutations = uint_ptr_from_numpy(py_permutations); - float* improvement = float_ptr_from_numpy(py_improvement); - unsigned int* permutation = uint_ptr_from_numpy(py_permutation); - - int fresh_alloc = set_up_check_permutation_memory(&d_matrix, rows, cols, &d_stripes, group_width, num_groups, - &d_permutations, num_permutations, &d_results, &results); - if (fresh_alloc == 1) { - gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * group_width * 4 * sizeof(unsigned int), - cudaMemcpyHostToDevice)); - gpuErrchk( - cudaMemcpy(d_stripes, stripe_groups, group_width * num_groups * sizeof(unsigned int), cudaMemcpyHostToDevice)); - } - - // initialize results, new matrix - gpuErrchk(cudaMemset(d_results, 0, num_permutations * sizeof(float))); - gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); - - // get results for all permutations - permute_and_sum_after_2_to_4<<>>( - d_matrix, rows, cols, d_stripes, group_width, d_permutations, d_results); - gpuErrchk(cudaDeviceSynchronize()); - - gpuErrchk(cudaMemcpy(results, d_results, num_permutations * sizeof(float), cudaMemcpyDeviceToHost)); - - // find the best permutation - could reduce on GPU - unsigned int best_permutation = 0; - float best_improvement = 0.0f; - for (unsigned int p = 1; p < num_permutations; ++p) { - float cur_improvement = results[p] - results[0]; - if (best_improvement < cur_improvement) { - best_permutation = p; - best_improvement = cur_improvement; - } - } - - *improvement = best_improvement; - *permutation = best_permutation; - - return 0; -} - /////////////////////////////////////////////////////////// /********************************************************** @@ -337,28 +296,6 @@ int set_up_sum_after_2_to_4_memory(float** dmatrix, unsigned int rows, unsigned return fresh_allocation; } -int run_subset_sum_after_2_to_4(py::array_t& py_matrix, unsigned int rows, unsigned int cols, - unsigned int start_col, unsigned int end_col, unsigned int blocks, unsigned int threads, - py::array_t& py_output) { - static float* d_matrix; - static float* d_result; - - int fresh_allocation = set_up_sum_after_2_to_4_memory(&d_matrix, rows, cols, &d_result); - - float* matrix = float_ptr_from_numpy(py_matrix); - float* output = float_ptr_from_numpy(py_output); - - gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); - gpuErrchk(cudaMemset(d_result, 0, sizeof(float))); - - subset_sum_after_2_to_4<<>>(d_matrix, rows, cols, start_col, end_col, d_result); - gpuErrchk(cudaDeviceSynchronize()); - - gpuErrchk(cudaMemcpy(output, d_result, sizeof(float), cudaMemcpyDeviceToHost)); - - return 0; -} - void set_up_permute_map_memory(float** dmatrix, unsigned int rows, unsigned int cols, unsigned int** dstripes, unsigned int num_groups, unsigned int group_width, unsigned int** dpermutations, unsigned int num_permutations, unsigned int perm_length, float** doutput, @@ -428,68 +365,6 @@ void set_up_permute_map_memory(float** dmatrix, unsigned int rows, unsigned int setUpPermLength = perm_length; } -int run_build_permute_map(py::array_t& py_matrix, unsigned int rows, unsigned int cols, - py::array_t& py_stripes, unsigned int num_groups, unsigned int group_width, - py::array_t& py_permutations, unsigned int perm_length, - py::array_t& py_improvements, py::array_t& py_best_indices) { - static float* d_matrix = NULL; - static unsigned int* d_stripes = NULL; - static unsigned int* d_permutations = NULL; - static float* d_output = NULL; - static unsigned int* d_indices = NULL; - static float* hresult = NULL; - static unsigned int* hindices = NULL; - - const unsigned int num_permutations = py_permutations.size() / perm_length; - - const unsigned int MAX_GROUPS_PER_LAUNCH = num_permutations <= 5775 ? 1820 : 40; - const unsigned int full_launches = num_groups / MAX_GROUPS_PER_LAUNCH; - const unsigned int final_launch = num_groups % MAX_GROUPS_PER_LAUNCH; - const unsigned int launches = full_launches + (final_launch != 0 ? 1 : 0); - - set_up_permute_map_memory(&d_matrix, rows, cols, &d_stripes, min(num_groups, MAX_GROUPS_PER_LAUNCH), group_width, - &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, &hindices); - - float* matrix = float_ptr_from_numpy(py_matrix); - unsigned int* stripes = uint_ptr_from_numpy(py_stripes); - unsigned int* permutations = uint_ptr_from_numpy(py_permutations); - float* improvements = float_ptr_from_numpy(py_improvements); - unsigned int* best_indices = uint_ptr_from_numpy(py_best_indices); - - gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); - gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * perm_length * sizeof(unsigned int), - cudaMemcpyHostToDevice)); - - unsigned int group_offset = 0; - for (unsigned int l = 0; l < launches; ++l) { - unsigned int groups_this_launch = (l < full_launches) ? MAX_GROUPS_PER_LAUNCH : final_launch; - - gpuErrchk(cudaMemcpy(d_stripes, &stripes[group_offset * group_width], - groups_this_launch * group_width * sizeof(unsigned int), cudaMemcpyHostToDevice)); - gpuErrchk(cudaMemset(d_output, 0, groups_this_launch * num_permutations * sizeof(float))); - gpuErrchk(cudaMemset(d_indices, 0, groups_this_launch * sizeof(unsigned int))); - - unsigned int shmem = 32 * (32) * sizeof(float); - build_permute_map<<>>(d_matrix, rows, cols, d_stripes, group_width, d_permutations, - num_permutations, perm_length, d_output, d_indices); - gpuErrchk(cudaDeviceSynchronize()); - - gpuErrchk( - cudaMemcpy(hresult, d_output, num_permutations * groups_this_launch * sizeof(float), cudaMemcpyDeviceToHost)); - gpuErrchk(cudaMemcpy(hindices, d_indices, groups_this_launch * sizeof(unsigned int), cudaMemcpyDeviceToHost)); - - // thread0 stuck the minimum in the first slot of each group - for (unsigned int g = 0; g < groups_this_launch; ++g) { - improvements[group_offset + g] = hresult[g * num_permutations]; - best_indices[group_offset + g] = hindices[g]; - } - - group_offset += groups_this_launch; - } - - return 0; -} - /********************************************************** * Build the swap map for channel_swaps **********************************************************/ @@ -594,17 +469,180 @@ void set_up_swap_map_memory(float** dmatrix, unsigned int rows, unsigned int col } } -int run_build_swap_map(py::array_t& py_matrix, unsigned int rows, unsigned int cols, - py::array_t& py_stripe_pairs, py::array_t& py_output) { +/////////////////////////////////////////////////////////// + +int64_t apex_permutation_search_check_permutations(torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, + int64_t cols_arg, torch::stable::Tensor const& stripe_groups_tensor, + int64_t group_width_arg, int64_t num_groups_arg, + torch::stable::Tensor const& permutations_tensor, + int64_t num_permutations_arg, + torch::stable::Tensor const& improvement_tensor, + torch::stable::Tensor const& permutation_tensor) { + static float* d_matrix; + static unsigned int* d_permutations; + static unsigned int* d_stripes; + static float* d_results; + static float* results; + + unsigned int rows = as_uint_arg(rows_arg, "rows"); + unsigned int cols = as_uint_arg(cols_arg, "cols"); + unsigned int group_width = as_uint_arg(group_width_arg, "group_width"); + unsigned int num_groups = as_uint_arg(num_groups_arg, "num_groups"); + unsigned int num_permutations = as_uint_arg(num_permutations_arg, "num_permutations"); + + float* matrix = float_ptr_from_tensor(matrix_tensor, "matrix"); + unsigned int* stripe_groups = uint_ptr_from_tensor(stripe_groups_tensor, "stripe_groups"); + unsigned int* permutations = uint_ptr_from_tensor(permutations_tensor, "permutations"); + float* improvement = float_ptr_from_tensor(improvement_tensor, "improvement"); + unsigned int* permutation = uint_ptr_from_tensor(permutation_tensor, "permutation"); + + int fresh_alloc = set_up_check_permutation_memory(&d_matrix, rows, cols, &d_stripes, group_width, num_groups, + &d_permutations, num_permutations, &d_results, &results); + if (fresh_alloc == 1) { + gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * group_width * 4 * sizeof(unsigned int), + cudaMemcpyHostToDevice)); + gpuErrchk( + cudaMemcpy(d_stripes, stripe_groups, group_width * num_groups * sizeof(unsigned int), cudaMemcpyHostToDevice)); + } + + gpuErrchk(cudaMemset(d_results, 0, num_permutations * sizeof(float))); + gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); + + permute_and_sum_after_2_to_4<<>>( + d_matrix, rows, cols, d_stripes, group_width, d_permutations, d_results); + gpuErrchk(cudaDeviceSynchronize()); + + gpuErrchk(cudaMemcpy(results, d_results, num_permutations * sizeof(float), cudaMemcpyDeviceToHost)); + + unsigned int best_permutation = 0; + float best_improvement = 0.0f; + for (unsigned int p = 1; p < num_permutations; ++p) { + float cur_improvement = results[p] - results[0]; + if (best_improvement < cur_improvement) { + best_permutation = p; + best_improvement = cur_improvement; + } + } + + *improvement = best_improvement; + *permutation = best_permutation; + return 0; +} + +int64_t apex_permutation_search_sum_after_2_to_4(torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, + int64_t cols_arg, int64_t start_col_arg, int64_t end_col_arg, + int64_t blocks_arg, int64_t threads_arg, + torch::stable::Tensor const& output_tensor) { + static float* d_matrix; + static float* d_result; + + unsigned int rows = as_uint_arg(rows_arg, "rows"); + unsigned int cols = as_uint_arg(cols_arg, "cols"); + unsigned int start_col = as_uint_arg(start_col_arg, "start_col"); + unsigned int end_col = as_uint_arg(end_col_arg, "end_col"); + unsigned int blocks = as_uint_arg(blocks_arg, "blocks"); + unsigned int threads = as_uint_arg(threads_arg, "threads"); + + int fresh_allocation = set_up_sum_after_2_to_4_memory(&d_matrix, rows, cols, &d_result); + (void)fresh_allocation; + + float* matrix = float_ptr_from_tensor(matrix_tensor, "matrix"); + float* output = float_ptr_from_tensor(output_tensor, "output"); + + gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); + gpuErrchk(cudaMemset(d_result, 0, sizeof(float))); + + subset_sum_after_2_to_4<<>>(d_matrix, rows, cols, start_col, end_col, d_result); + gpuErrchk(cudaDeviceSynchronize()); + + gpuErrchk(cudaMemcpy(output, d_result, sizeof(float), cudaMemcpyDeviceToHost)); + return 0; +} + +int64_t apex_permutation_search_build_permute_map(torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, + int64_t cols_arg, torch::stable::Tensor const& stripes_tensor, + int64_t num_groups_arg, int64_t group_width_arg, + torch::stable::Tensor const& permutations_tensor, + int64_t perm_length_arg, + torch::stable::Tensor const& improvements_tensor, + torch::stable::Tensor const& best_indices_tensor) { + static float* d_matrix = NULL; + static unsigned int* d_stripes = NULL; + static unsigned int* d_permutations = NULL; + static float* d_output = NULL; + static unsigned int* d_indices = NULL; + static float* hresult = NULL; + static unsigned int* hindices = NULL; + + unsigned int rows = as_uint_arg(rows_arg, "rows"); + unsigned int cols = as_uint_arg(cols_arg, "cols"); + unsigned int num_groups = as_uint_arg(num_groups_arg, "num_groups"); + unsigned int group_width = as_uint_arg(group_width_arg, "group_width"); + unsigned int perm_length = as_uint_arg(perm_length_arg, "perm_length"); + STD_TORCH_CHECK(perm_length > 0, "perm_length must be positive"); + unsigned int num_permutations = as_uint_arg(permutations_tensor.numel() / perm_length, "num_permutations"); + + const unsigned int MAX_GROUPS_PER_LAUNCH = num_permutations <= 5775 ? 1820 : 40; + const unsigned int full_launches = num_groups / MAX_GROUPS_PER_LAUNCH; + const unsigned int final_launch = num_groups % MAX_GROUPS_PER_LAUNCH; + const unsigned int launches = full_launches + (final_launch != 0 ? 1 : 0); + + set_up_permute_map_memory(&d_matrix, rows, cols, &d_stripes, min(num_groups, MAX_GROUPS_PER_LAUNCH), group_width, + &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, &hindices); + + float* matrix = float_ptr_from_tensor(matrix_tensor, "matrix"); + unsigned int* stripes = uint_ptr_from_tensor(stripes_tensor, "stripes"); + unsigned int* permutations = uint_ptr_from_tensor(permutations_tensor, "permutations"); + float* improvements = float_ptr_from_tensor(improvements_tensor, "improvements"); + unsigned int* best_indices = uint_ptr_from_tensor(best_indices_tensor, "best_indices"); + + gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); + gpuErrchk(cudaMemcpy(d_permutations, permutations, num_permutations * perm_length * sizeof(unsigned int), + cudaMemcpyHostToDevice)); + + unsigned int group_offset = 0; + for (unsigned int l = 0; l < launches; ++l) { + unsigned int groups_this_launch = (l < full_launches) ? MAX_GROUPS_PER_LAUNCH : final_launch; + + gpuErrchk(cudaMemcpy(d_stripes, &stripes[group_offset * group_width], + groups_this_launch * group_width * sizeof(unsigned int), cudaMemcpyHostToDevice)); + gpuErrchk(cudaMemset(d_output, 0, groups_this_launch * num_permutations * sizeof(float))); + gpuErrchk(cudaMemset(d_indices, 0, groups_this_launch * sizeof(unsigned int))); + + unsigned int shmem = 32 * (32) * sizeof(float); + build_permute_map<<>>(d_matrix, rows, cols, d_stripes, group_width, d_permutations, + num_permutations, perm_length, d_output, d_indices); + gpuErrchk(cudaDeviceSynchronize()); + + gpuErrchk( + cudaMemcpy(hresult, d_output, num_permutations * groups_this_launch * sizeof(float), cudaMemcpyDeviceToHost)); + gpuErrchk(cudaMemcpy(hindices, d_indices, groups_this_launch * sizeof(unsigned int), cudaMemcpyDeviceToHost)); + + for (unsigned int g = 0; g < groups_this_launch; ++g) { + improvements[group_offset + g] = hresult[g * num_permutations]; + best_indices[group_offset + g] = hindices[g]; + } + + group_offset += groups_this_launch; + } + + return 0; +} + +int64_t apex_permutation_search_build_swap_map(torch::stable::Tensor const& matrix_tensor, int64_t rows_arg, + int64_t cols_arg, torch::stable::Tensor const& stripe_pairs_tensor, + torch::stable::Tensor const& output_tensor) { static float* d_matrix = NULL; static float* d_result = NULL; static unsigned int* d_stripe_pairs = NULL; - float* matrix = float_ptr_from_numpy(py_matrix); //(float*)py_matrix.data(); - unsigned int* stripe_pairs = uint_ptr_from_numpy(py_stripe_pairs); //(unsigned int*)py_stripe_pairs.data(); - float* output = float_ptr_from_numpy(py_output); //(float*)py_output.data(); + unsigned int rows = as_uint_arg(rows_arg, "rows"); + unsigned int cols = as_uint_arg(cols_arg, "cols"); + unsigned int num_pairs = as_uint_arg(stripe_pairs_tensor.numel() / 2, "num_pairs"); - unsigned int num_pairs = py_stripe_pairs.size() / 2; + float* matrix = float_ptr_from_tensor(matrix_tensor, "matrix"); + unsigned int* stripe_pairs = uint_ptr_from_tensor(stripe_pairs_tensor, "stripe_pairs"); + float* output = float_ptr_from_tensor(output_tensor, "output"); set_up_swap_map_memory(&d_matrix, rows, cols, &d_stripe_pairs, num_pairs, &d_result); gpuErrchk(cudaMemcpy(d_matrix, matrix, rows * cols * sizeof(float), cudaMemcpyHostToDevice)); @@ -616,17 +654,29 @@ int run_build_swap_map(py::array_t& py_matrix, unsigned int rows, unsigne gpuErrchk(cudaDeviceSynchronize()); gpuErrchk(cudaMemcpy(output, d_result, num_pairs * 16 * sizeof(float), cudaMemcpyDeviceToHost)); - return 0; } -/////////////////////////////////////////////////////////// -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("sum_after_2_to_4", &run_subset_sum_after_2_to_4, "matrix sum after applying 2:4 (CUDA)", - py::call_guard()); - m.def("build_permute_map", &run_build_permute_map, "optimize stripe groups (CUDA)", - py::call_guard()); - m.def("check_permutations", &run_check_permutations, "exhaustively check all permutations (CUDA)", - py::call_guard()); - m.def("build_swap_map", &run_build_swap_map, "channel swaps (CUDA)", py::call_guard()); +STABLE_TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "permutation_search_sum_after_2_to_4(Tensor matrix, int rows, int cols, int start_col, int end_col, " + "int blocks, int threads, Tensor(a!) output) -> int"); + m.def( + "permutation_search_build_permute_map(Tensor matrix, int rows, int cols, Tensor stripes, int num_groups, " + "int group_width, Tensor permutations, int perm_length, Tensor(a!) improvements, Tensor(b!) best_indices) " + "-> int"); + m.def( + "permutation_search_check_permutations(Tensor matrix, int rows, int cols, Tensor stripe_groups, " + "int group_width, int num_groups, Tensor permutations, int num_permutations, Tensor(a!) improvement, " + "Tensor(b!) permutation) -> int"); + m.def( + "permutation_search_build_swap_map(Tensor matrix, int rows, int cols, Tensor stripe_pairs, " + "Tensor(a!) output) -> int"); +} + +STABLE_TORCH_LIBRARY_IMPL(apex, CPU, m) { + m.impl("permutation_search_sum_after_2_to_4", TORCH_BOX(&apex_permutation_search_sum_after_2_to_4)); + m.impl("permutation_search_build_permute_map", TORCH_BOX(&apex_permutation_search_build_permute_map)); + m.impl("permutation_search_check_permutations", TORCH_BOX(&apex_permutation_search_check_permutations)); + m.impl("permutation_search_build_swap_map", TORCH_BOX(&apex_permutation_search_build_swap_map)); } diff --git a/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py b/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py index 7d253fa0b..04d3a010a 100644 --- a/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py +++ b/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py @@ -6,12 +6,12 @@ gpus_found = 0 kernels_found = True try: - import permutation_search_cuda as permutation_search_cuda_kernels + from apex._extensions import permutation_search_cuda as permutation_search_cuda_kernels print("Found permutation search CUDA kernels") except ImportError: try: - from . import permutation_search_cuda as permutation_search_cuda_kernels + from apex._extensions import permutation_search_cuda as permutation_search_cuda_kernels print("Found permutation search CUDA kernels for standalone testing") diff --git a/apex/contrib/test/layer_norm/test_fast_layer_norm.py b/apex/contrib/test/layer_norm/test_fast_layer_norm.py index c85afa612..4e4dda0a0 100644 --- a/apex/contrib/test/layer_norm/test_fast_layer_norm.py +++ b/apex/contrib/test/layer_norm/test_fast_layer_norm.py @@ -6,7 +6,7 @@ SKIP_TEST = None try: from apex.contrib.layer_norm.layer_norm import FastLayerNorm - import fast_layer_norm as fln + from apex._extensions import fast_layer_norm as fln except ImportError as e: SKIP_TEST = e diff --git a/apex/contrib/transducer/transducer.py b/apex/contrib/transducer/transducer.py index bb53d39a6..18db7f826 100755 --- a/apex/contrib/transducer/transducer.py +++ b/apex/contrib/transducer/transducer.py @@ -1,6 +1,6 @@ import torch -import transducer_loss_cuda -import transducer_joint_cuda +from apex._extensions import transducer_loss_cuda +from apex._extensions import transducer_joint_cuda class TransducerJoint(torch.nn.Module): diff --git a/apex/contrib/xentropy/softmax_xentropy.py b/apex/contrib/xentropy/softmax_xentropy.py index 528f743b0..78fb44d10 100644 --- a/apex/contrib/xentropy/softmax_xentropy.py +++ b/apex/contrib/xentropy/softmax_xentropy.py @@ -1,6 +1,6 @@ import torch -import xentropy_cuda +from apex._extensions import xentropy_cuda class SoftmaxCrossEntropyLoss(torch.autograd.Function): diff --git a/apex/fused_dense/fused_dense.py b/apex/fused_dense/fused_dense.py index 239d12727..fffe5cf39 100644 --- a/apex/fused_dense/fused_dense.py +++ b/apex/fused_dense/fused_dense.py @@ -1,6 +1,8 @@ +import math + import torch from torch import nn -import fused_dense_cuda +from apex._extensions import fused_dense_cuda from apex._autocast_utils import _cast_if_autocast_enabled @@ -60,6 +62,8 @@ def backward(ctx, grad_output): def _fused_dense(input, weight, bias): args = _cast_if_autocast_enabled(input, weight, bias) with torch.amp.autocast("cuda", enabled=False): + if args[0].dtype == torch.bfloat16: + return torch.matmul(args[0], args[1].t()) + args[2] return FusedDenseFunc.apply(*args) @@ -86,6 +90,14 @@ def __init__(self, in_features, out_features, bias=True): else: # assert False, "no-bias option not added yet" self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) def forward(self, input): if self.bias is not None: @@ -105,6 +117,15 @@ def __init__(self, in_features, intermediate_features, out_features, bias=True): self.bias1 = nn.Parameter(torch.empty(intermediate_features)) self.weight2 = nn.Parameter(torch.empty(out_features, intermediate_features)) self.bias2 = nn.Parameter(torch.empty(out_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + for weight, bias in ((self.weight1, self.bias1), (self.weight2, self.bias2)): + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + nn.init.uniform_(bias, -bound, bound) def forward(self, input): return _fused_dense_gelu_dense(input, self.weight1, self.bias1, self.weight2, self.bias2) diff --git a/apex/mlp/mlp.py b/apex/mlp/mlp.py index 4297b0a73..834f3b70b 100644 --- a/apex/mlp/mlp.py +++ b/apex/mlp/mlp.py @@ -5,7 +5,7 @@ from torch import nn from apex._autocast_utils import _cast_if_autocast_enabled -import mlp_cuda +from apex._extensions import mlp_cuda class MlpFunction(torch.autograd.Function): diff --git a/apex/multi_tensor_apply/multi_tensor_apply.py b/apex/multi_tensor_apply/multi_tensor_apply.py index ba2c21ada..4eeb3f6ba 100644 --- a/apex/multi_tensor_apply/multi_tensor_apply.py +++ b/apex/multi_tensor_apply/multi_tensor_apply.py @@ -4,7 +4,7 @@ class MultiTensorApply(object): def __init__(self, chunk_size): try: - import amp_C + from apex._extensions import amp_C MultiTensorApply.available = True self.chunk_size = chunk_size diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index a0f3833bc..cd61dec85 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -40,7 +40,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -87,7 +89,9 @@ def fused_layer_norm_affine_fwd( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) input_ = input.contiguous() weight_ = weight.contiguous() @@ -204,7 +208,9 @@ class FusedRMSNormAffineFunction(torch.autograd.Function): def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -247,7 +253,9 @@ def fused_rms_norm_affine_fwd( ) -> Tuple[torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) input_ = input.contiguous() weight_ = weight.contiguous() @@ -358,7 +366,9 @@ class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -380,7 +390,9 @@ class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -401,7 +413,9 @@ class FusedLayerNormFunction(torch.autograd.Function): def forward(ctx, input, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -439,7 +453,9 @@ def fused_layer_norm_fwd( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) input_ = input.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward(input_, normalized_shape, eps) @@ -535,7 +551,9 @@ class FusedRMSNormFunction(torch.autograd.Function): def forward(ctx, input, normalized_shape, eps, memory_efficient=False): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) ctx.normalized_shape = normalized_shape ctx.eps = eps ctx.memory_efficient = memory_efficient @@ -573,7 +591,9 @@ def fused_rms_norm_fwd( ) -> Tuple[torch.Tensor, torch.Tensor]: global fused_layer_norm_cuda if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module( + "apex._extensions.fused_layer_norm_cuda" + ) input_ = input.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward(input_, normalized_shape, eps) @@ -791,7 +811,7 @@ def __init__( super().__init__() global fused_layer_norm_cuda - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) @@ -908,7 +928,7 @@ def __init__( super().__init__() global fused_layer_norm_cuda - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + fused_layer_norm_cuda = importlib.import_module("apex._extensions.fused_layer_norm_cuda") if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) diff --git a/apex/optimizers/fused_adagrad.py b/apex/optimizers/fused_adagrad.py index d9dbcd7cc..43c80d0e4 100644 --- a/apex/optimizers/fused_adagrad.py +++ b/apex/optimizers/fused_adagrad.py @@ -56,7 +56,7 @@ def __init__( self.set_grad_none = set_grad_none if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) diff --git a/apex/optimizers/fused_adam.py b/apex/optimizers/fused_adam.py index 45d7e017e..e7b42c13f 100644 --- a/apex/optimizers/fused_adam.py +++ b/apex/optimizers/fused_adam.py @@ -125,7 +125,7 @@ def __init__( self._step_supports_amp_scaling = True if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") diff --git a/apex/optimizers/fused_lamb.py b/apex/optimizers/fused_lamb.py index a5630a03c..ac6d79e85 100644 --- a/apex/optimizers/fused_lamb.py +++ b/apex/optimizers/fused_lamb.py @@ -88,7 +88,7 @@ def __init__( ) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm # Skip buffer diff --git a/apex/optimizers/fused_mixed_precision_lamb.py b/apex/optimizers/fused_mixed_precision_lamb.py index 2ecdddfd2..4a759ffaa 100644 --- a/apex/optimizers/fused_mixed_precision_lamb.py +++ b/apex/optimizers/fused_mixed_precision_lamb.py @@ -50,7 +50,7 @@ def __init__( self.param_groups[idx][item] = group[item].to(device=device) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm_mp # Skip buffer diff --git a/apex/optimizers/fused_novograd.py b/apex/optimizers/fused_novograd.py index b72e2e3e6..c044fc59d 100644 --- a/apex/optimizers/fused_novograd.py +++ b/apex/optimizers/fused_novograd.py @@ -93,7 +93,7 @@ def __init__( ) super(FusedNovoGrad, self).__init__(params, defaults) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer # Creating the overflow buffer on the same device as the params tensors. diff --git a/apex/optimizers/fused_sgd.py b/apex/optimizers/fused_sgd.py index f4cec9f57..429a6b5b2 100644 --- a/apex/optimizers/fused_sgd.py +++ b/apex/optimizers/fused_sgd.py @@ -111,7 +111,7 @@ def __init__( self.set_grad_none = set_grad_none if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C # Skip buffer self._dummy_overflow_buf = torch.tensor( diff --git a/csrc/amp_C_frontend.cpp b/csrc/amp_C_frontend.cpp index 585e7872a..7ad36f74b 100644 --- a/csrc/amp_C_frontend.cpp +++ b/csrc/amp_C_frontend.cpp @@ -1,4 +1,4 @@ -#include +#include void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, float scale); @@ -80,44 +80,229 @@ at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, at::Tensor gro const double backoff_factor, const int64_t growth_interval, const int hysteresis); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_scale", &multi_tensor_scale_cuda, "Fused overflow check + scale for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, "Fused SGD optimizer for list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda, "out = a*x + b*y for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, "Computes L2 norm for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_l2norm_mp", &multi_tensor_l2norm_mp_cuda, "Computes L2 norm for a list of contiguous tensors", - py::call_guard()); - m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda, - "Computes L2 norm for a list of contiguous tensors and does scaling", py::call_guard()); - m.def("multi_tensor_unscale_l2norm", &multi_tensor_unscale_l2norm_cuda, - "Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only performed for L2 norm " - "computation, and tensors are not updated)", - py::call_guard()); - m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda, "Computes update part of LAMB optimizer", - py::call_guard()); - m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda, - "Completes application of gradient to parameters for LAMB optimizer", py::call_guard()); - m.def("multi_tensor_adam", &multi_tensor_adam_cuda, - "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); - m.def("multi_tensor_adam_capturable", &multi_tensor_adam_capturable_cuda, - "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support and LR scheduling", - py::call_guard()); - m.def("multi_tensor_adam_capturable_master", &multi_tensor_adam_capturable_master_cuda, - "Compute and apply gradient update to parameters for Adam optimizer with CUDA graph support, LR scheduling and " - "FP32 master weights", - py::call_guard()); - m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda, - "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); - m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda, - "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); - m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer", - py::call_guard()); - m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda, "Computes and apply update for LAMB optimizer", - py::call_guard()); - m.def("update_scale_hysteresis", &update_scale_hysteresis_cuda, "Updates scale while accounting for hysteresis", - py::call_guard()); +namespace { +void apex_multi_tensor_scale(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double scale) { + multi_tensor_scale_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(scale)); +} + +void apex_multi_tensor_sgd(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + double wd, double momentum, double dampening, double lr, bool nesterov, bool first_run, + bool wd_after_momentum, double scale) { + multi_tensor_sgd_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(wd), + static_cast(momentum), static_cast(dampening), static_cast(lr), nesterov, + first_run, wd_after_momentum, static_cast(scale)); +} + +void apex_multi_tensor_axpby(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double a, double b, + int64_t arg_to_check) { + multi_tensor_axpby_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(a), + static_cast(b), static_cast(arg_to_check)); +} + +std::tuple apex_multi_tensor_l2norm(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python) { + return multi_tensor_l2norm_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), per_tensor_python); +} + +std::tuple apex_multi_tensor_l2norm_mp(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::optional per_tensor_python) { + return multi_tensor_l2norm_mp_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), + per_tensor_python); +} + +std::tuple apex_multi_tensor_l2norm_scale(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + double scale, at::optional per_tensor_python) { + return multi_tensor_l2norm_scale_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), + static_cast(scale), per_tensor_python); +} + +std::tuple apex_multi_tensor_unscale_l2norm(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor inv_scale, + at::optional per_tensor_python) { + return multi_tensor_unscale_l2norm_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), inv_scale, + per_tensor_python); +} + +void apex_multi_tensor_lamb_stage1_cuda(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor per_tensor_decay, + int64_t step, double beta1, double beta2, double epsilon, + at::Tensor global_grad_norm, double max_global_grad_norm) { + multi_tensor_lamb_stage1_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), per_tensor_decay, + static_cast(step), static_cast(beta1), static_cast(beta2), + static_cast(epsilon), global_grad_norm, + static_cast(max_global_grad_norm)); +} + +void apex_multi_tensor_lamb_stage2_cuda(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + at::Tensor per_tensor_param_norm, at::Tensor per_tensor_update_norm, double lr, + double weight_decay, at::optional use_nvlamb_python) { + multi_tensor_lamb_stage2_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), per_tensor_param_norm, + per_tensor_update_norm, static_cast(lr), static_cast(weight_decay), + use_nvlamb_python); +} + +void apex_multi_tensor_adam(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + double lr, double beta1, double beta2, double epsilon, int64_t step, int64_t mode, + int64_t bias_correction, double weight_decay) { + multi_tensor_adam_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + static_cast(step), static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay)); +} + +void apex_multi_tensor_adam_capturable(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor lr, double beta1, + double beta2, double epsilon, at::Tensor step, int64_t mode, + int64_t bias_correction, double weight_decay, at::Tensor inv_scale) { + multi_tensor_adam_capturable_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), lr, + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + step, static_cast(mode), static_cast(bias_correction), + static_cast(weight_decay), inv_scale); +} + +void apex_multi_tensor_adam_capturable_master(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor lr, + double beta1, double beta2, double epsilon, at::Tensor step, int64_t mode, + int64_t bias_correction, double weight_decay, at::Tensor inv_scale) { + multi_tensor_adam_capturable_master_cuda( + static_cast(chunk_size), noop_flag, std::move(tensor_lists), lr, static_cast(beta1), + static_cast(beta2), static_cast(epsilon), step, static_cast(mode), + static_cast(bias_correction), static_cast(weight_decay), inv_scale); +} + +void apex_multi_tensor_adagrad(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, double lr, double epsilon, + int64_t mode, double weight_decay) { + multi_tensor_adagrad_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(lr), + static_cast(epsilon), static_cast(mode), static_cast(weight_decay)); +} + +void apex_multi_tensor_novograd(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor grad_norms, double lr, + double beta1, double beta2, double epsilon, int64_t step, int64_t bias_correction, + double weight_decay, int64_t grad_averaging, int64_t mode, int64_t norm_type) { + multi_tensor_novograd_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), grad_norms, + static_cast(lr), static_cast(beta1), static_cast(beta2), + static_cast(epsilon), static_cast(step), static_cast(bias_correction), + static_cast(weight_decay), static_cast(grad_averaging), static_cast(mode), + static_cast(norm_type)); +} + +void apex_multi_tensor_lamb(int64_t chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, + double lr, double beta1, double beta2, double epsilon, int64_t step, + int64_t bias_correction, double weight_decay, int64_t grad_averaging, int64_t mode, + at::Tensor global_grad_norm, double max_grad_norm, at::optional use_nvlamb_python) { + multi_tensor_lamb_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), static_cast(lr), + static_cast(beta1), static_cast(beta2), static_cast(epsilon), + static_cast(step), static_cast(bias_correction), static_cast(weight_decay), + static_cast(grad_averaging), static_cast(mode), global_grad_norm, + static_cast(max_grad_norm), use_nvlamb_python); +} + +void apex_multi_tensor_lamb_mp(int64_t chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, at::Tensor lr, double beta1, + double beta2, double epsilon, at::Tensor step, int64_t bias_correction, + double weight_decay, int64_t grad_averaging, int64_t mode, at::Tensor global_grad_norm, + at::Tensor max_grad_norm, at::optional use_nvlamb_python, at::Tensor found_inf, + at::Tensor inv_scale) { + multi_tensor_lamb_mp_cuda(static_cast(chunk_size), noop_flag, std::move(tensor_lists), lr, + static_cast(beta1), static_cast(beta2), static_cast(epsilon), step, + static_cast(bias_correction), static_cast(weight_decay), + static_cast(grad_averaging), static_cast(mode), global_grad_norm, max_grad_norm, + use_nvlamb_python, found_inf, inv_scale); +} + +at::Tensor apex_update_scale_hysteresis(at::Tensor current_scale, at::Tensor growth_tracker, + at::Tensor hysteresis_tracker, at::Tensor found_inf, double growth_factor, + double backoff_factor, int64_t growth_interval, int64_t hysteresis) { + return update_scale_hysteresis_cuda(current_scale, growth_tracker, hysteresis_tracker, found_inf, growth_factor, + backoff_factor, growth_interval, static_cast(hysteresis)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("amp_multi_tensor_scale(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float scale) -> ()"); + m.def( + "amp_multi_tensor_sgd(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float wd, float momentum, " + "float dampening, float lr, bool nesterov, bool first_run, bool wd_after_momentum, float scale) -> ()"); + m.def( + "amp_multi_tensor_axpby(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float a, float b, " + "int arg_to_check) -> ()"); + m.def( + "amp_multi_tensor_l2norm(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, bool? per_tensor_python) " + "-> (Tensor, Tensor)"); + m.def( + "amp_multi_tensor_l2norm_mp(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "bool? per_tensor_python) -> (Tensor, Tensor)"); + m.def( + "amp_multi_tensor_l2norm_scale(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float scale, " + "bool? per_tensor_python) -> (Tensor, Tensor)"); + m.def( + "amp_multi_tensor_unscale_l2norm(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor inv_scale, " + "bool? per_tensor_python) -> (Tensor, Tensor)"); + m.def( + "amp_multi_tensor_lamb_stage1_cuda(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_decay, int step, float beta1, float beta2, float epsilon, Tensor global_grad_norm, " + "float max_global_grad_norm) -> ()"); + m.def( + "amp_multi_tensor_lamb_stage2_cuda(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, " + "Tensor per_tensor_param_norm, Tensor per_tensor_update_norm, float lr, float weight_decay, " + "bool? use_nvlamb_python) -> ()"); + m.def( + "amp_multi_tensor_adam(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " + "float beta2, float epsilon, int step, int mode, int bias_correction, float weight_decay) -> ()"); + m.def( + "amp_multi_tensor_adam_capturable(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, " + "float beta1, float beta2, float epsilon, Tensor step, int mode, int bias_correction, float weight_decay, " + "Tensor inv_scale) -> ()"); + m.def( + "amp_multi_tensor_adam_capturable_master(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, " + "float beta1, float beta2, float epsilon, Tensor step, int mode, int bias_correction, float weight_decay, " + "Tensor inv_scale) -> ()"); + m.def( + "amp_multi_tensor_adagrad(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float epsilon, " + "int mode, float weight_decay) -> ()"); + m.def( + "amp_multi_tensor_novograd(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor grad_norms, " + "float lr, float beta1, float beta2, float epsilon, int step, int bias_correction, float weight_decay, " + "int grad_averaging, int mode, int norm_type) -> ()"); + m.def( + "amp_multi_tensor_lamb(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, float lr, float beta1, " + "float beta2, float epsilon, int step, int bias_correction, float weight_decay, int grad_averaging, " + "int mode, Tensor global_grad_norm, float max_grad_norm, bool? use_nvlamb_python) -> ()"); + m.def( + "amp_multi_tensor_lamb_mp(int chunk_size, Tensor noop_flag, Tensor[][] tensor_lists, Tensor lr, float beta1, " + "float beta2, float epsilon, Tensor step, int bias_correction, float weight_decay, int grad_averaging, " + "int mode, Tensor global_grad_norm, Tensor max_grad_norm, bool? use_nvlamb_python, Tensor found_inf, " + "Tensor inv_scale) -> ()"); + m.def( + "amp_update_scale_hysteresis(Tensor current_scale, Tensor growth_tracker, Tensor hysteresis_tracker, " + "Tensor found_inf, float growth_factor, float backoff_factor, int growth_interval, int hysteresis) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("amp_multi_tensor_scale", &apex_multi_tensor_scale); + m.impl("amp_multi_tensor_sgd", &apex_multi_tensor_sgd); + m.impl("amp_multi_tensor_axpby", &apex_multi_tensor_axpby); + m.impl("amp_multi_tensor_l2norm", &apex_multi_tensor_l2norm); + m.impl("amp_multi_tensor_l2norm_mp", &apex_multi_tensor_l2norm_mp); + m.impl("amp_multi_tensor_l2norm_scale", &apex_multi_tensor_l2norm_scale); + m.impl("amp_multi_tensor_unscale_l2norm", &apex_multi_tensor_unscale_l2norm); + m.impl("amp_multi_tensor_lamb_stage1_cuda", &apex_multi_tensor_lamb_stage1_cuda); + m.impl("amp_multi_tensor_lamb_stage2_cuda", &apex_multi_tensor_lamb_stage2_cuda); + m.impl("amp_multi_tensor_adam", &apex_multi_tensor_adam); + m.impl("amp_multi_tensor_adam_capturable", &apex_multi_tensor_adam_capturable); + m.impl("amp_multi_tensor_adam_capturable_master", &apex_multi_tensor_adam_capturable_master); + m.impl("amp_multi_tensor_adagrad", &apex_multi_tensor_adagrad); + m.impl("amp_multi_tensor_novograd", &apex_multi_tensor_novograd); + m.impl("amp_multi_tensor_lamb", &apex_multi_tensor_lamb); + m.impl("amp_multi_tensor_lamb_mp", &apex_multi_tensor_lamb_mp); + m.impl("amp_update_scale_hysteresis", &apex_update_scale_hysteresis); } diff --git a/csrc/fused_dense.cpp b/csrc/fused_dense.cpp index 5b09f1e7c..4d98eb7ef 100644 --- a/csrc/fused_dense.cpp +++ b/csrc/fused_dense.cpp @@ -1,6 +1,7 @@ +#include +#include #include -#include -#include +#include #include @@ -40,7 +41,6 @@ at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor b AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_bias_forward", [&] { scalar_t* w_ptr = weight.data_ptr(); - scalar_t* b_ptr = bias.data_ptr(); [[maybe_unused]] auto result = linear_bias_forward_cuda(input, w_ptr, bias, in_features, batch_size, out_features, out, // out.data_ptr(), @@ -48,7 +48,7 @@ at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor b (void*)(lt_workspace.data_ptr())); }); - return {out}; + return out; } std::vector linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) { @@ -74,7 +74,6 @@ std::vector linear_bias_backward(at::Tensor input, at::Tensor weight AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_bias_backward", [&] { scalar_t* w_ptr = weight.data_ptr(); - scalar_t* d_b_ptr = d_bias.data_ptr(); [[maybe_unused]] auto result = linear_bias_backward_cuda( input.data_ptr(), w_ptr, d_output.data_ptr(), in_features, batch_size, out_features, d_weight.data_ptr(), d_bias.data_ptr(), d_input.data_ptr(), @@ -142,8 +141,6 @@ std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_bias_backward", [&] { - // scalar_t* w_ptr = weight.data_ptr(); - // scalar_t* d_b_ptr = d_bias.data_ptr(); [[maybe_unused]] auto result = linear_gelu_linear_backward_cuda( input.data_ptr(), gelu_in.data_ptr(), output1.data_ptr(), weight1.data_ptr(), weight2.data_ptr(), d_output1.data_ptr(), @@ -157,12 +154,20 @@ std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor return {d_input, d_weight1, d_bias1, d_weight2, d_bias2}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward", py::call_guard()); - m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward", - py::call_guard()); - m.def("linear_gelu_linear_forward", &linear_gelu_linear_forward, "linear gelu linear forward", - py::call_guard()); - m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward", - py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_dense_linear_bias_forward(Tensor input, Tensor weight, Tensor bias) -> Tensor"); + m.def("fused_dense_linear_bias_backward(Tensor input, Tensor weight, Tensor d_output) -> Tensor[]"); + m.def( + "fused_dense_linear_gelu_linear_forward(Tensor input, Tensor weight1, Tensor bias1, Tensor weight2, " + "Tensor bias2) -> Tensor[]"); + m.def( + "fused_dense_linear_gelu_linear_backward(Tensor input, Tensor gelu_in, Tensor output1, Tensor weight1, " + "Tensor weight2, Tensor d_output2) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_dense_linear_bias_forward", &linear_bias_forward); + m.impl("fused_dense_linear_bias_backward", &linear_bias_backward); + m.impl("fused_dense_linear_gelu_linear_forward", &linear_gelu_linear_forward); + m.impl("fused_dense_linear_gelu_linear_backward", &linear_gelu_linear_backward); } diff --git a/csrc/fused_dense_cuda.cu b/csrc/fused_dense_cuda.cu index 9cab0dfa8..fd9af6b85 100644 --- a/csrc/fused_dense_cuda.cu +++ b/csrc/fused_dense_cuda.cu @@ -4,7 +4,6 @@ #include #include #include -#include /* Includes, cuda */ #include diff --git a/csrc/layer_norm_cuda.cpp b/csrc/layer_norm_cuda.cpp index 4e5abae44..3361f22c5 100644 --- a/csrc/layer_norm_cuda.cpp +++ b/csrc/layer_norm_cuda.cpp @@ -1,7 +1,9 @@ -#include +#include +#include #include #include +#include #include namespace { @@ -249,24 +251,88 @@ std::vector rms_norm_gradient_affine(at::Tensor& dout, at::Tensor& i return {grad_input, grad_gamma}; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)", py::call_guard()); - m.def("forward", &layer_norm, "LayerNorm forward (CUDA)", py::call_guard()); - m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)", - py::call_guard()); - m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)", py::call_guard()); - - m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, - "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation", - py::call_guard()); - - m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)", py::call_guard()); - m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)", py::call_guard()); - m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)", - py::call_guard()); - m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)", py::call_guard()); - - m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, - "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation", - py::call_guard()); +namespace { +at::Tensor apex_layer_norm_gradient(const at::Tensor& dout, const std::optional& mean, + const at::Tensor& invvar, const at::Tensor& input_or_output, + at::IntArrayRef normalized_shape, double epsilon, bool memory_efficient) { + at::Tensor dout_ = dout; + at::Tensor invvar_ = invvar; + at::Tensor input_or_output_ = input_or_output; + return layer_norm_gradient(dout_, mean, invvar_, input_or_output_, normalized_shape, epsilon, memory_efficient); +} + +std::vector apex_layer_norm_gradient_affine(const at::Tensor& dout, const std::optional& mean, + const at::Tensor& invvar, const at::Tensor& input_or_output, + at::IntArrayRef normalized_shape, const at::Tensor& gamma, + const at::Tensor& beta, double epsilon, bool memory_efficient) { + at::Tensor dout_ = dout; + at::Tensor invvar_ = invvar; + at::Tensor input_or_output_ = input_or_output; + at::Tensor gamma_ = gamma; + at::Tensor beta_ = beta; + return layer_norm_gradient_affine(dout_, mean, invvar_, input_or_output_, normalized_shape, gamma_, beta_, epsilon, + memory_efficient); +} + +at::Tensor apex_rms_norm_gradient(const at::Tensor& dout, const at::Tensor& invvar, const at::Tensor& input_or_output, + at::IntArrayRef normalized_shape, double epsilon, bool memory_efficient) { + at::Tensor dout_ = dout; + at::Tensor invvar_ = invvar; + at::Tensor input_or_output_ = input_or_output; + return rms_norm_gradient(dout_, invvar_, input_or_output_, normalized_shape, epsilon, memory_efficient); +} + +std::vector apex_rms_norm_gradient_affine(const at::Tensor& dout, const at::Tensor& invvar, + const at::Tensor& input_or_output, + at::IntArrayRef normalized_shape, const at::Tensor& gamma, + double epsilon, bool memory_efficient) { + at::Tensor dout_ = dout; + at::Tensor invvar_ = invvar; + at::Tensor input_or_output_ = input_or_output; + at::Tensor gamma_ = gamma; + return rms_norm_gradient_affine(dout_, invvar_, input_or_output_, normalized_shape, gamma_, epsilon, + memory_efficient); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def( + "fused_layer_norm_forward_affine(Tensor input, int[] normalized_shape, Tensor gamma, Tensor beta, " + "float epsilon) -> Tensor[]"); + m.def("fused_layer_norm_forward(Tensor input, int[] normalized_shape, float epsilon) -> Tensor[]"); + m.def( + "fused_layer_norm_backward_affine(Tensor dout, Tensor? mean, Tensor invvar, Tensor input_or_output, " + "int[] normalized_shape, Tensor gamma, Tensor beta, float epsilon, bool memory_efficient) -> Tensor[]"); + m.def( + "fused_layer_norm_backward(Tensor dout, Tensor? mean, Tensor invvar, Tensor input_or_output, " + "int[] normalized_shape, float epsilon, bool memory_efficient) -> Tensor"); + m.def( + "fused_layer_norm_forward_affine_mixed_dtypes(Tensor input, int[] normalized_shape, Tensor gamma, Tensor beta, " + "float epsilon) -> Tensor[]"); + m.def( + "fused_layer_norm_rms_forward_affine(Tensor input, int[] normalized_shape, Tensor gamma, float epsilon) " + "-> Tensor[]"); + m.def("fused_layer_norm_rms_forward(Tensor input, int[] normalized_shape, float epsilon) -> Tensor[]"); + m.def( + "fused_layer_norm_rms_backward_affine(Tensor dout, Tensor invvar, Tensor input_or_output, " + "int[] normalized_shape, Tensor gamma, float epsilon, bool memory_efficient) -> Tensor[]"); + m.def( + "fused_layer_norm_rms_backward(Tensor dout, Tensor invvar, Tensor input_or_output, int[] normalized_shape, " + "float epsilon, bool memory_efficient) -> Tensor"); + m.def( + "fused_layer_norm_rms_forward_affine_mixed_dtypes(Tensor input, int[] normalized_shape, Tensor gamma, " + "float epsilon) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_layer_norm_forward_affine", &layer_norm_affine); + m.impl("fused_layer_norm_forward", &layer_norm); + m.impl("fused_layer_norm_backward_affine", &apex_layer_norm_gradient_affine); + m.impl("fused_layer_norm_backward", &apex_layer_norm_gradient); + m.impl("fused_layer_norm_forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes); + m.impl("fused_layer_norm_rms_forward_affine", &rms_norm_affine); + m.impl("fused_layer_norm_rms_forward", &rms_norm); + m.impl("fused_layer_norm_rms_backward_affine", &apex_rms_norm_gradient_affine); + m.impl("fused_layer_norm_rms_backward", &apex_rms_norm_gradient); + m.impl("fused_layer_norm_rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes); } diff --git a/csrc/megatron/fused_rotary_positional_embedding.cpp b/csrc/megatron/fused_rotary_positional_embedding.cpp index c0b2ded82..4efd711c7 100644 --- a/csrc/megatron/fused_rotary_positional_embedding.cpp +++ b/csrc/megatron/fused_rotary_positional_embedding.cpp @@ -14,32 +14,32 @@ * limitations under the License. */ -#include +#include +#include namespace fused_rope { -torch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, const bool transpose_output); +at::Tensor fwd_cuda(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output); -torch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& freqs, const bool transpose_output); +at::Tensor bwd_cuda(const at::Tensor& output_grads, const at::Tensor& freqs, const bool transpose_output); -torch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& cos, const torch::Tensor& sin, - const bool transpose_output); +at::Tensor fwd_cached_cuda(const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output); -torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos, const torch::Tensor& sin, - const bool transpose_output); +at::Tensor bwd_cached_cuda(const at::Tensor& output_grads, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output); -torch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs); +at::Tensor fwd_thd_cuda(const at::Tensor& input, const at::Tensor& cu_seqlens, const at::Tensor& freqs); -torch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tensor& cu_seqlens, - const torch::Tensor& freqs); +at::Tensor bwd_thd_cuda(const at::Tensor& output_grads, const at::Tensor& cu_seqlens, const at::Tensor& freqs); -torch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w); +at::Tensor fwd_2d_cuda(const at::Tensor& input, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w); -torch::Tensor bwd_2d_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w); +at::Tensor bwd_2d_cuda(const at::Tensor& output_grads, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w); -torch::Tensor fwd(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output) { +at::Tensor fwd(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(input.size(0) == freqs.size(0), "expected input and freqs tensor have the same sequence length"); @@ -53,7 +53,7 @@ torch::Tensor fwd(const at::Tensor& input, const at::Tensor& freqs, const bool t return fwd_cuda(input, freqs, transpose_output); } -torch::Tensor bwd(const torch::Tensor& output_grads, const at::Tensor& freqs, const bool transpose_output) { +at::Tensor bwd(const at::Tensor& output_grads, const at::Tensor& freqs, const bool transpose_output) { TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(output_grads.size(0) == freqs.size(0), @@ -68,8 +68,8 @@ torch::Tensor bwd(const torch::Tensor& output_grads, const at::Tensor& freqs, co return bwd_cuda(output_grads, freqs, transpose_output); } -torch::Tensor fwd_cached(const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, - const bool transpose_output) { +at::Tensor fwd_cached(const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); @@ -86,8 +86,8 @@ torch::Tensor fwd_cached(const at::Tensor& input, const at::Tensor& cos, const a return fwd_cached_cuda(input, cos, sin, transpose_output); } -torch::Tensor bwd_cached(const torch::Tensor& output_grads, const at::Tensor& cos, const at::Tensor& sin, - const bool transpose_output) { +at::Tensor bwd_cached(const at::Tensor& output_grads, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output) { TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); @@ -106,7 +106,7 @@ torch::Tensor bwd_cached(const torch::Tensor& output_grads, const at::Tensor& co return bwd_cached_cuda(output_grads, cos, sin, transpose_output); } -torch::Tensor fwd_thd(const torch::Tensor& input, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs) { +at::Tensor fwd_thd(const at::Tensor& input, const at::Tensor& cu_seqlens, const at::Tensor& freqs) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); @@ -120,7 +120,7 @@ torch::Tensor fwd_thd(const torch::Tensor& input, const torch::Tensor& cu_seqlen return fwd_thd_cuda(input, cu_seqlens, freqs); } -torch::Tensor bwd_thd(const torch::Tensor& output_grads, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs) { +at::Tensor bwd_thd(const at::Tensor& output_grads, const at::Tensor& cu_seqlens, const at::Tensor& freqs) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); @@ -134,8 +134,8 @@ torch::Tensor bwd_thd(const torch::Tensor& output_grads, const torch::Tensor& cu return bwd_thd_cuda(output_grads, cu_seqlens, freqs); } -torch::Tensor fwd_2d(const torch::Tensor& input, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w) { +at::Tensor fwd_2d(const at::Tensor& input, const at::Tensor& cos_h, const at::Tensor& sin_h, const at::Tensor& cos_w, + const at::Tensor& sin_w) { TORCH_CHECK(input.dim() == 5, "expected input to be 5D tensor"); TORCH_CHECK(cos_h.dim() == 4, "expected cos_h to be 4D tensor"); TORCH_CHECK(sin_h.dim() == 4, "expected sin_h to be 4D tensor"); @@ -151,8 +151,8 @@ torch::Tensor fwd_2d(const torch::Tensor& input, const torch::Tensor& cos_h, con return fwd_2d_cuda(input, cos_h, sin_h, cos_w, sin_w); } -torch::Tensor bwd_2d(const torch::Tensor& output_grads, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w) { +at::Tensor bwd_2d(const at::Tensor& output_grads, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w) { TORCH_CHECK(output_grads.dim() == 5, "expected output_grads to be 5D tensor"); TORCH_CHECK(cos_h.dim() == 4, "expected cos_h to be 4D tensor"); TORCH_CHECK(sin_h.dim() == 4, "expected sin_h to be 4D tensor"); @@ -172,24 +172,25 @@ torch::Tensor bwd_2d(const torch::Tensor& output_grads, const torch::Tensor& cos } // end namespace fused_rope -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &fused_rope::fwd, "Fused Rotary Positional Embedding -- Forward.", - py::call_guard()); - m.def("backward", &fused_rope::bwd, "Fused Rotary Positional Embedding -- Backward.", - py::call_guard()); - // cache sin/cos - m.def("forward_cached", &fused_rope::fwd_cached, "Fused Rotary Positional Embedding Cached -- Forward.", - py::call_guard()); - m.def("backward_cached", &fused_rope::bwd_cached, "Fused Rotary Positional Embedding Cached -- Backward.", - py::call_guard()); - // thd - m.def("forward_thd", &fused_rope::fwd_thd, "Fused Rotary Positional Embedding for thd layout -- Forward.", - py::call_guard()); - m.def("backward_thd", &fused_rope::bwd_thd, "Fused Rotary Positional Embedding for thd layout -- Backward.", - py::call_guard()); - // 2d - m.def("forward_2d", &fused_rope::fwd_2d, "2D Fused Rotary Positional Embedding -- Forward.", - py::call_guard()); - m.def("backward_2d", &fused_rope::bwd_2d, "2D Fused Rotary Positional Embedding -- Backward.", - py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_rope_forward(Tensor input, Tensor freqs, bool transpose_output) -> Tensor"); + m.def("fused_rope_backward(Tensor output_grads, Tensor freqs, bool transpose_output) -> Tensor"); + m.def("fused_rope_forward_cached(Tensor input, Tensor cos, Tensor sin, bool transpose_output) -> Tensor"); + m.def("fused_rope_backward_cached(Tensor output_grads, Tensor cos, Tensor sin, bool transpose_output) -> Tensor"); + m.def("fused_rope_forward_thd(Tensor input, Tensor cu_seqlens, Tensor freqs) -> Tensor"); + m.def("fused_rope_backward_thd(Tensor output_grads, Tensor cu_seqlens, Tensor freqs) -> Tensor"); + m.def("fused_rope_forward_2d(Tensor input, Tensor cos_h, Tensor sin_h, Tensor cos_w, Tensor sin_w) -> Tensor"); + m.def( + "fused_rope_backward_2d(Tensor output_grads, Tensor cos_h, Tensor sin_h, Tensor cos_w, Tensor sin_w) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_rope_forward", &fused_rope::fwd); + m.impl("fused_rope_backward", &fused_rope::bwd); + m.impl("fused_rope_forward_cached", &fused_rope::fwd_cached); + m.impl("fused_rope_backward_cached", &fused_rope::bwd_cached); + m.impl("fused_rope_forward_thd", &fused_rope::fwd_thd); + m.impl("fused_rope_backward_thd", &fused_rope::bwd_thd); + m.impl("fused_rope_forward_2d", &fused_rope::fwd_2d); + m.impl("fused_rope_backward_2d", &fused_rope::bwd_2d); } diff --git a/csrc/megatron/fused_rotary_positional_embedding.h b/csrc/megatron/fused_rotary_positional_embedding.h index 32a6042da..7af1864a9 100644 --- a/csrc/megatron/fused_rotary_positional_embedding.h +++ b/csrc/megatron/fused_rotary_positional_embedding.h @@ -18,9 +18,9 @@ #include #include +#include #include #include -#include namespace { diff --git a/csrc/megatron/fused_rotary_positional_embedding_cuda.cu b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu index 82fef191d..7d5ec9299 100644 --- a/csrc/megatron/fused_rotary_positional_embedding_cuda.cu +++ b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu @@ -21,7 +21,7 @@ namespace fused_rope { -torch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, const bool transpose_output) { +at::Tensor fwd_cuda(const at::Tensor& input, const at::Tensor& freqs, const bool transpose_output) { // input sizes: (s, b, h, d) // s: sequence length // b: batch size @@ -42,11 +42,11 @@ torch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, c // output auto act_options = input.options().requires_grad(false); - torch::Tensor output; + at::Tensor output; if (transpose_output) { - output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + output = at::empty({b, s, h, d}, act_options).transpose(0, 1); } else { - output = torch::empty({s, b, h, d}, act_options); + output = at::empty({s, b, h, d}, act_options); } // output strides const int o_stride_s = output.stride(0); @@ -62,7 +62,7 @@ torch::Tensor fwd_cuda(const torch::Tensor& input, const torch::Tensor& freqs, c return output; } -torch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& freqs, const bool transpose_output) { +at::Tensor bwd_cuda(const at::Tensor& output_grads, const at::Tensor& freqs, const bool transpose_output) { // output_grads sizes: (s, b, h, d) // s: sequence length // b: batch size @@ -82,11 +82,11 @@ torch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& f const int d2 = freqs.size(3); auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads; + at::Tensor input_grads; if (transpose_output) { - input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + input_grads = at::empty({b, s, h, d}, act_options).transpose(0, 1); } else { - input_grads = torch::empty({s, b, h, d}, act_options); + input_grads = at::empty({s, b, h, d}, act_options); } const int o_stride_s = input_grads.stride(0); const int o_stride_b = input_grads.stride(1); @@ -156,8 +156,8 @@ torch::Tensor bwd_cuda(const torch::Tensor& output_grads, const torch::Tensor& f TORCH_CHECK(false, #NAME, " not supported for '", toString(TYPE1), "' with '", toString(TYPE2), "'"); \ } -torch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& cos, const torch::Tensor& sin, - const bool transpose_output) { +at::Tensor fwd_cached_cuda(const at::Tensor& input, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output) { // input sizes: (s, b, h, d) // s: sequence length // b: batch size @@ -178,11 +178,11 @@ torch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& c // output auto act_options = input.options().requires_grad(false); - torch::Tensor output; + at::Tensor output; if (transpose_output) { - output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + output = at::empty({b, s, h, d}, act_options).transpose(0, 1); } else { - output = torch::empty({s, b, h, d}, act_options); + output = at::empty({s, b, h, d}, act_options); } // output strides const int o_stride_s = output.stride(0); @@ -198,8 +198,8 @@ torch::Tensor fwd_cached_cuda(const torch::Tensor& input, const torch::Tensor& c return output; } -torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos, const torch::Tensor& sin, - const bool transpose_output) { +at::Tensor bwd_cached_cuda(const at::Tensor& output_grads, const at::Tensor& cos, const at::Tensor& sin, + const bool transpose_output) { // output_grads sizes: (s, b, h, d) // s: sequence length // b: batch size @@ -219,11 +219,11 @@ torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Te const int d2 = cos.size(3); auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads; + at::Tensor input_grads; if (transpose_output) { - input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + input_grads = at::empty({b, s, h, d}, act_options).transpose(0, 1); } else { - input_grads = torch::empty({s, b, h, d}, act_options); + input_grads = at::empty({s, b, h, d}, act_options); } const int o_stride_s = input_grads.stride(0); const int o_stride_b = input_grads.stride(1); @@ -238,7 +238,7 @@ torch::Tensor bwd_cached_cuda(const torch::Tensor& output_grads, const torch::Te return input_grads; } -torch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_seqlens, const torch::Tensor& freqs) { +at::Tensor fwd_thd_cuda(const at::Tensor& input, const at::Tensor& cu_seqlens, const at::Tensor& freqs) { // input sizes: (t, h, d) // t: cumulative sum of sequence lengths // h: head num @@ -258,7 +258,7 @@ torch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_s // output auto act_options = input.options().requires_grad(false); - auto output = torch::empty({t, h, d}, act_options); + auto output = at::empty({t, h, d}, act_options); // output strides const int o_stride_t = output.stride(0); const int o_stride_h = output.stride(1); @@ -272,8 +272,7 @@ torch::Tensor fwd_thd_cuda(const torch::Tensor& input, const torch::Tensor& cu_s return output; } -torch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tensor& cu_seqlens, - const torch::Tensor& freqs) { +at::Tensor bwd_thd_cuda(const at::Tensor& output_grads, const at::Tensor& cu_seqlens, const at::Tensor& freqs) { // output_grads sizes: (t, h, d) // t: cumulative sum of sequence lengths // h: head num @@ -292,7 +291,7 @@ torch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tenso const int d2 = freqs.size(3); auto act_options = output_grads.options().requires_grad(false); - auto input_grads = torch::empty({t, h, d}, act_options); + auto input_grads = at::empty({t, h, d}, act_options); const int o_stride_t = input_grads.stride(0); const int o_stride_h = input_grads.stride(1); const int o_stride_d = input_grads.stride(2); @@ -305,8 +304,8 @@ torch::Tensor bwd_thd_cuda(const torch::Tensor& output_grads, const torch::Tenso return input_grads; } -torch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w) { +at::Tensor fwd_2d_cuda(const at::Tensor& input, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w) { // input sizes: (b, ih, iw, h, d) // b: batch size // ih: image height @@ -327,7 +326,7 @@ torch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h // output auto act_options = input.options().requires_grad(false); - auto output = torch::empty({b, ih * iw, h, d}, act_options); + auto output = at::empty({b, ih * iw, h, d}, act_options); // output strides const int o_stride_b = output.stride(0); const int o_stride_s = output.stride(1); @@ -343,8 +342,8 @@ torch::Tensor fwd_2d_cuda(const torch::Tensor& input, const torch::Tensor& cos_h return output; } -torch::Tensor bwd_2d_cuda(const torch::Tensor& output_grads, const torch::Tensor& cos_h, const torch::Tensor& sin_h, - const torch::Tensor& cos_w, const torch::Tensor& sin_w) { +at::Tensor bwd_2d_cuda(const at::Tensor& output_grads, const at::Tensor& cos_h, const at::Tensor& sin_h, + const at::Tensor& cos_w, const at::Tensor& sin_w) { // output_grads sizes: (b, ih, iw, h, d) // b: batch size // ih: image height @@ -364,7 +363,7 @@ torch::Tensor bwd_2d_cuda(const torch::Tensor& output_grads, const torch::Tensor const int stride_d = output_grads.stride(4); auto act_options = output_grads.options().requires_grad(false); - auto input_grads = torch::empty({b, ih * iw, h, d}, act_options); + auto input_grads = at::empty({b, ih * iw, h, d}, act_options); const int o_stride_b = input_grads.stride(0); const int o_stride_s = input_grads.stride(1); const int o_stride_h = input_grads.stride(2); diff --git a/csrc/megatron/fused_weight_gradient_dense.cpp b/csrc/megatron/fused_weight_gradient_dense.cpp index 71c826cfe..633ef068b 100644 --- a/csrc/megatron/fused_weight_gradient_dense.cpp +++ b/csrc/megatron/fused_weight_gradient_dense.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include #include @@ -7,9 +8,26 @@ void wgrad_gemm_accum_fp32_cuda_stub(at::Tensor& input_2d, at::Tensor& d_output_ void wgrad_gemm_accum_fp16_cuda_stub(at::Tensor& input_2d, at::Tensor& d_output_2d, at::Tensor& d_weight); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32", - py::call_guard()); - m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16", - py::call_guard()); +void wgrad_gemm_accum_fp32_dispatch(const at::Tensor& input, const at::Tensor& d_output, const at::Tensor& d_weight) { + at::Tensor input_arg = input; + at::Tensor d_output_arg = d_output; + at::Tensor d_weight_arg = d_weight; + wgrad_gemm_accum_fp32_cuda_stub(input_arg, d_output_arg, d_weight_arg); +} + +void wgrad_gemm_accum_fp16_dispatch(const at::Tensor& input, const at::Tensor& d_output, const at::Tensor& d_weight) { + at::Tensor input_arg = input; + at::Tensor d_output_arg = d_output; + at::Tensor d_weight_arg = d_weight; + wgrad_gemm_accum_fp16_cuda_stub(input_arg, d_output_arg, d_weight_arg); +} + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("fused_weight_gradient_mlp_wgrad_gemm_accum_fp32(Tensor input, Tensor d_output, Tensor(a!) d_weight) -> ()"); + m.def("fused_weight_gradient_mlp_wgrad_gemm_accum_fp16(Tensor input, Tensor d_output, Tensor(a!) d_weight) -> ()"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("fused_weight_gradient_mlp_wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_dispatch); + m.impl("fused_weight_gradient_mlp_wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_dispatch); } diff --git a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu index 09af5c0a9..87810500d 100644 --- a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu +++ b/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include diff --git a/csrc/megatron/fused_weight_gradient_dense_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_cuda.cu index 6adc8c7e2..e05f8e937 100644 --- a/csrc/megatron/fused_weight_gradient_dense_cuda.cu +++ b/csrc/megatron/fused_weight_gradient_dense_cuda.cu @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include diff --git a/csrc/megatron/generic_scaled_masked_softmax.cpp b/csrc/megatron/generic_scaled_masked_softmax.cpp index cb7a78809..770989801 100644 --- a/csrc/megatron/generic_scaled_masked_softmax.cpp +++ b/csrc/megatron/generic_scaled_masked_softmax.cpp @@ -14,8 +14,9 @@ * limitations under the License. */ +#include #include -#include +#include #include @@ -23,20 +24,20 @@ namespace multihead_attn { namespace fused_softmax { namespace generic_scaled_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor); +at::Tensor fwd_cuda(at::Tensor const& input, at::Tensor const& mask, float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); +at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); -torch::Tensor fwd(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { +at::Tensor fwd(at::Tensor const& input, at::Tensor const& mask, double scale_factor) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); TORCH_CHECK(mask.dim() == 4, "expected 4D tensor"); - return fwd_cuda(input, mask, scale_factor); + return fwd_cuda(input, mask, static_cast(scale_factor)); } -torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); @@ -47,17 +48,21 @@ torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softma (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return bwd_cuda(output_grads, softmax_results, scale_factor); + return bwd_cuda(output_grads, softmax_results, static_cast(scale_factor)); } } // end namespace generic_scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("generic_scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor"); + m.def( + "generic_scaled_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " + "float scale_factor) -> Tensor"); +} - m.def("backward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("generic_scaled_masked_softmax_forward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::fwd); + m.impl("generic_scaled_masked_softmax_backward", &multihead_attn::fused_softmax::generic_scaled_masked_softmax::bwd); } diff --git a/csrc/megatron/generic_scaled_masked_softmax_cuda.cu b/csrc/megatron/generic_scaled_masked_softmax_cuda.cu index cc1e55297..0ca0837b0 100644 --- a/csrc/megatron/generic_scaled_masked_softmax_cuda.cu +++ b/csrc/megatron/generic_scaled_masked_softmax_cuda.cu @@ -18,9 +18,10 @@ #include #include #include +#if __has_include() #include +#endif #include -#include #include "generic_scaled_masked_softmax.h" #include "type_shim.h" @@ -29,7 +30,7 @@ namespace multihead_attn { namespace fused_softmax { namespace generic_scaled_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { +at::Tensor fwd_cuda(at::Tensor const& input, at::Tensor const& mask, float scale_factor) { // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] const int batches = input.size(0); const int pad_batches = mask.size(0); @@ -43,7 +44,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, fl // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor softmax_results = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -58,7 +59,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, fl return softmax_results; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { +at::Tensor bwd_cuda(at::Tensor const& output_grads_, at::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -69,7 +70,7 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& const int key_seq_len = output_grads.size(3); auto act_options = output_grads.options(); - torch::Tensor input_grad = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor input_grad = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); void* output_grads_ptr = static_cast(output_grads.data_ptr()); diff --git a/csrc/megatron/scaled_masked_softmax.cpp b/csrc/megatron/scaled_masked_softmax.cpp index 82963817f..e5bab742f 100644 --- a/csrc/megatron/scaled_masked_softmax.cpp +++ b/csrc/megatron/scaled_masked_softmax.cpp @@ -14,8 +14,9 @@ * limitations under the License. */ +#include #include -#include +#include #include @@ -23,54 +24,67 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor); +at::Tensor fwd_cuda(at::Tensor const& input, at::Tensor const& mask, float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); +at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads); -torch::Tensor fwd(torch::Tensor& input, torch::Tensor& mask, float scale_factor) { - TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); - TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - TORCH_CHECK(mask.dim() == 4, "expected 4D tensor"); - if (!input.is_contiguous()) input = input.contiguous(); - if (!mask.is_contiguous()) mask = mask.contiguous(); +at::Tensor fwd(at::Tensor const& input, at::Tensor const& mask, double scale_factor) { + auto input_arg = input; + auto mask_arg = mask; + TORCH_CHECK(input_arg.dim() == 4, "expected 4D tensor"); + TORCH_CHECK( + (input_arg.scalar_type() == at::ScalarType::Half) || (input_arg.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + TORCH_CHECK(mask_arg.dim() == 4, "expected 4D tensor"); + if (!input_arg.is_contiguous()) input_arg = input_arg.contiguous(); + if (!mask_arg.is_contiguous()) mask_arg = mask_arg.contiguous(); - return fwd_cuda(input, mask, scale_factor); + return fwd_cuda(input_arg, mask_arg, static_cast(scale_factor)); } -torch::Tensor bwd(torch::Tensor& output_grads, torch::Tensor& softmax_results, float scale_factor) { - TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); - TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { + auto output_grads_arg = output_grads; + auto softmax_results_arg = softmax_results; + TORCH_CHECK(output_grads_arg.dim() == 4, "expected 3D tensor"); + TORCH_CHECK(softmax_results_arg.dim() == 4, "expected 3D tensor"); - TORCH_CHECK( - (output_grads.scalar_type() == at::ScalarType::Half) || (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), + TORCH_CHECK((output_grads_arg.scalar_type() == at::ScalarType::Half) || + (output_grads_arg.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + TORCH_CHECK((softmax_results_arg.scalar_type() == at::ScalarType::Half) || + (softmax_results_arg.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - if (!output_grads.is_contiguous()) output_grads = output_grads.contiguous(); - if (!softmax_results.is_contiguous()) softmax_results = softmax_results.contiguous(); + if (!output_grads_arg.is_contiguous()) output_grads_arg = output_grads_arg.contiguous(); + if (!softmax_results_arg.is_contiguous()) softmax_results_arg = softmax_results_arg.contiguous(); - return bwd_cuda(output_grads, softmax_results, scale_factor); + return bwd_cuda(output_grads_arg, softmax_results_arg, static_cast(scale_factor)); } -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +int64_t get_batch_per_block(int64_t query_seq_len, int64_t key_seq_len, int64_t batches, int64_t attn_heads) { + return get_batch_per_block_cuda(static_cast(query_seq_len), static_cast(key_seq_len), + static_cast(batches), static_cast(attn_heads)); } } // end namespace scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("scaled_masked_softmax_forward(Tensor input, Tensor mask, float scale_factor) -> Tensor"); + m.def("scaled_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, float scale_factor) -> Tensor"); + m.def( + "scaled_masked_softmax_get_batch_per_block(int query_seq_len, int key_seq_len, int batches, " + "int attn_heads) -> int"); +} - m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("scaled_masked_softmax_forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd); + m.impl("scaled_masked_softmax_backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd); +} - m.def("get_batch_per_block", &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size.", py::call_guard()); +TORCH_LIBRARY_IMPL(apex, CompositeExplicitAutograd, m) { + m.impl("scaled_masked_softmax_get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block); } diff --git a/csrc/megatron/scaled_masked_softmax_cuda.cu b/csrc/megatron/scaled_masked_softmax_cuda.cu index 879697819..fecc5cd37 100644 --- a/csrc/megatron/scaled_masked_softmax_cuda.cu +++ b/csrc/megatron/scaled_masked_softmax_cuda.cu @@ -18,9 +18,10 @@ #include #include #include +#if __has_include() #include +#endif #include -#include #include "scaled_masked_softmax.h" #include "type_shim.h" @@ -33,7 +34,7 @@ int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, in return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); } -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, float scale_factor) { +at::Tensor fwd_cuda(at::Tensor const& input, at::Tensor const& mask, float scale_factor) { // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] const int batches = input.size(0); const int pad_batches = mask.size(0); @@ -49,7 +50,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, fl // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor softmax_results = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -64,7 +65,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& mask, fl return softmax_results; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { +at::Tensor bwd_cuda(at::Tensor const& output_grads_, at::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -75,7 +76,7 @@ torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& const int key_seq_len = output_grads.size(3); auto act_options = output_grads.options().requires_grad(false); - torch::Tensor input_grads = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor input_grads = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); void* input_grads_ptr = static_cast(input_grads.data_ptr()); void* output_grads_ptr = static_cast(output_grads.data_ptr()); diff --git a/csrc/megatron/scaled_softmax.cpp b/csrc/megatron/scaled_softmax.cpp index 7e4d44c1b..5a43f8ba1 100644 --- a/csrc/megatron/scaled_softmax.cpp +++ b/csrc/megatron/scaled_softmax.cpp @@ -14,8 +14,9 @@ * limitations under the License. */ +#include #include -#include +#include #include @@ -23,19 +24,19 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); +at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); +at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { +at::Tensor fwd(at::Tensor const& input, double scale_factor) { TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return fwd_cuda(input, scale_factor); + return fwd_cuda(input, static_cast(scale_factor)); } -torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor"); @@ -46,16 +47,19 @@ torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softma (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return bwd_cuda(output_grads, softmax_results, scale_factor); + return bwd_cuda(output_grads, softmax_results, static_cast(scale_factor)); } } // end namespace scaled_softmax } // end namespace fused_softmax } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::scaled_softmax::fwd, - "Self Multihead Attention scaled, softmax -- Forward.", py::call_guard()); - m.def("backward", &multihead_attn::fused_softmax::scaled_softmax::bwd, - "Self Multihead Attention scaled, softmax -- Backward.", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("scaled_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def("scaled_softmax_backward(Tensor output_grads, Tensor softmax_results, float scale_factor) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("scaled_softmax_forward", &multihead_attn::fused_softmax::scaled_softmax::fwd); + m.impl("scaled_softmax_backward", &multihead_attn::fused_softmax::scaled_softmax::bwd); } diff --git a/csrc/megatron/scaled_softmax_cuda.cu b/csrc/megatron/scaled_softmax_cuda.cu index 4ec6e4bdf..93e08de31 100644 --- a/csrc/megatron/scaled_softmax_cuda.cu +++ b/csrc/megatron/scaled_softmax_cuda.cu @@ -18,9 +18,10 @@ #include #include #include +#if __has_include() #include +#endif #include -#include #include "scaled_masked_softmax.h" #include "type_shim.h" @@ -29,7 +30,7 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { +at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor) { // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] const int batches = input.size(0); const int attn_heads = input.size(1); @@ -40,7 +41,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + at::Tensor softmax_results = at::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -54,7 +55,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { return softmax_results; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { +at::Tensor bwd_cuda(at::Tensor const& output_grads_, at::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp index 726ea7dbe..6d907d017 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp +++ b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp @@ -14,8 +14,9 @@ * limitations under the License. */ +#include #include -#include +#include #include @@ -23,19 +24,19 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor); +at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor); -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor); +at::Tensor bwd_cuda(at::Tensor const& output_grads, at::Tensor const& softmax_results, float scale_factor); -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { +at::Tensor fwd(at::Tensor const& input, double scale_factor) { TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return fwd_cuda(input, scale_factor); + return fwd_cuda(input, static_cast(scale_factor)); } -torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softmax_results, float scale_factor) { +at::Tensor bwd(at::Tensor const& output_grads, at::Tensor const& softmax_results, double scale_factor) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(softmax_results.dim() == 3, "expected 3D tensor"); @@ -46,16 +47,23 @@ torch::Tensor bwd(torch::Tensor const& output_grads, torch::Tensor const& softma (softmax_results.scalar_type() == at::ScalarType::BFloat16), "Only fp16 and bf16 are supported"); - return bwd_cuda(output_grads, softmax_results, scale_factor); + return bwd_cuda(output_grads, softmax_results, static_cast(scale_factor)); } } // end namespace scaled_upper_triang_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward.", py::call_guard()); - m.def("backward", &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward.", py::call_guard()); +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("scaled_upper_triang_masked_softmax_forward(Tensor input, float scale_factor) -> Tensor"); + m.def( + "scaled_upper_triang_masked_softmax_backward(Tensor output_grads, Tensor softmax_results, " + "float scale_factor) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("scaled_upper_triang_masked_softmax_forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd); + m.impl("scaled_upper_triang_masked_softmax_backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd); } diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu index 86813c6a3..251209a87 100644 --- a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu +++ b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu @@ -18,9 +18,10 @@ #include #include #include +#if __has_include() #include +#endif #include -#include #include "scaled_upper_triang_masked_softmax.h" #include "type_shim.h" @@ -29,7 +30,7 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_upper_triang_masked_softmax { -torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { +at::Tensor fwd_cuda(at::Tensor const& input, float scale_factor) { // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] const int attn_batches = input.size(0); const int seq_len = input.size(1); @@ -37,7 +38,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { // Output auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = torch::empty({attn_batches, seq_len, seq_len}, act_options); + at::Tensor softmax_results = at::empty({attn_batches, seq_len, seq_len}, act_options); // Softmax Intermediate Result Ptr void* input_ptr = static_cast(input.data_ptr()); @@ -51,7 +52,7 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor) { return softmax_results; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads_, torch::Tensor const& softmax_results_, float scale_factor) { +at::Tensor bwd_cuda(at::Tensor const& output_grads_, at::Tensor const& softmax_results_, float scale_factor) { auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); diff --git a/csrc/mlp.cpp b/csrc/mlp.cpp index 5ee75a405..be88dc993 100644 --- a/csrc/mlp.cpp +++ b/csrc/mlp.cpp @@ -1,6 +1,7 @@ +#include +#include #include -#include -#include +#include #include @@ -106,7 +107,26 @@ std::vector mlp_backward(int use_bias, int activation, at::Tensor gr return outputs; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &mlp_forward, "MLP forward", py::call_guard()); - m.def("backward", &mlp_backward, "MLP backward", py::call_guard()); +namespace { +std::vector apex_mlp_forward(int64_t use_bias, int64_t activation, std::vector inputs) { + return mlp_forward(static_cast(use_bias), static_cast(activation), std::move(inputs)); +} + +std::vector apex_mlp_backward(int64_t use_bias, int64_t activation, at::Tensor grad_o, + std::vector fprop_outputs, std::vector inputs) { + return mlp_backward(static_cast(use_bias), static_cast(activation), grad_o, std::move(fprop_outputs), + std::move(inputs)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("mlp_forward(int use_bias, int activation, Tensor[] inputs) -> Tensor[]"); + m.def( + "mlp_backward(int use_bias, int activation, Tensor grad_o, Tensor[] fprop_outputs, Tensor[] inputs) " + "-> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("mlp_forward", &apex_mlp_forward); + m.impl("mlp_backward", &apex_mlp_backward); } diff --git a/csrc/mlp_cuda.cu b/csrc/mlp_cuda.cu index 4a870da4d..912fa096a 100644 --- a/csrc/mlp_cuda.cu +++ b/csrc/mlp_cuda.cu @@ -4,7 +4,6 @@ #include #include #include -#include /* Includes, cuda */ #include diff --git a/csrc/syncbn.cpp b/csrc/syncbn.cpp index d91a6ad29..b5dc892c2 100644 --- a/csrc/syncbn.cpp +++ b/csrc/syncbn.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include @@ -68,22 +68,48 @@ at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output, const at::Ten const at::optional z, const at::Tensor mean, const at::Tensor inv_std, const at::optional weight, const at::optional shift); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance", py::call_guard()); - m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance", - py::call_guard()); - m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward", py::call_guard()); - m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad", - py::call_guard()); - m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad", - py::call_guard()); - m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc", - py::call_guard()); - m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc", - py::call_guard()); - m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc", - py::call_guard()); - m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc", - py::call_guard()); - m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last", py::call_guard()); +namespace { +std::vector apex_welford_parallel_CUDA(const at::Tensor mean_feature_nodes, + const at::Tensor var_biased_feature_nodes, const at::Tensor numel, + double eps) { + return welford_parallel_CUDA(mean_feature_nodes, var_biased_feature_nodes, numel, static_cast(eps)); +} +} // namespace + +TORCH_LIBRARY_FRAGMENT(apex, m) { + m.def("syncbn_welford_mean_var(Tensor input) -> Tensor[]"); + m.def( + "syncbn_welford_parallel(Tensor mean_feature_nodes, Tensor var_biased_feature_nodes, Tensor numel, float eps) " + "-> Tensor[]"); + m.def("syncbn_batchnorm_forward(Tensor input, Tensor mean, Tensor inv_std, Tensor? weight, Tensor? shift) -> Tensor"); + m.def("syncbn_reduce_bn(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight) -> Tensor[]"); + m.def( + "syncbn_batchnorm_backward(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight, " + "Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor"); + m.def("syncbn_welford_mean_var_c_last(Tensor input) -> Tensor[]"); + m.def( + "syncbn_batchnorm_forward_c_last(Tensor input, Tensor? z, Tensor mean, Tensor inv_std, Tensor? weight, " + "Tensor? shift, bool fuse_relu) -> Tensor"); + m.def( + "syncbn_reduce_bn_c_last(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, Tensor? weight) " + "-> Tensor[]"); + m.def( + "syncbn_batchnorm_backward_c_last(Tensor grad_output, Tensor input, Tensor mean, Tensor inv_std, " + "Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor"); + m.def( + "syncbn_relu_bw_c_last(Tensor grad_output, Tensor input, Tensor? z, Tensor mean, Tensor inv_std, " + "Tensor? weight, Tensor? shift) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(apex, CUDA, m) { + m.impl("syncbn_welford_mean_var", &welford_mean_var_CUDA); + m.impl("syncbn_welford_parallel", &apex_welford_parallel_CUDA); + m.impl("syncbn_batchnorm_forward", &batchnorm_forward_CUDA); + m.impl("syncbn_reduce_bn", &reduce_bn_CUDA); + m.impl("syncbn_batchnorm_backward", &batchnorm_backward_CUDA); + m.impl("syncbn_welford_mean_var_c_last", &welford_mean_var_c_last_CUDA); + m.impl("syncbn_batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA); + m.impl("syncbn_reduce_bn_c_last", &reduce_bn_c_last_CUDA); + m.impl("syncbn_batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA); + m.impl("syncbn_relu_bw_c_last", &relu_backward_c_last_CUDA); } diff --git a/setup.py b/setup.py index a0e59cf86..b60c00e20 100644 --- a/setup.py +++ b/setup.py @@ -168,10 +168,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int if has_flag("--cpp_ext", "APEX_CPP_EXT"): if "--cpp_ext" in sys.argv: sys.argv.remove("--cpp_ext") - ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"])) + # apex._extensions.apex_C is a pure Python shim to avoid exposing a C++ ABI + # for dense tensor flattening utilities. -_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +bare_metal_version = None +if CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) if has_flag("--distributed_adam", "APEX_DISTRIBUTED_ADAM"): if "--distributed_adam" in sys.argv: @@ -179,12 +182,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--distributed_adam") ext_modules.append( CUDAExtension( - name="distributed_adam_cuda", + name="_distributed_adam_cuda", sources=[ "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp", "apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu", ], include_dirs=[os.path.join(this_dir, "csrc")], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"], @@ -198,12 +202,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--distributed_lamb") ext_modules.append( CUDAExtension( - name="distributed_lamb_cuda", + name="_distributed_lamb_cuda", sources=[ "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp", "apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu", ], include_dirs=[os.path.join(this_dir, "csrc")], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"], @@ -219,7 +224,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="amp_C", + name="_amp_C", sources=[ "csrc/amp_C_frontend.cpp", "csrc/multi_tensor_sgd_kernel.cu", @@ -237,6 +242,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "csrc/multi_tensor_lamb_mp.cu", "csrc/update_scale_hysteresis.cu", ], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": [ @@ -250,8 +256,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="syncbn", + name="_syncbn", sources=["csrc/syncbn.cpp", "csrc/welford.cu"], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3"], @@ -261,8 +268,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fused_layer_norm_cuda", + name="_fused_layer_norm_cuda", sources=["csrc/layer_norm_cuda.cpp", "csrc/layer_norm_cuda_kernel.cu"], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-maxrregcount=50", "-O3", "--use_fast_math"], @@ -272,8 +280,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="mlp_cuda", + name="_mlp_cuda", sources=["csrc/mlp.cpp", "csrc/mlp_cuda.cu"], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3"], @@ -282,8 +291,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="fused_dense_cuda", + name="_fused_dense_cuda", sources=["csrc/fused_dense.cpp", "csrc/fused_dense_cuda.cu"], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": ["-O3"], @@ -293,11 +303,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="scaled_upper_triang_masked_softmax_cuda", + name="_scaled_upper_triang_masked_softmax_cuda", sources=[ "csrc/megatron/scaled_upper_triang_masked_softmax.cpp", "csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -314,11 +325,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="generic_scaled_masked_softmax_cuda", + name="_generic_scaled_masked_softmax_cuda", sources=[ "csrc/megatron/generic_scaled_masked_softmax.cpp", "csrc/megatron/generic_scaled_masked_softmax_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -335,11 +347,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="scaled_masked_softmax_cuda", + name="_scaled_masked_softmax_cuda", sources=[ "csrc/megatron/scaled_masked_softmax.cpp", "csrc/megatron/scaled_masked_softmax_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -356,11 +369,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="scaled_softmax_cuda", + name="_scaled_softmax_cuda", sources=[ "csrc/megatron/scaled_softmax.cpp", "csrc/megatron/scaled_softmax_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -377,11 +391,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fused_rotary_positional_embedding", + name="_fused_rotary_positional_embedding", sources=[ "csrc/megatron/fused_rotary_positional_embedding.cpp", "csrc/megatron/fused_rotary_positional_embedding_cuda.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -398,13 +413,14 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fused_weight_gradient_mlp_cuda", + name="_fused_weight_gradient_mlp_cuda", include_dirs=[os.path.join(this_dir, "csrc")], sources=[ "csrc/megatron/fused_weight_gradient_dense.cpp", "csrc/megatron/fused_weight_gradient_dense_cuda.cu", "csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu", ], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"], "nvcc": [ @@ -431,7 +447,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int cc_flag = ["-Xcompiler", "-fPIC", "-shared"] ext_modules.append( CUDAExtension( - name="permutation_search_cuda", + name="_permutation_search_cuda", sources=[ "apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu" ], @@ -446,6 +462,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ], extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"] + cc_flag}, + py_limited_api=True, ) ) @@ -455,7 +472,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--bnp") ext_modules.append( CUDAExtension( - name="bnp", + name="_bnp", sources=[ "apex/contrib/csrc/groupbn/batch_norm.cu", "apex/contrib/csrc/groupbn/ipc.cu", @@ -463,6 +480,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "apex/contrib/csrc/groupbn/batch_norm_add_relu.cu", ], include_dirs=[os.path.join(this_dir, "csrc")], + py_limited_api=True, extra_compile_args={ "cxx": [], "nvcc": [ @@ -485,7 +503,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int print(f"`--xentropy` setting version of {xentropy_ver}") ext_modules.append( CUDAExtension( - name="xentropy_cuda", + name="_xentropy_cuda", sources=[ "apex/contrib/csrc/xentropy/interface.cpp", "apex/contrib/csrc/xentropy/xentropy_kernel.cu", @@ -495,6 +513,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"] + [f'-DXENTROPY_VER="{xentropy_ver}"'], "nvcc": ["-O3"], }, + py_limited_api=True, ) ) @@ -504,7 +523,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--focal_loss") ext_modules.append( CUDAExtension( - name="focal_loss_cuda", + name="_focal_loss_cuda", sources=[ "apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp", "apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu", @@ -514,6 +533,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math", "--ftz=false"], }, + py_limited_api=True, ) ) @@ -524,7 +544,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="group_norm_cuda", + name="_group_norm_cuda", sources=[ "apex/contrib/csrc/group_norm/group_norm_nhwc_op.cpp", ] @@ -538,6 +558,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "--ftz=false", ], }, + py_limited_api=True, ) ) @@ -554,7 +575,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="group_norm_v2_cuda", + name="_group_norm_v2_cuda", sources=[ "apex/contrib/csrc/group_norm_v2/gn.cpp", "apex/contrib/csrc/group_norm_v2/gn_cuda.cu", @@ -574,6 +595,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ] + arch_flags, }, + py_limited_api=True, ) ) @@ -583,7 +605,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--index_mul_2d") ext_modules.append( CUDAExtension( - name="fused_index_mul_2d", + name="_fused_index_mul_2d", sources=[ "apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp", "apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu", @@ -593,6 +615,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math", "--ftz=false"], }, + py_limited_api=True, ) ) @@ -602,7 +625,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--deprecated_fused_adam") ext_modules.append( CUDAExtension( - name="fused_adam_cuda", + name="_fused_adam_cuda", sources=[ "apex/contrib/csrc/optimizers/fused_adam_cuda.cpp", "apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu", @@ -612,6 +635,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"], }, + py_limited_api=True, ) ) @@ -621,7 +645,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--deprecated_fused_lamb") ext_modules.append( CUDAExtension( - name="fused_lamb_cuda", + name="_fused_lamb_cuda", sources=[ "apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp", "apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu", @@ -632,6 +656,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "cxx": ["-O3"], "nvcc": ["-O3", "--use_fast_math"], }, + py_limited_api=True, ) ) @@ -649,12 +674,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fast_layer_norm", + name="_fast_layer_norm", sources=[ "apex/contrib/csrc/layer_norm/ln_api.cpp", "apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu", "apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu", ], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"] + generator_flag, "nvcc": [ @@ -701,7 +727,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( - name="fmhalib", + name="_fmhalib", sources=[ "apex/contrib/csrc/fmha/fmha_api.cpp", "apex/contrib/csrc/fmha/src/fmha_fill.cu", @@ -732,6 +758,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src"), ], + py_limited_api=True, ) ) @@ -752,7 +779,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="fast_multihead_attn", + name="_fast_multihead_attn", sources=[ "apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp", "apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu", @@ -783,6 +810,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "apex/contrib/csrc/multihead_attn/cutlass/tools/util/include", ), ], + py_limited_api=True, ) ) @@ -792,11 +820,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--transducer") ext_modules.append( CUDAExtension( - name="transducer_joint_cuda", + name="_transducer_joint_cuda", sources=[ "apex/contrib/csrc/transducer/transducer_joint.cpp", "apex/contrib/csrc/transducer/transducer_joint_kernel.cu", ], + py_limited_api=True, extra_compile_args={ "cxx": ["-O3"] + generator_flag, "nvcc": ["-O3"] + generator_flag, @@ -809,11 +838,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="transducer_loss_cuda", + name="_transducer_loss_cuda", sources=[ "apex/contrib/csrc/transducer/transducer_loss.cpp", "apex/contrib/csrc/transducer/transducer_loss_kernel.cu", ], + py_limited_api=True, include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"], @@ -838,12 +868,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="cudnn_gbn_lib", + name="_cudnn_gbn_lib", sources=[ "apex/contrib/csrc/cudnn_gbn/norm_sample.cpp", "apex/contrib/csrc/cudnn_gbn/cudnn_gbn.cpp", ], include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], + py_limited_api=True, extra_compile_args={"cxx": ["-O3", "-g"] + generator_flag}, ) ) @@ -854,11 +885,12 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--peer_memory") ext_modules.append( CUDAExtension( - name="peer_memory_cuda", + name="_peer_memory_cuda", sources=[ "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu", "apex/contrib/csrc/peer_memory/peer_memory.cpp", ], + py_limited_api=True, extra_compile_args={"cxx": ["-O3"] + generator_flag}, ) ) @@ -870,11 +902,13 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int raise_if_cuda_home_none("--nccl_p2p") ext_modules.append( CUDAExtension( - name="nccl_p2p_cuda", + name="_nccl_p2p_cuda", sources=[ "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu", "apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp", ], + py_limited_api=True, + libraries=["nccl"], extra_compile_args={"cxx": ["-O3"] + generator_flag}, ) ) @@ -896,9 +930,10 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="fast_bottleneck", + name="_fast_bottleneck", sources=["apex/contrib/csrc/bottleneck/bottleneck.cpp"], include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], + py_limited_api=True, extra_compile_args={"cxx": ["-O3"] + generator_flag}, ) ) @@ -920,9 +955,10 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ext_modules.append( CUDAExtension( - name="fused_conv_bias_relu", + name="_fused_conv_bias_relu", sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"], include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], + py_limited_api=True, extra_compile_args={"cxx": ["-O3"] + generator_flag}, ) ) diff --git a/tests/L0/run_optimizers/test_lamb.py b/tests/L0/run_optimizers/test_lamb.py index 3b208e61b..171cf9827 100644 --- a/tests/L0/run_optimizers/test_lamb.py +++ b/tests/L0/run_optimizers/test_lamb.py @@ -39,7 +39,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0 defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(RefLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - import amp_C + from apex._extensions import amp_C self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm # Skip buffer diff --git a/tests/distributed/synced_batchnorm/single_gpu_unit_test.py b/tests/distributed/synced_batchnorm/single_gpu_unit_test.py index 18ea55dcc..e8b8180f7 100644 --- a/tests/distributed/synced_batchnorm/single_gpu_unit_test.py +++ b/tests/distributed/synced_batchnorm/single_gpu_unit_test.py @@ -4,7 +4,7 @@ if True: print("using setup tools") - import syncbn + from apex._extensions import syncbn else: print("using jit") from torch.utils.cpp_extension import load diff --git a/tests/distributed/synced_batchnorm/test_groups.py b/tests/distributed/synced_batchnorm/test_groups.py index e95aa9984..4be20dbb4 100644 --- a/tests/distributed/synced_batchnorm/test_groups.py +++ b/tests/distributed/synced_batchnorm/test_groups.py @@ -1,7 +1,7 @@ import torch import numpy as np import apex -import syncbn +from apex._extensions import syncbn import os import argparse import torch.optim as optim diff --git a/tests/distributed/synced_batchnorm/two_gpu_unit_test.py b/tests/distributed/synced_batchnorm/two_gpu_unit_test.py index 3c97e9ac6..343d7763b 100644 --- a/tests/distributed/synced_batchnorm/two_gpu_unit_test.py +++ b/tests/distributed/synced_batchnorm/two_gpu_unit_test.py @@ -1,7 +1,7 @@ import torch import numpy as np import apex -import syncbn +from apex._extensions import syncbn import os import argparse import torch.optim as optim