Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions fbgemm_gpu/codegen/genscript/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def rowwise_adagrad() -> Dict[str, Any]:
"""
split_post_update = """
if (max_norm > 0.0) {
CUDA_KERNEL_ASSERT(!(std::is_same<emb_t, uint8_t>::value && !cache_weights)); // not supported for uint8 yet
CUDA_KERNEL_ASSERT(!(std::is_same_v<emb_t, uint8_t> && !cache_weights)); // not supported for uint8 yet

// compute weight norm
at::acc_type<cache_t, true> weight_sum_square = 0.0;
Expand Down Expand Up @@ -932,9 +932,9 @@ def lamb() -> Dict[str, Any]:
at::acc_type<cache_t, true> rtw_sum_sq = 0.0;
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D);
float2 qparams;
if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
if constexpr (std::is_same_v<emb_t, uint8_t>) { if (!cache_weights) {
qparams = weight_row.load_qparams();
}
} }
"""
split_precomputation += generate_optimized_grad_sum_loop_access(
"""
Expand Down Expand Up @@ -1038,9 +1038,9 @@ def partial_rowwise_lamb() -> Dict[str, Any]:
at::acc_type<cache_t, true> rtw_sum_sq = 0.0;
auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D);
float2 qparams;
if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
if constexpr (std::is_same_v<emb_t, uint8_t>) { if (!cache_weights) {
qparams = weight_row.load_qparams();
}
} }
"""
split_precomputation += generate_optimized_grad_sum_loop_access(
"""
Expand Down Expand Up @@ -1378,9 +1378,9 @@ def lars_sgd() -> Dict[str, Any]:

