diff --git a/CMakeLists.txt b/CMakeLists.txt index 816ea8a9e965..1683152eebc5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -218,6 +218,7 @@ Halide_feature(WITH_DOCS "Halide's Doxygen documentation" OFF) Halide_feature(WITH_PACKAGING "Halide's CMake package install rules" TOP_LEVEL) Halide_feature(WITH_PYTHON_BINDINGS "Halide's native Python module (not the whole pip package)" ON DEPENDS Halide_ENABLE_EXCEPTIONS AND Halide_ENABLE_RTTI) +Halide_feature(WITH_COMPILER_PROFILING "Enable internal compiler tracing" OFF) Halide_feature(WITH_SERIALIZATION "Include experimental Serialization/Deserialization code" ON) Halide_feature(WITH_SERIALIZATION_JIT_ROUNDTRIP_TESTING "Intercepting JIT compilation with a serialization roundtrip, for test only" diff --git a/Makefile b/Makefile index b841da33060a..82c518dc1855 100644 --- a/Makefile +++ b/Makefile @@ -483,6 +483,7 @@ SOURCE_FILES = \ CodeGen_WebAssembly.cpp \ CodeGen_WebGPU_Dev.cpp \ CodeGen_X86.cpp \ + CompilerProfiling.cpp \ CompilerLogger.cpp \ ConstantBounds.cpp \ ConstantInterval.cpp \ @@ -677,6 +678,7 @@ HEADER_FILES = \ CodeGen_PyTorch.h \ CodeGen_Targets.h \ CodeGen_WebGPU_Dev.h \ + CompilerProfiling.h \ CompilerLogger.h \ ConciseCasts.h \ CPlusPlusMangle.h \ diff --git a/src/AddAtomicMutex.cpp b/src/AddAtomicMutex.cpp index cf3b0ae8bb89..eb3789b6ef8f 100644 --- a/src/AddAtomicMutex.cpp +++ b/src/AddAtomicMutex.cpp @@ -415,7 +415,7 @@ class AddAtomicMutex : public IRMutator { std::string name = unique_name('t'); index_let = index; index = Variable::make(index.type(), name); - body = ReplaceStoreIndexWithVar(op->producer_name, index).mutate(body); + body = ReplaceStoreIndexWithVar(op->producer_name, index)(body); } // This generates a pointer to the mutex array Expr mutex_array = Variable::make( @@ -454,8 +454,8 @@ Stmt add_atomic_mutex(Stmt s, const std::vector &outputs) { CheckAtomicValidity check; s.accept(&check); if (check.any_atomic) { - s = RemoveUnnecessaryMutexUse().mutate(s); - s = AddAtomicMutex(outputs).mutate(s); + s = RemoveUnnecessaryMutexUse()(s); + s = AddAtomicMutex(outputs)(s); } return s; } diff --git a/src/AddImageChecks.cpp b/src/AddImageChecks.cpp index 7114ad360135..0d3863bdaa0d 100644 --- a/src/AddImageChecks.cpp +++ b/src/AddImageChecks.cpp @@ -1,4 +1,5 @@ #include "AddImageChecks.h" +#include "CompilerProfiling.h" #include "ExternFuncArgument.h" #include "Function.h" #include "IRMutator.h" @@ -103,6 +104,7 @@ class TrimStmtToPartsThatAccessBuffers : public IRMutator { bool touches_buffer = false; const map &buffers; +protected: using IRMutator::visit; Expr visit(const Call *op) override { @@ -161,6 +163,7 @@ Stmt add_image_checks_inner(Stmt s, const map &env, const FuncValueBounds &fb, bool will_inject_host_copies) { + ZoneScoped; bool no_bounds_query = t.has_feature(Target::NoBoundsQuery); @@ -185,10 +188,10 @@ Stmt add_image_checks_inner(Stmt s, // Add the input buffer(s) and annotate which output buffers are // used on host. - s.accept(&finder); + finder(s); Scope empty_scope; - Stmt sub_stmt = TrimStmtToPartsThatAccessBuffers(bufs).mutate(s); + Stmt sub_stmt = TrimStmtToPartsThatAccessBuffers(bufs)(s); map boxes = boxes_touched(sub_stmt, empty_scope, fb); // Now iterate through all the buffers, creating a list of lets @@ -225,7 +228,7 @@ Stmt add_image_checks_inner(Stmt s, string extent_name = concat_strings(name, ".extent.", i); string stride_name = concat_strings(name, ".stride.", i); replace_with_required[min_name] = Variable::make(Int(32), min_name + ".required"); - replace_with_required[extent_name] = simplify(Variable::make(Int(32), extent_name + ".required")); + replace_with_required[extent_name] = Variable::make(Int(32), extent_name + ".required"); replace_with_required[stride_name] = Variable::make(Int(32), stride_name + ".required"); } } @@ -737,6 +740,7 @@ Stmt add_image_checks(const Stmt &s, // Checks for images go at the marker deposited by computation // bounds inference. class Injector : public IRMutator { + protected: using IRMutator::visit; Expr visit(const Variable *op) override { @@ -794,9 +798,10 @@ Stmt add_image_checks(const Stmt &s, bool will_inject_host_copies) : outputs(outputs), t(t), order(order), env(env), fb(fb), will_inject_host_copies(will_inject_host_copies) { } - } injector(outputs, t, order, env, fb, will_inject_host_copies); + }; + Injector injector(outputs, t, order, env, fb, will_inject_host_copies); - return injector.mutate(s); + return injector(s); } } // namespace Internal diff --git a/src/AddParameterChecks.cpp b/src/AddParameterChecks.cpp index e63c4a6395d5..18a7dfb21556 100644 --- a/src/AddParameterChecks.cpp +++ b/src/AddParameterChecks.cpp @@ -1,4 +1,5 @@ #include "AddParameterChecks.h" +#include "CompilerProfiling.h" #include "IROperator.h" #include "IRVisitor.h" #include "Substitute.h" diff --git a/src/AlignLoads.cpp b/src/AlignLoads.cpp index 263c6b4844de..9801411341dd 100644 --- a/src/AlignLoads.cpp +++ b/src/AlignLoads.cpp @@ -165,7 +165,7 @@ class AlignLoads : public IRMutator { } // namespace Stmt align_loads(const Stmt &s, int alignment, int min_bytes_to_align) { - return AlignLoads(alignment, min_bytes_to_align).mutate(s); + return AlignLoads(alignment, min_bytes_to_align)(s); } } // namespace Internal diff --git a/src/AllocationBoundsInference.cpp b/src/AllocationBoundsInference.cpp index a1e0831b975e..5598f286d276 100644 --- a/src/AllocationBoundsInference.cpp +++ b/src/AllocationBoundsInference.cpp @@ -169,8 +169,8 @@ class StripDeclareBoxTouched : public IRMutator { Stmt allocation_bounds_inference(Stmt s, const map &env, const FuncValueBounds &fb) { - s = AllocationInference(env, fb).mutate(s); - s = StripDeclareBoxTouched().mutate(s); + s = AllocationInference(env, fb)(s); + s = StripDeclareBoxTouched()(s); return s; } diff --git a/src/Associativity.cpp b/src/Associativity.cpp index bd67f0245af6..0598e6ce5c6f 100644 --- a/src/Associativity.cpp +++ b/src/Associativity.cpp @@ -339,7 +339,7 @@ AssociativeOp prove_associativity(const string &f, vector args, vector sema; std::set producers_dropped; @@ -285,6 +286,7 @@ class GenerateProducerBody : public NoOpCollapsingMutator { }; class GenerateConsumerBody : public NoOpCollapsingMutator { +protected: const string &func; vector sema; @@ -342,6 +344,7 @@ class GenerateConsumerBody : public NoOpCollapsingMutator { }; class CloneAcquire : public IRMutator { +protected: using IRMutator::visit; const string &old_name; @@ -390,6 +393,7 @@ class CountConsumeNodes : public IRVisitor { }; class ForkAsyncProducers : public IRMutator { +protected: using IRMutator::visit; const map &env; @@ -414,8 +418,8 @@ class ForkAsyncProducers : public IRMutator { sema_vars.push_back(Variable::make(type_of(), sema_names.back())); } - Stmt producer = GenerateProducerBody(name, sema_vars, cloned_acquires).mutate(body); - Stmt consumer = GenerateConsumerBody(name, sema_vars).mutate(body); + Stmt producer = GenerateProducerBody(name, sema_vars, cloned_acquires)(body); + Stmt consumer = GenerateConsumerBody(name, sema_vars)(body); // Recurse on both sides producer = mutate(producer); @@ -434,7 +438,7 @@ class ForkAsyncProducers : public IRMutator { // of the producer and consumer. const vector &clones = cloned_acquires[sema_name]; for (const auto &i : clones) { - body = CloneAcquire(sema_name, i).mutate(body); + body = CloneAcquire(sema_name, i)(body); body = LetStmt::make(i, sema_space, body); } @@ -493,6 +497,7 @@ class ForkAsyncProducers : public IRMutator { // simple failure case, error_async_require_fail. One has not been // written for the complex nested case yet.) class InitializeSemaphores : public IRMutator { +protected: using IRMutator::visit; const Type sema_type = type_of(); @@ -558,6 +563,7 @@ class InitializeSemaphores : public IRMutator { // A class to support stmt_uses_vars queries that repeatedly hit the same // sub-stmts. Used to support TightenProducerConsumerNodes below. class CachingStmtUsesVars : public IRMutator { +protected: const Scope<> &query; bool found_use = false; std::map cache; @@ -613,6 +619,7 @@ class CachingStmtUsesVars : public IRMutator { // Tighten the scope of consume nodes as much as possible to avoid needless synchronization. class TightenProducerConsumerNodes : public IRMutator { +protected: using IRMutator::visit; Stmt make_producer_consumer(const string &name, bool is_producer, Stmt body, const Scope<> &scope, CachingStmtUsesVars &uses_vars) { @@ -703,6 +710,7 @@ class TightenProducerConsumerNodes : public IRMutator { // Update indices to add ring buffer. class UpdateIndices : public IRMutator { +protected: using IRMutator::visit; Stmt visit(const Provide *op) override { @@ -734,6 +742,7 @@ class UpdateIndices : public IRMutator { // Inject ring buffering. class InjectRingBuffering : public IRMutator { +protected: using IRMutator::visit; struct Loop { @@ -768,7 +777,7 @@ class InjectRingBuffering : public IRMutator { } current_index = current_index % f.schedule().ring_buffer(); // Adds an extra index for to the all of the references of f. - body = UpdateIndices(op->name, current_index).mutate(body); + body = UpdateIndices(op->name, current_index)(body); if (f.schedule().async()) { Expr sema_var = Variable::make(type_of(), f.name() + ".folding_semaphore.ring_buffer"); @@ -816,6 +825,7 @@ class InjectRingBuffering : public IRMutator { // Broaden the scope of acquire nodes to pack trailing work into the // same task and to potentially reduce the nesting depth of tasks. class ExpandAcquireNodes : public IRMutator { +protected: using IRMutator::visit; Stmt visit(const Block *op) override { @@ -918,6 +928,7 @@ class ExpandAcquireNodes : public IRMutator { }; class TightenForkNodes : public IRMutator { +protected: using IRMutator::visit; Stmt make_fork(const Stmt &first, const Stmt &rest) { @@ -1005,12 +1016,13 @@ class TightenForkNodes : public IRMutator { } // namespace Stmt fork_async_producers(Stmt s, const map &env) { - s = TightenProducerConsumerNodes(env).mutate(s); - s = InjectRingBuffering(env).mutate(s); - s = ForkAsyncProducers(env).mutate(s); - s = ExpandAcquireNodes().mutate(s); - s = TightenForkNodes().mutate(s); - s = InitializeSemaphores().mutate(s); + ZoneScoped; + s = TightenProducerConsumerNodes(env)(s); + s = InjectRingBuffering(env)(s); + s = ForkAsyncProducers(env)(s); + s = ExpandAcquireNodes()(s); + s = TightenForkNodes()(s); + s = InitializeSemaphores()(s); return s; } diff --git a/src/AutoScheduleUtils.cpp b/src/AutoScheduleUtils.cpp index d2227f831462..5f4578ee484f 100644 --- a/src/AutoScheduleUtils.cpp +++ b/src/AutoScheduleUtils.cpp @@ -53,14 +53,14 @@ Expr substitute_var_estimates(Expr e) { if (!e.defined()) { return e; } - return simplify(SubstituteVarEstimates().mutate(e)); + return simplify(SubstituteVarEstimates()(e)); } Stmt substitute_var_estimates(Stmt s) { if (!s.defined()) { return s; } - return simplify(SubstituteVarEstimates().mutate(s)); + return simplify(SubstituteVarEstimates()(s)); } int string_to_int(const string &s) { diff --git a/src/BoundConstantExtentLoops.cpp b/src/BoundConstantExtentLoops.cpp index c4c4a17eb297..bc76cbeb2738 100644 --- a/src/BoundConstantExtentLoops.cpp +++ b/src/BoundConstantExtentLoops.cpp @@ -12,6 +12,7 @@ namespace Internal { namespace { class BoundLoops : public IRMutator { +protected: using IRMutator::visit; std::vector> lets; @@ -128,7 +129,7 @@ class BoundLoops : public IRMutator { } // namespace Stmt bound_constant_extent_loops(const Stmt &s) { - return BoundLoops().mutate(s); + return BoundLoops()(s); } } // namespace Internal diff --git a/src/BoundSmallAllocations.cpp b/src/BoundSmallAllocations.cpp index f3347c0f47fd..89f7c738ea9b 100644 --- a/src/BoundSmallAllocations.cpp +++ b/src/BoundSmallAllocations.cpp @@ -156,7 +156,8 @@ class BoundSmallAllocations : public IRMutator { } // namespace Stmt bound_small_allocations(const Stmt &s) { - return BoundSmallAllocations().mutate(s); + ZoneScoped; + return BoundSmallAllocations()(s); } } // namespace Internal diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 32b4159ea4f7..c5548ee870bc 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -119,6 +119,7 @@ Expr find_constant_bound(const Expr &e, Direction d, const Scope &scop } Interval find_constant_bounds(const Expr &e, const Scope &scope) { + ZoneScoped; Expr expr = bound_correlated_differences(simplify(remove_likelies(e))); Interval interval = bounds_of_expr_in_scope(expr, scope, FuncValueBounds(), true); interval = simplify(interval); @@ -227,7 +228,7 @@ class Bounds : public IRVisitor { #endif // DO_TRACK_BOUNDS_INTERVALS -private: +protected: // Compute the intrinsic bounds of a function. void bounds_of_func(const string &name, int value_index, Type t) { // if we can't get a good bound from the function, fall back to the bounds of the type. @@ -1795,7 +1796,7 @@ Interval bounds_of_expr_in_scope_with_indent(const Expr &expr, const Scope vars_depth; @@ -2247,7 +2249,7 @@ class BoxesTouched : public IRGraphVisitor { #endif // DO_TRACK_BOUNDS_INTERVALS -private: +protected: struct VarInstance { string var; int instance; @@ -3000,6 +3002,7 @@ class BoxesTouched : public IRGraphVisitor { map boxes_touched(const Expr &e, Stmt s, bool consider_calls, bool consider_provides, const string &fn, const Scope &scope, const FuncValueBounds &fb) { + ZoneScoped; if (!fn.empty() && s.defined()) { // Filter things down to the relevant sub-Stmts, so we don't spend a // long time reasoning about lets and ifs that don't surround an @@ -3099,7 +3102,7 @@ map boxes_touched(const Expr &e, Stmt s, bool consider_calls, bool // as possible, so that BoxesTouched can prune the variable scope tighter // when encountering the IfThenElse. if (s.defined()) { - s = SolveIfThenElse().mutate(s); + s = SolveIfThenElse()(s); } // Do calls and provides separately, for better simplification. @@ -3108,18 +3111,18 @@ map boxes_touched(const Expr &e, Stmt s, bool consider_calls, bool if (consider_calls) { if (e.defined()) { - e.accept(&calls); + calls(e); } if (s.defined()) { - s.accept(&calls); + calls(s); } } if (consider_provides) { if (e.defined()) { - e.accept(&provides); + provides(e); } if (s.defined()) { - s.accept(&provides); + provides(s); } } @@ -3255,6 +3258,7 @@ Interval compute_pure_function_definition_value_bounds( FuncValueBounds compute_function_value_bounds(const vector &order, const map &env) { + ZoneScoped; FuncValueBounds fb; for (const auto &func_name : order) { diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 72f45360b3b5..dda848e3bc7e 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -402,7 +402,7 @@ class BoundsInference : public IRMutator { } select_to_if_then_else; for (auto &e : exprs) { - e.value = select_to_if_then_else.mutate(e.value); + e.value = select_to_if_then_else(e.value); } } @@ -1382,8 +1382,7 @@ Stmt bounds_inference(Stmt s, s = For::make("", 0, 0, ForType::Serial, Partition::Never, DeviceAPI::None, s); s = BoundsInference(funcs, fused_func_groups, fused_pairs_in_groups, - outputs, func_bounds, target) - .mutate(s); + outputs, func_bounds, target)(s); return s.as()->body; } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ee8150c351bd..116f88a1c4ad 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -86,6 +86,7 @@ target_sources( CodeGen_Vulkan_Dev.h CodeGen_WebGPU_Dev.h CompilerLogger.h + CompilerProfiling.h ConciseCasts.h CPlusPlusMangle.h ConstantBounds.h @@ -269,6 +270,7 @@ target_sources( CodeGen_WebGPU_Dev.cpp CodeGen_X86.cpp CompilerLogger.cpp + CompilerProfiling.cpp CPlusPlusMangle.cpp ConstantBounds.cpp ConstantInterval.cpp @@ -515,6 +517,11 @@ target_compile_definitions(Halide PRIVATE WITH_SPIRV) target_compile_definitions(Halide PRIVATE WITH_VULKAN) target_compile_definitions(Halide PRIVATE WITH_WEBGPU) +if (WITH_COMPILER_PROFILING) + target_compile_definitions(Halide PRIVATE WITH_COMPILER_PROFILING) +endif() + + ## # Flatbuffers and Serialization dependencies. ## diff --git a/src/CSE.cpp b/src/CSE.cpp index c2a46d93bc4d..b54f2424cc31 100644 --- a/src/CSE.cpp +++ b/src/CSE.cpp @@ -1,6 +1,7 @@ #include #include "CSE.h" +#include "CompilerProfiling.h" #include "IREquality.h" #include "IRMutator.h" #include "IROperator.h" @@ -186,6 +187,7 @@ class Replacer : public IRGraphMutator { }; class RemoveLets : public IRGraphMutator { +protected: using IRGraphMutator::visit; Scope scope; @@ -218,6 +220,7 @@ class RemoveLets : public IRGraphMutator { }; class CSEEveryExprInStmt : public IRMutator { +protected: bool lift_all; using IRMutator::visit; @@ -260,6 +263,7 @@ class CSEEveryExprInStmt : public IRMutator { } // namespace Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) { + ZoneScoped; Expr e = e_in; // Early-out for trivial cases. @@ -269,7 +273,7 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) { debug(4) << "\n\n\nInput to CSE " << e << "\n"; - e = RemoveLets().mutate(e); + e = RemoveLets()(e); debug(4) << "After removing lets: " << e << "\n"; @@ -277,6 +281,7 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) { // the same name as the temporaries we intend to introduce. Find any such // Vars so that we know not to use those names. class UniqueNameProvider : public IRGraphVisitor { + protected: using IRGraphVisitor::visit; const char prefix = 't'; // Annoyingly, this can't be static because this is a local class. @@ -303,14 +308,18 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) { } while (vars.count(name)); return name; } - } namer; - e.accept(&namer); + }; + UniqueNameProvider namer; + { + ZoneScopedN("UniqueNameProvider"); + e.accept(&namer); + } GVN gvn; - e = gvn.mutate(e); + e = gvn(e); ComputeUseCounts count_uses(gvn, lift_all); - count_uses.include(e); + count_uses(e); debug(4) << "Canonical form without lets " << e << "\n"; @@ -331,7 +340,7 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) { // Rebuild the expr to include references to the variables: Replacer replacer(replacements); - e = replacer.mutate(e); + e = replacer(e); debug(4) << "With variables " << e << "\n"; @@ -340,7 +349,7 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) { // Drop this variable as an acceptable replacement for this expr. replacer.erase(value); // Use containing lets in the value. - e = Let::make(var, replacer.mutate(value), e); + e = Let::make(var, replacer(value), e); } debug(4) << "With lets: " << e << "\n"; @@ -349,7 +358,8 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) { } Stmt common_subexpression_elimination(const Stmt &s, bool lift_all) { - return CSEEveryExprInStmt(lift_all).mutate(s); + ZoneScoped; + return CSEEveryExprInStmt(lift_all)(s); } // Testing code. @@ -388,8 +398,7 @@ class NormalizeVarNames : public IRMutator { void check(const Expr &in, const Expr &correct) { Expr result = common_subexpression_elimination(in); - NormalizeVarNames n; - result = n.mutate(result); + result = NormalizeVarNames()(result); internal_assert(equal(result, correct)) << "Incorrect CSE:\n" << in diff --git a/src/CanonicalizeGPUVars.cpp b/src/CanonicalizeGPUVars.cpp index 7ca9b7c4fbf5..f5f4d7490f48 100644 --- a/src/CanonicalizeGPUVars.cpp +++ b/src/CanonicalizeGPUVars.cpp @@ -363,10 +363,9 @@ class ValidateGPUSchedule : public IRVisitor { } // anonymous namespace Stmt canonicalize_gpu_vars(Stmt s) { - ValidateGPUSchedule validator; - s.accept(&validator); - CanonicalizeGPUVars canonicalizer; - s = canonicalizer.mutate(s); + ZoneScoped; + ValidateGPUSchedule()(s); + s = CanonicalizeGPUVars()(s); return s; } diff --git a/src/ClampUnsafeAccesses.cpp b/src/ClampUnsafeAccesses.cpp index ed6955446196..b976e55d8d0f 100644 --- a/src/ClampUnsafeAccesses.cpp +++ b/src/ClampUnsafeAccesses.cpp @@ -107,7 +107,7 @@ struct ClampUnsafeAccesses : IRMutator { } // namespace Stmt clamp_unsafe_accesses(const Stmt &s, const std::map &env, FuncValueBounds &func_bounds) { - return ClampUnsafeAccesses(env, func_bounds).mutate(s); + return ClampUnsafeAccesses(env, func_bounds)(s); } } // namespace Halide::Internal diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index d5c3879d36af..0546c4de58df 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -1193,7 +1193,7 @@ void CodeGen_ARM::compile_func(const LoweredFunc &f, // Substitute in strided loads to get vld2/3/4 emission. We don't do it // on Apple silicon, because doing a dense load and then shuffling is // actually faster. - func.body = SubstituteInStridedLoads().mutate(func.body); + func.body = SubstituteInStridedLoads()(func.body); } // Look for opportunities to turn a + (b << c) into umlal/smlal // and a - (b << c) into umlsl/smlsl. diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 65892bff2c2c..7c3b179ec1b1 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -189,6 +189,7 @@ class TypeInfoGatherer : public IRGraphVisitor { CodeGen_C::CodeGen_C(ostream &s, const Target &t, OutputKind output_kind, const std::string &guard) : IRPrinter(s), id("$$ BAD ID $$"), target(t), output_kind(output_kind) { + ZoneScoped; if (output_kind == CPlusPlusFunctionInfoHeader) { // If it's a header, emit an include guard. diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index 8445ab2c527c..9d9bc8475845 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -1119,7 +1119,7 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s, }; FindSharedAllocationsAndUniquify fsa; - s = fsa.mutate(s); + s = fsa(s); uint32_t total_shared_bytes = 0; for (const Stmt &sop : fsa.allocs) { diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 0cadc5fe2a4b..90e0a4a48900 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -316,7 +316,7 @@ class SloppyUnpredicateLoadsAndStores : public IRMutator { }; Stmt sloppy_unpredicate_loads_and_stores(const Stmt &s) { - return SloppyUnpredicateLoadsAndStores().mutate(s); + return SloppyUnpredicateLoadsAndStores()(s); } class InjectHVXLocks : public IRMutator { @@ -463,7 +463,7 @@ class InjectHVXLocks : public IRMutator { Stmt inject_hvx_lock_unlock(Stmt body, const Target &target) { InjectHVXLocks i(target); - body = i.mutate(body); + body = i(body); if (i.uses_hvx) { body = acquire_hvx_context(body, target); } diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index b87254c3a1a7..80b7bfd45ec1 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -9,6 +9,7 @@ #include "CodeGen_Posix.h" #include "CodeGen_Targets.h" #include "CompilerLogger.h" +#include "CompilerProfiling.h" #include "Debug.h" #include "Deinterleave.h" #include "EmulateFloat16Math.h" @@ -223,6 +224,7 @@ std::unique_ptr CodeGen_LLVM::new_for_target(const Target &target, } void CodeGen_LLVM::initialize_llvm() { + ZoneScoped; static std::once_flag init_llvm_once; std::call_once(init_llvm_once, []() { // You can hack in command-line args to llvm with the @@ -260,6 +262,7 @@ void CodeGen_LLVM::initialize_llvm() { } void CodeGen_LLVM::init_context() { + ZoneScoped; // Ensure our IRBuilder is using the current context. builder = std::make_unique>(*context); @@ -301,6 +304,7 @@ void CodeGen_LLVM::init_context() { } void CodeGen_LLVM::init_module() { + ZoneScoped; init_context(); // Start with a module containing the initial module for this target. @@ -496,6 +500,7 @@ CodeGen_LLVM::ScopedFastMath::~ScopedFastMath() { } std::unique_ptr CodeGen_LLVM::compile(const Module &input) { + ZoneScoped; any_strict_float = input.any_strict_float(); init_codegen(input.name()); @@ -625,6 +630,7 @@ std::unique_ptr CodeGen_LLVM::compile(const Module &input) { } std::unique_ptr CodeGen_LLVM::finish_codegen() { + ZoneScoped; llvm::for_each(*module, set_function_attributes_from_halide_target_options); // Verify the module is ok @@ -704,6 +710,7 @@ void CodeGen_LLVM::end_func(const std::vector &args) { void CodeGen_LLVM::compile_func(const LoweredFunc &f, const std::string &simple_name, const std::string &extern_name) { + ZoneScoped; // Generate the function declaration and argument unpacking code. begin_func(f.linkage, simple_name, extern_name, f.args); @@ -1133,6 +1140,7 @@ llvm::Type *CodeGen_LLVM::llvm_type_of(const Type &t) const { } void CodeGen_LLVM::optimize_module() { + ZoneScoped; debug(3) << "Optimizing module\n"; auto time_start = std::chrono::high_resolution_clock::now(); diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index bac293b6ef16..785a2eecfd5b 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -70,6 +70,7 @@ class CodeGen_Metal_Dev : public CodeGen_GPU_Dev { public: CodeGen_Metal_C(std::ostream &s, const Target &t) : CodeGen_GPU_C(s, t) { + ZoneScoped; abs_returns_unsigned_type = false; #define alias(x, y) \ @@ -643,6 +644,7 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const FloatImm *op) { void CodeGen_Metal_Dev::add_kernel(Stmt s, const string &name, const vector &args) { + ZoneScoped; debug(2) << "CodeGen_Metal_Dev::compile " << name << "\n"; // We need to scalarize/de-predicate any loads/stores, since Metal does not @@ -676,7 +678,7 @@ struct BufferSize { void CodeGen_Metal_Dev::CodeGen_Metal_C::add_kernel(const Stmt &s, const string &name, const vector &args) { - + ZoneScoped; debug(2) << "Adding Metal kernel " << name << "\n"; // Figure out which arguments should be passed in constant. @@ -826,6 +828,7 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::add_kernel(const Stmt &s, } void CodeGen_Metal_Dev::init_module() { + ZoneScoped; debug(2) << "Metal device codegen init_module\n"; // wipe the internal kernel source @@ -866,6 +869,7 @@ void CodeGen_Metal_Dev::init_module() { } vector CodeGen_Metal_Dev::compile_to_src() { + ZoneScoped; string str = src_stream.str(); debug(1) << "Metal kernel:\n" << str << "\n"; @@ -922,6 +926,7 @@ std::string CodeGen_Metal_Dev::print_gpu_name(const std::string &name) { } // namespace std::unique_ptr new_CodeGen_Metal_Dev(const Target &target) { + ZoneScoped; return std::make_unique(target); } diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index 807d75444ed4..43d4e60c8a37 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -909,6 +909,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Atomic *op) { void CodeGen_OpenCL_Dev::add_kernel(Stmt s, const string &name, const vector &args) { + ZoneScoped; debug(2) << "CodeGen_OpenCL_Dev::compile " << name << "\n"; // We need to scalarize/de-predicate any loads/stores, since OpenCL does not @@ -943,6 +944,7 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::add_kernel(Stmt s, const string &name, const vector &args) { + ZoneScoped; debug(2) << "Adding OpenCL kernel " << name << "\n"; debug(2) << "Eliminating bool vectors\n"; diff --git a/src/CodeGen_PTX_Dev.cpp b/src/CodeGen_PTX_Dev.cpp index 3deab616ecc5..8c13c0bb6410 100644 --- a/src/CodeGen_PTX_Dev.cpp +++ b/src/CodeGen_PTX_Dev.cpp @@ -130,6 +130,7 @@ Type CodeGen_PTX_Dev::upgrade_type_for_storage(const Type &t) const { void CodeGen_PTX_Dev::add_kernel(Stmt stmt, const std::string &name, const std::vector &args) { + ZoneScoped; internal_assert(module != nullptr); debug(2) << "In CodeGen_PTX_Dev::add_kernel\n"; @@ -220,6 +221,7 @@ void CodeGen_PTX_Dev::add_kernel(Stmt stmt, } void CodeGen_PTX_Dev::init_module() { + ZoneScoped; // This class uses multiple inheritance. It's a GPU device code generator, // and also an llvm-based one. Both of these track strict_float presence, // but OffloadGPULoops only sets the GPU device code generator flag, so here @@ -538,7 +540,7 @@ void CodeGen_PTX_Dev::codegen_vector_reduce(const VectorReduce *op, const Expr & Expr b_slice = Shuffle::make_slice(b, i + l * factor, 1, p.factor); i_slice = Call::make(i_slice.type(), p.name, {a_slice, b_slice, i_slice}, Call::PureExtern); } - i_slice = RewriteLoadsAs32Bit().mutate(i_slice); + i_slice = RewriteLoadsAs32Bit()(i_slice); i_slice = simplify(i_slice); i_slice = common_subexpression_elimination(i_slice); result.push_back(i_slice); diff --git a/src/CompilerLogger.cpp b/src/CompilerLogger.cpp index 566fa4b9b926..3bf81728466c 100644 --- a/src/CompilerLogger.cpp +++ b/src/CompilerLogger.cpp @@ -123,7 +123,7 @@ void JSONCompilerLogger::obfuscate() { std::string rule = it.first; for (const auto &e : it.second) { ObfuscateNames obfuscater; - n[rule].emplace_back(obfuscater.mutate(e)); + n[rule].emplace_back(obfuscater(e)); } } matched_simplifier_rules = n; @@ -136,8 +136,8 @@ void JSONCompilerLogger::obfuscate() { // to post-process output from multiple unrelated Generators // and combine Exprs with similar shapes. ObfuscateNames obfuscater; - auto failed_to_prove = obfuscater.mutate(it.first); - auto original_expr = obfuscater.mutate(it.second); + auto failed_to_prove = obfuscater(it.first); + auto original_expr = obfuscater(it.second); n.emplace_back(std::move(failed_to_prove), std::move(original_expr)); } failed_to_prove_exprs = n; diff --git a/src/CompilerProfiling.cpp b/src/CompilerProfiling.cpp new file mode 100644 index 000000000000..06af2b7907b8 --- /dev/null +++ b/src/CompilerProfiling.cpp @@ -0,0 +1,193 @@ +#include "CompilerProfiling.h" + +#ifdef WITH_COMPILER_PROFILING + +#include "Util.h" + +#if !defined(_WIN32) +#include +#else +#include +#pragma warning(disable : 4091) +#include +#pragma comment(lib, "dbghelp.lib") +#endif + +#include +#include +#include + +namespace Halide { +namespace Internal { +namespace Profiling { + +static std::string demangle(const char *name) { +#if !defined(_WIN32) + int status = 0; + char *p = abi::__cxa_demangle(name, nullptr, nullptr, &status); + std::string ret(p); + free(p); + return ret; +#else + char demangled_name[8192]; + if (UnDecorateSymbolName(name, demangled_name, sizeof(demangled_name), + UNDNAME_COMPLETE)) { + std::string ret(demangled_name); + return ret; + } else { + DWORD error = GetLastError(); + std::cout << "UnDecorateSymbolName error: " << error << std::endl; + return name; + } +#endif +} + +static std::string cleanup_name(std::string name) { + debug(2) << "Cleaned " << name << " => "; + { + std::string_view profiled = "Halide::Internal::Profiling::Profiled<"; + if (size_t idx = name.find(profiled); idx != std::string::npos) { + size_t cl = idx + 1; + int num_open = 1; + while (num_open != 0 && cl < name.size()) { + if (name[cl] == '<') { + num_open++; + } else if (name[cl] == '>') { + num_open--; + } + cl++; + } + name = name.substr(idx + profiled.size(), cl - idx - profiled.size() - 1); + } + } + name = replace_all(name, "Halide::Internal::", ""); + name = replace_all(name, "(anonymous namespace)::", ""); + debug(2) << name << "\n"; + return name; +} + +Context::Context() { + if (std::string file = get_env_variable("HL_COMPILER_TRACE_FILE"); !file.empty()) { + if (std::string bits = get_env_variable("HL_COMPILER_TRACE_BITS"); !bits.empty()) { + active_bits = BIT_GENERIC; + std::vector sp = split_string(bits, ","); + for (const std::string &s : sp) { + if (s == "stmt") { + active_bits |= BIT_STMT; + } else if (s == "expr") { + active_bits |= BIT_EXPR; + } + } + } else { + active_bits = BIT_GENERIC; + } + } +} +Context::~Context() { + if (std::string file = get_env_variable("HL_COMPILER_TRACE_FILE"); !file.empty()) { + Profiling::write_halide_profiling_trace(file); + } +} + +void write_halide_profiling_trace(const std::string &file) { + ZoneScoped; + std::lock_guard lock(ctx.mutex); // Ensure no threads are born while serializing + + debug(1) << "Emitting trace.json events: " << file << "\n"; + + // Pass 1: Find the absolute global t=0 across all threads + uint64_t global_start_cycles = static_cast(-1); + for (const auto &trace : ctx.traces) { + if (trace.events.empty()) { + continue; + } + + uint64_t first_cycle = trace.start_cycles_64; + uint32_t first_timer = trace.events[0].timer; + uint32_t anchor_lower = static_cast(first_cycle & 0xFFFFFFFF); + + // Did the 32-bit timer roll over between thread creation and the first event? + if (first_timer < anchor_lower) { + first_cycle += 0x100000000ULL; + } + first_cycle = (first_cycle & 0xFFFFFFFF00000000ULL) | first_timer; + + if (first_cycle < global_start_cycles) { + global_start_cycles = first_cycle; + } + } + + std::map demangled_names; + + std::ofstream out(file); + out << "[\n"; + bool first = true; + + for (const auto &trace : ctx.traces) { + uint64_t current_cycles = trace.start_cycles_64; + uint32_t last_timer = static_cast(current_cycles & 0xFFFFFFFF); + + for (const auto &ev : trace.events) { + // Reconstruct the 64-bit timeline for this specific thread + if (ev.timer < last_timer) { + current_cycles += 0x1'0000'0000ULL; + } + current_cycles = (current_cycles & 0xFFFFFFFF00000000ULL) | ev.timer; + last_timer = ev.timer; + + double ts_us = static_cast(current_cycles - global_start_cycles) * 1e-3; + + if (!first) { + out << ",\n"; + } + first = false; + + out << " {"; + if (ev.event_type == 0) { + const char *src_tag = ev.src_tag; + if ( +#if !defined(_WIN32) + std::strncmp(src_tag, "N6Halide", 8) == 0 +#else + std::strncmp(src_tag, "class Halide", 12) == 0 || + std::strncmp(src_tag, "struct Halide", 13) == 0 +#endif + ) { + if (auto it = demangled_names.find(src_tag); it == demangled_names.end()) { + std::string dn = demangle(src_tag); + dn = cleanup_name(dn); + auto result = demangled_names.emplace(src_tag, std::move(dn)); + src_tag = result.first->second.c_str(); + } else { + src_tag = it->second.c_str(); + } + } + // Only add strings to the start event. + if (ev.tag == Event::Visitor) { + const char *node_name = IRNodeType_string((IRNodeType)ev.data); + out << "\"name\": \"" << node_name << "\", "; + out << "\"cat\": \"" << src_tag << "\", "; + } else if (ev.tag == Event::Generic) { + out << "\"name\": \"" << src_tag << "\", "; + out << "\"args\": {"; + out << "\"data\": \"" << ev.data << "\""; + out << "}, "; + } + } + out << "\"ph\": \"" << (ev.event_type == 0 ? "B" : "E") << "\", "; + out << "\"pid\": 1, "; + out << "\"tid\": " << trace.tid << ", "; // Same TID forces perfectly nested flame-graph boxes + out << "\"ts\": " << std::fixed << std::setprecision(3) << ts_us; + + out << "}\n"; // Let's put a newline to not kill bad parsers. + } + } + + out << "\n]\n"; +} + +} // namespace Profiling +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/CompilerProfiling.h b/src/CompilerProfiling.h new file mode 100644 index 000000000000..228b98bce3bf --- /dev/null +++ b/src/CompilerProfiling.h @@ -0,0 +1,245 @@ +#ifndef HALIDE_PROFILED_IR_VISITOR_H +#define HALIDE_PROFILED_IR_VISITOR_H + +#ifdef WITH_COMPILER_PROFILING + +#include "IR.h" + +#include +#include +#include +#include + +#ifndef __FUNCTION_NAME__ +#ifdef WIN32 // WINDOWS +#define __FUNCTION_NAME__ __FUNCTION__ +#else //*NIX +#define __FUNCTION_NAME__ __func__ +#endif +#endif + +/** \file + * Defines the base class for things that recursively walk over the IR + */ + +namespace Halide { +namespace Internal { +namespace Profiling { + +constexpr uint8_t BIT_GENERIC = 1 << 0; +constexpr uint8_t BIT_STMT = 1 << 1; +constexpr uint8_t BIT_EXPR = 1 << 2; + +inline uint64_t performance_counter() { + auto now = std::chrono::high_resolution_clock::now().time_since_epoch(); + return static_cast(std::chrono::duration_cast(now).count()); +} + +struct Event { + const char *src_tag; + uint32_t timer; + uint8_t event_type : 1; // 0 for Start, 1 for Stop + + enum Tag : uint8_t { + Generic, + Visitor, + } tag : 1; + + /** Optional for visitors and mutators. */ + uint32_t data : 30; +}; +static_assert(sizeof(Event) == sizeof(void *) + 8); + +struct ThreadTrace { + uint32_t tid; + uint64_t start_cycles_64; // 64-bit anchor to align threads accurately + std::vector events; +}; + +struct Context { + std::mutex mutex; + uint8_t active_bits{0}; + + // std::list guarantees that pointers to elements are never invalidated + std::list traces; + + Context(); + ~Context(); +}; + +inline Context ctx; + +inline ThreadTrace &init_thread_profiler() { + std::lock_guard lock(ctx.mutex); + + // Hash thread ID to get a clean 32-bit integer for the Chrome UI + uint32_t tid = static_cast(std::hash{}(std::this_thread::get_id())); + uint64_t anchor = performance_counter(); + + ctx.traces.push_back({tid, anchor, {}}); + ThreadTrace &tt = ctx.traces.back(); + + // Pre-allocate ~16MB per thread to guarantee zero allocations during hot loops + tt.events.reserve(1'000'000); + + return tt; +} + +inline thread_local ThreadTrace &tls_thread_trace = init_thread_profiler(); +inline thread_local std::vector *tls_profiling_events = &tls_thread_trace.events; + +inline void zone_begin(const char *src_tag, Event::Tag tag, unsigned data = 0) { + Event pe; + pe.timer = static_cast(performance_counter()); + pe.tag = tag; + pe.src_tag = src_tag; + pe.event_type = 0; // Start + pe.data = data; + tls_profiling_events->push_back(pe); +} + +inline void zone_end(const char *src_tag, Event::Tag tag, unsigned data = 0) { + Event pe; + pe.timer = static_cast(performance_counter()); + pe.tag = tag; + pe.src_tag = src_tag; + pe.event_type = 1; // Stop + pe.data = data; + tls_profiling_events->push_back(pe); +} + +inline void generic_zone_begin(const char *src_tag, unsigned data = 0) { + if (ctx.active_bits & BIT_GENERIC) { + zone_begin(src_tag, Event::Tag::Generic, data); + } +} + +inline void generic_zone_end(const char *src_tag, unsigned data = 0) { + if (ctx.active_bits & BIT_GENERIC) { + zone_end(src_tag, Event::Tag::Generic, data); + } +} + +struct ZoneScopedVisitor_ { + IRNodeType node_type; + const char *src_tag; + bool active; + + ZoneScopedVisitor_(IRNodeType node_type, const char *src_tag, uint8_t activation_bits) + : node_type(node_type), src_tag(src_tag), active(ctx.active_bits & activation_bits) { + if (active) { + zone_begin(src_tag, Event::Visitor, (unsigned)node_type); + } + } + + ZoneScopedVisitor_(const Expr &e, const char *src_tag) + : node_type(e.defined() ? e->node_type : IRNodeType::IntImm), + src_tag(src_tag), + active((ctx.active_bits & BIT_EXPR) && e.defined()) { + if (active) { + zone_begin(src_tag, Event::Visitor, (unsigned)node_type); + } + } + + ZoneScopedVisitor_(const Stmt &s, const char *src_tag) + : node_type(s.defined() ? s->node_type : IRNodeType::IntImm), + src_tag(src_tag), + active((ctx.active_bits & BIT_STMT) && s.defined()) { + if (active) { + zone_begin(src_tag, Event::Visitor, (unsigned)node_type); + } + } + + ~ZoneScopedVisitor_() { + if (active) { + zone_end(src_tag, Event::Visitor, (unsigned)node_type); + } + } +}; + +struct ZoneScoped_ { + Event::Tag tag; + const char *src_tag; + + ZoneScoped_(const char *src_tag, unsigned data = 0, Event::Tag tag = Event::Tag::Generic) + : tag(tag), + src_tag(src_tag) { + zone_begin(src_tag, tag, data); + } + + ~ZoneScoped_() { + zone_end(src_tag, tag, 0); + } +}; + +#define ZoneScoped \ + Halide::Internal::Profiling::ZoneScoped_ __zone_scoped(__FUNCTION_NAME__) +#define ZoneScopedN(...) \ + Halide::Internal::Profiling::ZoneScoped_ __zone_scoped(__VA_ARGS__) +#define ZoneScopedVisitor(...) \ + Halide::Internal::Profiling::ZoneScopedVisitor_ __zone_scoped(__VA_ARGS__) + +#ifdef HALIDE_ENABLE_RTTI +#define HalideVisitorDynamicNameTag typeid(*this).name() +#else +#define HalideVisitorDynamicNameTag __FUNCTION__ +#endif + +template +class Profiled : public Base { +public: + using Base::Base; +#ifdef HALIDE_ENABLE_RTTI + const char *tag = typeid(Base).name(); +#else + const char *tag = "Unknown (no RTTI)"; +#endif + +#define PROFILE_VISIT_STMT_OVERRIDE(T) \ + auto visit(const T *op) -> decltype(this->Base::visit(op)) override { \ + ZoneScopedVisitor_ _prof(IRNodeType::T, tag, BIT_STMT); \ + return Base::visit(op); \ + } + HALIDE_FOR_EACH_IR_STMT(PROFILE_VISIT_STMT_OVERRIDE) + +#undef PROFILE_VISIT_STMT_OVERRIDE + +#define PROFILE_VISIT_EXPR_OVERRIDE(T) \ + auto visit(const T *op) -> decltype(this->Base::visit(op)) override { \ + ZoneScopedVisitor_ _prof(IRNodeType::T, tag, BIT_EXPR); \ + return Base::visit(op); \ + } + HALIDE_FOR_EACH_IR_EXPR(PROFILE_VISIT_EXPR_OVERRIDE) + +#undef PROFILE_VISIT_EXPR_OVERRIDE +}; + +void write_halide_profiling_trace(const std::string &file); + +} // namespace Profiling + +template +using Profiled = Profiling::Profiled; + +} // namespace Internal +} // namespace Halide + +#else + +namespace Profiling { +inline void generic_zone_begin(const char *src_tag, unsigned data = 0) { +} +inline void generic_zone_end(const char *src_tag, unsigned data = 0) { +} +} // namespace Profiling + +template +using Profiled = Base; +#define ZoneScoped +#define ZoneScopedN(...) +#define ZoneScopedVisitor(...) +#define HalideVisitorDynamicNameTag + +#endif + +#endif // HALIDE_PROFILED_IR_VISITOR_H diff --git a/src/DebugToFile.cpp b/src/DebugToFile.cpp index 89ea6c36c92d..d4f90a75b4fd 100644 --- a/src/DebugToFile.cpp +++ b/src/DebugToFile.cpp @@ -125,12 +125,12 @@ class AddDummyRealizations : public IRMutator { Stmt debug_to_file(Stmt s, const vector &outputs, const map &env) { // Temporarily wrap the produce nodes for the output functions in // realize nodes so that we know when to write the debug outputs. - s = AddDummyRealizations(outputs).mutate(s); + s = AddDummyRealizations(outputs)(s); - s = DebugToFile(env).mutate(s); + s = DebugToFile(env)(s); // Remove the realize node we wrapped around the output - s = RemoveDummyRealizations(outputs).mutate(s); + s = RemoveDummyRealizations(outputs)(s); return s; } diff --git a/src/Definition.cpp b/src/Definition.cpp index 5e6b00d95867..1cd83ce4d45b 100644 --- a/src/Definition.cpp +++ b/src/Definition.cpp @@ -6,6 +6,7 @@ #include "IR.h" #include "IRMutator.h" #include "IROperator.h" +#include "IRVisitor.h" #include "Var.h" namespace Halide { @@ -26,47 +27,47 @@ struct DefinitionContents { : predicate(const_true()) { } - void accept(IRVisitor *visitor) const { + void accept(IRVisitor &visitor) const { if (predicate.defined()) { - predicate.accept(visitor); + visitor(predicate); } for (const Expr &val : values) { - val.accept(visitor); + visitor(val); } for (const Expr &arg : args) { - arg.accept(visitor); + visitor(arg); } - stage_schedule.accept(visitor); + stage_schedule.accept(&visitor); for (const Specialization &s : specializations) { if (s.condition.defined()) { - s.condition.accept(visitor); + s.condition.accept(&visitor); } - s.definition.accept(visitor); + s.definition.accept(&visitor); } } - void mutate(IRMutator *mutator) { + void mutate(IRMutator &mutator) { if (predicate.defined()) { - predicate = mutator->mutate(predicate); + predicate = mutator(predicate); } for (auto &value : values) { - value = mutator->mutate(value); + value = mutator(value); } for (auto &arg : args) { - arg = mutator->mutate(arg); + arg = mutator(arg); } - stage_schedule.mutate(mutator); + stage_schedule.mutate(&mutator); for (Specialization &s : specializations) { if (s.condition.defined()) { - s.condition = mutator->mutate(s.condition); + s.condition = mutator(s.condition); } - s.definition.mutate(mutator); + s.definition.mutate(&mutator); } } }; @@ -146,11 +147,11 @@ bool Definition::is_init() const { } void Definition::accept(IRVisitor *visitor) const { - contents->accept(visitor); + contents->accept(*visitor); } void Definition::mutate(IRMutator *mutator) { - contents->mutate(mutator); + contents->mutate(*mutator); } std::vector &Definition::args() { diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index f7a5b5f49aa8..0c3d0aa1f618 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -172,8 +172,7 @@ class StoreCollector : public IRMutator { Stmt collect_strided_stores(const Stmt &stmt, const std::string &name, int stride, int max_stores, std::vector lets, std::vector &stores) { - StoreCollector collect(name, stride, max_stores, lets, stores); - return collect.mutate(stmt); + return StoreCollector(name, stride, max_stores, lets, stores)(stmt); } class Deinterleaver : public IRGraphMutator { @@ -407,8 +406,7 @@ class Deinterleaver : public IRGraphMutator { Expr deinterleave(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) { e = substitute_in_all_lets(e); - Deinterleaver d(starting_lane, lane_stride, new_lanes, lets); - e = d.mutate(e); + e = Deinterleaver(starting_lane, lane_stride, new_lanes, lets)(e); e = common_subexpression_elimination(e); return simplify(e); } @@ -802,7 +800,7 @@ class Interleaver : public IRMutator { } // namespace Stmt rewrite_interleavings(const Stmt &s) { - return Interleaver().mutate(s); + return Interleaver()(s); } namespace { diff --git a/src/DerivativeUtils.cpp b/src/DerivativeUtils.cpp index 23643010855f..54e1c37e0b84 100644 --- a/src/DerivativeUtils.cpp +++ b/src/DerivativeUtils.cpp @@ -116,7 +116,7 @@ Expr add_let_expression(const Expr &expr, const map &let_var_mapping, const vector &let_variables) { // TODO: find a faster way to do this - Expr ret = StripLets().mutate(expr); + Expr ret = StripLets()(expr); bool changed = true; vector injected(let_variables.size(), false); while (changed) { @@ -593,7 +593,7 @@ struct SubstituteCallArgWithPureArg : public IRMutator { } // namespace Expr substitute_call_arg_with_pure_arg(Func f, int variable_id, const Expr &e) { - return simplify(SubstituteCallArgWithPureArg(std::move(f), variable_id).mutate(e)); + return simplify(SubstituteCallArgWithPureArg(std::move(f), variable_id)(e)); } } // namespace Internal diff --git a/src/DistributeShifts.cpp b/src/DistributeShifts.cpp index 4d053e7d8dfe..1169466f947c 100644 --- a/src/DistributeShifts.cpp +++ b/src/DistributeShifts.cpp @@ -196,7 +196,7 @@ class DistributeShiftsAsMuls : public IRMutator { } // namespace Stmt distribute_shifts(const Stmt &s, bool multiply_adds) { - return DistributeShiftsAsMuls(multiply_adds).mutate(s); + return DistributeShiftsAsMuls(multiply_adds)(s); } } // namespace Internal diff --git a/src/EarlyFree.cpp b/src/EarlyFree.cpp index 8b664c2bcf8d..0014915a0700 100644 --- a/src/EarlyFree.cpp +++ b/src/EarlyFree.cpp @@ -1,11 +1,8 @@ -#include #include #include "EarlyFree.h" -#include "ExprUsesVar.h" -#include "IREquality.h" #include "IRMutator.h" -#include "InjectHostDevBufferCopies.h" +#include "IRVisitor.h" namespace Halide { namespace Internal { @@ -159,7 +156,7 @@ class InjectEarlyFrees : public IRMutator { InjectMarker inject_marker; inject_marker.func = alloc->name; inject_marker.last_use = last_use.last_use; - stmt = inject_marker.mutate(stmt); + stmt = inject_marker(stmt); } else { stmt = Allocate::make(alloc->name, alloc->type, alloc->memory_type, alloc->extents, alloc->condition, @@ -174,7 +171,7 @@ class InjectEarlyFrees : public IRMutator { Stmt inject_early_frees(const Stmt &s) { InjectEarlyFrees early_frees; - return early_frees.mutate(s); + return early_frees(s); } } // namespace Internal diff --git a/src/EliminateBoolVectors.cpp b/src/EliminateBoolVectors.cpp index e4afa8f21569..b68e1efa2d83 100644 --- a/src/EliminateBoolVectors.cpp +++ b/src/EliminateBoolVectors.cpp @@ -322,11 +322,11 @@ class EliminateBoolVectors : public IRMutator { } // namespace Stmt eliminate_bool_vectors(const Stmt &s) { - return EliminateBoolVectors().mutate(s); + return EliminateBoolVectors()(s); } Expr eliminate_bool_vectors(const Expr &e) { - return EliminateBoolVectors().mutate(e); + return EliminateBoolVectors()(e); } } // namespace Internal diff --git a/src/Expr.cpp b/src/Expr.cpp index d73bd72660fa..7d55fe9350c4 100644 --- a/src/Expr.cpp +++ b/src/Expr.cpp @@ -4,6 +4,20 @@ namespace Halide { namespace Internal { +const char *IRNodeType_string(IRNodeType type) { + switch (type) { +#define PROFILE_NODE_CASE(T) \ + case Halide::Internal::IRNodeType::T: \ + return #T; + + HALIDE_FOR_EACH_IR_NODE(PROFILE_NODE_CASE) +#undef PROFILE_NODE_CASE + + default: + internal_error << "Unknown Node Tag"; + } +} + const IntImm *IntImm::make(Type t, int64_t value) { internal_assert(t.is_int() && t.is_scalar()) << "IntImm must be a scalar Int\n"; diff --git a/src/Expr.h b/src/Expr.h index b9832c104de8..0c691221298d 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -21,63 +21,76 @@ namespace Internal { class IRMutator; class IRVisitor; +// Exprs, in order of strength. Code in IRMatch.h and the +// simplifier relies on this order for canonicalization of +// expressions, so you may need to update those modules if you +// change this list. +#define HALIDE_FOR_EACH_IR_EXPR(V) \ + V(IntImm) \ + V(UIntImm) \ + V(FloatImm) \ + V(StringImm) \ + V(Broadcast) \ + V(Cast) \ + V(Reinterpret) \ + V(Variable) \ + V(Add) \ + V(Sub) \ + V(Mod) \ + V(Mul) \ + V(Div) \ + V(Min) \ + V(Max) \ + V(EQ) \ + V(NE) \ + V(LT) \ + V(LE) \ + V(GT) \ + V(GE) \ + V(And) \ + V(Or) \ + V(Not) \ + V(Select) \ + V(Load) \ + V(Ramp) \ + V(Call) \ + V(Let) \ + V(Shuffle) \ + V(VectorReduce) + +/* Stmts */ +#define HALIDE_FOR_EACH_IR_STMT(V) \ + V(LetStmt) \ + V(AssertStmt) \ + V(ProducerConsumer) \ + V(For) \ + V(Acquire) \ + V(Store) \ + V(Provide) \ + V(Allocate) \ + V(Free) \ + V(Realize) \ + V(Block) \ + V(Fork) \ + V(IfThenElse) \ + V(Evaluate) \ + V(Prefetch) \ + V(Atomic) \ + V(HoistedStorage) + +#define HALIDE_FOR_EACH_IR_NODE(V) \ + HALIDE_FOR_EACH_IR_EXPR(V) \ + HALIDE_FOR_EACH_IR_STMT(V) + /** All our IR node types get unique IDs for the purposes of RTTI */ -enum class IRNodeType { - // Exprs, in order of strength. Code in IRMatch.h and the - // simplifier relies on this order for canonicalization of - // expressions, so you may need to update those modules if you - // change this list. - IntImm, - UIntImm, - FloatImm, - StringImm, - Broadcast, - Cast, - Reinterpret, - Variable, - Add, - Sub, - Mod, - Mul, - Div, - Min, - Max, - EQ, - NE, - LT, - LE, - GT, - GE, - And, - Or, - Not, - Select, - Load, - Ramp, - Call, - Let, - Shuffle, - VectorReduce, - // Stmts - LetStmt, - AssertStmt, - ProducerConsumer, - For, - Acquire, - Store, - Provide, - Allocate, - Free, - Realize, - Block, - Fork, - IfThenElse, - Evaluate, - Prefetch, - Atomic, - HoistedStorage +enum class IRNodeType : uint8_t { +#define DECL_ENUM(X) X, + HALIDE_FOR_EACH_IR_NODE(DECL_ENUM) +#undef DECL_ENUM }; +const char *IRNodeType_string(IRNodeType type); + constexpr IRNodeType StrongestExprNodeType = IRNodeType::VectorReduce; /** The abstract base classes for a node in the Halide IR. */ diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 8fdcea73f34b..17d0004bdbf9 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -672,7 +672,7 @@ class ExtractTileOperations : public IRMutator { } // namespace Stmt extract_tile_operations(const Stmt &s) { - return ExtractTileOperations().mutate(s); + return ExtractTileOperations()(s); } } // namespace Internal } // namespace Halide diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index 2873c5eb4ca9..6dc0850e8300 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1,17 +1,14 @@ #include "FindIntrinsics.h" #include "CSE.h" -#include "CodeGen_Internal.h" -#include "ConciseCasts.h" #include "ConstantBounds.h" #include "IRMatch.h" #include "IRMutator.h" +#include "IRVisitor.h" #include "Simplify.h" namespace Halide { namespace Internal { -using namespace Halide::ConciseCasts; - namespace { // This routine provides a guard on the return type of intrisics such that only @@ -1121,6 +1118,7 @@ class FindIntrinsics : public IRMutator { // because each let in a chain has a wider value than the // ones it refers to. class SubstituteInWideningLets : public IRMutator { +protected: using IRMutator::visit; bool widens(const Expr &e) { @@ -1220,7 +1218,7 @@ class SubstituteInWideningLets : public IRMutator { if (should_replace) { size_t start_of_new_lets = frames.size(); - value = extractor.mutate(value); + value = extractor(value); // Mutate any subexpressions the extractor decided to // leave behind, in case they in turn depend on lets // we've decided to substitute in. @@ -1266,16 +1264,16 @@ class SubstituteInWideningLets : public IRMutator { } // namespace Stmt find_intrinsics(const Stmt &s) { - Stmt stmt = SubstituteInWideningLets().mutate(s); - stmt = FindIntrinsics().mutate(stmt); + Stmt stmt = SubstituteInWideningLets()(s); + stmt = FindIntrinsics()(stmt); // In case we want to hoist widening ops back out stmt = common_subexpression_elimination(stmt); return stmt; } Expr find_intrinsics(const Expr &e) { - Expr expr = SubstituteInWideningLets().mutate(e); - expr = FindIntrinsics().mutate(expr); + Expr expr = SubstituteInWideningLets()(e); + expr = FindIntrinsics()(expr); expr = common_subexpression_elimination(expr); return expr; } @@ -1651,11 +1649,11 @@ class LowerIntrinsics : public IRMutator { } // namespace Expr lower_intrinsics(const Expr &e) { - return LowerIntrinsics().mutate(e); + return LowerIntrinsics()(e); } Stmt lower_intrinsics(const Stmt &s) { - return LowerIntrinsics().mutate(s); + return LowerIntrinsics()(s); } } // namespace Internal diff --git a/src/FlattenNestedRamps.cpp b/src/FlattenNestedRamps.cpp index a98ed9bdb427..efa373f6970a 100644 --- a/src/FlattenNestedRamps.cpp +++ b/src/FlattenNestedRamps.cpp @@ -148,11 +148,11 @@ class LowerConcatBits : public IRMutator { } // namespace Stmt flatten_nested_ramps(const Stmt &s) { - return LowerConcatBits().mutate(FlattenRamps().mutate(s)); + return LowerConcatBits()(FlattenRamps()(s)); } Expr flatten_nested_ramps(const Expr &e) { - return LowerConcatBits().mutate(FlattenRamps().mutate(e)); + return LowerConcatBits()(FlattenRamps()(e)); } } // namespace Internal diff --git a/src/Func.cpp b/src/Func.cpp index 04b50412b6ac..a081116762df 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -614,7 +614,7 @@ vector substitute_self_reference(const vector &values, const string vector result; result.reserve(values.size()); for (const auto &val : values) { - result.push_back(subs.mutate(val)); + result.push_back(subs(val)); } return result; } diff --git a/src/Function.cpp b/src/Function.cpp index 54fe96f785f2..d9484e5aca0d 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -1,7 +1,5 @@ #include #include -#include -#include #include #include "CSE.h" @@ -168,10 +166,10 @@ struct FunctionContents { if (!extern_function_name.empty()) { for (ExternFuncArgument &i : extern_arguments) { if (i.is_expr()) { - i.expr = mutator->mutate(i.expr); + i.expr = (*mutator)(i.expr); } } - extern_proxy_expr = mutator->mutate(extern_proxy_expr); + extern_proxy_expr = (*mutator)(extern_proxy_expr); } } }; @@ -837,14 +835,14 @@ void Function::define_update(const vector &_args, vector values, con // memory leaks. We need to break these cycles. WeakenFunctionPtrs weakener(contents.get()); for (auto &arg : args) { - arg = weakener.mutate(arg); + arg = weakener(arg); } for (auto &value : values) { - value = weakener.mutate(value); + value = weakener(value); } if (check.reduction_domain.defined()) { check.reduction_domain.set_predicate( - weakener.mutate(check.reduction_domain.predicate())); + weakener(check.reduction_domain.predicate())); } Definition r(args, values, check.reduction_domain, false); diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index 88f9a542550f..a8d447bd391c 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -31,6 +31,7 @@ using std::vector; namespace { class ExtractBlockSize : public IRVisitor { +protected: Expr block_extent[3], block_count[3]; string block_var_name[3]; @@ -123,6 +124,7 @@ class ExtractBlockSize : public IRVisitor { }; class NormalizeDimensionality : public IRMutator { +protected: using IRMutator::visit; const ExtractBlockSize &block_size; @@ -184,6 +186,7 @@ class NormalizeDimensionality : public IRMutator { }; class ReplaceForWithIf : public IRMutator { +protected: using IRMutator::visit; const ExtractBlockSize &block_size; @@ -223,6 +226,7 @@ class ReplaceForWithIf : public IRMutator { }; class ExtractSharedAndHeapAllocations : public IRMutator { +protected: using IRMutator::visit; struct IntInterval { @@ -289,7 +293,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { public: vector allocations; -private: +protected: map shared; bool in_threads = false; @@ -907,7 +911,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { : alloc_name(alloc_name), cluster_name(cluster_name), offset(offset) { } } rewriter{alloc.name, name, offset}; - s = rewriter.mutate(s); + s = rewriter(s); } // Define the group offset in terms of the previous group in the cluster @@ -1041,6 +1045,7 @@ class ExtractSharedAndHeapAllocations : public IRMutator { // block. Should only be run after shared allocations have already // been extracted. class ExtractRegisterAllocations : public IRMutator { +protected: using IRMutator::visit; struct RegisterAllocation { @@ -1206,6 +1211,7 @@ class ExtractRegisterAllocations : public IRMutator { }; class InjectThreadBarriers : public IRMutator { +protected: bool in_threads = false, injected_barrier; using IRMutator::visit; @@ -1360,6 +1366,7 @@ class InjectThreadBarriers : public IRMutator { }; class FuseGPUThreadLoopsSingleKernel : public IRMutator { +protected: using IRMutator::visit; const ExtractBlockSize &block_size; ExtractSharedAndHeapAllocations &block_allocations; @@ -1373,7 +1380,7 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator { << body << "\n\n"; NormalizeDimensionality n(block_size, op->device_api); - body = n.mutate(body); + body = n(body); debug(3) << "Normalized dimensionality:\n" << body << "\n\n"; @@ -1382,7 +1389,7 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator { ExtractRegisterAllocations register_allocs; ForType innermost_loop_type = ForType::GPUThread; if (block_size.threads_dimensions()) { - body = register_allocs.mutate(body); + body = register_allocs(body); if (register_allocs.has_lane_loop) { innermost_loop_type = ForType::GPULane; } @@ -1394,14 +1401,14 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator { if (register_allocs.has_thread_loop) { // If there's no loop over threads, everything is already synchronous. InjectThreadBarriers i{block_allocations, register_allocs}; - body = i.mutate(body); + body = i(body); } debug(3) << "Injected synchronization:\n" << body << "\n\n"; ReplaceForWithIf f(block_size); - body = f.mutate(body); + body = f(body); debug(3) << "Replaced for with if:\n" << body << "\n\n"; @@ -1448,6 +1455,7 @@ class FuseGPUThreadLoopsSingleKernel : public IRMutator { }; class FuseGPUThreadLoops : public IRMutator { +protected: using IRMutator::visit; Stmt visit(const For *op) override { @@ -1463,17 +1471,17 @@ class FuseGPUThreadLoops : public IRMutator { // Do the analysis of thread block size and shared memory // usage. ExtractBlockSize block_size; - Stmt loop = Stmt(op); - loop.accept(&block_size); + block_size(op); + Stmt loop(op); ExtractSharedAndHeapAllocations block_allocations(op->device_api); - loop = block_allocations.mutate(loop); + loop = block_allocations(loop); debug(3) << "Pulled out shared allocations:\n" << loop << "\n\n"; // Mutate the inside of the kernel - loop = FuseGPUThreadLoopsSingleKernel(block_size, block_allocations).mutate(loop); + loop = FuseGPUThreadLoopsSingleKernel(block_size, block_allocations)(loop); loop = block_allocations.rewrap_kernel_launch(loop, block_size, op->device_api); @@ -1485,6 +1493,7 @@ class FuseGPUThreadLoops : public IRMutator { }; class ZeroGPULoopMins : public IRMutator { +protected: bool in_non_glsl_gpu = false; using IRMutator::visit; @@ -1516,13 +1525,14 @@ class ZeroGPULoopMins : public IRMutator { // Also used by InjectImageIntrinsics Stmt zero_gpu_loop_mins(const Stmt &s) { - return ZeroGPULoopMins().mutate(s); + return ZeroGPULoopMins()(s); } namespace { // Find the inner most GPU block of a statement. class FindInnermostGPUBlock : public IRVisitor { +protected: using IRVisitor::visit; void visit(const For *op) override { @@ -1540,6 +1550,7 @@ class FindInnermostGPUBlock : public IRVisitor { // Given a condition and a loop, add the condition // to the loop body. class AddConditionToALoop : public IRMutator { +protected: using IRMutator::visit; Stmt visit(const For *op) override { @@ -1562,6 +1573,7 @@ class AddConditionToALoop : public IRMutator { // Push if statements between GPU blocks through all GPU blocks. // Throw error if the if statement has an else clause. class NormalizeIfStatements : public IRMutator { +protected: using IRMutator::visit; bool inside_gpu_blocks = false; @@ -1579,10 +1591,10 @@ class NormalizeIfStatements : public IRMutator { return IRMutator::visit(op); } FindInnermostGPUBlock find; - op->accept(&find); + find(op); if (find.found_gpu_block != nullptr) { internal_assert(!op->else_case.defined()) << "Found an if statement with else case between two GPU blocks.\n"; - return AddConditionToALoop(op->condition, find.found_gpu_block).mutate(op->then_case); + return AddConditionToALoop(op->condition, find.found_gpu_block)(op->then_case); } return IRMutator::visit(op); } @@ -1591,12 +1603,13 @@ class NormalizeIfStatements : public IRMutator { } // namespace Stmt fuse_gpu_thread_loops(Stmt s) { + ZoneScoped; // NormalizeIfStatements pushes the predicates between GPU blocks // into the innermost GPU block. FuseGPUThreadLoops would then // merge the predicate into the merged GPU thread. - s = NormalizeIfStatements().mutate(s); - s = FuseGPUThreadLoops().mutate(s); - s = ZeroGPULoopMins().mutate(s); + s = NormalizeIfStatements()(s); + s = FuseGPUThreadLoops()(s); + s = ZeroGPULoopMins()(s); return s; } diff --git a/src/FuzzFloatStores.cpp b/src/FuzzFloatStores.cpp index 35f311224cd1..511141889680 100644 --- a/src/FuzzFloatStores.cpp +++ b/src/FuzzFloatStores.cpp @@ -27,7 +27,7 @@ class FuzzFloatStores : public IRMutator { } // namespace Stmt fuzz_float_stores(const Stmt &s) { - return FuzzFloatStores().mutate(s); + return FuzzFloatStores()(s); } } // namespace Internal diff --git a/src/HexagonOffload.cpp b/src/HexagonOffload.cpp index ddab37ecc9b3..5cafec409579 100644 --- a/src/HexagonOffload.cpp +++ b/src/HexagonOffload.cpp @@ -693,7 +693,7 @@ class ReplaceParams : public IRMutator { }; Stmt replace_params(const Stmt &s, const std::map &replacements) { - return ReplaceParams(replacements).mutate(s); + return ReplaceParams(replacements)(s); } class InjectHexagonRpc : public IRMutator { diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index 5129c5aaa5c8..aeb271e63733 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -331,7 +331,7 @@ Expr apply_patterns(Expr x, const vector &patterns, const Target &targe } // Mutate the operands with the given mutator. for (Expr &op : matches) { - op = op_mutator->mutate(op); + op = (*op_mutator)(op); } x = replace_pattern(x, matches, p); @@ -2286,8 +2286,8 @@ Stmt optimize_hexagon_shuffles(const Stmt &s, int lut_alignment) { Stmt scatter_gather_generator(Stmt s) { // Generate vscatter-vgather instruction if target >= v65 s = substitute_in_all_lets(s); - s = ScatterGatherGenerator().mutate(s); - s = SyncronizationBarriers().mutate(s); + s = ScatterGatherGenerator()(s); + s = SyncronizationBarriers()(s); s = common_subexpression_elimination(s); return s; } @@ -2311,24 +2311,24 @@ Stmt optimize_hexagon_instructions(Stmt s, const Target &t) { // Pattern match VectorReduce IR node. Handle vector reduce instructions // before OptimizePatterns to prevent being mutated by patterns like // (v0 + v1 * c) -> add_mpy - s = VectorReducePatterns().mutate(s); + s = VectorReducePatterns()(s); debug(4) << "Hexagon: Lowering after VectorReducePatterns\n" << s << "\n"; // Peephole optimize for Hexagon instructions. These can generate // interleaves and deinterleaves alongside the HVX intrinsics. - s = OptimizePatterns(t).mutate(s); + s = OptimizePatterns(t)(s); debug(4) << "Hexagon: Lowering after OptimizePatterns\n" << s << "\n"; // Try to eliminate any redundant interleave/deinterleave pairs. - s = EliminateInterleaves(t, t.natural_vector_size(Int(8))).mutate(s); + s = EliminateInterleaves(t, t.natural_vector_size(Int(8)))(s); debug(4) << "Hexagon: Lowering after EliminateInterleaves\n" << s << "\n"; // There may be interleaves left over that we can fuse with other // operations. - s = FuseInterleaves().mutate(s); + s = FuseInterleaves()(s); debug(4) << "Hexagon: Lowering after FuseInterleaves\n" << s << "\n"; return s; diff --git a/src/IRMatch.cpp b/src/IRMatch.cpp index ffbb9406ad1e..7649fa9c2152 100644 --- a/src/IRMatch.cpp +++ b/src/IRMatch.cpp @@ -406,7 +406,7 @@ class WithLanes : public IRMutator { } // namespace Expr with_lanes(const Expr &x, int lanes) { - return WithLanes(lanes).mutate(x); + return WithLanes(lanes)(x); } } // namespace Internal diff --git a/src/IRMutator.cpp b/src/IRMutator.cpp index 9eecd0579840..c5ae276bda62 100644 --- a/src/IRMutator.cpp +++ b/src/IRMutator.cpp @@ -45,63 +45,59 @@ Expr IRMutator::visit(const Reinterpret *op) { return Reinterpret::make(op->type, std::move(value)); } -namespace { -template -Expr mutate_binary_operator(IRMutator *mutator, const T *op) { - Expr a = mutator->mutate(op->a); - Expr b = mutator->mutate(op->b); - if (a.same_as(op->a) && - b.same_as(op->b)) { - return op; - } - return T::make(std::move(a), std::move(b)); -} -} // namespace +#define mutate_binary_operator \ + Expr a = mutate(op->a); \ + Expr b = mutate(op->b); \ + if (a.same_as(op->a) && \ + b.same_as(op->b)) { \ + return op; \ + } \ + return std::decay_t::make(std::move(a), std::move(b)) Expr IRMutator::visit(const Add *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const Sub *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const Mul *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const Div *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const Mod *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const Min *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const Max *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const EQ *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const NE *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const LT *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const LE *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const GT *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const GE *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const And *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const Or *op) { - return mutate_binary_operator(this, op); + mutate_binary_operator; } Expr IRMutator::visit(const Not *op) { diff --git a/src/IRMutator.h b/src/IRMutator.h index 148caf56a54b..ab502f3edade 100644 --- a/src/IRMutator.h +++ b/src/IRMutator.h @@ -9,6 +9,7 @@ #include #include +#include "CompilerProfiling.h" #include "IR.h" namespace Halide { @@ -34,73 +35,51 @@ class IRMutator { * these in your subclass to mutate sub-expressions and * sub-statements. */ - virtual Expr mutate(const Expr &expr); - virtual Stmt mutate(const Stmt &stmt); + inline Expr operator()(const Expr &expr) { + ZoneScopedN(HalideVisitorDynamicNameTag); + return mutate(expr); + } + + inline Stmt operator()(const Stmt &stmt) { + ZoneScopedN(HalideVisitorDynamicNameTag); + return mutate(stmt); + } + + // Like mutate_with_changes, but discard the changes flag. + std::vector operator()(const std::vector &exprs) { + ZoneScopedN(HalideVisitorDynamicNameTag); + return mutate_with_changes(exprs).first; + } // Mutate all the Exprs and return the new list in ret, along with // a flag that is true iff at least one item in the list changed. std::pair, bool> mutate_with_changes(const std::vector &); +protected: + virtual Expr mutate(const Expr &expr); + virtual Stmt mutate(const Stmt &stmt); // Like mutate_with_changes, but discard the changes flag. std::vector mutate(const std::vector &exprs) { return mutate_with_changes(exprs).first; } -protected: // ExprNode<> and StmtNode<> are allowed to call visit (to implement mutate_expr/mutate_stmt()) template friend struct ExprNode; template friend struct StmtNode; - - virtual Expr visit(const IntImm *); - virtual Expr visit(const UIntImm *); - virtual Expr visit(const FloatImm *); - virtual Expr visit(const StringImm *); - virtual Expr visit(const Cast *); - virtual Expr visit(const Reinterpret *); - virtual Expr visit(const Add *); - virtual Expr visit(const Sub *); - virtual Expr visit(const Mul *); - virtual Expr visit(const Div *); - virtual Expr visit(const Mod *); - virtual Expr visit(const Min *); - virtual Expr visit(const Max *); - virtual Expr visit(const EQ *); - virtual Expr visit(const NE *); - virtual Expr visit(const LT *); - virtual Expr visit(const LE *); - virtual Expr visit(const GT *); - virtual Expr visit(const GE *); - virtual Expr visit(const And *); - virtual Expr visit(const Or *); - virtual Expr visit(const Not *); - virtual Expr visit(const Select *); - virtual Expr visit(const Load *); - virtual Expr visit(const Ramp *); - virtual Expr visit(const Broadcast *); - virtual Expr visit(const Let *); - virtual Stmt visit(const LetStmt *); - virtual Stmt visit(const AssertStmt *); - virtual Stmt visit(const ProducerConsumer *); - virtual Stmt visit(const Store *); - virtual Stmt visit(const Provide *); - virtual Stmt visit(const Allocate *); - virtual Stmt visit(const Free *); - virtual Stmt visit(const Realize *); - virtual Stmt visit(const Block *); - virtual Stmt visit(const Fork *); - virtual Stmt visit(const IfThenElse *); - virtual Stmt visit(const Evaluate *); - virtual Expr visit(const Call *); - virtual Expr visit(const Variable *); - virtual Stmt visit(const For *); - virtual Stmt visit(const Acquire *); - virtual Expr visit(const Shuffle *); - virtual Stmt visit(const Prefetch *); - virtual Stmt visit(const HoistedStorage *); - virtual Stmt visit(const Atomic *); - virtual Expr visit(const VectorReduce *); + template + friend std::pair mutate_region(Mutator *mutator, const Region &bounds, Args &&...args); + +#define HALIDE_DECL_VISIT_EXPR(T) \ + virtual Expr visit(const T *op); + HALIDE_FOR_EACH_IR_EXPR(HALIDE_DECL_VISIT_EXPR) +#undef HALIDE_DECL_VISIT_EXPR + +#define HALIDE_DECL_VISIT_STMT(T) \ + virtual Stmt visit(const T *op); + HALIDE_FOR_EACH_IR_STMT(HALIDE_DECL_VISIT_STMT) +#undef HALIDE_DECL_VISIT_STMT }; /** A mutator that caches and reapplies previously done mutations so @@ -111,10 +90,19 @@ class IRGraphMutator : public IRMutator { std::map expr_replacements; std::map stmt_replacements; -public: using IRMutator::mutate; Stmt mutate(const Stmt &s) override; Expr mutate(const Expr &e) override; + +public: + inline Expr operator()(const Expr &expr) { + ZoneScopedN(HalideVisitorDynamicNameTag); + return mutate(expr); + } + inline Stmt operator()(const Stmt &stmt) { + ZoneScopedN(HalideVisitorDynamicNameTag); + return mutate(stmt); + } }; /** A lambda-based IR mutator that accepts multiple lambdas for different @@ -131,6 +119,9 @@ struct LambdaMutator final : IRMutator { return IRMutator::visit(op); } +public: + using IRMutator::mutate; + private: LambdaOverloads handlers; @@ -149,150 +140,20 @@ struct LambdaMutator final : IRMutator { } protected: - Expr visit(const IntImm *op) override { - return this->visit_impl(op); - } - Expr visit(const UIntImm *op) override { - return this->visit_impl(op); - } - Expr visit(const FloatImm *op) override { - return this->visit_impl(op); - } - Expr visit(const StringImm *op) override { - return this->visit_impl(op); - } - Expr visit(const Cast *op) override { - return this->visit_impl(op); - } - Expr visit(const Reinterpret *op) override { - return this->visit_impl(op); - } - Expr visit(const Add *op) override { - return this->visit_impl(op); - } - Expr visit(const Sub *op) override { - return this->visit_impl(op); - } - Expr visit(const Mul *op) override { - return this->visit_impl(op); - } - Expr visit(const Div *op) override { - return this->visit_impl(op); - } - Expr visit(const Mod *op) override { - return this->visit_impl(op); - } - Expr visit(const Min *op) override { - return this->visit_impl(op); - } - Expr visit(const Max *op) override { - return this->visit_impl(op); - } - Expr visit(const EQ *op) override { - return this->visit_impl(op); - } - Expr visit(const NE *op) override { - return this->visit_impl(op); - } - Expr visit(const LT *op) override { - return this->visit_impl(op); - } - Expr visit(const LE *op) override { - return this->visit_impl(op); - } - Expr visit(const GT *op) override { - return this->visit_impl(op); - } - Expr visit(const GE *op) override { - return this->visit_impl(op); - } - Expr visit(const And *op) override { - return this->visit_impl(op); - } - Expr visit(const Or *op) override { - return this->visit_impl(op); - } - Expr visit(const Not *op) override { - return this->visit_impl(op); - } - Expr visit(const Select *op) override { - return this->visit_impl(op); - } - Expr visit(const Load *op) override { - return this->visit_impl(op); - } - Expr visit(const Ramp *op) override { - return this->visit_impl(op); - } - Expr visit(const Broadcast *op) override { - return this->visit_impl(op); - } - Expr visit(const Let *op) override { - return this->visit_impl(op); - } - Stmt visit(const LetStmt *op) override { - return this->visit_impl(op); - } - Stmt visit(const AssertStmt *op) override { - return this->visit_impl(op); - } - Stmt visit(const ProducerConsumer *op) override { - return this->visit_impl(op); - } - Stmt visit(const Store *op) override { - return this->visit_impl(op); - } - Stmt visit(const Provide *op) override { - return this->visit_impl(op); - } - Stmt visit(const Allocate *op) override { - return this->visit_impl(op); - } - Stmt visit(const Free *op) override { - return this->visit_impl(op); - } - Stmt visit(const Realize *op) override { - return this->visit_impl(op); - } - Stmt visit(const Block *op) override { - return this->visit_impl(op); - } - Stmt visit(const Fork *op) override { - return this->visit_impl(op); - } - Stmt visit(const IfThenElse *op) override { - return this->visit_impl(op); - } - Stmt visit(const Evaluate *op) override { - return this->visit_impl(op); - } - Expr visit(const Call *op) override { - return this->visit_impl(op); - } - Expr visit(const Variable *op) override { - return this->visit_impl(op); - } - Stmt visit(const For *op) override { - return this->visit_impl(op); - } - Stmt visit(const Acquire *op) override { - return this->visit_impl(op); - } - Expr visit(const Shuffle *op) override { - return this->visit_impl(op); - } - Stmt visit(const Prefetch *op) override { - return this->visit_impl(op); - } - Stmt visit(const HoistedStorage *op) override { - return this->visit_impl(op); - } - Stmt visit(const Atomic *op) override { - return this->visit_impl(op); - } - Expr visit(const VectorReduce *op) override { - return this->visit_impl(op); - } +#define HALIDE_CALL_VISIT_EXPR_IMPL(T) \ + Expr visit(const T *op) override { \ + ZoneScopedVisitor(IRNodeType::T, "LambdaMutator", Profiling::BIT_EXPR); \ + return this->visit_impl(op); \ + } + HALIDE_FOR_EACH_IR_EXPR(HALIDE_CALL_VISIT_EXPR_IMPL) +#undef HALIDE_CALL_VISIT_EXPR_IMPL +#define HALIDE_CALL_VISIT_STMT_IMPL(T) \ + Stmt visit(const T *op) override { \ + ZoneScopedVisitor(IRNodeType::T, "LambdaMutator", Profiling::BIT_STMT); \ + return this->visit_impl(op); \ + } + HALIDE_FOR_EACH_IR_STMT(HALIDE_CALL_VISIT_STMT_IMPL) +#undef HALIDE_CALL_VISIT_STMT_IMPL }; /** A lambda-based IR mutator that accepts multiple lambdas for overloading @@ -312,6 +173,7 @@ struct LambdaMutatorGeneric final : IRMutator { return IRMutator::mutate(op); } +public: Expr mutate(const Expr &e) override { if constexpr (std::is_invocable_v) { return handlers(this, e); @@ -338,7 +200,7 @@ auto mutate_with(const T &ir, Lambdas &&...lambdas) { using Generic = LambdaMutatorGeneric; if constexpr (std::is_invocable_v || std::is_invocable_v) { - return LambdaMutatorGeneric{std::forward(lambdas)...}.mutate(ir); + return LambdaMutatorGeneric{std::forward(lambdas)...}(ir); } else { LambdaMutator mutator{std::forward(lambdas)...}; // Each lambda must take two args: (auto *self, op). @@ -350,7 +212,7 @@ auto mutate_with(const T &ir, Lambdas &&...lambdas) { ...); static_assert(all_take_two_args, "All mutate_with lambdas must take two arguments: (auto *self, const T *op)"); - return mutator.mutate(ir); + return mutator(ir); } } diff --git a/src/IROperator.cpp b/src/IROperator.cpp index c729539daa29..bd5319944f26 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -1075,18 +1075,22 @@ T remove_intrinsics(const T &e, const std::initializer_list & } // namespace Expr remove_likelies(const Expr &e) { + ZoneScoped; return remove_intrinsics(e, {Call::likely, Call::likely_if_innermost}); } Stmt remove_likelies(const Stmt &s) { + ZoneScoped; return remove_intrinsics(s, {Call::likely, Call::likely_if_innermost}); } Expr remove_promises(const Expr &e) { + ZoneScoped; return remove_intrinsics(e, {Call::promise_clamped, Call::unsafe_promise_clamped}); } Stmt remove_promises(const Stmt &s) { + ZoneScoped; return remove_intrinsics(s, {Call::promise_clamped, Call::unsafe_promise_clamped}); } diff --git a/src/IRPrinter.h b/src/IRPrinter.h index de812bdb0c47..12ffacb4fa88 100644 --- a/src/IRPrinter.h +++ b/src/IRPrinter.h @@ -65,7 +65,7 @@ class Closure; struct Interval; struct ConstantInterval; struct ModulusRemainder; -enum class IRNodeType; +enum class IRNodeType : uint8_t; /** Emit a halide node type on an output stream (such as std::cout) in * human-readable form */ diff --git a/src/IRVisitor.h b/src/IRVisitor.h index a14e71558e89..e8e67dd57b7b 100644 --- a/src/IRVisitor.h +++ b/src/IRVisitor.h @@ -3,6 +3,7 @@ #include +#include "CompilerProfiling.h" #include "IR.h" /** \file @@ -21,6 +22,22 @@ class IRVisitor { IRVisitor() = default; virtual ~IRVisitor() = default; + inline void operator()(const Stmt &s) { + ZoneScopedN(HalideVisitorDynamicNameTag); + s.accept(this); + } + + inline void operator()(const Expr &e) { + ZoneScopedN(HalideVisitorDynamicNameTag); + e.accept(this); + } + + template + inline void operator()(const T *op) { + ZoneScopedN(HalideVisitorDynamicNameTag); + visit(op); + } + protected: // ExprNode<> and StmtNode<> are allowed to call visit (to implement accept()) template @@ -29,54 +46,10 @@ class IRVisitor { template friend struct StmtNode; - virtual void visit(const IntImm *); - virtual void visit(const UIntImm *); - virtual void visit(const FloatImm *); - virtual void visit(const StringImm *); - virtual void visit(const Cast *); - virtual void visit(const Reinterpret *); - virtual void visit(const Add *); - virtual void visit(const Sub *); - virtual void visit(const Mul *); - virtual void visit(const Div *); - virtual void visit(const Mod *); - virtual void visit(const Min *); - virtual void visit(const Max *); - virtual void visit(const EQ *); - virtual void visit(const NE *); - virtual void visit(const LT *); - virtual void visit(const LE *); - virtual void visit(const GT *); - virtual void visit(const GE *); - virtual void visit(const And *); - virtual void visit(const Or *); - virtual void visit(const Not *); - virtual void visit(const Select *); - virtual void visit(const Load *); - virtual void visit(const Ramp *); - virtual void visit(const Broadcast *); - virtual void visit(const Let *); - virtual void visit(const LetStmt *); - virtual void visit(const AssertStmt *); - virtual void visit(const ProducerConsumer *); - virtual void visit(const Store *); - virtual void visit(const Provide *); - virtual void visit(const Allocate *); - virtual void visit(const Free *); - virtual void visit(const Realize *); - virtual void visit(const Block *); - virtual void visit(const Fork *); - virtual void visit(const IfThenElse *); - virtual void visit(const Evaluate *); - virtual void visit(const Call *); - virtual void visit(const Variable *); - virtual void visit(const For *); - virtual void visit(const Acquire *); - virtual void visit(const Shuffle *); - virtual void visit(const Prefetch *); - virtual void visit(const HoistedStorage *); - virtual void visit(const Atomic *); - virtual void visit(const VectorReduce *); +#define HALIDE_DECL_VISIT(T) \ + virtual void visit(const T *op); + HALIDE_FOR_EACH_IR_NODE(HALIDE_DECL_VISIT) +#undef HALIDE_DECL_VISIT }; /** A lambda-based IR visitor that accepts multiple lambdas for different @@ -111,150 +84,20 @@ struct LambdaVisitor final : IRVisitor { } protected: - void visit(const IntImm *op) override { - this->visit_impl(op); - } - void visit(const UIntImm *op) override { - this->visit_impl(op); - } - void visit(const FloatImm *op) override { - this->visit_impl(op); - } - void visit(const StringImm *op) override { - this->visit_impl(op); - } - void visit(const Cast *op) override { - this->visit_impl(op); - } - void visit(const Reinterpret *op) override { - this->visit_impl(op); - } - void visit(const Add *op) override { - this->visit_impl(op); - } - void visit(const Sub *op) override { - this->visit_impl(op); - } - void visit(const Mul *op) override { - this->visit_impl(op); - } - void visit(const Div *op) override { - this->visit_impl(op); - } - void visit(const Mod *op) override { - this->visit_impl(op); - } - void visit(const Min *op) override { - this->visit_impl(op); - } - void visit(const Max *op) override { - this->visit_impl(op); - } - void visit(const EQ *op) override { - this->visit_impl(op); - } - void visit(const NE *op) override { - this->visit_impl(op); - } - void visit(const LT *op) override { - this->visit_impl(op); - } - void visit(const LE *op) override { - this->visit_impl(op); - } - void visit(const GT *op) override { - this->visit_impl(op); - } - void visit(const GE *op) override { - this->visit_impl(op); - } - void visit(const And *op) override { - this->visit_impl(op); - } - void visit(const Or *op) override { - this->visit_impl(op); - } - void visit(const Not *op) override { - this->visit_impl(op); - } - void visit(const Select *op) override { - this->visit_impl(op); - } - void visit(const Load *op) override { - this->visit_impl(op); - } - void visit(const Ramp *op) override { - this->visit_impl(op); - } - void visit(const Broadcast *op) override { - this->visit_impl(op); - } - void visit(const Let *op) override { - this->visit_impl(op); - } - void visit(const LetStmt *op) override { - this->visit_impl(op); - } - void visit(const AssertStmt *op) override { - this->visit_impl(op); - } - void visit(const ProducerConsumer *op) override { - this->visit_impl(op); - } - void visit(const Store *op) override { - this->visit_impl(op); - } - void visit(const Provide *op) override { - this->visit_impl(op); - } - void visit(const Allocate *op) override { - this->visit_impl(op); - } - void visit(const Free *op) override { - this->visit_impl(op); - } - void visit(const Realize *op) override { - this->visit_impl(op); - } - void visit(const Block *op) override { - this->visit_impl(op); - } - void visit(const Fork *op) override { - this->visit_impl(op); - } - void visit(const IfThenElse *op) override { - this->visit_impl(op); - } - void visit(const Evaluate *op) override { - this->visit_impl(op); - } - void visit(const Call *op) override { - this->visit_impl(op); - } - void visit(const Variable *op) override { - this->visit_impl(op); - } - void visit(const For *op) override { - this->visit_impl(op); - } - void visit(const Acquire *op) override { - this->visit_impl(op); - } - void visit(const Shuffle *op) override { - this->visit_impl(op); - } - void visit(const Prefetch *op) override { - this->visit_impl(op); - } - void visit(const HoistedStorage *op) override { - this->visit_impl(op); - } - void visit(const Atomic *op) override { - this->visit_impl(op); - } - void visit(const VectorReduce *op) override { - this->visit_impl(op); - } +#define HALIDE_CALL_VISIT_EXPR_IMPL(T) \ + void visit(const T *op) override { \ + ZoneScopedVisitor(IRNodeType::T, "LambdaVisitor", Profiling::BIT_EXPR); \ + this->visit_impl(op); \ + } + HALIDE_FOR_EACH_IR_EXPR(HALIDE_CALL_VISIT_EXPR_IMPL) +#undef HALIDE_CALL_VISIT_EXPR_IMPL +#define HALIDE_CALL_VISIT_STMT_IMPL(T) \ + void visit(const T *op) override { \ + ZoneScopedVisitor(IRNodeType::T, "LambdaVisitor", Profiling::BIT_STMT); \ + this->visit_impl(op); \ + } + HALIDE_FOR_EACH_IR_STMT(HALIDE_CALL_VISIT_STMT_IMPL) +#undef HALIDE_CALL_VISIT_STMT_IMPL }; template @@ -281,6 +124,22 @@ void visit_with(const IRHandle &ir, Lambdas &&...lambdas) { * without visiting the same node twice. This is for passes that are * capable of interpreting the IR as a DAG instead of a tree. */ class IRGraphVisitor : public IRVisitor { +public: + inline void operator()(const Expr &e) { + ZoneScopedN(HalideVisitorDynamicNameTag); + include(e); + } + inline void operator()(const Stmt &s) { + ZoneScopedN(HalideVisitorDynamicNameTag); + include(s); + } + +private: + /** The nodes visited so far. Only includes nodes with a ref count greater + * than one, because we know that nodes with a ref count of 1 will only be + * visited once if their parents are only visited once. */ + std::set visited; + protected: /** By default these methods add the node to the visited set, and * return whether or not it was already there. If it wasn't there, @@ -291,64 +150,13 @@ class IRGraphVisitor : public IRVisitor { virtual void include(const Stmt &); // @} -private: - /** The nodes visited so far. Only includes nodes with a ref count greater - * than one, because we know that nodes with a ref count of 1 will only be - * visited once if their parents are only visited once. */ - std::set visited; - -protected: /** These methods should call 'include' on the children to only * visit them if they haven't been visited already. */ // @{ - void visit(const IntImm *) override; - void visit(const UIntImm *) override; - void visit(const FloatImm *) override; - void visit(const StringImm *) override; - void visit(const Cast *) override; - void visit(const Reinterpret *) override; - void visit(const Add *) override; - void visit(const Sub *) override; - void visit(const Mul *) override; - void visit(const Div *) override; - void visit(const Mod *) override; - void visit(const Min *) override; - void visit(const Max *) override; - void visit(const EQ *) override; - void visit(const NE *) override; - void visit(const LT *) override; - void visit(const LE *) override; - void visit(const GT *) override; - void visit(const GE *) override; - void visit(const And *) override; - void visit(const Or *) override; - void visit(const Not *) override; - void visit(const Select *) override; - void visit(const Load *) override; - void visit(const Ramp *) override; - void visit(const Broadcast *) override; - void visit(const Let *) override; - void visit(const LetStmt *) override; - void visit(const AssertStmt *) override; - void visit(const ProducerConsumer *) override; - void visit(const Store *) override; - void visit(const Provide *) override; - void visit(const Allocate *) override; - void visit(const Free *) override; - void visit(const Realize *) override; - void visit(const Block *) override; - void visit(const Fork *) override; - void visit(const IfThenElse *) override; - void visit(const Evaluate *) override; - void visit(const Call *) override; - void visit(const Variable *) override; - void visit(const For *) override; - void visit(const Acquire *) override; - void visit(const Shuffle *) override; - void visit(const Prefetch *) override; - void visit(const HoistedStorage *) override; - void visit(const Atomic *) override; - void visit(const VectorReduce *) override; +#define HALIDE_VISIT_OVERRIDE(T) \ + void visit(const T *) override; + HALIDE_FOR_EACH_IR_NODE(HALIDE_VISIT_OVERRIDE) +#undef HALIDE_VISIT_OVERRIDE // @} }; @@ -360,94 +168,28 @@ class IRGraphVisitor : public IRVisitor { template class VariadicVisitor { private: +#ifdef WITH_COMPILER_PROFILING +#ifdef HALIDE_ENABLE_RTTI + const char *name = typeid(T).name(); +#else + const char *name = "VariadicVisitor"; +#endif +#endif + template ExprRet dispatch_expr(const BaseExprNode *node, Args &&...args) { if (node == nullptr) { return ExprRet{}; } + ZoneScopedVisitor(node->node_type, name, Profiling::BIT_EXPR); switch (node->node_type) { - case IRNodeType::IntImm: - return ((T *)this)->visit((const IntImm *)node, std::forward(args)...); - case IRNodeType::UIntImm: - return ((T *)this)->visit((const UIntImm *)node, std::forward(args)...); - case IRNodeType::FloatImm: - return ((T *)this)->visit((const FloatImm *)node, std::forward(args)...); - case IRNodeType::StringImm: - return ((T *)this)->visit((const StringImm *)node, std::forward(args)...); - case IRNodeType::Broadcast: - return ((T *)this)->visit((const Broadcast *)node, std::forward(args)...); - case IRNodeType::Cast: - return ((T *)this)->visit((const Cast *)node, std::forward(args)...); - case IRNodeType::Reinterpret: - return ((T *)this)->visit((const Reinterpret *)node, std::forward(args)...); - case IRNodeType::Variable: - return ((T *)this)->visit((const Variable *)node, std::forward(args)...); - case IRNodeType::Add: - return ((T *)this)->visit((const Add *)node, std::forward(args)...); - case IRNodeType::Sub: - return ((T *)this)->visit((const Sub *)node, std::forward(args)...); - case IRNodeType::Mod: - return ((T *)this)->visit((const Mod *)node, std::forward(args)...); - case IRNodeType::Mul: - return ((T *)this)->visit((const Mul *)node, std::forward(args)...); - case IRNodeType::Div: - return ((T *)this)->visit((const Div *)node, std::forward(args)...); - case IRNodeType::Min: - return ((T *)this)->visit((const Min *)node, std::forward(args)...); - case IRNodeType::Max: - return ((T *)this)->visit((const Max *)node, std::forward(args)...); - case IRNodeType::EQ: - return ((T *)this)->visit((const EQ *)node, std::forward(args)...); - case IRNodeType::NE: - return ((T *)this)->visit((const NE *)node, std::forward(args)...); - case IRNodeType::LT: - return ((T *)this)->visit((const LT *)node, std::forward(args)...); - case IRNodeType::LE: - return ((T *)this)->visit((const LE *)node, std::forward(args)...); - case IRNodeType::GT: - return ((T *)this)->visit((const GT *)node, std::forward(args)...); - case IRNodeType::GE: - return ((T *)this)->visit((const GE *)node, std::forward(args)...); - case IRNodeType::And: - return ((T *)this)->visit((const And *)node, std::forward(args)...); - case IRNodeType::Or: - return ((T *)this)->visit((const Or *)node, std::forward(args)...); - case IRNodeType::Not: - return ((T *)this)->visit((const Not *)node, std::forward(args)...); - case IRNodeType::Select: - return ((T *)this)->visit((const Select *)node, std::forward(args)...); - case IRNodeType::Load: - return ((T *)this)->visit((const Load *)node, std::forward(args)...); - case IRNodeType::Ramp: - return ((T *)this)->visit((const Ramp *)node, std::forward(args)...); - case IRNodeType::Call: - return ((T *)this)->visit((const Call *)node, std::forward(args)...); - case IRNodeType::Let: - return ((T *)this)->visit((const Let *)node, std::forward(args)...); - case IRNodeType::Shuffle: - return ((T *)this)->visit((const Shuffle *)node, std::forward(args)...); - case IRNodeType::VectorReduce: - return ((T *)this)->visit((const VectorReduce *)node, std::forward(args)...); - // Explicitly list the Stmt types rather than using a - // default case so that when new IR nodes are added we - // don't miss them here. - case IRNodeType::LetStmt: - case IRNodeType::AssertStmt: - case IRNodeType::ProducerConsumer: - case IRNodeType::For: - case IRNodeType::Acquire: - case IRNodeType::Store: - case IRNodeType::Provide: - case IRNodeType::Allocate: - case IRNodeType::Free: - case IRNodeType::Realize: - case IRNodeType::Block: - case IRNodeType::Fork: - case IRNodeType::IfThenElse: - case IRNodeType::Evaluate: - case IRNodeType::Prefetch: - case IRNodeType::Atomic: - case IRNodeType::HoistedStorage: +#define HALIDE_SWITCH_EXPR(NT) \ + case IRNodeType::NT: \ + return ((T *)this)->visit((const NT *)node, std::forward(args)...); + HALIDE_FOR_EACH_IR_EXPR(HALIDE_SWITCH_EXPR) +#undef HALIDE_SWITCH_EXPR + + default: internal_error << "Unreachable"; } return ExprRet{}; @@ -458,74 +200,17 @@ class VariadicVisitor { if (node == nullptr) { return StmtRet{}; } + ZoneScopedVisitor(node->node_type, name, Profiling::BIT_STMT); switch (node->node_type) { - case IRNodeType::IntImm: - case IRNodeType::UIntImm: - case IRNodeType::FloatImm: - case IRNodeType::StringImm: - case IRNodeType::Broadcast: - case IRNodeType::Cast: - case IRNodeType::Reinterpret: - case IRNodeType::Variable: - case IRNodeType::Add: - case IRNodeType::Sub: - case IRNodeType::Mod: - case IRNodeType::Mul: - case IRNodeType::Div: - case IRNodeType::Min: - case IRNodeType::Max: - case IRNodeType::EQ: - case IRNodeType::NE: - case IRNodeType::LT: - case IRNodeType::LE: - case IRNodeType::GT: - case IRNodeType::GE: - case IRNodeType::And: - case IRNodeType::Or: - case IRNodeType::Not: - case IRNodeType::Select: - case IRNodeType::Load: - case IRNodeType::Ramp: - case IRNodeType::Call: - case IRNodeType::Let: - case IRNodeType::Shuffle: - case IRNodeType::VectorReduce: +#define HALIDE_SWITCH_STMT(NT) \ + case IRNodeType::NT: \ + return ((T *)this)->visit((const NT *)node, std::forward(args)...); + HALIDE_FOR_EACH_IR_STMT(HALIDE_SWITCH_STMT) +#undef HALIDE_SWITCH_STMT + + default: internal_error << "Unreachable"; break; - case IRNodeType::LetStmt: - return ((T *)this)->visit((const LetStmt *)node, std::forward(args)...); - case IRNodeType::AssertStmt: - return ((T *)this)->visit((const AssertStmt *)node, std::forward(args)...); - case IRNodeType::ProducerConsumer: - return ((T *)this)->visit((const ProducerConsumer *)node, std::forward(args)...); - case IRNodeType::For: - return ((T *)this)->visit((const For *)node, std::forward(args)...); - case IRNodeType::Acquire: - return ((T *)this)->visit((const Acquire *)node, std::forward(args)...); - case IRNodeType::Store: - return ((T *)this)->visit((const Store *)node, std::forward(args)...); - case IRNodeType::Provide: - return ((T *)this)->visit((const Provide *)node, std::forward(args)...); - case IRNodeType::Allocate: - return ((T *)this)->visit((const Allocate *)node, std::forward(args)...); - case IRNodeType::Free: - return ((T *)this)->visit((const Free *)node, std::forward(args)...); - case IRNodeType::Realize: - return ((T *)this)->visit((const Realize *)node, std::forward(args)...); - case IRNodeType::Block: - return ((T *)this)->visit((const Block *)node, std::forward(args)...); - case IRNodeType::Fork: - return ((T *)this)->visit((const Fork *)node, std::forward(args)...); - case IRNodeType::IfThenElse: - return ((T *)this)->visit((const IfThenElse *)node, std::forward(args)...); - case IRNodeType::Evaluate: - return ((T *)this)->visit((const Evaluate *)node, std::forward(args)...); - case IRNodeType::Prefetch: - return ((T *)this)->visit((const Prefetch *)node, std::forward(args)...); - case IRNodeType::Atomic: - return ((T *)this)->visit((const Atomic *)node, std::forward(args)...); - case IRNodeType::HoistedStorage: - return ((T *)this)->visit((const HoistedStorage *)node, std::forward(args)...); } return StmtRet{}; } diff --git a/src/InjectHostDevBufferCopies.cpp b/src/InjectHostDevBufferCopies.cpp index 4a899ee61fd6..c3b57b0c8435 100644 --- a/src/InjectHostDevBufferCopies.cpp +++ b/src/InjectHostDevBufferCopies.cpp @@ -1,6 +1,7 @@ #include "InjectHostDevBufferCopies.h" #include "CodeGen_GPU_Dev.h" +#include "CompilerProfiling.h" #include "Debug.h" #include "ExternFuncArgument.h" #include "IRMutator.h" @@ -29,6 +30,7 @@ Stmt call_extern_and_assert(const string &name, const vector &args) { namespace { class FindBufferUsage : public IRVisitor { +protected: using IRVisitor::visit; void visit(const Load *op) override { @@ -142,6 +144,7 @@ class FindBufferUsage : public IRVisitor { // the buffer as we go, sniffing usage within each leaf using // FindBufferUsage, and injecting device buffer logic as needed. class InjectBufferCopiesForSingleBuffer : public IRMutator { +protected: using IRMutator::visit; // The buffer being managed @@ -215,10 +218,11 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator { } Stmt do_copies(Stmt s, FindBufferUsage *precomputed = nullptr) { + ZoneScoped; // Sniff what happens to the buffer inside the stmt FindBufferUsage local_finder(buffer, DeviceAPI::Host); if (!precomputed) { - s.accept(&local_finder); + local_finder(s); precomputed = &local_finder; } FindBufferUsage &finder = *precomputed; @@ -356,7 +360,7 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator { // leaf. Stmt visit(const For *op) override { FindBufferUsage finder(buffer, DeviceAPI::Host); - op->accept(&finder); + finder(op); if (finder.devices_touched.size() > 1) { // The state of the buffer going into the loop is the // union of the state before the loop starts and the state @@ -480,12 +484,14 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator { // Inject the buffer-handling logic for all internal // allocations. Inputs and outputs are handled below. class InjectBufferCopies : public IRMutator { +protected: using IRMutator::visit; // Inject the registration of a device destructor just after the // .buffer symbol is defined (which is safely before the first // device_malloc). class InjectDeviceDestructor : public IRMutator { + protected: using IRMutator::visit; Stmt visit(const LetStmt *op) override { @@ -514,6 +520,7 @@ class InjectBufferCopies : public IRMutator { // and an Allocate node that takes its host field from the // .buffer. class InjectCombinedAllocation : public IRMutator { + protected: using IRMutator::visit; Stmt visit(const LetStmt *op) override { @@ -565,6 +572,7 @@ class InjectBufferCopies : public IRMutator { }; Stmt inject_free_after_last_use(Stmt body, const Stmt &last_use, const Stmt &free_stmt) { + ZoneScoped; bool success = false; body = mutate_with( body, @@ -582,7 +590,7 @@ class InjectBufferCopies : public IRMutator { Stmt visit(const Allocate *op) override { FindBufferUsage finder(op->name, DeviceAPI::Host); - op->body.accept(&finder); + finder(op->body); bool touched_on_host = finder.devices_touched.count(DeviceAPI::Host); bool touched_on_device = finder.devices_touched.size() > (touched_on_host ? 1 : 0); @@ -595,7 +603,7 @@ class InjectBufferCopies : public IRMutator { Stmt body = mutate(op->body); InjectBufferCopiesForSingleBuffer injector(op->name, false, op->memory_type); - body = injector.mutate(body); + body = injector(body); string buffer_name = op->name + ".buffer"; Expr buffer = Variable::make(Handle(), buffer_name); @@ -621,8 +629,7 @@ class InjectBufferCopies : public IRMutator { Expr device_interface = make_device_interface_call(touching_device, op->memory_type); return InjectCombinedAllocation(op->name, op->type, op->extents, - op->condition, device_interface) - .mutate(body); + op->condition, device_interface)(body); } else { // Only touched on host but passed to an extern stage, or // only touched on device, or touched on multiple @@ -636,7 +643,7 @@ class InjectBufferCopies : public IRMutator { } // Add a device destructor - body = InjectDeviceDestructor(buffer_name).mutate(body); + body = InjectDeviceDestructor(buffer_name)(body); Expr condition = op->condition; bool touched_on_one_device = !touched_on_host && finder.devices_touched.size() == 1 && @@ -673,6 +680,7 @@ class InjectBufferCopies : public IRMutator { // ProducerConsumer node. Sometimes it's a Block containing a pair of // them. class FindOutermostProduce : public IRVisitor { +protected: using IRVisitor::visit; void visit(const Block *op) override { @@ -695,10 +703,12 @@ class FindOutermostProduce : public IRVisitor { // Inject the buffer handling code for the inputs and outputs at the // appropriate site. class InjectBufferCopiesForInputsAndOutputs : public IRMutator { +protected: Stmt site; // Find all references to external buffers. class FindInputsAndOutputs : public IRVisitor { + protected: using IRVisitor::visit; void include(const Parameter &p) { @@ -752,10 +762,11 @@ class InjectBufferCopiesForInputsAndOutputs : public IRMutator { Stmt mutate(const Stmt &s) override { if (s.same_as(site)) { FindInputsAndOutputs finder; - s.accept(&finder); + finder(s); Stmt new_stmt = s; for (const string &buf : finder.result) { - new_stmt = InjectBufferCopiesForSingleBuffer(buf, true, finder.result_storage.at(buf)).mutate(new_stmt); + ZoneScopedN("InjectBufferCopiesForSingleBuffer"); + new_stmt = InjectBufferCopiesForSingleBuffer(buf, true, finder.result_storage.at(buf))(new_stmt); } return new_stmt; } else { @@ -782,15 +793,15 @@ Stmt inject_host_dev_buffer_copies(Stmt s, const Target &t) { } // Handle internal allocations - s = InjectBufferCopies().mutate(s); + s = InjectBufferCopies()(s); // Handle inputs and outputs FindOutermostProduce outermost; - s.accept(&outermost); + outermost(s); if (outermost.result.defined()) { // If the entire pipeline simplified away, or just dispatches // to another pipeline, there may be no outermost produce. - s = InjectBufferCopiesForInputsAndOutputs(outermost.result).mutate(s); + s = InjectBufferCopiesForInputsAndOutputs(outermost.result)(s); } return s; diff --git a/src/Inline.cpp b/src/Inline.cpp index 31b6efcdf749..ce829f1c7326 100644 --- a/src/Inline.cpp +++ b/src/Inline.cpp @@ -181,15 +181,13 @@ class Inliner : public IRMutator { } }; -Stmt inline_function(Stmt s, const Function &f) { - Inliner i(f); - s = i.mutate(s); - return s; +Stmt inline_function(const Stmt &s, const Function &f) { + return Inliner(f)(s); } Expr inline_function(Expr e, const Function &f) { Inliner i(f); - e = i.mutate(e); + e = i(e); // TODO: making this > 1 should be desirable, // but explodes compiletimes in some situations. if (i.found > 0) { diff --git a/src/Inline.h b/src/Inline.h index fbbd78751e18..344e7c7ddf6d 100644 --- a/src/Inline.h +++ b/src/Inline.h @@ -16,7 +16,7 @@ class Function; * be inlined, it must not have any specializations (i.e. it can only have one * values definition). */ // @{ -Stmt inline_function(Stmt s, const Function &f); +Stmt inline_function(const Stmt &s, const Function &f); Expr inline_function(Expr e, const Function &f); void inline_function(Function caller, const Function &f); // @} diff --git a/src/InlineReductions.cpp b/src/InlineReductions.cpp index e3fc8c1311d2..2bd2873b2736 100644 --- a/src/InlineReductions.cpp +++ b/src/InlineReductions.cpp @@ -126,7 +126,7 @@ Expr sum(const RDom &r, Expr e, const Func &f) { << " passed to sum already has a definition"; Internal::FindFreeVars v(r, f.name()); - e = v.mutate(common_subexpression_elimination(e)); + e = v(common_subexpression_elimination(e)); user_assert(v.rdom.defined()) << "Expression passed to sum must reference a reduction domain"; @@ -152,7 +152,7 @@ Expr saturating_sum(const RDom &r, Expr e, const Func &f) { << " passed to saturating_sum already has a definition"; Internal::FindFreeVars v(r, f.name()); - e = v.mutate(common_subexpression_elimination(e)); + e = v(common_subexpression_elimination(e)); user_assert(v.rdom.defined()) << "Expression passed to saturating_sum must reference a reduction domain"; @@ -179,7 +179,7 @@ Expr product(const RDom &r, Expr e, const Func &f) { << " passed to product already has a definition"; Internal::FindFreeVars v(r, f.name()); - e = v.mutate(common_subexpression_elimination(e)); + e = v(common_subexpression_elimination(e)); user_assert(v.rdom.defined()) << "Expression passed to product must reference a reduction domain"; @@ -205,7 +205,7 @@ Expr maximum(const RDom &r, Expr e, const Func &f) { << " passed to maximum already has a definition"; Internal::FindFreeVars v(r, f.name()); - e = v.mutate(common_subexpression_elimination(e)); + e = v(common_subexpression_elimination(e)); user_assert(v.rdom.defined()) << "Expression passed to maximum must reference a reduction domain"; @@ -232,7 +232,7 @@ Expr minimum(const RDom &r, Expr e, const Func &f) { << " passed to minimum already has a definition"; Internal::FindFreeVars v(r, f.name()); - e = v.mutate(common_subexpression_elimination(e)); + e = v(common_subexpression_elimination(e)); user_assert(v.rdom.defined()) << "Expression passed to minimum must reference a reduction domain"; @@ -259,7 +259,7 @@ Tuple argmax(const RDom &r, Expr e, const Func &f) { << " passed to argmax already has a definition"; Internal::FindFreeVars v(r, f.name()); - e = v.mutate(common_subexpression_elimination(e)); + e = v(common_subexpression_elimination(e)); user_assert(v.rdom.defined()) << "Expression passed to argmax must reference a reduction domain"; @@ -298,7 +298,7 @@ Tuple argmin(const RDom &r, Expr e, const Func &f) { << " passed to argmin already has a definition"; Internal::FindFreeVars v(r, f.name()); - e = v.mutate(common_subexpression_elimination(e)); + e = v(common_subexpression_elimination(e)); user_assert(v.rdom.defined()) << "Expression passed to argmin must reference a reduction domain"; diff --git a/src/JITModule.cpp b/src/JITModule.cpp index cc88e69a7dbc..cfd4d7d601b9 100644 --- a/src/JITModule.cpp +++ b/src/JITModule.cpp @@ -285,6 +285,7 @@ namespace { // Retrieve a function pointer from an llvm module, possibly by compiling it. JITModule::Symbol compile_and_get_function(llvm::orc::LLJIT &JIT, const string &name) { + ZoneScoped; debug(2) << "JIT Compiling " << name << "\n"; auto addr = JIT.lookup(name); @@ -341,6 +342,7 @@ JITModule::JITModule() { JITModule::JITModule(const Module &m, const LoweredFunc &fn, const std::vector &dependencies) { + ZoneScoped; jit_module = new JITModuleContents(); std::unique_ptr llvm_module(compile_module_to_llvm_module(m, *jit_module->context)); std::vector deps_with_runtime = dependencies; @@ -358,6 +360,7 @@ void compile_module_impl( std::unique_ptr m, const string &function_name, const Target &target, const std::vector &dependencies, const std::vector &requested_exports) { + ZoneScoped; // Ensure that LLVM is initialized CodeGen_LLVM::initialize_llvm(); @@ -518,6 +521,7 @@ void compile_module_impl( void JITModule::compile_module(std::unique_ptr m, const string &function_name, const Target &target, const std::vector &dependencies, const std::vector &requested_exports) { + ZoneScoped; // LLJIT's SimpleCompiler triggers LLVM's AsmPrinter, which can use a large // amount of stack (observed stack overflows on macOS worker threads with // 512KB default stacks). Use run_with_large_stack to ensure enough space. diff --git a/src/LICM.cpp b/src/LICM.cpp index ca5fcec3bd8b..a34c194cecb5 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -1,5 +1,6 @@ #include "LICM.h" #include "CSE.h" +#include "CompilerProfiling.h" #include "ExprUsesVar.h" #include "IREquality.h" #include "IRMutator.h" @@ -21,6 +22,7 @@ namespace { // Is it safe to lift an Expr out of a loop (and potentially across a device boundary) class CanLift : public IRVisitor { +protected: using IRVisitor::visit; void visit(const Call *op) override { @@ -54,6 +56,7 @@ class CanLift : public IRVisitor { // Lift pure loop invariants to the top level. Applied independently // to each loop. class LiftLoopInvariants : public IRMutator { +protected: using IRMutator::visit; Scope<> varying; @@ -183,6 +186,7 @@ class LiftLoopInvariants : public IRMutator { // them as just renamings of other variables. Easier to substitute // them in as a post-pass rather than make the pass above more clever. class SubstituteTrivialLets : public IRMutator { +protected: using IRMutator::visit; Expr visit(const Let *op) override { @@ -203,6 +207,7 @@ class SubstituteTrivialLets : public IRMutator { }; class LICM : public IRMutator { +protected: using IRMutator::visit; bool in_gpu_loop{false}; @@ -246,8 +251,8 @@ class LICM : public IRMutator { // Lift invariants LiftLoopInvariants lifter; - Stmt new_stmt = lifter.mutate(op); - new_stmt = SubstituteTrivialLets().mutate(new_stmt); + Stmt new_stmt = lifter(op); + new_stmt = SubstituteTrivialLets()(new_stmt); // As an optimization to reduce register pressure, take // the set of expressions to lift and check if any can @@ -336,6 +341,7 @@ class LICM : public IRMutator { // Reassociate summations to group together the loop invariants. Useful to run before LICM. class GroupLoopInvariants : public IRMutator { +protected: using IRMutator::visit; Scope var_depth; @@ -520,9 +526,10 @@ class GroupLoopInvariants : public IRMutator { } // namespace Stmt hoist_loop_invariant_values(Stmt s) { - s = GroupLoopInvariants().mutate(s); + ZoneScoped; + s = GroupLoopInvariants()(s); s = common_subexpression_elimination(s); - s = LICM().mutate(s); + s = LICM()(s); s = simplify_exprs(s); return s; } @@ -532,6 +539,7 @@ namespace { // Move IfThenElse nodes from the inside of a piece of Stmt IR to the // outside when legal. class HoistIfStatements : public IRMutator { +protected: using IRMutator::visit; Stmt visit(const LetStmt *op) override { @@ -656,9 +664,8 @@ class HoistIfStatements : public IRMutator { } // namespace -Stmt hoist_loop_invariant_if_statements(Stmt s) { - s = HoistIfStatements().mutate(s); - return s; +Stmt hoist_loop_invariant_if_statements(const Stmt &s) { + return HoistIfStatements()(s); } } // namespace Internal diff --git a/src/LICM.h b/src/LICM.h index 3d04db35143e..6c9cd16116d7 100644 --- a/src/LICM.h +++ b/src/LICM.h @@ -19,7 +19,7 @@ Stmt hoist_loop_invariant_values(Stmt); /** Just hoist loop-invariant if statements as far up as * possible. Does not lift other values. It's useful to run this * earlier in lowering to simplify the IR. */ -Stmt hoist_loop_invariant_if_statements(Stmt); +Stmt hoist_loop_invariant_if_statements(const Stmt &); } // namespace Internal } // namespace Halide diff --git a/src/LLVM_Runtime_Linker.cpp b/src/LLVM_Runtime_Linker.cpp index 73e0f53d281a..1c3fa22356ad 100644 --- a/src/LLVM_Runtime_Linker.cpp +++ b/src/LLVM_Runtime_Linker.cpp @@ -1,4 +1,5 @@ #include "LLVM_Runtime_Linker.h" +#include "CompilerProfiling.h" #include "Error.h" #include "LLVM_Headers.h" #include "Target.h" @@ -11,6 +12,7 @@ using std::vector; namespace { std::unique_ptr parse_bitcode_file(llvm::StringRef buf, llvm::LLVMContext *context, const char *id) { + ZoneScoped; llvm::MemoryBufferRef bitcode_buffer = llvm::MemoryBufferRef(buf, id); @@ -611,6 +613,7 @@ void convert_weak_to_linkonce(llvm::GlobalValue &gv) { // triple appropriately for the target. void link_modules(std::vector> &modules, Target t, bool allow_stripping_all_weak_functions = false) { + ZoneScoped; llvm::DataLayout data_layout = get_data_layout_for_target(t); llvm::Triple triple = Internal::get_triple_for_target(t); @@ -839,6 +842,7 @@ namespace Internal { std::unique_ptr link_with_wasm_jit_runtime(llvm::LLVMContext *c, const Target &t, std::unique_ptr extra_module) { + ZoneScoped; bool bits_64 = (t.bits == 64); bool debug = t.has_feature(Target::Debug); @@ -882,6 +886,7 @@ std::unique_ptr link_with_wasm_jit_runtime(llvm::LLVMContext *c, c /** Create an llvm module containing the support code for a given target. */ std::unique_ptr get_initial_module_for_target(Target t, llvm::LLVMContext *c, bool for_shared_jit_runtime, bool just_gpu) { + ZoneScoped; enum InitialModuleType { ModuleAOT, ModuleAOTNoRuntime, @@ -1374,6 +1379,7 @@ std::unique_ptr get_initial_module_for_target(Target t, llvm::LLVM #ifdef WITH_NVPTX std::unique_ptr get_initial_module_for_ptx_device(Target target, llvm::LLVMContext *c) { + ZoneScoped; std::vector> modules; modules.push_back(get_initmod_ptx_dev_ll(c)); @@ -1427,6 +1433,7 @@ std::unique_ptr get_initial_module_for_ptx_device(Target target, l void add_bitcode_to_module(llvm::LLVMContext *context, llvm::Module &module, const std::vector &bitcode, const std::string &name) { + ZoneScoped; llvm::StringRef sb = llvm::StringRef((const char *)&bitcode[0], bitcode.size()); std::unique_ptr add_in = parse_bitcode_file(sb, context, name.c_str()); diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index fccc049ce5b6..dfc670f7d9e6 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -167,7 +167,7 @@ class StepForwards : public IRGraphMutator { Expr step_forwards(Expr e, const Scope &linear) { StepForwards step(linear); - e = step.mutate(e); + e = step(e); if (!step.success) { return Expr(); } else { @@ -554,7 +554,7 @@ class LoopCarry : public IRMutator { Stmt stmt; Stmt body = mutate(op->body); LoopCarryOverLoop carry(op->name, in_consume, max_carried_values); - body = carry.mutate(body); + body = carry(body); if (body.same_as(op->body)) { stmt = op; } else { @@ -583,9 +583,8 @@ class LoopCarry : public IRMutator { } // namespace -Stmt loop_carry(Stmt s, int max_carried_values) { - s = LoopCarry(max_carried_values).mutate(s); - return s; +Stmt loop_carry(const Stmt &s, int max_carried_values) { + return LoopCarry(max_carried_values)(s); } } // namespace Internal diff --git a/src/LoopCarry.h b/src/LoopCarry.h index f473e2627d54..e409935a0b3d 100644 --- a/src/LoopCarry.h +++ b/src/LoopCarry.h @@ -12,7 +12,7 @@ namespace Internal { * pessimization depending on how good the L1 cache is on the architecture * and how many memory issue slots there are. Currently only intended * for Hexagon. */ -Stmt loop_carry(Stmt, int max_carried_values = 8); +Stmt loop_carry(const Stmt &, int max_carried_values = 8); } // namespace Internal } // namespace Halide diff --git a/src/Lower.cpp b/src/Lower.cpp index 9b55bd20840d..076700ba222b 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -21,6 +21,7 @@ #include "CanonicalizeGPUVars.h" #include "ClampUnsafeAccesses.h" #include "CompilerLogger.h" +#include "CompilerProfiling.h" #include "Debug.h" #include "DebugArguments.h" #include "DebugToFile.h" @@ -94,29 +95,52 @@ namespace { class LoweringLogger { Stmt last_written; std::chrono::time_point last_time; + const char *last_msg; + std::vector> timings; bool time_lowering_passes = false; public: LoweringLogger() { - last_time = std::chrono::high_resolution_clock::now(); static bool should_time = !get_env_variable("HL_TIME_LOWERING_PASSES").empty(); time_lowering_passes = should_time; } - void operator()(const string &message, const Stmt &s) { + void begin(const char *msg) { + debug(1) << "Lowering pass: " << msg << "...\n"; + Profiling::generic_zone_begin(msg); + last_time = std::chrono::high_resolution_clock::now(); + last_msg = msg; + } + + void begin(const char *msg, int data) { + debug(1) << "Lowering pass: " << msg << " " << data << "...\n"; + Profiling::generic_zone_begin(msg, data); + last_time = std::chrono::high_resolution_clock::now(); + last_msg = msg; + } + + void end() { + Profiling::generic_zone_end(last_msg); auto t = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = t - last_time; + timings.emplace_back(diff.count() * 1000, last_msg); + } + + void end(const Stmt &s) { + Profiling::generic_zone_end(last_msg); + auto t = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = t - last_time; + timings.emplace_back(diff.count() * 1000, last_msg); if (!s.same_as(last_written)) { - debug(2) << message << "\n" + debug(2) << "Lowering after " << last_msg << "\n" << s << "\n"; last_written = s; last_time = t; } else { - debug(2) << message << " (unchanged)\n\n"; + debug(2) << "Lowering after " << last_msg << " (unchanged)\n\n"; last_time = t; } - timings.emplace_back(diff.count() * 1000, message); } ~LoweringLogger() { @@ -143,6 +167,7 @@ void lower_impl(const vector &output_funcs, bool trace_pipeline, const vector &custom_passes, Module &result_module) { + ZoneScoped; auto time_start = std::chrono::high_resolution_clock::now(); size_t initial_lowered_function_count = result_module.functions().size(); @@ -173,293 +198,294 @@ void lower_impl(const vector &output_funcs, LoweringLogger log; - debug(1) << "Creating initial loop nests...\n"; + log.begin("Creating initial loop nests"); bool any_memoized = false; Stmt s = schedule_functions(outputs, fused_groups, env, t, any_memoized); - log("Lowering after creating initial loop nests:", s); + log.end(s); if (any_memoized) { - debug(1) << "Injecting memoization...\n"; + log.begin("Injecting memoization"); s = inject_memoization(s, env, pipeline_name, outputs); - log("Lowering after injecting memoization:", s); + log.end(s); } else { debug(1) << "Skipping injecting memoization...\n"; } - debug(1) << "Injecting tracing...\n"; + log.begin("Injecting tracing"); s = inject_tracing(s, pipeline_name, trace_pipeline, env, outputs, t); - log("Lowering after injecting tracing:", s); + log.end(s); - debug(1) << "Adding checks for parameters\n"; + log.begin("Adding checks for parameters"); s = add_parameter_checks(requirements, s, t); - log("Lowering after injecting parameter checks:", s); + log.end(s); // Compute the maximum and minimum possible value of each // function. Used in later bounds inference passes. - debug(1) << "Computing bounds of each function's value\n"; + log.begin("Computing bounds of each function's value"); FuncValueBounds func_bounds = compute_function_value_bounds(order, env); + log.end(); // Clamp unsafe instances where a Func f accesses a Func g using // an index which depends on a third Func h. - debug(1) << "Clamping unsafe data-dependent accesses\n"; + log.begin("Clamping unsafe data-dependent accesses"); s = clamp_unsafe_accesses(s, env, func_bounds); - log("Lowering after clamping unsafe data-dependent accesses", s); + log.end(s); // This pass injects nested definitions of variable names, so we // can't simplify statements from here until we fix them up. (We // can still simplify Exprs). - debug(1) << "Performing computation bounds inference...\n"; + log.begin("Bounds inference"); s = bounds_inference(s, outputs, order, fused_groups, env, func_bounds, t); - log("Lowering after computation bounds inference:", s); + log.end(s); - debug(1) << "Asserting that all split factors are positive...\n"; + log.begin("Asserting that all split factors are positive"); s = add_split_factor_checks(s, env); - log("Lowering after asserting that all split factors are positive:", s); + log.end(s); - debug(1) << "Removing extern loops...\n"; + log.begin("Removing extern loops"); s = remove_extern_loops(s); - log("Lowering after removing extern loops:", s); + log.end(s); - debug(1) << "Performing sliding window optimization...\n"; + log.begin("Sliding window optimization"); s = sliding_window(s, env); - log("Lowering after sliding window:", s); + log.end(s); // This uniquifies the variable names, so we're good to simplify // after this point. This lets later passes assume syntactic // equivalence means semantic equivalence. - debug(1) << "Uniquifying variable names...\n"; + log.begin("Uniquifying variable names"); s = uniquify_variable_names(s); - log("Lowering after uniquifying variable names:", s); + log.end(s); - debug(1) << "Simplifying...\n"; + log.begin("Simplifying"); s = simplify(s); - log("Lowering after first simplification:", s); + log.end(s); - debug(1) << "Simplifying correlated differences...\n"; + log.begin("Simplifying correlated differences"); s = simplify_correlated_differences(s); - log("Lowering after simplifying correlated differences:", s); + log.end(s); - debug(1) << "Performing allocation bounds inference...\n"; + log.begin("Allocation bounds inference"); s = allocation_bounds_inference(s, env, func_bounds); - log("Lowering after allocation bounds inference:", s); + log.end(s); bool will_inject_host_copies = (t.has_gpu_feature() || t.has_feature(Target::HexagonDma) || (t.arch != Target::Hexagon && (t.has_feature(Target::HVX)))); - debug(1) << "Adding checks for images\n"; + log.begin("Adding checks for images"); s = add_image_checks(s, outputs, t, order, env, func_bounds, will_inject_host_copies); - log("Lowering after injecting image checks:", s); + log.end(s); - debug(1) << "Removing code that depends on undef values...\n"; + log.begin("Removing code that depends on undef values"); s = remove_undef(s); - log("Lowering after removing code that depends on undef values:", s); + log.end(s); - debug(1) << "Performing storage folding optimization...\n"; + log.begin("Performing storage folding optimization"); s = storage_folding(s, env); - log("Lowering after storage folding:", s); + log.end(s); - debug(1) << "Injecting debug_to_file calls...\n"; + log.begin("Injecting debug_to_file calls"); s = debug_to_file(s, outputs, env); - log("Lowering after injecting debug_to_file calls:", s); + log.end(s); - debug(1) << "Injecting prefetches...\n"; + log.begin("Injecting prefetches"); s = inject_prefetch(s, env); - log("Lowering after injecting prefetches:", s); + log.end(s); - debug(1) << "Discarding safe promises...\n"; + log.begin("Discarding safe promises"); s = lower_safe_promises(s); - log("Lowering after discarding safe promises:", s); + log.end(s); - debug(1) << "Dynamically skipping stages...\n"; + log.begin("Dynamically skipping stages"); s = skip_stages(s, outputs, fused_groups, env); - log("Lowering after dynamically skipping stages:", s); + log.end(s); - debug(1) << "Forking asynchronous producers...\n"; + log.begin("Forking asynchronous producers"); s = fork_async_producers(s, env); - log("Lowering after forking asynchronous producers:", s); + log.end(s); - debug(1) << "Destructuring tuple-valued realizations...\n"; + log.begin("Destructuring tuple-valued realizations"); s = split_tuples(s, env); - log("Lowering after destructuring tuple-valued realizations:", s); + log.end(s); if (t.has_gpu_feature()) { - debug(1) << "Canonicalizing GPU var names...\n"; + log.begin("Canonicalizing GPU var names"); s = canonicalize_gpu_vars(s); - log("Lowering after canonicalizing GPU var names:", s); + log.end(s); } - debug(1) << "Bounding small realizations...\n"; + log.begin("Bounding small realizations"); s = simplify_correlated_differences(s); s = bound_small_allocations(s); - log("Lowering after bounding small realizations:", s); + log.end(s); - debug(1) << "Performing storage flattening...\n"; + log.begin("Performing storage flattening"); s = storage_flattening(s, outputs, env, t); - log("Lowering after storage flattening:", s); + log.end(s); - debug(1) << "Adding atomic mutex allocation...\n"; + log.begin("Adding atomic mutex allocation"); s = add_atomic_mutex(s, outputs); - log("Lowering after adding atomic mutex allocation:", s); + log.end(s); - debug(1) << "Unpacking buffer arguments...\n"; + log.begin("Unpacking buffer arguments"); s = unpack_buffers(s); - log("Lowering after unpacking buffer arguments:", s); + log.end(s); if (any_memoized) { - debug(1) << "Rewriting memoized allocations...\n"; + log.begin("Rewriting memoized allocations"); s = rewrite_memoized_allocations(s, env); - log("Lowering after rewriting memoized allocations:", s); + log.end(s); } else { - debug(1) << "Skipping rewriting memoized allocations...\n"; + debug(1) << "Skipping rewriting memoized allocations\n"; } if (will_inject_host_copies) { - debug(1) << "Selecting a GPU API for GPU loops...\n"; + log.begin("Selecting a GPU API for GPU loops"); s = select_gpu_api(s, t); - log("Lowering after selecting a GPU API:", s); + log.end(s); - debug(1) << "Injecting host <-> dev buffer copies...\n"; + log.begin("Injecting host <-> dev buffer copies"); s = inject_host_dev_buffer_copies(s, t); - log("Lowering after injecting host <-> dev buffer copies:", s); + log.end(s); - debug(1) << "Selecting a GPU API for extern stages...\n"; + log.begin("Selecting a GPU API for extern stages"); s = select_gpu_api(s, t); - log("Lowering after selecting a GPU API for extern stages:", s); + log.end(s); } else { - debug(1) << "Injecting host-dirty marking...\n"; + log.begin("Injecting host-dirty marking"); s = inject_host_dev_buffer_copies(s, t); - log("Lowering after injecting host-dirty marking:", s); + log.end(s); } - debug(1) << "Simplifying...\n"; + log.begin("Simplifying"); s = simplify(s); s = unify_duplicate_lets(s); - log("Lowering after second simplification:", s); + log.end(s); - debug(1) << "Reduce prefetch dimension...\n"; + log.begin("Reduce prefetch dimension"); s = reduce_prefetch_dimension(s, t); - log("Lowering after reduce prefetch dimension:", s); + log.end(s); - debug(1) << "Simplifying correlated differences...\n"; + log.begin("Simplifying correlated differences"); s = simplify_correlated_differences(s); - log("Lowering after simplifying correlated differences:", s); + log.end(s); - debug(1) << "Bounding constant extent loops...\n"; + log.begin("Bounding constant extent loops"); s = bound_constant_extent_loops(s); - log("Lowering after bounding constant extent loops:", s); + log.end(s); - debug(1) << "Unrolling...\n"; + log.begin("Unrolling"); s = unroll_loops(s); - log("Lowering after unrolling:", s); + log.end(s); - debug(1) << "Vectorizing...\n"; + log.begin("Vectorizing"); s = vectorize_loops(s, env); s = simplify(s); - log("Lowering after vectorizing:", s); + log.end(s); if (t.has_gpu_feature() || t.has_feature(Target::Vulkan)) { - debug(1) << "Injecting per-block gpu synchronization...\n"; + log.begin("Injecting per-block gpu synchronization"); s = fuse_gpu_thread_loops(s); - log("Lowering after injecting per-block gpu synchronization:", s); + log.end(s); } - debug(1) << "Detecting vector interleavings...\n"; + log.begin("Detecting vector interleavings"); s = rewrite_interleavings(s); s = simplify(s); - log("Lowering after rewriting vector interleavings:", s); + log.end(s); - debug(1) << "Partitioning loops to simplify boundary conditions...\n"; + log.begin("Partitioning loops to simplify boundary conditions"); s = partition_loops(s); s = simplify(s); - log("Lowering after partitioning loops:", s); + log.end(s); - debug(1) << "Staging strided loads...\n"; + log.begin("Staging strided loads"); s = stage_strided_loads(s); - log("Lowering after staging strided loads:", s); + log.end(s); - debug(1) << "Trimming loops to the region over which they do something...\n"; + log.begin("Trimming loops to the region over which they do something"); s = trim_no_ops(s); - log("Lowering after loop trimming:", s); + log.end(s); - debug(1) << "Rebasing loops to zero...\n"; + log.begin("Rebasing loops to zero"); s = rebase_loops_to_zero(s); - debug(2) << "Lowering after rebasing loops to zero:\n" - << s << "\n\n"; + log.end(s); - debug(1) << "Hoisting loop invariant if statements...\n"; + log.begin("Hoisting loop invariant if statements"); s = hoist_loop_invariant_if_statements(s); - log("Lowering after hoisting loop invariant if statements:", s); + log.end(s); - debug(1) << "Injecting early frees...\n"; + log.begin("Injecting early frees"); s = inject_early_frees(s); - log("Lowering after injecting early frees:", s); + log.end(s); if (t.has_feature(Target::FuzzFloatStores)) { - debug(1) << "Fuzzing floating point stores...\n"; + log.begin("Fuzzing floating point stores"); s = fuzz_float_stores(s); - log("Lowering after fuzzing floating point stores:", s); + log.end(s); } - debug(1) << "Simplifying correlated differences...\n"; + log.begin("Simplifying correlated differences"); s = simplify_correlated_differences(s); - log("Lowering after simplifying correlated differences:", s); + log.end(s); - debug(1) << "Bounding small allocations...\n"; + log.begin("Bounding small allocations"); s = bound_small_allocations(s); - log("Lowering after bounding small allocations:", s); + log.end(s); if (t.has_feature(Target::Profile) || t.has_feature(Target::ProfileByTimer)) { - debug(1) << "Injecting profiling...\n"; + log.begin("Injecting profiling"); s = inject_profiling(s, pipeline_name, env); - log("Lowering after injecting profiling:", s); + log.end(s); } if (t.has_feature(Target::CUDA)) { - debug(1) << "Injecting warp shuffles...\n"; + log.begin("Injecting warp shuffles"); s = lower_warp_shuffles(s, t); - log("Lowering after injecting warp shuffles:", s); + log.end(s); } - debug(1) << "Simplifying...\n"; + log.begin("Simplifying"); s = common_subexpression_elimination(s); + log.end(); - debug(1) << "Lowering unsafe promises...\n"; + log.begin("Lowering unsafe promises"); s = lower_unsafe_promises(s, t); - log("Lowering after lowering unsafe promises:", s); + log.end(s); if (t.has_feature(Target::AVX512_SapphireRapids)) { - debug(1) << "Extracting tile operations...\n"; + log.begin("Extracting tile operations"); s = extract_tile_operations(s); - log("Lowering after extracting tile operations:", s); + log.end(s); } - debug(1) << "Flattening nested ramps...\n"; + log.begin("Flattening nested ramps"); s = flatten_nested_ramps(s); - log("Lowering after flattening nested ramps:", s); + log.end(s); - debug(1) << "Removing dead allocations and moving loop invariant code...\n"; + log.begin("Removing dead allocations and moving loop invariant code"); s = remove_dead_allocations(s); s = simplify(s); s = hoist_loop_invariant_values(s); s = hoist_loop_invariant_if_statements(s); - log("Lowering after removing dead allocations and hoisting loop invariants:", s); + log.end(s); - debug(1) << "Finding intrinsics...\n"; + log.begin("Finding intrinsics"); // Must be run after the last simplification, because it turns // divisions into shifts, which the simplifier reverses. s = find_intrinsics(s); - log("Lowering after finding intrinsics:", s); + log.end(s); - debug(1) << "Hoisting prefetches...\n"; + log.begin("Hoisting prefetches"); s = hoist_prefetches(s); - log("Lowering after hoisting prefetches:", s); + log.end(s); if (t.has_feature(Target::NoAsserts)) { - debug(1) << "Stripping asserts...\n"; + log.begin("Stripping asserts"); s = strip_asserts(s); - log("Lowering after stripping asserts:", s); + log.end(s); } debug(1) << "Lowering after final simplification:\n" @@ -467,10 +493,9 @@ void lower_impl(const vector &output_funcs, if (!custom_passes.empty()) { for (size_t i = 0; i < custom_passes.size(); i++) { - debug(1) << "Running custom lowering pass " << i << "...\n"; - s = custom_passes[i]->mutate(s); - debug(1) << "Lowering after custom pass " << i << ":\n" - << s << "\n\n"; + log.begin("Custom lowering pass", i); + s = custom_passes[i]->operator()(s); + log.end(s); } } @@ -478,19 +503,17 @@ void lower_impl(const vector &output_funcs, result_module.set_conceptual_code_stmt(s); if (t.arch != Target::Hexagon && t.has_feature(Target::HVX)) { - debug(1) << "Splitting off Hexagon offload...\n"; + log.begin("Splitting off Hexagon offload"); s = inject_hexagon_rpc(s, t, result_module); - debug(2) << "Lowering after splitting off Hexagon offload:\n" - << s << "\n"; + log.end(s); } else { debug(1) << "Skipping Hexagon offload...\n"; } if (t.has_gpu_feature()) { - debug(1) << "Offloading GPU loops...\n"; + log.begin("Offloading GPU loops"); s = inject_gpu_offload(s, t, any_strict_float); - debug(2) << "Lowering after splitting off GPU loops:\n" - << s << "\n\n"; + log.end(s); } else { debug(1) << "Skipping GPU offload...\n"; } @@ -505,7 +528,7 @@ void lower_impl(const vector &output_funcs, vector inferred_args = infer_arguments(s, outputs); std::vector closure_implementations; - debug(1) << "Lowering Parallel Tasks...\n"; + log.begin("Lowering Parallel Tasks"); s = lower_parallel_tasks(s, closure_implementations, pipeline_name, t); // Process any LoweredFunctions added by other passes. In practice, this // will likely not work well enough due to ordering issues with @@ -520,8 +543,7 @@ void lower_impl(const vector &output_funcs, for (auto &lowered_func : closure_implementations) { result_module.append(lowered_func); } - debug(2) << "Lowering after generating parallel tasks and closures:\n" - << s << "\n\n"; + log.end(s); vector public_args = args; for (const auto &out : outputs) { diff --git a/src/LowerParallelTasks.cpp b/src/LowerParallelTasks.cpp index 62e909136841..6d28d9779cbd 100644 --- a/src/LowerParallelTasks.cpp +++ b/src/LowerParallelTasks.cpp @@ -426,7 +426,7 @@ struct LowerParallelTasks : public IRMutator { Stmt lower_parallel_tasks(const Stmt &s, std::vector &closure_implementations, const std::string &name, const Target &t) { LowerParallelTasks lowering_mutator(name, t); - Stmt result = lowering_mutator.mutate(s); + Stmt result = lowering_mutator(s); // Main body will be dumped as part of standard lowering debugging, but closures will not be. debug(2) << [&] { diff --git a/src/LowerWarpShuffles.cpp b/src/LowerWarpShuffles.cpp index 6985b8ea9f46..dc2b68e43148 100644 --- a/src/LowerWarpShuffles.cpp +++ b/src/LowerWarpShuffles.cpp @@ -785,8 +785,8 @@ class HoistWarpShuffles : public IRMutator { Stmt else_case = mutate(op->else_case); HoistWarpShufflesFromSingleIfStmt hoister; - then_case = hoister.mutate(then_case); - else_case = hoister.mutate(else_case); + then_case = hoister(then_case); + else_case = hoister(else_case); Stmt s = IfThenElse::make(op->condition, then_case, else_case); if (hoister.success) { return hoister.rewrap(s); @@ -794,7 +794,7 @@ class HoistWarpShuffles : public IRMutator { // Need to move the ifstmt further inwards instead. internal_assert(!else_case.defined()) << "Cannot hoist warp shuffle out of " << s << "\n"; string pred_name = unique_name('p'); - s = MoveIfStatementInwards(Variable::make(op->condition.type(), pred_name)).mutate(then_case); + s = MoveIfStatementInwards(Variable::make(op->condition.type(), pred_name))(then_case); return LetStmt::make(pred_name, op->condition, s); } } @@ -824,8 +824,8 @@ class LowerWarpShufflesInEachKernel : public IRMutator { Stmt visit(const For *op) override { if (op->device_api == DeviceAPI::CUDA && has_lane_loop(op)) { Stmt s = op; - s = LowerWarpShuffles(cuda_cap).mutate(s); - s = HoistWarpShuffles().mutate(s); + s = LowerWarpShuffles(cuda_cap)(s); + s = HoistWarpShuffles()(s); return simplify(s); } else { return IRMutator::visit(op); @@ -844,9 +844,9 @@ class LowerWarpShufflesInEachKernel : public IRMutator { Stmt lower_warp_shuffles(Stmt s, const Target &t) { s = hoist_loop_invariant_values(s); - s = SubstituteInLaneVar().mutate(s); + s = SubstituteInLaneVar()(s); s = simplify(s); - s = LowerWarpShufflesInEachKernel(t.get_cuda_capability_lower_bound()).mutate(s); + s = LowerWarpShufflesInEachKernel(t.get_cuda_capability_lower_bound())(s); return s; }; diff --git a/src/Memoization.cpp b/src/Memoization.cpp index 21cfbd4c9dce..b55c4a66a264 100644 --- a/src/Memoization.cpp +++ b/src/Memoization.cpp @@ -470,7 +470,7 @@ Stmt inject_memoization(const Stmt &s, const std::map &en InjectMemoization injector(env, memoize_instance++, name, outputs); - return injector.mutate(s); + return injector(s); } namespace { @@ -563,9 +563,7 @@ class RewriteMemoizedAllocations : public IRMutator { Stmt rewrite_memoized_allocations(const Stmt &s, const std::map &env) { - RewriteMemoizedAllocations rewriter(env); - - return rewriter.mutate(s); + return RewriteMemoizedAllocations(env)(s); } } // namespace Internal diff --git a/src/OffloadGPULoops.cpp b/src/OffloadGPULoops.cpp index 11b8c3ccecf3..cec32874f911 100644 --- a/src/OffloadGPULoops.cpp +++ b/src/OffloadGPULoops.cpp @@ -9,6 +9,7 @@ #include "CodeGen_PTX_Dev.h" #include "CodeGen_Vulkan_Dev.h" #include "CodeGen_WebGPU_Dev.h" +#include "CompilerProfiling.h" #include "ExprUsesVar.h" #include "IRMutator.h" #include "IROperator.h" @@ -44,7 +45,7 @@ class ExtractBounds : public IRVisitor { } } -private: +protected: bool found_shared = false; using IRVisitor::visit; @@ -87,6 +88,7 @@ class ExtractBounds : public IRVisitor { }; class InjectGpuOffload : public IRMutator { +protected: /** Child code generator for device kernels. */ map> cgdev; @@ -131,7 +133,7 @@ class InjectGpuOffload : public IRMutator { << "A concrete device API should have been selected before codegen."; ExtractBounds bounds; - loop->accept(&bounds); + bounds(loop); debug(2) << "Kernel bounds: (" << bounds.num_threads[0] << ", " << bounds.num_threads[1] << ", " @@ -247,6 +249,7 @@ class InjectGpuOffload : public IRMutator { public: InjectGpuOffload(const Target &target, bool any_strict_float) : target(target) { + ZoneScoped; Target device_target = target; // For the GPU target we just want to pass the flags, to avoid the // generated kernel code unintentionally having any dependence on the @@ -280,12 +283,13 @@ class InjectGpuOffload : public IRMutator { } Stmt inject(const Stmt &s) { + ZoneScoped; // Create a new module for all of the kernels we find in this function. for (auto &i : cgdev) { i.second->init_module(); } - Stmt result = mutate(s); + Stmt result = (s); for (auto &i : cgdev) { string api_unique_name = i.second->api_unique_name(); diff --git a/src/OptimizeShuffles.cpp b/src/OptimizeShuffles.cpp index 0a88d02f0b60..83672fa59395 100644 --- a/src/OptimizeShuffles.cpp +++ b/src/OptimizeShuffles.cpp @@ -144,7 +144,7 @@ class OptimizeShuffles : public IRMutator { } // namespace Stmt optimize_shuffles(Stmt s, int lut_alignment) { - s = OptimizeShuffles(lut_alignment).mutate(s); + s = OptimizeShuffles(lut_alignment)(s); return s; } diff --git a/src/ParallelRVar.cpp b/src/ParallelRVar.cpp index f2240fb9e036..538cd144f449 100644 --- a/src/ParallelRVar.cpp +++ b/src/ParallelRVar.cpp @@ -110,7 +110,7 @@ bool can_parallelize_rvar(const string &v, // Make an expr representing the store done by a different thread. RenameFreeVars renamer; - auto other_store = renamer.mutate(args); + auto other_store = renamer(args); // Construct an expression which is true when the two threads are // in fact two different threads. We'll use this liberally in the @@ -147,7 +147,7 @@ bool can_parallelize_rvar(const string &v, // Add the definition's predicate if there is any if (pred.defined() || !is_const_one(pred)) { const Expr &this_pred = pred; - Expr other_pred = renamer.mutate(pred); + Expr other_pred = renamer(pred); debug(3) << "......this thread predicate: " << this_pred << "\n"; debug(3) << "......other thread predicate: " << other_pred << "\n"; hazard = hazard && this_pred && other_pred; @@ -156,7 +156,7 @@ bool can_parallelize_rvar(const string &v, debug(3) << "Attempting to falsify: " << hazard << "\n"; // Pull out common non-boolean terms hazard = common_subexpression_elimination(hazard); - hazard = SubstituteInBooleanLets().mutate(hazard); + hazard = SubstituteInBooleanLets()(hazard); hazard = simplify(hazard, bounds); debug(3) << "Simplified to: " << hazard << "\n"; diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index 8f80ea42cd85..e31b8b9387b2 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -1165,16 +1165,16 @@ bool has_likely_tag(const Expr &e, const Scope<> &scope) { } Stmt partition_loops(Stmt s) { - s = LowerLikelyIfInnermost().mutate(s); + s = LowerLikelyIfInnermost()(s); // Walk inwards to the first loop before doing any more work. s = mutate_with(s, [](auto *self, const For *op) { Stmt s = op; - s = MarkClampedRampsAsLikely().mutate(s); - s = ExpandSelects().mutate(s); - s = PartitionLoops().mutate(s); - s = RenormalizeGPULoops().mutate(s); - s = CollapseSelects().mutate(s); + s = MarkClampedRampsAsLikely()(s); + s = ExpandSelects()(s); + s = PartitionLoops()(s); + s = RenormalizeGPULoops()(s); + s = CollapseSelects()(s); return s; }); diff --git a/src/Pipeline.cpp b/src/Pipeline.cpp index c305a6105742..a7a5ea843601 100644 --- a/src/Pipeline.cpp +++ b/src/Pipeline.cpp @@ -4,6 +4,7 @@ #include "Argument.h" #include "Callable.h" #include "CodeGen_Internal.h" +#include "CompilerProfiling.h" #include "Deserialization.h" #include "FindCalls.h" #include "Func.h" @@ -495,6 +496,7 @@ Module Pipeline::compile_to_module(const vector &args, const string &fn_name, const Target &target, const LinkageType linkage_type) { + ZoneScoped; user_assert(defined()) << "Can't compile undefined Pipeline.\n"; for (const Function &f : contents->outputs) { @@ -580,6 +582,7 @@ Target Pipeline::get_compiled_jit_target() const { } void Pipeline::compile_jit(const Target &target_arg) { + ZoneScoped; user_assert(defined()) << "Pipeline is undefined\n"; Target target = target_arg; @@ -645,6 +648,7 @@ void Pipeline::compile_jit(const Target &target_arg) { } Callable Pipeline::compile_to_callable(const std::vector &args_in, const Target &target_arg) { + ZoneScoped; user_assert(defined()) << "Pipeline is undefined\n"; Target target = target_arg.with_feature(Target::JIT).with_feature(Target::UserContext); @@ -676,6 +680,7 @@ Callable Pipeline::compile_to_callable(const std::vector &args_in, con const std::vector &outputs, const std::map &jit_externs_in, const Target &target_arg) { + ZoneScoped; user_assert(!target_arg.has_unknowns()) << "Cannot jit-compile for target '" << target_arg << "'\n"; Target jit_target = target_arg.with_feature(Target::JIT).with_feature(Target::UserContext); diff --git a/src/Prefetch.cpp b/src/Prefetch.cpp index c0eedf50b817..2a4618eb72a5 100644 --- a/src/Prefetch.cpp +++ b/src/Prefetch.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include "Bounds.h" @@ -19,6 +20,7 @@ namespace Internal { using std::map; using std::set; using std::string; +using std::unordered_map; using std::vector; /** @@ -42,12 +44,15 @@ namespace { // Collect the bounds of all the externally referenced buffers in a stmt. class CollectExternalBufferBounds : public IRVisitor { public: - map buffers; + unordered_map buffers; using IRVisitor::visit; void add_buffer_bounds(const string &name, const Buffer<> &image, const Parameter ¶m, int dims) { - Box b; + if (buffers.find(name) != buffers.end()) { + return; + } + Box b(dims); for (int i = 0; i < dims; ++i) { string dim_name = std::to_string(i); Expr buf_min_i = Variable::make(Int(32), concat_strings(name, ".min.", i), @@ -55,7 +60,7 @@ class CollectExternalBufferBounds : public IRVisitor { Expr buf_extent_i = Variable::make(Int(32), concat_strings(name, ".extent.", i), image, param, ReductionDomain()); Expr buf_max_i = buf_min_i + buf_extent_i - 1; - b.push_back(Interval(buf_min_i, buf_max_i)); + b[i] = Interval(buf_min_i, buf_max_i); } buffers.emplace(name, b); } @@ -74,13 +79,13 @@ class CollectExternalBufferBounds : public IRVisitor { class InjectPrefetch : public IRMutator { public: - InjectPrefetch(const map &e, const map &buffers) + InjectPrefetch(const map &e, const unordered_map &buffers) : env(e), external_buffers(buffers) { } -private: +protected: const map &env; - const map &external_buffers; + const unordered_map &external_buffers; Scope buffer_bounds; using IRMutator::visit; @@ -187,7 +192,7 @@ class InjectPlaceholderPrefetch : public IRMutator { : env(e), prefix(prefix), prefetch_list(prefetches) { } -private: +protected: const map &env; const string &prefix; const vector &prefetch_list; @@ -260,6 +265,7 @@ class InjectPlaceholderPrefetch : public IRMutator { // Reduce the prefetch dimension if bigger than 'max_dim'. It keeps the 'max_dim' // innermost dimensions and replaces the rests with for-loops. class ReducePrefetchDimension : public IRMutator { +protected: using IRMutator::visit; const size_t max_dim; @@ -321,6 +327,7 @@ class ReducePrefetchDimension : public IRMutator { // prefetch. This will split the prefetch call into multiple calls by adding // an outer for-loop around the prefetch. class SplitPrefetch : public IRMutator { +protected: using IRMutator::visit; Expr max_byte_size; @@ -400,6 +407,7 @@ void traverse_block(const Stmt &s, Fn &&f) { } class HoistPrefetches : public IRMutator { +protected: using IRMutator::visit; Stmt visit(const Block *op) override { @@ -432,17 +440,20 @@ class HoistPrefetches : public IRMutator { Stmt inject_placeholder_prefetch(const Stmt &s, const map &env, const string &prefix, const vector &prefetches) { - Stmt stmt = InjectPlaceholderPrefetch(env, prefix, prefetches).mutate(s); + ZoneScoped; + Stmt stmt = InjectPlaceholderPrefetch(env, prefix, prefetches)(s); return stmt; } Stmt inject_prefetch(const Stmt &s, const map &env) { + ZoneScoped; CollectExternalBufferBounds finder; s.accept(&finder); - return InjectPrefetch(env, finder.buffers).mutate(s); + return InjectPrefetch(env, finder.buffers)(s); } Stmt reduce_prefetch_dimension(Stmt stmt, const Target &t) { + ZoneScoped; size_t max_dim = 0; Expr max_byte_size; @@ -461,17 +472,18 @@ Stmt reduce_prefetch_dimension(Stmt stmt, const Target &t) { } internal_assert(max_dim > 0); - stmt = ReducePrefetchDimension(max_dim).mutate(stmt); + stmt = ReducePrefetchDimension(max_dim)(stmt); if (max_byte_size.defined()) { // If the max byte size is specified, we may need to tile // the prefetch - stmt = SplitPrefetch(max_byte_size).mutate(stmt); + stmt = SplitPrefetch(max_byte_size)(stmt); } return stmt; } Stmt hoist_prefetches(const Stmt &s) { - return HoistPrefetches().mutate(s); + ZoneScoped; + return HoistPrefetches()(s); } } // namespace Internal diff --git a/src/Profiling.cpp b/src/Profiling.cpp index c8054e83544e..fae38b94fe7a 100644 --- a/src/Profiling.cpp +++ b/src/Profiling.cpp @@ -611,7 +611,7 @@ Stmt inject_profiling(const Stmt &stmt, const string &pipeline_name, const std:: Names names(pipeline_name); InjectProfiling profiling(names, env); - Stmt s = profiling.mutate(stmt); + Stmt s = profiling(stmt); int num_funcs = (int)(profiling.indices.size()); diff --git a/src/PurifyIndexMath.cpp b/src/PurifyIndexMath.cpp index 1ea205a6ff6b..0bdf7f60691d 100644 --- a/src/PurifyIndexMath.cpp +++ b/src/PurifyIndexMath.cpp @@ -1,7 +1,6 @@ #include "PurifyIndexMath.h" #include "IRMutator.h" #include "IROperator.h" -#include "Simplify.h" namespace Halide { namespace Internal { @@ -26,7 +25,7 @@ class PurifyIndexMath : public IRMutator { } // namespace Expr purify_index_math(const Expr &s) { - return PurifyIndexMath().mutate(s); + return PurifyIndexMath()(s); } } // namespace Internal diff --git a/src/Qualify.cpp b/src/Qualify.cpp index 470162909f94..ef4a32abb355 100644 --- a/src/Qualify.cpp +++ b/src/Qualify.cpp @@ -36,8 +36,7 @@ class QualifyExpr : public IRMutator { } // namespace Expr qualify(const string &prefix, const Expr &value) { - QualifyExpr q(prefix); - return q.mutate(value); + return QualifyExpr(prefix)(value); } } // namespace Internal diff --git a/src/Random.cpp b/src/Random.cpp index 57eeb69f9210..e97d43ce0680 100644 --- a/src/Random.cpp +++ b/src/Random.cpp @@ -141,8 +141,7 @@ class LowerRandom : public IRMutator { } // namespace Expr lower_random(const Expr &e, const vector &free_vars, int tag) { - LowerRandom r(free_vars, tag); - return r.mutate(e); + return LowerRandom(free_vars, tag)(e); } } // namespace Internal diff --git a/src/RebaseLoopsToZero.cpp b/src/RebaseLoopsToZero.cpp index 49f97126bb93..dfc0fb3d0709 100644 --- a/src/RebaseLoopsToZero.cpp +++ b/src/RebaseLoopsToZero.cpp @@ -46,7 +46,7 @@ class RebaseLoopsToZero : public IRMutator { } // namespace Stmt rebase_loops_to_zero(const Stmt &s) { - return RebaseLoopsToZero().mutate(s); + return RebaseLoopsToZero()(s); } } // namespace Internal diff --git a/src/Reduction.cpp b/src/Reduction.cpp index bedb51065694..b0dd66ee823f 100644 --- a/src/Reduction.cpp +++ b/src/Reduction.cpp @@ -102,32 +102,32 @@ struct ReductionDomainContents { } // Pass an IRVisitor through to all Exprs referenced in the ReductionDomainContents - void accept(IRVisitor *visitor) { + void accept(IRVisitor &visitor) { for (const ReductionVariable &rvar : domain) { if (rvar.min.defined()) { - rvar.min.accept(visitor); + visitor(rvar.min); } if (rvar.extent.defined()) { - rvar.extent.accept(visitor); + visitor(rvar.extent); } } if (predicate.defined()) { - predicate.accept(visitor); + visitor(predicate); } } // Pass an IRMutator through to all Exprs referenced in the ReductionDomainContents - void mutate(IRMutator *mutator) { + void mutate(IRMutator &mutator) { for (ReductionVariable &rvar : domain) { if (rvar.min.defined()) { - rvar.min = mutator->mutate(rvar.min); + rvar.min = mutator(rvar.min); } if (rvar.extent.defined()) { - rvar.extent = mutator->mutate(rvar.extent); + rvar.extent = mutator(rvar.extent); } } if (predicate.defined()) { - predicate = mutator->mutate(predicate); + predicate = mutator(predicate); } } }; @@ -196,7 +196,7 @@ class DropSelfReferences : public IRMutator { void ReductionDomain::set_predicate(const Expr &p) { // The predicate can refer back to the RDom. We need to break // those cycles to prevent a leak. - contents->predicate = DropSelfReferences(p, *this).mutate(p); + contents->predicate = DropSelfReferences(p, *this)(p); } void ReductionDomain::where(Expr predicate) { @@ -223,13 +223,13 @@ bool ReductionDomain::frozen() const { void ReductionDomain::accept(IRVisitor *visitor) const { if (contents.defined()) { - contents->accept(visitor); + contents->accept(*visitor); } } void ReductionDomain::mutate(IRMutator *mutator) { if (contents.defined()) { - contents->mutate(mutator); + contents->mutate(*mutator); } } diff --git a/src/RemoveDeadAllocations.cpp b/src/RemoveDeadAllocations.cpp index 33a1a0190b07..74a5c32751d2 100644 --- a/src/RemoveDeadAllocations.cpp +++ b/src/RemoveDeadAllocations.cpp @@ -9,6 +9,7 @@ namespace Internal { namespace { class RemoveDeadAllocations : public IRMutator { +protected: using IRMutator::visit; Scope allocs; @@ -88,7 +89,7 @@ class RemoveDeadAllocations : public IRMutator { } // namespace Stmt remove_dead_allocations(const Stmt &s) { - return RemoveDeadAllocations().mutate(s); + return RemoveDeadAllocations()(s); } } // namespace Internal diff --git a/src/RemoveExternLoops.cpp b/src/RemoveExternLoops.cpp index 9fb0e187b3eb..c3aab4db7c75 100644 --- a/src/RemoveExternLoops.cpp +++ b/src/RemoveExternLoops.cpp @@ -22,7 +22,7 @@ class RemoveExternLoops : public IRMutator { } // namespace Stmt remove_extern_loops(const Stmt &s) { - return RemoveExternLoops().mutate(s); + return RemoveExternLoops()(s); } } // namespace Internal diff --git a/src/RemoveUndef.cpp b/src/RemoveUndef.cpp index 9667aafe891a..512c984427a7 100644 --- a/src/RemoveUndef.cpp +++ b/src/RemoveUndef.cpp @@ -628,7 +628,7 @@ class RemoveUndef : public IRMutator { Stmt remove_undef(Stmt s) { RemoveUndef r; - s = r.mutate(s); + s = r(s); internal_assert(!r.predicate.defined()) << "Undefined expression leaked outside of a Store node: " << r.predicate << "\n"; diff --git a/src/Schedule.cpp b/src/Schedule.cpp index 9f5f51d7a043..a2583d0fb732 100644 --- a/src/Schedule.cpp +++ b/src/Schedule.cpp @@ -250,33 +250,33 @@ struct FuncScheduleContents { } // Pass an IRMutator through to all Exprs referenced in the FuncScheduleContents - void mutate(IRMutator *mutator) { + void mutate(IRMutator &mutator) { for (Bound &b : bounds) { if (b.min.defined()) { - b.min = mutator->mutate(b.min); + b.min = mutator(b.min); } if (b.extent.defined()) { - b.extent = mutator->mutate(b.extent); + b.extent = mutator(b.extent); } if (b.modulus.defined()) { - b.modulus = mutator->mutate(b.modulus); + b.modulus = mutator(b.modulus); } if (b.remainder.defined()) { - b.remainder = mutator->mutate(b.remainder); + b.remainder = mutator(b.remainder); } } for (Bound &b : estimates) { if (b.min.defined()) { - b.min = mutator->mutate(b.min); + b.min = mutator(b.min); } if (b.extent.defined()) { - b.extent = mutator->mutate(b.extent); + b.extent = mutator(b.extent); } if (b.modulus.defined()) { - b.modulus = mutator->mutate(b.modulus); + b.modulus = mutator(b.modulus); } if (b.remainder.defined()) { - b.remainder = mutator->mutate(b.remainder); + b.remainder = mutator(b.remainder); } } } @@ -313,23 +313,23 @@ struct StageScheduleContents { } // Pass an IRMutator through to all Exprs referenced in the StageScheduleContents - void mutate(IRMutator *mutator) { + void mutate(IRMutator &mutator) { for (ReductionVariable &r : rvars) { if (r.min.defined()) { - r.min = mutator->mutate(r.min); + r.min = mutator(r.min); } if (r.extent.defined()) { - r.extent = mutator->mutate(r.extent); + r.extent = mutator(r.extent); } } for (Split &s : splits) { if (s.factor.defined()) { - s.factor = mutator->mutate(s.factor); + s.factor = mutator(s.factor); } } for (PrefetchDirective &p : prefetches) { if (p.offset.defined()) { - p.offset = mutator->mutate(p.offset); + p.offset = mutator(p.offset); } } } @@ -521,7 +521,7 @@ void FuncSchedule::accept(IRVisitor *visitor) const { void FuncSchedule::mutate(IRMutator *mutator) { if (contents.defined()) { - contents->mutate(mutator); + contents->mutate(*mutator); } } @@ -665,7 +665,7 @@ void StageSchedule::accept(IRVisitor *visitor) const { void StageSchedule::mutate(IRMutator *mutator) { if (contents.defined()) { - contents->mutate(mutator); + contents->mutate(*mutator); } } diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 1a9e0858c4d6..a0b4a4c18876 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -133,7 +133,7 @@ class SubstituteIn : public IRGraphMutator { }; Stmt substitute_in(const string &name, const Expr &value, bool calls, bool provides, const Stmt &s) { - return SubstituteIn(name, value, calls, provides).mutate(s); + return SubstituteIn(name, value, calls, provides)(s); } class AddPredicates : public IRGraphMutator { @@ -177,7 +177,7 @@ class AddPredicates : public IRGraphMutator { }; Stmt add_predicates(const Expr &cond, const Function &func, ApplySplitResult::Type type, const Stmt &s) { - return AddPredicates(cond, func, type).mutate(s); + return AddPredicates(cond, func, type)(s); } // Build a loop nest about a provide node using a schedule @@ -1004,7 +1004,7 @@ Stmt inject_stmt(Stmt root, Stmt injected, const LoopLevel &level) { return Block::make(root, injected); } InjectStmt injector(injected, level); - root = injector.mutate(root); + root = injector(root); internal_assert(injector.found_level); return root; } @@ -1091,7 +1091,7 @@ Stmt substitute_fused_bounds(Stmt s, const map &replacements) } } subs(replacements); - return subs.mutate(s); + return subs(s); } // Add letstmts inside each parent loop that define the corresponding child loop @@ -1128,7 +1128,7 @@ Stmt add_loop_var_aliases(Stmt s, const map> &loop_var_alias } } add_aliases(loop_var_aliases); - return add_aliases.mutate(s); + return add_aliases(s); } // Shift the iteration domain of a loop nest by some factor. @@ -1161,8 +1161,7 @@ class ShiftLoopNest : public IRMutator { if (shifts.empty()) { return node; } - ShiftLoopNest visitor(shifts); - return visitor.mutate(node); + return ShiftLoopNest(shifts)(node); } }; @@ -2612,7 +2611,7 @@ Stmt schedule_functions(const vector &outputs, } else { debug(1) << "Injecting realization of " << funcs << "\n"; InjectFunctionRealization injector(funcs, is_output_list, target, env); - s = injector.mutate(s); + s = injector(s); internal_assert(injector.found_store_level() && injector.found_compute_level() && injector.found_hoist_storage_level()); } @@ -2625,7 +2624,7 @@ Stmt schedule_functions(const vector &outputs, s = root_loop->body; // We can also remove all the loops over __outermost now. - s = RemoveLoopsOverOutermost().mutate(s); + s = RemoveLoopsOverOutermost()(s); return s; } diff --git a/src/SelectGPUAPI.cpp b/src/SelectGPUAPI.cpp index ec73c883e955..4c764f6bc619 100644 --- a/src/SelectGPUAPI.cpp +++ b/src/SelectGPUAPI.cpp @@ -50,7 +50,7 @@ class SelectGPUAPI : public IRMutator { } // namespace Stmt select_gpu_api(const Stmt &s, const Target &t) { - return SelectGPUAPI(t).mutate(s); + return SelectGPUAPI(t)(s); } } // namespace Internal diff --git a/src/Simplify.cpp b/src/Simplify.cpp index edaccbcefc16..d1ac6d828efb 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -3,6 +3,7 @@ #include "CSE.h" #include "CompilerLogger.h" +#include "CompilerProfiling.h" #include "IRMutator.h" #include "Substitute.h" @@ -16,7 +17,7 @@ using std::string; using std::vector; Simplify::Simplify(const Scope *bi, const Scope *ai) { - + ZoneScoped; // Only respect the constant bounds from the containing scope. for (auto iter = bi->cbegin(); iter != bi->cend(); ++iter) { ExprInfo info; @@ -324,6 +325,7 @@ template T substitute_facts_impl(const T &t, const std::set &truths, const std::set &falsehoods) { + ZoneScoped; return mutate_with(t, [&](auto *self, const Expr &e) { if (e.type().is_bool()) { if (truths.count(e)) { @@ -364,6 +366,7 @@ Expr simplify(const Expr &e, const Scope &bounds, const Scope &alignment, const std::vector &assumptions) { + ZoneScoped; Simplify m(&bounds, &alignment); std::vector facts; facts.reserve(assumptions.size()); @@ -381,6 +384,7 @@ Stmt simplify(const Stmt &s, const Scope &bounds, const Scope &alignment, const std::vector &assumptions) { + ZoneScoped; Simplify m(&bounds, &alignment); std::vector facts; facts.reserve(assumptions.size()); @@ -403,10 +407,12 @@ class SimplifyExprs : public IRMutator { }; Stmt simplify_exprs(const Stmt &s) { + ZoneScoped; return SimplifyExprs().mutate(s); } bool can_prove(Expr e, const Scope &bounds) { + ZoneScoped; internal_assert(e.type().is_bool()) << "Argument to can_prove is not a boolean Expr: " << e << "\n"; @@ -450,7 +456,7 @@ bool can_prove(Expr e, const Scope &bounds) { std::vector> out_vars; } renamer; - e = renamer.mutate(e); + e = renamer(e); // Look for a concrete counter-example with random probing static std::mt19937 rng(0); diff --git a/src/SimplifyCorrelatedDifferences.cpp b/src/SimplifyCorrelatedDifferences.cpp index 3afe5d84dcce..b3c016ea3c57 100644 --- a/src/SimplifyCorrelatedDifferences.cpp +++ b/src/SimplifyCorrelatedDifferences.cpp @@ -20,6 +20,7 @@ using std::string; using std::vector; class PartiallyCancelDifferences : public IRMutator { +protected: using IRMutator::visit; // Symbols used by rewrite rules @@ -65,6 +66,7 @@ class PartiallyCancelDifferences : public IRMutator { }; class SimplifyCorrelatedDifferences : public IRMutator { +protected: using IRMutator::visit; string loop_var; @@ -177,6 +179,7 @@ class SimplifyCorrelatedDifferences : public IRMutator { // Add the names of any free variables in an expr to the provided set void track_free_vars(const Expr &e, std::set *vars) { class TrackFreeVars : public IRVisitor { + protected: using IRVisitor::visit; void visit(const Variable *op) override { if (!scope.contains(op->name)) { @@ -195,7 +198,7 @@ class SimplifyCorrelatedDifferences : public IRMutator { : vars(vars) { } } tracker(vars); - e.accept(&tracker); + tracker(e); } Expr cancel_correlated_subexpression(Expr e, const Expr &a, const Expr &b, bool correlated) { @@ -224,7 +227,7 @@ class SimplifyCorrelatedDifferences : public IRMutator { } e = common_subexpression_elimination(e); e = solve_expression(e, loop_var).result; - e = PartiallyCancelDifferences().mutate(e); + e = PartiallyCancelDifferences()(e); e = simplify(e); const bool check_non_monotonic = debug_is_active(1) || get_compiler_logger() != nullptr; @@ -308,11 +311,11 @@ class SimplifyCorrelatedDifferences : public IRMutator { } // namespace Stmt simplify_correlated_differences(const Stmt &stmt) { - return SimplifyCorrelatedDifferences().mutate(stmt); + return SimplifyCorrelatedDifferences()(stmt); } Expr bound_correlated_differences(const Expr &expr) { - return PartiallyCancelDifferences().mutate(expr); + return PartiallyCancelDifferences()(expr); } } // namespace Internal diff --git a/src/SkipStages.cpp b/src/SkipStages.cpp index ae17bfae0f8b..6031af145543 100644 --- a/src/SkipStages.cpp +++ b/src/SkipStages.cpp @@ -835,7 +835,7 @@ Stmt skip_stages(const Stmt &stmt, } SkipStages skipper(analysis, name_for_id); - stmt = skipper.mutate(stmt); + stmt = skipper(stmt); stmt = skipper.emit_outermost_defs(stmt); return stmt; }; diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index 3f779cb0bca2..9c0180316357 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -86,7 +86,7 @@ class ExpandExpr : public IRMutator { // Perform all the substitutions in a scope Expr expand_expr(const Expr &e, const Scope &scope) { ExpandExpr ee(scope); - Expr result = ee.mutate(e); + Expr result = ee(e); debug(4) << "Expanded " << e << " into " << result << "\n"; return result; } @@ -612,7 +612,7 @@ class SlidingWindowOnFunctionAndLoop : public IRMutator { Interval new_bounds; Stmt translate_loop(const Stmt &s) { - return RollFunc(func, dim_idx, loop_var, old_bounds, new_bounds).mutate(s); + return RollFunc(func, dim_idx, loop_var, old_bounds, new_bounds)(s); } }; @@ -817,7 +817,7 @@ class SlidingWindow : public IRMutator { set &slid_dims = slid_dimensions[func.name()]; size_t old_slid_dims_size = slid_dims.size(); SlidingWindowOnFunctionAndLoop slider(func, name, prev_loop_min, slid_dims); - body = slider.mutate(body); + body = slider(body); if (func.schedule().memory_type() == MemoryType::Register && slider.old_bounds.has_lower_bound()) { @@ -845,7 +845,7 @@ class SlidingWindow : public IRMutator { {name + ".loop_min", loop_min}, }, body); - body = SubstitutePrefetchVar(name, new_name).mutate(body); + body = SubstitutePrefetchVar(name, new_name)(body); name = new_name; @@ -921,7 +921,7 @@ class AddLoopMinOrig : public IRMutator { } // namespace Stmt sliding_window(const Stmt &s, const map &env) { - return SlidingWindow(env).mutate(AddLoopMinOrig().mutate(s)); + return SlidingWindow(env)(AddLoopMinOrig()(s)); } } // namespace Internal diff --git a/src/Solve.cpp b/src/Solve.cpp index d8557026cf00..c6d51fe335db 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -61,7 +61,7 @@ class SolveExpression : public IRMutator { // Has the solve failed. bool failed = false; -private: +protected: // The variable we're solving for. string var; @@ -1135,8 +1135,9 @@ class SolveForInterval : public IRVisitor { } // Anonymous namespace SolverResult solve_expression(const Expr &e, const std::string &variable, const Scope &scope) { + ZoneScoped; SolveExpression solver(variable, scope); - Expr new_e = solver.mutate(e); + Expr new_e = solver(e); // The process has expanded lets. Re-collect them. new_e = common_subexpression_elimination(new_e); debug(3) << "Solved expr for " << variable << " :\n" diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index ee805634c151..adfe0968a65c 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -408,7 +408,7 @@ class SplitScatterGather : public IRMutator { vector vars; for (extractor.idx = 0; extractor.idx < size; extractor.idx++) { string name = unique_name(op->name + "." + std::to_string(extractor.idx)); - lets.emplace_back(name, extractor.mutate(op->value)); + lets.emplace_back(name, extractor(op->value)); vars.push_back(Variable::make(op->value.type(), name)); } @@ -475,15 +475,15 @@ class SplitScatterGather : public IRMutator { vector args = op->args; for (Expr &a : args) { string name = unique_name('t'); - exprs.push_back(extractor.mutate(a)); + exprs.push_back(extractor(a)); names.push_back(name); a = Variable::make(a.type(), name); } vector values = op->values; for (Expr &v : values) { - v = extractor.mutate(v); + v = extractor(v); string name = unique_name('t'); - exprs.push_back(extractor.mutate(v)); + exprs.push_back(extractor(v)); names.push_back(name); v = Variable::make(v.type(), name); } @@ -524,8 +524,8 @@ class SplitScatterGather : public IRMutator { } // namespace Stmt split_tuples(const Stmt &stmt, const map &env) { - Stmt s = SplitTuples(env).mutate(stmt); - s = SplitScatterGather().mutate(s); + Stmt s = SplitTuples(env)(stmt); + s = SplitScatterGather()(s); return s; } diff --git a/src/StorageFlattening.cpp b/src/StorageFlattening.cpp index 4fbf1fb1f036..79b724e42f1a 100644 --- a/src/StorageFlattening.cpp +++ b/src/StorageFlattening.cpp @@ -414,7 +414,7 @@ class FlattenDimensions : public IRMutator { }; class HoistStorage : public IRMutator { - +protected: struct HoistedAllocationInfo { string name; Type type; @@ -579,6 +579,7 @@ class HoistStorage : public IRMutator { // Realizations, stores, and loads must all be on types that are // multiples of 8-bits. This really only affects bools class PromoteToMemoryType : public IRMutator { +protected: using IRMutator::visit; Type upgrade(Type t) { @@ -624,6 +625,7 @@ Stmt storage_flattening(Stmt s, const vector &outputs, const map &env, const Target &target) { + ZoneScoped; s = zero_gpu_loop_mins(s); // Make an environment that makes it easier to figure out which @@ -639,9 +641,9 @@ Stmt storage_flattening(Stmt s, tuple_env[p.first] = {p.second, 0}; } } - s = FlattenDimensions(tuple_env, outputs, target).mutate(s); - s = HoistStorage().mutate(s); - s = PromoteToMemoryType().mutate(s); + s = FlattenDimensions(tuple_env, outputs, target)(s); + s = HoistStorage()(s); + s = PromoteToMemoryType()(s); return s; } diff --git a/src/StorageFolding.cpp b/src/StorageFolding.cpp index e97b06a8a6b9..4b4fa8152d4a 100644 --- a/src/StorageFolding.cpp +++ b/src/StorageFolding.cpp @@ -691,8 +691,7 @@ class AttemptStorageFoldingOfFunction : public IRMutator { op->name, sema_var, dim, - storage_dim) - .mutate(body); + storage_dim)(body); if (storage_dim.fold_forward) { can_fold_forwards = true; @@ -791,7 +790,7 @@ class AttemptStorageFoldingOfFunction : public IRMutator { } else { head = dynamic_footprint; } - body = FoldStorageOfFunction(func.name(), (int)i - 1, factor, head).mutate(body); + body = FoldStorageOfFunction(func.name(), (int)i - 1, factor, head)(body); } // If the producer is async, it can run ahead by @@ -961,7 +960,7 @@ class StorageFolding : public IRMutator { } else { debug(3) << "Attempting to fold " << op->name << " automatically or explicitly\n"; } - body = folder.mutate(body); + body = folder(body); if (body.same_as(op->body)) { return op; @@ -1034,8 +1033,8 @@ class RemoveSlidingWindowMarkers : public IRMutator { } // namespace Stmt storage_folding(const Stmt &s, const std::map &env) { - Stmt stmt = StorageFolding(env).mutate(s); - stmt = RemoveSlidingWindowMarkers().mutate(stmt); + Stmt stmt = StorageFolding(env)(s); + stmt = RemoveSlidingWindowMarkers()(stmt); return stmt; } diff --git a/src/StrictifyFloat.cpp b/src/StrictifyFloat.cpp index 37263d00c89b..9deb86679808 100644 --- a/src/StrictifyFloat.cpp +++ b/src/StrictifyFloat.cpp @@ -127,7 +127,7 @@ class AnyStrictIntrinsics : public IRVisitor { } // namespace Expr strictify_float(const Expr &e) { - return Strictify{}.mutate(e); + return Strictify{}(e); } Expr unstrictify_float(const Call *op) { diff --git a/src/StripAsserts.cpp b/src/StripAsserts.cpp index e8f101fcfc0c..f1ffb8967ca9 100644 --- a/src/StripAsserts.cpp +++ b/src/StripAsserts.cpp @@ -109,7 +109,7 @@ class StripAsserts : public IRMutator { } // namespace Stmt strip_asserts(const Stmt &s) { - return StripAsserts().mutate(s); + return StripAsserts()(s); } } // namespace Internal diff --git a/src/Substitute.cpp b/src/Substitute.cpp index 9b280b0b0483..df71101d0b07 100644 --- a/src/Substitute.cpp +++ b/src/Substitute.cpp @@ -12,6 +12,7 @@ using std::string; namespace { class Substitute : public IRMutator { +protected: const map &replace; Scope<> hidden; @@ -104,24 +105,24 @@ Expr substitute(const string &name, const Expr &replacement, const Expr &expr) { map m; m[name] = replacement; Substitute s(m); - return s.mutate(expr); + return s(expr); } Stmt substitute(const string &name, const Expr &replacement, const Stmt &stmt) { map m; m[name] = replacement; Substitute s(m); - return s.mutate(stmt); + return s(stmt); } Expr substitute(const map &m, const Expr &expr) { Substitute s(m); - return s.mutate(expr); + return s(expr); } Stmt substitute(const map &m, const Stmt &stmt) { Substitute s(m); - return s.mutate(stmt); + return s(stmt); } namespace { @@ -139,10 +140,12 @@ auto substitute_impl(const Expr &find, const Expr &replacement, const T &ir) { } // namespace Expr substitute(const Expr &find, const Expr &replacement, const Expr &expr) { + ZoneScoped; return substitute_impl(find, replacement, expr); } Stmt substitute(const Expr &find, const Expr &replacement, const Stmt &stmt) { + ZoneScoped; return substitute_impl(find, replacement, stmt); } @@ -150,6 +153,7 @@ namespace { /** Substitute an expr for a var in a graph. */ class GraphSubstitute : public IRGraphMutator { +protected: string var; Expr value; @@ -202,25 +206,29 @@ class GraphSubstituteExpr : public IRGraphMutator { } // namespace Expr graph_substitute(const string &name, const Expr &replacement, const Expr &expr) { - return GraphSubstitute(name, replacement).mutate(expr); + ZoneScoped; + return GraphSubstitute(name, replacement)(expr); } Stmt graph_substitute(const string &name, const Expr &replacement, const Stmt &stmt) { - return GraphSubstitute(name, replacement).mutate(stmt); + ZoneScoped; + return GraphSubstitute(name, replacement)(stmt); } Expr graph_substitute(const Expr &find, const Expr &replacement, const Expr &expr) { - return GraphSubstituteExpr(find, replacement).mutate(expr); + ZoneScoped; + return GraphSubstituteExpr(find, replacement)(expr); } Stmt graph_substitute(const Expr &find, const Expr &replacement, const Stmt &stmt) { - return GraphSubstituteExpr(find, replacement).mutate(stmt); + ZoneScoped; + return GraphSubstituteExpr(find, replacement)(stmt); } namespace { class SubstituteInAllLets : public IRGraphMutator { - +protected: using IRGraphMutator::visit; Expr visit(const Let *op) override { @@ -233,11 +241,13 @@ class SubstituteInAllLets : public IRGraphMutator { } // namespace Expr substitute_in_all_lets(const Expr &expr) { - return SubstituteInAllLets().mutate(expr); + ZoneScoped; + return SubstituteInAllLets()(expr); } Stmt substitute_in_all_lets(const Stmt &stmt) { - return SubstituteInAllLets().mutate(stmt); + ZoneScoped; + return SubstituteInAllLets()(stmt); } } // namespace Internal diff --git a/src/Tracing.cpp b/src/Tracing.cpp index 0bc9086d8635..59c310f757f2 100644 --- a/src/Tracing.cpp +++ b/src/Tracing.cpp @@ -350,7 +350,7 @@ Stmt inject_tracing(Stmt s, const string &pipeline_name, bool trace_pipeline, } // Inject tracing calls - s = tracing.mutate(s); + s = tracing(s); // Strip off the dummy realize blocks s = mutate_with(s, [&](auto *self, const Realize *op) { diff --git a/src/TrimNoOps.cpp b/src/TrimNoOps.cpp index 1842a702fab4..13d358a4f0bd 100644 --- a/src/TrimNoOps.cpp +++ b/src/TrimNoOps.cpp @@ -105,7 +105,7 @@ class IsNoOp : public IRVisitor { Expr equivalent_load = Load::make(op->value.type(), op->name, op->index, Buffer<>(), Parameter(), op->predicate, op->alignment); Expr is_no_op = equivalent_load == op->value; - is_no_op = StripIdentities().mutate(is_no_op); + is_no_op = StripIdentities()(is_no_op); // We need to call CSE since sometimes we have "let" stmt on the RHS // that makes the expr harder to solve, i.e. the solver will just give up // and return a conservative false on call to and_condition_over_domain(). @@ -405,7 +405,7 @@ class TrimNoOps : public IRMutator { // Simplify the body to take advantage of the fact that the // loop range is now truncated - body = simplify(SimplifyUsingBounds(op->name, i).mutate(body)); + body = simplify(SimplifyUsingBounds(op->name, i)(body)); string new_min_name = unique_name(op->name + ".new_min"); string new_max_name = unique_name(op->name + ".new_max"); @@ -445,9 +445,8 @@ class TrimNoOps : public IRMutator { } // namespace -Stmt trim_no_ops(Stmt s) { - s = TrimNoOps().mutate(s); - return s; +Stmt trim_no_ops(const Stmt &s) { + return TrimNoOps()(s); } } // namespace Internal diff --git a/src/TrimNoOps.h b/src/TrimNoOps.h index 51d264cd03fb..548e2d383857 100644 --- a/src/TrimNoOps.h +++ b/src/TrimNoOps.h @@ -13,7 +13,7 @@ namespace Internal { /** Truncate loop bounds to the region over which they actually do * something. For examples see test/correctness/trim_no_ops.cpp */ -Stmt trim_no_ops(Stmt s); +Stmt trim_no_ops(const Stmt &s); } // namespace Internal } // namespace Halide diff --git a/src/UniquifyVariableNames.cpp b/src/UniquifyVariableNames.cpp index 91f0279de04c..2c6cc2e60178 100644 --- a/src/UniquifyVariableNames.cpp +++ b/src/UniquifyVariableNames.cpp @@ -5,7 +5,6 @@ #include "IRVisitor.h" #include "Scope.h" #include "Var.h" -#include namespace Halide { namespace Internal { @@ -16,7 +15,7 @@ using std::vector; namespace { class UniquifyVariableNames : public IRMutator { - +protected: using IRMutator::visit; // The mapping from old names to new names @@ -119,7 +118,7 @@ class UniquifyVariableNames : public IRMutator { }; class FindFreeVars : public IRVisitor { - +protected: using IRVisitor::visit; Scope<> scope; @@ -166,10 +165,10 @@ class FindFreeVars : public IRVisitor { } // namespace Stmt uniquify_variable_names(const Stmt &s) { + ZoneScoped; FindFreeVars finder; s.accept(&finder); - UniquifyVariableNames u(&finder.free_vars); - return u.mutate(s); + return UniquifyVariableNames(&finder.free_vars)(s); } namespace { diff --git a/src/UnrollLoops.cpp b/src/UnrollLoops.cpp index ffcba564966a..a1e736e12527 100644 --- a/src/UnrollLoops.cpp +++ b/src/UnrollLoops.cpp @@ -1,4 +1,5 @@ #include "UnrollLoops.h" +#include "CompilerProfiling.h" #include "IRMutator.h" #include "IROperator.h" #include "Simplify.h" @@ -11,16 +12,17 @@ namespace Internal { namespace { class UnrollLoops : public IRMutator { +protected: using IRMutator::visit; Stmt visit(const For *for_loop) override { if (for_loop->for_type == ForType::Unrolled) { + ZoneScopedN("UnrollForLoop"); Stmt body = for_loop->body; Expr extent = simplify(for_loop->extent()); const IntImm *e = extent.as(); - internal_assert(e) - << "Loop over " << for_loop->name << " should have had a constant extent\n"; + internal_assert(e) << "Loop over " << for_loop->name << " should have had a constant extent\n"; body = mutate(body); if (e->value == 1) { @@ -53,7 +55,8 @@ class UnrollLoops : public IRMutator { } // namespace Stmt unroll_loops(const Stmt &s) { - Stmt stmt = UnrollLoops().mutate(s); + ZoneScoped; + Stmt stmt = UnrollLoops()(s); // Unrolling duplicates variable names. Other passes assume variable names are unique. return uniquify_variable_names(stmt); } diff --git a/src/UnsafePromises.cpp b/src/UnsafePromises.cpp index c1fdc51d8758..02912594d49b 100644 --- a/src/UnsafePromises.cpp +++ b/src/UnsafePromises.cpp @@ -57,11 +57,11 @@ class LowerSafePromises : public IRMutator { } // namespace Stmt lower_unsafe_promises(const Stmt &s, const Target &t) { - return LowerUnsafePromises(t.has_feature(Target::CheckUnsafePromises)).mutate(s); + return LowerUnsafePromises(t.has_feature(Target::CheckUnsafePromises))(s); } Stmt lower_safe_promises(const Stmt &s) { - return LowerSafePromises().mutate(s); + return LowerSafePromises()(s); } } // namespace Internal diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 2d149adbaf20..3c230c9c6886 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -309,6 +309,7 @@ bool is_interleaved_ramp(const Expr &e, const Scope &scope, InterleavedRam // vector lane. This means loads and stores to them need to be // rewritten slightly. class RewriteAccessToVectorAlloc : public IRMutator { +protected: Expr var; string alloc; int lanes; @@ -363,6 +364,7 @@ class SerializeLoops : public IRMutator { // Wrap a vectorized predicate around a Load/Store node. class PredicateLoadStore : public IRMutator { +protected: string var; Expr vector_predicate; int lanes; @@ -480,6 +482,7 @@ struct VectorizedVar { // Substitutes a vector for a scalar var in a Stmt. Used on the // body of every vectorized loop. class VectorSubs : public IRMutator { +protected: // A list of vectorized loop vars encountered so far. The last // element corresponds to the most inner vectorized loop. std::vector vectorized_vars; @@ -862,12 +865,12 @@ class VectorSubs : public IRMutator { Stmt predicated_stmt; if (vectorize_predicate) { PredicateLoadStore p(vectorized_vars.front().name, cond); - predicated_stmt = p.mutate(then_case); + predicated_stmt = p(then_case); vectorize_predicate = p.is_vectorized(); } if (vectorize_predicate && else_case.defined()) { PredicateLoadStore p(vectorized_vars.front().name, !cond); - predicated_stmt = Block::make(predicated_stmt, p.mutate(else_case)); + predicated_stmt = Block::make(predicated_stmt, p(else_case)); vectorize_predicate = p.is_vectorized(); } @@ -1072,7 +1075,7 @@ class VectorSubs : public IRMutator { // Rewrite loads and stores to this allocation like so: // foo[x] -> foo[x*lanes + v] for (const auto &vv : vectorized_vars) { - body = RewriteAccessToVectorAlloc(vv.name + ".from_zero", op->name, vv.lanes).mutate(body); + body = RewriteAccessToVectorAlloc(vv.name + ".from_zero", op->name, vv.lanes)(body); } body = mutate(body); @@ -1314,7 +1317,7 @@ class VectorSubs : public IRMutator { // better luck vectorizing it. if (serialize_inner_loops) { - s = SerializeLoops().mutate(s); + s = SerializeLoops()(s); } // We'll need the original scalar versions of any containing lets. for (const auto &[var, value] : reverse_view(containing_lets)) { @@ -1410,6 +1413,7 @@ class VectorSubs : public IRMutator { }; // namespace class FindVectorizableExprsInAtomicNode : public IRMutator { +protected: // An Atomic node protects all accesses to a given buffer. We // consider a name "poisoned" if it depends on an access to this // buffer. We can't lift or vectorize anything that has been @@ -1502,6 +1506,7 @@ class FindVectorizableExprsInAtomicNode : public IRMutator { }; class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator { +protected: const std::set &liftable; using IRMutator::visit; @@ -1556,6 +1561,7 @@ class LiftVectorizableExprsOutOfSingleAtomicNode : public IRMutator { }; class LiftVectorizableExprsOutOfAllAtomicNodes : public IRMutator { +protected: using IRMutator::visit; Stmt visit(const Atomic *op) override { @@ -1582,6 +1588,7 @@ class LiftVectorizableExprsOutOfAllAtomicNodes : public IRMutator { // Vectorize all loops marked as such in a Stmt class VectorizeLoops : public IRMutator { +protected: using IRMutator::visit; Stmt visit(const For *for_loop) override { @@ -1597,7 +1604,7 @@ class VectorizeLoops : public IRMutator { } VectorizedVar vectorized_var = {for_loop->name, for_loop->min, (int)extent->value}; - stmt = VectorSubs(vectorized_var).mutate(for_loop->body); + stmt = VectorSubs(vectorized_var)(for_loop->body); } else { stmt = IRMutator::visit(for_loop); } @@ -1622,6 +1629,7 @@ class AllStoresInScope : public IRVisitor { } }; bool all_stores_in_scope(const Stmt &stmt, const Scope<> &scope) { + ZoneScoped; AllStoresInScope checker(scope); stmt.accept(&checker); return checker.result; @@ -1630,6 +1638,7 @@ bool all_stores_in_scope(const Stmt &stmt, const Scope<> &scope) { /** Drop any atomic nodes protecting buffers that are only accessed * from a single thread. */ class RemoveUnnecessaryAtomics : public IRMutator { +protected: using IRMutator::visit; // Allocations made from within this same thread @@ -1664,17 +1673,19 @@ class RemoveUnnecessaryAtomics : public IRMutator { }; Stmt vectorize_statement(const Stmt &stmt) { - return VectorizeLoops().mutate(stmt); + ZoneScoped; + return VectorizeLoops()(stmt); } } // namespace Stmt vectorize_loops(const Stmt &stmt, const map &env) { + ZoneScoped; // Limit the scope of atomic nodes to just the necessary stuff. // TODO: Should this be an earlier pass? It's probably a good idea // for non-vectorizing stuff too. - Stmt s = LiftVectorizableExprsOutOfAllAtomicNodes(env).mutate(stmt); + Stmt s = LiftVectorizableExprsOutOfAllAtomicNodes(env)(stmt); s = vectorize_statement(s); - s = RemoveUnnecessaryAtomics().mutate(s); + s = RemoveUnnecessaryAtomics()(s); return s; }