-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathsimplify.cpp
More file actions
223 lines (199 loc) · 7.11 KB
/
simplify.cpp
File metadata and controls
223 lines (199 loc) · 7.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#include "Halide.h"
#include <functional>
#include "fuzz_helpers.h"
#include "random_expr_generator.h"
// Test the simplifier in Halide by testing for equivalence of randomly generated expressions.
namespace {
using std::map;
using std::string;
using namespace Halide;
using namespace Halide::Internal;
struct SimpilfyResult : public std::variant<Expr, InternalError> {
using std::variant<Expr, InternalError>::variant;
bool ok() const {
return index() == 0;
}
bool failed() const {
return index() == 1;
}
operator Expr() const {
return std::get<Expr>(*this);
}
};
SimpilfyResult safe_simplify(const Expr &e) {
try {
return simplify(e);
} catch (InternalError &err) {
std::cerr << "Simplifier failed to simplify expression:\n"
<< e << "\n";
std::cerr << err.what() << "\n";
return err;
}
}
bool test_simplification(Expr a, Expr b, const map<string, Expr> &vars) {
if (equal(a, b) && !a.same_as(b)) {
std::cerr << "Simplifier created new IR node but made no changes:\n"
<< a << "\n";
return false;
}
SimpilfyResult sb = safe_simplify(b);
if (sb.failed() || !equal(b, (Expr)sb)) {
// Test all sub-expressions in pre-order traversal to minimize
bool found_failure = false;
mutate_with(a, [&](auto *self, const Expr &e) {
self->mutate_base(e);
Expr s, ss;
if (SimpilfyResult res = safe_simplify(e); res.ok()) {
s = res;
} else {
found_failure = true;
return e;
}
if (SimpilfyResult res = safe_simplify(s); res.ok()) {
ss = res;
} else {
found_failure = true;
return e;
}
if (!found_failure && !equal(s, ss)) {
std::cerr << "Idempotency failure\n "
<< e << "\n -> "
<< s << "\n -> "
<< ss << "\n";
// These are broken out below to make it easier to parse any logging
// added to the simplifier to debug the failure.
std::cerr << "---------------------------------\n"
<< "Begin simplification of original:\n"
<< s << "\n";
std::cerr << "---------------------------------\n"
<< "Begin resimplification of result:\n"
<< ss << "\n"
<< "---------------------------------\n";
found_failure = true;
}
return e;
});
return false;
}
Expr a_v = substitute(vars, a);
if (SimpilfyResult res = safe_simplify(a_v); res.ok()) {
a_v = res;
} else {
return false;
}
Expr b_v = substitute(vars, b);
if (SimpilfyResult res = safe_simplify(b_v); res.ok()) {
b_v = res;
} else {
return false;
}
// If the simplifier didn't produce constants, there must be
// undefined behavior in this expression. Ignore it.
if (!Internal::is_const(a_v) || !Internal::is_const(b_v)) {
return true;
}
if (!equal(a_v, b_v)) {
std::cerr << "Simplified Expr is not equal() to Original Expr!\n";
for (const auto &[var, val] : vars) {
std::cerr << "Var " << var << " = " << val << "\n";
}
std::cerr << "Original Expr is: " << a << "\n";
std::cerr << "Simplified Expr is: " << b << "\n";
std::cerr << " " << a << " -> " << a_v << "\n";
std::cerr << " " << b << " -> " << b_v << "\n";
return false;
}
return true;
}
bool test_expression(RandomExpressionGenerator ®, Expr test, int samples) {
Expr simplified;
if (SimpilfyResult res = safe_simplify(test); res.ok()) {
simplified = res;
} else {
return false;
}
map<string, Expr> vars;
for (const auto &fuzz_var : reg.fuzz_vars) {
vars[fuzz_var.name()] = Expr();
}
for (int i = 0; i < samples; i++) {
for (auto &[var, val] : vars) {
constexpr size_t kMaxLeafIterations = 10000;
// Don't let the random leaf depend on v itself.
size_t iterations = 0;
do {
val = reg.random_leaf(Int(32), true);
iterations++;
} while (expr_uses_var(val, var) && iterations < kMaxLeafIterations);
}
if (!test_simplification(test, simplified, vars)) {
return false;
}
}
return true;
}
SimpilfyResult simplify_at_depth(int limit, const Expr &in) {
try {
return mutate_with(in, [&](auto *self, const Expr &e) {
if (limit == 0) {
return simplify(e);
}
limit--;
Expr new_e = self->mutate_base(e);
limit++;
return new_e;
});
} catch (InternalError &err) {
return err;
}
}
} // namespace
FUZZ_TEST(simplify, FuzzingContext &fuzz) {
// Depth of the randomly generated expression trees.
constexpr int depth = 6;
// Number of samples to test the generated expressions for.
constexpr int samples = 3;
// Number of samples to test the generated expressions for during minimization.
constexpr int samples_during_minimization = 100;
RandomExpressionGenerator reg{fuzz};
// FIXME: UInt64 fails!
reg.fuzz_types = {UInt(1), UInt(8), UInt(16), UInt(32), Int(8), Int(16), Int(32)};
// FIXME: These need to be disabled (otherwise crashes and/or failures):
// reg.gen_ramp_of_vector = false;
// reg.gen_broadcast_of_vector = false;
// reg.gen_vector_reduce = false;
reg.gen_reinterpret = false;
reg.gen_shuffles = false;
int width = fuzz.PickValueInArray({1, 2, 3, 4, 6, 8});
Expr test = reg.random_expr(reg.random_type(width), depth);
if (!test_expression(reg, test, samples)) {
// Failure. Find the minimal subexpression that failed.
std::cerr << "Testing subexpressions...\n";
bool found_failure = false;
test = mutate_with(test, [&](auto *self, const Expr &e) {
self->mutate_base(e);
if (e.type().bits() && !found_failure) {
for (int i = 1; i < 4 && !found_failure; i++) {
SimpilfyResult limited_res = simplify_at_depth(i, e);
if (limited_res.failed()) {
found_failure = true;
return e;
} else {
Expr limited = limited_res;
found_failure = !test_expression(reg, limited, samples_during_minimization);
if (found_failure) {
return limited;
}
}
}
if (!found_failure) {
found_failure = !test_expression(reg, e, samples_during_minimization);
}
}
return e;
});
std::cerr << "Final test case: " << test << "\n";
return 1;
}
return 0;
}