@@ -1368,7 +1368,7 @@ class GPUTilingDedup {
13681368 }
13691369
13701370 /* * Generate Halide GPU schedules. */
1371- void apply (AutoSchedule &sched) {
1371+ void apply (AutoSchedule &sched, const Expr ¶llelism ) {
13721372 if (!ordering.empty () && !is_initial_order) {
13731373 std::set<std::string> var_list;
13741374 for (const auto &v : ordering) {
@@ -1396,7 +1396,7 @@ class GPUTilingDedup {
13961396 }
13971397
13981398 GPUTileHelper helper{f, stage_num};
1399- Expr threads_budget = max_n_threads;
1399+ Expr threads_budget = min (parallelism, max_n_threads) ;
14001400
14011401 // Maximize GPU thread occupancy with the grid-stride loop.
14021402 //
@@ -1423,22 +1423,16 @@ class GPUTilingDedup {
14231423
14241424 const auto &[var, entry] = *iter;
14251425
1426- const bool should_unroll = can_prove (entry.factor <= 1 );
1427- if (should_unroll) {
1428- // Skip thread size of 1.
1429- continue ;
1430- }
1431-
14321426 split_info new_entry{entry};
1433- new_entry.factor = 1 ;
1427+ new_entry.factor = simplify ( min (threads_budget, entry. factor )) ;
14341428
14351429 const bool can_split = helper.try_split (new_entry);
14361430 if (!can_split) {
14371431 // If more than 3 gpu_blocks are defined, mark the current loop as the for-loop.
14381432 parallelize.erase (iter);
14391433 continue ;
14401434 }
1441- threads_budget = simplify (max (threads_budget / entry .factor , 1 ));
1435+ threads_budget = simplify (max (threads_budget / new_entry .factor , 1 ));
14421436 }
14431437
14441438 helper.commit (sched, is_compute_at);
@@ -2210,7 +2204,7 @@ Partitioner::find_best_tile_config(const Group &g) {
22102204 Group no_tile = g;
22112205 no_tile.tile_sizes = no_tile_config;
22122206
2213- bool show_analysis = false ;
2207+ constexpr bool show_analysis = false ;
22142208 GroupAnalysis no_tile_analysis = analyze_group (no_tile, show_analysis);
22152209
22162210 GroupAnalysis best_analysis = no_tile_analysis;
@@ -2233,7 +2227,7 @@ Partitioner::find_best_tile_config(const Group &g) {
22332227 Expr benefit = estimate_benefit (best_analysis, new_analysis,
22342228 no_redundant_work, true );
22352229
2236- if (show_analysis) {
2230+ if constexpr (show_analysis) {
22372231 debug (0 ) << " Benefit relative to not tiling:" << benefit << " \n " ;
22382232 debug (0 ) << " Best analysis:" << new_analysis;
22392233 debug (0 ) << " No tile analysis:" << no_tile_analysis;
@@ -3439,7 +3433,8 @@ void Partitioner::generate_group_cpu_schedule(
34393433 }
34403434 }
34413435 if (arch_params.is_gpu_schedule ) {
3442- auto parallelized_split = gpu_tiling.can_parallelize (v, iter->second );
3436+ const Expr gpu_threads = simplify (min (iter->second , arch_params.parallelism / def_par));
3437+ auto parallelized_split = gpu_tiling.can_parallelize (v, gpu_threads);
34433438 if (parallelized_split) {
34443439 auto split_vars = *parallelized_split;
34453440 inner_dims.emplace_back (split_vars.inner );
@@ -3463,7 +3458,7 @@ void Partitioner::generate_group_cpu_schedule(
34633458 }
34643459
34653460 if (arch_params.is_gpu_schedule ) {
3466- gpu_tiling.apply (sched);
3461+ gpu_tiling.apply (sched, arch_params. parallelism );
34673462 }
34683463
34693464 // Find the level at which group members will be computed.
@@ -3552,7 +3547,7 @@ void Partitioner::generate_group_cpu_schedule(
35523547 mem_rvars, mem_estimates, sched, gpu_tiling2);
35533548
35543549 if (arch_params.is_gpu_schedule ) {
3555- gpu_tiling2.apply (sched);
3550+ gpu_tiling2.apply (sched, arch_params. parallelism );
35563551 }
35573552 }
35583553}
0 commit comments