diff --git a/.gitignore b/.gitignore index 368a457c85c2..d8a64849438f 100644 --- a/.gitignore +++ b/.gitignore @@ -208,6 +208,8 @@ vcpkg_installed/ ################################################################################ ## IDE directories and metadata +.ccls-cache/ + # Visual Studio .vs/ out/ diff --git a/Makefile b/Makefile index 4e5cb36364f8..da2abf49fd1f 100644 --- a/Makefile +++ b/Makefile @@ -1473,7 +1473,7 @@ $(BIN_DIR)/$(TARGET)/correctness_opencl_runtime: $(ROOT_DIR)/test/correctness/op $(BIN_DIR)/performance_%: $(ROOT_DIR)/test/performance/%.cpp $(TEST_DEPS) $(CXX) $(TEST_CXX_FLAGS) $(OPTIMIZE) $< -I$(INCLUDE_DIR) -I$(ROOT_DIR)/src/runtime -I$(ROOT_DIR)/test/common $(TEST_LD_FLAGS) -o $@ -$(BIN_DIR)/fuzz_%: $(ROOT_DIR)/test/fuzz/%.cpp $(ROOT_DIR)/test/fuzz/halide_fuzz_main.cpp $(ROOT_DIR)/test/fuzz/fuzz_helpers.h $(ROOT_DIR)/test/fuzz/halide_fuzz_main.h $(TEST_DEPS) +$(BIN_DIR)/fuzz_%: $(ROOT_DIR)/test/fuzz/%.cpp $(ROOT_DIR)/test/fuzz/halide_fuzz_main.cpp $(ROOT_DIR)/test/fuzz/fuzz_helpers.h $(ROOT_DIR)/test/fuzz/halide_fuzz_main.h $(ROOT_DIR)/test/fuzz/ExprInterpreter.cpp $(ROOT_DIR)/test/fuzz/ExprInterpreter.h $(TEST_DEPS) $(CXX) $(TEST_CXX_FLAGS) -I$(ROOT_DIR)/src/runtime -I$(ROOT_DIR)/test/common $(OPTIMIZE_FOR_BUILD_TIME) $(filter %.cpp,$^) -I$(INCLUDE_DIR) $(TEST_LD_FLAGS) -o $@ -DHALIDE_FUZZER_BACKEND=0 # Error tests that link against libHalide diff --git a/src/IR.h b/src/IR.h index 3666581803db..4a473edb61f6 100644 --- a/src/IR.h +++ b/src/IR.h @@ -883,6 +883,30 @@ struct Call : public ExprNode { Call::strict_sub}); } + /** Does not include the strict_float intrinsics. */ + bool is_arithmetic_intrinsic() const { + return is_intrinsic( + {Call::widen_right_add, + Call::widen_right_mul, + Call::widen_right_sub, + Call::widening_add, + Call::widening_mul, + Call::widening_sub, + Call::saturating_add, + Call::saturating_sub, + Call::saturating_cast, + Call::widening_shift_left, + Call::widening_shift_right, + Call::rounding_shift_right, + Call::rounding_shift_left, + Call::halving_add, + Call::halving_sub, + Call::rounding_halving_add, + Call::rounding_mul_shift_right, + Call::mul_shift_right, + Call::sorted_avg}); + } + static const IRNodeType _node_type = IRNodeType::Call; }; diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt index 97b8372cf2ab..b2b6711caf6f 100644 --- a/test/fuzz/CMakeLists.txt +++ b/test/fuzz/CMakeLists.txt @@ -36,18 +36,20 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t *Data, std::size_t Size) { } ]] HAVE_LIBFUZZER_FLAGS) -add_library(Halide_fuzz INTERFACE) +add_library(Halide_fuzz OBJECT) add_library(Halide::fuzz ALIAS Halide_fuzz) +target_sources(Halide_fuzz PRIVATE ExprInterpreter.cpp) +target_link_libraries(Halide_fuzz PRIVATE Halide::Halide Halide::Test) if (NOT HAVE_LIBFUZZER_FLAGS) if (LIB_FUZZING_ENGINE) message(FATAL_ERROR "Cannot set LIB_FUZZING_ENGINE when not building with -fsanitize=fuzzer or a compatible fuzzing engine.") endif () - target_sources(Halide_fuzz INTERFACE halide_fuzz_main.cpp halide_fuzz_main.h) - target_compile_definitions(Halide_fuzz INTERFACE HALIDE_FUZZER_BACKEND=HALIDE_FUZZER_BACKEND_STDLIB) + target_sources(Halide_fuzz PRIVATE halide_fuzz_main.cpp halide_fuzz_main.h) + target_compile_definitions(Halide_fuzz PUBLIC HALIDE_FUZZER_BACKEND=HALIDE_FUZZER_BACKEND_STDLIB) else () target_link_libraries(Halide_fuzz INTERFACE ${LIB_FUZZING_ENGINE}) - target_compile_definitions(Halide_fuzz INTERFACE HALIDE_FUZZER_BACKEND=HALIDE_FUZZER_BACKEND_LIBFUZZER) + target_compile_definitions(Halide_fuzz PUBLIC HALIDE_FUZZER_BACKEND=HALIDE_FUZZER_BACKEND_LIBFUZZER) endif () foreach (fuzzer IN LISTS TEST_NAMES) diff --git a/test/fuzz/ExprInterpreter.cpp b/test/fuzz/ExprInterpreter.cpp new file mode 100644 index 000000000000..deb192d214f9 --- /dev/null +++ b/test/fuzz/ExprInterpreter.cpp @@ -0,0 +1,948 @@ +#include "ExprInterpreter.h" + +#include +#include +#include + +namespace Halide { +namespace Internal { + +bool has_undefined_overflow(Type t) { + return t.is_int() && t.bits() >= 32; +} + +std::ostream &operator<<(std::ostream &o, const ExprInterpreter::EvalValue &val) { + o << "(" << val.type << ")"; + if (val.lanes.size() > 1) { + o << "["; + } + bool first = true; + for (const auto &l : val.lanes) { + if (!first) { + o << ","; + } + first = false; + std::visit( + [&o](auto x) { + o << x; + }, + l); + } + if (val.lanes.size() > 1) { + o << "]"; + } + if (val.did_overflow) { + o << " (did overflow)"; + } + return o; +} + +bool ExprInterpreter::EvalValue::is_close(const ExprInterpreter::EvalValue &o, double threshold) const { + internal_assert(type.is_float()); + internal_assert(type == o.type); + for (int i = 0; i < lanes.size(); ++i) { + if (std::abs(std::get(lanes[i]) - std::get(o.lanes[i])) > threshold) { + return false; + } + } + return true; +} + +bool ExprInterpreter::EvalValue::operator==(const ExprInterpreter::EvalValue &o) const { + internal_assert(type == o.type); + for (int i = 0; i < lanes.size(); ++i) { + bool equal = std::visit( + [&](auto x) { + return x == std::get>(o.lanes[i]); + }, + lanes[i]); + if (!equal) { + return false; + } + } + return true; +} + +ExprInterpreter::EvalValue::EvalValue(Type t) : type(t), lanes(t.lanes()) { + for (int i = 0; i < t.lanes(); ++i) { + if (t.is_float()) { + lanes[i] = double{0.0}; + } else if (t.is_int()) { + lanes[i] = int64_t{0}; + } else { + lanes[i] = uint64_t{0}; + } + } +} + +template +ExprInterpreter::EvalValue ExprInterpreter::apply_unary(Type t, const EvalValue &a, F f) { + EvalValue res(t); + for (int i = 0; i < t.lanes(); ++i) { + res.lanes[i] = std::visit( + [&f, &t, &res](auto x) -> Scalar { + bool overflow = false; + auto out = f(x, overflow); + res.did_overflow |= overflow; + if (t.is_float()) { + return static_cast(out); + } + if (t.is_int()) { + return static_cast(out); + } + return static_cast(out); + }, + a.lanes[i]); + } + res.did_overflow = a.did_overflow; + return res; +} + +template +ExprInterpreter::EvalValue ExprInterpreter::apply_binary(Type t, const EvalValue &a, const EvalValue &b, F f) { + EvalValue res(t); + if constexpr (StrictTypeMatch) { + internal_assert(a.type == b.type) << "Binary Type mismatch " << a.type << " != " << b.type; + } else { + internal_assert(a.type.lanes() == b.type.lanes()) << "Lanes mismatch " << a.type << " != " << b.type; + } + for (int i = 0; i < t.lanes(); ++i) { + res.lanes[i] = std::visit( + [&f, &t, &res](auto x, auto y) -> Scalar { + if constexpr (!StrictTypeMatch || std::is_same_v) { + bool overflow = false; + auto out = f(x, y, overflow); + res.did_overflow |= overflow; + if (t.is_float()) { + return static_cast(out); + } + if (t.is_int()) { + return static_cast(out); + } + return static_cast(out); + } else { + internal_error << "Binary operator has incompatible types"; + } + }, + a.lanes[i], b.lanes[i]); + } + res.did_overflow |= a.did_overflow || b.did_overflow; + return res; +} + +template +ExprInterpreter::EvalValue ExprInterpreter::apply_cmp(Type t, const EvalValue &a, const EvalValue &b, F f) { + EvalValue res(t); + res.did_overflow = a.did_overflow || b.did_overflow; + internal_assert(a.type == b.type); + for (int i = 0; i < t.lanes(); ++i) { + res.lanes[i] = std::visit( + [&f, &t](auto x, auto y) -> Scalar { + if constexpr (std::is_same_v) { + static_assert(std::is_same_v); + bool out = f(x, y); + return static_cast(out); + } else { + internal_error << "Binary operator type mismatch"; + } + }, + a.lanes[i], b.lanes[i]); + } + return res; +} + +ExprInterpreter::EvalValue ExprInterpreter::eval(const Expr &e) { + if (!e.defined()) { + return EvalValue(); + } + e.accept(this); + internal_assert(result.lanes.size() == result.type.lanes()); + truncate(result); + debug(2) << "Evaluated " << e << " to be " << result << "\n"; + return result; +} + +void ExprInterpreter::truncate(EvalValue &v) { + int b = v.type.bits(); + + // Floats do not overflow/truncate in the same way, + // and shifts >= 64 are Undefined Behavior in C++. + if (v.type.is_float() || b >= 64) { + return; + } + + uint64_t mask = (1ULL << b) - 1; + uint64_t sign_bit = 1ULL << (b - 1); + + for (int j = 0; j < v.type.lanes(); j++) { + std::visit( + [&](auto &x) { + // Only apply truncation to integer variants (int64_t, uint64_t) + if constexpr (std::is_integral_v>) { + uint64_t u = static_cast(x) & mask; + + // If the underlying variant is signed, perform sign-extension + if constexpr (std::is_signed_v>) { + if (u & sign_bit) { + u |= ~mask; + } + } + + x = static_cast>(u); + } + }, + v.lanes[j]); + } +} + +void ExprInterpreter::visit(const IntImm *op) { + result = EvalValue(op->type); + result.lanes[0] = (int64_t)op->value; +} + +void ExprInterpreter::visit(const UIntImm *op) { + result = EvalValue(op->type); + result.lanes[0] = (uint64_t)op->value; +} + +void ExprInterpreter::visit(const FloatImm *op) { + result = EvalValue(op->type); + result.lanes[0] = (double)op->value; +} + +void ExprInterpreter::visit(const StringImm *op) { + internal_error << "Cannot evaluate StringImm as a vector representation."; +} + +void ExprInterpreter::visit(const Variable *op) { + auto it = var_env.find(op->name); + if (it != var_env.end()) { + result = it->second; + } else { + internal_error << "Unbound variable in ExprInterpreter: " << op->name; + } +} + +void ExprInterpreter::visit(const Cast *op) { + result = apply_unary(op->type, eval(op->value), [&op](auto x, bool &overflow) { + if (has_undefined_overflow(op->type)) { + if (!op->type.can_represent(x)) { + overflow = true; + } + } + return x; + }); +} + +void ExprInterpreter::visit(const Reinterpret *op) { + EvalValue val = eval(op->value); + result = EvalValue(op->type); + + int in_lanes = val.type.lanes(); + int in_bits = val.type.bits(); + int in_bytes = in_bits / 8; + + int out_lanes = op->type.lanes(); + int out_bits = op->type.bits(); + int out_bytes = out_bits / 8; + + int total_bytes = std::max(1, (in_bits * in_lanes) / 8); + if (in_bytes == 0) { + in_bytes = 1; + } + if (out_bytes == 0) { + out_bytes = 1; + } + + std::vector buffer(total_bytes, 0); + + for (int j = 0; j < in_lanes; j++) { + char *dst = buffer.data() + j * in_bytes; + std::visit( + [&](auto x) { + if constexpr (std::is_floating_point_v) { + if (in_bits == 32) { + float f = static_cast(x); + std::memcpy(dst, &f, 4); + } else if (in_bits == 64) { + std::memcpy(dst, &x, 8); + } else { + internal_error << "Unsupported float bit width in Reinterpret input"; + } + } else { + uint64_t u = static_cast(x); + std::memcpy(dst, &u, in_bytes); + } + }, + val.lanes[j]); + } + + for (int j = 0; j < out_lanes; j++) { + const char *src = buffer.data() + j * out_bytes; + if (op->type.is_float()) { + if (out_bits == 32) { + float f = 0.0f; + std::memcpy(&f, src, 4); + result.lanes[j] = static_cast(f); + } else if (out_bits == 64) { + double f = 0.0; + std::memcpy(&f, src, 8); + result.lanes[j] = f; + } else { + internal_error << "Unsupported float bit width in Reinterpret output"; + } + } else if (op->type.is_int()) { + uint64_t u = 0; + std::memcpy(&u, src, out_bytes); + result.lanes[j] = static_cast(u); + } else { + uint64_t u = 0; + std::memcpy(&u, src, out_bytes); + result.lanes[j] = u; + } + } +} + +void ExprInterpreter::visit(const Add *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [&](auto x, auto y, bool &overflow) -> decltype(x) { + if (has_undefined_overflow(op->type)) { + int64_t r; + overflow = !add_with_overflow(op->a.type().bits(), x, y, &r); + return r; + } else { + return x + y; + } + }); +} +void ExprInterpreter::visit(const Sub *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [&](auto x, auto y, bool &overflow) -> decltype(x) { + if (has_undefined_overflow(op->type)) { + int64_t r; + overflow = !sub_with_overflow(op->a.type().bits(), x, y, &r); + return r; + } else { + return x - y; + } + }); +} +void ExprInterpreter::visit(const Mul *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [&](auto x, auto y, bool &overflow) -> decltype(x) { + if (has_undefined_overflow(op->type)) { + int64_t r; + overflow = !mul_with_overflow(op->a.type().bits(), x, y, &r); + return r; + } else { + return x * y; + } + }); +} +void ExprInterpreter::visit(const Min *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y, bool &) { return std::min(x, y); }); +} +void ExprInterpreter::visit(const Max *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y, bool &) { return std::max(x, y); }); +} + +void ExprInterpreter::visit(const EQ *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x == y; }); +} +void ExprInterpreter::visit(const NE *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x != y; }); +} +void ExprInterpreter::visit(const LT *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x < y; }); +} +void ExprInterpreter::visit(const LE *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x <= y; }); +} +void ExprInterpreter::visit(const GT *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x > y; }); +} +void ExprInterpreter::visit(const GE *op) { + result = apply_cmp(op->type, eval(op->a), eval(op->b), [](auto x, auto y) { return x >= y; }); +} + +void ExprInterpreter::visit(const Div *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y, bool &) { + if constexpr (std::is_floating_point_v) { + return x / y; + } else if constexpr (std::is_signed_v) { + if (y == 0) return decltype(x){0}; + // Prevent C++ hardware crash (SIGFPE) on INT_MIN / -1 + if (y == -1) return static_cast(~static_cast(x) + 1); + + auto q = x / y; + auto r = x % y; + + // Euclidean division correction: if the C++ remainder is negative, + // the quotient must shift so the remainder becomes positive. + if (r < 0) { + q += (y < 0) ? 1 : -1; + } + return q; + } else { + if (y == 0) return decltype(x){0}; + return x / y; + } + }); +} + +void ExprInterpreter::visit(const Mod *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y, bool &) { + if constexpr (std::is_floating_point_v) { + // Halide doc states floats fallback to fmod + if (y == 0.0) return decltype(x){0}; + + auto r = std::fmod(x, y); + + // Guarantee positive remainder for floats as well + if (r < 0) { + r += std::abs(y); + } + + return r; + + } else if constexpr (std::is_signed_v) { + if (y == 0) return decltype(x){0}; + + // Prevent C++ hardware crash (SIGFPE) on INT_MIN % -1 + if (y == -1) return decltype(x){0}; + + auto r = x % y; + + // Euclidean modulo correction: + // If the C++ remainder is negative, add the absolute value of the divisor. + if (r < 0) { + r += (y < 0) ? -y : y; + } + + return r; + + } else { + // Unsigned integers natively produce positive remainders + // and cannot be negative. + if (y == 0) return decltype(x){0}; + + return x % y; + } + }); +} + +void ExprInterpreter::visit(const And *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y, bool &) { + if constexpr (std::is_integral_v) { + return x & y; + } else { + internal_error << "Bitwise AND on floats"; + return x; + } + }); +} + +void ExprInterpreter::visit(const Or *op) { + result = apply_binary(op->type, eval(op->a), eval(op->b), [](auto x, auto y, bool &) { + if constexpr (std::is_integral_v) { + return x | y; + } else { + internal_error << "Bitwise OR on floats"; + return x; + } + }); +} + +void ExprInterpreter::visit(const Not *op) { + result = apply_unary(op->type, eval(op->a), [](auto x, bool &) { + if constexpr (std::is_integral_v) { + return ~x; + } else { + internal_error << "Bitwise NOT on floats"; + return x; + } + }); +} + +void ExprInterpreter::visit(const Select *op) { + EvalValue cond = eval(op->condition), t = eval(op->true_value), f = eval(op->false_value); + result = EvalValue(op->type); + result.did_overflow = cond.did_overflow || t.did_overflow || f.did_overflow; + for (int j = 0; j < op->type.lanes(); j++) { + bool c = std::visit([](auto x) { return x != 0; }, cond.lanes[j]); + result.lanes[j] = c ? t.lanes[j] : f.lanes[j]; + } +} + +void ExprInterpreter::visit(const Load *op) { + internal_error << "Load nodes are unsupported without memory mapping in ExprInterpreter."; +} + +void ExprInterpreter::visit(const Let *op) { + EvalValue val = eval(op->value); + auto old_val = var_env.find(op->name); + bool had_old = (old_val != var_env.end()); + EvalValue old; + if (had_old) { + old = old_val->second; + } + + var_env[op->name] = val; + result = eval(op->body); + + if (had_old) { + var_env[op->name] = old; + } else { + var_env.erase(op->name); + } +} + +void ExprInterpreter::visit(const Ramp *op) { + EvalValue base = eval(op->base), stride = eval(op->stride); + result = EvalValue(op->type); + result.did_overflow = base.did_overflow || stride.did_overflow; + + int n = base.type.lanes(); // The lane-width of the base and stride + + // ramp(b, s, l) = concat_vectors(b, b + s, b + 2*s, ... b + (l-1)*s) + for (int j = 0; j < op->lanes; j++) { + for (int k = 0; k < n; k++) { + std::visit( + [&](auto b, auto s) { + if constexpr (std::is_same_v) { + auto res = b + j * s; + if (has_undefined_overflow(op->type)) { + const int bits = op->type.bits(); + int64_t r; + bool overflow = false; + overflow |= !mul_with_overflow(bits, j, s, &r); + overflow |= !add_with_overflow(bits, b, r, &r); + result.did_overflow |= overflow; + } + if (op->type.is_float()) { + result.lanes[j * n + k] = static_cast(res); + } else if (op->type.is_int()) { + result.lanes[j * n + k] = static_cast(res); + } else { + result.lanes[j * n + k] = static_cast(res); + } + } else { + internal_error << "Ramp base and stride type mismatch"; + } + }, + base.lanes[k], stride.lanes[k]); + } + } +} + +void ExprInterpreter::visit(const Broadcast *op) { + EvalValue val = eval(op->value); + result = EvalValue(op->type); + result.did_overflow = val.did_overflow; + int v_lanes = op->value.type().lanes(); + for (int j = 0; j < op->lanes; j++) { + for (int k = 0; k < v_lanes; k++) { + result.lanes[j * v_lanes + k] = val.lanes[k]; + } + } +} + +void ExprInterpreter::visit(const Shuffle *op) { + result = EvalValue(op->type); + std::vector vecs; + vecs.reserve(op->vectors.size()); + for (const Expr &e : op->vectors) { + vecs.push_back(eval(e)); + } + + std::vector flat; + for (const EvalValue &v : vecs) { + result.did_overflow |= v.did_overflow; + for (int j = 0; j < v.type.lanes(); j++) { + flat.push_back(v.lanes[j]); + } + } + + for (int j = 0; j < (int)op->indices.size(); j++) { + int idx = op->indices[j]; + if (idx >= 0 && idx < (int)flat.size()) { + result.lanes[j] = flat[idx]; + } else { + internal_error << "Shuffle index out of bounds."; + } + } +} + +void ExprInterpreter::visit(const VectorReduce *op) { + EvalValue val = eval(op->value); + result = EvalValue(op->type); + result.did_overflow |= val.did_overflow; + int in_lanes = op->value.type().lanes(); + int out_lanes = op->type.lanes(); + int factor = in_lanes / out_lanes; + + bool check_overflow = has_undefined_overflow(op->type); + int bits = op->type.bits(); + + for (int j = 0; j < out_lanes; j++) { + Scalar res = val.lanes[j * factor]; + for (int k = 1; k < factor; k++) { + Scalar next = val.lanes[j * factor + k]; + res = std::visit( + [&](auto a, auto b) -> Scalar { + if constexpr (std::is_same_v) { + switch (op->op) { + case VectorReduce::Add: + if (check_overflow) { + result.did_overflow = add_would_overflow(bits, a, b); + } + return a + b; + case VectorReduce::Mul: + if (check_overflow) { + result.did_overflow = mul_would_overflow(bits, a, b); + } + return a * b; + case VectorReduce::Min: + return std::min(a, b); + case VectorReduce::Max: + return std::max(a, b); + case VectorReduce::And: + if constexpr (std::is_integral_v) { + return a & b; + } else { + internal_error << "And on floats"; + return a; + } + case VectorReduce::Or: + if constexpr (std::is_integral_v) { + return a | b; + } else { + internal_error << "Or on floats"; + return a; + } + default: + internal_error << "Unhandled VectorReduce op"; + return a; + } + } else { + internal_error << "VectorReduce type mismatch"; + return a; + } + }, + res, next); + } + + std::visit( + [&](auto x) { + if (op->type.is_float()) { + result.lanes[j] = static_cast(x); + } else if (op->type.is_int()) { + result.lanes[j] = static_cast(x); + } else { + result.lanes[j] = static_cast(x); + } + }, + res); + } +} + +void ExprInterpreter::visit(const Call *op) { + result = EvalValue(op->type); + std::vector args; + args.reserve(op->args.size()); + for (const Expr &e : op->args) { + auto arg_val = eval(e); + result.did_overflow |= arg_val.did_overflow; + args.push_back(std::move(arg_val)); + } + + if (op->is_intrinsic(Call::bitwise_and)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b, bool &) { + if constexpr (std::is_integral_v) { + return a & b; + } else { + internal_error << "bitwise_and on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::bitwise_or)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b, bool &) { + if constexpr (std::is_integral_v) { + return a | b; + } else { + internal_error << "bitwise_or on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::bitwise_xor)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b, bool &) { + if constexpr (std::is_integral_v) { + return a ^ b; + } else { + internal_error << "bitwise_xor on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::bitwise_not)) { + result = apply_unary(op->type, args[0], [](auto a, bool &) { + if constexpr (std::is_integral_v) { + return ~a; + } else { + internal_error << "bitwise_not on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::shift_left)) { + result = apply_binary(op->type, args[0], args[1], [&op](auto a, auto b, bool &overflow) { + if constexpr (std::is_integral_v && std::is_integral_v) { + if (has_undefined_overflow(op->type)) { + internal_assert((std::is_same_v)); + + int bits = op->type.bits(); + int64_t shift_amount = static_cast(b); + + // Shifting by a negative amount or >= the bit-width drops bits / triggers overflow + if (shift_amount < 0 || shift_amount >= bits) { + overflow = true; + } else { + // To avoid dropping the most significant bits (including the sign bit), + // 'a' must be strictly bounded by [-2^(bits - 1 - b), 2^(bits - 1 - b) - 1]. + // We use 1ULL to prevent C++ UB when shifting into the sign bit. + int64_t max_val = static_cast((1ULL << (bits - 1 - shift_amount)) - 1); + int64_t min_val = -max_val - 1; + + if (a < min_val || a > max_val) { + overflow = true; + } + } + } + + // Safe Evaluation (Preventing host C++ UB): + // 1. Cast 'a' to uint64_t because left-shifting negative signed values was UB prior to C++20. + // 2. Mask 'b' to 63 to prevent host hardware crashes when shift_amount >= 64. + uint64_t safe_b = static_cast(b) & 63; + return static_cast(static_cast(a) << safe_b); + } else { + internal_error << "shift_left on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::shift_right)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b, bool &) { + if constexpr (std::is_integral_v && std::is_integral_v) { + return a >> b; + } else { + internal_error << "shift_right on float"; + return a; + } + }); + } else if (op->is_intrinsic(Call::abs)) { + result = apply_unary(op->type, args[0], [](auto a, bool &) { + if constexpr (std::is_floating_point_v) { + return std::abs(a); + } else if constexpr (std::is_signed_v) { + if (a == std::numeric_limits::min()) { + return (uint64_t)(1ULL << 63); + } + return (uint64_t)std::abs(a); + } else { + return a; + } + }); + } else if (op->is_intrinsic(Call::bool_to_mask) || op->is_intrinsic(Call::cast_mask)) { + result = apply_unary(op->type, args[0], [](auto a, bool &) { + if constexpr (std::is_integral_v) { + return a ? static_cast(-1) : 0; + } else { + internal_error << "mask intrinsic on float"; + return int64_t{0}; + } + }); + } else if (op->is_intrinsic(Call::select_mask) || op->is_intrinsic({Call::if_then_else, Call::if_then_else_mask})) { + for (int j = 0; j < op->type.lanes(); j++) { + bool cond = std::visit([](auto x) { return x != 0; }, args[0].lanes[j]); + result.lanes[j] = cond ? args[1].lanes[j] : args[2].lanes[j]; + } + } else if (op->is_intrinsic({Call::likely, Call::likely_if_innermost, Call::promise_clamped, Call::unsafe_promise_clamped})) { + result = args[0]; + } else if (op->is_intrinsic({Call::return_second, Call::require})) { + result = args[1]; + } else if (starts_with(op->name, "sin_")) { + result = apply_unary(op->type, args[0], [](auto a, bool &) { return std::sin(a); }); + } else if (starts_with(op->name, "cos_")) { + result = apply_unary(op->type, args[0], [](auto a, bool &) { return std::cos(a); }); + } else if (starts_with(op->name, "exp_")) { + result = apply_unary(op->type, args[0], [](auto a, bool &) { return std::exp(a); }); + } else if (starts_with(op->name, "log_")) { + result = apply_unary(op->type, args[0], [](auto a, bool &) { return std::log(a); }); + } else if (starts_with(op->name, "sqrt_")) { + result = apply_unary(op->type, args[0], [](auto a, bool &) { return std::sqrt(a); }); + } else if (op->is_intrinsic(Call::strict_fma)) { + internal_assert(op->args.size() == 3); + internal_assert(op->args[0].type().is_float()); + for (int j = 0; j < op->type.lanes(); j++) { + result.lanes[j] = std::visit( + [&](auto a, auto b, auto c) -> Scalar { + if constexpr (std::is_same_v && std::is_same_v) { + auto out = std::fma(a, b, c); + if (op->type.is_float()) { + return static_cast(out); + } + if (op->type.is_int()) { + return static_cast(out); + } + return static_cast(out); + } else { + internal_error << "Type mismatch in strict_fma"; + return double{0}; + } + }, + args[0].lanes[j], args[1].lanes[j], args[2].lanes[j]); + } + } else if (op->is_strict_float_intrinsic()) { + Expr unstrict = unstrictify_float(op); + unstrict.accept(this); + } else if (op->is_arithmetic_intrinsic()) { + Expr lower = lower_intrinsic(op); + lower.accept(this); + } else if (op->is_intrinsic(Call::absd)) { + result = apply_binary(op->type, args[0], args[1], [](auto a, auto b, bool &) { + return a < b ? b - a : a - b; + }); + } else if (op->is_intrinsic(Call::signed_integer_overflow)) { + result = EvalValue(op->type); + result.did_overflow = true; + } else { + internal_error << "Unhandled Call intrinsic / function in ExprInterpreter: " << op->name; + } +} + +namespace { + +void test_scalar_equivalence() { + ExprInterpreter interp; + + // 1. Integer scalar math equivalence + auto math_test_int = [](const auto &x, const auto &y) { + // Keeps values positive to align C++ truncation division with Halide's Euclidean division + return (x + y) * (x - y) + (x / y) + (x % y); + }; + + int32_t cx = 42, cy = 5; + int32_t c_res = math_test_int(cx, cy); + + Expr hx = Expr(cx), hy = Expr(cy); + Expr h_ast = math_test_int(hx, hy); + + auto eval_res = interp.eval(h_ast); + internal_assert(eval_res.type.is_int() && eval_res.type.bits() == 32 && eval_res.type.lanes() == 1); + internal_assert(std::get(eval_res.lanes[0]) == c_res) + << "Integer scalar evaluation mismatch. Expected: " << c_res + << ", Got: " << std::get(eval_res.lanes[0]); + + // 2. Float scalar math equivalence + using Halide::sin; + using std::sin; + auto math_test_float = [](const auto &x, const auto &y) { + return (x * y) - sin(x / (y + 1.0f)); + }; + + float fx = 3.14f, fy = 2.0f; + float f_res = math_test_float(fx, fy); + + Expr hfx = Expr(fx), hfy = Expr(fy); + Expr hf_ast = math_test_float(hfx, hfy); + + auto eval_f_res = interp.eval(hf_ast); + internal_assert(eval_f_res.type.is_float() && eval_f_res.type.bits() == 32 && eval_f_res.type.lanes() == 1); + + double diff = std::abs(std::get(eval_f_res.lanes[0]) - f_res); + internal_assert(diff < 1e-5) << "Float scalar evaluation mismatch."; +} + +void test_vector_operations() { + ExprInterpreter interp; + + // 1. Ramp: create a vector <10, 13, 16, 19> + Expr base = Expr(10); + Expr stride = Expr(3); + Expr ramp = Ramp::make(base, stride, 4); + + auto eval_ramp = interp.eval(ramp); + internal_assert(eval_ramp.type.lanes() == 4); + internal_assert(std::get(eval_ramp.lanes[0]) == 10); + internal_assert(std::get(eval_ramp.lanes[1]) == 13); + internal_assert(std::get(eval_ramp.lanes[2]) == 16); + internal_assert(std::get(eval_ramp.lanes[3]) == 19); + + // 2. Broadcast: <5, 5, 5> + Expr bc = Broadcast::make(Expr(5), 3); + auto eval_bc = interp.eval(bc); + internal_assert(eval_bc.type.lanes() == 3); + internal_assert(std::get(eval_bc.lanes[0]) == 5); + internal_assert(std::get(eval_bc.lanes[1]) == 5); + internal_assert(std::get(eval_bc.lanes[2]) == 5); + + // 3. Shuffle: reverse the ramp -> <19, 16, 13, 10> + Expr reversed = Shuffle::make({ramp}, {3, 2, 1, 0}); + auto eval_rev = interp.eval(reversed); + internal_assert(eval_rev.type.lanes() == 4); + internal_assert(std::get(eval_rev.lanes[0]) == 19); + internal_assert(std::get(eval_rev.lanes[1]) == 16); + internal_assert(std::get(eval_rev.lanes[2]) == 13); + internal_assert(std::get(eval_rev.lanes[3]) == 10); + + // 4. VectorReduce: Sum the ramp -> 10 + 13 + 16 + 19 = 58 + Expr sum = VectorReduce::make(VectorReduce::Add, ramp, 1); + auto eval_sum = interp.eval(sum); + internal_assert(eval_sum.type.lanes() == 1); + internal_assert(std::get(eval_sum.lanes[0]) == 58); + + // 5. Ramp of Ramp + Expr ramp_of_ramp = Ramp::make(ramp, Broadcast::make(100, 4), 4); + auto eval_ror = interp.eval(ramp_of_ramp); + internal_assert(eval_ror.type.lanes() == 16); + for (int i = 0; i < 4; ++i) { + internal_assert(std::get(eval_ror.lanes[4 * i + 0]) == 100 * i + 10); + internal_assert(std::get(eval_ror.lanes[4 * i + 1]) == 100 * i + 13); + internal_assert(std::get(eval_ror.lanes[4 * i + 2]) == 100 * i + 16); + internal_assert(std::get(eval_ror.lanes[4 * i + 3]) == 100 * i + 19); + } + + // 6. Broadcast of Ramp + Expr bc_of_ramp = Broadcast::make(ramp, 5); + auto eval_bor = interp.eval(bc_of_ramp); + internal_assert(eval_bor.type.lanes() == 20); + for (int i = 0; i < 5; ++i) { + internal_assert(std::get(eval_bor.lanes[4 * i + 0]) == 10); + internal_assert(std::get(eval_bor.lanes[4 * i + 1]) == 13); + internal_assert(std::get(eval_bor.lanes[4 * i + 2]) == 16); + internal_assert(std::get(eval_bor.lanes[4 * i + 3]) == 19); + } +} + +void test_let_and_scoping() { + ExprInterpreter interp; + + // Test: let x = 42 in (let x = x + 8 in x * 2) + // Inner scoping should shadow outer scoping and evaluate cleanly + Expr var_x = Variable::make(Int(32), "x"); + Expr inner_let = Let::make("x", var_x + Expr(8), var_x * Expr(2)); + Expr outer_let = Let::make("x", Expr(42), inner_let); + + auto res = interp.eval(outer_let); + internal_assert(res.type.is_int() && res.type.lanes() == 1); + + // (42 + 8) * 2 = 100 + internal_assert(std::get(res.lanes[0]) == 100) + << "Variable scoping / Let evaluation failed."; +} +} // namespace + +void ExprInterpreter::test() { + test_scalar_equivalence(); + test_vector_operations(); + test_let_and_scoping(); + + std::cout << "ExprInterpreter tests passed!" << "\n"; +} + +} // namespace Internal +} // namespace Halide diff --git a/test/fuzz/ExprInterpreter.h b/test/fuzz/ExprInterpreter.h new file mode 100644 index 000000000000..2c847138b988 --- /dev/null +++ b/test/fuzz/ExprInterpreter.h @@ -0,0 +1,93 @@ +#ifndef HALIDE_INTERNAL_EXPR_INTERPRETER_H +#define HALIDE_INTERNAL_EXPR_INTERPRETER_H + +#include + +#include +#include +#include +#include + +namespace Halide { +namespace Internal { + +class ExprInterpreter : public IRVisitor { +public: + using Scalar = std::variant; + + struct EvalValue { + Type type; + std::vector lanes; + bool did_overflow{false}; + + EvalValue() = default; + explicit EvalValue(Type t); + + bool is_close(const EvalValue &o, double threshold) const; + bool operator==(const EvalValue &o) const; + bool operator!=(const EvalValue &o) const { + return !operator==(o); + } + }; + + std::map var_env; + EvalValue result; + + EvalValue eval(const Expr &e); + +protected: + using IRVisitor::visit; + void truncate(EvalValue &v); + + void visit(const IntImm *op) override; + void visit(const UIntImm *op) override; + void visit(const FloatImm *op) override; + void visit(const StringImm *op) override; + void visit(const Variable *op) override; + void visit(const Cast *op) override; + void visit(const Reinterpret *op) override; + void visit(const Add *op) override; + void visit(const Sub *op) override; + void visit(const Mul *op) override; + void visit(const Div *op) override; + void visit(const Mod *op) override; + void visit(const Min *op) override; + void visit(const Max *op) override; + void visit(const EQ *op) override; + void visit(const NE *op) override; + void visit(const LT *op) override; + void visit(const LE *op) override; + void visit(const GT *op) override; + void visit(const GE *op) override; + void visit(const And *op) override; + void visit(const Or *op) override; + void visit(const Not *op) override; + void visit(const Select *op) override; + void visit(const Load *op) override; + void visit(const Ramp *op) override; + void visit(const Broadcast *op) override; + void visit(const Call *op) override; + void visit(const Shuffle *op) override; + void visit(const VectorReduce *op) override; + void visit(const Let *op) override; + +private: + template + EvalValue apply_unary(Type t, const EvalValue &a, F f); + + template + EvalValue apply_binary(Type t, const EvalValue &a, const EvalValue &b, F f); + + template + EvalValue apply_cmp(Type t, const EvalValue &a, const EvalValue &b, F f); + +public: + static void test(); +}; + +std::ostream &operator<<(std::ostream &o, const ExprInterpreter::EvalValue &val); + +} // namespace Internal +} // namespace Halide + +#endif // HALIDE_INTERNAL_EXPR_INTERPRETER_H diff --git a/test/fuzz/simplify.cpp b/test/fuzz/simplify.cpp index 2ddc82e6913d..08e6fdbd03d6 100644 --- a/test/fuzz/simplify.cpp +++ b/test/fuzz/simplify.cpp @@ -1,6 +1,7 @@ #include "Halide.h" #include +#include "ExprInterpreter.h" #include "fuzz_helpers.h" #include "random_expr_generator.h" @@ -12,7 +13,7 @@ using std::string; using namespace Halide; using namespace Halide::Internal; -bool test_simplification(Expr a, Expr b, const map &vars) { +bool test_simplification(const Expr &a, const Expr &b, const map &vars) { if (equal(a, b) && !a.same_as(b)) { std::cerr << "Simplifier created new IR node but made no changes:\n" << a << "\n"; @@ -71,7 +72,7 @@ bool test_simplification(Expr a, Expr b, const map &vars) { return true; } -bool test_expression(RandomExpressionGenerator ®, Expr test, int samples) { +bool test_expression(RandomExpressionGenerator ®, const Expr &test, int samples) { Expr simplified = simplify(test); map vars; @@ -94,6 +95,53 @@ bool test_expression(RandomExpressionGenerator ®, Expr test, int samples) { return false; } } + + // Additionally test a few rounds with the ExprInterpreter to test + // if the simplification was correct. + for (int i = 0; i < samples; ++i) { + ExprInterpreter ei; + for (const auto &fuzz_var : reg.fuzz_vars) { + ExprInterpreter::EvalValue val(Int(32)); + val.lanes[0] = reg.fuzz.ConsumeIntegral(); + ei.var_env[fuzz_var.name()] = std::move(val); + } + ExprInterpreter::EvalValue eval_test = ei.eval(test); + ExprInterpreter::EvalValue eval_simplified = ei.eval(simplified); + if (eval_test.did_overflow || eval_simplified.did_overflow) { + // The expression interpreter detected overflow on types that are + // defined by halide to be not-overflowable. So the simplifier will + // have done transformations which don't hold when the numbers do overflow. + continue; // Try different numbers instead! + } + bool good = true; + if (eval_test.type != eval_simplified.type) { + good = false; + } else { + if (eval_test.type.is_float()) { + if (!eval_test.is_close(eval_simplified, 1e-5)) { + good = false; + } + } else { + if (eval_test != eval_simplified) { + good = false; + } + } + } + if (!good) { + std::cerr << "ExprInterpreter of the following Exprs did not match:\n\n"; + std::cerr << "Original: " << test << "\n"; + std::cerr << "Value: " << eval_test << "\n\n"; + std::cerr << "Simplified: " << simplified << "\n"; + std::cerr << "Value: " << eval_simplified << "\n\n"; + std::cerr << "With the following variables values:\n"; + for (const auto &var : ei.var_env) { + std::cerr << "\t" << var.first << " = " << var.second << "\n"; + } + + return false; + } + } + return true; } @@ -131,6 +179,7 @@ FUZZ_TEST(simplify, FuzzingContext &fuzz) { int width = fuzz.PickValueInArray({1, 2, 3, 4, 6, 8}); Expr test = reg.random_expr(reg.random_type(width), depth); + debug(1) << "Testing " << test << "\n"; if (!test_expression(reg, test, samples)) { // Failure. Find the minimal subexpression that failed.