auto weight_row = WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(weights, cache_weights, D);
float2 qparams;
if (std::is_same<emb_t, uint8_t>::value && !cache_weights) {
if constexpr (std::is_same_v<emb_t, uint8_t>) { if (!cache_weights) {
qparams = weight_row.load_qparams();
}
} }
"""
split_precomputation += generate_optimized_grad_sum_loop_access(
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,14 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
using float16 = uint16_t;
using bfloat16 = uint16_t;
using int8 = uint8_t;
using base_fbgemm_out_t = typename std::conditional<
using base_fbgemm_out_t = std::conditional_t<
std::is_same_v<output_t, at::Half>,
float16,
std::conditional<std::is_same_v<output_t, at::BFloat16>, bfloat16, std::conditional<std::is_same_v<output_t, float>, float, int8>::type> ::type >::type;
using other_fbgemm_out_t = typename std::conditional<
std::conditional_t<std::is_same_v<output_t, at::BFloat16>, bfloat16, std::conditional_t<std::is_same_v<output_t, float>, float, int8>>>;
using other_fbgemm_out_t = std::conditional_t<
std::is_same_v<output_t, at::Half>,
float16,
std::conditional<std::is_same_v<output_t, at::BFloat16>, bfloat16, float>::type> ::type;
std::conditional_t<std::is_same_v<output_t, at::BFloat16>, bfloat16, float>>;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] {
const auto* indices_acc = indices.const_data_ptr<index_t>();
const auto* offsets_acc = offsets.const_data_ptr<index_t>();
Expand Down Expand Up @@ -320,7 +320,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
const index_t* offsets_begin_ptr = offsets_acc + t * B;

bool success = true;
const bool has_weight = {{ "true" if weighted else "false" }};
constexpr bool has_weight = {{ "true" if weighted else "false" }};
const bool normalize_by_lengths = static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN;

const index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,13 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no
for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) {
uint32_t input_rows_in_flight = min(static_cast<uint32_t>(InputRowsInFlight), max_Ls - L_start);

typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow];
using AllBuffers = uint4[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow];
__shared__ AllBuffers buffers;

{% if weighted %}
// In case of PackedMode, overallocate indice weights buffer to store additional per-row weights for
// packed bags.
typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][PackedMode ? NumUint4LoadsPerRow : 1];
using AllIndiceWeights = float[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][PackedMode ? NumUint4LoadsPerRow : 1];
__shared__ AllIndiceWeights buffers_indice_weights;
{% endif %}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
const unsigned int shfl_sync_mask = 0xffffffffu;
#endif
constexpr int VEC_WIDTH = 4;
constexpr auto kIsInt8 = std::is_same<emb_t, uint8_t>::value;
constexpr auto kIsInt8 = std::is_same_v<emb_t, uint8_t>;
int32_t T = weights_offsets.size(0);
const int32_t num_long_runs = num_long_run_ids[0];
const auto warp_id = threadIdx.y;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
const unsigned int shfl_sync_mask = 0xffffffffu;
#endif
constexpr int VEC_WIDTH = 4;
constexpr auto kIsInt8 = std::is_same<emb_t, uint8_t>::value;
constexpr auto kIsInt8 = std::is_same_v<emb_t, uint8_t>;

struct SharedMemory<Vec4TAcc<cache_t>> smem;
const int32_t grad_sum_stride = max_D / VEC_WIDTH;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,13 +973,13 @@ Tensor {{ embedding_cuda_op }}(

{%- if not dense and optimizer != "none" %}
at::PhiloxCudaState rng_engine_inputs;
if (stochastic_rounding && !std::is_same<emb_t, float>::value) {
if constexpr (!std::is_same_v<emb_t, float>) { if (stochastic_rounding) {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs =
at::check_generator<at::CUDAGeneratorImpl>(gen)
->philox_cuda_state(4);
}
} }
{%- endif %}

DISPATCH_OPTIMAL_KERNEL(max_D, [&] {
Expand Down Expand Up @@ -1041,7 +1041,7 @@ Tensor {{ embedding_cuda_op }}(
// A temp buffer to accumulate gradients with atomics.
auto temp_grad_accum = at::zeros(
{use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D},
aligned_grad_output.options().dtype(std::is_same<cache_t, double>::value ? at::kDouble : at::kFloat));
aligned_grad_output.options().dtype(std::is_same_v<cache_t, double> ? at::kDouble : at::kFloat));

DISPATCH_PLACEHOLDER_TYPES(
{%- for ph_name in args.placeholder_tensor_names %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ void split_embedding_forward_cpu_kernel(

bool success = true;
if (use_fbgemm) {
using fbgemm_weight_t = typename std::conditional<
using fbgemm_weight_t = std::conditional_t<
std::is_same_v<weights_t, at::Half>,
fbgemm::float16,
weights_t>::type;
weights_t>;
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
fbgemm_weight_t,
/*IndexType=*/index_t,
Expand Down Expand Up @@ -220,10 +220,10 @@ Tensor split_embedding_codegen_forward_cpu(

FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
weights.scalar_type(), "split_embedding_cpu_forward_2", [&] {
using ind_weights_t = std::conditional<
using ind_weights_t = std::conditional_t<
std::is_same_v<scalar_t, double>,
double,
float>::type;
float>;

AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(), "split_embedding_cpu_forward_3", [&] {
Expand Down Expand Up @@ -427,7 +427,7 @@ void csr2csc_template_(
csr_offsets[table_to_feature_offset[0] * B];

using pair_t = std::pair<int, scalar_t>;
using value_t = typename std::conditional<IS_VALUE_PAIR, pair_t, int>::type;
using value_t = std::conditional_t<IS_VALUE_PAIR, pair_t, int>;

csc.column_segment_ids = fbgemm::makeAlignedUniquePtr<int>(64, nnz);
auto tmpBufKeys = fbgemm::makeAlignedUniquePtr<int>(64, NS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ __noinline__ __device__ void process_all_indices_small_Ls(
reinterpret_cast<uintptr_t>(&lxu_cache_weights[cache_idx * max_D_cache]) :
reinterpret_cast<uintptr_t>(&weights[indices[l] * load_D]);
}
if constexpr (!std::is_same<emb_t, cache_t>::value) {
if constexpr (!std::is_same_v<emb_t, cache_t>) {
cache_look_up_bits = ballot_sync(cache_idx != kCacheLocationMissing);
}
}
Expand Down Expand Up @@ -592,7 +592,7 @@ __noinline__ __device__ void process_all_indices_large_Ls(
reinterpret_cast<uintptr_t>(&lxu_cache_weights[cache_idx * max_D_cache]) :
reinterpret_cast<uintptr_t>(&weights[indices[l] * load_D]);
}
if constexpr (!std::is_same<emb_t, cache_t>::value) {
if constexpr (!std::is_same_v<emb_t, cache_t>) {
cache_look_up_bits = ballot_sync(cache_idx != kCacheLocationMissing);
// Shift cache_look_up_bits based on group_id
cache_look_up_bits >>= static_cast<uint32_t>(threadIdx.x / LOAD_GROUP_SIZE);
Expand Down Expand Up @@ -903,7 +903,7 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(
output_vec_t, \
look_up_bits_t, \
USE_CACHE, \
USE_CACHE && !std::is_same<emb_t, cache_t>::value, \
USE_CACHE && !std::is_same_v<emb_t, cache_t>, \
NUM_PARAMS * NUM_WARPS, \
STEP, \
STEP_MASK, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ batch_index_select_dim0_codegen_forward_cuda(

{%- if has_experimental %}
const bool is_experimental_ = (
is_experimental && !(std::is_same<emb_t, uint8_t>() || std::is_same<output_t, uint8_t>())
is_experimental && !(std::is_same_v<emb_t, uint8_t> || std::is_same_v<output_t, uint8_t>)
);
// if max_D > {{ legacy_max_embedding_dim }}, use TBE v2
if (!is_experimental_ && max_D <= {{ legacy_max_embedding_dim }}) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ void split_embedding_{{ optimizer }}_update(
TORCH_CHECK(!(std::is_same_v<emb_t, uint8_t>));

at::PhiloxCudaState rng_engine_inputs;
if (stochastic_rounding && !std::is_same<emb_t, float>::value) {
if (stochastic_rounding && !std::is_same_v<emb_t, float>) {
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs =
Expand Down