-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathSimplify_Exprs.cpp
More file actions
387 lines (351 loc) · 16.9 KB
/
Simplify_Exprs.cpp
File metadata and controls
387 lines (351 loc) · 16.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
#include "Simplify_Internal.h"
using std::string;
namespace Halide {
namespace Internal {
// Miscellaneous expression visitors that are too small to bother putting in their own files
Expr Simplify::visit(const IntImm *op, ExprInfo *info) {
if (info) {
info->bounds = ConstantInterval::single_point(op->value);
info->alignment = ModulusRemainder(0, op->value);
info->cast_to(op->type);
} else {
clear_expr_info(info);
}
return op;
}
Expr Simplify::visit(const UIntImm *op, ExprInfo *info) {
if (info) {
// Pretend it's an int constant that has been cast to uint.
int64_t v = (int64_t)(op->value);
info->bounds = ConstantInterval::single_point(v);
info->alignment = ModulusRemainder(0, v);
// If it's not representable as an int64, this will wrap the alignment appropriately:
info->cast_to(op->type);
// Be as informative as we can with bounds for out-of-range uint64s
if ((int64_t)op->value < 0) {
info->bounds = ConstantInterval::bounded_below(INT64_MAX);
}
} else {
clear_expr_info(info);
}
return op;
}
Expr Simplify::visit(const FloatImm *op, ExprInfo *info) {
clear_expr_info(info);
return op;
}
Expr Simplify::visit(const StringImm *op, ExprInfo *info) {
clear_expr_info(info);
return op;
}
Expr Simplify::visit(const Broadcast *op, ExprInfo *info) {
Expr value = mutate(op->value, info);
const int lanes = op->lanes;
auto rewrite = IRMatcher::rewriter(IRMatcher::broadcast(value, lanes), op->type);
if (rewrite(broadcast(broadcast(x, c0), lanes), broadcast(x, c0 * lanes)) ||
rewrite(broadcast(IRMatcher::Overflow(), lanes), IRMatcher::Overflow()) ||
false) {
return mutate(rewrite.result, info);
}
if (value.same_as(op->value)) {
return op;
} else {
return Broadcast::make(value, op->lanes);
}
}
Expr Simplify::visit(const VectorReduce *op, ExprInfo *info) {
Expr value = mutate(op->value, info);
const int lanes = op->type.lanes();
const int arg_lanes = op->value.type().lanes();
const int factor = arg_lanes / lanes;
if (factor == 1) {
return value;
}
if (info && op->type.is_int_or_uint()) {
switch (op->op) {
case VectorReduce::Add:
// Alignment of result is the alignment of the arg. Bounds
// of the result can grow according to the reduction
// factor.
info->bounds = cast(op->type, info->bounds * factor);
break;
case VectorReduce::SaturatingAdd:
info->bounds = saturating_cast(op->type, info->bounds * factor);
break;
case VectorReduce::Mul:
// Don't try to infer anything about bounds. Leave the
// alignment unchanged even though we could theoretically
// upgrade it.
info->bounds = ConstantInterval{};
break;
case VectorReduce::Min:
case VectorReduce::Max:
// Bounds and alignment of the result are just the bounds and alignment of the arg.
break;
case VectorReduce::And:
case VectorReduce::Or:
// For integer types this is a bitwise operator. Don't try
// to infer anything for now.
info->bounds = ConstantInterval{};
info->alignment = ModulusRemainder{};
break;
}
}
// We can pull multiplications by a broadcast out of horizontal
// additions and do the horizontal addition earlier. This means we
// do the multiplication on a vector with fewer lanes. This
// approach applies whenever we have a distributive law. We'll
// exploit the following distributive laws here:
// - Multiplication distributes over addition
// - min/max distributes over min/max
// - and/or distributes over and/or
// Further, we can collapse min/max/and/or of a broadcast down to
// a narrower broadcast.
// TODO: There are other rules we could apply here if they ever
// come up in practice:
// - a horizontal min/max/add of a ramp is a different ramp
// - horizontal add of a broadcast is a broadcast + multiply
// - horizontal reduce of an shuffle_vectors may be simplifiable to the
// underlying op on different shuffle_vectors calls
switch (op->op) {
case VectorReduce::Add: {
auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type);
if (rewrite(h_add(x * broadcast(y, arg_lanes), lanes), h_add(x, lanes) * broadcast(y, lanes)) ||
rewrite(h_add(broadcast(x, arg_lanes) * y, lanes), h_add(y, lanes) * broadcast(x, lanes)) ||
rewrite(h_add(broadcast(x, arg_lanes), lanes), broadcast(x * factor, lanes)) ||
rewrite(h_add(broadcast(x, c0), lanes), broadcast(h_add(x, lanes / c0), c0), lanes % c0 == 0) ||
rewrite(h_add(broadcast(x, c0), lanes), broadcast(h_add(x, 1) * (c0 / lanes), lanes), c0 % lanes == 0) ||
false) {
return mutate(rewrite.result, info);
}
break;
}
case VectorReduce::Min: {
auto rewrite = IRMatcher::rewriter(IRMatcher::h_min(value, lanes), op->type);
if (rewrite(h_min(min(x, broadcast(y, arg_lanes)), lanes), min(h_min(x, lanes), broadcast(y, lanes))) ||
rewrite(h_min(min(broadcast(x, arg_lanes), y), lanes), min(h_min(y, lanes), broadcast(x, lanes))) ||
rewrite(h_min(max(x, broadcast(y, arg_lanes)), lanes), max(h_min(x, lanes), broadcast(y, lanes))) ||
rewrite(h_min(max(broadcast(x, arg_lanes), y), lanes), max(h_min(y, lanes), broadcast(x, lanes))) ||
rewrite(h_min(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
rewrite(h_min(broadcast(x, c0), 1), h_min(x, 1)) ||
rewrite(h_min(broadcast(x, c0), lanes), broadcast(h_min(x, lanes / c0), c0), lanes % c0 == 0) ||
rewrite(h_min(ramp(x, y, arg_lanes), 1), x + min(y * (arg_lanes - 1), 0)) ||
rewrite(h_min(ramp(x, y, arg_lanes), lanes), ramp(x + min(y * (factor - 1), 0), y * factor, lanes)) ||
false) {
return mutate(rewrite.result, info);
}
break;
}
case VectorReduce::Max: {
auto rewrite = IRMatcher::rewriter(IRMatcher::h_max(value, lanes), op->type);
if (rewrite(h_max(min(x, broadcast(y, arg_lanes)), lanes), min(h_max(x, lanes), broadcast(y, lanes))) ||
rewrite(h_max(min(broadcast(x, arg_lanes), y), lanes), min(h_max(y, lanes), broadcast(x, lanes))) ||
rewrite(h_max(max(x, broadcast(y, arg_lanes)), lanes), max(h_max(x, lanes), broadcast(y, lanes))) ||
rewrite(h_max(max(broadcast(x, arg_lanes), y), lanes), max(h_max(y, lanes), broadcast(x, lanes))) ||
rewrite(h_max(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
rewrite(h_max(broadcast(x, c0), 1), h_max(x, 1)) ||
rewrite(h_max(broadcast(x, c0), lanes), broadcast(h_max(x, lanes / c0), c0), lanes % c0 == 0) ||
rewrite(h_max(ramp(x, y, arg_lanes), 1), x + max(y * (arg_lanes - 1), 0)) ||
rewrite(h_max(ramp(x, y, arg_lanes), lanes), ramp(x + max(y * (factor - 1), 0), y * factor, lanes)) ||
false) {
return mutate(rewrite.result, info);
}
break;
}
case VectorReduce::And: {
auto rewrite = IRMatcher::rewriter(IRMatcher::h_and(value, lanes), op->type);
if (rewrite(h_and(x || broadcast(y, arg_lanes), lanes), h_and(x, lanes) || broadcast(y, lanes)) ||
rewrite(h_and(broadcast(x, arg_lanes) || y, lanes), h_and(y, lanes) || broadcast(x, lanes)) ||
rewrite(h_and(x && broadcast(y, arg_lanes), lanes), h_and(x, lanes) && broadcast(y, lanes)) ||
rewrite(h_and(broadcast(x, arg_lanes) && y, lanes), h_and(y, lanes) && broadcast(x, lanes)) ||
rewrite(h_and(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
rewrite(h_and(broadcast(x, c0), lanes), broadcast(h_and(x, lanes / c0), c0), lanes % c0 == 0) ||
rewrite(h_and(broadcast(x, c0), lanes), broadcast(h_and(x, 1), lanes), c0 >= lanes) ||
(lanes == 1 && rewrite(h_and(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes),
x + max(y * (arg_lanes - 1), 0) < z)) ||
(lanes == 1 && rewrite(h_and(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes),
x + max(y * (arg_lanes - 1), 0) <= z)) ||
(lanes == 1 && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x < y + min(z * (arg_lanes - 1), 0))) ||
(lanes == 1 && rewrite(h_and(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x <= y + min(z * (arg_lanes - 1), 0))) ||
false) {
return mutate(rewrite.result, info);
}
break;
}
case VectorReduce::Or: {
auto rewrite = IRMatcher::rewriter(IRMatcher::h_or(value, lanes), op->type);
if (rewrite(h_or(x || broadcast(y, arg_lanes), lanes), h_or(x, lanes) || broadcast(y, lanes)) ||
rewrite(h_or(broadcast(x, arg_lanes) || y, lanes), h_or(y, lanes) || broadcast(x, lanes)) ||
rewrite(h_or(x && broadcast(y, arg_lanes), lanes), h_or(x, lanes) && broadcast(y, lanes)) ||
rewrite(h_or(broadcast(x, arg_lanes) && y, lanes), h_or(y, lanes) && broadcast(x, lanes)) ||
rewrite(h_or(broadcast(x, arg_lanes), lanes), broadcast(x, lanes)) ||
rewrite(h_or(broadcast(x, c0), lanes), broadcast(h_or(x, lanes / c0), c0), lanes % c0 == 0) ||
rewrite(h_or(broadcast(x, c0), lanes), broadcast(h_or(x, 1), lanes), c0 >= lanes) ||
// type of arg_lanes is somewhat indeterminate
rewrite(h_or(ramp(x, y, arg_lanes) < broadcast(z, arg_lanes), lanes),
x + min(y * (arg_lanes - 1), 0) < z) ||
rewrite(h_or(ramp(x, y, arg_lanes) <= broadcast(z, arg_lanes), lanes),
x + min(y * (arg_lanes - 1), 0) <= z) ||
rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x < y + max(z * (arg_lanes - 1), 0)) ||
rewrite(h_or(broadcast(x, arg_lanes) < ramp(y, z, arg_lanes), lanes),
x <= y + max(z * (arg_lanes - 1), 0)) ||
false) {
return mutate(rewrite.result, info);
}
break;
}
default:
break;
}
if (value.same_as(op->value)) {
return op;
} else {
return VectorReduce::make(op->op, value, op->type.lanes());
}
}
Expr Simplify::visit(const Variable *op, ExprInfo *info) {
if (const ExprInfo *b = bounds_and_alignment_info.find(op->name)) {
if (info) {
*info = *b;
}
if (b->bounds.is_single_point()) {
return make_const(op->type, b->bounds.min, nullptr);
}
} else if (info && !no_overflow_int(op->type)) {
info->bounds = ConstantInterval::bounds_of_type(op->type);
}
if (auto *v_info = var_info.shallow_find(op->name)) {
// if replacement is defined, we should substitute it in (unless
// it's a var that has been hidden by a nested scope).
if (v_info->replacement.defined()) {
internal_assert(v_info->replacement.type() == op->type)
<< "Cannot replace variable " << op->name
<< " of type " << op->type
<< " with expression of type " << v_info->replacement.type() << "\n";
v_info->new_uses++;
// We want to remutate the replacement, because we may be
// injecting it into a context where it is known to be a
// constant (e.g. due to an if).
return mutate(v_info->replacement, info);
} else {
// This expression was not something deemed
// substitutable - no replacement is defined.
v_info->old_uses++;
return op;
}
} else {
// We never encountered a let that defines this var. Must
// be a uniform. Don't touch it.
return op;
}
}
Expr Simplify::visit(const Ramp *op, ExprInfo *info) {
ExprInfo base_info, stride_info;
Expr base = mutate(op->base, &base_info);
Expr stride = mutate(op->stride, &stride_info);
const int lanes = op->lanes;
if (info) {
info->bounds = base_info.bounds + stride_info.bounds * ConstantInterval(0, lanes - 1);
// A ramp lane is b + l * s. Expanding b into mb * x + rb and s into ms * y + rs, we get:
// mb * x + rb + l * (ms * y + rs)
// = mb * x + ms * l * y + rs * l + rb
// = gcd(rs, ms, mb) * z + rb
int64_t m = stride_info.alignment.modulus;
m = gcd(m, stride_info.alignment.remainder);
m = gcd(m, base_info.alignment.modulus);
int64_t r = base_info.alignment.remainder;
if (m != 0) {
r = mod_imp(base_info.alignment.remainder, m);
}
info->alignment = {m, r};
info->cast_to(op->type);
info->trim_bounds_using_alignment();
}
// A somewhat torturous way to check if the stride is zero,
// but it helps to have as many rules as possible written as
// formal rewrites, so that they can be formally verified,
// etc.
auto rewrite = IRMatcher::rewriter(IRMatcher::ramp(base, stride, lanes), op->type);
if (rewrite(ramp(x, 0, lanes), broadcast(x, lanes)) ||
rewrite(ramp(ramp(x, c0, c2), broadcast(c1, c4), c3),
ramp(x, c0, c2 * c3),
// In the multiply below, it's important c0 is on the
// right. When folding constants, binary ops take their type
// from the RHS. c2 is an int64 lane count but c0 has the type
// we want for the comparison.
c1 == c2 * c0) ||
false) {
return mutate(rewrite.result, info);
}
if (base.same_as(op->base) &&
stride.same_as(op->stride)) {
return op;
} else {
return Ramp::make(base, stride, op->lanes);
}
}
Expr Simplify::visit(const Load *op, ExprInfo *info) {
found_buffer_reference(op->name);
if (info) {
info->bounds = ConstantInterval::bounds_of_type(op->type);
}
Expr predicate = mutate(op->predicate, nullptr);
ExprInfo index_info;
Expr index = mutate(op->index, &index_info);
// If an unpredicated load is fully out of bounds, replace it with an
// unreachable intrinsic. This should only occur inside branches that make
// the load unreachable, but perhaps the branch was hard to prove constant
// true or false. This provides an alternative mechanism to simplify these
// unreachable loads.
if (is_const_one(op->predicate)) {
string alloc_extent_name = op->name + ".total_extent_bytes";
if (const auto *alloc_info = bounds_and_alignment_info.find(alloc_extent_name)) {
if (index_info.bounds < 0 ||
index_info.bounds * op->type.bytes() > alloc_info->bounds) {
in_unreachable = true;
return unreachable(op->type);
}
}
}
ExprInfo base_info;
if (const Ramp *r = index.as<Ramp>()) {
mutate(r->base, &base_info);
}
base_info.alignment = ModulusRemainder::intersect(base_info.alignment, index_info.alignment);
ModulusRemainder align = ModulusRemainder::intersect(op->alignment, base_info.alignment);
const Broadcast *b_index = index.as<Broadcast>();
const Shuffle *s_index = index.as<Shuffle>();
if (is_const_zero(predicate)) {
// Predicate is always false
return make_zero(op->type);
} else if (b_index && is_const_one(predicate)) {
// Load of a broadcast should be broadcast of the load
Expr new_index = b_index->value;
int new_lanes = new_index.type().lanes();
Expr load = Load::make(op->type.with_lanes(new_lanes), op->name, b_index->value,
op->image, op->param, const_true(new_lanes, nullptr), align);
return Broadcast::make(load, b_index->lanes);
} else if (s_index &&
is_const_one(predicate) &&
(s_index->is_concat() ||
s_index->is_interleave())) {
// Loads of concats/interleaves should be concats/interleaves of loads
std::vector<Expr> loaded_vecs;
for (const Expr &new_index : s_index->vectors) {
int new_lanes = new_index.type().lanes();
Expr load = Load::make(op->type.with_lanes(new_lanes), op->name, new_index,
op->image, op->param, const_true(new_lanes, nullptr), ModulusRemainder{});
loaded_vecs.emplace_back(std::move(load));
}
return Shuffle::make(loaded_vecs, s_index->indices);
} else if (predicate.same_as(op->predicate) && index.same_as(op->index) && align == op->alignment) {
return op;
} else {
return Load::make(op->type, op->name, index, op->image, op->param, predicate, align);
}
}
} // namespace Internal
} // namespace Halide