Skip to content
7 changes: 6 additions & 1 deletion src/Lean/Meta/Closure.lean
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,15 @@ def mkAuxDefinitionFor (name : Name) (value : Expr) (zetaDelta : Bool := false)
-/
def mkAuxTheorem (type : Expr) (value : Expr) (zetaDelta : Bool := false) (kind? : Option Name := none) (cache := true) : MetaM Expr := do
let result ← Closure.mkValueTypeClosure type value zetaDelta
let name ← mkAuxLemma (kind? := kind?) (cache := cache) result.levelParams.toList result.type result.value
let (resType, resValue) ← profileitM Exception "share common exprs" (← getOptions) do
withTraceNode `Meta.Closure.maxSharing (fun _ => return m!"share common exprs") do
let es := ShareCommon.shareCommon' #[result.type, result.value]
return (es[0]!, es[1]!)
let name ← mkAuxLemma (kind? := kind?) (cache := cache) result.levelParams.toList resType resValue
return mkAppN (mkConst name result.levelArgs.toList) result.exprArgs

builtin_initialize
registerTraceClass `Meta.Closure
registerTraceClass `Meta.Closure.maxSharing

end Lean.Meta
14 changes: 5 additions & 9 deletions src/kernel/environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ Author: Leonardo de Moura
#include <limits>
#include "runtime/sstream.h"
#include "runtime/thread.h"
#include "runtime/sharecommon.h"
#include "util/map_foreach.h"
#include "util/io.h"
#include "kernel/environment.h"
Expand Down Expand Up @@ -194,15 +193,12 @@ environment environment::add_theorem(declaration const & d, bool check) const {
theorem_val const & v = d.to_theorem_val();
if (check) {
type_checker checker(*this, diag.get());
sharecommon_persistent_fn share;
expr val(share(v.get_value().raw()));
expr type(share(v.get_type().raw()));
check_constant_val(*this, v.to_constant_val(), checker);
if (!checker.is_prop(type))
throw theorem_type_is_not_prop(*this, v.get_name(), type);
check_no_metavar_no_fvar(*this, v.get_name(), val);
expr val_type = checker.check(val, v.get_lparams());
if (!checker.is_def_eq(val_type, type))
if (!checker.is_prop(v.get_type()))
throw theorem_type_is_not_prop(*this, v.get_name(), v.get_type());
check_no_metavar_no_fvar(*this, v.get_name(), v.get_value());
expr val_type = checker.check(v.get_value(), v.get_lparams());
if (!checker.is_def_eq(val_type, v.get_type()))
throw definition_type_mismatch_exception(*this, d, val_type);
}
return diag.update(add(constant_info(d)));
Expand Down
23 changes: 12 additions & 11 deletions src/kernel/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ extern "C" uint8 lean_expr_has_level_param(obj_arg e);
bool has_univ_param(expr const & e) { return lean_expr_has_level_param(e.to_obj_arg()); }

extern "C" unsigned lean_expr_loose_bvar_range(object * e);
unsigned get_loose_bvar_range(expr const & e) { return lean_expr_loose_bvar_range(e.to_obj_arg()); }
unsigned loose_bvar_range_core(expr const & e) { return lean_expr_loose_bvar_range(e.to_obj_arg()); }

extern "C" LEAN_EXPORT uint64_t lean_expr_mk_data(uint64_t hash, object * bvarRange, uint32_t approxDepth, uint8_t hasFVar, uint8_t hasExprMVar, uint8_t hasLevelMVar, uint8_t hasLevelParam) {
if (approxDepth > 255) approxDepth = 255;
Expand Down Expand Up @@ -418,12 +418,13 @@ expr lower_loose_bvars(expr const & e, unsigned s, unsigned d) {
if (d == 0 || s >= get_loose_bvar_range(e))
return e;
lean_assert(s >= d);
return replace(e, [=](expr const & e, unsigned offset) -> optional<expr> {
return replace(e,
[=](expr const & e, unsigned offset) {
unsigned s1 = s + offset;
return s1 < s /* overflow */ || s1 >= get_loose_bvar_range(e);
},
[=](expr const & e, unsigned offset) -> optional<expr> {
unsigned s1 = s + offset;
if (s1 < s)
return some_expr(e); // overflow, vidx can't be >= max unsigned
if (s1 >= get_loose_bvar_range(e))
return some_expr(e); // expression e does not contain bound variables with idx >= s1
if (is_bvar(e) && bvar_idx(e) >= s1) {
lean_assert(bvar_idx(e) >= offset + d);
return some_expr(mk_bvar(bvar_idx(e) - nat(d)));
Expand All @@ -448,12 +449,12 @@ extern "C" LEAN_EXPORT object * lean_expr_lower_loose_bvars(b_obj_arg e, b_obj_a
expr lift_loose_bvars(expr const & e, unsigned s, unsigned d) {
if (d == 0 || s >= get_loose_bvar_range(e))
return e;
return replace(e, [=](expr const & e, unsigned offset) -> optional<expr> {
return replace(e,
[=](expr const & e, unsigned offset) {
unsigned s1 = s + offset;
if (s1 < s)
return some_expr(e); // overflow, vidx can't be >= max unsigned
if (s1 >= get_loose_bvar_range(e))
return some_expr(e); // expression e does not contain bound variables with idx >= s1
return s1 < s /* overflow */ || s1 >= get_loose_bvar_range(e);
},
[=](expr const & e, unsigned offset) -> optional<expr> {
if (is_var(e) && bvar_idx(e) >= s + offset) {
return some_expr(mk_bvar(bvar_idx(e) + nat(d)));
} else {
Expand Down
8 changes: 7 additions & 1 deletion src/kernel/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ inline bool has_fvar(expr const & e) {
return r;
}
bool has_univ_param(expr const & e);
unsigned get_loose_bvar_range(expr const & e);
/* This is the implementation in Lean */
unsigned loose_bvar_range_core(expr const & e);
inline unsigned get_loose_bvar_range(expr const & e) {
unsigned r = static_cast<unsigned>(get_data(e) >> 44);
lean_assert(r == loose_bvar_range_core(e));
return r;
}

struct expr_hash { unsigned operator()(expr const & e) const { return hash(e); } };
struct expr_pair_hash {
Expand Down
35 changes: 21 additions & 14 deletions src/kernel/instantiate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@ namespace lean {
expr instantiate(expr const & a, unsigned s, unsigned n, expr const * subst) {
if (s >= get_loose_bvar_range(a) || n == 0)
return a;
return replace(a, [=](expr const & m, unsigned offset) -> optional<expr> {
return replace(a,
[=](expr const & m, unsigned offset) {
unsigned s1 = s + offset;
if (s1 < s)
return some_expr(m); // overflow, vidx can't be >= max unsigned
if (s1 >= get_loose_bvar_range(m))
return some_expr(m); // expression m does not contain loose bound variables with idx >= s1
return s1 < s /* overflow */ || s1 >= get_loose_bvar_range(m);
},
[=](expr const & m, unsigned offset) -> optional<expr> {
if (is_bvar(m)) {
unsigned s1 = s + offset;
nat const & vidx = bvar_idx(m);
if (vidx >= s1) {
unsigned h = s1 + n;
Expand Down Expand Up @@ -58,9 +59,11 @@ static object * lean_expr_instantiate_core(b_obj_arg a0, size_t n, object** subs
lean_inc(a0);
return a0;
}
expr r = replace(a, [=](expr const & m, unsigned offset) -> optional<expr> {
if (offset >= get_loose_bvar_range(m))
return some_expr(m); // expression m does not contain loose bound variables with idx >= offset
expr r = replace(a,
[=](expr const & m, unsigned offset) {
return offset >= get_loose_bvar_range(m);
},
[=](expr const & m, unsigned offset) -> optional<expr> {
if (is_bvar(m)) {
nat const & vidx = bvar_idx(m);
if (vidx >= offset) {
Expand Down Expand Up @@ -99,9 +102,11 @@ extern "C" LEAN_EXPORT object * lean_expr_instantiate_range(b_obj_arg a, b_obj_a
expr instantiate_rev(expr const & a, unsigned n, expr const * subst) {
if (!has_loose_bvars(a))
return a;
return replace(a, [=](expr const & m, unsigned offset) -> optional<expr> {
if (offset >= get_loose_bvar_range(m))
return some_expr(m); // expression m does not contain loose bound variables with idx >= offset
return replace(a,
[=](expr const & m, unsigned offset) {
return offset >= get_loose_bvar_range(m);
},
[=](expr const & m, unsigned offset) -> optional<expr> {
if (is_bvar(m)) {
nat const & vidx = bvar_idx(m);
if (vidx >= offset) {
Expand All @@ -123,9 +128,11 @@ static object * lean_expr_instantiate_rev_core(object * a0, size_t n, object **
lean_inc(a0);
return a0;
}
expr r = replace(a, [=](expr const & m, unsigned offset) -> optional<expr> {
if (offset >= get_loose_bvar_range(m))
return some_expr(m); // expression m does not contain loose bound variables with idx >= offset
expr r = replace(a,
[=](expr const & m, unsigned offset) {
return offset >= get_loose_bvar_range(m);
},
[=](expr const & m, unsigned offset) -> optional<expr> {
if (is_bvar(m)) {
nat const & vidx = bvar_idx(m);
if (vidx >= offset) {
Expand Down
83 changes: 15 additions & 68 deletions src/kernel/replace_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,75 +13,22 @@ Author: Leonardo de Moura

namespace lean {

class replace_rec_fn {
struct key_hasher {
std::size_t operator()(std::pair<lean_object *, unsigned> const & p) const {
return hash((size_t)p.first >> 3, p.second);
}
};
lean::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
std::function<optional<expr>(expr const &, unsigned)> m_f;
bool m_use_cache;

expr save_result(expr const & e, unsigned offset, expr r, bool shared) {
if (shared)
m_cache.insert(mk_pair(mk_pair(e.raw(), offset), r));
return r;
}

expr apply(expr const & e, unsigned offset) {
bool shared = false;
if (m_use_cache && !is_likely_unshared(e)) {
auto it = m_cache.find(mk_pair(e.raw(), offset));
if (it != m_cache.end())
return it->second;
shared = true;
}
if (optional<expr> r = m_f(e, offset)) {
return save_result(e, offset, std::move(*r), shared);
} else {
switch (e.kind()) {
case expr_kind::Const: case expr_kind::Sort:
case expr_kind::BVar: case expr_kind::Lit:
case expr_kind::MVar: case expr_kind::FVar:
return save_result(e, offset, e, shared);
case expr_kind::MData: {
expr new_e = apply(mdata_expr(e), offset);
return save_result(e, offset, update_mdata(e, new_e), shared);
}
case expr_kind::Proj: {
expr new_e = apply(proj_expr(e), offset);
return save_result(e, offset, update_proj(e, new_e), shared);
}
case expr_kind::App: {
expr new_f = apply(app_fn(e), offset);
expr new_a = apply(app_arg(e), offset);
return save_result(e, offset, update_app(e, new_f, new_a), shared);
}
case expr_kind::Pi: case expr_kind::Lambda: {
expr new_d = apply(binding_domain(e), offset);
expr new_b = apply(binding_body(e), offset+1);
return save_result(e, offset, update_binding(e, new_d, new_b), shared);
}
case expr_kind::Let: {
expr new_t = apply(let_type(e), offset);
expr new_v = apply(let_value(e), offset);
expr new_b = apply(let_body(e), offset+1);
return save_result(e, offset, update_let(e, new_t, new_v, new_b), shared);
}
}
lean_unreachable();
}
}
public:
template<typename F>
replace_rec_fn(F const & f, bool use_cache):m_f(f), m_use_cache(use_cache) {}

expr operator()(expr const & e) { return apply(e, 0); }
};

// Out-of-line `std::function`-typed overloads, declared in replace_fn.h. Defer
// to the templated implementation in the header, instantiated once per
// `std::function` shape. Hot callers should bypass these and use the templated
// `replace<F>` (or `replace<Skip, F>`) directly to avoid the per-call indirect
// dispatch through `std::function`.
expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned)> const & f, bool use_cache) {
return replace_rec_fn(f, use_cache)(e);
return replace_rec_fn<std::function<optional<expr>(expr const &, unsigned)>>(f, use_cache)(e);
}

expr replace(expr const & e,
std::function<bool(expr const &, unsigned)> const & skip,
std::function<optional<expr>(expr const &, unsigned)> const & f,
bool use_cache) {
using F = std::function<optional<expr>(expr const &, unsigned)>;
using S = std::function<bool(expr const &, unsigned)>;
return replace_rec_fn<F, S>(skip, f, use_cache)(e);
}

class replace_fn {
Expand Down
Loading
Loading