From 63d1806fd7823973698a3c5ceb89200021e5745c Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 8 Apr 2026 22:55:34 +0800 Subject: [PATCH] Use std::is_same_v, std::conditional_t, and using aliases in codegen --- fbgemm_gpu/codegen/genscript/optimizers.py | 14 +++++++------- .../embedding_forward_quantized_cpu_template.cpp | 10 +++++----- ...forward_quantized_split_nbit_kernel_template.cu | 4 ++-- ...embedding_backward_split_kernel_cta_template.cu | 2 +- ...mbedding_backward_split_kernel_warp_template.cu | 2 +- .../backward/embedding_backward_split_template.cu | 6 +++--- .../forward/embedding_forward_split_cpu.cpp | 10 +++++----- .../embedding_forward_split_kernel_v2_template.cu | 6 +++--- .../forward/embedding_forward_split_template.cu | 2 +- .../embedding_optimizer_split_template.cu | 2 +- 10 files changed, 29 insertions(+), 29 deletions(-) diff --git a/fbgemm_gpu/codegen/genscript/optimizers.py b/fbgemm_gpu/codegen/genscript/optimizers.py index b63cf84d9f..0a3f370d38 100644 --- a/fbgemm_gpu/codegen/genscript/optimizers.py +++ b/fbgemm_gpu/codegen/genscript/optimizers.py @@ -126,7 +126,7 @@ def rowwise_adagrad() -> Dict[str, Any]: """ split_post_update = """ if (max_norm > 0.0) { - CUDA_KERNEL_ASSERT(!(std::is_same::value && !cache_weights)); // not supported for uint8 yet + CUDA_KERNEL_ASSERT(!(std::is_same_v && !cache_weights)); // not supported for uint8 yet // compute weight norm at::acc_type weight_sum_square = 0.0; @@ -932,9 +932,9 @@ def lamb() -> Dict[str, Any]: at::acc_type rtw_sum_sq = 0.0; auto weight_row = WeightRow>(weights, cache_weights, D); float2 qparams; - if (std::is_same::value && !cache_weights) { + if constexpr (std::is_same_v) { if (!cache_weights) { qparams = weight_row.load_qparams(); - } + } } """ split_precomputation += generate_optimized_grad_sum_loop_access( """ @@ -1038,9 +1038,9 @@ def partial_rowwise_lamb() -> Dict[str, Any]: at::acc_type rtw_sum_sq = 0.0; auto weight_row = WeightRow>(weights, cache_weights, D); float2 qparams; - if (std::is_same::value && !cache_weights) { + if constexpr (std::is_same_v) { if (!cache_weights) { qparams = weight_row.load_qparams(); - } + } } """ split_precomputation += generate_optimized_grad_sum_loop_access( """ @@ -1378,9 +1378,9 @@ def lars_sgd() -> Dict[str, Any]: auto weight_row = WeightRow>(weights, cache_weights, D); float2 qparams; - if (std::is_same::value && !cache_weights) { + if constexpr (std::is_same_v) { if (!cache_weights) { qparams = weight_row.load_qparams(); - } + } } """ split_precomputation += generate_optimized_grad_sum_loop_access( """ diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp index e2c3c76245..43a7e51642 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp @@ -251,14 +251,14 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ using float16 = uint16_t; using bfloat16 = uint16_t; using int8 = uint8_t; - using base_fbgemm_out_t = typename std::conditional< + using base_fbgemm_out_t = std::conditional_t< std::is_same_v, float16, - std::conditional, bfloat16, std::conditional, float, int8>::type> ::type >::type; - using other_fbgemm_out_t = typename std::conditional< + std::conditional_t, bfloat16, std::conditional_t, float, int8>>>; + using other_fbgemm_out_t = std::conditional_t< std::is_same_v, float16, - std::conditional, bfloat16, float>::type> ::type; + std::conditional_t, bfloat16, float>>; AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_", [&] { const auto* indices_acc = indices.const_data_ptr(); const auto* offsets_acc = offsets.const_data_ptr(); @@ -320,7 +320,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const index_t* offsets_begin_ptr = offsets_acc + t * B; bool success = true; - const bool has_weight = {{ "true" if weighted else "false" }}; + constexpr bool has_weight = {{ "true" if weighted else "false" }}; const bool normalize_by_lengths = static_cast(pooling_mode) == PoolingMode::MEAN; const index_t index_size = offsets_acc[(t + 1) * B] - *offsets_begin_ptr; diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index d0a37e93e5..7172236b19 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -139,13 +139,13 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); - typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow]; + using AllBuffers = uint4[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow]; __shared__ AllBuffers buffers; {% if weighted %} // In case of PackedMode, overallocate indice weights buffer to store additional per-row weights for // packed bags. - typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][PackedMode ? NumUint4LoadsPerRow : 1]; + using AllIndiceWeights = float[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][PackedMode ? NumUint4LoadsPerRow : 1]; __shared__ AllIndiceWeights buffers_indice_weights; {% endif %} diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu index ff8a0b4fff..49505a3fba 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu @@ -173,7 +173,7 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row( const unsigned int shfl_sync_mask = 0xffffffffu; #endif constexpr int VEC_WIDTH = 4; - constexpr auto kIsInt8 = std::is_same::value; + constexpr auto kIsInt8 = std::is_same_v; int32_t T = weights_offsets.size(0); const int32_t num_long_runs = num_long_run_ids[0]; const auto warp_id = threadIdx.y; diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 607f2030d6..92a0f9712b 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -169,7 +169,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( const unsigned int shfl_sync_mask = 0xffffffffu; #endif constexpr int VEC_WIDTH = 4; - constexpr auto kIsInt8 = std::is_same::value; + constexpr auto kIsInt8 = std::is_same_v; struct SharedMemory> smem; const int32_t grad_sum_stride = max_D / VEC_WIDTH; diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 7d850a5eba..7704f5b6b4 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -973,13 +973,13 @@ Tensor {{ embedding_cuda_op }}( {%- if not dense and optimizer != "none" %} at::PhiloxCudaState rng_engine_inputs; - if (stochastic_rounding && !std::is_same::value) { + if constexpr (!std::is_same_v) { if (stochastic_rounding) { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); std::lock_guard lock(gen.mutex()); rng_engine_inputs = at::check_generator(gen) ->philox_cuda_state(4); - } + } } {%- endif %} DISPATCH_OPTIMAL_KERNEL(max_D, [&] { @@ -1041,7 +1041,7 @@ Tensor {{ embedding_cuda_op }}( // A temp buffer to accumulate gradients with atomics. auto temp_grad_accum = at::zeros( {use_deterministic_algorithms ? 0 : grad_accum_counter.numel(), max_D}, - aligned_grad_output.options().dtype(std::is_same::value ? at::kDouble : at::kFloat)); + aligned_grad_output.options().dtype(std::is_same_v ? at::kDouble : at::kFloat)); DISPATCH_PLACEHOLDER_TYPES( {%- for ph_name in args.placeholder_tensor_names %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp index 11e71e9c8f..22d8d00e82 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp @@ -101,10 +101,10 @@ void split_embedding_forward_cpu_kernel( bool success = true; if (use_fbgemm) { - using fbgemm_weight_t = typename std::conditional< + using fbgemm_weight_t = std::conditional_t< std::is_same_v, fbgemm::float16, - weights_t>::type; + weights_t>; auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides< fbgemm_weight_t, /*IndexType=*/index_t, @@ -220,10 +220,10 @@ Tensor split_embedding_codegen_forward_cpu( FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE( weights.scalar_type(), "split_embedding_cpu_forward_2", [&] { - using ind_weights_t = std::conditional< + using ind_weights_t = std::conditional_t< std::is_same_v, double, - float>::type; + float>; AT_DISPATCH_INDEX_TYPES( offsets.scalar_type(), "split_embedding_cpu_forward_3", [&] { @@ -427,7 +427,7 @@ void csr2csc_template_( csr_offsets[table_to_feature_offset[0] * B]; using pair_t = std::pair; - using value_t = typename std::conditional::type; + using value_t = std::conditional_t; csc.column_segment_ids = fbgemm::makeAlignedUniquePtr(64, nnz); auto tmpBufKeys = fbgemm::makeAlignedUniquePtr(64, NS); diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu index bde1ef527f..3a742068ce 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu @@ -365,7 +365,7 @@ __noinline__ __device__ void process_all_indices_small_Ls( reinterpret_cast(&lxu_cache_weights[cache_idx * max_D_cache]) : reinterpret_cast(&weights[indices[l] * load_D]); } - if constexpr (!std::is_same::value) { + if constexpr (!std::is_same_v) { cache_look_up_bits = ballot_sync(cache_idx != kCacheLocationMissing); } } @@ -592,7 +592,7 @@ __noinline__ __device__ void process_all_indices_large_Ls( reinterpret_cast(&lxu_cache_weights[cache_idx * max_D_cache]) : reinterpret_cast(&weights[indices[l] * load_D]); } - if constexpr (!std::is_same::value) { + if constexpr (!std::is_same_v) { cache_look_up_bits = ballot_sync(cache_idx != kCacheLocationMissing); // Shift cache_look_up_bits based on group_id cache_look_up_bits >>= static_cast(threadIdx.x / LOAD_GROUP_SIZE); @@ -903,7 +903,7 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel( output_vec_t, \ look_up_bits_t, \ USE_CACHE, \ - USE_CACHE && !std::is_same::value, \ + USE_CACHE && !std::is_same_v, \ NUM_PARAMS * NUM_WARPS, \ STEP, \ STEP_MASK, \ diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index 0d64cc8d6d..9532b01b78 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -608,7 +608,7 @@ batch_index_select_dim0_codegen_forward_cuda( {%- if has_experimental %} const bool is_experimental_ = ( - is_experimental && !(std::is_same() || std::is_same()) + is_experimental && !(std::is_same_v || std::is_same_v) ); // if max_D > {{ legacy_max_embedding_dim }}, use TBE v2 if (!is_experimental_ && max_D <= {{ legacy_max_embedding_dim }}) { diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu index 04f5d0437f..8a3a81b79e 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_template.cu @@ -155,7 +155,7 @@ void split_embedding_{{ optimizer }}_update( TORCH_CHECK(!(std::is_same_v)); at::PhiloxCudaState rng_engine_inputs; - if (stochastic_rounding && !std::is_same::value) { + if (stochastic_rounding && !std::is_same_v) { auto gen = at::cuda::detail::getDefaultCUDAGenerator(); std::lock_guard lock(gen.mutex()); rng_engine_inputs =